xref: /aosp_15_r20/external/tensorflow/third_party/systemlibs/grpc.bazel.protobuf.bzl (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Utility functions for generating protobuf code."""
2
3_PROTO_EXTENSION = ".proto"
4_VIRTUAL_IMPORTS = "/_virtual_imports/"
5
6def well_known_proto_libs():
7    return [
8        "@com_google_protobuf//:any_proto",
9        "@com_google_protobuf//:api_proto",
10        "@com_google_protobuf//:compiler_plugin_proto",
11        "@com_google_protobuf//:descriptor_proto",
12        "@com_google_protobuf//:duration_proto",
13        "@com_google_protobuf//:empty_proto",
14        "@com_google_protobuf//:field_mask_proto",
15        "@com_google_protobuf//:source_context_proto",
16        "@com_google_protobuf//:struct_proto",
17        "@com_google_protobuf//:timestamp_proto",
18        "@com_google_protobuf//:type_proto",
19        "@com_google_protobuf//:wrappers_proto",
20    ]
21
22def get_proto_root(workspace_root):
23    """Gets the root protobuf directory.
24
25    Args:
26      workspace_root: context.label.workspace_root
27
28    Returns:
29      The directory relative to which generated include paths should be.
30    """
31    if workspace_root:
32        return "/{}".format(workspace_root)
33    else:
34        return ""
35
36def _strip_proto_extension(proto_filename):
37    if not proto_filename.endswith(_PROTO_EXTENSION):
38        fail('"{}" does not end with "{}"'.format(
39            proto_filename,
40            _PROTO_EXTENSION,
41        ))
42    return proto_filename[:-len(_PROTO_EXTENSION)]
43
44def proto_path_to_generated_filename(proto_path, fmt_str):
45    """Calculates the name of a generated file for a protobuf path.
46
47    For example, "examples/protos/helloworld.proto" might map to
48      "helloworld.pb.h".
49
50    Args:
51      proto_path: The path to the .proto file.
52      fmt_str: A format string used to calculate the generated filename. For
53        example, "{}.pb.h" might be used to calculate a C++ header filename.
54
55    Returns:
56      The generated filename.
57    """
58    return fmt_str.format(_strip_proto_extension(proto_path))
59
60def get_include_directory(source_file):
61    """Returns the include directory path for the source_file.
62
63    I.e. all of the include statements within the given source_file
64    are calculated relative to the directory returned by this method.
65
66    The returned directory path can be used as the "--proto_path=" argument
67    value.
68
69    Args:
70      source_file: A proto file.
71
72    Returns:
73      The include directory path for the source_file.
74    """
75    directory = source_file.path
76    prefix_len = 0
77
78    if is_in_virtual_imports(source_file):
79        root, relative = source_file.path.split(_VIRTUAL_IMPORTS, 2)
80        result = root + _VIRTUAL_IMPORTS + relative.split("/", 1)[0]
81        return result
82
83    if not source_file.is_source and directory.startswith(source_file.root.path):
84        prefix_len = len(source_file.root.path) + 1
85
86    if directory.startswith("external", prefix_len):
87        external_separator = directory.find("/", prefix_len)
88        repository_separator = directory.find("/", external_separator + 1)
89        return directory[:repository_separator]
90    else:
91        return source_file.root.path if source_file.root.path else "."
92
93def get_plugin_args(
94        plugin,
95        flags,
96        dir_out,
97        generate_mocks,
98        plugin_name = "PLUGIN"):
99    """Returns arguments configuring protoc to use a plugin for a language.
100
101    Args:
102      plugin: An executable file to run as the protoc plugin.
103      flags: The plugin flags to be passed to protoc.
104      dir_out: The output directory for the plugin.
105      generate_mocks: A bool indicating whether to generate mocks.
106      plugin_name: A name of the plugin, it is required to be unique when there
107      are more than one plugin used in a single protoc command.
108    Returns:
109      A list of protoc arguments configuring the plugin.
110    """
111    augmented_flags = list(flags)
112    if generate_mocks:
113        augmented_flags.append("generate_mock_code=true")
114
115    augmented_dir_out = dir_out
116    if augmented_flags:
117        augmented_dir_out = ",".join(augmented_flags) + ":" + dir_out
118
119    return [
120        "--plugin=protoc-gen-{plugin_name}={plugin_path}".format(
121            plugin_name = plugin_name,
122            plugin_path = plugin.path,
123        ),
124        "--{plugin_name}_out={dir_out}".format(
125            plugin_name = plugin_name,
126            dir_out = augmented_dir_out,
127        ),
128    ]
129
130def _get_staged_proto_file(context, source_file):
131    if (source_file.dirname == context.label.package or
132        is_in_virtual_imports(source_file)):
133        return source_file
134    else:
135        copied_proto = context.actions.declare_file(source_file.basename)
136        context.actions.run_shell(
137            inputs = [source_file],
138            outputs = [copied_proto],
139            command = "cp {} {}".format(source_file.path, copied_proto.path),
140            mnemonic = "CopySourceProto",
141        )
142        return copied_proto
143
144def protos_from_context(context):
145    """Copies proto files to the appropriate location.
146
147    Args:
148      context: The ctx object for the rule.
149
150    Returns:
151      A list of the protos.
152    """
153    protos = []
154    for src in context.attr.deps:
155        for file in src[ProtoInfo].direct_sources:
156            protos.append(_get_staged_proto_file(context, file))
157    return protos
158
159def includes_from_deps(deps):
160    """Get includes from rule dependencies."""
161    return [
162        file
163        for src in deps
164        for file in src[ProtoInfo].transitive_imports.to_list()
165    ]
166
167def get_proto_arguments(protos, genfiles_dir_path):
168    """Get the protoc arguments specifying which protos to compile."""
169    arguments = []
170    for proto in protos:
171        strip_prefix_len = 0
172        if is_in_virtual_imports(proto):
173            incl_directory = get_include_directory(proto)
174            if proto.path.startswith(incl_directory):
175                strip_prefix_len = len(incl_directory) + 1
176        elif proto.path.startswith(genfiles_dir_path):
177            strip_prefix_len = len(genfiles_dir_path) + 1
178
179        arguments.append(proto.path[strip_prefix_len:])
180
181    return arguments
182
183def declare_out_files(protos, context, generated_file_format):
184    """Declares and returns the files to be generated."""
185
186    out_file_paths = []
187    for proto in protos:
188        if not is_in_virtual_imports(proto):
189            out_file_paths.append(proto.basename)
190        else:
191            path = proto.path[proto.path.index(_VIRTUAL_IMPORTS) + 1:]
192            out_file_paths.append(path)
193
194    return [
195        context.actions.declare_file(
196            proto_path_to_generated_filename(
197                out_file_path,
198                generated_file_format,
199            ),
200        )
201        for out_file_path in out_file_paths
202    ]
203
204def get_out_dir(protos, context):
205    """ Returns the calculated value for --<lang>_out= protoc argument based on
206    the input source proto files and current context.
207
208    Args:
209        protos: A list of protos to be used as source files in protoc command
210        context: A ctx object for the rule.
211    Returns:
212        The value of --<lang>_out= argument.
213    """
214    at_least_one_virtual = 0
215    for proto in protos:
216        if is_in_virtual_imports(proto):
217            at_least_one_virtual = True
218        elif at_least_one_virtual:
219            fail("Proto sources must be either all virtual imports or all real")
220    if at_least_one_virtual:
221        out_dir = get_include_directory(protos[0])
222        ws_root = protos[0].owner.workspace_root
223        if ws_root and out_dir.find(ws_root) >= 0:
224            out_dir = "".join(out_dir.rsplit(ws_root, 1))
225        return struct(
226            path = out_dir,
227            import_path = out_dir[out_dir.find(_VIRTUAL_IMPORTS) + 1:],
228        )
229    return struct(path = context.genfiles_dir.path, import_path = None)
230
231def is_in_virtual_imports(source_file, virtual_folder = _VIRTUAL_IMPORTS):
232    """Determines if source_file is virtual (is placed in _virtual_imports
233    subdirectory). The output of all proto_library targets which use
234    import_prefix  and/or strip_import_prefix arguments is placed under
235    _virtual_imports directory.
236
237    Args:
238        source_file: A proto file.
239        virtual_folder: The virtual folder name (is set to "_virtual_imports"
240            by default)
241    Returns:
242        True if source_file is located under _virtual_imports, False otherwise.
243    """
244    return not source_file.is_source and virtual_folder in source_file.path
245