xref: /aosp_15_r20/external/pigweed/pw_build/py/pw_build/bazel_query.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2023 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Parses Bazel rules from a local Bazel workspace."""
15
16import json
17import re
18import subprocess
19
20from pathlib import PurePath, PurePosixPath
21from typing import (
22    Any,
23    Callable,
24    Iterable,
25)
26
27BazelValue = bool | int | str | list[str] | dict[str, str]
28
29LABEL_PAT = re.compile(r'^(?:(?:@([a-zA-Z0-9_]*))?//([^:]*))?(?::([^:]+))?$')
30
31
32class ParseError(Exception):
33    """Raised when a Bazel query returns data that can't be parsed."""
34
35
36class BazelLabel:
37    """Represents a Bazel target identifier."""
38
39    def __init__(
40        self,
41        label_str: str,
42        repo: str | None = None,
43        package: str | None = None,
44    ) -> None:
45        """Creates a Bazel label.
46
47        This method will attempt to parse the repo, package, and target portion
48        of the given label string. If the repo portion of the label is omitted,
49        it will use the given repo, if provided. If the repo and the package
50        portions of the label are omitted, it will use the given package, if
51        provided. If the target portion is omitted, the last segment of the
52        package will be used.
53
54        Args:
55            label_str: Bazel label string, like "@repo//pkg:target".
56            repo: Repo to use if omitted from label, e.g. as in "//pkg:target".
57            package: Package to use if omitted from label, e.g. as in ":target".
58        """
59        match = re.match(LABEL_PAT, label_str)
60        if not label_str or not match:
61            raise ParseError(f'invalid label: "{label_str}"')
62        if match.group(1) or not package:
63            package = ''
64        if match.group(1):
65            self._repo = match.group(1)
66        else:
67            assert repo
68            self._repo = repo
69        self._package = match.group(2) if match.group(2) else package
70        self._target = (
71            match.group(3)
72            if match.group(3)
73            else PurePosixPath(self._package).name
74        )
75
76    def __str__(self) -> str:
77        """Canonical representation of a Bazel label."""
78        return f'@{self._repo}//{self._package}:{self._target}'
79
80    def repo(self) -> str:
81        """Returns the repository identifier associated with this label."""
82        return self._repo
83
84    def package(self) -> str:
85        """Returns the package path associated with this label."""
86        return self._package
87
88    def target(self) -> str:
89        """Returns the target name associated with this label."""
90        return self._target
91
92
93def parse_invalid(attr: dict[str, Any]) -> BazelValue:
94    """Raises an error that a type is unrecognized."""
95    attr_type = attr['type']
96    raise ParseError(f'unknown type: {attr_type}, expected one of {BazelValue}')
97
98
99class BazelRule:
100    """Represents a Bazel rule as parsed from the query results."""
101
102    def __init__(self, kind: str, label: BazelLabel) -> None:
103        """Create a Bazel rule.
104
105        Args:
106            kind: The type of Bazel rule, e.g. cc_library.
107            label: A Bazel label corresponding to this rule.
108        """
109        self._kind = kind
110        self._label = label
111        self._attrs: dict[str, BazelValue] = {}
112        self._types: dict[str, str] = {}
113
114    def kind(self) -> str:
115        """Returns this rule's target type."""
116        return self._kind
117
118    def label(self) -> BazelLabel:
119        """Returns this rule's Bazel label."""
120        return self._label
121
122    def parse_attrs(self, attrs: Iterable[dict[str, Any]]) -> None:
123        """Maps JSON data from a bazel query into this object.
124
125        Args:
126            attrs: A dictionary of attribute names and values for the Bazel
127                rule. These should match the output of
128                `bazel cquery ... --output=jsonproto`.
129        """
130        attr_parsers: dict[str, Callable[[dict[str, Any]], BazelValue]] = {
131            'boolean': lambda attr: attr.get('booleanValue', False),
132            'integer': lambda attr: int(attr.get('intValue', '0')),
133            'string': lambda attr: attr.get('stringValue', ''),
134            'label': lambda attr: attr.get('stringValue', ''),
135            'string_list': lambda attr: attr.get('stringListValue', []),
136            'label_list': lambda attr: attr.get('stringListValue', []),
137            'string_dict': lambda attr: {
138                p['key']: p['value'] for p in attr.get('stringDictValue', [])
139            },
140        }
141        for attr in attrs:
142            if 'explicitlySpecified' not in attr:
143                continue
144            if not attr['explicitlySpecified']:
145                continue
146            try:
147                attr_name = attr['name']
148            except KeyError:
149                raise ParseError(
150                    f'missing "name" in {json.dumps(attr, indent=2)}'
151                )
152            try:
153                attr_type = attr['type'].lower()
154            except KeyError:
155                raise ParseError(
156                    f'missing "type" in {json.dumps(attr, indent=2)}'
157                )
158
159            attr_parser = attr_parsers.get(attr_type, parse_invalid)
160            self._attrs[attr_name] = attr_parser(attr)
161            self._types[attr_name] = attr_type
162
163    def has_attr(self, attr_name: str) -> bool:
164        """Returns whether the rule has an attribute of the given name.
165
166        Args:
167            attr_name: The name of the attribute.
168        """
169        return attr_name in self._attrs
170
171    def attr_type(self, attr_name: str) -> str:
172        """Returns the type of an attribute according to Bazel.
173
174        Args:
175            attr_name: The name of the attribute.
176        """
177        return self._types[attr_name]
178
179    def get_bool(self, attr_name: str) -> bool:
180        """Gets the value of a boolean attribute.
181
182        Args:
183            attr_name: The name of the boolean attribute.
184        """
185        val = self._attrs.get(attr_name, False)
186        assert isinstance(val, bool)
187        return val
188
189    def get_int(self, attr_name: str) -> int:
190        """Gets the value of an integer attribute.
191
192        Args:
193            attr_name: The name of the integer attribute.
194        """
195        val = self._attrs.get(attr_name, 0)
196        assert isinstance(val, int)
197        return val
198
199    def get_str(self, attr_name: str) -> str:
200        """Gets the value of a string attribute.
201
202        Args:
203            attr_name: The name of the string attribute.
204        """
205        val = self._attrs.get(attr_name, '')
206        assert isinstance(val, str)
207        return val
208
209    def get_list(self, attr_name: str) -> list[str]:
210        """Gets the value of a string list attribute.
211
212        Args:
213            attr_name: The name of the string list attribute.
214        """
215        val = self._attrs.get(attr_name, [])
216        assert isinstance(val, list)
217        return val
218
219    def get_dict(self, attr_name: str) -> dict[str, str]:
220        """Gets the value of a string list attribute.
221
222        Args:
223            attr_name: The name of the string list attribute.
224        """
225        val = self._attrs.get(attr_name, {})
226        assert isinstance(val, dict)
227        return val
228
229    def set_attr(self, attr_name: str, value: BazelValue) -> None:
230        """Sets the value of an attribute.
231
232        Args:
233            attr_name: The name of the attribute.
234            value: The value to set.
235        """
236        self._attrs[attr_name] = value
237
238    def filter_attr(self, attr_name: str, remove: Iterable[str]) -> None:
239        """Removes values from a list attribute.
240
241        Args:
242            attr_name: The name of the attribute.
243            remove: The values to remove.
244        """
245        values = self.get_list(attr_name)
246        self._attrs[attr_name] = [v for v in values if v not in remove]
247
248    def generate(self) -> Iterable[str]:
249        """Yields a sequence of strings describing the rule in Bazel."""
250        yield f'{self._kind}('
251        yield f'    name = "{self._label.target()}",'
252        for name in sorted(self._attrs.keys()):
253            if name == 'name':
254                continue
255            attr_type = self._types[name]
256            if attr_type == 'boolean':
257                yield f'    {name} = {self.get_bool(name)},'
258            elif attr_type == 'integer':
259                yield f'    {name} = {self.get_int(name)},'
260            elif attr_type == 'string':
261                yield f'    {name} = "{self.get_str(name)}",'
262            elif attr_type == 'string_list' or attr_type == 'label_list':
263                strs = self.get_list(name)
264                if len(strs) == 1:
265                    yield f'    {name} = ["{strs[0]}"],'
266                elif len(strs) > 1:
267                    yield f'    {name} = ['
268                    for s in strs:
269                        yield f'        "{s}",'
270                    yield '    ],'
271            elif attr_type == 'string_dict':
272                str_dict = self.get_dict(name)
273                yield f'    {name} = {{'
274                for k, v in str_dict.items():
275                    yield f'        {k} = "{v}",'
276                yield '    },'
277        yield ')'
278
279
280class BazelWorkspace:
281    """Represents a local instance of a Bazel repository.
282
283    Attributes:
284        defaults: Attributes automatically applied to every rule in the
285                  workspace, which may be ignored when parsing.
286        generate: Indicates whether GN should be automatically generated for
287                  this workspace or not.
288        targets:  A list of the Bazel targets to parse, along with their
289                  dependencies.
290    """
291
292    def __init__(
293        self, repo: str, source_dir: PurePath | None, fetch: bool = True
294    ) -> None:
295        """Creates an object representing a Bazel workspace at the given path.
296
297        Args:
298            repo: The Bazel repository name, like "com_google_pigweed".
299            source_dir: Path to the local instance of a Bazel workspace.
300        """
301        self.defaults: dict[str, list[str]] = {}
302        self.generate = True
303        self.targets: list[str] = []
304        self.options: dict[str, Any] = {}
305        self._fetched = False
306        self._repo: str = repo
307        self._revisions: dict[str, str] = {}
308        self._rules: dict[str, BazelRule] = {}
309        self._source_dir = source_dir
310
311        # Make sure the workspace has up-to-date objects and refs.
312        if fetch:
313            self._git('fetch')
314
315    def repo(self) -> str:
316        """Returns the Bazel repository name for this workspace."""
317        return self._repo
318
319    def get_http_archives(self) -> Iterable[BazelRule]:
320        """Returns the http_archive rules from a workspace's WORKSPACE file."""
321        if not self.generate:
322            return
323        for result in self._query('kind(http_archive, //external:*)'):
324            if result['type'] != 'RULE':
325                continue
326            yield self._make_rule(result['rule'])
327
328    def get_rules(self, labels: list[BazelLabel]) -> Iterable[BazelRule]:
329        """Returns a rule matching the given label."""
330        if not self.generate:
331            return
332        needed: list[str] = []
333        for label in labels:
334            print(f'Examining [{label}]                             ')
335            short = f'//{label.package()}:{label.target()}'
336            rule = self._rules.get(short, None)
337            if rule:
338                yield rule
339                continue
340            needed.append(short)
341        flags = [f'--{label}={value}' for label, value in self.options.items()]
342        results = list(self._cquery('+'.join(needed), flags))
343        for result in results:
344            rule_data = result['target']['rule']
345            yield self._make_rule(rule_data)
346
347    def revision(self, commitish: str = 'HEAD') -> str:
348        """Returns the revision digest of the workspace's git commit-ish."""
349        try:
350            return self._git('rev-parse', commitish)
351        except ParseError:
352            pass
353        tags = self._git('tag', '--list').split()
354        tag = min([tag for tag in tags if commitish in tag], key=len)
355        return self._git('rev-parse', tag)
356
357    def timestamp(self, revision: str) -> str:
358        """Returns the timestamp of the workspace's git commit-ish."""
359        return self._git('show', '--no-patch', '--format=%ci', revision)
360
361    def url(self) -> str:
362        """Returns the git URL of the workspace."""
363        return self._git('remote', 'get-url', 'origin')
364
365    def _cquery(self, expr: str, flags: list[str]) -> Iterable[Any]:
366        """Invokes `bazel cquery` with the given selector."""
367        result = self._bazel('cquery', expr, *flags, '--output=jsonproto')
368        return json.loads(result)['results']
369
370    def _query(self, expr: str) -> Iterable[Any]:
371        """Invokes `bazel query` with the given selector."""
372        results = self._bazel('query', expr, '--output=streamed_jsonproto')
373        return [json.loads(result) for result in results.split('\n')]
374
375    def _make_rule(self, rule_data: Any) -> BazelRule:
376        """Make a BazelRule from JSON data returned by query or cquery."""
377        short = rule_data['name']
378        label = BazelLabel(short, repo=self._repo)
379        rule = BazelRule(rule_data['ruleClass'], label)
380        rule.parse_attrs(rule_data['attribute'])
381        for attr_name, values in self.defaults.items():
382            rule.filter_attr(attr_name, values)
383        self._rules[short] = rule
384        return rule
385
386    def _bazel(self, *args: str) -> str:
387        """Execute a Bazel command in the workspace."""
388        return self._exec('bazel', *args, '--noshow_progress', check=False)
389
390    def _git(self, *args: str) -> str:
391        """Execute a git command in the workspace."""
392        return self._exec('git', *args, check=True)
393
394    def _exec(self, *args: str, check=True) -> str:
395        """Execute a command in the workspace."""
396        try:
397            result = subprocess.run(
398                list(args),
399                cwd=self._source_dir,
400                check=check,
401                capture_output=True,
402            )
403            if result.stdout:
404                # The extra 'str()' facilitates MagicMock.
405                return str(result.stdout.decode('utf-8')).strip()
406            if check:
407                return ''
408            errmsg = result.stderr.decode('utf-8')
409        except subprocess.CalledProcessError as error:
410            errmsg = error.stderr.decode('utf-8')
411        cmdline = ' '.join(list(args))
412        raise ParseError(
413            f'{self._repo} failed to exec '
414            + f'`cd {self._source_dir} && {cmdline}`: {errmsg}'
415        )
416