xref: /aosp_15_r20/external/bazelbuild-rules_rust/proto/prost/private/prost.bzl (revision d4726bddaa87cc4778e7472feed243fa4b6c267f)
1"""Rules for building protos in Rust with Prost and Tonic."""
2
3load("@rules_proto//proto:defs.bzl", "ProtoInfo", "proto_common")
4load("@rules_proto//proto:proto_common.bzl", proto_toolchains = "toolchains")
5load("//proto/prost:providers.bzl", "ProstProtoInfo")
6load("//rust:defs.bzl", "rust_common")
7
8# buildifier: disable=bzl-visibility
9load("//rust/private:providers.bzl", "RustAnalyzerGroupInfo", "RustAnalyzerInfo")
10
11# buildifier: disable=bzl-visibility
12load("//rust/private:rust.bzl", "RUSTC_ATTRS")
13
14# buildifier: disable=bzl-visibility
15load("//rust/private:rust_analyzer.bzl", "write_rust_analyzer_spec_file")
16
17# buildifier: disable=bzl-visibility
18load("//rust/private:rustc.bzl", "rustc_compile_action")
19
20# buildifier: disable=bzl-visibility
21load("//rust/private:utils.bzl", "can_build_metadata")
22
23RUST_EDITION = "2021"
24
25TOOLCHAIN_TYPE = "@rules_rust//proto/prost:toolchain_type"
26
27def _create_proto_lang_toolchain(ctx, prost_toolchain):
28    proto_lang_toolchain = proto_common.ProtoLangToolchainInfo(
29        out_replacement_format_flag = "--prost_out=%s",
30        plugin_format_flag = prost_toolchain.prost_plugin_flag,
31        plugin = prost_toolchain.prost_plugin[DefaultInfo].files_to_run,
32        runtime = prost_toolchain.prost_runtime,
33        provided_proto_sources = depset(),
34        proto_compiler = ctx.attr._prost_process_wrapper[DefaultInfo].files_to_run,
35        protoc_opts = prost_toolchain.protoc_opts,
36        progress_message = "ProstGenProto %{label}",
37        mnemonic = "ProstGenProto",
38    )
39
40    return proto_lang_toolchain
41
42def _compile_proto(ctx, crate_name, proto_info, deps, prost_toolchain, rustfmt_toolchain = None):
43    deps_info_file = ctx.actions.declare_file(ctx.label.name + ".prost_deps_info")
44    dep_package_infos = [dep[ProstProtoInfo].package_info for dep in deps]
45    ctx.actions.write(
46        output = deps_info_file,
47        content = "\n".join([file.path for file in dep_package_infos]),
48    )
49
50    package_info_file = ctx.actions.declare_file(ctx.label.name + ".prost_package_info")
51    lib_rs = ctx.actions.declare_file("{}.lib.rs".format(ctx.label.name))
52
53    proto_compiler = prost_toolchain.proto_compiler
54    tools = depset([proto_compiler.executable])
55
56    additional_args = ctx.actions.args()
57
58    # Prost process wrapper specific args
59    additional_args.add("--protoc={}".format(proto_compiler.executable.path))
60    additional_args.add("--label={}".format(ctx.label))
61    additional_args.add("--out_librs={}".format(lib_rs.path))
62    additional_args.add("--package_info_output={}".format("{}={}".format(crate_name, package_info_file.path)))
63    additional_args.add("--deps_info={}".format(deps_info_file.path))
64    additional_args.add("--prost_opt=compile_well_known_types")
65    additional_args.add("--descriptor_set={}".format(proto_info.direct_descriptor_set.path))
66    additional_args.add_all(prost_toolchain.prost_opts, format_each = "--prost_opt=%s")
67
68    if prost_toolchain.tonic_plugin:
69        tonic_plugin = prost_toolchain.tonic_plugin[DefaultInfo].files_to_run
70        additional_args.add(prost_toolchain.tonic_plugin_flag % tonic_plugin.executable.path)
71        additional_args.add("--tonic_opt=no_include")
72        additional_args.add("--tonic_opt=compile_well_known_types")
73        additional_args.add("--is_tonic")
74        additional_args.add_all(prost_toolchain.tonic_opts, format_each = "--tonic_opt=%s")
75        tools = depset([tonic_plugin.executable], transitive = [tools])
76
77    if rustfmt_toolchain:
78        additional_args.add("--rustfmt={}".format(rustfmt_toolchain.rustfmt.path))
79        tools = depset(transitive = [tools, rustfmt_toolchain.all_files])
80
81    additional_inputs = depset([deps_info_file, proto_info.direct_descriptor_set] + [dep[ProstProtoInfo].package_info for dep in deps])
82
83    proto_common.compile(
84        actions = ctx.actions,
85        proto_info = proto_info,
86        additional_tools = tools.to_list(),
87        additional_inputs = additional_inputs,
88        additional_args = additional_args,
89        generated_files = [lib_rs, package_info_file],
90        proto_lang_toolchain_info = _create_proto_lang_toolchain(ctx, prost_toolchain),
91        plugin_output = ctx.bin_dir.path,
92    )
93
94    return lib_rs, package_info_file
95
96def _get_crate_info(providers):
97    """Finds the CrateInfo provider in the list of providers."""
98    for provider in providers:
99        if hasattr(provider, "name"):
100            return provider
101    fail("Couldn't find a CrateInfo in the list of providers")
102
103def _get_dep_info(providers):
104    """Finds the DepInfo provider in the list of providers."""
105    for provider in providers:
106        if hasattr(provider, "direct_crates"):
107            return provider
108    fail("Couldn't find a DepInfo in the list of providers")
109
110def _get_cc_info(providers):
111    """Finds the CcInfo provider in the list of providers."""
112    for provider in providers:
113        if hasattr(provider, "linking_context"):
114            return provider
115    fail("Couldn't find a CcInfo in the list of providers")
116
117def _compile_rust(ctx, attr, crate_name, src, deps, edition):
118    """Compiles a Rust source file.
119
120    Args:
121      ctx (RuleContext): The rule context.
122      attr (Attrs): The current rule's attributes (`ctx.attr` for rules, `ctx.rule.attr` for aspects)
123      crate_name (str): The crate module name to use.
124      src (File): The crate root source file to be compiled.
125      deps (List of DepVariantInfo): A list of dependencies needed.
126      edition (str): The Rust edition to use.
127
128    Returns:
129      A DepVariantInfo provider.
130    """
131    toolchain = ctx.toolchains["@rules_rust//rust:toolchain_type"]
132    output_hash = repr(hash(src.path + ".prost"))
133
134    lib_name = "{prefix}{name}-{lib_hash}{extension}".format(
135        prefix = "lib",
136        name = crate_name,
137        lib_hash = output_hash,
138        extension = ".rlib",
139    )
140
141    rmeta_name = "{prefix}{name}-{lib_hash}{extension}".format(
142        prefix = "lib",
143        name = crate_name,
144        lib_hash = output_hash,
145        extension = ".rmeta",
146    )
147
148    lib = ctx.actions.declare_file(lib_name)
149    rmeta = None
150
151    if can_build_metadata(toolchain, ctx, "rlib"):
152        rmeta_name = "{prefix}{name}-{lib_hash}{extension}".format(
153            prefix = "lib",
154            name = crate_name,
155            lib_hash = output_hash,
156            extension = ".rmeta",
157        )
158        rmeta = ctx.actions.declare_file(rmeta_name)
159
160    providers = rustc_compile_action(
161        ctx = ctx,
162        attr = attr,
163        toolchain = toolchain,
164        crate_info_dict = dict(
165            name = crate_name,
166            type = "rlib",
167            root = src,
168            srcs = depset([src]),
169            deps = depset(deps),
170            proc_macro_deps = depset([]),
171            aliases = {},
172            output = lib,
173            metadata = rmeta,
174            edition = edition,
175            is_test = False,
176            rustc_env = {},
177            compile_data = depset([]),
178            compile_data_targets = depset([]),
179            owner = ctx.label,
180        ),
181        output_hash = output_hash,
182    )
183
184    crate_info = _get_crate_info(providers)
185    dep_info = _get_dep_info(providers)
186    cc_info = _get_cc_info(providers)
187
188    return rust_common.dep_variant_info(
189        crate_info = crate_info,
190        dep_info = dep_info,
191        cc_info = cc_info,
192        build_info = None,
193    )
194
195def _rust_prost_aspect_impl(target, ctx):
196    if ProstProtoInfo in target:
197        return []
198
199    runtime_deps = []
200
201    rustfmt_toolchain = ctx.toolchains["@rules_rust//rust/rustfmt:toolchain_type"]
202    prost_toolchain = ctx.toolchains["@rules_rust//proto/prost:toolchain_type"]
203    for prost_runtime in [prost_toolchain.prost_runtime, prost_toolchain.tonic_runtime]:
204        if not prost_runtime:
205            continue
206        if rust_common.crate_group_info in prost_runtime:
207            crate_group_info = prost_runtime[rust_common.crate_group_info]
208            runtime_deps.extend(crate_group_info.dep_variant_infos.to_list())
209        else:
210            runtime_deps.append(rust_common.dep_variant_info(
211                crate_info = prost_runtime[rust_common.crate_info] if rust_common.crate_info in prost_runtime else None,
212                dep_info = prost_runtime[rust_common.dep_info] if rust_common.dep_info in prost_runtime else None,
213                cc_info = prost_runtime[CcInfo] if CcInfo in prost_runtime else None,
214                build_info = None,
215            ))
216
217    proto_deps = getattr(ctx.rule.attr, "deps", [])
218
219    direct_deps = []
220    transitive_deps = [depset(runtime_deps)]
221    rust_analyzer_deps = []
222    for proto_dep in proto_deps:
223        proto_info = proto_dep[ProstProtoInfo]
224
225        direct_deps.append(proto_info.dep_variant_info)
226        transitive_deps.append(depset(
227            [proto_info.dep_variant_info],
228            transitive = [proto_info.transitive_dep_infos],
229        ))
230
231        if RustAnalyzerInfo in proto_dep:
232            rust_analyzer_deps.append(proto_dep[RustAnalyzerInfo])
233
234    deps = runtime_deps + direct_deps
235
236    crate_name = ctx.label.name.replace("-", "_").replace("/", "_")
237
238    proto_info = target[ProtoInfo]
239
240    lib_rs, package_info_file = _compile_proto(
241        ctx = ctx,
242        crate_name = crate_name,
243        proto_info = proto_info,
244        deps = proto_deps,
245        prost_toolchain = prost_toolchain,
246        rustfmt_toolchain = rustfmt_toolchain,
247    )
248
249    dep_variant_info = _compile_rust(
250        ctx = ctx,
251        attr = ctx.rule.attr,
252        crate_name = crate_name,
253        src = lib_rs,
254        deps = deps,
255        edition = RUST_EDITION,
256    )
257
258    # Always add `test` & `debug_assertions`. See rust-analyzer source code:
259    # https://github.com/rust-analyzer/rust-analyzer/blob/2021-11-15/crates/project_model/src/workspace.rs#L529-L531
260    cfgs = ["test", "debug_assertions"]
261
262    rust_analyzer_info = write_rust_analyzer_spec_file(ctx, ctx.rule.attr, ctx.label, RustAnalyzerInfo(
263        aliases = {},
264        crate = dep_variant_info.crate_info,
265        cfgs = cfgs,
266        env = dep_variant_info.crate_info.rustc_env,
267        deps = rust_analyzer_deps,
268        crate_specs = depset(transitive = [dep.crate_specs for dep in rust_analyzer_deps]),
269        proc_macro_dylib_path = None,
270        build_info = dep_variant_info.build_info,
271    ))
272
273    return [
274        ProstProtoInfo(
275            dep_variant_info = dep_variant_info,
276            transitive_dep_infos = depset(transitive = transitive_deps),
277            package_info = package_info_file,
278        ),
279        rust_analyzer_info,
280        OutputGroupInfo(rust_generated_srcs = [lib_rs]),
281    ]
282
283rust_prost_aspect = aspect(
284    doc = "An aspect used to generate and compile proto files with Prost.",
285    implementation = _rust_prost_aspect_impl,
286    attr_aspects = ["deps"],
287    attrs = {
288        "_collect_cc_coverage": attr.label(
289            default = Label("//util:collect_coverage"),
290            executable = True,
291            cfg = "exec",
292        ),
293        "_grep_includes": attr.label(
294            allow_single_file = True,
295            default = Label("@bazel_tools//tools/cpp:grep-includes"),
296            cfg = "exec",
297        ),
298        "_prost_process_wrapper": attr.label(
299            doc = "The wrapper script for the Prost protoc plugin.",
300            cfg = "exec",
301            executable = True,
302            default = Label("//proto/prost/private:protoc_wrapper"),
303        ),
304    } | RUSTC_ATTRS,
305    fragments = ["cpp"],
306    toolchains = [
307        TOOLCHAIN_TYPE,
308        "@bazel_tools//tools/cpp:toolchain_type",
309        "@rules_rust//rust:toolchain_type",
310        "@rules_rust//rust/rustfmt:toolchain_type",
311    ],
312)
313
314def _rust_prost_library_impl(ctx):
315    proto_dep = ctx.attr.proto
316    rust_proto_info = proto_dep[ProstProtoInfo]
317    dep_variant_info = rust_proto_info.dep_variant_info
318
319    return [
320        DefaultInfo(files = depset([dep_variant_info.crate_info.output])),
321        rust_common.crate_group_info(
322            dep_variant_infos = depset(
323                [dep_variant_info],
324                transitive = [rust_proto_info.transitive_dep_infos],
325            ),
326        ),
327        RustAnalyzerGroupInfo(deps = [proto_dep[RustAnalyzerInfo]]),
328    ]
329
330rust_prost_library = rule(
331    doc = "A rule for generating a Rust library using Prost.",
332    implementation = _rust_prost_library_impl,
333    attrs = {
334        "proto": attr.label(
335            doc = "A `proto_library` target for which to generate Rust gencode.",
336            providers = [ProtoInfo],
337            aspects = [rust_prost_aspect],
338            mandatory = True,
339        ),
340        "_collect_cc_coverage": attr.label(
341            default = Label("@rules_rust//util:collect_coverage"),
342            executable = True,
343            cfg = "exec",
344        ),
345    },
346)
347
348def _rust_prost_toolchain_impl(ctx):
349    tonic_attrs = [ctx.attr.tonic_plugin_flag, ctx.attr.tonic_plugin, ctx.attr.tonic_runtime]
350    if any(tonic_attrs) and not all(tonic_attrs):
351        fail("When one tonic attribute is added, all must be added")
352
353    if ctx.attr.proto_compiler:
354        # buildifier: disable=print
355        print("WARN: rust_prost_toolchain's proto_compiler attribute is deprecated. Make sure your rules_proto dependency is at least version 6.0.0 and stop setting proto_compiler")
356
357    proto_toolchain = proto_toolchains.find_toolchain(
358        ctx,
359        legacy_attr = "_legacy_proto_toolchain",
360        toolchain_type = "@rules_proto//proto:toolchain_type",
361    )
362
363    return [platform_common.ToolchainInfo(
364        prost_opts = ctx.attr.prost_opts,
365        prost_plugin = ctx.attr.prost_plugin,
366        prost_plugin_flag = ctx.attr.prost_plugin_flag,
367        prost_runtime = ctx.attr.prost_runtime,
368        prost_types = ctx.attr.prost_types,
369        proto_compiler = ctx.attr.proto_compiler or proto_toolchain.proto_compiler,
370        protoc_opts = ctx.fragments.proto.experimental_protoc_opts,
371        tonic_opts = ctx.attr.tonic_opts,
372        tonic_plugin = ctx.attr.tonic_plugin,
373        tonic_plugin_flag = ctx.attr.tonic_plugin_flag,
374        tonic_runtime = ctx.attr.tonic_runtime,
375    )]
376
377rust_prost_toolchain = rule(
378    implementation = _rust_prost_toolchain_impl,
379    doc = "Rust Prost toolchain rule.",
380    fragments = ["proto"],
381    attrs = dict({
382        "prost_opts": attr.string_list(
383            doc = "Additional options to add to Prost.",
384        ),
385        "prost_plugin": attr.label(
386            doc = "Additional plugins to add to Prost.",
387            cfg = "exec",
388            executable = True,
389            mandatory = True,
390        ),
391        "prost_plugin_flag": attr.string(
392            doc = "Prost plugin flag format. (e.g. `--plugin=protoc-gen-prost=%s`)",
393            default = "--plugin=protoc-gen-prost=%s",
394        ),
395        "prost_runtime": attr.label(
396            doc = "The Prost runtime crates to use.",
397            providers = [[rust_common.crate_info], [rust_common.crate_group_info]],
398            mandatory = True,
399        ),
400        "prost_types": attr.label(
401            doc = "The Prost types crates to use.",
402            providers = [[rust_common.crate_info], [rust_common.crate_group_info]],
403            mandatory = True,
404        ),
405        "proto_compiler": attr.label(
406            doc = "The protoc compiler to use. Note that this attribute is deprecated - prefer to use --incompatible_enable_proto_toolchain_resolution.",
407            cfg = "exec",
408            executable = True,
409        ),
410        "tonic_opts": attr.string_list(
411            doc = "Additional options to add to Tonic.",
412        ),
413        "tonic_plugin": attr.label(
414            doc = "Additional plugins to add to Tonic.",
415            cfg = "exec",
416            executable = True,
417        ),
418        "tonic_plugin_flag": attr.string(
419            doc = "Tonic plugin flag format. (e.g. `--plugin=protoc-gen-tonic=%s`))",
420            default = "--plugin=protoc-gen-tonic=%s",
421        ),
422        "tonic_runtime": attr.label(
423            doc = "The Tonic runtime crates to use.",
424            providers = [[rust_common.crate_info], [rust_common.crate_group_info]],
425        ),
426    }, **proto_toolchains.if_legacy_toolchain({
427        "_legacy_proto_toolchain": attr.label(
428            default = "//proto/protobuf:legacy_proto_toolchain",
429        ),
430    })),
431    toolchains = proto_toolchains.use_toolchain("@rules_proto//proto:toolchain_type"),
432)
433
434def _current_prost_runtime_impl(ctx):
435    toolchain = ctx.toolchains[TOOLCHAIN_TYPE]
436
437    runtime_deps = []
438
439    for target in [toolchain.prost_runtime, toolchain.prost_types]:
440        if rust_common.crate_group_info in target:
441            crate_group_info = target[rust_common.crate_group_info]
442            runtime_deps.extend(crate_group_info.dep_variant_infos.to_list())
443        else:
444            runtime_deps.append(rust_common.dep_variant_info(
445                crate_info = target[rust_common.crate_info] if rust_common.crate_info in target else None,
446                dep_info = target[rust_common.dep_info] if rust_common.dep_info in target else None,
447                cc_info = target[CcInfo] if CcInfo in target else None,
448                build_info = None,
449            ))
450
451    return [rust_common.crate_group_info(
452        dep_variant_infos = depset(runtime_deps),
453    )]
454
455current_prost_runtime = rule(
456    doc = "A rule for accessing the current Prost toolchain components needed by the process wrapper",
457    provides = [rust_common.crate_group_info],
458    implementation = _current_prost_runtime_impl,
459    toolchains = [TOOLCHAIN_TYPE],
460)
461