1# Copyright 2024 The Bazel Authors. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Helper functions for working with args.""" 15 16load("@bazel_skylib//lib:structs.bzl", "structs") 17load("//cc:cc_toolchain_config_lib.bzl", "flag_group", "variable_with_value") 18load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo", "VariableInfo") 19load(":collect.bzl", "collect_files", "collect_provider") 20 21visibility([ 22 "//cc/toolchains", 23 "//tests/rule_based_toolchain/...", 24]) 25 26REQUIRES_MUTUALLY_EXCLUSIVE_ERR = "requires_none, requires_not_none, requires_true, requires_false, and requires_equal are mutually exclusive" 27REQUIRES_NOT_NONE_ERR = "requires_not_none only works on options" 28REQUIRES_NONE_ERR = "requires_none only works on options" 29REQUIRES_TRUE_ERR = "requires_true only works on bools" 30REQUIRES_FALSE_ERR = "requires_false only works on bools" 31REQUIRES_EQUAL_ERR = "requires_equal only works on strings" 32REQUIRES_EQUAL_VALUE_ERR = "When requires_equal is provided, you must also provide requires_equal_value to specify what it should be equal to" 33FORMAT_ARGS_ERR = "format_args can only format strings, files, or directories" 34 35_NOT_ESCAPED_FMT = "%% should always either of the form %%s, or escaped with %%%%. Instead, got %r" 36 37_EXAMPLE = """ 38 39cc_args( 40 ..., 41 args = [format_arg("--foo=%s", "//cc/toolchains/variables:foo")] 42) 43 44or 45 46cc_args( 47 ..., 48 # If foo_list contains ["a", "b"], then this expands to ["--foo", "+a", "--foo", "+b"]. 49 args = ["--foo", format_arg("+%s")], 50 iterate_over = "//toolchains/variables:foo_list", 51""" 52 53# @unsorted-dict-items. 54NESTED_ARGS_ATTRS = { 55 "args": attr.string_list( 56 doc = """json-encoded arguments to be added to the command-line. 57 58Usage: 59cc_args( 60 ..., 61 args = ["--foo", format_arg("%s", "//cc/toolchains/variables:foo")] 62) 63 64This is equivalent to flag_group(flags = ["--foo", "%{foo}"]) 65 66Mutually exclusive with nested. 67""", 68 ), 69 "nested": attr.label_list( 70 providers = [NestedArgsInfo], 71 doc = """nested_args that should be added on the command-line. 72 73Mutually exclusive with args.""", 74 ), 75 "data": attr.label_list( 76 allow_files = True, 77 doc = """Files required to add this argument to the command-line. 78 79For example, a flag that sets the header directory might add the headers in that 80directory as additional files. 81""", 82 ), 83 "variables": attr.label_list( 84 providers = [VariableInfo], 85 doc = "Variables to be used in substitutions", 86 ), 87 "iterate_over": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.iterate_over"), 88 "requires_not_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_available"), 89 "requires_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_not_available"), 90 "requires_true": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_true"), 91 "requires_false": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_false"), 92 "requires_equal": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_equal"), 93 "requires_equal_value": attr.string(), 94} 95 96def args_wrapper_macro(*, name, rule, args = [], **kwargs): 97 """Invokes a rule by converting args to attributes. 98 99 Args: 100 name: (str) The name of the target. 101 rule: (rule) The rule to invoke. Either cc_args or cc_nested_args. 102 args: (List[str|Formatted]) A list of either strings, or function calls 103 from format.bzl. For example: 104 ["--foo", format_arg("--sysroot=%s", "//cc/toolchains/variables:sysroot")] 105 **kwargs: kwargs to pass through into the rule invocation. 106 """ 107 out_args = [] 108 vars = [] 109 if type(args) != "list": 110 fail("Args must be a list in %s" % native.package_relative_label(name)) 111 for arg in args: 112 if type(arg) == "string": 113 out_args.append(raw_string(arg)) 114 elif getattr(arg, "format_type") == "format_arg": 115 arg = structs.to_dict(arg) 116 if arg["value"] == None: 117 out_args.append(arg) 118 else: 119 var = arg.pop("value") 120 121 # Swap the variable from a label to an index. This allows us to 122 # actually get the providers in a rule. 123 out_args.append(struct(value = len(vars), **arg)) 124 vars.append(var) 125 else: 126 fail("Invalid type of args in %s. Expected either a string or format_args(format_string, variable_label), got value %r" % (native.package_relative_label(name), arg)) 127 128 rule( 129 name = name, 130 args = [json.encode(arg) for arg in out_args], 131 variables = vars, 132 **kwargs 133 ) 134 135def _var(target): 136 if target == None: 137 return None 138 return target[VariableInfo].name 139 140# TODO: Consider replacing this with a subrule in the future. However, maybe not 141# for a long time, since it'll break compatibility with all bazel versions < 7. 142def nested_args_provider_from_ctx(ctx): 143 """Gets the nested args provider from a rule that has NESTED_ARGS_ATTRS. 144 145 Args: 146 ctx: The rule context 147 Returns: 148 NestedArgsInfo 149 """ 150 variables = collect_provider(ctx.attr.variables, VariableInfo) 151 args = [] 152 for arg in ctx.attr.args: 153 arg = json.decode(arg) 154 if "value" in arg: 155 if arg["value"] != None: 156 arg["value"] = variables[arg["value"]] 157 args.append(struct(**arg)) 158 159 return nested_args_provider( 160 label = ctx.label, 161 args = args, 162 nested = collect_provider(ctx.attr.nested, NestedArgsInfo), 163 files = collect_files(ctx.attr.data), 164 iterate_over = _var(ctx.attr.iterate_over), 165 requires_not_none = _var(ctx.attr.requires_not_none), 166 requires_none = _var(ctx.attr.requires_none), 167 requires_true = _var(ctx.attr.requires_true), 168 requires_false = _var(ctx.attr.requires_false), 169 requires_equal = _var(ctx.attr.requires_equal), 170 requires_equal_value = ctx.attr.requires_equal_value, 171 ) 172 173def raw_string(s): 174 """Constructs metadata for creating a raw string. 175 176 Args: 177 s: (str) The string to input. 178 Returns: 179 Metadata suitable for format_variable. 180 """ 181 return struct(format_type = "raw", format = s) 182 183def format_string_indexes(s, fail = fail): 184 """Gets the index of a '%s' in a string. 185 186 Args: 187 s: (str) The string 188 fail: The fail function. Used for tests 189 190 Returns: 191 List[int] The indexes of the '%s' in the string 192 """ 193 indexes = [] 194 escaped = False 195 for i in range(len(s)): 196 if not escaped and s[i] == "%": 197 escaped = True 198 elif escaped: 199 if s[i] == "{": 200 fail('Using the old mechanism for variables, %%{variable}, but we instead use format_arg("--foo=%%s", "//cc/toolchains/variables:<variable>"). Got %r' % s) 201 elif s[i] == "s": 202 indexes.append(i - 1) 203 elif s[i] != "%": 204 fail(_NOT_ESCAPED_FMT % s) 205 escaped = False 206 if escaped: 207 return fail(_NOT_ESCAPED_FMT % s) 208 return indexes 209 210def format_variable(arg, iterate_over, fail = fail): 211 """Lists all of the variables referenced by an argument. 212 213 Eg: referenced_variables([ 214 format_arg("--foo", None), 215 format_arg("--bar=%s", ":bar") 216 ]) => ["--foo", "--bar=%{bar}"] 217 218 Args: 219 arg: [Formatted] The command-line arguments, as created by the format_arg function. 220 iterate_over: (Optional[str]) The name of the variable we're iterating over. 221 fail: The fail function. Used for tests 222 223 Returns: 224 A string defined to be compatible with flag groups. 225 """ 226 indexes = format_string_indexes(arg.format, fail = fail) 227 if arg.format_type == "raw": 228 if indexes: 229 return fail("Can't use %s with a raw string. Either escape it with %%s or use format_arg, like the following examples:" + _EXAMPLE) 230 return arg.format 231 else: 232 if len(indexes) == 0: 233 return fail('format_arg requires a "%%s" in the format string, but got %r' % arg.format) 234 elif len(indexes) > 1: 235 return fail("Only one %%s can be used in a format string, but got %r" % arg.format) 236 237 if arg.value == None: 238 if iterate_over == None: 239 return fail("format_arg requires either a variable to format, or iterate_over must be provided. For example:" + _EXAMPLE) 240 var = iterate_over 241 else: 242 var = arg.value.name 243 244 index = indexes[0] 245 return arg.format[:index] + "%{" + var + "}" + arg.format[index + 2:] 246 247def nested_args_provider( 248 *, 249 label, 250 args = [], 251 nested = [], 252 files = depset([]), 253 iterate_over = None, 254 requires_not_none = None, 255 requires_none = None, 256 requires_true = None, 257 requires_false = None, 258 requires_equal = None, 259 requires_equal_value = "", 260 fail = fail): 261 """Creates a validated NestedArgsInfo. 262 263 Does not validate types, as you can't know the type of a variable until 264 you have a cc_args wrapping it, because the outer layers can change that 265 type using iterate_over. 266 267 Args: 268 label: (Label) The context we are currently evaluating in. Used for 269 error messages. 270 args: (List[str]) The command-line arguments to add. 271 nested: (List[NestedArgsInfo]) command-line arguments to expand. 272 files: (depset[File]) Files required for this set of command-line args. 273 iterate_over: (Optional[str]) Variable to iterate over 274 requires_not_none: (Optional[str]) If provided, this NestedArgsInfo will 275 be ignored if the variable is None 276 requires_none: (Optional[str]) If provided, this NestedArgsInfo will 277 be ignored if the variable is not None 278 requires_true: (Optional[str]) If provided, this NestedArgsInfo will 279 be ignored if the variable is false 280 requires_false: (Optional[str]) If provided, this NestedArgsInfo will 281 be ignored if the variable is true 282 requires_equal: (Optional[str]) If provided, this NestedArgsInfo will 283 be ignored if the variable is not equal to requires_equal_value. 284 requires_equal_value: (str) The value to compare the requires_equal 285 variable with 286 fail: A fail function. Use only for testing. 287 Returns: 288 NestedArgsInfo 289 """ 290 if bool(args) == bool(nested): 291 fail("Exactly one of args and nested must be provided") 292 293 transitive_files = [ea.files for ea in nested] 294 transitive_files.append(files) 295 296 has_value = [attr for attr in [ 297 requires_not_none, 298 requires_none, 299 requires_true, 300 requires_false, 301 requires_equal, 302 ] if attr != None] 303 304 # We may want to reconsider this down the line, but it's easier to open up 305 # an API than to lock down an API. 306 if len(has_value) > 1: 307 fail(REQUIRES_MUTUALLY_EXCLUSIVE_ERR) 308 309 kwargs = {} 310 requires_types = {} 311 if nested: 312 kwargs["flag_groups"] = [ea.legacy_flag_group for ea in nested] 313 314 unwrap_options = [] 315 316 if iterate_over: 317 kwargs["iterate_over"] = iterate_over 318 319 if requires_not_none: 320 kwargs["expand_if_available"] = requires_not_none 321 requires_types.setdefault(requires_not_none, []).append(struct( 322 msg = REQUIRES_NOT_NONE_ERR, 323 valid_types = ["option"], 324 after_option_unwrap = False, 325 )) 326 unwrap_options.append(requires_not_none) 327 elif requires_none: 328 kwargs["expand_if_not_available"] = requires_none 329 requires_types.setdefault(requires_none, []).append(struct( 330 msg = REQUIRES_NONE_ERR, 331 valid_types = ["option"], 332 after_option_unwrap = False, 333 )) 334 elif requires_true: 335 kwargs["expand_if_true"] = requires_true 336 requires_types.setdefault(requires_true, []).append(struct( 337 msg = REQUIRES_TRUE_ERR, 338 valid_types = ["bool"], 339 after_option_unwrap = True, 340 )) 341 unwrap_options.append(requires_true) 342 elif requires_false: 343 kwargs["expand_if_false"] = requires_false 344 requires_types.setdefault(requires_false, []).append(struct( 345 msg = REQUIRES_FALSE_ERR, 346 valid_types = ["bool"], 347 after_option_unwrap = True, 348 )) 349 unwrap_options.append(requires_false) 350 elif requires_equal: 351 if not requires_equal_value: 352 fail(REQUIRES_EQUAL_VALUE_ERR) 353 kwargs["expand_if_equal"] = variable_with_value( 354 name = requires_equal, 355 value = requires_equal_value, 356 ) 357 unwrap_options.append(requires_equal) 358 requires_types.setdefault(requires_equal, []).append(struct( 359 msg = REQUIRES_EQUAL_ERR, 360 valid_types = ["string"], 361 after_option_unwrap = True, 362 )) 363 364 for arg in args: 365 if arg.format_type != "raw": 366 var_name = arg.value.name if arg.value != None else iterate_over 367 requires_types.setdefault(var_name, []).append(struct( 368 msg = FORMAT_ARGS_ERR, 369 valid_types = ["string", "file", "directory"], 370 after_option_unwrap = True, 371 )) 372 373 if args: 374 kwargs["flags"] = [ 375 format_variable(arg, iterate_over = iterate_over, fail = fail) 376 for arg in args 377 ] 378 379 return NestedArgsInfo( 380 label = label, 381 nested = nested, 382 files = depset(transitive = transitive_files), 383 iterate_over = iterate_over, 384 unwrap_options = unwrap_options, 385 requires_types = requires_types, 386 legacy_flag_group = flag_group(**kwargs), 387 ) 388