xref: /aosp_15_r20/external/pytorch/torchgen/shape_functions/gen_jit_shape_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2import os
3import sys
4from importlib.util import module_from_spec, spec_from_file_location
5from itertools import chain
6from pathlib import Path
7
8
9# Manually importing the shape function module based on current directory
10# instead of torch imports to avoid needing to recompile Pytorch before
11# running the script
12
13file_path = Path.cwd() / "torch" / "jit" / "_shape_functions.py"
14module_name = "torch.jit._shape_functions"
15
16err_msg = """Could not find shape functions file, please make sure
17you are in the root directory of the Pytorch git repo"""
18if not file_path.exists():
19    raise Exception(err_msg)  # noqa: TRY002
20
21spec = spec_from_file_location(module_name, file_path)
22assert spec is not None
23module = module_from_spec(spec)
24sys.modules[module_name] = module
25assert spec.loader is not None
26assert module is not None
27spec.loader.exec_module(module)
28
29bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
30shape_compute_graph_mapping = module.shape_compute_graph_mapping
31
32
33SHAPE_HEADER = r"""
34/**
35 * @generated
36 * This is an auto-generated file. Please do not modify it by hand.
37 * To re-generate, please run:
38 * cd ~/pytorch && python
39 * torchgen/shape_functions/gen_jit_shape_functions.py
40 */
41#include <torch/csrc/jit/jit_log.h>
42#include <torch/csrc/jit/passes/inliner.h>
43#include <torch/csrc/jit/runtime/operator.h>
44#include <torch/csrc/jit/runtime/serialized_shape_function_registry.h>
45
46// clang-format off
47
48namespace torch {
49namespace jit {
50
51
52std::string shape_funcs = ""
53"""
54
55
56DECOMP_CENTER = r"""
57
58
59const std::string& GetSerializedShapeFunctions() {
60  return shape_funcs;
61}
62
63"""
64
65DECOMP_END = r"""
66// clang-format on
67
68} // namespace jit
69} // namespace torch
70"""
71
72
73SERIALIZED_SHAPE_UTIL_FILE_NAME = "serialized_shape_function_registry.cpp"
74
75
76def gen_serialized_decompisitions() -> str:
77    already_serialized_names = set()
78    unique_funcs = []
79    all_funcs = chain(
80        shape_compute_graph_mapping.values(), *bounded_compute_graph_mapping.values()
81    )
82    for scripted_func in all_funcs:
83        if scripted_func.name in already_serialized_names:
84            continue
85        already_serialized_names.add(scripted_func.name)
86        unique_funcs.append(scripted_func)
87
88    output_strs = []
89    curr_str = ""
90    for scripted_func in unique_funcs:
91        serialized_code = scripted_func.code
92        # technically its higher but give a buffer bc there are weird rules
93        # around some characters
94        # TODO: this was the limit I found by googling but it seems way
95        # too short ?
96        MAX_MSFT_STR_LEN = 2000
97        if len(curr_str) + len(serialized_code) <= MAX_MSFT_STR_LEN:
98            curr_str += "\n" + serialized_code
99        else:
100            output_strs.append(curr_str)
101            curr_str = scripted_func.code
102    output_strs.append(curr_str)
103
104    final_output = ""
105    # Windows compiler doesnt correctly handle adjacent
106    # string literals
107    for output_str in output_strs:
108        start = '+ std::string(R"=====('
109        end = '\n)=====")\n'
110        final_output += start + output_str + end
111    final_output += ";"
112    return final_output
113
114
115SHAPE_SCHEMA_START = r"""
116const OperatorMap<std::string>& GetShapeFunctionMappings() {
117 static const OperatorMap<std::string> shape_mappings {
118"""
119
120SHAPE_SCHEMA_END = r"""
121  };
122
123  return shape_mappings;
124}
125"""
126
127
128def gen_shape_mappings() -> str:
129    shape_mappings = []
130    for schema, scripted_func in shape_compute_graph_mapping.items():
131        shape_mappings.append('    {"' + schema + '", "' + scripted_func.name + '"},')
132    return SHAPE_SCHEMA_START + "\n".join(shape_mappings) + SHAPE_SCHEMA_END
133
134
135BOUNDED_SCHEMA_START = r"""
136const OperatorMap<std::pair<std::string, std::string>>& GetBoundedShapeMappings() {
137 static const OperatorMap<std::pair<std::string, std::string>> shape_mappings {
138"""
139
140
141def gen_bounded_mappings() -> str:
142    bounded_mappings = []
143    for schema, (lower_func, upper_func) in bounded_compute_graph_mapping.items():
144        map_str = (
145            '    {"'
146            + schema
147            + '", {"'
148            + lower_func.name
149            + '", "'
150            + upper_func.name
151            + '"}},'
152        )
153        bounded_mappings.append(map_str)
154    return BOUNDED_SCHEMA_START + "\n".join(bounded_mappings) + SHAPE_SCHEMA_END
155
156
157def write_decomposition_util_file(path: str) -> None:
158    decomposition_str = gen_serialized_decompisitions()
159    shape_mappings = gen_shape_mappings()
160    bounded_mappings = gen_bounded_mappings()
161    file_components = [
162        SHAPE_HEADER,
163        decomposition_str,
164        DECOMP_CENTER,
165        shape_mappings,
166        bounded_mappings,
167        DECOMP_END,
168    ]
169    print("writing file to : ", path + "/" + SERIALIZED_SHAPE_UTIL_FILE_NAME)
170    with open(os.path.join(path, SERIALIZED_SHAPE_UTIL_FILE_NAME), "wb") as out_file:
171        final_output = "".join(file_components)
172        out_file.write(final_output.encode("utf-8"))
173
174
175def main() -> None:
176    pytorch_dir = Path(__file__).resolve().parents[2]
177    upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime"
178    write_decomposition_util_file(str(upgrader_path))
179
180
181if __name__ == "__main__":
182    main()
183