Coverage for src/metador_core/schema/jsonschema.py: 38%

120 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-02 09:33 +0000

1"""Hacks to improve pydantic JSON Schema generation.""" 

2 

3import json 

4from functools import partial 

5from typing import Any, Dict, Iterable, List, Type, Union 

6 

7from pydantic import BaseModel 

8from pydantic import schema_of as pyd_schema_of 

9from pydantic.schema import schema as pyd_schemas 

10from typing_extensions import TypeAlias 

11 

12from ..util.hashsums import hashsum 

13from ..util.models import updated_fields 

14 

15KEY_SCHEMA_DEFS = "$defs" 

16"""JSON schema key to store subschema definitions.""" 

17 

18KEY_SCHEMA_HASH = "$jsonschema_hash" 

19"""Custom key to store schema hashsum.""" 

20 

21# ---- 

22 

23JSON_PRIMITIVE_TYPES = (type(None), bool, int, float, str) 

24 

25# shallow type definitions 

26JSONPrimitive: TypeAlias = Union[None, bool, int, float, str] 

27JSONObject: TypeAlias = Dict[str, Any] 

28JSONArray: TypeAlias = List[Any] 

29JSONType: TypeAlias = Union[JSONPrimitive, JSONArray, JSONObject] 

30 

31JSONSCHEMA_STRIP = { 

32 # these are meta-level keys that should not affect the hash: 

33 "title", 

34 "description", 

35 "examples", 

36 "$comment", 

37 "readOnly", 

38 "writeOnly", 

39 "deprecated", 

40 "$id", 

41 # subschemas have their own hashes, so if referenced schemas 

42 # change, then the $refs change automatically. 

43 "definitions", 

44 KEY_SCHEMA_DEFS, 

45 # we can't hash the hash, it cannot be part of the clean schema. 

46 KEY_SCHEMA_HASH, 

47} 

48"""Fields to be removed for JSON Schema hashsum computation.""" 

49 

50 

51def clean_jsonschema(obj: JSONType, *, _is_properties: bool = False): 

52 if isinstance(obj, JSON_PRIMITIVE_TYPES): 

53 return obj 

54 if isinstance(obj, list): 

55 return list(map(clean_jsonschema, obj)) 

56 if isinstance(obj, dict): 

57 return { 

58 k: clean_jsonschema(v, _is_properties=k == "properties") 

59 for k, v in obj.items() 

60 # must ensure not to touch keys in a properties sub-object! 

61 if _is_properties or k not in JSONSCHEMA_STRIP 

62 } 

63 

64 raise ValueError(f"Object {obj} not of a JSON type: {type(obj)}") 

65 

66 

67def normalized_json(obj: JSONType) -> bytes: 

68 return json.dumps( 

69 obj, 

70 ensure_ascii=True, 

71 allow_nan=False, 

72 indent=None, 

73 sort_keys=True, 

74 separators=(",", ":"), 

75 ).encode("utf-8") 

76 

77 

78def jsonschema_id(schema: JSONType): 

79 """Compute robust semantic schema identifier. 

80 

81 A schema identifier is based on the schema plugin name + version 

82 and its JSON Schema representation, which includes all parent and nested schemas. 

83 """ 

84 return hashsum(normalized_json(clean_jsonschema(schema)), "sha256") 

85 

86 

87# ---- 

88 

89 

90def lift_nested_defs(schema: JSONObject): 

91 """Flatten nested $defs ($defs -> key -> $defs) in-place.""" 

92 if mydefs := schema.get(KEY_SCHEMA_DEFS): 

93 inner = [] 

94 for schema in mydefs.values(): 

95 lift_nested_defs(schema) 

96 if nested := schema.pop(KEY_SCHEMA_DEFS, None): 

97 inner.append(nested) 

98 for nested in inner: 

99 mydefs.update(nested) 

100 

101 

102KEY_PYD_DEFS = "definitions" 

103"""key name where pydantic stores subschema definitions.""" 

104 

105REF_PREFIX = f"#/{KEY_PYD_DEFS}/" 

106"""default $refs prefix of pydantic.""" 

107 

108 

109def merge_nested_defs(schema: JSONObject): 

110 """Merge definitions in-place.""" 

111 if defs := schema.pop(KEY_PYD_DEFS, None): 

112 my_defs = schema.get(KEY_SCHEMA_DEFS) 

113 if not my_defs: 

114 schema[KEY_SCHEMA_DEFS] = {} 

115 my_defs = schema[KEY_SCHEMA_DEFS] 

116 # update, by preserve existing 

117 defs.update(my_defs) 

118 my_defs.update(defs) 

119 

120 

121# ---- 

122 

123 

124def collect_defmap(defs: JSONObject): 

125 """Compute dict mapping current name in $defs to new name based on metador_hash.""" 

126 defmap = {} 

127 for name, subschema in defs.items(): 

128 if KEY_SCHEMA_HASH in subschema: 

129 defmap[name] = subschema[KEY_SCHEMA_HASH].strip("/") 

130 else: 

131 # print("no hashsum: ", name) 

132 defmap[name] = name 

133 

134 return defmap 

135 

136 

137def map_ref(defmap, refstr: str): 

138 """Update the `$ref` string based on defmap. 

139 

140 Will replace `#/definitions/orig` 

141 with `#/$defs/mapped`. 

142 """ 

143 if refstr.startswith(REF_PREFIX): 

144 # print("remap", refstr) 

145 plen = len(REF_PREFIX) 

146 if new_name := defmap.get(refstr[plen:]): 

147 return f"#/{KEY_SCHEMA_DEFS}/{new_name}" 

148 return refstr 

149 

150 

151def update_refs(defmap, obj): 

152 """Recursively update `$ref` in `obj` based on defmap.""" 

153 print("update", obj) 

154 if isinstance(obj, (type(None), bool, int, float, str)): 

155 return obj 

156 elif isinstance(obj, list): 

157 return list(map(partial(update_refs, defmap), obj)) 

158 elif isinstance(obj, dict): 

159 return { 

160 k: (update_refs(defmap, v) if k != "$ref" else map_ref(defmap, v)) 

161 for k, v in obj.items() 

162 } 

163 raise ValueError(f"Object {obj} not of a JSON type: {type(obj)}") 

164 

165 

166def remap_refs(schema): 

167 """Remap the $refs to use metador_hash-based keys. 

168 

169 Input must be a completed schema with a global `$defs` section 

170 that all nested entities use for local references. 

171 """ 

172 defs = schema.pop(KEY_SCHEMA_DEFS, None) 

173 if not defs: # nothing to do 

174 return schema 

175 

176 # get name map, old -> new 

177 defmap = collect_defmap(defs) 

178 # update refs 

179 defs.update(update_refs(defmap, defs)) 

180 schema.update(update_refs(defmap, schema)) 

181 # rename defs 

182 schema[KEY_SCHEMA_DEFS] = {defmap[k]: v for k, v in defs.items()} 

183 

184 

185# ---- 

186 

187 

188def fixup_jsonschema(schema): 

189 merge_nested_defs(schema) # move `definitions` into `$defs` 

190 lift_nested_defs(schema) # move nested `$defs` to top level `$defs` 

191 remap_refs(schema) # "rename" defs from model name to metador hashsum 

192 

193 

194def schema_of(model: Type[BaseModel], *args, **kwargs): 

195 """Return JSON Schema for a model. 

196 

197 Improved version of `pydantic.schema_of`, returns result 

198 in $defs normal form, with $ref pointing to the model. 

199 """ 

200 schema = pyd_schema_of(model, *args, **kwargs) 

201 print(type(schema), schema) 

202 schema.pop("title", None) 

203 fixup_jsonschema(schema) 

204 return schema 

205 

206 

207def schemas(models: Iterable[Type[BaseModel]], *args, **kwargs): 

208 """Return JSON Schema for multiple models. 

209 

210 Improved version of `pydantic.schema.schema`, 

211 returns result in $defs normal form. 

212 """ 

213 schema = pyd_schemas(tuple(models), *args, **kwargs) 

214 fixup_jsonschema(schema) 

215 return schema 

216 

217 

218# ---- 

219 

220 

221def split_model_inheritance(schema: JSONObject, model: Type[BaseModel]): 

222 """Decompose a model into an allOf combination with a parent model. 

223 

224 This is ugly because pydantic does in-place wrangling and caching, 

225 and we need to hack around it. 

226 """ 

227 # NOTE: important - we assume to get the $defs standard form 

228 # print("want schema of", model.__base__.__name__) 

229 base_schema = model.__base__.schema() # type: ignore 

230 

231 # compute filtered properties / required section 

232 schema_new = dict(schema) 

233 ps = schema_new.pop("properties", None) 

234 rq = schema_new.pop("required", None) 

235 

236 lst_fields = updated_fields(model) 

237 ps_new = {k: v for k, v in ps.items() if k in lst_fields} 

238 rq_new = None if not rq else [k for k in rq if k in ps_new] 

239 schema_this = {k: v for k, v in [("properties", ps_new), ("required", rq_new)] if v} 

240 

241 # construct new schema as combination of base schema and remainder schema 

242 schema_new.update( 

243 { 

244 # "rdfs:subClassOf": f"/{base_id}", 

245 "allOf": [{"$ref": base_schema["$ref"]}, schema_this], 

246 } 

247 ) 

248 

249 # we need to add the definitions to/from the base schema as well 

250 if KEY_SCHEMA_DEFS not in schema_new: 

251 schema_new[KEY_SCHEMA_DEFS] = {} 

252 schema_new[KEY_SCHEMA_DEFS].update(base_schema.get(KEY_SCHEMA_DEFS, {})) 

253 

254 schema.clear() 

255 schema.update(schema_new) 

256 

257 

258def finalize_schema_extra( 

259 schema: JSONObject, 

260 model: Type[BaseModel], 

261 *, 

262 base_model: Type[BaseModel] = None, 

263) -> None: 

264 """Perform custom JSON Schema postprocessing. 

265 

266 To be called as last action in custom schema_extra method in the used base model. 

267 

268 Arguments: 

269 schema: The JSON object containing the schema 

270 model: The underlying pydantic model 

271 base_model: The custom base model that this function is called for. 

272 """ 

273 base_model = base_model or BaseModel 

274 assert issubclass(model, base_model) 

275 

276 # a schema should have a specified standard 

277 schema["$schema"] = "https://json-schema.org/draft/2020-12/schema" 

278 

279 if model.__base__ is not base_model: 

280 # tricky part: de-duplicate fields from parent class 

281 split_model_inheritance(schema, model) 

282 

283 # do this last, because it needs everything else to compute the correct hashsum: 

284 schema[KEY_SCHEMA_HASH] = f"{jsonschema_id(schema)}" 

285 fixup_jsonschema(schema)