Coverage for src/metador_core/schema/partial.py: 99%

182 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-02 09:33 +0000

1"""Partial pydantic models. 

2 

3Partial models replicate another model with all fields made optional. 

4However, they cannot be part of the inheritance chain of the original 

5models, because semantically it does not make sense or leads to 

6a diamond problem. 

7 

8Therefore they are independent models, unrelated to the original 

9models, but convertable through methods. 

10 

11Partial models are implemented as a mixin class to be combined 

12with the top level base model that you are using in your model 

13hierarchy, e.g. if you use plain `BaseModel`, you can e.g. define: 

14 

15``` 

16class MyPartial(DeepPartialModel, BaseModel): ... 

17``` 

18 

19And use `MyPartial.get_partial` on your models. 

20 

21If different compatible model instances are merged, 

22the merge will produce an instance of the left type. 

23 

24Some theory - partial schemas form a monoid with: 

25* the empty partial schema as neutral element 

26* merge of the fields as the binary operation 

27* associativity follows from associativity of used merge operations 

28""" 

29 

30from __future__ import annotations 

31 

32from functools import reduce 

33from typing import ( 

34 Any, 

35 ClassVar, 

36 Dict, 

37 ForwardRef, 

38 Iterator, 

39 List, 

40 Optional, 

41 Set, 

42 Tuple, 

43 Type, 

44 Union, 

45 cast, 

46) 

47 

48from pydantic import BaseModel, ValidationError, create_model, validate_model 

49from pydantic.fields import FieldInfo 

50from typing_extensions import Annotated 

51 

52from ..util import is_public_name 

53from ..util import typing as t 

54 

55 

56def _is_list_or_set(hint): 

57 return t.is_list(hint) or t.is_set(hint) 

58 

59 

60def _check_type_mergeable(hint, *, allow_none: bool = False) -> bool: 

61 """Check whether a type is mergeable. 

62 

63 An atomic type is: 

64 * not None 

65 * not a List, Set, Union or Optional 

66 * not a model 

67 

68 (i.e. usually a primitive value e.g. int/bool/etc.) 

69 

70 A singular type is: 

71 * an atomic type, or 

72 * a model, or 

73 * a Union of multiple singular types 

74 

75 A complex type is: 

76 * a singular type 

77 * it is a List/Set of a singular type 

78 

79 A mergeable type is: 

80 * a complex type, or 

81 * an Optional complex type 

82 

83 merge(x,y): 

84 merge(None, None) = None 

85 merge(None, x) = merge(x, None) = x 

86 merge(x: model1, y: model2) 

87 | model1 < model2 or model2 < model1 = recursive_merge(x, y) 

88 | otherwise = y 

89 merge(x: singular, y: singular) = y 

90 merge(x: list, y: list) = x ++ y 

91 merge(x: set, y: set) = x.union(y) 

92 """ 

93 # print("check ", hint) 

94 args = t.get_args(hint) 

95 

96 if _is_list_or_set(hint): # list or set -> dig deeper 

97 return all(map(_check_type_mergeable, args)) 

98 

99 # not union, list or set? 

100 if not t.is_union(hint): 

101 # will be either primitive or recursively merged -> ok 

102 return True 

103 

104 # Union case: 

105 if not allow_none and t.is_optional(hint): 

106 return False # allows none, but should not! 

107 

108 # If a Union contains a set or list, it must be only combined with None 

109 is_prim_union = not any(map(_is_list_or_set, args)) 

110 is_opt_set_or_list = len(args) == 2 and t.NoneType in args 

111 if not (is_prim_union or is_opt_set_or_list): 

112 return False 

113 

114 # everything looks fine on this level -> recurse down 

115 return all(map(_check_type_mergeable, args)) 

116 

117 

118def is_mergeable_type(hint) -> bool: 

119 """Return whether given type can be deeply merged. 

120 

121 This imposes some constraints on the shape of valid hints. 

122 """ 

123 return _check_type_mergeable(hint, allow_none=True) 

124 

125 

126def val_from_partial(val): 

127 """Recursively convert back from a partial model if val is one.""" 

128 if isinstance(val, PartialModel): 

129 return val.from_partial() 

130 if isinstance(val, list): 

131 return [val_from_partial(x) for x in val] 

132 if isinstance(val, set): 

133 return {val_from_partial(x) for x in val} 

134 return val 

135 

136 

137class PartialModel: 

138 """Base partial metadata model mixin. 

139 

140 In this variant merging is done by simply overwriting old field values 

141 with new values (if the new value is not `None`) in a shallow way. 

142 

143 For more fancy merging, consider `DeepPartialModel`. 

144 """ 

145 

146 __partial_src__: ClassVar[Type[BaseModel]] 

147 """Original model class this partial class is based on.""" 

148 

149 __partial_fac__: ClassVar[Type[PartialFactory]] 

150 """Factory class that created this partial.""" 

151 

152 def from_partial(self): 

153 """Return a new non-partial model instance (will run full validation). 

154 

155 Raises ValidationError on failure (e.g. if the partial is missing fields). 

156 """ 

157 fields = { 

158 k: val_from_partial(v) 

159 for k, v in self.__partial_fac__._get_field_vals(self) 

160 } 

161 return self.__partial_src__.parse_obj(fields) 

162 

163 @classmethod 

164 def to_partial(cls, obj, *, ignore_invalid: bool = False): 

165 """Transform `obj` into a new instance of this partial model. 

166 

167 If passed object is instance of (a subclass of) this or the original model, 

168 no validation is performed. 

169 

170 Returns partial instance with the successfully parsed fields. 

171 

172 Raises ValidationError if parsing fails. 

173 (usless ignore_invalid is set by default or passed). 

174 """ 

175 if isinstance(obj, (cls, cls.__partial_src__)): 

176 # safe, because subclasses are "stricter" 

177 return cls.construct(**obj.__dict__) # type: ignore 

178 

179 if ignore_invalid: 

180 # validate data and keep only valid fields 

181 data, fields, _ = validate_model(cls, obj) # type: ignore 

182 return cls.construct(_fields_set=fields, **data) # type: ignore 

183 

184 # parse a dict or another pydantic model 

185 if isinstance(obj, BaseModel): 

186 obj = obj.dict(exclude_none=True) # type: ignore 

187 return cls.parse_obj(obj) # type: ignore 

188 

189 @classmethod 

190 def cast( 

191 cls, 

192 obj: Union[BaseModel, PartialModel], 

193 *, 

194 ignore_invalid: bool = False, 

195 ): 

196 """Cast given object into this partial model if needed. 

197 

198 If it already is an instance, will do nothing. 

199 Otherwise, will call `to_partial`. 

200 """ 

201 if isinstance(obj, cls): 

202 return obj 

203 return cls.to_partial(obj, ignore_invalid=ignore_invalid) 

204 

205 def _update_field( 

206 self, 

207 v_old, 

208 v_new, 

209 *, 

210 path: List[str] = None, 

211 allow_overwrite: bool = False, 

212 ): 

213 """Return merged result of the two passed arguments. 

214 

215 None is always overwritten by a non-None value, 

216 lists are concatenated, sets are unionized, 

217 partial models are recursively merged, 

218 otherwise the new value overwrites the old one. 

219 """ 

220 path = path or [] 

221 # None -> missing value -> just use new value (shortcut) 

222 if v_old is None or v_new is None: 

223 return v_new or v_old 

224 

225 # list -> new one must also be a list -> concatenate 

226 if isinstance(v_old, list): 

227 return v_old + v_new 

228 

229 # set -> new one must also be a set -> set union 

230 if isinstance(v_old, set): 

231 # NOTE: we could try being smarter for sets of partial models 

232 # https://github.com/Materials-Data-Science-and-Informatics/metador-core/issues/20 

233 return v_old.union(v_new) # set union 

234 

235 # another model -> recursive merge of partials, if compatible 

236 old_is_model = isinstance(v_old, self.__partial_fac__.base_model) 

237 new_is_model = isinstance(v_new, self.__partial_fac__.base_model) 

238 if old_is_model and new_is_model: 

239 v_old_p = self.__partial_fac__.get_partial(type(v_old)).cast(v_old) 

240 v_new_p = self.__partial_fac__.get_partial(type(v_new)).cast(v_new) 

241 new_subclass_old = issubclass(type(v_new_p), type(v_old_p)) 

242 old_subclass_new = issubclass(type(v_old_p), type(v_new_p)) 

243 if new_subclass_old or old_subclass_new: 

244 try: 

245 return v_old_p.merge_with( 

246 v_new_p, allow_overwrite=allow_overwrite, _path=path 

247 ) 

248 except ValidationError: 

249 # casting failed -> proceed to next merge variant 

250 # TODO: maybe raise unless "ignore invalid"? 

251 pass 

252 

253 # if we're here, treat it as an opaque value 

254 if not allow_overwrite: 

255 msg_title = ( 

256 f"Can't overwrite (allow_overwrite=False) at {' -> '.join(path)}:" 

257 ) 

258 msg = f"{msg_title}\n\t{repr(v_old)}\n\twith\n\t{repr(v_new)}" 

259 raise ValueError(msg) 

260 return v_new 

261 

262 def merge_with( 

263 self, 

264 obj, 

265 *, 

266 ignore_invalid: bool = False, 

267 allow_overwrite: bool = False, 

268 _path: List[str] = None, 

269 ): 

270 """Return a new partial model with updated fields (without validation). 

271 

272 Raises `ValidationError` if passed `obj` is not suitable, 

273 unless `ignore_invalid` is set to `True`. 

274 

275 Raises `ValueError` if `allow_overwrite=False` and a value would be overwritten. 

276 """ 

277 _path = _path or [] 

278 obj = self.cast(obj, ignore_invalid=ignore_invalid) # raises on failure 

279 

280 ret = self.copy() # type: ignore 

281 for f_name, v_new in self.__partial_fac__._get_field_vals(obj): 

282 v_old = ret.__dict__.get(f_name) 

283 v_merged = self._update_field( 

284 v_old, v_new, path=_path + [f_name], allow_overwrite=allow_overwrite 

285 ) 

286 ret.__dict__[f_name] = v_merged 

287 return ret 

288 

289 @classmethod 

290 def merge(cls, *objs: PartialModel, **kwargs) -> PartialModel: 

291 """Merge all passed partial models in given order using `merge_with`.""" 

292 # sadly it looks like *args and named kwargs (,*,) syntax cannot be mixed 

293 ignore_invalid = kwargs.get("ignore_invalid", False) 

294 allow_overwrite = kwargs.get("allow_overwrite", False) 

295 

296 if not objs: 

297 return cls() 

298 

299 def merge_two(x, y): 

300 return cls.cast(x).merge_with( 

301 y, ignore_invalid=ignore_invalid, allow_overwrite=allow_overwrite 

302 ) 

303 

304 return cls.cast(reduce(merge_two, objs)) 

305 

306 

307# ---- 

308# Partial factory 

309# 

310# PartialModel-specific lookup dicts 

311_partials: Dict[Type[PartialFactory], Dict[Type[BaseModel], Type[PartialModel]]] = {} 

312_forwardrefs: Dict[Type[PartialFactory], Dict[str, Type[PartialModel]]] = {} 

313 

314 

315class PartialFactory: 

316 """Factory class to create and manage partial models.""" 

317 

318 base_model: Type[BaseModel] = BaseModel 

319 partial_mixin: Type[PartialModel] = PartialModel 

320 

321 # TODO: how to configure the whole partial family 

322 # with default parameters for merge() efficiently? 

323 # ---- 

324 # default arguments for merge (if not explicitly passed) 

325 # allow_overwrite: bool = False 

326 # """Default argument for merge() of partials.""" 

327 

328 # ignore_invalid: bool = False 

329 # """Default argument for merge() of partials.""" 

330 

331 @classmethod 

332 def _is_base_subclass(cls, obj: Any) -> bool: 

333 if not isinstance(obj, type): 

334 return False # not a class (probably just a hint) 

335 if not issubclass(obj, cls.base_model): 

336 return False # not a suitable model 

337 return True 

338 

339 @classmethod 

340 def _partial_name(cls, mcls: Type[BaseModel]) -> str: 

341 """Return class name for partial of model `mcls`.""" 

342 return f"{mcls.__qualname__}.{cls.partial_mixin.__name__}" 

343 

344 @classmethod 

345 def _partial_forwardref_name(cls, mcls: Type[BaseModel]) -> str: 

346 """Return ForwardRef string for partial of model `mcls`.""" 

347 return f"__{mcls.__module__}_{cls._partial_name(mcls)}".replace(".", "_") 

348 

349 @classmethod 

350 def _get_field_vals(cls, obj: BaseModel) -> Iterator[Tuple[str, Any]]: 

351 """Return field values, excluding None and private fields. 

352 

353 This is different from `BaseModel.dict` as it ignores the defined alias 

354 and is used here only for "internal representation". 

355 """ 

356 return ( 

357 (k, v) 

358 for k, v in obj.__dict__.items() 

359 if is_public_name(k) and v is not None 

360 ) 

361 

362 @classmethod 

363 def _nested_models(cls, field_types: Dict[str, t.TypeHint]) -> Set[Type[BaseModel]]: 

364 """Collect all compatible nested model classes (for which we need partials).""" 

365 return { 

366 cast(Type[BaseModel], h) 

367 for th in field_types.values() 

368 for h in t.traverse_typehint(th) 

369 if cls._is_base_subclass(h) 

370 } 

371 

372 @classmethod 

373 def _model_to_partial_fref(cls, orig_type: t.TypeHint) -> t.TypeHint: 

374 """Substitute type hint with forward reference to partial models. 

375 

376 Will return unchanged type if passed argument is not suitable. 

377 """ 

378 if not cls._is_base_subclass(orig_type): 

379 return orig_type 

380 return ForwardRef(cls._partial_forwardref_name(orig_type)) 

381 

382 @classmethod 

383 def _model_to_partial( 

384 cls, mcls: Type[BaseModel] 

385 ) -> Type[Union[BaseModel, PartialModel]]: 

386 """Substitute a model class with partial model. 

387 

388 Will return unchanged type if argument not suitable. 

389 """ 

390 return cls.get_partial(mcls) if cls._is_base_subclass(mcls) else mcls 

391 

392 @classmethod 

393 def _partial_type(cls, orig_type: t.TypeHint) -> t.TypeHint: 

394 """Convert a field type hint into a type hint for the partial. 

395 

396 This will make the field optional and also replace all nested models 

397 derived from the configured base_model with the respective partial model. 

398 """ 

399 return Optional[t.map_typehint(orig_type, cls._model_to_partial_fref)] 

400 

401 @classmethod 

402 def _partial_field(cls, orig_type: t.TypeHint) -> Tuple[Type, Optional[FieldInfo]]: 

403 """Return a field declaration tuple for dynamic model creation.""" 

404 th, fi = orig_type, None 

405 

406 # if pydantic Field is added (in an Annotated[...]) - unwrap 

407 if t.get_origin(orig_type) is Annotated: 

408 args = t.get_args(orig_type) 

409 th = args[0] 

410 fi = next(filter(lambda ann: isinstance(ann, FieldInfo), args[1:]), None) 

411 

412 pth = cls._partial_type(th) # map the (unwrapped) type to optional 

413 return (pth, fi) 

414 

415 @classmethod 

416 def _create_base_partial(cls): 

417 class PartialBaseModel(cls.partial_mixin, cls.base_model): 

418 class Config: 

419 frozen = True # make sure it's hashable 

420 

421 return PartialBaseModel 

422 

423 @classmethod 

424 def _create_partial(cls, mcls: Type[BaseModel], *, typehints=None): 

425 """Create a new partial model class based on `mcls`.""" 

426 if not cls._is_base_subclass(mcls): 

427 raise TypeError(f"{mcls} is not subclass of {cls.base_model.__name__}!") 

428 if mcls is cls.base_model: 

429 return (cls._create_base_partial(), []) 

430 # ---- 

431 # get field type annotations (or use the passed ones / for performance) 

432 hints = typehints or t.get_type_hints(mcls) 

433 field_types = {k: v for k, v in hints.items() if k in mcls.__fields__} 

434 # get dependencies that must be substituted 

435 missing_partials = cls._nested_models(field_types) 

436 # compute new field types 

437 new_fields = {k: cls._partial_field(v) for k, v in field_types.items()} 

438 # replace base classes with corresponding partial bases 

439 new_bases = tuple(map(cls._model_to_partial, mcls.__bases__)) 

440 # create partial model 

441 ret: Type[PartialModel] = create_model( 

442 cls._partial_name(mcls), 

443 __base__=new_bases, 

444 __module__=mcls.__module__, 

445 __validators__=mcls.__validators__, # type: ignore 

446 **new_fields, 

447 ) 

448 ret.__partial_src__ = mcls # connect to original model 

449 ret.__partial_fac__ = cls # connect to this class 

450 # ---- 

451 return ret, missing_partials 

452 

453 @classmethod 

454 def get_partial(cls, mcls: Type[BaseModel], *, typehints=None): 

455 """Return a partial schema with all fields of the given schema optional. 

456 

457 Original default values are not respected and are set to `None`. 

458 

459 The use of the returned class is for validating partial data before 

460 zipping together partial results into a completed one. 

461 

462 This is a much more fancy version of e.g. 

463 https://github.com/pydantic/pydantic/issues/1799 

464 

465 because it recursively substitutes with partial models. 

466 This allows us to implement smart deep merge for partials. 

467 """ 

468 if cls not in _partials: # first use of this partial factory 

469 _partials[cls] = {} 

470 _forwardrefs[cls] = {} 

471 

472 if partial := _partials[cls].get(mcls): 

473 return partial # already have a partial 

474 else: # block the spot (to break recursion) 

475 _partials[cls][mcls] = None 

476 

477 # ---- 

478 # create a partial for a model: 

479 # localns = {cls._partial_forwardref_name(k): v for k,v in _partials[cls].items() if v} 

480 localns: Dict[str, Any] = {} 

481 mcls.update_forward_refs(**localns) # to be sure 

482 

483 partial, nested = cls._create_partial(mcls, typehints=typehints) 

484 partial_ref = cls._partial_forwardref_name(mcls) 

485 # store result 

486 _forwardrefs[cls][partial_ref] = partial 

487 _partials[cls][mcls] = partial 

488 # create partials for nested models 

489 for model in nested: 

490 cls.get_partial(model) 

491 # resolve possible circular references 

492 partial.update_forward_refs(**_forwardrefs[cls]) # type: ignore 

493 # ---- 

494 return partial