Coverage for sqlmodel/_compat.py: 93%
290 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 00:02 +0000
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 00:02 +0000
1import types 1deabcf
2from contextlib import contextmanager 1deabcf
3from contextvars import ContextVar 1deabcf
4from dataclasses import dataclass 1deabcf
5from typing import ( 1deabcf
6 TYPE_CHECKING,
7 AbstractSet,
8 Any,
9 Callable,
10 Dict,
11 ForwardRef,
12 Generator,
13 Mapping,
14 Optional,
15 Set,
16 Type,
17 TypeVar,
18 Union,
19)
21from pydantic import VERSION as P_VERSION 1deabcf
22from pydantic import BaseModel 1deabcf
23from pydantic.fields import FieldInfo 1deabcf
24from typing_extensions import get_args, get_origin 1deabcf
26# Reassign variable to make it reexported for mypy
27PYDANTIC_VERSION = P_VERSION 1deabcf
28IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") 1deabcf
31if TYPE_CHECKING: 1deabcf
32 from .main import RelationshipInfo, SQLModel
34UnionType = getattr(types, "UnionType", Union) 1deabcf
35NoneType = type(None) 1deabcf
36T = TypeVar("T") 1deabcf
37InstanceOrType = Union[T, Type[T]] 1deabcf
38_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") 1deabcf
41class FakeMetadata: 1deabcf
42 max_length: Optional[int] = None 1deabcf
43 max_digits: Optional[int] = None 1deabcf
44 decimal_places: Optional[int] = None 1deabcf
47@dataclass 1deabcf
48class ObjectWithUpdateWrapper: 1eabcf
49 obj: Any 1deabcf
50 update: Dict[str, Any] 1deabcf
52 def __getattribute__(self, __name: str) -> Any: 1deabcf
53 update = super().__getattribute__("update") 1dabc
54 obj = super().__getattribute__("obj") 1dabc
55 if __name in update: 1dabc
56 return update[__name] 1dabc
57 return getattr(obj, __name) 1dabc
60def _is_union_type(t: Any) -> bool: 1deabcf
61 return t is UnionType or t is Union 1dabc
64finish_init: ContextVar[bool] = ContextVar("finish_init", default=True) 1deabcf
67@contextmanager 1deabcf
68def partial_init() -> Generator[None, None, None]: 1deabcf
69 token = finish_init.set(False) 1dabc
70 yield 1dabc
71 finish_init.reset(token) 1dabc
74if IS_PYDANTIC_V2: 1deabcf
75 from annotated_types import MaxLen 1dabc
76 from pydantic import ConfigDict as BaseConfig 1dabc
77 from pydantic._internal._fields import PydanticMetadata 1dabc
78 from pydantic._internal._model_construction import ModelMetaclass 1dabc
79 from pydantic._internal._repr import Representation as Representation 1dabc
80 from pydantic_core import PydanticUndefined as Undefined 1dabc
81 from pydantic_core import PydanticUndefinedType as UndefinedType 1dabc
83 # Dummy for types, to make it importable
84 class ModelField: 1dabc
85 pass 1dabc
87 class SQLModelConfig(BaseConfig, total=False): 1dabc
88 table: Optional[bool] 1dabc
89 registry: Optional[Any] 1dabc
91 def get_config_value( 1abc
92 *, model: InstanceOrType["SQLModel"], parameter: str, default: Any = None
93 ) -> Any:
94 return model.model_config.get(parameter, default) 1dabc
96 def set_config_value( 1abc
97 *,
98 model: InstanceOrType["SQLModel"],
99 parameter: str,
100 value: Any,
101 ) -> None:
102 model.model_config[parameter] = value # type: ignore[literal-required] 1dabc
104 def get_model_fields(model: InstanceOrType[BaseModel]) -> Dict[str, "FieldInfo"]: 1dabc
105 return model.model_fields 1dabc
107 def get_fields_set( 1abc
108 object: InstanceOrType["SQLModel"],
109 ) -> Union[Set[str], Callable[[BaseModel], Set[str]]]:
110 return object.model_fields_set 1dabc
112 def init_pydantic_private_attrs(new_object: InstanceOrType["SQLModel"]) -> None: 1dabc
113 object.__setattr__(new_object, "__pydantic_fields_set__", set()) 1dabc
114 object.__setattr__(new_object, "__pydantic_extra__", None) 1dabc
115 object.__setattr__(new_object, "__pydantic_private__", None) 1dabc
117 def get_annotations(class_dict: Dict[str, Any]) -> Dict[str, Any]: 1dabc
118 return class_dict.get("__annotations__", {}) 1dabc
120 def is_table_model_class(cls: Type[Any]) -> bool: 1dabc
121 config = getattr(cls, "model_config", {}) 1dabc
122 if config: 1dabc
123 return config.get("table", False) or False 1dabc
124 return False 1dabc
126 def get_relationship_to( 1abc
127 name: str,
128 rel_info: "RelationshipInfo",
129 annotation: Any,
130 ) -> Any:
131 origin = get_origin(annotation) 1dabc
132 use_annotation = annotation 1dabc
133 # Direct relationships (e.g. 'Team' or Team) have None as an origin
134 if origin is None: 1dabc
135 if isinstance(use_annotation, ForwardRef): 1dabc
136 use_annotation = use_annotation.__forward_arg__ 1dabc
137 else:
138 return use_annotation 1dabc
139 # If Union (e.g. Optional), get the real field
140 elif _is_union_type(origin): 1dabc
141 use_annotation = get_args(annotation) 1dabc
142 if len(use_annotation) > 2: 1dabc
143 raise ValueError(
144 "Cannot have a (non-optional) union as a SQLAlchemy field"
145 )
146 arg1, arg2 = use_annotation 1dabc
147 if arg1 is NoneType and arg2 is not NoneType: 1dabc
148 use_annotation = arg2
149 elif arg2 is NoneType and arg1 is not NoneType: 1dabc
150 use_annotation = arg1 1dabc
151 else:
152 raise ValueError(
153 "Cannot have a Union of None and None as a SQLAlchemy field"
154 )
156 # If a list, then also get the real field
157 elif origin is list: 1dabc
158 use_annotation = get_args(annotation)[0] 1dabc
160 return get_relationship_to( 1dabc
161 name=name, rel_info=rel_info, annotation=use_annotation
162 )
164 def is_field_noneable(field: "FieldInfo") -> bool: 1dabc
165 if getattr(field, "nullable", Undefined) is not Undefined: 1dabc
166 return field.nullable # type: ignore 1dabc
167 origin = get_origin(field.annotation) 1dabc
168 if origin is not None and _is_union_type(origin): 1dabc
169 args = get_args(field.annotation) 1dabc
170 if any(arg is NoneType for arg in args): 1dabc
171 return True 1dabc
172 if not field.is_required(): 1dabc
173 if field.default is Undefined: 1dabc
174 return False
175 if field.annotation is None or field.annotation is NoneType: # type: ignore[comparison-overlap] 1dabc
176 return True
177 return False 1dabc
178 return False 1dabc
180 def get_type_from_field(field: Any) -> Any: 1dabc
181 type_: Any = field.annotation 1dabc
182 # Resolve Optional fields
183 if type_ is None: 1dabc
184 raise ValueError("Missing field type")
185 origin = get_origin(type_) 1dabc
186 if origin is None: 1dabc
187 return type_ 1dabc
188 if _is_union_type(origin): 1dabc
189 bases = get_args(type_) 1dabc
190 if len(bases) > 2: 1dabc
191 raise ValueError(
192 "Cannot have a (non-optional) union as a SQLAlchemy field"
193 )
194 # Non optional unions are not allowed
195 if bases[0] is not NoneType and bases[1] is not NoneType: 1dabc
196 raise ValueError( 1dabc
197 "Cannot have a (non-optional) union as a SQLAlchemy field"
198 )
199 # Optional unions are allowed
200 return bases[0] if bases[0] is not NoneType else bases[1] 1dabc
201 return origin 1dabc
203 def get_field_metadata(field: Any) -> Any: 1dabc
204 for meta in field.metadata: 1dabc
205 if isinstance(meta, (PydanticMetadata, MaxLen)): 1dabc
206 return meta 1dabc
207 return FakeMetadata() 1dabc
209 def post_init_field_info(field_info: FieldInfo) -> None: 1dabc
210 return None 1dabc
212 # Dummy to make it importable
213 def _calculate_keys( 1abc
214 self: "SQLModel", 1abc
215 include: Optional[Mapping[Union[int, str], Any]], 1abc
216 exclude: Optional[Mapping[Union[int, str], Any]], 1abc
217 exclude_unset: bool, 1abc
218 update: Optional[Dict[str, Any]] = None, 1dabc
219 ) -> Optional[AbstractSet[str]]: # pragma: no cover 1dabc
220 return None
222 def sqlmodel_table_construct( 1abc
223 *,
224 self_instance: _TSQLModel,
225 values: Dict[str, Any],
226 _fields_set: Union[Set[str], None] = None,
227 ) -> _TSQLModel:
228 # Copy from Pydantic's BaseModel.construct()
229 # Ref: https://github.com/pydantic/pydantic/blob/v2.5.2/pydantic/main.py#L198
230 # Modified to not include everything, only the model fields, and to
231 # set relationships
232 # SQLModel override to get class SQLAlchemy __dict__ attributes and
233 # set them back in after creating the object
234 # new_obj = cls.__new__(cls)
235 cls = type(self_instance) 1dabc
236 old_dict = self_instance.__dict__.copy() 1dabc
237 # End SQLModel override
239 fields_values: Dict[str, Any] = {} 1dabc
240 defaults: Dict[ 1abc
241 str, Any
242 ] = {} # keeping this separate from `fields_values` helps us compute `_fields_set`
243 for name, field in cls.model_fields.items(): 1dabc
244 if field.alias and field.alias in values: 1dabc
245 fields_values[name] = values.pop(field.alias)
246 elif name in values: 1dabc
247 fields_values[name] = values.pop(name) 1dabc
248 elif not field.is_required(): 1dabc
249 defaults[name] = field.get_default(call_default_factory=True) 1dabc
250 if _fields_set is None: 1dabc
251 _fields_set = set(fields_values.keys()) 1dabc
252 fields_values.update(defaults) 1dabc
254 _extra: Union[Dict[str, Any], None] = None 1dabc
255 if cls.model_config.get("extra") == "allow": 1dabc
256 _extra = {}
257 for k, v in values.items():
258 _extra[k] = v
259 # SQLModel override, do not include everything, only the model fields
260 # else:
261 # fields_values.update(values)
262 # End SQLModel override
263 # SQLModel override
264 # Do not set __dict__, instead use setattr to trigger SQLAlchemy
265 # object.__setattr__(new_obj, "__dict__", fields_values)
266 # instrumentation
267 for key, value in {**old_dict, **fields_values}.items(): 1dabc
268 setattr(self_instance, key, value) 1dabc
269 # End SQLModel override
270 object.__setattr__(self_instance, "__pydantic_fields_set__", _fields_set) 1dabc
271 if not cls.__pydantic_root_model__: 1dabc
272 object.__setattr__(self_instance, "__pydantic_extra__", _extra) 1dabc
274 if cls.__pydantic_post_init__: 1dabc
275 self_instance.model_post_init(None)
276 elif not cls.__pydantic_root_model__: 1dabc
277 # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist
278 # Since it doesn't, that means that `__pydantic_private__` should be set to None
279 object.__setattr__(self_instance, "__pydantic_private__", None) 1dabc
280 # SQLModel override, set relationships
281 # Get and set any relationship objects
282 for key in self_instance.__sqlmodel_relationships__: 1dabc
283 value = values.get(key, Undefined) 1dabc
284 if value is not Undefined: 1dabc
285 setattr(self_instance, key, value) 1dabc
286 # End SQLModel override
287 return self_instance 1dabc
289 def sqlmodel_validate( 1abc
290 cls: Type[_TSQLModel],
291 obj: Any,
292 *,
293 strict: Union[bool, None] = None,
294 from_attributes: Union[bool, None] = None,
295 context: Union[Dict[str, Any], None] = None,
296 update: Union[Dict[str, Any], None] = None,
297 ) -> _TSQLModel:
298 if not is_table_model_class(cls): 1dabc
299 new_obj: _TSQLModel = cls.__new__(cls) 1dabc
300 else:
301 # If table, create the new instance normally to make SQLAlchemy create
302 # the _sa_instance_state attribute
303 # The wrapper of this function should use with _partial_init()
304 with partial_init(): 1dabc
305 new_obj = cls() 1dabc
306 # SQLModel Override to get class SQLAlchemy __dict__ attributes and
307 # set them back in after creating the object
308 old_dict = new_obj.__dict__.copy() 1dabc
309 use_obj = obj 1dabc
310 if isinstance(obj, dict) and update: 1dabc
311 use_obj = {**obj, **update}
312 elif update: 1dabc
313 use_obj = ObjectWithUpdateWrapper(obj=obj, update=update) 1dabc
314 cls.__pydantic_validator__.validate_python( 1dabc
315 use_obj,
316 strict=strict,
317 from_attributes=from_attributes,
318 context=context,
319 self_instance=new_obj,
320 )
321 # Capture fields set to restore it later
322 fields_set = new_obj.__pydantic_fields_set__.copy() 1dabc
323 if not is_table_model_class(cls): 1dabc
324 # If not table, normal Pydantic code, set __dict__
325 new_obj.__dict__ = {**old_dict, **new_obj.__dict__} 1dabc
326 else:
327 # Do not set __dict__, instead use setattr to trigger SQLAlchemy
328 # instrumentation
329 for key, value in {**old_dict, **new_obj.__dict__}.items(): 1dabc
330 setattr(new_obj, key, value) 1dabc
331 # Restore fields set
332 object.__setattr__(new_obj, "__pydantic_fields_set__", fields_set) 1dabc
333 # Get and set any relationship objects
334 if is_table_model_class(cls): 1dabc
335 for key in new_obj.__sqlmodel_relationships__: 1dabc
336 value = getattr(use_obj, key, Undefined) 1dabc
337 if value is not Undefined: 1dabc
338 setattr(new_obj, key, value)
339 return new_obj 1dabc
341 def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None: 1dabc
342 old_dict = self.__dict__.copy() 1dabc
343 if not is_table_model_class(self.__class__): 1dabc
344 self.__pydantic_validator__.validate_python( 1dabc
345 data,
346 self_instance=self,
347 )
348 else:
349 sqlmodel_table_construct( 1dabc
350 self_instance=self,
351 values=data,
352 )
353 object.__setattr__( 1dabc
354 self,
355 "__dict__",
356 {**old_dict, **self.__dict__},
357 )
359else:
360 from pydantic import BaseConfig as BaseConfig # type: ignore[assignment] 1ef
361 from pydantic.errors import ConfigError 1ef
362 from pydantic.fields import ( # type: ignore[attr-defined, no-redef] 1ef
363 SHAPE_SINGLETON,
364 ModelField,
365 )
366 from pydantic.fields import ( # type: ignore[attr-defined, no-redef] 1ef
367 Undefined as Undefined, # noqa
368 )
369 from pydantic.fields import ( # type: ignore[attr-defined, no-redef] 1ef
370 UndefinedType as UndefinedType,
371 )
372 from pydantic.main import ( # type: ignore[no-redef] 1ef
373 ModelMetaclass as ModelMetaclass,
374 )
375 from pydantic.main import validate_model 1ef
376 from pydantic.typing import resolve_annotations 1ef
377 from pydantic.utils import ROOT_KEY, ValueItems 1ef
378 from pydantic.utils import ( # type: ignore[no-redef] 1ef
379 Representation as Representation,
380 )
382 class SQLModelConfig(BaseConfig): # type: ignore[no-redef] 1ef
383 table: Optional[bool] = None # type: ignore[misc] 1ef
384 registry: Optional[Any] = None # type: ignore[misc] 1ef
386 def get_config_value( 1ef
387 *, model: InstanceOrType["SQLModel"], parameter: str, default: Any = None
388 ) -> Any:
389 return getattr(model.__config__, parameter, default) # type: ignore[union-attr] 1ef
391 def set_config_value( 1ef
392 *,
393 model: InstanceOrType["SQLModel"],
394 parameter: str,
395 value: Any,
396 ) -> None:
397 setattr(model.__config__, parameter, value) # type: ignore 1ef
399 def get_model_fields(model: InstanceOrType[BaseModel]) -> Dict[str, "FieldInfo"]: 1ef
400 return model.__fields__ # type: ignore 1ef
402 def get_fields_set( 1ef
403 object: InstanceOrType["SQLModel"],
404 ) -> Union[Set[str], Callable[[BaseModel], Set[str]]]:
405 return object.__fields_set__ 1ef
407 def init_pydantic_private_attrs(new_object: InstanceOrType["SQLModel"]) -> None: 1ef
408 object.__setattr__(new_object, "__fields_set__", set()) 1ef
410 def get_annotations(class_dict: Dict[str, Any]) -> Dict[str, Any]: 1ef
411 return resolve_annotations( # type: ignore[no-any-return] 1ef
412 class_dict.get("__annotations__", {}),
413 class_dict.get("__module__", None),
414 )
416 def is_table_model_class(cls: Type[Any]) -> bool: 1ef
417 config = getattr(cls, "__config__", None) 1ef
418 if config: 1ef
419 return getattr(config, "table", False) 1ef
420 return False
422 def get_relationship_to( 1ef
423 name: str,
424 rel_info: "RelationshipInfo",
425 annotation: Any,
426 ) -> Any:
427 temp_field = ModelField.infer( # type: ignore[attr-defined] 1ef
428 name=name,
429 value=rel_info,
430 annotation=annotation,
431 class_validators=None,
432 config=SQLModelConfig,
433 )
434 relationship_to = temp_field.type_ 1ef
435 if isinstance(temp_field.type_, ForwardRef): 1ef
436 relationship_to = temp_field.type_.__forward_arg__ 1ef
437 return relationship_to 1ef
439 def is_field_noneable(field: "FieldInfo") -> bool: 1ef
440 if not field.required: # type: ignore[attr-defined] 1ef
441 # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
442 return field.allow_none and ( # type: ignore[attr-defined] 1ef
443 field.shape != SHAPE_SINGLETON or not field.sub_fields # type: ignore[attr-defined]
444 )
445 return field.allow_none # type: ignore[no-any-return, attr-defined] 1ef
447 def get_type_from_field(field: Any) -> Any: 1ef
448 if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: 1ef
449 return field.type_ 1ef
450 raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") 1ef
452 def get_field_metadata(field: Any) -> Any: 1ef
453 metadata = FakeMetadata() 1ef
454 metadata.max_length = field.field_info.max_length 1ef
455 metadata.max_digits = getattr(field.type_, "max_digits", None) 1ef
456 metadata.decimal_places = getattr(field.type_, "decimal_places", None) 1ef
457 return metadata 1ef
459 def post_init_field_info(field_info: FieldInfo) -> None: 1ef
460 field_info._validate() # type: ignore[attr-defined] 1ef
462 def _calculate_keys( 1ef
463 self: "SQLModel",
464 include: Optional[Mapping[Union[int, str], Any]],
465 exclude: Optional[Mapping[Union[int, str], Any]],
466 exclude_unset: bool,
467 update: Optional[Dict[str, Any]] = None,
468 ) -> Optional[AbstractSet[str]]:
469 if include is None and exclude is None and not exclude_unset: 1ef
470 # Original in Pydantic:
471 # return None
472 # Updated to not return SQLAlchemy attributes
473 # Do not include relationships as that would easily lead to infinite
474 # recursion, or traversing the whole database
475 return ( 1ef
476 self.__fields__.keys() # noqa
477 ) # | self.__sqlmodel_relationships__.keys()
479 keys: AbstractSet[str]
480 if exclude_unset: 1ef
481 keys = self.__fields_set__.copy() # noqa 1ef
482 else:
483 # Original in Pydantic:
484 # keys = self.__dict__.keys()
485 # Updated to not return SQLAlchemy attributes
486 # Do not include relationships as that would easily lead to infinite
487 # recursion, or traversing the whole database
488 keys = (
489 self.__fields__.keys() # noqa
490 ) # | self.__sqlmodel_relationships__.keys()
491 if include is not None: 1ef
492 keys &= include.keys()
494 if update: 1ef
495 keys -= update.keys()
497 if exclude: 1ef
498 keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)}
500 return keys 1ef
502 def sqlmodel_validate( 1ef
503 cls: Type[_TSQLModel],
504 obj: Any,
505 *,
506 strict: Union[bool, None] = None,
507 from_attributes: Union[bool, None] = None,
508 context: Union[Dict[str, Any], None] = None,
509 update: Union[Dict[str, Any], None] = None,
510 ) -> _TSQLModel:
511 # This was SQLModel's original from_orm() for Pydantic v1
512 # Duplicated from Pydantic
513 if not cls.__config__.orm_mode: # type: ignore[attr-defined] # noqa 1ef
514 raise ConfigError(
515 "You must have the config attribute orm_mode=True to use from_orm"
516 )
517 if not isinstance(obj, Mapping): 1ef
518 obj = ( 1ef
519 {ROOT_KEY: obj}
520 if cls.__custom_root_type__ # type: ignore[attr-defined] # noqa
521 else cls._decompose_class(obj) # type: ignore[attr-defined] # noqa
522 )
523 # SQLModel, support update dict
524 if update is not None: 1ef
525 obj = {**obj, **update} 1ef
526 # End SQLModel support dict
527 if not getattr(cls.__config__, "table", False): # noqa 1ef
528 # If not table, normal Pydantic code
529 m: _TSQLModel = cls.__new__(cls) 1ef
530 else:
531 # If table, create the new instance normally to make SQLAlchemy create
532 # the _sa_instance_state attribute
533 m = cls() 1ef
534 values, fields_set, validation_error = validate_model(cls, obj) 1ef
535 if validation_error: 1ef
536 raise validation_error
537 # Updated to trigger SQLAlchemy internal handling
538 if not getattr(cls.__config__, "table", False): # noqa 1ef
539 object.__setattr__(m, "__dict__", values) 1ef
540 else:
541 for key, value in values.items(): 1ef
542 setattr(m, key, value) 1ef
543 # Continue with standard Pydantic logic
544 object.__setattr__(m, "__fields_set__", fields_set) 1ef
545 m._init_private_attributes() # type: ignore[attr-defined] # noqa 1ef
546 return m 1ef
548 def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None: 1ef
549 values, fields_set, validation_error = validate_model(self.__class__, data) 1ef
550 # Only raise errors if not a SQLModel model
551 if ( 1e
552 not is_table_model_class(self.__class__) # noqa
553 and validation_error
554 ):
555 raise validation_error 1ef
556 if not is_table_model_class(self.__class__): 1ef
557 object.__setattr__(self, "__dict__", values) 1ef
558 else:
559 # Do not set values as in Pydantic, pass them through setattr, so
560 # SQLAlchemy can handle them
561 for key, value in values.items(): 1ef
562 setattr(self, key, value) 1ef
563 object.__setattr__(self, "__fields_set__", fields_set) 1ef
564 non_pydantic_keys = data.keys() - values.keys() 1ef
566 if is_table_model_class(self.__class__): 1ef
567 for key in non_pydantic_keys: 1ef
568 if key in self.__sqlmodel_relationships__: 1ef
569 setattr(self, key, data[key]) 1ef