Coverage for pydantic/_internal/_discriminated_union.py: 96.99%
197 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-06-21 17:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-06-21 17:00 +0000
1from __future__ import annotations as _annotations 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
3from typing import TYPE_CHECKING, Any, Hashable, Sequence 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
5from pydantic_core import CoreSchema, core_schema 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
7from ..errors import PydanticUserError 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
8from . import _core_utils 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
9from ._core_utils import ( 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
10 CoreSchemaField,
11 collect_definitions,
12)
14if TYPE_CHECKING: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
15 from ..types import Discriminator
17CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator' 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
20class MissingDefinitionForUnionRef(Exception): 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
21 """Raised when applying a discriminated union discriminator to a schema
22 requires a definition that is not yet defined
23 """
25 def __init__(self, ref: str) -> None: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
26 self.ref = ref 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
27 super().__init__(f'Missing definition for ref {self.ref!r}') 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
30def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
31 schema.setdefault('metadata', {}) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
32 metadata = schema.get('metadata') 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
33 assert metadata is not None 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
34 metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
37def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
38 # We recursively walk through the `schema` passed to `apply_discriminators`, applying discriminators
39 # where necessary at each level. During this recursion, we allow references to be resolved from the definitions
40 # that are originally present on the original, outermost `schema`. Before `apply_discriminators` is called,
41 # `simplify_schema_references` is called on the schema (in the `clean_schema` function),
42 # which often puts the definitions in the outermost schema.
43 global_definitions: dict[str, CoreSchema] = collect_definitions(schema) 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
45 def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
46 nonlocal global_definitions
48 s = recurse(s, inner) 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
49 if s['type'] == 'tagged-union': 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
50 return s 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
52 metadata = s.get('metadata', {}) 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
53 discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None) 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
54 if discriminator is not None: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
55 s = apply_discriminator(s, discriminator, global_definitions) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
56 return s 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
58 return _core_utils.walk_core_schema(schema, inner) 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
61def apply_discriminator( 1akblcmdneopqrstuvwxyGHIJKLMNOfzgAhBiCjD
62 schema: core_schema.CoreSchema,
63 discriminator: str | Discriminator,
64 definitions: dict[str, core_schema.CoreSchema] | None = None,
65) -> core_schema.CoreSchema:
66 """Applies the discriminator and returns a new core schema.
68 Args:
69 schema: The input schema.
70 discriminator: The name of the field which will serve as the discriminator.
71 definitions: A mapping of schema ref to schema.
73 Returns:
74 The new core schema.
76 Raises:
77 TypeError:
78 - If `discriminator` is used with invalid union variant.
79 - If `discriminator` is used with `Union` type with one variant.
80 - If `discriminator` value mapped to multiple choices.
81 MissingDefinitionForUnionRef:
82 If the definition for ref is missing.
83 PydanticUserError:
84 - If a model in union doesn't have a discriminator field.
85 - If discriminator field has a non-string alias.
86 - If discriminator fields have different aliases.
87 - If discriminator field not of type `Literal`.
88 """
89 from ..types import Discriminator 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
91 if isinstance(discriminator, Discriminator): 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
92 if isinstance(discriminator.discriminator, str): 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
93 discriminator = discriminator.discriminator 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
94 else:
95 return discriminator._convert_schema(schema) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
97 return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
100class _ApplyInferredDiscriminator: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
101 """This class is used to convert an input schema containing a union schema into one where that union is
102 replaced with a tagged-union, with all the associated debugging and performance benefits.
104 This is done by:
105 * Validating that the input schema is compatible with the provided discriminator
106 * Introspecting the schema to determine which discriminator values should map to which union choices
107 * Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more
109 I have chosen to implement the conversion algorithm in this class, rather than a function,
110 to make it easier to maintain state while recursively walking the provided CoreSchema.
111 """
113 def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]): 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
114 # `discriminator` should be the name of the field which will serve as the discriminator.
115 # It must be the python name of the field, and *not* the field's alias. Note that as of now,
116 # all members of a discriminated union _must_ use a field with the same name as the discriminator.
117 # This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices.
118 self.discriminator = discriminator 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
120 # `definitions` should contain a mapping of schema ref to schema for all schemas which might
121 # be referenced by some choice
122 self.definitions = definitions 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
124 # `_discriminator_alias` will hold the value, if present, of the alias for the discriminator
125 #
126 # Note: following the v1 implementation, we currently disallow the use of different aliases
127 # for different choices. This is not a limitation of pydantic_core, but if we try to handle
128 # this, the inference logic gets complicated very quickly, and could result in confusing
129 # debugging challenges for users making subtle mistakes.
130 #
131 # Rather than trying to do the most powerful inference possible, I think we should eventually
132 # expose a way to more-manually control the way the TaggedUnionSchema is constructed through
133 # the use of a new type which would be placed as an Annotation on the Union type. This would
134 # provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for
135 # more complex cases, without over-complicating the inference logic for the common cases.
136 self._discriminator_alias: str | None = None 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
138 # `_should_be_nullable` indicates whether the converted union has `None` as an allowed value.
139 # If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while
140 # constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True.
141 # Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure
142 # that the final schema gets wrapped as a NullableSchema. This has the same semantics on the
143 # python side, but resolves the issue that `None` cannot correspond to any discriminator values.
144 self._should_be_nullable = False 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
146 # `_is_nullable` is used to track if the final produced schema will definitely be nullable;
147 # we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved
148 # as an indication that, even if None is discovered as one of the union choices, we will not need to wrap
149 # the final value in another nullable schema.
150 #
151 # This is more complicated than just checking for the final outermost schema having type 'nullable' thanks
152 # to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc.
153 self._is_nullable = False 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
155 # `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices
156 # from the union in the wrapped schema will be appended to this list, and the recursive choice-handling
157 # algorithm may add more choices to this stack as (nested) unions are encountered.
158 self._choices_to_handle: list[core_schema.CoreSchema] = [] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
160 # `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included
161 # in the output TaggedUnionSchema that will replace the union from the input schema
162 self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {} 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
164 # `_used` is changed to True after applying the discriminator to prevent accidental re-use
165 self._used = False 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
167 def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
168 """Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided
169 to this class.
171 Args:
172 schema: The input schema.
174 Returns:
175 The new core schema.
177 Raises:
178 TypeError:
179 - If `discriminator` is used with invalid union variant.
180 - If `discriminator` is used with `Union` type with one variant.
181 - If `discriminator` value mapped to multiple choices.
182 ValueError:
183 If the definition for ref is missing.
184 PydanticUserError:
185 - If a model in union doesn't have a discriminator field.
186 - If discriminator field has a non-string alias.
187 - If discriminator fields have different aliases.
188 - If discriminator field not of type `Literal`.
189 """
190 assert not self._used 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
191 schema = self._apply_to_root(schema) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
192 if self._should_be_nullable and not self._is_nullable: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
193 schema = core_schema.nullable_schema(schema) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
194 self._used = True 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
195 return schema 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
197 def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
198 """This method handles the outer-most stage of recursion over the input schema:
199 unwrapping nullable or definitions schemas, and calling the `_handle_choice`
200 method iteratively on the choices extracted (recursively) from the possibly-wrapped union.
201 """
202 if schema['type'] == 'nullable': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
203 self._is_nullable = True 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
204 wrapped = self._apply_to_root(schema['schema']) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
205 nullable_wrapper = schema.copy() 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
206 nullable_wrapper['schema'] = wrapped 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
207 return nullable_wrapper 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
209 if schema['type'] == 'definitions': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
210 wrapped = self._apply_to_root(schema['schema']) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
211 definitions_wrapper = schema.copy() 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
212 definitions_wrapper['schema'] = wrapped 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
213 return definitions_wrapper 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
215 if schema['type'] != 'union': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
216 # If the schema is not a union, it probably means it just had a single member and
217 # was flattened by pydantic_core.
218 # However, it still may make sense to apply the discriminator to this schema,
219 # as a way to get discriminated-union-style error messages, so we allow this here.
220 schema = core_schema.union_schema([schema]) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
222 # Reverse the choices list before extending the stack so that they get handled in the order they occur
223 choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
224 self._choices_to_handle.extend(choices_schemas) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
225 while self._choices_to_handle: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
226 choice = self._choices_to_handle.pop() 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
227 self._handle_choice(choice) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
229 if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
230 # * We need to annotate `discriminator` as a union here to handle both branches of this conditional
231 # * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the
232 # invariance of list, and because list[list[str | int]] is the type of the discriminator argument
233 # to tagged_union_schema below
234 # * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to
235 # interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here
236 # is the appropriate way to provide a list of fallback attributes to check for a discriminator value.)
237 discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
238 else:
239 discriminator = self.discriminator 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
240 return core_schema.tagged_union_schema( 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
241 choices=self._tagged_union_choices,
242 discriminator=discriminator,
243 custom_error_type=schema.get('custom_error_type'),
244 custom_error_message=schema.get('custom_error_message'),
245 custom_error_context=schema.get('custom_error_context'),
246 strict=False,
247 from_attributes=True,
248 ref=schema.get('ref'),
249 metadata=schema.get('metadata'),
250 serialization=schema.get('serialization'),
251 )
253 def _handle_choice(self, choice: core_schema.CoreSchema) -> None: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
254 """This method handles the "middle" stage of recursion over the input schema.
255 Specifically, it is responsible for handling each choice of the outermost union
256 (and any "coalesced" choices obtained from inner unions).
258 Here, "handling" entails:
259 * Coalescing nested unions and compatible tagged-unions
260 * Tracking the presence of 'none' and 'nullable' schemas occurring as choices
261 * Validating that each allowed discriminator value maps to a unique choice
262 * Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
263 """
264 if choice['type'] == 'definition-ref': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
265 if choice['schema_ref'] not in self.definitions: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
266 raise MissingDefinitionForUnionRef(choice['schema_ref']) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
268 if choice['type'] == 'none': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
269 self._should_be_nullable = True 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
270 elif choice['type'] == 'definitions': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
271 self._handle_choice(choice['schema']) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
272 elif choice['type'] == 'nullable': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
273 self._should_be_nullable = True 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
274 self._handle_choice(choice['schema']) # unwrap the nullable schema 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
275 elif choice['type'] == 'union': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
276 # Reverse the choices list before extending the stack so that they get handled in the order they occur
277 choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
278 self._choices_to_handle.extend(choices_schemas) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
279 elif choice['type'] not in { 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
280 'model',
281 'typed-dict',
282 'tagged-union',
283 'lax-or-strict',
284 'dataclass',
285 'dataclass-args',
286 'definition-ref',
287 } and not _core_utils.is_function_with_inner_schema(choice):
288 # We should eventually handle 'definition-ref' as well
289 raise TypeError( 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
290 f'{choice["type"]!r} is not a valid discriminated union variant;'
291 ' should be a `BaseModel` or `dataclass`'
292 )
293 else:
294 if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice): 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
295 # In this case, this inner tagged-union is compatible with the outer tagged-union,
296 # and its choices can be coalesced into the outer TaggedUnionSchema.
297 subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
298 # Reverse the choices list before extending the stack so that they get handled in the order they occur
299 self._choices_to_handle.extend(subchoices[::-1]) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
300 return 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
302 inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
303 self._set_unique_choice_for_values(choice, inferred_discriminator_values) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
305 def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
306 """This method returns a boolean indicating whether the discriminator for the `choice`
307 is the same as that being used for the outermost tagged union. This is used to
308 determine whether this TaggedUnionSchema choice should be "coalesced" into the top level,
309 or whether it should be treated as a separate (nested) choice.
310 """
311 inner_discriminator = choice['discriminator'] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
312 return inner_discriminator == self.discriminator or ( 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
313 isinstance(inner_discriminator, list)
314 and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator)
315 )
317 def _infer_discriminator_values_for_choice( # noqa C901 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
318 self, choice: core_schema.CoreSchema, source_name: str | None
319 ) -> list[str | int]:
320 """This function recurses over `choice`, extracting all discriminator values that should map to this choice.
322 `model_name` is accepted for the purpose of producing useful error messages.
323 """
324 if choice['type'] == 'definitions': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
325 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
326 elif choice['type'] == 'function-plain': 326 ↛ 327line 326 didn't jump to line 327, because the condition on line 326 was never true1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
327 raise TypeError(
328 f'{choice["type"]!r} is not a valid discriminated union variant;'
329 ' should be a `BaseModel` or `dataclass`'
330 )
331 elif _core_utils.is_function_with_inner_schema(choice): 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
332 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
333 elif choice['type'] == 'lax-or-strict': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
334 return sorted( 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
335 set(
336 self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None)
337 + self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None)
338 )
339 )
341 elif choice['type'] == 'tagged-union': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
342 values: list[str | int] = [] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
343 # Ignore str/int "choices" since these are just references to other choices
344 subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
345 for subchoice in subchoices: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
346 subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
347 values.extend(subchoice_values) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
348 return values 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
350 elif choice['type'] == 'union': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
351 values = [] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
352 for subchoice in choice['choices']: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
353 subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
354 subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
355 values.extend(subchoice_values) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
356 return values 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
358 elif choice['type'] == 'nullable': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
359 self._should_be_nullable = True 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
360 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
362 elif choice['type'] == 'model': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
363 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
365 elif choice['type'] == 'dataclass': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
366 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
368 elif choice['type'] == 'model-fields': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
369 return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
371 elif choice['type'] == 'dataclass-args': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
372 return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
374 elif choice['type'] == 'typed-dict': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
375 return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
377 elif choice['type'] == 'definition-ref': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
378 schema_ref = choice['schema_ref'] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
379 if schema_ref not in self.definitions: 379 ↛ 380line 379 didn't jump to line 380, because the condition on line 379 was never true1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
380 raise MissingDefinitionForUnionRef(schema_ref)
381 return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
382 else:
383 raise TypeError( 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
384 f'{choice["type"]!r} is not a valid discriminated union variant;'
385 ' should be a `BaseModel` or `dataclass`'
386 )
388 def _infer_discriminator_values_for_typed_dict_choice( 1akblcmdneopqrstuvwxyGHIJKLMNOfzgAhBiCjD
389 self, choice: core_schema.TypedDictSchema, source_name: str | None = None
390 ) -> list[str | int]:
391 """This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema
392 for the sake of readability.
393 """
394 source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}' 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
395 field = choice['fields'].get(self.discriminator) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
396 if field is None: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
397 raise PydanticUserError( 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
398 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
399 )
400 return self._infer_discriminator_values_for_field(field, source) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
402 def _infer_discriminator_values_for_model_choice( 1akblcmdneopqrstuvwxyGHIJKLMNOfzgAhBiCjD
403 self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None
404 ) -> list[str | int]:
405 source = 'ModelFields' if source_name is None else f'Model {source_name!r}' 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
406 field = choice['fields'].get(self.discriminator) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
407 if field is None: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
408 raise PydanticUserError( 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
409 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
410 )
411 return self._infer_discriminator_values_for_field(field, source) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
413 def _infer_discriminator_values_for_dataclass_choice( 1akblcmdneopqrstuvwxyGHIJKLMNOfzgAhBiCjD
414 self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None
415 ) -> list[str | int]:
416 source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}' 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
417 for field in choice['fields']: 417 ↛ 421line 417 didn't jump to line 421, because the loop on line 417 didn't complete1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
418 if field['name'] == self.discriminator: 418 ↛ 417line 418 didn't jump to line 417, because the condition on line 418 was always true1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
419 break 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
420 else:
421 raise PydanticUserError(
422 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
423 )
424 return self._infer_discriminator_values_for_field(field, source) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
426 def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
427 if field['type'] == 'computed-field': 427 ↛ 429line 427 didn't jump to line 429, because the condition on line 427 was never true1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
428 # This should never occur as a discriminator, as it is only relevant to serialization
429 return []
430 alias = field.get('validation_alias', self.discriminator) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
431 if not isinstance(alias, str): 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
432 raise PydanticUserError( 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
433 f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type'
434 )
435 if self._discriminator_alias is None: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
436 self._discriminator_alias = alias 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
437 elif self._discriminator_alias != alias: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
438 raise PydanticUserError( 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
439 f'Aliases for discriminator {self.discriminator!r} must be the same '
440 f'(got {alias}, {self._discriminator_alias})',
441 code='discriminator-alias',
442 )
443 return self._infer_discriminator_values_for_inner_schema(field['schema'], source) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
445 def _infer_discriminator_values_for_inner_schema( 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
446 self, schema: core_schema.CoreSchema, source: str
447 ) -> list[str | int]:
448 """When inferring discriminator values for a field, we typically extract the expected values from a literal
449 schema. This function does that, but also handles nested unions and defaults.
450 """
451 if schema['type'] == 'literal': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
452 return schema['expected'] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
454 elif schema['type'] == 'union': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
455 # Generally when multiple values are allowed they should be placed in a single `Literal`, but
456 # we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s.
457 # For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]`
458 values: list[Any] = [] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
459 for choice in schema['choices']: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
460 choice_schema = choice[0] if isinstance(choice, tuple) else choice 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
461 choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
462 values.extend(choice_values) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
463 return values 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
465 elif schema['type'] == 'default': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
466 # This will happen if the field has a default value; we ignore it while extracting the discriminator values
467 return self._infer_discriminator_values_for_inner_schema(schema['schema'], source) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
469 elif schema['type'] == 'function-after': 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
470 # After validators don't affect the discriminator values
471 return self._infer_discriminator_values_for_inner_schema(schema['schema'], source) 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
473 elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
474 validator_type = repr(schema['type'].split('-')[1]) 1abcdefghij
475 raise PydanticUserError( 1abcdefghij
476 f'Cannot use a mode={validator_type} validator in the'
477 f' discriminator field {self.discriminator!r} of {source}',
478 code='discriminator-validator',
479 )
481 else:
482 raise PydanticUserError( 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
483 f'{source} needs field {self.discriminator!r} to be of type `Literal`',
484 code='discriminator-needs-literal',
485 )
487 def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None: 1akblcmdneoEFpqrstuvwxyGHIJKLMNOfzgAhBiCjD
488 """This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the
489 provided `choice`, validating that none of these values already map to another (different) choice.
490 """
491 for discriminator_value in values: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
492 if discriminator_value in self._tagged_union_choices: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
493 # It is okay if `value` is already in tagged_union_choices as long as it maps to the same value.
494 # Because tagged_union_choices may map values to other values, we need to walk the choices dict
495 # until we get to a "real" choice, and confirm that is equal to the one assigned.
496 existing_choice = self._tagged_union_choices[discriminator_value] 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
497 if existing_choice != choice: 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
498 raise TypeError( 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD
499 f'Value {discriminator_value!r} for discriminator '
500 f'{self.discriminator!r} mapped to multiple choices'
501 )
502 else:
503 self._tagged_union_choices[discriminator_value] = choice 1akblcmdneoEFpqrstuvwxyfzgAhBiCjD