xref: /aosp_15_r20/external/bazelbuild-rules_cc/cc/toolchains/impl/nested_args.bzl (revision eed53cd41c5909d05eedc7ad9720bb158fd93452)
1# Copyright 2024 The Bazel Authors. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Helper functions for working with args."""
15
16load("@bazel_skylib//lib:structs.bzl", "structs")
17load("//cc:cc_toolchain_config_lib.bzl", "flag_group", "variable_with_value")
18load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo", "VariableInfo")
19load(":collect.bzl", "collect_files", "collect_provider")
20
21visibility([
22    "//cc/toolchains",
23    "//tests/rule_based_toolchain/...",
24])
25
26REQUIRES_MUTUALLY_EXCLUSIVE_ERR = "requires_none, requires_not_none, requires_true, requires_false, and requires_equal are mutually exclusive"
27REQUIRES_NOT_NONE_ERR = "requires_not_none only works on options"
28REQUIRES_NONE_ERR = "requires_none only works on options"
29REQUIRES_TRUE_ERR = "requires_true only works on bools"
30REQUIRES_FALSE_ERR = "requires_false only works on bools"
31REQUIRES_EQUAL_ERR = "requires_equal only works on strings"
32REQUIRES_EQUAL_VALUE_ERR = "When requires_equal is provided, you must also provide requires_equal_value to specify what it should be equal to"
33FORMAT_ARGS_ERR = "format_args can only format strings, files, or directories"
34
35_NOT_ESCAPED_FMT = "%% should always either of the form %%s, or escaped with %%%%. Instead, got %r"
36
37_EXAMPLE = """
38
39cc_args(
40    ...,
41    args = [format_arg("--foo=%s", "//cc/toolchains/variables:foo")]
42)
43
44or
45
46cc_args(
47    ...,
48    # If foo_list contains ["a", "b"], then this expands to ["--foo", "+a", "--foo", "+b"].
49    args = ["--foo", format_arg("+%s")],
50    iterate_over = "//toolchains/variables:foo_list",
51"""
52
53# @unsorted-dict-items.
54NESTED_ARGS_ATTRS = {
55    "args": attr.string_list(
56        doc = """json-encoded arguments to be added to the command-line.
57
58Usage:
59cc_args(
60    ...,
61    args = ["--foo", format_arg("%s", "//cc/toolchains/variables:foo")]
62)
63
64This is equivalent to flag_group(flags = ["--foo", "%{foo}"])
65
66Mutually exclusive with nested.
67""",
68    ),
69    "nested": attr.label_list(
70        providers = [NestedArgsInfo],
71        doc = """nested_args that should be added on the command-line.
72
73Mutually exclusive with args.""",
74    ),
75    "data": attr.label_list(
76        allow_files = True,
77        doc = """Files required to add this argument to the command-line.
78
79For example, a flag that sets the header directory might add the headers in that
80directory as additional files.
81""",
82    ),
83    "variables": attr.label_list(
84        providers = [VariableInfo],
85        doc = "Variables to be used in substitutions",
86    ),
87    "iterate_over": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.iterate_over"),
88    "requires_not_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_available"),
89    "requires_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_not_available"),
90    "requires_true": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_true"),
91    "requires_false": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_false"),
92    "requires_equal": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_equal"),
93    "requires_equal_value": attr.string(),
94}
95
96def args_wrapper_macro(*, name, rule, args = [], **kwargs):
97    """Invokes a rule by converting args to attributes.
98
99    Args:
100        name: (str) The name of the target.
101        rule: (rule) The rule to invoke. Either cc_args or cc_nested_args.
102        args: (List[str|Formatted]) A list of either strings, or function calls
103          from format.bzl. For example:
104            ["--foo", format_arg("--sysroot=%s", "//cc/toolchains/variables:sysroot")]
105        **kwargs: kwargs to pass through into the rule invocation.
106    """
107    out_args = []
108    vars = []
109    if type(args) != "list":
110        fail("Args must be a list in %s" % native.package_relative_label(name))
111    for arg in args:
112        if type(arg) == "string":
113            out_args.append(raw_string(arg))
114        elif getattr(arg, "format_type") == "format_arg":
115            arg = structs.to_dict(arg)
116            if arg["value"] == None:
117                out_args.append(arg)
118            else:
119                var = arg.pop("value")
120
121                # Swap the variable from a label to an index. This allows us to
122                # actually get the providers in a rule.
123                out_args.append(struct(value = len(vars), **arg))
124                vars.append(var)
125        else:
126            fail("Invalid type of args in %s. Expected either a string or format_args(format_string, variable_label), got value %r" % (native.package_relative_label(name), arg))
127
128    rule(
129        name = name,
130        args = [json.encode(arg) for arg in out_args],
131        variables = vars,
132        **kwargs
133    )
134
135def _var(target):
136    if target == None:
137        return None
138    return target[VariableInfo].name
139
140# TODO: Consider replacing this with a subrule in the future. However, maybe not
141# for a long time, since it'll break compatibility with all bazel versions < 7.
142def nested_args_provider_from_ctx(ctx):
143    """Gets the nested args provider from a rule that has NESTED_ARGS_ATTRS.
144
145    Args:
146        ctx: The rule context
147    Returns:
148        NestedArgsInfo
149    """
150    variables = collect_provider(ctx.attr.variables, VariableInfo)
151    args = []
152    for arg in ctx.attr.args:
153        arg = json.decode(arg)
154        if "value" in arg:
155            if arg["value"] != None:
156                arg["value"] = variables[arg["value"]]
157        args.append(struct(**arg))
158
159    return nested_args_provider(
160        label = ctx.label,
161        args = args,
162        nested = collect_provider(ctx.attr.nested, NestedArgsInfo),
163        files = collect_files(ctx.attr.data),
164        iterate_over = _var(ctx.attr.iterate_over),
165        requires_not_none = _var(ctx.attr.requires_not_none),
166        requires_none = _var(ctx.attr.requires_none),
167        requires_true = _var(ctx.attr.requires_true),
168        requires_false = _var(ctx.attr.requires_false),
169        requires_equal = _var(ctx.attr.requires_equal),
170        requires_equal_value = ctx.attr.requires_equal_value,
171    )
172
173def raw_string(s):
174    """Constructs metadata for creating a raw string.
175
176    Args:
177      s: (str) The string to input.
178    Returns:
179      Metadata suitable for format_variable.
180    """
181    return struct(format_type = "raw", format = s)
182
183def format_string_indexes(s, fail = fail):
184    """Gets the index of a '%s' in a string.
185
186    Args:
187      s: (str) The string
188      fail: The fail function. Used for tests
189
190    Returns:
191      List[int] The indexes of the '%s' in the string
192    """
193    indexes = []
194    escaped = False
195    for i in range(len(s)):
196        if not escaped and s[i] == "%":
197            escaped = True
198        elif escaped:
199            if s[i] == "{":
200                fail('Using the old mechanism for variables, %%{variable}, but we instead use format_arg("--foo=%%s", "//cc/toolchains/variables:<variable>"). Got %r' % s)
201            elif s[i] == "s":
202                indexes.append(i - 1)
203            elif s[i] != "%":
204                fail(_NOT_ESCAPED_FMT % s)
205            escaped = False
206    if escaped:
207        return fail(_NOT_ESCAPED_FMT % s)
208    return indexes
209
210def format_variable(arg, iterate_over, fail = fail):
211    """Lists all of the variables referenced by an argument.
212
213    Eg: referenced_variables([
214        format_arg("--foo", None),
215        format_arg("--bar=%s", ":bar")
216    ]) => ["--foo", "--bar=%{bar}"]
217
218    Args:
219      arg: [Formatted] The command-line arguments, as created by the format_arg function.
220      iterate_over: (Optional[str]) The name of the variable we're iterating over.
221      fail: The fail function. Used for tests
222
223    Returns:
224      A string defined to be compatible with flag groups.
225    """
226    indexes = format_string_indexes(arg.format, fail = fail)
227    if arg.format_type == "raw":
228        if indexes:
229            return fail("Can't use %s with a raw string. Either escape it with %%s or use format_arg, like the following examples:" + _EXAMPLE)
230        return arg.format
231    else:
232        if len(indexes) == 0:
233            return fail('format_arg requires a "%%s" in the format string, but got %r' % arg.format)
234        elif len(indexes) > 1:
235            return fail("Only one %%s can be used in a format string, but got %r" % arg.format)
236
237        if arg.value == None:
238            if iterate_over == None:
239                return fail("format_arg requires either a variable to format, or iterate_over must be provided. For example:" + _EXAMPLE)
240            var = iterate_over
241        else:
242            var = arg.value.name
243
244        index = indexes[0]
245        return arg.format[:index] + "%{" + var + "}" + arg.format[index + 2:]
246
247def nested_args_provider(
248        *,
249        label,
250        args = [],
251        nested = [],
252        files = depset([]),
253        iterate_over = None,
254        requires_not_none = None,
255        requires_none = None,
256        requires_true = None,
257        requires_false = None,
258        requires_equal = None,
259        requires_equal_value = "",
260        fail = fail):
261    """Creates a validated NestedArgsInfo.
262
263    Does not validate types, as you can't know the type of a variable until
264    you have a cc_args wrapping it, because the outer layers can change that
265    type using iterate_over.
266
267    Args:
268        label: (Label) The context we are currently evaluating in. Used for
269          error messages.
270        args: (List[str]) The command-line arguments to add.
271        nested: (List[NestedArgsInfo]) command-line arguments to expand.
272        files: (depset[File]) Files required for this set of command-line args.
273        iterate_over: (Optional[str]) Variable to iterate over
274        requires_not_none: (Optional[str]) If provided, this NestedArgsInfo will
275          be ignored if the variable is None
276        requires_none: (Optional[str]) If provided, this NestedArgsInfo will
277          be ignored if the variable is not None
278        requires_true: (Optional[str]) If provided, this NestedArgsInfo will
279          be ignored if the variable is false
280        requires_false: (Optional[str]) If provided, this NestedArgsInfo will
281          be ignored if the variable is true
282        requires_equal: (Optional[str]) If provided, this NestedArgsInfo will
283          be ignored if the variable is not equal to requires_equal_value.
284        requires_equal_value: (str) The value to compare the requires_equal
285          variable with
286        fail: A fail function. Use only for testing.
287    Returns:
288        NestedArgsInfo
289    """
290    if bool(args) == bool(nested):
291        fail("Exactly one of args and nested must be provided")
292
293    transitive_files = [ea.files for ea in nested]
294    transitive_files.append(files)
295
296    has_value = [attr for attr in [
297        requires_not_none,
298        requires_none,
299        requires_true,
300        requires_false,
301        requires_equal,
302    ] if attr != None]
303
304    # We may want to reconsider this down the line, but it's easier to open up
305    # an API than to lock down an API.
306    if len(has_value) > 1:
307        fail(REQUIRES_MUTUALLY_EXCLUSIVE_ERR)
308
309    kwargs = {}
310    requires_types = {}
311    if nested:
312        kwargs["flag_groups"] = [ea.legacy_flag_group for ea in nested]
313
314    unwrap_options = []
315
316    if iterate_over:
317        kwargs["iterate_over"] = iterate_over
318
319    if requires_not_none:
320        kwargs["expand_if_available"] = requires_not_none
321        requires_types.setdefault(requires_not_none, []).append(struct(
322            msg = REQUIRES_NOT_NONE_ERR,
323            valid_types = ["option"],
324            after_option_unwrap = False,
325        ))
326        unwrap_options.append(requires_not_none)
327    elif requires_none:
328        kwargs["expand_if_not_available"] = requires_none
329        requires_types.setdefault(requires_none, []).append(struct(
330            msg = REQUIRES_NONE_ERR,
331            valid_types = ["option"],
332            after_option_unwrap = False,
333        ))
334    elif requires_true:
335        kwargs["expand_if_true"] = requires_true
336        requires_types.setdefault(requires_true, []).append(struct(
337            msg = REQUIRES_TRUE_ERR,
338            valid_types = ["bool"],
339            after_option_unwrap = True,
340        ))
341        unwrap_options.append(requires_true)
342    elif requires_false:
343        kwargs["expand_if_false"] = requires_false
344        requires_types.setdefault(requires_false, []).append(struct(
345            msg = REQUIRES_FALSE_ERR,
346            valid_types = ["bool"],
347            after_option_unwrap = True,
348        ))
349        unwrap_options.append(requires_false)
350    elif requires_equal:
351        if not requires_equal_value:
352            fail(REQUIRES_EQUAL_VALUE_ERR)
353        kwargs["expand_if_equal"] = variable_with_value(
354            name = requires_equal,
355            value = requires_equal_value,
356        )
357        unwrap_options.append(requires_equal)
358        requires_types.setdefault(requires_equal, []).append(struct(
359            msg = REQUIRES_EQUAL_ERR,
360            valid_types = ["string"],
361            after_option_unwrap = True,
362        ))
363
364    for arg in args:
365        if arg.format_type != "raw":
366            var_name = arg.value.name if arg.value != None else iterate_over
367            requires_types.setdefault(var_name, []).append(struct(
368                msg = FORMAT_ARGS_ERR,
369                valid_types = ["string", "file", "directory"],
370                after_option_unwrap = True,
371            ))
372
373    if args:
374        kwargs["flags"] = [
375            format_variable(arg, iterate_over = iterate_over, fail = fail)
376            for arg in args
377        ]
378
379    return NestedArgsInfo(
380        label = label,
381        nested = nested,
382        files = depset(transitive = transitive_files),
383        iterate_over = iterate_over,
384        unwrap_options = unwrap_options,
385        requires_types = requires_types,
386        legacy_flag_group = flag_group(**kwargs),
387    )
388