# Copyright 2024 The Bazel Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A simple precompiler to generate deterministic pyc files for Bazel.""" # NOTE: Imports specific to the persistent worker should only be imported # when a persistent worker is used. Avoiding the unnecessary imports # saves significant startup time for non-worker invocations. import argparse import py_compile import sys def _create_parser() -> "argparse.Namespace": parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument("--invalidation_mode", default="CHECKED_HASH") parser.add_argument("--optimize", type=int, default=-1) parser.add_argument("--python_version") parser.add_argument("--src", action="append", dest="srcs") parser.add_argument("--src_name", action="append", dest="src_names") parser.add_argument("--pyc", action="append", dest="pycs") parser.add_argument("--persistent_worker", action="store_true") parser.add_argument("--log_level", default="ERROR") parser.add_argument("--worker_impl", default="async") return parser def _compile(options: "argparse.Namespace") -> None: try: invalidation_mode = py_compile.PycInvalidationMode[ options.invalidation_mode.upper() ] except KeyError as e: raise ValueError( f"Unknown PycInvalidationMode: {options.invalidation_mode}" ) from e if not (len(options.srcs) == len(options.src_names) == len(options.pycs)): raise AssertionError( "Mismatched number of --src, --src_name, and/or --pyc args" ) for src, src_name, pyc in zip(options.srcs, options.src_names, options.pycs): py_compile.compile( src, pyc, doraise=True, dfile=src_name, optimize=options.optimize, invalidation_mode=invalidation_mode, ) return 0 # A stub type alias for readability. # See the Bazel WorkRequest object definition: # https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/worker_protocol.proto JsonWorkerRequest = object # A stub type alias for readability. # See the Bazel WorkResponse object definition: # https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/worker_protocol.proto JsonWorkerResponse = object class _SerialPersistentWorker: """Simple, synchronous, serial persistent worker.""" def __init__(self, instream: "typing.TextIO", outstream: "typing.TextIO"): self._instream = instream self._outstream = outstream self._parser = _create_parser() def run(self) -> None: try: while True: request = None try: request = self._get_next_request() if request is None: _logger.info("Empty request: exiting") break response = self._process_request(request) if response: # May be none for cancel request self._send_response(response) except Exception: _logger.exception("Unhandled error: request=%s", request) output = ( f"Unhandled error:\nRequest: {request}\n" + traceback.format_exc() ) request_id = 0 if not request else request.get("requestId", 0) self._send_response( { "exitCode": 3, "output": output, "requestId": request_id, } ) finally: _logger.info("Worker shutting down") def _get_next_request(self) -> "object | None": line = self._instream.readline() if not line: return None return json.loads(line) def _process_request(self, request: "JsonWorkRequest") -> "JsonWorkResponse | None": if request.get("cancel"): return None options = self._options_from_request(request) _compile(options) response = { "requestId": request.get("requestId", 0), "exitCode": 0, } return response def _options_from_request( self, request: "JsonWorkResponse" ) -> "argparse.Namespace": options = self._parser.parse_args(request["arguments"]) if request.get("sandboxDir"): prefix = request["sandboxDir"] options.srcs = [os.path.join(prefix, v) for v in options.srcs] options.pycs = [os.path.join(prefix, v) for v in options.pycs] return options def _send_response(self, response: "JsonWorkResponse") -> None: self._outstream.write(json.dumps(response) + "\n") self._outstream.flush() class _AsyncPersistentWorker: """Asynchronous, concurrent, persistent worker.""" def __init__(self, reader: "typing.TextIO", writer: "typing.TextIO"): self._reader = reader self._writer = writer self._parser = _create_parser() self._request_id_to_task = {} self._task_to_request_id = {} @classmethod async def main(cls, instream: "typing.TextIO", outstream: "typing.TextIO") -> None: reader, writer = await cls._connect_streams(instream, outstream) await cls(reader, writer).run() @classmethod async def _connect_streams( cls, instream: "typing.TextIO", outstream: "typing.TextIO" ) -> "tuple[asyncio.StreamReader, asyncio.StreamWriter]": loop = asyncio.get_event_loop() reader = asyncio.StreamReader() protocol = asyncio.StreamReaderProtocol(reader) await loop.connect_read_pipe(lambda: protocol, instream) w_transport, w_protocol = await loop.connect_write_pipe( asyncio.streams.FlowControlMixin, outstream ) writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop) return reader, writer async def run(self) -> None: while True: _logger.info("pending requests: %s", len(self._request_id_to_task)) request = await self._get_next_request() request_id = request.get("requestId", 0) task = asyncio.create_task( self._process_request(request), name=f"request_{request_id}" ) self._request_id_to_task[request_id] = task self._task_to_request_id[task] = request_id task.add_done_callback(self._handle_task_done) async def _get_next_request(self) -> "JsonWorkRequest": _logger.debug("awaiting line") line = await self._reader.readline() _logger.debug("recv line: %s", line) return json.loads(line) def _handle_task_done(self, task: "asyncio.Task") -> None: request_id = self._task_to_request_id[task] _logger.info("task done: %s %s", request_id, task) del self._task_to_request_id[task] del self._request_id_to_task[request_id] async def _process_request(self, request: "JsonWorkRequest") -> None: _logger.info("request %s: start: %s", request.get("requestId"), request) try: if request.get("cancel", False): await self._process_cancel_request(request) else: await self._process_compile_request(request) except asyncio.CancelledError: _logger.info( "request %s: cancel received, stopping processing", request.get("requestId"), ) # We don't send a response because we assume the request that # triggered cancelling sent the response raise except: _logger.exception("Unhandled error: request=%s", request) self._send_response( { "exitCode": 3, "output": f"Unhandled error:\nRequest: {request}\n" + traceback.format_exc(), "requestId": 0 if not request else request.get("requestId", 0), } ) async def _process_cancel_request(self, request: "JsonWorkRequest") -> None: request_id = request.get("requestId", 0) task = self._request_id_to_task.get(request_id) if not task: # It must be already completed, so ignore the request, per spec return task.cancel() self._send_response({"requestId": request_id, "wasCancelled": True}) async def _process_compile_request(self, request: "JsonWorkRequest") -> None: options = self._options_from_request(request) # _compile performs a varity of blocking IO calls, so run it separately await asyncio.to_thread(_compile, options) self._send_response( { "requestId": request.get("requestId", 0), "exitCode": 0, } ) def _options_from_request(self, request: "JsonWorkRequest") -> "argparse.Namespace": options = self._parser.parse_args(request["arguments"]) if request.get("sandboxDir"): prefix = request["sandboxDir"] options.srcs = [os.path.join(prefix, v) for v in options.srcs] options.pycs = [os.path.join(prefix, v) for v in options.pycs] return options def _send_response(self, response: "JsonWorkResponse") -> None: _logger.info("request %s: respond: %s", response.get("requestId"), response) self._writer.write(json.dumps(response).encode("utf8") + b"\n") def main(args: "list[str]") -> int: options = _create_parser().parse_args(args) # Persistent workers are started with the `--persistent_worker` flag. # See the following docs for details on persistent workers: # https://bazel.build/remote/persistent # https://bazel.build/remote/multiplex # https://bazel.build/remote/creating if options.persistent_worker: global asyncio, itertools, json, logging, os, traceback, _logger import asyncio import itertools import json import logging import os.path import traceback _logger = logging.getLogger("precompiler") # Only configure logging for workers. This prevents non-worker # invocations from spamming stderr with logging info logging.basicConfig(level=getattr(logging, options.log_level)) _logger.info("persistent worker: impl=%s", options.worker_impl) if options.worker_impl == "serial": _SerialPersistentWorker(sys.stdin, sys.stdout).run() elif options.worker_impl == "async": asyncio.run(_AsyncPersistentWorker.main(sys.stdin, sys.stdout)) else: raise ValueError(f"Unknown worker impl: {options.worker_impl}") else: _compile(options) return 0 if __name__ == "__main__": sys.exit(main(sys.argv[1:]))