Coverage for src/metador_core/util/typing.py: 99%
106 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-02 09:33 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-02 09:33 +0000
1import enum
2from collections import ChainMap
3from typing import (
4 Any,
5 Callable,
6 Dict,
7 Iterable,
8 List,
9 Literal,
10 Mapping,
11 Set,
12 Tuple,
13 Type,
14 Union,
15)
17import typing_extensions as te
19# import typing_utils # for issubtype
20from runtype import validation as rv # for is_subtype
21from typing_extensions import Annotated, ClassVar, TypedDict
23get_args = te.get_args # re-export get_args
25TypeHint: Any
26"""For documentation purposes - to mark type hint arguments."""
29def get_type_hints(cls) -> Mapping[str, Any]:
30 """Return type hints of this class."""
31 return te.get_type_hints(cls, include_extras=True)
34def get_annotations(cls, *, all: bool = False) -> Mapping[str, Any]:
35 """Return (non-inherited) annotations (unparsed) of given class."""
36 if not all:
37 return cls.__dict__.get("__annotations__", {})
38 return ChainMap(*(c.__dict__.get("__annotations__", {}) for c in cls.__mro__))
41to38hint: Dict[Type, Any] = {
42 list: List,
43 set: Set,
44 dict: Dict,
45 type: Type,
46}
47"""Type hint map for consistent behavior between python versions."""
50def get_origin(hint):
51 # Fixed get_origin to return 3.8-compatible type hints.
52 # Need this for pydantic dynamic model generation in 3.8.
53 o = te.get_origin(hint)
54 return to38hint.get(o, o)
57# ----
60def is_list(hint):
61 return get_origin(hint) is List
64def is_set(hint):
65 return get_origin(hint) is Set
68def is_union(hint):
69 return get_origin(hint) is Union
72def is_classvar(hint):
73 return get_origin(hint) is ClassVar
76def is_annotated(hint):
77 return get_origin(hint) is Annotated
80def is_literal(hint):
81 return get_origin(hint) is Literal
84NoneType = type(None)
87def is_nonetype(hint):
88 return hint is NoneType
91def is_optional(hint):
92 # internally, Optional is just sugar for a Union including NoneType.
93 return is_union(hint) and any(map(is_nonetype, get_args(hint)))
96# ----
99def make_typehint(h, *args):
100 if is_optional(h):
101 args_ = list(args)
102 args_.append(type(None))
103 elif is_annotated(h):
104 args_ = (args[0],)
105 else:
106 args_ = args
107 return h.copy_with(tuple(args_))
110UNION_PROXY = Union[int, str]
112LIT = Literal[1]
113TUP = Tuple[Any]
116def make_literal(val):
117 """Given a JSON object, return type that parses exactly that object.
119 Note that dicts can have extra fields that will be ignored
120 and that coercion between bool and int might happen.
122 Sets and floats are not supported.
123 """
124 if val is None:
125 return type(None)
126 elif isinstance(val, (bool, int, str)):
127 return make_typehint(LIT, val)
128 elif issubclass(val.__class__, enum.Enum):
129 return make_typehint(LIT, val.value)
130 elif isinstance(val, (tuple, list)):
131 args = tuple(map(make_literal, val))
132 return make_typehint(TUP, *args)
133 elif isinstance(val, dict):
134 d = {k: make_literal(v) for k, v in val.items()}
135 # NOTE: the TypedDict must be from typing_extensions for 3.8!
136 return TypedDict("AnonConstDict", d) # type: ignore
137 raise ValueError(f"Unsupported value: {val}")
140def make_tree_traversal(succ_func: Callable[[Any], Iterable]):
141 """Return generator to traverse nodes of a tree-shaped object.
143 Returned function has a boolean keyword argument `post_order`.
144 If True, will emit the parent node after children instead of before.
146 Args:
147 succ_func: Function to be called on each node returning Iterable of children
148 """
150 def traverse(obj, *, post_order: bool = False):
151 if not post_order:
152 yield obj
153 for t in succ_func(obj):
154 yield from traverse(t, post_order=post_order)
155 if post_order:
156 yield obj
158 return traverse
161traverse_typehint = make_tree_traversal(get_args)
162"""Perform depth-first pre-order traversal of a type annotation.
164Args:
165 th (object): type hint object to be traversed
166"""
169def make_tree_mapper(node_constructor, succ_func):
170 def map_func(obj, leaf_map_func):
171 if children := succ_func(obj):
172 mcs = (map_func(c, leaf_map_func) for c in children)
173 return node_constructor(obj, *mcs)
174 else:
175 return leaf_map_func(obj)
177 return map_func
180map_typehint = make_tree_mapper(make_typehint, get_args)
183def unoptional(th):
184 """Return type hint that is not optional (if it was optional)."""
185 if is_annotated(th):
186 # remove inner optional, preserve annotation
187 return make_typehint(th, unoptional(get_args(th)[0]))
189 if not is_union(th):
190 # all optionals are actually unions -> nothing to do
191 return th
193 # filter out NoneType from the Union arguments
194 args = tuple(filter(lambda h: not is_nonetype(h), get_args(th)))
195 if len(args) == 1:
196 # not a union anymore -> remove type
197 return args[0]
198 # remove union without NoneType (i.e. not optional)
199 return make_typehint(UNION_PROXY, *args)
202# ----
205def is_subtype(sub, base):
206 # add hack to ignore pydantic Annotated FieldInfo
207 # add hacky fix for literals
208 # NOTE: this is only superficial, actually issubtype must be fixed
209 # or it won't work with nested Annotated types or complicated stuff
210 ann_sub, ann_base = is_annotated(sub), is_annotated(base)
211 lit_sub, lit_base = is_literal(sub), is_literal(base)
212 if ann_sub != ann_base or lit_sub != lit_base:
213 return False # not equal on annotated wrapping status
215 if not ann_sub:
216 # proceed as usual
217 return rv.is_subtype(sub, base)
218 # return typing_utils.issubtype(sub, base)
219 else:
220 sub_args, base_args = get_args(sub), get_args(base)
221 # NOTE: FieldInfo of pydantic is not comparable :( so we ignore it
222 # same_ann = list(sub_args)[1:] == list(base_args)[1:]
223 return is_subtype(sub_args[0], base_args[0]) # and same_ann
226def is_subtype_of(t: Any) -> Callable[[Any], bool]:
227 """Return a predicate to check issubtype for a given type."""
228 return lambda obj: is_subtype(obj, t)
231def is_instance_of(t: Any) -> Callable[[Any], bool]:
232 """Return a predicate to check isinstance for a given type."""
233 return lambda obj: isinstance(obj, t)
236def is_subclass_of(t: Any) -> Callable[[Any], bool]:
237 """Return a predicate to check issubclass for a given type."""
238 return lambda obj: isinstance(obj, type) and issubclass(obj, t)
241is_enum = is_subclass_of(enum.Enum)