Coverage for src/metador_core/util/typing.py: 99%

106 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-02 09:33 +0000

1import enum 

2from collections import ChainMap 

3from typing import ( 

4 Any, 

5 Callable, 

6 Dict, 

7 Iterable, 

8 List, 

9 Literal, 

10 Mapping, 

11 Set, 

12 Tuple, 

13 Type, 

14 Union, 

15) 

16 

17import typing_extensions as te 

18 

19# import typing_utils # for issubtype 

20from runtype import validation as rv # for is_subtype 

21from typing_extensions import Annotated, ClassVar, TypedDict 

22 

23get_args = te.get_args # re-export get_args 

24 

25TypeHint: Any 

26"""For documentation purposes - to mark type hint arguments.""" 

27 

28 

29def get_type_hints(cls) -> Mapping[str, Any]: 

30 """Return type hints of this class.""" 

31 return te.get_type_hints(cls, include_extras=True) 

32 

33 

34def get_annotations(cls, *, all: bool = False) -> Mapping[str, Any]: 

35 """Return (non-inherited) annotations (unparsed) of given class.""" 

36 if not all: 

37 return cls.__dict__.get("__annotations__", {}) 

38 return ChainMap(*(c.__dict__.get("__annotations__", {}) for c in cls.__mro__)) 

39 

40 

41to38hint: Dict[Type, Any] = { 

42 list: List, 

43 set: Set, 

44 dict: Dict, 

45 type: Type, 

46} 

47"""Type hint map for consistent behavior between python versions.""" 

48 

49 

50def get_origin(hint): 

51 # Fixed get_origin to return 3.8-compatible type hints. 

52 # Need this for pydantic dynamic model generation in 3.8. 

53 o = te.get_origin(hint) 

54 return to38hint.get(o, o) 

55 

56 

57# ---- 

58 

59 

60def is_list(hint): 

61 return get_origin(hint) is List 

62 

63 

64def is_set(hint): 

65 return get_origin(hint) is Set 

66 

67 

68def is_union(hint): 

69 return get_origin(hint) is Union 

70 

71 

72def is_classvar(hint): 

73 return get_origin(hint) is ClassVar 

74 

75 

76def is_annotated(hint): 

77 return get_origin(hint) is Annotated 

78 

79 

80def is_literal(hint): 

81 return get_origin(hint) is Literal 

82 

83 

84NoneType = type(None) 

85 

86 

87def is_nonetype(hint): 

88 return hint is NoneType 

89 

90 

91def is_optional(hint): 

92 # internally, Optional is just sugar for a Union including NoneType. 

93 return is_union(hint) and any(map(is_nonetype, get_args(hint))) 

94 

95 

96# ---- 

97 

98 

99def make_typehint(h, *args): 

100 if is_optional(h): 

101 args_ = list(args) 

102 args_.append(type(None)) 

103 elif is_annotated(h): 

104 args_ = (args[0],) 

105 else: 

106 args_ = args 

107 return h.copy_with(tuple(args_)) 

108 

109 

110UNION_PROXY = Union[int, str] 

111 

112LIT = Literal[1] 

113TUP = Tuple[Any] 

114 

115 

116def make_literal(val): 

117 """Given a JSON object, return type that parses exactly that object. 

118 

119 Note that dicts can have extra fields that will be ignored 

120 and that coercion between bool and int might happen. 

121 

122 Sets and floats are not supported. 

123 """ 

124 if val is None: 

125 return type(None) 

126 elif isinstance(val, (bool, int, str)): 

127 return make_typehint(LIT, val) 

128 elif issubclass(val.__class__, enum.Enum): 

129 return make_typehint(LIT, val.value) 

130 elif isinstance(val, (tuple, list)): 

131 args = tuple(map(make_literal, val)) 

132 return make_typehint(TUP, *args) 

133 elif isinstance(val, dict): 

134 d = {k: make_literal(v) for k, v in val.items()} 

135 # NOTE: the TypedDict must be from typing_extensions for 3.8! 

136 return TypedDict("AnonConstDict", d) # type: ignore 

137 raise ValueError(f"Unsupported value: {val}") 

138 

139 

140def make_tree_traversal(succ_func: Callable[[Any], Iterable]): 

141 """Return generator to traverse nodes of a tree-shaped object. 

142 

143 Returned function has a boolean keyword argument `post_order`. 

144 If True, will emit the parent node after children instead of before. 

145 

146 Args: 

147 succ_func: Function to be called on each node returning Iterable of children 

148 """ 

149 

150 def traverse(obj, *, post_order: bool = False): 

151 if not post_order: 

152 yield obj 

153 for t in succ_func(obj): 

154 yield from traverse(t, post_order=post_order) 

155 if post_order: 

156 yield obj 

157 

158 return traverse 

159 

160 

161traverse_typehint = make_tree_traversal(get_args) 

162"""Perform depth-first pre-order traversal of a type annotation. 

163 

164Args: 

165 th (object): type hint object to be traversed 

166""" 

167 

168 

169def make_tree_mapper(node_constructor, succ_func): 

170 def map_func(obj, leaf_map_func): 

171 if children := succ_func(obj): 

172 mcs = (map_func(c, leaf_map_func) for c in children) 

173 return node_constructor(obj, *mcs) 

174 else: 

175 return leaf_map_func(obj) 

176 

177 return map_func 

178 

179 

180map_typehint = make_tree_mapper(make_typehint, get_args) 

181 

182 

183def unoptional(th): 

184 """Return type hint that is not optional (if it was optional).""" 

185 if is_annotated(th): 

186 # remove inner optional, preserve annotation 

187 return make_typehint(th, unoptional(get_args(th)[0])) 

188 

189 if not is_union(th): 

190 # all optionals are actually unions -> nothing to do 

191 return th 

192 

193 # filter out NoneType from the Union arguments 

194 args = tuple(filter(lambda h: not is_nonetype(h), get_args(th))) 

195 if len(args) == 1: 

196 # not a union anymore -> remove type 

197 return args[0] 

198 # remove union without NoneType (i.e. not optional) 

199 return make_typehint(UNION_PROXY, *args) 

200 

201 

202# ---- 

203 

204 

205def is_subtype(sub, base): 

206 # add hack to ignore pydantic Annotated FieldInfo 

207 # add hacky fix for literals 

208 # NOTE: this is only superficial, actually issubtype must be fixed 

209 # or it won't work with nested Annotated types or complicated stuff 

210 ann_sub, ann_base = is_annotated(sub), is_annotated(base) 

211 lit_sub, lit_base = is_literal(sub), is_literal(base) 

212 if ann_sub != ann_base or lit_sub != lit_base: 

213 return False # not equal on annotated wrapping status 

214 

215 if not ann_sub: 

216 # proceed as usual 

217 return rv.is_subtype(sub, base) 

218 # return typing_utils.issubtype(sub, base) 

219 else: 

220 sub_args, base_args = get_args(sub), get_args(base) 

221 # NOTE: FieldInfo of pydantic is not comparable :( so we ignore it 

222 # same_ann = list(sub_args)[1:] == list(base_args)[1:] 

223 return is_subtype(sub_args[0], base_args[0]) # and same_ann 

224 

225 

226def is_subtype_of(t: Any) -> Callable[[Any], bool]: 

227 """Return a predicate to check issubtype for a given type.""" 

228 return lambda obj: is_subtype(obj, t) 

229 

230 

231def is_instance_of(t: Any) -> Callable[[Any], bool]: 

232 """Return a predicate to check isinstance for a given type.""" 

233 return lambda obj: isinstance(obj, t) 

234 

235 

236def is_subclass_of(t: Any) -> Callable[[Any], bool]: 

237 """Return a predicate to check issubclass for a given type.""" 

238 return lambda obj: isinstance(obj, type) and issubclass(obj, t) 

239 

240 

241is_enum = is_subclass_of(enum.Enum)