1#!/usr/bin/env python3 2 3from __future__ import annotations 4 5import argparse 6import os 7 8import yaml 9 10from torchgen.code_template import CodeTemplate 11from torchgen.selective_build.selector import SelectiveBuilder 12 13 14# Safely load fast C Yaml loader/dumper if they are available 15try: 16 from yaml import CSafeLoader as Loader 17except ImportError: 18 from yaml import SafeLoader as Loader # type: ignore[assignment, misc] 19 20 21if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) { 22 return $dtype_checks; 23}""" 24if_condition_template = CodeTemplate(if_condition_template_str) 25 26selected_kernel_dtypes_h_template_str = """ 27#include <c10/core/ScalarType.h> 28#include <c10/util/string_view.h> 29#include <c10/macros/Macros.h> 30 31namespace at { 32inline constexpr bool should_include_kernel_dtype( 33 const char *kernel_tag_str, 34 at::ScalarType scalar_type 35) { 36 c10::string_view kernel_tag_sv C10_UNUSED = c10::string_view(kernel_tag_str); 37 $body 38 return false; 39} 40} 41""" 42selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str) 43 44selected_mobile_ops_preamble = """#pragma once 45/** 46 * Generated by gen_selected_mobile_ops_header.py 47 */ 48 49""" 50 51 52def extract_root_operators(selective_builder: SelectiveBuilder) -> set[str]: 53 ops = [] 54 for op_name, op in selective_builder.operators.items(): 55 if op.is_root_operator: 56 ops.append(op_name) 57 return set(ops) 58 59 60def get_selected_kernel_dtypes_code( 61 selective_builder: SelectiveBuilder, 62) -> str: 63 # See https://www.internalfb.com/intern/paste/P153411698/ for an example of the 64 # generated code in case all kernel dtypes are selected and in case some kernel 65 # dtypes are selected (i.e. both cases). 66 # 67 body = "return true;" 68 if ( 69 selective_builder.include_all_operators is False 70 and selective_builder.include_all_non_op_selectives is False 71 ): 72 body_parts = [] 73 for kernel_tag, dtypes in selective_builder.kernel_metadata.items(): 74 conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes] 75 body_parts.append( 76 if_condition_template.substitute( 77 kernel_tag_name=kernel_tag, 78 dtype_checks=" || ".join(conditions), 79 ), 80 ) 81 body = " else ".join(body_parts) 82 83 header_contents = selected_kernel_dtypes_h_template.substitute(body=body) 84 return header_contents 85 86 87# Write the file selected_mobile_ops.h with optionally: 88# 1. The selected root operators 89# 2. The selected kernel dtypes 90def write_selected_mobile_ops( 91 output_file_path: str, 92 selective_builder: SelectiveBuilder, 93) -> None: 94 root_ops = extract_root_operators(selective_builder) 95 custom_classes = selective_builder.custom_classes 96 build_features = selective_builder.build_features 97 with open(output_file_path, "wb") as out_file: 98 body_parts = [selected_mobile_ops_preamble] 99 # This condition checks if we are in selective build. 100 # if these lists are not defined the corresponding selective build macros trivially return the item in question was selected 101 if not selective_builder.include_all_operators: 102 body_parts.append( 103 "#define TORCH_OPERATOR_WHITELIST " 104 + (";".join(sorted(root_ops))) 105 + ";\n\n" 106 ) 107 # This condition checks if we are in tracing based selective build 108 if selective_builder.include_all_non_op_selectives is False: 109 body_parts.append( 110 "#define TORCH_CUSTOM_CLASS_ALLOWLIST " 111 + (";".join(sorted(custom_classes))) 112 + ";\n\n" 113 ) 114 body_parts.append( 115 "#define TORCH_BUILD_FEATURE_ALLOWLIST " 116 + (";".join(sorted(build_features))) 117 + ";\n\n" 118 ) 119 120 body_parts.append(get_selected_kernel_dtypes_code(selective_builder)) 121 header_contents = "".join(body_parts) 122 out_file.write(header_contents.encode("utf-8")) 123 124 125# root_ops: a set of selected root operators for selective build 126# Write the file selected_mobile_ops.h with optionally: 127# 1. The selected root operators from root_ops 128# 2. All kernel dtypes 129def write_selected_mobile_ops_with_all_dtypes( 130 output_file_path: str, 131 root_ops: set[str], 132) -> None: 133 with open(output_file_path, "wb") as out_file: 134 body_parts = [selected_mobile_ops_preamble] 135 body_parts.append( 136 "#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n" 137 ) 138 139 selective_builder = SelectiveBuilder.get_nop_selector() 140 body_parts.append(get_selected_kernel_dtypes_code(selective_builder)) 141 142 header_contents = "".join(body_parts) 143 out_file.write(header_contents.encode("utf-8")) 144 145 146def main() -> None: 147 parser = argparse.ArgumentParser( 148 description="Generate selected_mobile_ops.h for selective build." 149 ) 150 parser.add_argument( 151 "-p", 152 "--yaml-file-path", 153 "--yaml_file_path", 154 type=str, 155 required=True, 156 help="Path to the yaml file with a list of operators used by the model.", 157 ) 158 parser.add_argument( 159 "-o", 160 "--output-file-path", 161 "--output_file_path", 162 type=str, 163 required=True, 164 help="Path to destination" 165 "folder where selected_mobile_ops.h will be written.", 166 ) 167 parsed_args = parser.parse_args() 168 model_file_name = parsed_args.yaml_file_path 169 170 print("Loading yaml file: ", model_file_name) 171 loaded_model = {} 172 with open(model_file_name, "rb") as model_file: 173 loaded_model = yaml.load(model_file, Loader=Loader) 174 175 root_operators_set = set(loaded_model) 176 print("Writing header file selected_mobile_ops.h: ", parsed_args.output_file_path) 177 write_selected_mobile_ops_with_all_dtypes( 178 os.path.join(parsed_args.output_file_path, "selected_mobile_ops.h"), 179 root_operators_set, 180 ) 181 182 183if __name__ == "__main__": 184 main() 185