1#!/usr/bin/env python3 2import os 3import sys 4from importlib.util import module_from_spec, spec_from_file_location 5from itertools import chain 6from pathlib import Path 7 8 9# Manually importing the shape function module based on current directory 10# instead of torch imports to avoid needing to recompile Pytorch before 11# running the script 12 13file_path = Path.cwd() / "torch" / "jit" / "_shape_functions.py" 14module_name = "torch.jit._shape_functions" 15 16err_msg = """Could not find shape functions file, please make sure 17you are in the root directory of the Pytorch git repo""" 18if not file_path.exists(): 19 raise Exception(err_msg) # noqa: TRY002 20 21spec = spec_from_file_location(module_name, file_path) 22assert spec is not None 23module = module_from_spec(spec) 24sys.modules[module_name] = module 25assert spec.loader is not None 26assert module is not None 27spec.loader.exec_module(module) 28 29bounded_compute_graph_mapping = module.bounded_compute_graph_mapping 30shape_compute_graph_mapping = module.shape_compute_graph_mapping 31 32 33SHAPE_HEADER = r""" 34/** 35 * @generated 36 * This is an auto-generated file. Please do not modify it by hand. 37 * To re-generate, please run: 38 * cd ~/pytorch && python 39 * torchgen/shape_functions/gen_jit_shape_functions.py 40 */ 41#include <torch/csrc/jit/jit_log.h> 42#include <torch/csrc/jit/passes/inliner.h> 43#include <torch/csrc/jit/runtime/operator.h> 44#include <torch/csrc/jit/runtime/serialized_shape_function_registry.h> 45 46// clang-format off 47 48namespace torch { 49namespace jit { 50 51 52std::string shape_funcs = "" 53""" 54 55 56DECOMP_CENTER = r""" 57 58 59const std::string& GetSerializedShapeFunctions() { 60 return shape_funcs; 61} 62 63""" 64 65DECOMP_END = r""" 66// clang-format on 67 68} // namespace jit 69} // namespace torch 70""" 71 72 73SERIALIZED_SHAPE_UTIL_FILE_NAME = "serialized_shape_function_registry.cpp" 74 75 76def gen_serialized_decompisitions() -> str: 77 already_serialized_names = set() 78 unique_funcs = [] 79 all_funcs = chain( 80 shape_compute_graph_mapping.values(), *bounded_compute_graph_mapping.values() 81 ) 82 for scripted_func in all_funcs: 83 if scripted_func.name in already_serialized_names: 84 continue 85 already_serialized_names.add(scripted_func.name) 86 unique_funcs.append(scripted_func) 87 88 output_strs = [] 89 curr_str = "" 90 for scripted_func in unique_funcs: 91 serialized_code = scripted_func.code 92 # technically its higher but give a buffer bc there are weird rules 93 # around some characters 94 # TODO: this was the limit I found by googling but it seems way 95 # too short ? 96 MAX_MSFT_STR_LEN = 2000 97 if len(curr_str) + len(serialized_code) <= MAX_MSFT_STR_LEN: 98 curr_str += "\n" + serialized_code 99 else: 100 output_strs.append(curr_str) 101 curr_str = scripted_func.code 102 output_strs.append(curr_str) 103 104 final_output = "" 105 # Windows compiler doesnt correctly handle adjacent 106 # string literals 107 for output_str in output_strs: 108 start = '+ std::string(R"=====(' 109 end = '\n)=====")\n' 110 final_output += start + output_str + end 111 final_output += ";" 112 return final_output 113 114 115SHAPE_SCHEMA_START = r""" 116const OperatorMap<std::string>& GetShapeFunctionMappings() { 117 static const OperatorMap<std::string> shape_mappings { 118""" 119 120SHAPE_SCHEMA_END = r""" 121 }; 122 123 return shape_mappings; 124} 125""" 126 127 128def gen_shape_mappings() -> str: 129 shape_mappings = [] 130 for schema, scripted_func in shape_compute_graph_mapping.items(): 131 shape_mappings.append(' {"' + schema + '", "' + scripted_func.name + '"},') 132 return SHAPE_SCHEMA_START + "\n".join(shape_mappings) + SHAPE_SCHEMA_END 133 134 135BOUNDED_SCHEMA_START = r""" 136const OperatorMap<std::pair<std::string, std::string>>& GetBoundedShapeMappings() { 137 static const OperatorMap<std::pair<std::string, std::string>> shape_mappings { 138""" 139 140 141def gen_bounded_mappings() -> str: 142 bounded_mappings = [] 143 for schema, (lower_func, upper_func) in bounded_compute_graph_mapping.items(): 144 map_str = ( 145 ' {"' 146 + schema 147 + '", {"' 148 + lower_func.name 149 + '", "' 150 + upper_func.name 151 + '"}},' 152 ) 153 bounded_mappings.append(map_str) 154 return BOUNDED_SCHEMA_START + "\n".join(bounded_mappings) + SHAPE_SCHEMA_END 155 156 157def write_decomposition_util_file(path: str) -> None: 158 decomposition_str = gen_serialized_decompisitions() 159 shape_mappings = gen_shape_mappings() 160 bounded_mappings = gen_bounded_mappings() 161 file_components = [ 162 SHAPE_HEADER, 163 decomposition_str, 164 DECOMP_CENTER, 165 shape_mappings, 166 bounded_mappings, 167 DECOMP_END, 168 ] 169 print("writing file to : ", path + "/" + SERIALIZED_SHAPE_UTIL_FILE_NAME) 170 with open(os.path.join(path, SERIALIZED_SHAPE_UTIL_FILE_NAME), "wb") as out_file: 171 final_output = "".join(file_components) 172 out_file.write(final_output.encode("utf-8")) 173 174 175def main() -> None: 176 pytorch_dir = Path(__file__).resolve().parents[2] 177 upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime" 178 write_decomposition_util_file(str(upgrader_path)) 179 180 181if __name__ == "__main__": 182 main() 183