xref: /aosp_15_r20/external/pigweed/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tools for compiling and importing Python protos on the fly."""
15
16from __future__ import annotations
17
18from collections.abc import Mapping
19import importlib.util
20import logging
21import os
22from pathlib import Path
23import subprocess
24import shlex
25import tempfile
26from types import ModuleType
27from typing import (
28    Generic,
29    Iterable,
30    Iterator,
31    NamedTuple,
32    Set,
33    TypeVar,
34)
35
36try:
37    # pylint: disable=wrong-import-position
38    import black
39
40    black_mode: black.Mode | None = black.Mode(string_normalization=False)
41
42    # pylint: enable=wrong-import-position
43except ImportError:
44    black = None  # type: ignore
45    black_mode = None
46
47_LOG = logging.getLogger(__name__)
48
49
50def _find_protoc() -> str:
51    """Locates a protoc binary to use for compiling protos."""
52    if 'PROTOC' in os.environ:
53        return os.environ['PROTOC']
54
55    # Fallback is assuming `protoc` is on the system PATH.
56    return 'protoc'
57
58
59def compile_protos(
60    output_dir: Path | str,
61    proto_files: Iterable[Path | str],
62    includes: Iterable[Path | str] = (),
63) -> None:
64    """Compiles proto files for Python by invoking the protobuf compiler.
65
66    Proto files not covered by one of the provided include paths will have their
67    directory added as an include path.
68    """
69    proto_paths: list[Path] = [Path(f).resolve() for f in proto_files]
70    include_paths: Set[Path] = set(Path(d).resolve() for d in includes)
71
72    for path in proto_paths:
73        if not any(include in path.parents for include in include_paths):
74            include_paths.add(path.parent)
75
76    cmd: tuple[Path | str, ...] = (
77        _find_protoc(),
78        '--experimental_allow_proto3_optional',
79        '--python_out',
80        os.path.abspath(output_dir),
81        *(f'-I{d}' for d in include_paths),
82        *proto_paths,
83    )
84
85    _LOG.debug('%s', ' '.join(shlex.quote(str(c)) for c in cmd))
86    process = subprocess.run(cmd, capture_output=True)
87
88    if process.returncode:
89        _LOG.error(
90            'protoc invocation failed!\n%s\n%s',
91            ' '.join(shlex.quote(str(c)) for c in cmd),
92            process.stderr.decode(),
93        )
94        process.check_returncode()
95
96
97def _import_module(name: str, path: str) -> ModuleType:
98    spec = importlib.util.spec_from_file_location(name, path)
99    assert spec is not None
100    module = importlib.util.module_from_spec(spec)
101    spec.loader.exec_module(module)  # type: ignore[union-attr]
102    return module
103
104
105def import_modules(directory: Path | str) -> Iterator:
106    """Imports modules in a directory and yields them."""
107    parent = os.path.dirname(directory)
108
109    for dirpath, _, files in os.walk(directory):
110        path_parts = os.path.relpath(dirpath, parent).split(os.sep)
111
112        for file in files:
113            name, ext = os.path.splitext(file)
114
115            if ext == '.py':
116                yield _import_module(
117                    f'{".".join(path_parts)}.{name}',
118                    os.path.join(dirpath, file),
119                )
120
121
122def compile_and_import(
123    proto_files: Iterable[Path | str],
124    includes: Iterable[Path | str] = (),
125    output_dir: Path | str | None = None,
126) -> Iterator:
127    """Compiles protos and imports their modules; yields the proto modules.
128
129    Args:
130      proto_files: paths to .proto files to compile
131      includes: include paths to use for .proto compilation
132      output_dir: where to place the generated modules; a temporary directory is
133          used if omitted
134
135    Yields:
136      the generated protobuf Python modules
137    """
138
139    if output_dir:
140        compile_protos(output_dir, proto_files, includes)
141        yield from import_modules(output_dir)
142    else:
143        with tempfile.TemporaryDirectory(prefix='compiled_protos_') as tempdir:
144            compile_protos(tempdir, proto_files, includes)
145            yield from import_modules(tempdir)
146
147
148def compile_and_import_file(
149    proto_file: Path | str,
150    includes: Iterable[Path | str] = (),
151    output_dir: Path | str | None = None,
152):
153    """Compiles and imports the module for a single .proto file."""
154    return next(iter(compile_and_import([proto_file], includes, output_dir)))
155
156
157def compile_and_import_strings(
158    contents: Iterable[str],
159    includes: Iterable[Path | str] = (),
160    output_dir: Path | str | None = None,
161) -> Iterator:
162    """Compiles protos in one or more strings."""
163
164    if isinstance(contents, str):
165        contents = [contents]
166
167    with tempfile.TemporaryDirectory(prefix='proto_sources_') as path:
168        protos = []
169
170        for proto in contents:
171            # Use a hash of the proto so the same contents map to the same file
172            # name. The protobuf package complains if it seems the same contents
173            # in files with different names.
174            protos.append(Path(path, f'protobuf_{hash(proto):x}.proto'))
175            protos[-1].write_text(proto)
176
177        yield from compile_and_import(protos, includes, output_dir)
178
179
180T = TypeVar('T')
181
182
183class _NestedPackage(Generic[T]):
184    """Facilitates navigating protobuf packages as attributes."""
185
186    def __init__(self, package: str):
187        self._packages: dict[str, _NestedPackage[T]] = {}
188        self._items: list[T] = []
189        self._package = package
190
191    def _add_package(self, subpackage: str, package: _NestedPackage) -> None:
192        self._packages[subpackage] = package
193
194    def _add_item(self, item) -> None:
195        if item not in self._items:  # Don't store the same item multiple times.
196            self._items.append(item)
197
198    def __getattr__(self, attr: str):
199        """Look up subpackages or package members."""
200        if attr in self._packages:
201            return self._packages[attr]
202
203        for item in self._items:
204            if hasattr(item, attr):
205                return getattr(item, attr)
206
207        raise AttributeError(
208            f'Proto package "{self._package}" does not contain "{attr}"'
209        )
210
211    def __getitem__(self, subpackage: str) -> _NestedPackage[T]:
212        """Support accessing nested packages by name."""
213        result = self
214
215        for package in subpackage.split('.'):
216            result = result._packages[package]
217
218        return result
219
220    def __dir__(self) -> list[str]:
221        """List subpackages and members of modules as attributes."""
222        attributes = list(self._packages)
223
224        for item in self._items:
225            for attr, value in vars(item).items():
226                # Exclude private variables and modules from dir().
227                if not attr.startswith('_') and not isinstance(
228                    value, ModuleType
229                ):
230                    attributes.append(attr)
231
232        return attributes
233
234    def __iter__(self) -> Iterator['_NestedPackage[T]']:
235        """Iterate over nested packages."""
236        return iter(self._packages.values())
237
238    def __repr__(self) -> str:
239        msg = [f'ProtoPackage({self._package!r}']
240
241        public_members = [
242            i
243            for i in vars(self)
244            if i not in self._packages and not i.startswith('_')
245        ]
246        if public_members:
247            msg.append(f'members={str(public_members)}')
248
249        if self._packages:
250            msg.append(f'subpackages={str(list(self._packages))}')
251
252        return ', '.join(msg) + ')'
253
254    def __str__(self) -> str:
255        return self._package
256
257
258class Packages(NamedTuple):
259    """Items in a protobuf package structure; returned from as_package."""
260
261    items_by_package: dict[str, list]
262    packages: _NestedPackage
263
264
265def as_packages(
266    items: Iterable[tuple[str, T]], packages: Packages | None = None
267) -> Packages:
268    """Places items in a proto-style package structure navigable by attributes.
269
270    Args:
271      items: (package, item) tuples to insert into the package structure
272      packages: if provided, update this Packages instead of creating a new one
273    """
274    if packages is None:
275        packages = Packages({}, _NestedPackage(''))
276
277    for package, item in items:
278        packages.items_by_package.setdefault(package, []).append(item)
279
280        entry = packages.packages
281        subpackages = package.split('.')
282
283        # pylint: disable=protected-access
284        for i, subpackage in enumerate(subpackages, 1):
285            if subpackage not in entry._packages:
286                entry._add_package(
287                    subpackage, _NestedPackage('.'.join(subpackages[:i]))
288                )
289
290            entry = entry._packages[subpackage]
291
292        entry._add_item(item)
293        # pylint: enable=protected-access
294
295    return packages
296
297
298PathOrModule = str | Path | ModuleType
299
300
301class Library:
302    """A collection of protocol buffer modules sorted by package.
303
304    In Python, each .proto file is compiled into a Python module. The Library
305    class makes it simple to navigate a collection of Python modules
306    corresponding to .proto files, without relying on the location of these
307    compiled modules.
308
309    Proto messages and other types can be directly accessed by their protocol
310    buffer package name. For example, the foo.bar.Baz message can be accessed
311    in a Library called `protos` as:
312
313      protos.packages.foo.bar.Baz
314
315    A Library also provides the modules_by_package dictionary, for looking up
316    the list of modules in a particular package, and the modules() generator
317    for iterating over all modules.
318    """
319
320    @classmethod
321    def from_paths(cls, protos: Iterable[str | Path | ModuleType]) -> Library:
322        """Creates a Library from paths to proto files or proto modules."""
323        paths: list[Path | str] = []
324        modules: list[ModuleType] = []
325
326        for proto in protos:
327            if isinstance(proto, (Path, str)):
328                paths.append(proto)
329            else:
330                modules.append(proto)
331
332        if paths:
333            modules += compile_and_import(paths)
334        return Library(modules)
335
336    @classmethod
337    def from_strings(
338        cls,
339        contents: Iterable[str],
340        includes: Iterable[Path | str] = (),
341        output_dir: Path | str | None = None,
342    ) -> Library:
343        """Creates a proto library from protos in the provided strings."""
344        return cls(compile_and_import_strings(contents, includes, output_dir))
345
346    def __init__(self, modules: Iterable[ModuleType]):
347        """Constructs a Library from an iterable of modules.
348
349        A Library can be constructed with modules dynamically compiled by
350        compile_and_import. For example:
351
352            protos = Library(compile_and_import(list_of_proto_files))
353        """
354        self.modules_by_package, self.packages = as_packages(
355            (m.DESCRIPTOR.package, m)  # type: ignore[attr-defined]
356            for m in modules
357        )
358
359    def modules(self) -> Iterable:
360        """Iterates over all protobuf modules in this library."""
361        for module_list in self.modules_by_package.values():
362            yield from module_list
363
364    def messages(self) -> Iterable:
365        """Iterates over all protobuf messages in this library."""
366        for module in self.modules():
367            yield from _nested_messages(
368                module, module.DESCRIPTOR.message_types_by_name
369            )
370
371
372def _nested_messages(scope, message_names: Iterable[str]) -> Iterator:
373    for name in message_names:
374        msg = getattr(scope, name)
375        yield msg
376        yield from _nested_messages(msg, msg.DESCRIPTOR.nested_types_by_name)
377
378
379def _repr_char(char: int) -> str:
380    r"""Returns an ASCII char or the \x code for non-printable values."""
381    if ord(' ') <= char <= ord('~'):
382        return r"\'" if chr(char) == "'" else chr(char)
383
384    return f'\\x{char:02X}'
385
386
387def bytes_repr(value: bytes) -> str:
388    """Prints bytes as mixed ASCII only if at least half are printable."""
389    ascii_char_count = sum(ord(' ') <= c <= ord('~') for c in value)
390    if ascii_char_count >= len(value) / 2:
391        contents = ''.join(_repr_char(c) for c in value)
392    else:
393        contents = ''.join(f'\\x{c:02X}' for c in value)
394
395    return f"b'{contents}'"
396
397
398def _field_repr(field, value) -> str:
399    if field.type == field.TYPE_ENUM:
400        try:
401            enum = field.enum_type.values_by_number[value]
402            return f'{field.enum_type.full_name}.{enum.name}'
403        except KeyError:
404            return repr(value)
405
406    if field.type == field.TYPE_MESSAGE:
407        return proto_repr(value)
408
409    if field.type == field.TYPE_BYTES:
410        return bytes_repr(value)
411
412    return repr(value)
413
414
415def _proto_repr(message) -> Iterator[str]:
416    for field in message.DESCRIPTOR.fields:
417        value = getattr(message, field.name)
418
419        # Skip fields that are not present.
420        try:
421            if not message.HasField(field.name):
422                continue
423        except ValueError:
424            # Skip default-valued fields that don't support HasField.
425            if (
426                field.label != field.LABEL_REPEATED
427                and value == field.default_value
428            ):
429                continue
430
431        if field.label == field.LABEL_REPEATED:
432            if not value:
433                continue
434
435            if isinstance(value, Mapping):
436                key_desc, value_desc = field.message_type.fields
437                values = ', '.join(
438                    f'{_field_repr(key_desc, k)}: {_field_repr(value_desc, v)}'
439                    for k, v in value.items()
440                )
441                yield f'{field.name}={{{values}}}'
442            else:
443                values = ', '.join(_field_repr(field, v) for v in value)
444                yield f'{field.name}=[{values}]'
445        else:
446            yield f'{field.name}={_field_repr(field, value)}'
447
448
449def proto_repr(message, *, wrap: bool = True) -> str:
450    """Creates a repr-like string for a protobuf.
451
452    In an interactive console that imports proto objects into the namespace, the
453    output of proto_repr() can be used as Python source to create a proto
454    object.
455
456    Args:
457      message: The protobuf message to format
458      wrap: If true and black is available, the output is wrapped according to
459          PEP8 using black.
460    """
461    raw = f'{message.DESCRIPTOR.full_name}({", ".join(_proto_repr(message))})'
462
463    if wrap and black is not None and black_mode is not None:
464        return black.format_str(raw, mode=black_mode).strip()
465
466    return raw
467