Coverage for src/metador_core/schema/inspect.py: 100%

88 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-02 09:33 +0000

1from collections import ChainMap 

2from dataclasses import dataclass 

3from io import UnsupportedOperation 

4from typing import Any, Callable, List, Mapping, Optional, Set, Type 

5 

6import wrapt 

7from pydantic import BaseModel 

8from simple_parsing.docstring import get_attribute_docstring 

9 

10from ..util.models import field_origins 

11from ..util.typing import get_annotations 

12 

13 

14class LiftedRODict(type): 

15 """Metaclass for classes providing dict keys as attributes. 

16 

17 Mostly for aesthetic reasons and to be used for things 

18 where the dict is actually a fixed lookup table. 

19 

20 We don't provide explicit `keys`/`values`/`items`, because 

21 these could be key names in the dict. 

22 

23 You can use `iter` to go through the keys and use dict-like 

24 access, if dynamic iteration is needed. 

25 """ 

26 

27 # NOTE: we don't want to add non-default methods 

28 # because method names collide with dict keys 

29 

30 _dict: Mapping[str, Any] 

31 """The underlying dict.""" 

32 

33 _keys: Optional[List[str]] = None 

34 """Optionally, list of keys in desired order.""" 

35 

36 _repr: Optional[Callable] = None 

37 """Optional custom repr string or function.""" 

38 

39 def __repr__(self): 

40 # choose best representation based on configuration 

41 if self._repr: 

42 return self._repr(self) 

43 if self._keys: 

44 return repr(self._keys) 

45 return repr(list(self._dict.keys())) 

46 

47 def __dir__(self): 

48 # helpful for tab completion 

49 return list(self._dict.keys()) 

50 

51 def __bool__(self): 

52 return bool(self._dict) 

53 

54 def __contains__(self, key): 

55 return key in self._dict 

56 

57 def __iter__(self): 

58 if self._keys: 

59 return iter(self._keys) 

60 return iter(self._dict) 

61 

62 def __getitem__(self, key): 

63 return self._dict[key] 

64 

65 def __getattr__(self, key): 

66 try: 

67 return self[key] 

68 except KeyError as e: 

69 raise AttributeError(str(e)) 

70 

71 def __setattr__(self, key, value): 

72 # this is supposed to be read-only 

73 raise UnsupportedOperation 

74 

75 

76def lift_dict(name, dct, *, keys=None, repr=None): 

77 """Return LiftedRODict class based on passed dict.""" 

78 assert hasattr(dct, "__getitem__") 

79 kwargs = {"_dict": dct} 

80 if keys is not None: 

81 assert set(keys) == set(iter(dct)) 

82 kwargs["_keys"] = keys 

83 if repr is not None: 

84 kwargs["_repr"] = repr 

85 return LiftedRODict(name, (), kwargs) 

86 

87 

88class WrappedLiftedDict(wrapt.ObjectProxy): 

89 """Wrap values returned by a LiftedRODict.""" 

90 

91 def __init__(self, obj, wrapperfun): 

92 assert isinstance(obj, LiftedRODict) 

93 super().__init__(obj) 

94 self._self_wrapperfun = wrapperfun 

95 

96 def __getitem__(self, key): 

97 return self._self_wrapperfun(self.__wrapped__[key]) 

98 

99 def __getattr__(self, key): 

100 return LiftedRODict.__getattr__(self, key) 

101 

102 def __repr__(self): 

103 return repr(self.__wrapped__) 

104 

105 

106@dataclass 

107class FieldInspector: 

108 """Basic field inspector carrying type and description of a field.""" 

109 

110 origin: Type 

111 name: str 

112 type: str 

113 

114 description: str # declared for proper repr generation 

115 

116 def _get_description(self): 

117 desc = get_attribute_docstring(self.origin, self.name).docstring_below 

118 if not desc: 

119 # if none set, try getting docstring from field type 

120 if ths := getattr(self.origin, "_typehints", None): 

121 th = ths[self.name] 

122 if isinstance(th, type): # it's a class-like thing? 

123 desc = th.__doc__ 

124 

125 return desc 

126 

127 @property # type: ignore 

128 def description(self): 

129 # look up on-demand and cache, could be expensive (parses source) 

130 if not hasattr(self, "_description"): 

131 self._description = self._get_description() 

132 return self._description 

133 

134 def __init__(self, model: Type[BaseModel], name: str, hint: str): 

135 origin = next(field_origins(model, name)) 

136 self.origin = origin 

137 self.name = name 

138 self.type = hint 

139 

140 

141def make_field_inspector( 

142 model: Type[BaseModel], 

143 prop_name: str, 

144 *, 

145 bound: Optional[Type[BaseModel]] = BaseModel, 

146 key_filter: Optional[Callable[[str], bool]], 

147 i_cls: Optional[Type[FieldInspector]] = FieldInspector, 

148) -> Type[LiftedRODict]: 

149 """Create a field inspector class for the given model. 

150 

151 This can be used for introspection about fields and also 

152 enables users to access subschemas without extra imports, 

153 improving decoupling of plugins and packages. 

154 

155 To be used in a metaclass for a custom top level model. 

156 

157 Args: 

158 model: Class for which to return the inspector 

159 prop_name: Name of the metaclass property that wraps this function 

160 i_cls: Optional subclass of FieldInspector to customize it 

161 bound: Top level class using the custom metaclass that uses this function 

162 key_filter: Predicate used to filter the annotations that are to be inspectable 

163 Returns: 

164 A fresh inspector class for the fields. 

165 """ 

166 # get hints corresponding to fields that are not inherited 

167 field_hints = { 

168 k: v 

169 for k, v in get_annotations(model).items() 

170 if not key_filter or key_filter(k) 

171 } 

172 

173 # inspectors for fields declared in the given model (for inherited, will reuse/create parent inspectors) 

174 new_inspectors = {k: i_cls(model, k, v) for k, v in field_hints.items()} 

175 # manually compute desired traversal order (from newest overwritten to oldest inherited fields) 

176 # as the default chain map order semantically is not suitable. 

177 inspectors = [new_inspectors] + [ 

178 getattr(b, prop_name)._dict for b in model.__bases__ if issubclass(b, bound) 

179 ] 

180 covered_keys: Set[str] = set() 

181 ordered_keys: List[str] = [] 

182 for d in inspectors: 

183 rem_keys = set(iter(d)) - covered_keys 

184 covered_keys.update(rem_keys) 

185 ordered_keys += [k for k in d if k in rem_keys] 

186 

187 # construct and return the class 

188 return lift_dict( 

189 f"{model.__name__}.{prop_name}", 

190 ChainMap(*inspectors), 

191 keys=ordered_keys, 

192 repr=lambda self: "\n".join(map(str, (self[k] for k in self))), 

193 )