Coverage for src/metador_core/util/models.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-02 09:33 +0000

1from typing import Dict, Iterator, Set, Type 

2 

3from pydantic import BaseModel 

4from pydantic.fields import ModelField 

5 

6from .typing import get_annotations, get_type_hints, is_subclass_of, traverse_typehint 

7 

8 

9def field_origins(m: Type[BaseModel], name: str) -> Iterator[Type[BaseModel]]: 

10 """Return sequence of bases where the field type hint was defined / overridden.""" 

11 return ( 

12 b for b in m.__mro__ if issubclass(b, BaseModel) and name in get_annotations(b) 

13 ) 

14 

15 

16def updated_fields(m: Type[BaseModel]) -> Set[str]: 

17 """Return subset of fields that are added or overridden by a new type hint.""" 

18 return {n for n in m.__fields__.keys() if next(field_origins(m, n)) is m} 

19 

20 

21def new_fields(m: Type[BaseModel]) -> Set[str]: 

22 # return {n for n in updated_fields(m) if not next(field_origins(m.__base__, n), None)} 

23 return set(m.__fields__.keys()) - set(m.__base__.__fields__.keys()) # type: ignore 

24 

25 

26def field_atomic_types(mf: ModelField, *, bound=object) -> Iterator[Type]: 

27 """Return sequence of nested atomic types in the hint of given field.""" 

28 return filter(is_subclass_of(bound), traverse_typehint(mf.type_)) 

29 

30 

31def atomic_types(m: BaseModel, *, bound=object) -> Dict[str, Set[Type]]: 

32 """Return dict from field name to model classes referenced in the field definition. 

33 

34 Args: 

35 m: Pydantic model 

36 bound (object): If provided, will be used to filter results to 

37 contain only subclasses of the bound. 

38 """ 

39 return {k: set(field_atomic_types(v, bound=bound)) for k, v in m.__fields__.items()} 

40 

41 

42def field_parent_type(m: Type[BaseModel], name: str) -> Type[BaseModel]: 

43 """Return type of field assigned in the next parent that provides a type hint.""" 

44 b = next(filter(lambda x: x is not m, field_origins(m, name)), None) 

45 if not b: 

46 raise ValueError(f"No base class of {m} defines a field called '{name}'!") 

47 return get_type_hints(b).get(name)