xref: /aosp_15_r20/external/pigweed/pw_cli/py/pw_cli/plugins.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Provides general purpose plugin functionality.
15
16As used in this module, a plugin is a Python object associated with a name.
17Plugins are registered in a Registry. The plugin object is typically a function,
18but can be anything.
19
20Plugins may be loaded in a variety of ways:
21
22- Listed in a plugins file in the file system (e.g. as "name module target").
23- Registered in a Python file using a decorator (@my_registry.plugin).
24- Registered directly or by name with function calls on a registry object.
25
26This functionality can be used to create plugins for command line tools,
27interactive consoles, or anything else. Pigweed's pw command uses this module
28for its plugins.
29"""
30
31from __future__ import annotations
32
33import collections
34import collections.abc
35import importlib
36import inspect
37import logging
38from pathlib import Path
39import pkgutil
40import sys
41from textwrap import TextWrapper
42import types
43from typing import Any, Callable, Iterable, Iterator, Set
44
45_LOG = logging.getLogger(__name__)
46_BUILT_IN = '<built-in>'
47
48
49class Error(Exception):
50    """Indicates that a plugin is invalid or cannot be registered."""
51
52    def __str__(self):
53        """Displays the error as a string, including the __cause__ if present.
54
55        Adding __cause__ gives useful context without displaying a backtrace.
56        """
57        if self.__cause__ is None:
58            return super().__str__()
59
60        return (
61            f'{super().__str__()} '
62            f'({type(self.__cause__).__name__}: {self.__cause__})'
63        )
64
65
66def _get_module(member: object) -> types.ModuleType:
67    """Gets the module or a fake module if the module isn't found."""
68    module = inspect.getmodule(member)
69    return module if module else types.ModuleType('<unknown>')
70
71
72class Plugin:
73    """Represents a Python entity registered as a plugin.
74
75    Each plugin resolves to a Python object, typically a function.
76    """
77
78    @classmethod
79    def from_name(
80        cls,
81        name: str,
82        module_name: str,
83        member_name: str,
84        source: Path | None,
85    ) -> Plugin:
86        """Creates a plugin by module and attribute name.
87
88        Args:
89          name: the name of the plugin
90          module_name: Python module name (e.g. 'foo_pkg.bar')
91          member_name: the name of the member in the module
92          source: path to the plugins file that declared this plugin, if any
93        """
94
95        # Attempt to access the module and member. Catch any errors that might
96        # occur, since a bad plugin shouldn't be a fatal error.
97        try:
98            module = importlib.import_module(module_name)
99        except Exception as err:
100            _LOG.debug(
101                'Failed to import module "%s" for "%s" plugin',
102                module_name,
103                name,
104                exc_info=True,
105            )
106            raise Error(f'Failed to import module "{module_name}"') from err
107
108        try:
109            member = getattr(module, member_name)
110        except AttributeError as err:
111            raise Error(
112                f'"{module_name}.{member_name}" does not exist'
113            ) from err
114
115        return cls(name, member, source)
116
117    def __init__(
118        self, name: str, target: Any, source: Path | None = None
119    ) -> None:
120        """Creates a plugin for the provided target."""
121        self.name = name
122        self._module = _get_module(target)
123        self.target = target
124        self.source = source
125
126    @property
127    def target_name(self) -> str:
128        return (
129            f'{self._module.__name__}.'
130            f'{getattr(self.target, "__name__", self.target)}'
131        )
132
133    @property
134    def source_name(self) -> str:
135        return _BUILT_IN if self.source is None else str(self.source)
136
137    def run_with_argv(self, argv: Iterable[str]) -> int:
138        """Sets sys.argv and calls the plugin function.
139
140        This is used to call a plugin as if from the command line.
141        """
142        original_sys_argv = sys.argv
143        sys.argv = [f'pw {self.name}', *argv]
144
145        try:
146            # If the plugin doesn't return an exit code assume it succeeded.
147            return self.target() or 0
148        finally:
149            sys.argv = original_sys_argv
150
151    def help(self, full: bool = False) -> str:
152        """Returns a description of this plugin from its docstring."""
153        docstring = self.target.__doc__ or self._module.__doc__ or ''
154        return docstring if full else next(iter(docstring.splitlines()), '')
155
156    def details(self, full: bool = False) -> Iterator[str]:
157        yield f'help    {self.help(full=full)}'
158        yield f'module  {self._module.__name__}'
159        yield f'target  {getattr(self.target, "__name__", self.target)}'
160        yield f'source  {self.source_name}'
161
162    def __repr__(self) -> str:
163        return (
164            f'{self.__class__.__name__}(name={self.name!r}, '
165            f'target={self.target_name}'
166            f'{f", source={self.source_name!r}" if self.source else ""})'
167        )
168
169
170def callable_with_no_args(plugin: Plugin) -> None:
171    """Checks that a plugin is callable without arguments.
172
173    May be used for the validator argument to Registry.
174    """
175    try:
176        params = inspect.signature(plugin.target).parameters
177    except TypeError:
178        raise Error(
179            'Plugin functions must be callable, but '
180            f'{plugin.target_name} is a '
181            f'{type(plugin.target).__name__}'
182        )
183
184    positional = sum(p.default == p.empty for p in params.values())
185    if positional:
186        raise Error(
187            f'Plugin functions cannot have any required positional '
188            f'arguments, but {plugin.target_name} has {positional}'
189        )
190
191
192class Registry(collections.abc.Mapping):
193    """Manages a set of plugins from Python modules or plugins files."""
194
195    def __init__(
196        self, validator: Callable[[Plugin], Any] = lambda _: None
197    ) -> None:
198        """Creates a new, empty plugins registry.
199
200        Args:
201          validator: Function that checks whether a plugin is valid and should
202              be registered. Must raise plugins.Error is the plugin is invalid.
203        """
204
205        self._registry: dict[str, Plugin] = {}
206        self._sources: Set[Path] = set()  # Paths to plugins files
207        self._errors: dict[str, list[Exception]] = collections.defaultdict(list)
208        self._validate_plugin = validator
209
210    def __getitem__(self, name: str) -> Plugin:
211        """Accesses a plugin by name; raises KeyError if it does not exist."""
212        if name in self._registry:
213            return self._registry[name]
214
215        if name in self._errors:
216            raise KeyError(
217                f'Registration for "{name}" failed: '
218                + ', '.join(str(e) for e in self._errors[name])
219            )
220
221        raise KeyError(f'The plugin "{name}" has not been registered')
222
223    def __iter__(self) -> Iterator[str]:
224        return iter(self._registry)
225
226    def __len__(self) -> int:
227        return len(self._registry)
228
229    def errors(self) -> dict[str, list[Exception]]:
230        return self._errors
231
232    def run_with_argv(self, name: str, argv: Iterable[str]) -> int:
233        """Runs a plugin by name, setting sys.argv to the provided args.
234
235        This is used to run a command as if it were executed directly from the
236        command line. The plugin is expected to return an int.
237
238        Raises:
239          KeyError if plugin is not registered.
240        """
241        return self[name].run_with_argv(argv)
242
243    def _should_register(self, plugin: Plugin) -> bool:
244        """Determines and logs if a plugin should be registered or not.
245
246        Some errors are exceptions, others are not.
247        """
248
249        if plugin.name in self._registry and plugin.source is None:
250            raise Error(
251                f'Attempted to register built-in plugin "{plugin.name}", but '
252                'a plugin with that name was previously registered '
253                f'({self[plugin.name]})!'
254            )
255
256        # Run the user-provided validation function, which raises exceptions
257        # if there are errors.
258        self._validate_plugin(plugin)
259
260        existing = self._registry.get(plugin.name)
261
262        if existing is None:
263            return True
264
265        if existing.source is None:
266            _LOG.debug(
267                '%s: Overriding built-in plugin "%s" with %s',
268                plugin.source_name,
269                plugin.name,
270                plugin.target_name,
271            )
272            return True
273
274        if plugin.source != existing.source:
275            _LOG.debug(
276                '%s: The plugin "%s" was previously registered in %s; '
277                'ignoring registration as %s',
278                plugin.source_name,
279                plugin.name,
280                self._registry[plugin.name].source,
281                plugin.target_name,
282            )
283        elif plugin.source not in self._sources:
284            _LOG.warning(
285                '%s: "%s" is registered file multiple times in this file! '
286                'Only the first registration takes effect',
287                plugin.source_name,
288                plugin.name,
289            )
290
291        return False
292
293    def register(self, name: str, target: Any) -> Plugin | None:
294        """Registers an object as a plugin."""
295        return self._register(Plugin(name, target, None))
296
297    def register_by_name(
298        self,
299        name: str,
300        module_name: str,
301        member_name: str,
302        source: Path | None = None,
303    ) -> Plugin | None:
304        """Registers an object from its module and name as a plugin."""
305        return self._register(
306            Plugin.from_name(name, module_name, member_name, source)
307        )
308
309    def _register(self, plugin: Plugin) -> Plugin | None:
310        # Prohibit functions not from a plugins file from overriding others.
311        if not self._should_register(plugin):
312            return None
313
314        self._registry[plugin.name] = plugin
315        _LOG.debug(
316            '%s: Registered plugin "%s" for %s',
317            plugin.source_name,
318            plugin.name,
319            plugin.target_name,
320        )
321
322        return plugin
323
324    def register_config(
325        self,
326        config: dict,
327        path: Path | None = None,
328    ) -> None:
329        """Registers plugins from a Pigweed config.
330
331        Any exceptions raised from parsing the file are caught and logged.
332        """
333        plugins = config.get('pw', {}).get('pw_cli', {}).get('plugins', {})
334        for name, location in plugins.items():
335            module = location.pop('module')
336            function = location.pop('function')
337            if location:
338                raise ValueError(f'unrecognized plugin options: {location}')
339
340            try:
341                self.register_by_name(name, module, function, path)
342            except Error as err:
343                self._errors[name].append(err)
344                _LOG.error(
345                    '%s Failed to register plugin "%s": %s',
346                    path,
347                    name,
348                    err,
349                )
350
351    def register_file(self, path: Path) -> None:
352        """Registers plugins from a plugins file.
353
354        Any exceptions raised from parsing the file are caught and logged.
355        """
356        with path.open() as contents:
357            for lineno, line in enumerate(contents, 1):
358                line = line.strip()
359                if not line or line.startswith('#'):
360                    continue
361
362                try:
363                    name, module, function = line.split()
364                except ValueError as err:
365                    self._errors[line.strip()].append(Error(err))
366                    _LOG.error(
367                        '%s:%d: Failed to parse plugin entry "%s": '
368                        'Expected 3 items (name, module, function), '
369                        'got %d',
370                        path,
371                        lineno,
372                        line,
373                        len(line.split()),
374                    )
375                    continue
376
377                try:
378                    self.register_by_name(name, module, function, path)
379                except Error as err:
380                    self._errors[name].append(err)
381                    _LOG.error(
382                        '%s: Failed to register plugin "%s": %s',
383                        path,
384                        name,
385                        err,
386                    )
387
388        self._sources.add(path)
389
390    def register_directory(
391        self,
392        directory: Path,
393        file_name: str,
394        restrict_to: Path | None = None,
395    ) -> None:
396        """Finds and registers plugins from plugins files in a directory.
397
398        Args:
399          directory: The directory from which to start searching up.
400          file_name: The name of plugins files to look for.
401          restrict_to: If provided, do not search higher than this directory.
402        """
403        for path in find_all_in_parents(file_name, directory):
404            if not path.is_file():
405                continue
406
407            if restrict_to is not None and restrict_to not in path.parents:
408                _LOG.debug(
409                    "Skipping plugins file %s because it's outside of %s",
410                    path,
411                    restrict_to,
412                )
413                continue
414
415            _LOG.debug('Found plugins file %s', path)
416            self.register_file(path)
417
418    def short_help(self) -> str:
419        """Returns a help string for the registered plugins."""
420        width = (
421            max(len(name) for name in self._registry) + 1
422            if self._registry
423            else 1
424        )
425        help_items = '\n'.join(
426            f'  {name:{width}} {plugin.help()}'
427            for name, plugin in sorted(self._registry.items())
428        )
429        return f'supported plugins:\n{help_items}'
430
431    def detailed_help(self, plugins: Iterable[str] = ()) -> Iterator[str]:
432        """Yields lines of detailed information about commands."""
433        if not plugins:
434            plugins = list(self._registry)
435
436        yield '\ndetailed plugin information:'
437
438        wrapper = TextWrapper(
439            width=80, initial_indent='   ', subsequent_indent=' ' * 11
440        )
441
442        plugins = sorted(plugins)
443        for plugin in plugins:
444            yield f'  [{plugin}]'
445
446            try:
447                for line in self[plugin].details(full=len(plugins) == 1):
448                    yield wrapper.fill(line)
449            except KeyError as err:
450                yield wrapper.fill(f'error   {str(err)[1:-1]}')
451
452            yield ''
453
454        yield 'Plugins files:'
455
456        if self._sources:
457            yield from (
458                f'  [{i}] {file}' for i, file in enumerate(self._sources, 1)
459            )
460        else:
461            yield '  (none found)'
462
463    def plugin(
464        self, function: Callable | None = None, *, name: str | None = None
465    ) -> Callable[[Callable], Callable]:
466        """Decorator that registers a function with this plugin registry."""
467
468        def decorator(function: Callable) -> Callable:
469            self.register(function.__name__ if name is None else name, function)
470            return function
471
472        if function is None:
473            return decorator
474
475        self.register(function.__name__, function)
476        return function
477
478
479def find_in_parents(name: str, path: Path) -> Path | None:
480    """Searches parent directories of the path for a file or directory."""
481    path = path.resolve()
482
483    while not path.joinpath(name).exists():
484        path = path.parent
485
486        if path.samefile(path.parent):
487            return None
488
489    return path.joinpath(name)
490
491
492def find_all_in_parents(name: str, path: Path) -> Iterator[Path]:
493    """Searches all parent directories of the path for files or directories."""
494
495    while True:
496        result = find_in_parents(name, path)
497        if result is None:
498            return
499
500        yield result
501        path = result.parent.parent
502
503
504def import_submodules(
505    module: types.ModuleType, recursive: bool = False
506) -> None:
507    """Imports the submodules of a package.
508
509    This can be used to collect plugins registered with a decorator from a
510    directory.
511    """
512    path = module.__path__  # type: ignore[attr-defined]
513    if recursive:
514        modules = pkgutil.walk_packages(path, module.__name__ + '.')
515    else:
516        modules = pkgutil.iter_modules(path, module.__name__ + '.')
517
518    for info in modules:
519        importlib.import_module(info.name)
520