xref: /aosp_15_r20/external/pytorch/tools/lite_interpreter/gen_selected_mobile_ops_header.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3from __future__ import annotations
4
5import argparse
6import os
7
8import yaml
9
10from torchgen.code_template import CodeTemplate
11from torchgen.selective_build.selector import SelectiveBuilder
12
13
14# Safely load fast C Yaml loader/dumper if they are available
15try:
16    from yaml import CSafeLoader as Loader
17except ImportError:
18    from yaml import SafeLoader as Loader  # type: ignore[assignment, misc]
19
20
21if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {
22  return $dtype_checks;
23}"""
24if_condition_template = CodeTemplate(if_condition_template_str)
25
26selected_kernel_dtypes_h_template_str = """
27#include <c10/core/ScalarType.h>
28#include <c10/util/string_view.h>
29#include <c10/macros/Macros.h>
30
31namespace at {
32inline constexpr bool should_include_kernel_dtype(
33  const char *kernel_tag_str,
34  at::ScalarType scalar_type
35) {
36  c10::string_view kernel_tag_sv C10_UNUSED = c10::string_view(kernel_tag_str);
37  $body
38  return false;
39}
40}
41"""
42selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str)
43
44selected_mobile_ops_preamble = """#pragma once
45/**
46 * Generated by gen_selected_mobile_ops_header.py
47 */
48
49"""
50
51
52def extract_root_operators(selective_builder: SelectiveBuilder) -> set[str]:
53    ops = []
54    for op_name, op in selective_builder.operators.items():
55        if op.is_root_operator:
56            ops.append(op_name)
57    return set(ops)
58
59
60def get_selected_kernel_dtypes_code(
61    selective_builder: SelectiveBuilder,
62) -> str:
63    # See https://www.internalfb.com/intern/paste/P153411698/ for an example of the
64    # generated code in case all kernel dtypes are selected and in case some kernel
65    # dtypes are selected (i.e. both cases).
66    #
67    body = "return true;"
68    if (
69        selective_builder.include_all_operators is False
70        and selective_builder.include_all_non_op_selectives is False
71    ):
72        body_parts = []
73        for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
74            conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes]
75            body_parts.append(
76                if_condition_template.substitute(
77                    kernel_tag_name=kernel_tag,
78                    dtype_checks=" || ".join(conditions),
79                ),
80            )
81        body = " else ".join(body_parts)
82
83    header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
84    return header_contents
85
86
87# Write the file selected_mobile_ops.h with optionally:
88# 1. The selected root operators
89# 2. The selected kernel dtypes
90def write_selected_mobile_ops(
91    output_file_path: str,
92    selective_builder: SelectiveBuilder,
93) -> None:
94    root_ops = extract_root_operators(selective_builder)
95    custom_classes = selective_builder.custom_classes
96    build_features = selective_builder.build_features
97    with open(output_file_path, "wb") as out_file:
98        body_parts = [selected_mobile_ops_preamble]
99        # This condition checks if we are in selective build.
100        # if these lists are not defined the corresponding selective build macros trivially return the item in question was selected
101        if not selective_builder.include_all_operators:
102            body_parts.append(
103                "#define TORCH_OPERATOR_WHITELIST "
104                + (";".join(sorted(root_ops)))
105                + ";\n\n"
106            )
107            # This condition checks if we are in tracing based selective build
108            if selective_builder.include_all_non_op_selectives is False:
109                body_parts.append(
110                    "#define TORCH_CUSTOM_CLASS_ALLOWLIST "
111                    + (";".join(sorted(custom_classes)))
112                    + ";\n\n"
113                )
114                body_parts.append(
115                    "#define TORCH_BUILD_FEATURE_ALLOWLIST "
116                    + (";".join(sorted(build_features)))
117                    + ";\n\n"
118                )
119
120        body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
121        header_contents = "".join(body_parts)
122        out_file.write(header_contents.encode("utf-8"))
123
124
125# root_ops: a set of selected root operators for selective build
126# Write the file selected_mobile_ops.h with optionally:
127# 1. The selected root operators from root_ops
128# 2. All kernel dtypes
129def write_selected_mobile_ops_with_all_dtypes(
130    output_file_path: str,
131    root_ops: set[str],
132) -> None:
133    with open(output_file_path, "wb") as out_file:
134        body_parts = [selected_mobile_ops_preamble]
135        body_parts.append(
136            "#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n"
137        )
138
139        selective_builder = SelectiveBuilder.get_nop_selector()
140        body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
141
142        header_contents = "".join(body_parts)
143        out_file.write(header_contents.encode("utf-8"))
144
145
146def main() -> None:
147    parser = argparse.ArgumentParser(
148        description="Generate selected_mobile_ops.h for selective build."
149    )
150    parser.add_argument(
151        "-p",
152        "--yaml-file-path",
153        "--yaml_file_path",
154        type=str,
155        required=True,
156        help="Path to the yaml file with a list of operators used by the model.",
157    )
158    parser.add_argument(
159        "-o",
160        "--output-file-path",
161        "--output_file_path",
162        type=str,
163        required=True,
164        help="Path to destination"
165        "folder where selected_mobile_ops.h will be written.",
166    )
167    parsed_args = parser.parse_args()
168    model_file_name = parsed_args.yaml_file_path
169
170    print("Loading yaml file: ", model_file_name)
171    loaded_model = {}
172    with open(model_file_name, "rb") as model_file:
173        loaded_model = yaml.load(model_file, Loader=Loader)
174
175    root_operators_set = set(loaded_model)
176    print("Writing header file selected_mobile_ops.h: ", parsed_args.output_file_path)
177    write_selected_mobile_ops_with_all_dtypes(
178        os.path.join(parsed_args.output_file_path, "selected_mobile_ops.h"),
179        root_operators_set,
180    )
181
182
183if __name__ == "__main__":
184    main()
185