Coverage for src/metador_core/plugin/interface.py: 100%

210 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-02 09:33 +0000

1"""Interface for plugin groups.""" 

2from __future__ import annotations 

3 

4from abc import ABCMeta 

5from typing import ( 

6 Any, 

7 ClassVar, 

8 Dict, 

9 Generic, 

10 Iterator, 

11 List, 

12 Optional, 

13 Set, 

14 Tuple, 

15 Type, 

16 TypeVar, 

17 Union, 

18 cast, 

19 overload, 

20) 

21 

22from importlib_metadata import EntryPoint 

23from typing_extensions import TypeAlias 

24 

25from ..schema.plugins import PluginBase, PluginPkgMeta 

26from ..schema.plugins import PluginRef as AnyPluginRef 

27from ..util import eprint 

28from . import util 

29from .entrypoints import get_group, pkg_meta 

30from .metaclass import UndefVersion 

31from .types import ( 

32 EP_NAME_REGEX, 

33 EPName, 

34 PluginLike, 

35 SemVerTuple, 

36 ep_name_has_namespace, 

37 from_ep_name, 

38 is_pluginlike, 

39 plugin_args, 

40 to_semver_str, 

41) 

42 

43PG_GROUP_NAME = "plugingroup" 

44 

45 

46class PGPlugin(PluginBase): 

47 plugin_info_class: Optional[Type[PluginBase]] = None 

48 plugin_class: Optional[Any] = object 

49 

50 

51# TODO: plugin group inheritance is not checked yet because it adds complications 

52class PluginGroupMeta(ABCMeta): 

53 """Metaclass to initialize some things on creation.""" 

54 

55 def __init__(self, name, bases, dct): 

56 assert is_pluginlike(self, check_group=False) 

57 

58 # attach generated subclass that auto-fills the group for plugin infos 

59 self.PluginRef: Type[AnyPluginRef] = AnyPluginRef._subclass_for( 

60 self.Plugin.name 

61 ) 

62 

63 if pgi_cls := self.Plugin.__dict__.get("plugin_info_class"): 

64 # attach group name to provided plugin info class 

65 pgi_cls.group = self.Plugin.name 

66 else: 

67 # derive generic plugin info class with the group defined 

68 class PGInfo(PluginBase): 

69 group = self.Plugin.name 

70 

71 self.Plugin.plugin_info_class = PGInfo 

72 

73 # sanity checks... this magic should not mess with the PluginBase 

74 assert self.Plugin.plugin_info_class is not PluginBase 

75 assert PluginBase.group == "" 

76 

77 

78T = TypeVar("T", bound=PluginLike) 

79 

80 

81class PluginGroup(Generic[T], metaclass=PluginGroupMeta): 

82 """All pluggable entities in metador are subclasses of this class. 

83 

84 The type parameter is the (parent) class of all loaded plugins. 

85 

86 They must implement the check method and be listed as plugin group. 

87 The name of their entrypoint defines the name of the plugin group. 

88 """ 

89 

90 PluginRef: TypeAlias = AnyPluginRef 

91 """Plugin reference class for this plugin group.""" 

92 

93 _PKG_META: ClassVar[Dict[str, PluginPkgMeta]] = pkg_meta 

94 """Package name -> package metadata.""" 

95 

96 class Plugin: 

97 """This is the plugin group plugin group, the first loaded group.""" 

98 

99 name = PG_GROUP_NAME 

100 version = (0, 1, 0) 

101 plugin_info_class = PGPlugin 

102 plugin_class: Type 

103 # plugin_class = PluginGroup # can't set that -> check manually 

104 

105 _ENTRY_POINTS: Dict[EPName, EntryPoint] 

106 """Dict of entry points of versioned plugins (not loaded).""" 

107 

108 _VERSIONS: Dict[str, List[AnyPluginRef]] 

109 """Mapping from plugin name to pluginrefs of available versions.""" 

110 

111 _LOADED_PLUGINS: Dict[AnyPluginRef, Type[T]] 

112 """Dict from entry points to loaded plugins of that pluggable type.""" 

113 

114 def _add_ep(self, epname_str: str, ep_obj: EntryPoint): 

115 """Add an entrypoint loaded from importlib_metadata.""" 

116 try: 

117 ep_name = EPName(epname_str) 

118 except TypeError: 

119 msg = f"{epname_str}: Invalid entrypoint name, must match {EP_NAME_REGEX}" 

120 raise ValueError(msg) 

121 if type(self) is not PluginGroup and not ep_name_has_namespace(ep_name): 

122 msg = f"{epname_str}: Plugin name has no qualifying namespace!" 

123 raise ValueError(msg) 

124 

125 name, version = from_ep_name(ep_name) 

126 p_ref = AnyPluginRef(group=self.name, name=name, version=version) 

127 

128 if ep_name in self._ENTRY_POINTS: 

129 self._LOADED_PLUGINS.pop(p_ref, None) # unload, if loaded 

130 pkg = ep_obj.dist 

131 msg = f"WARNING: {ep_name} is probably provided by multiple packages!\n" 

132 msg += f"The plugin will now be provided by: {pkg.name} {pkg.version}" 

133 eprint(msg) 

134 self._ENTRY_POINTS[ep_name] = ep_obj 

135 

136 if ep_name not in self._VERSIONS: 

137 self._VERSIONS[name] = [] 

138 self._VERSIONS[name].append(p_ref) 

139 self._VERSIONS[name].sort() # should be cheap 

140 

141 def __init__(self, entrypoints: Dict[str, EntryPoint]): 

142 self._ENTRY_POINTS = {} 

143 self._VERSIONS = {} 

144 

145 for k, v in entrypoints.items(): 

146 self._add_ep(k, v) 

147 

148 self._LOADED_PLUGINS = {} 

149 self.__post_init__() 

150 

151 def __post_init__(self): 

152 if type(self) is PluginGroup: 

153 # make the magic plugingroup plugin add itself for consistency 

154 ep_name = util.to_ep_name(self.Plugin.name, self.Plugin.version) 

155 ep_path = f"{type(self).__module__}:{type(self).__name__}" 

156 ep = EntryPoint(ep_name, ep_path, self.name) 

157 self._add_ep(ep_name, ep) 

158 

159 self_ref = AnyPluginRef( 

160 group=self.name, name=self.name, version=self.Plugin.version 

161 ) 

162 self._LOADED_PLUGINS[self_ref] = self 

163 self.provider(self_ref).plugins[self.name].append(self_ref) 

164 

165 @property 

166 def name(self) -> str: 

167 """Return name of the plugin group.""" 

168 return self.Plugin.name 

169 

170 @property 

171 def packages(self) -> Dict[str, PluginPkgMeta]: 

172 """Return metadata of all packages providing metador plugins.""" 

173 return dict(self._PKG_META) 

174 

175 def versions( 

176 self, p_name: str, version: Optional[SemVerTuple] = None 

177 ) -> List[AnyPluginRef]: 

178 """Return installed versions of a plugin (compatible with given version).""" 

179 refs = list(self._VERSIONS.get(p_name) or []) 

180 if version is None: 

181 return refs 

182 requested = self.PluginRef(name=p_name, version=version) 

183 return [ref for ref in refs if ref.supports(requested)] 

184 

185 def resolve( 

186 self, p_name: str, version: Optional[SemVerTuple] = None 

187 ) -> Optional[AnyPluginRef]: 

188 """Return most recent compatible version of a plugin.""" 

189 if refs := self.versions(p_name, version): 

190 return refs[-1] # latest (compatible) version 

191 return None 

192 

193 def provider(self, ref: AnyPluginRef) -> PluginPkgMeta: 

194 """Return package metadata of Python package providing this plugin.""" 

195 if type(self) is PluginGroup and ref.name == PG_GROUP_NAME: 

196 # special case - the mother plugingroup plugin is not an EP, 

197 # so we cheat a bit (schema is in same package, but is an EP) 

198 return self.provider(self.resolve("schema")) 

199 

200 ep_name = util.to_ep_name(ref.name, ref.version) 

201 ep = self._ENTRY_POINTS[ep_name] 

202 return self._PKG_META[cast(Any, ep).dist.name] 

203 

204 def is_plugin(self, p_cls): 

205 """Return whether this class is a (possibly marked) installed plugin.""" 

206 if not isinstance(p_cls, type) or not issubclass( 

207 p_cls, self.Plugin.plugin_class 

208 ): 

209 return False 

210 

211 c = UndefVersion._unwrap(p_cls) or p_cls # get real underlying class 

212 # check its exactly a registered plugin, if it has a Plugin section 

213 if info := c.__dict__.get("Plugin"): 

214 if not isinstance(info, PluginBase): 

215 return False 

216 loaded_p = self._get_unsafe(info.name, info.version) 

217 return loaded_p is c 

218 else: 

219 return False 

220 

221 # ---- 

222 

223 def __repr__(self): 

224 return f"<PluginGroup '{self.name}' {list(self.keys())}>" 

225 

226 def __str__(self): 

227 def pg_line(name_refs): 

228 name, refs = name_refs 

229 vs = list(map(lambda x: to_semver_str(x.version), refs)) 

230 # p = self.provider(pg_ref.name) 

231 # pkg = f"{p.name} {semver_str(p.version)}" 

232 return f"\t'{name}' ({', '.join(vs)})" 

233 

234 pgs = "\n".join(map(pg_line, self._VERSIONS.items())) 

235 return f"Available '{self.name}' plugins:\n{pgs}" 

236 

237 # ---- 

238 # dict-like interface will provide latest versions of plugins by default 

239 

240 def __contains__(self, key) -> bool: 

241 name, version = plugin_args(key) 

242 if pg_versions := self._VERSIONS.get(name): 

243 if not version: 

244 return True 

245 pg = self.PluginRef(name=name, version=version) 

246 return pg in pg_versions 

247 return False 

248 

249 def __getitem__(self, key) -> Type[T]: 

250 if key not in self: 

251 raise KeyError(f"{self.name} not found: {key}") 

252 return self.get(key) 

253 

254 def keys(self) -> Iterator[AnyPluginRef]: 

255 """Return all names of all plugins.""" 

256 for pgs in self._VERSIONS.values(): 

257 yield from pgs 

258 

259 def values(self) -> Iterator[Type[T]]: 

260 """Return latest versions of all plugins (THIS LOADS ALL PLUGINS!).""" 

261 return map(self.__getitem__, self.keys()) 

262 

263 def items(self) -> Iterator[Tuple[AnyPluginRef, Type[T]]]: 

264 """Return pairs of plugin name and latest installed version (THIS LOADS ALL PLUGINS!).""" 

265 return map(lambda k: (k, self[k]), self.keys()) 

266 

267 # ---- 

268 

269 def _get_unsafe(self, p_name: str, version: Optional[SemVerTuple] = None): 

270 """Return most recent compatible version of given plugin name, without safety rails. 

271 

272 Raises KeyError if no (compatible) schema found. 

273 

274 For internal use only! 

275 """ 

276 if ref := self.resolve(p_name, version): 

277 self._ensure_is_loaded(ref) 

278 return self._LOADED_PLUGINS[ref] 

279 else: # error 

280 msg = f"{p_name}" 

281 if version: 

282 msg += f": no installed version is compatible with {version}" 

283 raise KeyError(msg) 

284 

285 # inspired by this nice trick: https://stackoverflow.com/a/60362860 

286 PRX = TypeVar("PRX", bound="Type[T]") # type: ignore 

287 

288 @overload 

289 def get(self, key: str, version: Optional[SemVerTuple] = None) -> Optional[Type[T]]: 

290 ... # pragma: no cover 

291 

292 @overload 

293 def get(self, key: PRX, version: Optional[SemVerTuple] = None) -> Optional[PRX]: 

294 ... # pragma: no cover 

295 

296 def get( 

297 self, key: Union[str, PRX], version: Optional[SemVerTuple] = None 

298 ) -> Union[Type[T], PRX, None]: 

299 key_, version = plugin_args(key, version) 

300 

301 # retrieve compatible plugin 

302 try: 

303 ret = self._get_unsafe(key_, version) 

304 except KeyError: 

305 return None 

306 

307 if version is None: 

308 # no version constraint was passed or inferred -> mark it 

309 ret = UndefVersion._mark_class(ret) 

310 

311 if isinstance(key, str): 

312 return cast(Type[T], ret) 

313 else: 

314 return ret 

315 

316 # ---- 

317 

318 def _ensure_is_loaded(self, ref: AnyPluginRef): 

319 """Load plugin from entrypoint, if it is not loaded yet.""" 

320 assert ref.group == self.name 

321 if ref in self._LOADED_PLUGINS: 

322 return # already loaded, all good 

323 

324 ep_name = util.to_ep_name(ref.name, ref.version) 

325 ret = self._ENTRY_POINTS[ep_name].load() 

326 self._LOADED_PLUGINS[ref] = ret 

327 

328 self._load_plugin(ep_name, ret) 

329 

330 def _explicit_plugin_deps(self, plugin) -> Set[AnyPluginRef]: 

331 """Return all plugin dependencies that must be taken into account.""" 

332 def_deps = set(plugin.Plugin.requires) 

333 extra_deps = set(self.plugin_deps(plugin) or set()) 

334 return def_deps.union(extra_deps) 

335 

336 def plugin_deps(self, plugin) -> Set[AnyPluginRef]: 

337 """Return additional automatically inferred dependencies for a plugin.""" 

338 

339 def _load_plugin(self, ep_name: EPName, plugin): 

340 """Run checks and finalize loaded plugin.""" 

341 from ..plugins import plugingroups 

342 

343 # run inner Plugin class checks (with possibly new Fields cls) 

344 if not plugin.__dict__.get("Plugin"): 

345 raise TypeError(f"{ep_name}: {plugin} is missing Plugin inner class!") 

346 # pass ep_name to check that it agrees with the plugin info 

347 plugin.Plugin = self.Plugin.plugin_info_class.parse_info( 

348 plugin.Plugin, ep_name=ep_name 

349 ) 

350 

351 # do general checks first, if they fail no need to continue 

352 self._check_common(ep_name, plugin) 

353 self.check_plugin(ep_name, plugin) 

354 

355 for dep_ref in self._explicit_plugin_deps(plugin): 

356 dep_grp = plugingroups[dep_ref.group] 

357 dep_grp._ensure_is_loaded(dep_ref) 

358 

359 self.init_plugin(plugin) 

360 

361 def _check_common(self, ep_name: EPName, plugin): 

362 """Perform both the common and specific checks a registered plugin. 

363 

364 Raises a TypeError with message in case of failure. 

365 """ 

366 # check correct base class of plugin, if stated 

367 if self.Plugin.plugin_class: 

368 util.check_is_subclass(ep_name, plugin, self.Plugin.plugin_class) 

369 

370 def check_plugin(self, ep_name: EPName, plugin: Type[T]): 

371 """Perform plugin group specific checks on a registered plugin. 

372 

373 Raises a TypeError with message in case of failure. 

374 

375 To be overridden in subclasses for plugin group specific checks. 

376 

377 Args: 

378 ep_name: Declared entrypoint name. 

379 plugin: Object the entrypoint is pointing to. 

380 """ 

381 # NOTE: following cannot happen as long as we enforce 

382 # overriding check_plugin. 

383 # keep that here for now, in case we loosen this 

384 # if type(self) is not PluginGroup: 

385 # return # is not the "plugingroup" group itself 

386 

387 # these are the checks done for other plugin group plugins: 

388 util.check_is_subclass(ep_name, plugin, PluginGroup) 

389 util.check_is_subclass(ep_name, self.Plugin.plugin_info_class, PluginBase) 

390 if plugin != PluginGroup: # exclude itself. this IS its check_plugin 

391 util.check_implements_method(ep_name, plugin, PluginGroup.check_plugin) 

392 

393 # NOTE: following cannot happen as long as we set the group 

394 # automatically using the metaclass. 

395 # keep that in case we decide to change that / get rid of the metaclass 

396 # --- 

397 # make sure that the declared plugin_info_class for the group sets 'group' 

398 # and it is also equal to the plugin group 'name'. 

399 # this is the safest way to make sure that Plugin.ref() works correctly. 

400 # ppgi_cls = plugin.Plugin.plugin_info_class 

401 # if not ppgi_cls.group: 

402 # raise TypeError(f"{ep_name}: {ppgi_cls} is missing 'group' attribute!") 

403 # if not ppgi_cls.group == plugin.Plugin.name: 

404 # msg = f"{ep_name}: {ppgi_cls.__name__}.group != {plugin.__name__}.Plugin.name!" 

405 # raise TypeError(msg) 

406 

407 def init_plugin(self, plugin: Type[T]): 

408 """Override this to do something after the plugin has been checked.""" 

409 if type(self) is not PluginGroup: 

410 return # is not the "plugingroup" group itself 

411 create_pg(plugin) # create plugin group if it does not exist 

412 

413 

414# ---- 

415 

416_plugin_groups: Dict[str, PluginGroup] = {} 

417"""Instances of initialized plugin groups.""" 

418 

419 

420def create_pg(pg_cls): 

421 """Create plugin group instance if it does not exist.""" 

422 pg_ref = AnyPluginRef( 

423 group=PG_GROUP_NAME, name=pg_cls.Plugin.name, version=pg_cls.Plugin.version 

424 ) 

425 if pg_ref in _plugin_groups: 

426 return _plugin_groups[pg_ref] 

427 

428 if not isinstance(pg_cls.Plugin, PluginBase): 

429 # magic - substitute Plugin class with parsed plugin object 

430 pg_cls.Plugin = PGPlugin.parse_info(pg_cls.Plugin) 

431 

432 # TODO: currently we cannot distinguish entrypoints 

433 # for different versions of the plugin group. 

434 # should not be problematic for now, 

435 # as the groups should not change much 

436 pg = pg_cls(get_group(pg_ref.name)) 

437 _plugin_groups[pg_ref] = pg