Coverage for src/metador_core/schema/core.py: 100%

214 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-08 10:29 +0000

1"""Core Metadata schemas for Metador that are essential for the container API.""" 

2 

3 

4from __future__ import annotations 

5 

6import contextlib 

7from collections import ChainMap 

8from functools import partial 

9from typing import Any, ClassVar, Dict, Optional, Set, Type, cast 

10 

11import wrapt 

12from pydantic import Extra, root_validator 

13 

14from ..plugin.metaclass import PluginMetaclassMixin, UndefVersion 

15from ..util import cache, is_public_name 

16from ..util.models import field_atomic_types, traverse_typehint 

17from ..util.typing import ( 

18 get_annotations, 

19 get_type_hints, 

20 is_classvar, 

21 is_instance_of, 

22 is_subclass_of, 

23 is_subtype, 

24) 

25from .base import BaseModelPlus 

26from .encoder import DynEncoderModelMetaclass 

27from .inspect import ( 

28 FieldInspector, 

29 LiftedRODict, 

30 WrappedLiftedDict, 

31 lift_dict, 

32 make_field_inspector, 

33) 

34 

35# from .jsonschema import finalize_schema_extra, schema_of 

36from .partial import PartialFactory, is_mergeable_type 

37from .types import to_semver_str 

38 

39 

40def add_missing_field_descriptions(schema, model): 

41 """Add missing field descriptions from own Fields info, if possible.""" 

42 for fname, fjsdef in schema.get("properties", {}).items(): 

43 if not fjsdef.get("description"): 

44 with contextlib.suppress(KeyError): 

45 if desc := model.Fields[fname].description: 

46 fjsdef["description"] = desc 

47 

48 

49KEY_SCHEMA_PG = "$metador_plugin" 

50"""Key in JSON schema to put metador plugin name an version.""" 

51 

52KEY_SCHEMA_CONSTFLDS = "$metador_constants" 

53"""Key in JSON schema to put metador 'constant fields'.""" 

54 

55 

56class SchemaBase(BaseModelPlus): 

57 __constants__: ClassVar[Dict[str, Any]] 

58 """Constant model fields, usually added with a decorator, ignored on input.""" 

59 

60 __overrides__: ClassVar[Set[str]] 

61 """Field names explicitly overriding inherited field type. 

62 

63 Those not listed (by @overrides decorator) must, if they are overridden, 

64 be strict subtypes of the inherited type.""" 

65 

66 __types_checked__: ClassVar[bool] 

67 """Helper flag used by check_overrides to avoid re-checking.""" 

68 

69 class Config: 

70 @staticmethod 

71 def schema_extra(schema: Dict[str, Any], model: Type[BaseModelPlus]) -> None: 

72 model = UndefVersion._unwrap(model) or model 

73 

74 # custom extra key to connect back to metador schema: 

75 if pgi := model.__dict__.get("Plugin"): 

76 schema[KEY_SCHEMA_PG] = pgi.ref().copy(exclude={"group"}).json_dict() 

77 

78 # enrich schema with descriptions retrieved from e.g. docstrings 

79 if model is not MetadataSchema: 

80 add_missing_field_descriptions(schema, model) 

81 

82 # special handling for "constant fields" 

83 if model.__constants__: 

84 schema[KEY_SCHEMA_CONSTFLDS] = {} 

85 for cname, cval in model.__constants__.items(): 

86 # list them (so they are not rejected even with additionalProperties=False) 

87 schema["properties"][cname] = True 

88 # store the constant alongside the schema 

89 schema[KEY_SCHEMA_CONSTFLDS][cname] = cval 

90 

91 # do magic (TODO: fix/rewrite) 

92 # finalize_schema_extra(schema, model, base_model=MetadataSchema) 

93 

94 # NOTE: custom JSON schema feature is broken 

95 # @classmethod 

96 # def schema(cls, *args, **kwargs): 

97 # """Return customized JSONSchema for this model.""" 

98 # return schema_of(UndefVersion._unwrap(cls) or cls, *args, **kwargs) 

99 

100 @root_validator(pre=True) 

101 def override_consts(cls, values): 

102 """Override/add defined schema constants. 

103 

104 They must be present on dump, but are ignored on load. 

105 """ 

106 values.update(cls.__constants__) 

107 return values 

108 

109 

110ALLOWED_SCHEMA_CONFIG_FIELDS = {"title", "extra", "allow_mutation"} 

111"""Allowed pydantic Config fields to be overridden in schema models.""" 

112 

113 

114class SchemaMagic(DynEncoderModelMetaclass): 

115 """Metaclass for doing some magic.""" 

116 

117 def __new__(cls, name, bases, dct): 

118 # enforce single inheritance 

119 if len(bases) > 1: 

120 raise TypeError("A schema can only have one parent schema!") 

121 baseschema = bases[0] 

122 

123 # only allow inheriting from other schemas: 

124 # NOTE: can't normally happen (this metaclass won't be triggered) 

125 # if not issubclass(baseschema, SchemaBase): 

126 # raise TypeError(f"Base class {baseschema} is not a MetadataSchema!") 

127 

128 # prevent user from defining special schema fields by hand 

129 for atr in SchemaBase.__annotations__.keys(): 

130 if atr in dct: 

131 raise TypeError(f"{name}: Invalid attribute '{atr}'") 

132 

133 # prevent most changes to pydantic config 

134 if conf := dct.get("Config"): 

135 for conffield in conf.__dict__: 

136 if ( 

137 is_public_name(conffield) 

138 and conffield not in ALLOWED_SCHEMA_CONFIG_FIELDS 

139 ): 

140 raise TypeError(f"{name}: {conffield} must not be set or changed!") 

141 

142 # generate pydantic model of schema (further checks are easier that way) 

143 # can't do these checks in __init__, because in __init__ the bases could be mangled 

144 ret = super().__new__(cls, name, bases, dct) 

145 

146 # prevent user defining fields that are constants in a parent 

147 if base_consts := set(getattr(baseschema, "__constants__", {}).keys()): 

148 new_defs = set(get_annotations(ret).keys()) 

149 if illegal := new_defs.intersection(base_consts): 

150 msg = ( 

151 f"{name}: Cannot define {illegal}, defined as const field already!" 

152 ) 

153 raise TypeError(msg) 

154 

155 # prevent parent-compat breaking change of extra handling / new fields: 

156 parent_forbids_extras = baseschema.__config__.extra is Extra.forbid 

157 if parent_forbids_extras: 

158 # if parent forbids, child does not -> problem (child can parse, parent can't) 

159 extra = ret.__config__.extra 

160 if extra is not Extra.forbid: 

161 msg = ( 

162 f"{name}: cannot {extra.value} extra fields if parent forbids them!" 

163 ) 

164 raise TypeError(msg) 

165 

166 # parent forbids extras, child has new fields -> same problem 

167 if new_flds := set(ret.__fields__.keys()) - set( 

168 baseschema.__fields__.keys() 

169 ): 

170 msg = f"{name}: Cannot define new fields {new_flds} if parent forbids extra fields!" 

171 raise TypeError(msg) 

172 

173 # everything looks ok 

174 return ret 

175 

176 def __init__(self, name, bases, dct): 

177 self.__types_checked__ = False # marker used by check_types (for performance) 

178 

179 # prevent implicit inheritance of class-specific internal/meta stuff: 

180 # should be taken care of by plugin metaclass 

181 assert self.Plugin is None or self.Plugin != bases[0].Plugin 

182 

183 # also prevent inheriting override marker 

184 self.__overrides__ = set() 

185 

186 # and manually prevent inheriting annotations (for Python < 3.10) 

187 if "__annotations__" not in self.__dict__: 

188 self.__annotations__ = {} 

189 

190 # "constant fields" are inherited, but copied - not shared 

191 self.__constants__ = {} 

192 for b in bases: 

193 self.__constants__.update(getattr(b, "__constants__", {})) 

194 

195 @property # type: ignore 

196 @cache # noqa: B019 

197 def _typehints(self): 

198 """Return typehints of this class.""" 

199 return get_type_hints(self) 

200 

201 @property # type: ignore 

202 @cache # noqa: B019 

203 def _base_typehints(self): 

204 """Return typehints accumulated from base class chain.""" 

205 return ChainMap( 

206 *(b._typehints for b in self.__bases__ if issubclass(b, SchemaBase)) 

207 ) 

208 

209 # ---- for public use ---- 

210 

211 def __str__(self): 

212 """Show schema and field documentation.""" 

213 unwrapped = UndefVersion._unwrap(self) 

214 schema = unwrapped or self 

215 defstr = f"Schema {super().__str__()}" 

216 defstr = f"{defstr}\n{'='*len(defstr)}" 

217 descstr = "" 

218 if schema.__doc__ is not None and schema.__doc__.strip(): 

219 desc = schema.__doc__ 

220 descstr = f"\nDescription:\n------------\n\t{desc}" 

221 fieldsstr = f"Fields:\n-------\n\n{str(self.Fields)}" 

222 return "\n".join([defstr, descstr, fieldsstr]) 

223 

224 @property 

225 def Fields(self: Any) -> Any: 

226 """Access the field introspection interface.""" 

227 fields = make_schema_inspector(self) 

228 # make sure that subschemas accessed in a schema without explicit version 

229 # are also marked so we can check if someone uses them illegally 

230 if issubclass(self, UndefVersion): 

231 return WrappedLiftedDict(fields, UndefVersionFieldInspector) 

232 else: 

233 return fields 

234 

235 @property 

236 def Partial(self): 

237 """Access the partial schema based on the current schema.""" 

238 return PartialSchemas.get_partial(self) 

239 

240 

241class SchemaMetaclass(PluginMetaclassMixin, SchemaMagic): 

242 """Combine schema magic with general plugin magic.""" 

243 

244 

245class MetadataSchema(SchemaBase, metaclass=SchemaMetaclass): 

246 """Extends Pydantic models with custom serializers and functions.""" 

247 

248 Plugin: ClassVar[Optional[Type]] # user-defined inner class (for schema plugins) 

249 

250 

251# ---- 

252 

253 

254def _indent_text(txt, prefix="\t"): 

255 """Add indentation at new lines in given string.""" 

256 return "\n".join(map(lambda line: f"{prefix}{line}", txt.split("\n"))) 

257 

258 

259class SchemaFieldInspector(FieldInspector): 

260 """MetadataSchema-specific field inspector. 

261 

262 It adds a user-friendly repr and access to nested subschemas. 

263 """ 

264 

265 schemas: LiftedRODict 

266 _origin_name: str 

267 

268 def _get_description(self): 

269 fld = self.origin._typehints[self.name] 

270 if any(map(lambda x: fld is x, (None, type(None), bool, int, float, str))): 

271 return None 

272 return super()._get_description() 

273 

274 def __init__(self, model: Type[MetadataSchema], name: str, hint: str): 

275 super().__init__(model, name, hint) 

276 

277 # to show plugin name and version in case of registered plugin schemas: 

278 og = self.origin 

279 self._origin_name = f"{og.__module__}.{og.__qualname__}" 

280 if pgi := og.Plugin: 

281 self._origin_name += f" (plugin: {pgi.name} {to_semver_str(pgi.version)})" 

282 

283 # access to sub-entities/schemas: 

284 subschemas = list(field_atomic_types(og.__fields__[name], bound=MetadataSchema)) 

285 self.schemas = lift_dict("Schemas", {s.__name__: s for s in set(subschemas)}) 

286 

287 def __repr__(self) -> str: 

288 desc_str = "" 

289 if self.description: 

290 desc_str = f"description:\n{_indent_text(self.description)}\n" 

291 schemas_str = ( 

292 f"schemas: {', '.join(iter(self.schemas))}\n" if self.schemas else "" 

293 ) 

294 info = f"type: {str(self.type)}\norigin: {self._origin_name}\n{schemas_str}{desc_str}" 

295 return f"{self.name}\n{_indent_text(info)}" 

296 

297 

298def _is_schema_field(schema, key: str): 

299 """Return whether a given key name is a non-constant model field.""" 

300 return key in schema.__fields__ and key not in schema.__constants__ 

301 

302 

303@cache 

304def make_schema_inspector(schema): 

305 return make_field_inspector( 

306 schema, 

307 "Fields", 

308 bound=MetadataSchema, 

309 key_filter=partial(_is_schema_field, schema), 

310 i_cls=SchemaFieldInspector, 

311 ) 

312 

313 

314# ---- 

315# wrapper to "infect" nested schemas with UndefVersion 

316 

317 

318class UndefVersionFieldInspector(wrapt.ObjectProxy): 

319 @property 

320 def schemas(self): 

321 return WrappedLiftedDict(self.__wrapped__.schemas, UndefVersion._mark_class) 

322 

323 def __repr__(self): 

324 return repr(self.__wrapped__) 

325 

326 

327# ---- 

328 

329 

330class PartialSchemas(PartialFactory): 

331 """Partial model for MetadataSchema model. 

332 

333 Needed for harvesters to work (which can provide validated but partial metadata). 

334 """ 

335 

336 base_model = SchemaBase 

337 

338 # override to ignore "constant fields" for partials 

339 @classmethod 

340 def _get_field_vals(cls, obj): 

341 return ( 

342 (k, v) 

343 for k, v in super()._get_field_vals(obj) 

344 if k not in obj.__constants__ 

345 ) 

346 

347 # override to add some fixes to partials 

348 @classmethod 

349 def _create_partial(cls, mcls, *, typehints=...): 

350 th = getattr(mcls, "_typehints", None) 

351 unw = UndefVersion._unwrap(mcls) or mcls 

352 ret, nested = super()._create_partial(unw, typehints=th) 

353 # attach constant field list for field filtering 

354 ret.__constants__ = getattr(mcls, "__constants__", set()) 

355 # copy custom parser to partial 

356 # (otherwise partial can't parse correctly with parser mixin) 

357 if parser := getattr(mcls, "Parser", None): 

358 ret.Parser = parser 

359 return (ret, nested) 

360 

361 

362# --- delayed checks (triggered during schema loading) --- 

363# if we do it earlier, might lead to problems with forward refs and circularity 

364 

365 

366def check_types(schema: Type[MetadataSchema], *, recheck: bool = False): 

367 if schema is MetadataSchema or schema.__types_checked__ and not recheck: 

368 return 

369 schema.__types_checked__ = True 

370 

371 # recursively check compositional and inheritance dependencies 

372 for b in schema.__bases__: 

373 if issubclass(b, MetadataSchema): 

374 check_types(b, recheck=recheck) 

375 

376 schemaFields = cast(Any, schema.Fields) 

377 for f in schemaFields: # type: ignore 

378 for sname in schemaFields[f].schemas: 

379 s = schemaFields[f].schemas[sname] 

380 if s is not schema and issubclass(s, MetadataSchema): 

381 check_types(s, recheck=recheck) 

382 

383 check_allowed_types(schema) 

384 check_overrides(schema) 

385 

386 

387def check_allowed_types(schema: Type[MetadataSchema]): 

388 """Check that shape of defined fields is suitable for deep merging.""" 

389 hints = cast(Any, schema._typehints) 

390 for field, hint in hints.items(): 

391 if not is_public_name(field): 

392 continue # private field 

393 if not is_mergeable_type(hint): 

394 msg = f"{schema}:\n\ttype of '{field}' contains a forbidden pattern!" 

395 raise TypeError(msg) 

396 

397 # check that no nested schemas from undefVersion plugins are used in field definitions 

398 # (the Plugin metaclass cannot check this, but it checks for inheritance) 

399 if illegal := next( 

400 filter( 

401 is_subclass_of(UndefVersion), 

402 filter(is_instance_of(type), traverse_typehint(hint)), 

403 ), 

404 None, 

405 ): 

406 msg = f"{schema}:\n\ttype of '{field}' contains an illegal subschema:\n\t\t{illegal}" 

407 raise TypeError(msg) 

408 

409 

410def infer_parent(plugin: Type[MetadataSchema]) -> Optional[Type[MetadataSchema]]: 

411 """Return closest base schema that is a plugin, or None. 

412 

413 This allows to skip over intermediate schemas and bases that are not plugins. 

414 """ 

415 return next( 

416 filter( 

417 lambda c: issubclass(c, MetadataSchema) and c.__dict__.get("Plugin"), 

418 plugin.__mro__[1:], 

419 ), 

420 None, 

421 ) 

422 

423 

424def is_pub_instance_field(schema, name, hint): 

425 """Return whether field `name` in `schema` is a non-constant, public schema instance field.""" 

426 return ( 

427 is_public_name(name) 

428 and not is_classvar(hint) 

429 and name not in schema.__constants__ 

430 ) 

431 

432 

433def detect_field_overrides(schema: Type[MetadataSchema]): 

434 anns = get_annotations(schema) 

435 base_hints = cast(Any, schema._base_typehints) 

436 new_hints = {n for n, h in anns.items() if is_pub_instance_field(schema, n, h)} 

437 return set(base_hints.keys()).intersection(new_hints) 

438 

439 

440def check_overrides(schema: Type[MetadataSchema]): 

441 """Check that fields are overridden to subtypes or explicitly declared as overridden.""" 

442 hints = cast(Any, schema._typehints) 

443 base_hints = cast(Any, schema._base_typehints) 

444 

445 actual_overrides = detect_field_overrides(schema) 

446 undecl_override = actual_overrides - schema.__overrides__ 

447 if unreal_override := schema.__overrides__ - set(base_hints.keys()): 

448 msg = f"{schema.__name__}: No parent field to override: {unreal_override}" 

449 raise ValueError(msg) 

450 

451 if miss_override := schema.__overrides__ - actual_overrides: 

452 msg = f"{schema.__name__}: Missing claimed field overrides: {miss_override}" 

453 raise ValueError(msg) 

454 

455 # all undeclared overrides must be strict subtypes of the inherited type: 

456 for fname in undecl_override: 

457 hint, parent_hint = hints[fname], base_hints[fname] 

458 if not is_subtype(hint, parent_hint): 

459 parent = infer_parent(schema) 

460 parent_name = ( 

461 parent.Fields[fname]._origin_name 

462 if parent 

463 else schema.__base__.__name__ 

464 ) 

465 msg = f"""The type assigned to field '{fname}' 

466in schema {repr(schema)}: 

467 

468 {hint} 

469 

470does not look like a valid subtype of the inherited type: 

471 

472 {parent_hint} 

473 

474from schema {parent_name}. 

475 

476If you are ABSOLUTELY sure that this is a false alarm, 

477use the @overrides decorator to silence this error 

478and live forever with the burden of responsibility. 

479""" 

480 raise TypeError(msg) 

481 

482 

483# def detect_field_overrides(schema: Type[MetadataSchema]) -> Set[str]: 

484# return {n 

485# for n in updated_fields(schema) 

486# if is_public_name(n) and not is_classvar(schema._typehints[n]) and n not in schema.__constants__ 

487# }