Coverage for src/metador_core/schema/core.py: 100%

212 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-02 09:33 +0000

1"""Core Metadata schemas for Metador that are essential for the container API.""" 

2 

3 

4from __future__ import annotations 

5 

6import contextlib 

7from collections import ChainMap 

8from functools import partial 

9from typing import Any, ClassVar, Dict, Optional, Set, Type, cast 

10 

11import wrapt 

12from pydantic import Extra, root_validator 

13 

14from ..plugin.metaclass import PluginMetaclassMixin, UndefVersion 

15from ..util import cache, is_public_name 

16from ..util.models import field_atomic_types, traverse_typehint 

17from ..util.typing import ( 

18 get_annotations, 

19 get_type_hints, 

20 is_classvar, 

21 is_instance_of, 

22 is_subclass_of, 

23 is_subtype, 

24) 

25from .base import BaseModelPlus 

26from .encoder import DynEncoderModelMetaclass 

27from .inspect import ( 

28 FieldInspector, 

29 LiftedRODict, 

30 WrappedLiftedDict, 

31 lift_dict, 

32 make_field_inspector, 

33) 

34 

35# from .jsonschema import finalize_schema_extra, schema_of 

36from .partial import PartialFactory, is_mergeable_type 

37from .types import to_semver_str 

38 

39 

40def add_missing_field_descriptions(schema, model): 

41 """Add missing field descriptions from own Fields info, if possible.""" 

42 for fname, fjsdef in schema.get("properties", {}).items(): 

43 if not fjsdef.get("description"): 

44 with contextlib.suppress(KeyError): 

45 if desc := model.Fields[fname].description: 

46 fjsdef["description"] = desc 

47 

48 

49KEY_SCHEMA_PG = "$metador_plugin" 

50"""Key in JSON schema to put metador plugin name an version.""" 

51 

52KEY_SCHEMA_CONSTFLDS = "$metador_constants" 

53"""Key in JSON schema to put metador 'constant fields'.""" 

54 

55 

56class SchemaBase(BaseModelPlus): 

57 __constants__: ClassVar[Dict[str, Any]] 

58 """Constant model fields, usually added with a decorator, ignored on input.""" 

59 

60 __overrides__: ClassVar[Set[str]] 

61 """Field names explicitly overriding inherited field type. 

62 

63 Those not listed (by @overrides decorator) must, if they are overridden, 

64 be strict subtypes of the inherited type.""" 

65 

66 __types_checked__: ClassVar[bool] 

67 """Helper flag used by check_overrides to avoid re-checking.""" 

68 

69 class Config: 

70 @staticmethod 

71 def schema_extra(schema: Dict[str, Any], model: Type[BaseModelPlus]) -> None: 

72 model = UndefVersion._unwrap(model) or model 

73 

74 # custom extra key to connect back to metador schema: 

75 # TODO: does not always work?! 

76 # if pgi := model.__dict__.get("Plugin"): 

77 # schema[KEY_SCHEMA_PG] = pgi.ref().copy(exclude={"group"}).json_dict() 

78 

79 # enrich schema with descriptions retrieved from e.g. docstrings 

80 if model is not MetadataSchema: 

81 add_missing_field_descriptions(schema, model) 

82 

83 # special handling for "constant fields" 

84 if model.__constants__: 

85 schema[KEY_SCHEMA_CONSTFLDS] = {} 

86 for cname, cval in model.__constants__.items(): 

87 # list them (so they are not rejected even with additionalProperties=False) 

88 schema["properties"][cname] = True 

89 # store the constant alongside the schema 

90 schema[KEY_SCHEMA_CONSTFLDS][cname] = cval 

91 

92 # do magic (TODO: fix/rewrite) 

93 # finalize_schema_extra(schema, model, base_model=MetadataSchema) 

94 

95 # NOTE: custom JSON schema feature is broken 

96 # @classmethod 

97 # def schema(cls, *args, **kwargs): 

98 # """Return customized JSONSchema for this model.""" 

99 # return schema_of(UndefVersion._unwrap(cls) or cls, *args, **kwargs) 

100 

101 @root_validator(pre=True) 

102 def override_consts(cls, values): 

103 """Override/add defined schema constants. 

104 

105 They must be present on dump, but are ignored on load. 

106 """ 

107 values.update(cls.__constants__) 

108 return values 

109 

110 

111ALLOWED_SCHEMA_CONFIG_FIELDS = {"title", "extra", "allow_mutation"} 

112"""Allowed pydantic Config fields to be overridden in schema models.""" 

113 

114 

115class SchemaMagic(DynEncoderModelMetaclass): 

116 """Metaclass for doing some magic.""" 

117 

118 def __new__(cls, name, bases, dct): 

119 # enforce single inheritance 

120 if len(bases) > 1: 

121 raise TypeError("A schema can only have one parent schema!") 

122 baseschema = bases[0] 

123 

124 # only allow inheriting from other schemas: 

125 # NOTE: can't normally happen (this metaclass won't be triggered) 

126 # if not issubclass(baseschema, SchemaBase): 

127 # raise TypeError(f"Base class {baseschema} is not a MetadataSchema!") 

128 

129 # prevent user from defining special schema fields by hand 

130 for atr in SchemaBase.__annotations__.keys(): 

131 if atr in dct: 

132 raise TypeError(f"{name}: Invalid attribute '{atr}'") 

133 

134 # prevent most changes to pydantic config 

135 if conf := dct.get("Config"): 

136 for conffield in conf.__dict__: 

137 if ( 

138 is_public_name(conffield) 

139 and conffield not in ALLOWED_SCHEMA_CONFIG_FIELDS 

140 ): 

141 raise TypeError(f"{name}: {conffield} must not be set or changed!") 

142 

143 # generate pydantic model of schema (further checks are easier that way) 

144 # can't do these checks in __init__, because in __init__ the bases could be mangled 

145 ret = super().__new__(cls, name, bases, dct) 

146 

147 # prevent user defining fields that are constants in a parent 

148 if base_consts := set(getattr(baseschema, "__constants__", {}).keys()): 

149 new_defs = set(get_annotations(ret).keys()) 

150 if illegal := new_defs.intersection(base_consts): 

151 msg = ( 

152 f"{name}: Cannot define {illegal}, defined as const field already!" 

153 ) 

154 raise TypeError(msg) 

155 

156 # prevent parent-compat breaking change of extra handling / new fields: 

157 parent_forbids_extras = baseschema.__config__.extra is Extra.forbid 

158 if parent_forbids_extras: 

159 # if parent forbids, child does not -> problem (child can parse, parent can't) 

160 extra = ret.__config__.extra 

161 if extra is not Extra.forbid: 

162 msg = ( 

163 f"{name}: cannot {extra.value} extra fields if parent forbids them!" 

164 ) 

165 raise TypeError(msg) 

166 

167 # parent forbids extras, child has new fields -> same problem 

168 if new_flds := set(ret.__fields__.keys()) - set( 

169 baseschema.__fields__.keys() 

170 ): 

171 msg = f"{name}: Cannot define new fields {new_flds} if parent forbids extra fields!" 

172 raise TypeError(msg) 

173 

174 # everything looks ok 

175 return ret 

176 

177 def __init__(self, name, bases, dct): 

178 self.__types_checked__ = False # marker used by check_types (for performance) 

179 

180 # prevent implicit inheritance of class-specific internal/meta stuff: 

181 # should be taken care of by plugin metaclass 

182 assert self.Plugin is None or self.Plugin != bases[0].Plugin 

183 

184 # also prevent inheriting override marker 

185 self.__overrides__ = set() 

186 

187 # and manually prevent inheriting annotations (for Python < 3.10) 

188 if "__annotations__" not in self.__dict__: 

189 self.__annotations__ = {} 

190 

191 # "constant fields" are inherited, but copied - not shared 

192 self.__constants__ = {} 

193 for b in bases: 

194 self.__constants__.update(getattr(b, "__constants__", {})) 

195 

196 @property # type: ignore 

197 @cache # noqa: B019 

198 def _typehints(self): 

199 """Return typehints of this class.""" 

200 return get_type_hints(self) 

201 

202 @property # type: ignore 

203 @cache # noqa: B019 

204 def _base_typehints(self): 

205 """Return typehints accumulated from base class chain.""" 

206 return ChainMap( 

207 *(b._typehints for b in self.__bases__ if issubclass(b, SchemaBase)) 

208 ) 

209 

210 # ---- for public use ---- 

211 

212 def __str__(self): 

213 """Show schema and field documentation.""" 

214 unwrapped = UndefVersion._unwrap(self) 

215 schema = unwrapped or self 

216 defstr = f"Schema {super().__str__()}" 

217 defstr = f"{defstr}\n{'='*len(defstr)}" 

218 descstr = "" 

219 if schema.__doc__ is not None and schema.__doc__.strip(): 

220 desc = schema.__doc__ 

221 descstr = f"\nDescription:\n------------\n\t{desc}" 

222 fieldsstr = f"Fields:\n-------\n\n{str(self.Fields)}" 

223 return "\n".join([defstr, descstr, fieldsstr]) 

224 

225 @property 

226 def Fields(self: Any) -> Any: 

227 """Access the field introspection interface.""" 

228 fields = make_schema_inspector(self) 

229 # make sure that subschemas accessed in a schema without explicit version 

230 # are also marked so we can check if someone uses them illegally 

231 if issubclass(self, UndefVersion): 

232 return WrappedLiftedDict(fields, UndefVersionFieldInspector) 

233 else: 

234 return fields 

235 

236 @property 

237 def Partial(self): 

238 """Access the partial schema based on the current schema.""" 

239 return PartialSchemas.get_partial(self) 

240 

241 

242class SchemaMetaclass(PluginMetaclassMixin, SchemaMagic): 

243 """Combine schema magic with general plugin magic.""" 

244 

245 

246class MetadataSchema(SchemaBase, metaclass=SchemaMetaclass): 

247 """Extends Pydantic models with custom serializers and functions.""" 

248 

249 Plugin: ClassVar[Optional[Type]] # user-defined inner class (for schema plugins) 

250 

251 

252# ---- 

253 

254 

255def _indent_text(txt, prefix="\t"): 

256 """Add indentation at new lines in given string.""" 

257 return "\n".join(map(lambda line: f"{prefix}{line}", txt.split("\n"))) 

258 

259 

260class SchemaFieldInspector(FieldInspector): 

261 """MetadataSchema-specific field inspector. 

262 

263 It adds a user-friendly repr and access to nested subschemas. 

264 """ 

265 

266 schemas: LiftedRODict 

267 _origin_name: str 

268 

269 def _get_description(self): 

270 fld = self.origin._typehints[self.name] 

271 if any(map(lambda x: fld is x, (None, type(None), bool, int, float, str))): 

272 return None 

273 return super()._get_description() 

274 

275 def __init__(self, model: Type[MetadataSchema], name: str, hint: str): 

276 super().__init__(model, name, hint) 

277 

278 # to show plugin name and version in case of registered plugin schemas: 

279 og = self.origin 

280 self._origin_name = f"{og.__module__}.{og.__qualname__}" 

281 if pgi := og.Plugin: 

282 self._origin_name += f" (plugin: {pgi.name} {to_semver_str(pgi.version)})" 

283 

284 # access to sub-entities/schemas: 

285 subschemas = list(field_atomic_types(og.__fields__[name], bound=MetadataSchema)) 

286 self.schemas = lift_dict("Schemas", {s.__name__: s for s in set(subschemas)}) 

287 

288 def __repr__(self) -> str: 

289 desc_str = "" 

290 if self.description: 

291 desc_str = f"description:\n{_indent_text(self.description)}\n" 

292 schemas_str = ( 

293 f"schemas: {', '.join(iter(self.schemas))}\n" if self.schemas else "" 

294 ) 

295 info = f"type: {str(self.type)}\norigin: {self._origin_name}\n{schemas_str}{desc_str}" 

296 return f"{self.name}\n{_indent_text(info)}" 

297 

298 

299def _is_schema_field(schema, key: str): 

300 """Return whether a given key name is a non-constant model field.""" 

301 return key in schema.__fields__ and key not in schema.__constants__ 

302 

303 

304@cache 

305def make_schema_inspector(schema): 

306 return make_field_inspector( 

307 schema, 

308 "Fields", 

309 bound=MetadataSchema, 

310 key_filter=partial(_is_schema_field, schema), 

311 i_cls=SchemaFieldInspector, 

312 ) 

313 

314 

315# ---- 

316# wrapper to "infect" nested schemas with UndefVersion 

317 

318 

319class UndefVersionFieldInspector(wrapt.ObjectProxy): 

320 @property 

321 def schemas(self): 

322 return WrappedLiftedDict(self.__wrapped__.schemas, UndefVersion._mark_class) 

323 

324 def __repr__(self): 

325 return repr(self.__wrapped__) 

326 

327 

328# ---- 

329 

330 

331class PartialSchemas(PartialFactory): 

332 """Partial model for MetadataSchema model. 

333 

334 Needed for harvesters to work (which can provide validated but partial metadata). 

335 """ 

336 

337 base_model = SchemaBase 

338 

339 # override to ignore "constant fields" for partials 

340 @classmethod 

341 def _get_field_vals(cls, obj): 

342 return ( 

343 (k, v) 

344 for k, v in super()._get_field_vals(obj) 

345 if k not in obj.__constants__ 

346 ) 

347 

348 # override to add some fixes to partials 

349 @classmethod 

350 def _create_partial(cls, mcls, *, typehints=...): 

351 th = getattr(mcls, "_typehints", None) 

352 unw = UndefVersion._unwrap(mcls) or mcls 

353 ret, nested = super()._create_partial(unw, typehints=th) 

354 # attach constant field list for field filtering 

355 ret.__constants__ = getattr(mcls, "__constants__", set()) 

356 # copy custom parser to partial 

357 # (otherwise partial can't parse correctly with parser mixin) 

358 if parser := getattr(mcls, "Parser", None): 

359 ret.Parser = parser 

360 return (ret, nested) 

361 

362 

363# --- delayed checks (triggered during schema loading) --- 

364# if we do it earlier, might lead to problems with forward refs and circularity 

365 

366 

367def check_types(schema: Type[MetadataSchema], *, recheck: bool = False): 

368 if schema is MetadataSchema or schema.__types_checked__ and not recheck: 

369 return 

370 schema.__types_checked__ = True 

371 

372 # recursively check compositional and inheritance dependencies 

373 for b in schema.__bases__: 

374 if issubclass(b, MetadataSchema): 

375 check_types(b, recheck=recheck) 

376 

377 schemaFields = cast(Any, schema.Fields) 

378 for f in schemaFields: # type: ignore 

379 for sname in schemaFields[f].schemas: 

380 s = schemaFields[f].schemas[sname] 

381 if s is not schema and issubclass(s, MetadataSchema): 

382 check_types(s, recheck=recheck) 

383 

384 check_allowed_types(schema) 

385 check_overrides(schema) 

386 

387 

388def check_allowed_types(schema: Type[MetadataSchema]): 

389 """Check that shape of defined fields is suitable for deep merging.""" 

390 hints = cast(Any, schema._typehints) 

391 for field, hint in hints.items(): 

392 if not is_public_name(field): 

393 continue # private field 

394 if not is_mergeable_type(hint): 

395 msg = f"{schema}:\n\ttype of '{field}' contains a forbidden pattern!" 

396 raise TypeError(msg) 

397 

398 # check that no nested schemas from undefVersion plugins are used in field definitions 

399 # (the Plugin metaclass cannot check this, but it checks for inheritance) 

400 if illegal := next( 

401 filter( 

402 is_subclass_of(UndefVersion), 

403 filter(is_instance_of(type), traverse_typehint(hint)), 

404 ), 

405 None, 

406 ): 

407 msg = f"{schema}:\n\ttype of '{field}' contains an illegal subschema:\n\t\t{illegal}" 

408 raise TypeError(msg) 

409 

410 

411def infer_parent(plugin: Type[MetadataSchema]) -> Optional[Type[MetadataSchema]]: 

412 """Return closest base schema that is a plugin, or None. 

413 

414 This allows to skip over intermediate schemas and bases that are not plugins. 

415 """ 

416 return next( 

417 filter( 

418 lambda c: issubclass(c, MetadataSchema) and c.__dict__.get("Plugin"), 

419 plugin.__mro__[1:], 

420 ), 

421 None, 

422 ) 

423 

424 

425def is_pub_instance_field(schema, name, hint): 

426 """Return whether field `name` in `schema` is a non-constant, public schema instance field.""" 

427 return ( 

428 is_public_name(name) 

429 and not is_classvar(hint) 

430 and name not in schema.__constants__ 

431 ) 

432 

433 

434def detect_field_overrides(schema: Type[MetadataSchema]): 

435 anns = get_annotations(schema) 

436 base_hints = cast(Any, schema._base_typehints) 

437 new_hints = {n for n, h in anns.items() if is_pub_instance_field(schema, n, h)} 

438 return set(base_hints.keys()).intersection(new_hints) 

439 

440 

441def check_overrides(schema: Type[MetadataSchema]): 

442 """Check that fields are overridden to subtypes or explicitly declared as overridden.""" 

443 hints = cast(Any, schema._typehints) 

444 base_hints = cast(Any, schema._base_typehints) 

445 

446 actual_overrides = detect_field_overrides(schema) 

447 undecl_override = actual_overrides - schema.__overrides__ 

448 if unreal_override := schema.__overrides__ - set(base_hints.keys()): 

449 msg = f"{schema.__name__}: No parent field to override: {unreal_override}" 

450 raise ValueError(msg) 

451 

452 if miss_override := schema.__overrides__ - actual_overrides: 

453 msg = f"{schema.__name__}: Missing claimed field overrides: {miss_override}" 

454 raise ValueError(msg) 

455 

456 # all undeclared overrides must be strict subtypes of the inherited type: 

457 for fname in undecl_override: 

458 hint, parent_hint = hints[fname], base_hints[fname] 

459 if not is_subtype(hint, parent_hint): 

460 parent = infer_parent(schema) 

461 parent_name = ( 

462 parent.Fields[fname]._origin_name 

463 if parent 

464 else schema.__base__.__name__ 

465 ) 

466 msg = f"""The type assigned to field '{fname}' 

467in schema {repr(schema)}: 

468 

469 {hint} 

470 

471does not look like a valid subtype of the inherited type: 

472 

473 {parent_hint} 

474 

475from schema {parent_name}. 

476 

477If you are ABSOLUTELY sure that this is a false alarm, 

478use the @overrides decorator to silence this error 

479and live forever with the burden of responsibility. 

480""" 

481 raise TypeError(msg) 

482 

483 

484# def detect_field_overrides(schema: Type[MetadataSchema]) -> Set[str]: 

485# return {n 

486# for n in updated_fields(schema) 

487# if is_public_name(n) and not is_classvar(schema._typehints[n]) and n not in schema.__constants__ 

488# }