Coverage for pydantic/_internal/_discriminated_union.py: 96.99%
197 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-03 19:29 +0000
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-03 19:29 +0000
1from __future__ import annotations as _annotations 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
3from typing import TYPE_CHECKING, Any, Hashable, Sequence 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
5from pydantic_core import CoreSchema, core_schema 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
7from ..errors import PydanticUserError 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
8from . import _core_utils 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
9from ._core_utils import ( 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
10 CoreSchemaField,
11 collect_definitions,
12)
14if TYPE_CHECKING: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
15 from ..types import Discriminator
17CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator' 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
20class MissingDefinitionForUnionRef(Exception): 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
21 """Raised when applying a discriminated union discriminator to a schema
22 requires a definition that is not yet defined
23 """
25 def __init__(self, ref: str) -> None: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
26 self.ref = ref 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
27 super().__init__(f'Missing definition for ref {self.ref!r}') 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
30def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
31 schema.setdefault('metadata', {}) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
32 metadata = schema.get('metadata') 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
33 assert metadata is not None 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
34 metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
37def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
38 # We recursively walk through the `schema` passed to `apply_discriminators`, applying discriminators
39 # where necessary at each level. During this recursion, we allow references to be resolved from the definitions
40 # that are originally present on the original, outermost `schema`. Before `apply_discriminators` is called,
41 # `simplify_schema_references` is called on the schema (in the `clean_schema` function),
42 # which often puts the definitions in the outermost schema.
43 global_definitions: dict[str, CoreSchema] = collect_definitions(schema) 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
45 def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
46 nonlocal global_definitions
48 s = recurse(s, inner) 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
49 if s['type'] == 'tagged-union': 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
50 return s 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
52 metadata = s.get('metadata', {}) 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
53 discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None) 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
54 if discriminator is not None: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
55 s = apply_discriminator(s, discriminator, global_definitions) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
56 return s 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
58 return _core_utils.walk_core_schema(schema, inner) 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
61def apply_discriminator( 1akblcmdneopqrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
62 schema: core_schema.CoreSchema,
63 discriminator: str | Discriminator,
64 definitions: dict[str, core_schema.CoreSchema] | None = None,
65) -> core_schema.CoreSchema:
66 """Applies the discriminator and returns a new core schema.
68 Args:
69 schema: The input schema.
70 discriminator: The name of the field which will serve as the discriminator.
71 definitions: A mapping of schema ref to schema.
73 Returns:
74 The new core schema.
76 Raises:
77 TypeError:
78 - If `discriminator` is used with invalid union variant.
79 - If `discriminator` is used with `Union` type with one variant.
80 - If `discriminator` value mapped to multiple choices.
81 MissingDefinitionForUnionRef:
82 If the definition for ref is missing.
83 PydanticUserError:
84 - If a model in union doesn't have a discriminator field.
85 - If discriminator field has a non-string alias.
86 - If discriminator fields have different aliases.
87 - If discriminator field not of type `Literal`.
88 """
89 from ..types import Discriminator 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
91 if isinstance(discriminator, Discriminator): 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
92 if isinstance(discriminator.discriminator, str): 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
93 discriminator = discriminator.discriminator 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
94 else:
95 return discriminator._convert_schema(schema) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
97 return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
100class _ApplyInferredDiscriminator: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
101 """This class is used to convert an input schema containing a union schema into one where that union is
102 replaced with a tagged-union, with all the associated debugging and performance benefits.
104 This is done by:
105 * Validating that the input schema is compatible with the provided discriminator
106 * Introspecting the schema to determine which discriminator values should map to which union choices
107 * Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more
109 I have chosen to implement the conversion algorithm in this class, rather than a function,
110 to make it easier to maintain state while recursively walking the provided CoreSchema.
111 """
113 def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]): 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
114 # `discriminator` should be the name of the field which will serve as the discriminator.
115 # It must be the python name of the field, and *not* the field's alias. Note that as of now,
116 # all members of a discriminated union _must_ use a field with the same name as the discriminator.
117 # This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices.
118 self.discriminator = discriminator 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
120 # `definitions` should contain a mapping of schema ref to schema for all schemas which might
121 # be referenced by some choice
122 self.definitions = definitions 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
124 # `_discriminator_alias` will hold the value, if present, of the alias for the discriminator
125 #
126 # Note: following the v1 implementation, we currently disallow the use of different aliases
127 # for different choices. This is not a limitation of pydantic_core, but if we try to handle
128 # this, the inference logic gets complicated very quickly, and could result in confusing
129 # debugging challenges for users making subtle mistakes.
130 #
131 # Rather than trying to do the most powerful inference possible, I think we should eventually
132 # expose a way to more-manually control the way the TaggedUnionSchema is constructed through
133 # the use of a new type which would be placed as an Annotation on the Union type. This would
134 # provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for
135 # more complex cases, without over-complicating the inference logic for the common cases.
136 self._discriminator_alias: str | None = None 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
138 # `_should_be_nullable` indicates whether the converted union has `None` as an allowed value.
139 # If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while
140 # constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True.
141 # Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure
142 # that the final schema gets wrapped as a NullableSchema. This has the same semantics on the
143 # python side, but resolves the issue that `None` cannot correspond to any discriminator values.
144 self._should_be_nullable = False 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
146 # `_is_nullable` is used to track if the final produced schema will definitely be nullable;
147 # we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved
148 # as an indication that, even if None is discovered as one of the union choices, we will not need to wrap
149 # the final value in another nullable schema.
150 #
151 # This is more complicated than just checking for the final outermost schema having type 'nullable' thanks
152 # to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc.
153 self._is_nullable = False 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
155 # `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices
156 # from the union in the wrapped schema will be appended to this list, and the recursive choice-handling
157 # algorithm may add more choices to this stack as (nested) unions are encountered.
158 self._choices_to_handle: list[core_schema.CoreSchema] = [] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
160 # `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included
161 # in the output TaggedUnionSchema that will replace the union from the input schema
162 self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {} 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
164 # `_used` is changed to True after applying the discriminator to prevent accidental re-use
165 self._used = False 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
167 def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
168 """Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided
169 to this class.
171 Args:
172 schema: The input schema.
174 Returns:
175 The new core schema.
177 Raises:
178 TypeError:
179 - If `discriminator` is used with invalid union variant.
180 - If `discriminator` is used with `Union` type with one variant.
181 - If `discriminator` value mapped to multiple choices.
182 ValueError:
183 If the definition for ref is missing.
184 PydanticUserError:
185 - If a model in union doesn't have a discriminator field.
186 - If discriminator field has a non-string alias.
187 - If discriminator fields have different aliases.
188 - If discriminator field not of type `Literal`.
189 """
190 assert not self._used 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
191 schema = self._apply_to_root(schema) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
192 if self._should_be_nullable and not self._is_nullable: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
193 schema = core_schema.nullable_schema(schema) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
194 self._used = True 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
195 return schema 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
197 def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
198 """This method handles the outer-most stage of recursion over the input schema:
199 unwrapping nullable or definitions schemas, and calling the `_handle_choice`
200 method iteratively on the choices extracted (recursively) from the possibly-wrapped union.
201 """
202 if schema['type'] == 'nullable': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
203 self._is_nullable = True 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
204 wrapped = self._apply_to_root(schema['schema']) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
205 nullable_wrapper = schema.copy() 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
206 nullable_wrapper['schema'] = wrapped 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
207 return nullable_wrapper 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
209 if schema['type'] == 'definitions': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
210 wrapped = self._apply_to_root(schema['schema']) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
211 definitions_wrapper = schema.copy() 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
212 definitions_wrapper['schema'] = wrapped 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
213 return definitions_wrapper 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
215 if schema['type'] != 'union': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
216 # If the schema is not a union, it probably means it just had a single member and
217 # was flattened by pydantic_core.
218 # However, it still may make sense to apply the discriminator to this schema,
219 # as a way to get discriminated-union-style error messages, so we allow this here.
220 schema = core_schema.union_schema([schema]) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
222 # Reverse the choices list before extending the stack so that they get handled in the order they occur
223 choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
224 self._choices_to_handle.extend(choices_schemas) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
225 while self._choices_to_handle: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
226 choice = self._choices_to_handle.pop() 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
227 self._handle_choice(choice) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
229 if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
230 # * We need to annotate `discriminator` as a union here to handle both branches of this conditional
231 # * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the
232 # invariance of list, and because list[list[str | int]] is the type of the discriminator argument
233 # to tagged_union_schema below
234 # * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to
235 # interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here
236 # is the appropriate way to provide a list of fallback attributes to check for a discriminator value.)
237 discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
238 else:
239 discriminator = self.discriminator 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
240 return core_schema.tagged_union_schema( 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
241 choices=self._tagged_union_choices,
242 discriminator=discriminator,
243 custom_error_type=schema.get('custom_error_type'),
244 custom_error_message=schema.get('custom_error_message'),
245 custom_error_context=schema.get('custom_error_context'),
246 strict=False,
247 from_attributes=True,
248 ref=schema.get('ref'),
249 metadata=schema.get('metadata'),
250 serialization=schema.get('serialization'),
251 )
253 def _handle_choice(self, choice: core_schema.CoreSchema) -> None: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
254 """This method handles the "middle" stage of recursion over the input schema.
255 Specifically, it is responsible for handling each choice of the outermost union
256 (and any "coalesced" choices obtained from inner unions).
258 Here, "handling" entails:
259 * Coalescing nested unions and compatible tagged-unions
260 * Tracking the presence of 'none' and 'nullable' schemas occurring as choices
261 * Validating that each allowed discriminator value maps to a unique choice
262 * Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
263 """
264 if choice['type'] == 'definition-ref': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
265 if choice['schema_ref'] not in self.definitions: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
266 raise MissingDefinitionForUnionRef(choice['schema_ref']) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
268 if choice['type'] == 'none': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
269 self._should_be_nullable = True 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
270 elif choice['type'] == 'definitions': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
271 self._handle_choice(choice['schema']) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
272 elif choice['type'] == 'nullable': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
273 self._should_be_nullable = True 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
274 self._handle_choice(choice['schema']) # unwrap the nullable schema 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
275 elif choice['type'] == 'union': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
276 # Reverse the choices list before extending the stack so that they get handled in the order they occur
277 choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
278 self._choices_to_handle.extend(choices_schemas) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
279 elif choice['type'] not in { 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
280 'model',
281 'typed-dict',
282 'tagged-union',
283 'lax-or-strict',
284 'dataclass',
285 'dataclass-args',
286 'definition-ref',
287 } and not _core_utils.is_function_with_inner_schema(choice):
288 # We should eventually handle 'definition-ref' as well
289 raise TypeError( 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
290 f'{choice["type"]!r} is not a valid discriminated union variant;'
291 ' should be a `BaseModel` or `dataclass`'
292 )
293 else:
294 if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice): 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
295 # In this case, this inner tagged-union is compatible with the outer tagged-union,
296 # and its choices can be coalesced into the outer TaggedUnionSchema.
297 subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
298 # Reverse the choices list before extending the stack so that they get handled in the order they occur
299 self._choices_to_handle.extend(subchoices[::-1]) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
300 return 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
302 inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
303 self._set_unique_choice_for_values(choice, inferred_discriminator_values) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
305 def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
306 """This method returns a boolean indicating whether the discriminator for the `choice`
307 is the same as that being used for the outermost tagged union. This is used to
308 determine whether this TaggedUnionSchema choice should be "coalesced" into the top level,
309 or whether it should be treated as a separate (nested) choice.
310 """
311 inner_discriminator = choice['discriminator'] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
312 return inner_discriminator == self.discriminator or ( 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
313 isinstance(inner_discriminator, list)
314 and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator)
315 )
317 def _infer_discriminator_values_for_choice( # noqa C901 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
318 self, choice: core_schema.CoreSchema, source_name: str | None
319 ) -> list[str | int]:
320 """This function recurses over `choice`, extracting all discriminator values that should map to this choice.
322 `model_name` is accepted for the purpose of producing useful error messages.
323 """
324 if choice['type'] == 'definitions': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
325 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
326 elif choice['type'] == 'function-plain': 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
327 raise TypeError(
328 f'{choice["type"]!r} is not a valid discriminated union variant;'
329 ' should be a `BaseModel` or `dataclass`'
330 )
331 elif _core_utils.is_function_with_inner_schema(choice): 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
332 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
333 elif choice['type'] == 'lax-or-strict': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
334 return sorted( 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
335 set(
336 self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None)
337 + self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None)
338 )
339 )
341 elif choice['type'] == 'tagged-union': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
342 values: list[str | int] = [] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
343 # Ignore str/int "choices" since these are just references to other choices
344 subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
345 for subchoice in subchoices: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
346 subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
347 values.extend(subchoice_values) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
348 return values 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
350 elif choice['type'] == 'union': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
351 values = [] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
352 for subchoice in choice['choices']: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
353 subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
354 subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
355 values.extend(subchoice_values) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
356 return values 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
358 elif choice['type'] == 'nullable': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
359 self._should_be_nullable = True 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
360 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
362 elif choice['type'] == 'model': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
363 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
365 elif choice['type'] == 'dataclass': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
366 return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
368 elif choice['type'] == 'model-fields': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
369 return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
371 elif choice['type'] == 'dataclass-args': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
372 return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
374 elif choice['type'] == 'typed-dict': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
375 return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
377 elif choice['type'] == 'definition-ref': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
378 schema_ref = choice['schema_ref'] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
379 if schema_ref not in self.definitions: 379 ↛ 380line 379 didn't jump to line 380 because the condition on line 379 was never true1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
380 raise MissingDefinitionForUnionRef(schema_ref)
381 return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
382 else:
383 raise TypeError( 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
384 f'{choice["type"]!r} is not a valid discriminated union variant;'
385 ' should be a `BaseModel` or `dataclass`'
386 )
388 def _infer_discriminator_values_for_typed_dict_choice( 1akblcmdneopqrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
389 self, choice: core_schema.TypedDictSchema, source_name: str | None = None
390 ) -> list[str | int]:
391 """This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema
392 for the sake of readability.
393 """
394 source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}' 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
395 field = choice['fields'].get(self.discriminator) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
396 if field is None: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
397 raise PydanticUserError( 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
398 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
399 )
400 return self._infer_discriminator_values_for_field(field, source) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
402 def _infer_discriminator_values_for_model_choice( 1akblcmdneopqrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
403 self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None
404 ) -> list[str | int]:
405 source = 'ModelFields' if source_name is None else f'Model {source_name!r}' 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
406 field = choice['fields'].get(self.discriminator) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
407 if field is None: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
408 raise PydanticUserError( 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
409 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
410 )
411 return self._infer_discriminator_values_for_field(field, source) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
413 def _infer_discriminator_values_for_dataclass_choice( 1akblcmdneopqrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
414 self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None
415 ) -> list[str | int]:
416 source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}' 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
417 for field in choice['fields']: 417 ↛ 421line 417 didn't jump to line 421 because the loop on line 417 didn't complete1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
418 if field['name'] == self.discriminator: 418 ↛ 417line 418 didn't jump to line 417 because the condition on line 418 was always true1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
419 break 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
420 else:
421 raise PydanticUserError(
422 f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
423 )
424 return self._infer_discriminator_values_for_field(field, source) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
426 def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
427 if field['type'] == 'computed-field': 427 ↛ 429line 427 didn't jump to line 429 because the condition on line 427 was never true1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
428 # This should never occur as a discriminator, as it is only relevant to serialization
429 return []
430 alias = field.get('validation_alias', self.discriminator) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
431 if not isinstance(alias, str): 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
432 raise PydanticUserError( 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
433 f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type'
434 )
435 if self._discriminator_alias is None: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
436 self._discriminator_alias = alias 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
437 elif self._discriminator_alias != alias: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
438 raise PydanticUserError( 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
439 f'Aliases for discriminator {self.discriminator!r} must be the same '
440 f'(got {alias}, {self._discriminator_alias})',
441 code='discriminator-alias',
442 )
443 return self._infer_discriminator_values_for_inner_schema(field['schema'], source) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
445 def _infer_discriminator_values_for_inner_schema( 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
446 self, schema: core_schema.CoreSchema, source: str
447 ) -> list[str | int]:
448 """When inferring discriminator values for a field, we typically extract the expected values from a literal
449 schema. This function does that, but also handles nested unions and defaults.
450 """
451 if schema['type'] == 'literal': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
452 return schema['expected'] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
454 elif schema['type'] == 'union': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
455 # Generally when multiple values are allowed they should be placed in a single `Literal`, but
456 # we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s.
457 # For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]`
458 values: list[Any] = [] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
459 for choice in schema['choices']: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
460 choice_schema = choice[0] if isinstance(choice, tuple) else choice 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
461 choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
462 values.extend(choice_values) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
463 return values 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
465 elif schema['type'] == 'default': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
466 # This will happen if the field has a default value; we ignore it while extracting the discriminator values
467 return self._infer_discriminator_values_for_inner_schema(schema['schema'], source) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
469 elif schema['type'] == 'function-after': 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
470 # After validators don't affect the discriminator values
471 return self._infer_discriminator_values_for_inner_schema(schema['schema'], source) 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
473 elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
474 validator_type = repr(schema['type'].split('-')[1]) 1abcdefghij
475 raise PydanticUserError( 1abcdefghij
476 f'Cannot use a mode={validator_type} validator in the'
477 f' discriminator field {self.discriminator!r} of {source}',
478 code='discriminator-validator',
479 )
481 else:
482 raise PydanticUserError( 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
483 f'{source} needs field {self.discriminator!r} to be of type `Literal`',
484 code='discriminator-needs-literal',
485 )
487 def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None: 1akblcmdneopqKLrstuvwxyzABCMNOPQRSTUVfDgEhFiGjHIJ
488 """This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the
489 provided `choice`, validating that none of these values already map to another (different) choice.
490 """
491 for discriminator_value in values: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
492 if discriminator_value in self._tagged_union_choices: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
493 # It is okay if `value` is already in tagged_union_choices as long as it maps to the same value.
494 # Because tagged_union_choices may map values to other values, we need to walk the choices dict
495 # until we get to a "real" choice, and confirm that is equal to the one assigned.
496 existing_choice = self._tagged_union_choices[discriminator_value] 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
497 if existing_choice != choice: 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
498 raise TypeError( 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ
499 f'Value {discriminator_value!r} for discriminator '
500 f'{self.discriminator!r} mapped to multiple choices'
501 )
502 else:
503 self._tagged_union_choices[discriminator_value] = choice 1akblcmdneopqKLrstuvwxyzABCfDgEhFiGjHIJ