Coverage for src/metador_core/schema/partial.py: 99%
182 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-02 09:33 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-02 09:33 +0000
1"""Partial pydantic models.
3Partial models replicate another model with all fields made optional.
4However, they cannot be part of the inheritance chain of the original
5models, because semantically it does not make sense or leads to
6a diamond problem.
8Therefore they are independent models, unrelated to the original
9models, but convertable through methods.
11Partial models are implemented as a mixin class to be combined
12with the top level base model that you are using in your model
13hierarchy, e.g. if you use plain `BaseModel`, you can e.g. define:
15```
16class MyPartial(DeepPartialModel, BaseModel): ...
17```
19And use `MyPartial.get_partial` on your models.
21If different compatible model instances are merged,
22the merge will produce an instance of the left type.
24Some theory - partial schemas form a monoid with:
25* the empty partial schema as neutral element
26* merge of the fields as the binary operation
27* associativity follows from associativity of used merge operations
28"""
30from __future__ import annotations
32from functools import reduce
33from typing import (
34 Any,
35 ClassVar,
36 Dict,
37 ForwardRef,
38 Iterator,
39 List,
40 Optional,
41 Set,
42 Tuple,
43 Type,
44 Union,
45 cast,
46)
48from pydantic import BaseModel, ValidationError, create_model, validate_model
49from pydantic.fields import FieldInfo
50from typing_extensions import Annotated
52from ..util import is_public_name
53from ..util import typing as t
56def _is_list_or_set(hint):
57 return t.is_list(hint) or t.is_set(hint)
60def _check_type_mergeable(hint, *, allow_none: bool = False) -> bool:
61 """Check whether a type is mergeable.
63 An atomic type is:
64 * not None
65 * not a List, Set, Union or Optional
66 * not a model
68 (i.e. usually a primitive value e.g. int/bool/etc.)
70 A singular type is:
71 * an atomic type, or
72 * a model, or
73 * a Union of multiple singular types
75 A complex type is:
76 * a singular type
77 * it is a List/Set of a singular type
79 A mergeable type is:
80 * a complex type, or
81 * an Optional complex type
83 merge(x,y):
84 merge(None, None) = None
85 merge(None, x) = merge(x, None) = x
86 merge(x: model1, y: model2)
87 | model1 < model2 or model2 < model1 = recursive_merge(x, y)
88 | otherwise = y
89 merge(x: singular, y: singular) = y
90 merge(x: list, y: list) = x ++ y
91 merge(x: set, y: set) = x.union(y)
92 """
93 # print("check ", hint)
94 args = t.get_args(hint)
96 if _is_list_or_set(hint): # list or set -> dig deeper
97 return all(map(_check_type_mergeable, args))
99 # not union, list or set?
100 if not t.is_union(hint):
101 # will be either primitive or recursively merged -> ok
102 return True
104 # Union case:
105 if not allow_none and t.is_optional(hint):
106 return False # allows none, but should not!
108 # If a Union contains a set or list, it must be only combined with None
109 is_prim_union = not any(map(_is_list_or_set, args))
110 is_opt_set_or_list = len(args) == 2 and t.NoneType in args
111 if not (is_prim_union or is_opt_set_or_list):
112 return False
114 # everything looks fine on this level -> recurse down
115 return all(map(_check_type_mergeable, args))
118def is_mergeable_type(hint) -> bool:
119 """Return whether given type can be deeply merged.
121 This imposes some constraints on the shape of valid hints.
122 """
123 return _check_type_mergeable(hint, allow_none=True)
126def val_from_partial(val):
127 """Recursively convert back from a partial model if val is one."""
128 if isinstance(val, PartialModel):
129 return val.from_partial()
130 if isinstance(val, list):
131 return [val_from_partial(x) for x in val]
132 if isinstance(val, set):
133 return {val_from_partial(x) for x in val}
134 return val
137class PartialModel:
138 """Base partial metadata model mixin.
140 In this variant merging is done by simply overwriting old field values
141 with new values (if the new value is not `None`) in a shallow way.
143 For more fancy merging, consider `DeepPartialModel`.
144 """
146 __partial_src__: ClassVar[Type[BaseModel]]
147 """Original model class this partial class is based on."""
149 __partial_fac__: ClassVar[Type[PartialFactory]]
150 """Factory class that created this partial."""
152 def from_partial(self):
153 """Return a new non-partial model instance (will run full validation).
155 Raises ValidationError on failure (e.g. if the partial is missing fields).
156 """
157 fields = {
158 k: val_from_partial(v)
159 for k, v in self.__partial_fac__._get_field_vals(self)
160 }
161 return self.__partial_src__.parse_obj(fields)
163 @classmethod
164 def to_partial(cls, obj, *, ignore_invalid: bool = False):
165 """Transform `obj` into a new instance of this partial model.
167 If passed object is instance of (a subclass of) this or the original model,
168 no validation is performed.
170 Returns partial instance with the successfully parsed fields.
172 Raises ValidationError if parsing fails.
173 (usless ignore_invalid is set by default or passed).
174 """
175 if isinstance(obj, (cls, cls.__partial_src__)):
176 # safe, because subclasses are "stricter"
177 return cls.construct(**obj.__dict__) # type: ignore
179 if ignore_invalid:
180 # validate data and keep only valid fields
181 data, fields, _ = validate_model(cls, obj) # type: ignore
182 return cls.construct(_fields_set=fields, **data) # type: ignore
184 # parse a dict or another pydantic model
185 if isinstance(obj, BaseModel):
186 obj = obj.dict(exclude_none=True) # type: ignore
187 return cls.parse_obj(obj) # type: ignore
189 @classmethod
190 def cast(
191 cls,
192 obj: Union[BaseModel, PartialModel],
193 *,
194 ignore_invalid: bool = False,
195 ):
196 """Cast given object into this partial model if needed.
198 If it already is an instance, will do nothing.
199 Otherwise, will call `to_partial`.
200 """
201 if isinstance(obj, cls):
202 return obj
203 return cls.to_partial(obj, ignore_invalid=ignore_invalid)
205 def _update_field(
206 self,
207 v_old,
208 v_new,
209 *,
210 path: List[str] = None,
211 allow_overwrite: bool = False,
212 ):
213 """Return merged result of the two passed arguments.
215 None is always overwritten by a non-None value,
216 lists are concatenated, sets are unionized,
217 partial models are recursively merged,
218 otherwise the new value overwrites the old one.
219 """
220 path = path or []
221 # None -> missing value -> just use new value (shortcut)
222 if v_old is None or v_new is None:
223 return v_new or v_old
225 # list -> new one must also be a list -> concatenate
226 if isinstance(v_old, list):
227 return v_old + v_new
229 # set -> new one must also be a set -> set union
230 if isinstance(v_old, set):
231 # NOTE: we could try being smarter for sets of partial models
232 # https://github.com/Materials-Data-Science-and-Informatics/metador-core/issues/20
233 return v_old.union(v_new) # set union
235 # another model -> recursive merge of partials, if compatible
236 old_is_model = isinstance(v_old, self.__partial_fac__.base_model)
237 new_is_model = isinstance(v_new, self.__partial_fac__.base_model)
238 if old_is_model and new_is_model:
239 v_old_p = self.__partial_fac__.get_partial(type(v_old)).cast(v_old)
240 v_new_p = self.__partial_fac__.get_partial(type(v_new)).cast(v_new)
241 new_subclass_old = issubclass(type(v_new_p), type(v_old_p))
242 old_subclass_new = issubclass(type(v_old_p), type(v_new_p))
243 if new_subclass_old or old_subclass_new:
244 try:
245 return v_old_p.merge_with(
246 v_new_p, allow_overwrite=allow_overwrite, _path=path
247 )
248 except ValidationError:
249 # casting failed -> proceed to next merge variant
250 # TODO: maybe raise unless "ignore invalid"?
251 pass
253 # if we're here, treat it as an opaque value
254 if not allow_overwrite:
255 msg_title = (
256 f"Can't overwrite (allow_overwrite=False) at {' -> '.join(path)}:"
257 )
258 msg = f"{msg_title}\n\t{repr(v_old)}\n\twith\n\t{repr(v_new)}"
259 raise ValueError(msg)
260 return v_new
262 def merge_with(
263 self,
264 obj,
265 *,
266 ignore_invalid: bool = False,
267 allow_overwrite: bool = False,
268 _path: List[str] = None,
269 ):
270 """Return a new partial model with updated fields (without validation).
272 Raises `ValidationError` if passed `obj` is not suitable,
273 unless `ignore_invalid` is set to `True`.
275 Raises `ValueError` if `allow_overwrite=False` and a value would be overwritten.
276 """
277 _path = _path or []
278 obj = self.cast(obj, ignore_invalid=ignore_invalid) # raises on failure
280 ret = self.copy() # type: ignore
281 for f_name, v_new in self.__partial_fac__._get_field_vals(obj):
282 v_old = ret.__dict__.get(f_name)
283 v_merged = self._update_field(
284 v_old, v_new, path=_path + [f_name], allow_overwrite=allow_overwrite
285 )
286 ret.__dict__[f_name] = v_merged
287 return ret
289 @classmethod
290 def merge(cls, *objs: PartialModel, **kwargs) -> PartialModel:
291 """Merge all passed partial models in given order using `merge_with`."""
292 # sadly it looks like *args and named kwargs (,*,) syntax cannot be mixed
293 ignore_invalid = kwargs.get("ignore_invalid", False)
294 allow_overwrite = kwargs.get("allow_overwrite", False)
296 if not objs:
297 return cls()
299 def merge_two(x, y):
300 return cls.cast(x).merge_with(
301 y, ignore_invalid=ignore_invalid, allow_overwrite=allow_overwrite
302 )
304 return cls.cast(reduce(merge_two, objs))
307# ----
308# Partial factory
309#
310# PartialModel-specific lookup dicts
311_partials: Dict[Type[PartialFactory], Dict[Type[BaseModel], Type[PartialModel]]] = {}
312_forwardrefs: Dict[Type[PartialFactory], Dict[str, Type[PartialModel]]] = {}
315class PartialFactory:
316 """Factory class to create and manage partial models."""
318 base_model: Type[BaseModel] = BaseModel
319 partial_mixin: Type[PartialModel] = PartialModel
321 # TODO: how to configure the whole partial family
322 # with default parameters for merge() efficiently?
323 # ----
324 # default arguments for merge (if not explicitly passed)
325 # allow_overwrite: bool = False
326 # """Default argument for merge() of partials."""
328 # ignore_invalid: bool = False
329 # """Default argument for merge() of partials."""
331 @classmethod
332 def _is_base_subclass(cls, obj: Any) -> bool:
333 if not isinstance(obj, type):
334 return False # not a class (probably just a hint)
335 if not issubclass(obj, cls.base_model):
336 return False # not a suitable model
337 return True
339 @classmethod
340 def _partial_name(cls, mcls: Type[BaseModel]) -> str:
341 """Return class name for partial of model `mcls`."""
342 return f"{mcls.__qualname__}.{cls.partial_mixin.__name__}"
344 @classmethod
345 def _partial_forwardref_name(cls, mcls: Type[BaseModel]) -> str:
346 """Return ForwardRef string for partial of model `mcls`."""
347 return f"__{mcls.__module__}_{cls._partial_name(mcls)}".replace(".", "_")
349 @classmethod
350 def _get_field_vals(cls, obj: BaseModel) -> Iterator[Tuple[str, Any]]:
351 """Return field values, excluding None and private fields.
353 This is different from `BaseModel.dict` as it ignores the defined alias
354 and is used here only for "internal representation".
355 """
356 return (
357 (k, v)
358 for k, v in obj.__dict__.items()
359 if is_public_name(k) and v is not None
360 )
362 @classmethod
363 def _nested_models(cls, field_types: Dict[str, t.TypeHint]) -> Set[Type[BaseModel]]:
364 """Collect all compatible nested model classes (for which we need partials)."""
365 return {
366 cast(Type[BaseModel], h)
367 for th in field_types.values()
368 for h in t.traverse_typehint(th)
369 if cls._is_base_subclass(h)
370 }
372 @classmethod
373 def _model_to_partial_fref(cls, orig_type: t.TypeHint) -> t.TypeHint:
374 """Substitute type hint with forward reference to partial models.
376 Will return unchanged type if passed argument is not suitable.
377 """
378 if not cls._is_base_subclass(orig_type):
379 return orig_type
380 return ForwardRef(cls._partial_forwardref_name(orig_type))
382 @classmethod
383 def _model_to_partial(
384 cls, mcls: Type[BaseModel]
385 ) -> Type[Union[BaseModel, PartialModel]]:
386 """Substitute a model class with partial model.
388 Will return unchanged type if argument not suitable.
389 """
390 return cls.get_partial(mcls) if cls._is_base_subclass(mcls) else mcls
392 @classmethod
393 def _partial_type(cls, orig_type: t.TypeHint) -> t.TypeHint:
394 """Convert a field type hint into a type hint for the partial.
396 This will make the field optional and also replace all nested models
397 derived from the configured base_model with the respective partial model.
398 """
399 return Optional[t.map_typehint(orig_type, cls._model_to_partial_fref)]
401 @classmethod
402 def _partial_field(cls, orig_type: t.TypeHint) -> Tuple[Type, Optional[FieldInfo]]:
403 """Return a field declaration tuple for dynamic model creation."""
404 th, fi = orig_type, None
406 # if pydantic Field is added (in an Annotated[...]) - unwrap
407 if t.get_origin(orig_type) is Annotated:
408 args = t.get_args(orig_type)
409 th = args[0]
410 fi = next(filter(lambda ann: isinstance(ann, FieldInfo), args[1:]), None)
412 pth = cls._partial_type(th) # map the (unwrapped) type to optional
413 return (pth, fi)
415 @classmethod
416 def _create_base_partial(cls):
417 class PartialBaseModel(cls.partial_mixin, cls.base_model):
418 class Config:
419 frozen = True # make sure it's hashable
421 return PartialBaseModel
423 @classmethod
424 def _create_partial(cls, mcls: Type[BaseModel], *, typehints=None):
425 """Create a new partial model class based on `mcls`."""
426 if not cls._is_base_subclass(mcls):
427 raise TypeError(f"{mcls} is not subclass of {cls.base_model.__name__}!")
428 if mcls is cls.base_model:
429 return (cls._create_base_partial(), [])
430 # ----
431 # get field type annotations (or use the passed ones / for performance)
432 hints = typehints or t.get_type_hints(mcls)
433 field_types = {k: v for k, v in hints.items() if k in mcls.__fields__}
434 # get dependencies that must be substituted
435 missing_partials = cls._nested_models(field_types)
436 # compute new field types
437 new_fields = {k: cls._partial_field(v) for k, v in field_types.items()}
438 # replace base classes with corresponding partial bases
439 new_bases = tuple(map(cls._model_to_partial, mcls.__bases__))
440 # create partial model
441 ret: Type[PartialModel] = create_model(
442 cls._partial_name(mcls),
443 __base__=new_bases,
444 __module__=mcls.__module__,
445 __validators__=mcls.__validators__, # type: ignore
446 **new_fields,
447 )
448 ret.__partial_src__ = mcls # connect to original model
449 ret.__partial_fac__ = cls # connect to this class
450 # ----
451 return ret, missing_partials
453 @classmethod
454 def get_partial(cls, mcls: Type[BaseModel], *, typehints=None):
455 """Return a partial schema with all fields of the given schema optional.
457 Original default values are not respected and are set to `None`.
459 The use of the returned class is for validating partial data before
460 zipping together partial results into a completed one.
462 This is a much more fancy version of e.g.
463 https://github.com/pydantic/pydantic/issues/1799
465 because it recursively substitutes with partial models.
466 This allows us to implement smart deep merge for partials.
467 """
468 if cls not in _partials: # first use of this partial factory
469 _partials[cls] = {}
470 _forwardrefs[cls] = {}
472 if partial := _partials[cls].get(mcls):
473 return partial # already have a partial
474 else: # block the spot (to break recursion)
475 _partials[cls][mcls] = None
477 # ----
478 # create a partial for a model:
479 # localns = {cls._partial_forwardref_name(k): v for k,v in _partials[cls].items() if v}
480 localns: Dict[str, Any] = {}
481 mcls.update_forward_refs(**localns) # to be sure
483 partial, nested = cls._create_partial(mcls, typehints=typehints)
484 partial_ref = cls._partial_forwardref_name(mcls)
485 # store result
486 _forwardrefs[cls][partial_ref] = partial
487 _partials[cls][mcls] = partial
488 # create partials for nested models
489 for model in nested:
490 cls.get_partial(model)
491 # resolve possible circular references
492 partial.update_forward_refs(**_forwardrefs[cls]) # type: ignore
493 # ----
494 return partial