1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tools for compiling and importing Python protos on the fly.""" 15 16from __future__ import annotations 17 18from collections.abc import Mapping 19import importlib.util 20import logging 21import os 22from pathlib import Path 23import subprocess 24import shlex 25import tempfile 26from types import ModuleType 27from typing import ( 28 Generic, 29 Iterable, 30 Iterator, 31 NamedTuple, 32 Set, 33 TypeVar, 34) 35 36try: 37 # pylint: disable=wrong-import-position 38 import black 39 40 black_mode: black.Mode | None = black.Mode(string_normalization=False) 41 42 # pylint: enable=wrong-import-position 43except ImportError: 44 black = None # type: ignore 45 black_mode = None 46 47_LOG = logging.getLogger(__name__) 48 49 50def _find_protoc() -> str: 51 """Locates a protoc binary to use for compiling protos.""" 52 if 'PROTOC' in os.environ: 53 return os.environ['PROTOC'] 54 55 # Fallback is assuming `protoc` is on the system PATH. 56 return 'protoc' 57 58 59def compile_protos( 60 output_dir: Path | str, 61 proto_files: Iterable[Path | str], 62 includes: Iterable[Path | str] = (), 63) -> None: 64 """Compiles proto files for Python by invoking the protobuf compiler. 65 66 Proto files not covered by one of the provided include paths will have their 67 directory added as an include path. 68 """ 69 proto_paths: list[Path] = [Path(f).resolve() for f in proto_files] 70 include_paths: Set[Path] = set(Path(d).resolve() for d in includes) 71 72 for path in proto_paths: 73 if not any(include in path.parents for include in include_paths): 74 include_paths.add(path.parent) 75 76 cmd: tuple[Path | str, ...] = ( 77 _find_protoc(), 78 '--experimental_allow_proto3_optional', 79 '--python_out', 80 os.path.abspath(output_dir), 81 *(f'-I{d}' for d in include_paths), 82 *proto_paths, 83 ) 84 85 _LOG.debug('%s', ' '.join(shlex.quote(str(c)) for c in cmd)) 86 process = subprocess.run(cmd, capture_output=True) 87 88 if process.returncode: 89 _LOG.error( 90 'protoc invocation failed!\n%s\n%s', 91 ' '.join(shlex.quote(str(c)) for c in cmd), 92 process.stderr.decode(), 93 ) 94 process.check_returncode() 95 96 97def _import_module(name: str, path: str) -> ModuleType: 98 spec = importlib.util.spec_from_file_location(name, path) 99 assert spec is not None 100 module = importlib.util.module_from_spec(spec) 101 spec.loader.exec_module(module) # type: ignore[union-attr] 102 return module 103 104 105def import_modules(directory: Path | str) -> Iterator: 106 """Imports modules in a directory and yields them.""" 107 parent = os.path.dirname(directory) 108 109 for dirpath, _, files in os.walk(directory): 110 path_parts = os.path.relpath(dirpath, parent).split(os.sep) 111 112 for file in files: 113 name, ext = os.path.splitext(file) 114 115 if ext == '.py': 116 yield _import_module( 117 f'{".".join(path_parts)}.{name}', 118 os.path.join(dirpath, file), 119 ) 120 121 122def compile_and_import( 123 proto_files: Iterable[Path | str], 124 includes: Iterable[Path | str] = (), 125 output_dir: Path | str | None = None, 126) -> Iterator: 127 """Compiles protos and imports their modules; yields the proto modules. 128 129 Args: 130 proto_files: paths to .proto files to compile 131 includes: include paths to use for .proto compilation 132 output_dir: where to place the generated modules; a temporary directory is 133 used if omitted 134 135 Yields: 136 the generated protobuf Python modules 137 """ 138 139 if output_dir: 140 compile_protos(output_dir, proto_files, includes) 141 yield from import_modules(output_dir) 142 else: 143 with tempfile.TemporaryDirectory(prefix='compiled_protos_') as tempdir: 144 compile_protos(tempdir, proto_files, includes) 145 yield from import_modules(tempdir) 146 147 148def compile_and_import_file( 149 proto_file: Path | str, 150 includes: Iterable[Path | str] = (), 151 output_dir: Path | str | None = None, 152): 153 """Compiles and imports the module for a single .proto file.""" 154 return next(iter(compile_and_import([proto_file], includes, output_dir))) 155 156 157def compile_and_import_strings( 158 contents: Iterable[str], 159 includes: Iterable[Path | str] = (), 160 output_dir: Path | str | None = None, 161) -> Iterator: 162 """Compiles protos in one or more strings.""" 163 164 if isinstance(contents, str): 165 contents = [contents] 166 167 with tempfile.TemporaryDirectory(prefix='proto_sources_') as path: 168 protos = [] 169 170 for proto in contents: 171 # Use a hash of the proto so the same contents map to the same file 172 # name. The protobuf package complains if it seems the same contents 173 # in files with different names. 174 protos.append(Path(path, f'protobuf_{hash(proto):x}.proto')) 175 protos[-1].write_text(proto) 176 177 yield from compile_and_import(protos, includes, output_dir) 178 179 180T = TypeVar('T') 181 182 183class _NestedPackage(Generic[T]): 184 """Facilitates navigating protobuf packages as attributes.""" 185 186 def __init__(self, package: str): 187 self._packages: dict[str, _NestedPackage[T]] = {} 188 self._items: list[T] = [] 189 self._package = package 190 191 def _add_package(self, subpackage: str, package: _NestedPackage) -> None: 192 self._packages[subpackage] = package 193 194 def _add_item(self, item) -> None: 195 if item not in self._items: # Don't store the same item multiple times. 196 self._items.append(item) 197 198 def __getattr__(self, attr: str): 199 """Look up subpackages or package members.""" 200 if attr in self._packages: 201 return self._packages[attr] 202 203 for item in self._items: 204 if hasattr(item, attr): 205 return getattr(item, attr) 206 207 raise AttributeError( 208 f'Proto package "{self._package}" does not contain "{attr}"' 209 ) 210 211 def __getitem__(self, subpackage: str) -> _NestedPackage[T]: 212 """Support accessing nested packages by name.""" 213 result = self 214 215 for package in subpackage.split('.'): 216 result = result._packages[package] 217 218 return result 219 220 def __dir__(self) -> list[str]: 221 """List subpackages and members of modules as attributes.""" 222 attributes = list(self._packages) 223 224 for item in self._items: 225 for attr, value in vars(item).items(): 226 # Exclude private variables and modules from dir(). 227 if not attr.startswith('_') and not isinstance( 228 value, ModuleType 229 ): 230 attributes.append(attr) 231 232 return attributes 233 234 def __iter__(self) -> Iterator['_NestedPackage[T]']: 235 """Iterate over nested packages.""" 236 return iter(self._packages.values()) 237 238 def __repr__(self) -> str: 239 msg = [f'ProtoPackage({self._package!r}'] 240 241 public_members = [ 242 i 243 for i in vars(self) 244 if i not in self._packages and not i.startswith('_') 245 ] 246 if public_members: 247 msg.append(f'members={str(public_members)}') 248 249 if self._packages: 250 msg.append(f'subpackages={str(list(self._packages))}') 251 252 return ', '.join(msg) + ')' 253 254 def __str__(self) -> str: 255 return self._package 256 257 258class Packages(NamedTuple): 259 """Items in a protobuf package structure; returned from as_package.""" 260 261 items_by_package: dict[str, list] 262 packages: _NestedPackage 263 264 265def as_packages( 266 items: Iterable[tuple[str, T]], packages: Packages | None = None 267) -> Packages: 268 """Places items in a proto-style package structure navigable by attributes. 269 270 Args: 271 items: (package, item) tuples to insert into the package structure 272 packages: if provided, update this Packages instead of creating a new one 273 """ 274 if packages is None: 275 packages = Packages({}, _NestedPackage('')) 276 277 for package, item in items: 278 packages.items_by_package.setdefault(package, []).append(item) 279 280 entry = packages.packages 281 subpackages = package.split('.') 282 283 # pylint: disable=protected-access 284 for i, subpackage in enumerate(subpackages, 1): 285 if subpackage not in entry._packages: 286 entry._add_package( 287 subpackage, _NestedPackage('.'.join(subpackages[:i])) 288 ) 289 290 entry = entry._packages[subpackage] 291 292 entry._add_item(item) 293 # pylint: enable=protected-access 294 295 return packages 296 297 298PathOrModule = str | Path | ModuleType 299 300 301class Library: 302 """A collection of protocol buffer modules sorted by package. 303 304 In Python, each .proto file is compiled into a Python module. The Library 305 class makes it simple to navigate a collection of Python modules 306 corresponding to .proto files, without relying on the location of these 307 compiled modules. 308 309 Proto messages and other types can be directly accessed by their protocol 310 buffer package name. For example, the foo.bar.Baz message can be accessed 311 in a Library called `protos` as: 312 313 protos.packages.foo.bar.Baz 314 315 A Library also provides the modules_by_package dictionary, for looking up 316 the list of modules in a particular package, and the modules() generator 317 for iterating over all modules. 318 """ 319 320 @classmethod 321 def from_paths(cls, protos: Iterable[str | Path | ModuleType]) -> Library: 322 """Creates a Library from paths to proto files or proto modules.""" 323 paths: list[Path | str] = [] 324 modules: list[ModuleType] = [] 325 326 for proto in protos: 327 if isinstance(proto, (Path, str)): 328 paths.append(proto) 329 else: 330 modules.append(proto) 331 332 if paths: 333 modules += compile_and_import(paths) 334 return Library(modules) 335 336 @classmethod 337 def from_strings( 338 cls, 339 contents: Iterable[str], 340 includes: Iterable[Path | str] = (), 341 output_dir: Path | str | None = None, 342 ) -> Library: 343 """Creates a proto library from protos in the provided strings.""" 344 return cls(compile_and_import_strings(contents, includes, output_dir)) 345 346 def __init__(self, modules: Iterable[ModuleType]): 347 """Constructs a Library from an iterable of modules. 348 349 A Library can be constructed with modules dynamically compiled by 350 compile_and_import. For example: 351 352 protos = Library(compile_and_import(list_of_proto_files)) 353 """ 354 self.modules_by_package, self.packages = as_packages( 355 (m.DESCRIPTOR.package, m) # type: ignore[attr-defined] 356 for m in modules 357 ) 358 359 def modules(self) -> Iterable: 360 """Iterates over all protobuf modules in this library.""" 361 for module_list in self.modules_by_package.values(): 362 yield from module_list 363 364 def messages(self) -> Iterable: 365 """Iterates over all protobuf messages in this library.""" 366 for module in self.modules(): 367 yield from _nested_messages( 368 module, module.DESCRIPTOR.message_types_by_name 369 ) 370 371 372def _nested_messages(scope, message_names: Iterable[str]) -> Iterator: 373 for name in message_names: 374 msg = getattr(scope, name) 375 yield msg 376 yield from _nested_messages(msg, msg.DESCRIPTOR.nested_types_by_name) 377 378 379def _repr_char(char: int) -> str: 380 r"""Returns an ASCII char or the \x code for non-printable values.""" 381 if ord(' ') <= char <= ord('~'): 382 return r"\'" if chr(char) == "'" else chr(char) 383 384 return f'\\x{char:02X}' 385 386 387def bytes_repr(value: bytes) -> str: 388 """Prints bytes as mixed ASCII only if at least half are printable.""" 389 ascii_char_count = sum(ord(' ') <= c <= ord('~') for c in value) 390 if ascii_char_count >= len(value) / 2: 391 contents = ''.join(_repr_char(c) for c in value) 392 else: 393 contents = ''.join(f'\\x{c:02X}' for c in value) 394 395 return f"b'{contents}'" 396 397 398def _field_repr(field, value) -> str: 399 if field.type == field.TYPE_ENUM: 400 try: 401 enum = field.enum_type.values_by_number[value] 402 return f'{field.enum_type.full_name}.{enum.name}' 403 except KeyError: 404 return repr(value) 405 406 if field.type == field.TYPE_MESSAGE: 407 return proto_repr(value) 408 409 if field.type == field.TYPE_BYTES: 410 return bytes_repr(value) 411 412 return repr(value) 413 414 415def _proto_repr(message) -> Iterator[str]: 416 for field in message.DESCRIPTOR.fields: 417 value = getattr(message, field.name) 418 419 # Skip fields that are not present. 420 try: 421 if not message.HasField(field.name): 422 continue 423 except ValueError: 424 # Skip default-valued fields that don't support HasField. 425 if ( 426 field.label != field.LABEL_REPEATED 427 and value == field.default_value 428 ): 429 continue 430 431 if field.label == field.LABEL_REPEATED: 432 if not value: 433 continue 434 435 if isinstance(value, Mapping): 436 key_desc, value_desc = field.message_type.fields 437 values = ', '.join( 438 f'{_field_repr(key_desc, k)}: {_field_repr(value_desc, v)}' 439 for k, v in value.items() 440 ) 441 yield f'{field.name}={{{values}}}' 442 else: 443 values = ', '.join(_field_repr(field, v) for v in value) 444 yield f'{field.name}=[{values}]' 445 else: 446 yield f'{field.name}={_field_repr(field, value)}' 447 448 449def proto_repr(message, *, wrap: bool = True) -> str: 450 """Creates a repr-like string for a protobuf. 451 452 In an interactive console that imports proto objects into the namespace, the 453 output of proto_repr() can be used as Python source to create a proto 454 object. 455 456 Args: 457 message: The protobuf message to format 458 wrap: If true and black is available, the output is wrapped according to 459 PEP8 using black. 460 """ 461 raw = f'{message.DESCRIPTOR.full_name}({", ".join(_proto_repr(message))})' 462 463 if wrap and black is not None and black_mode is not None: 464 return black.format_str(raw, mode=black_mode).strip() 465 466 return raw 467