1import argparse 2import datetime 3import re 4import sys 5import warnings 6from collections import defaultdict 7 8import torch 9from torch._C import parse_schema 10 11 12# How to run this test locally: 13# 1 Have two virtual environments (eg conda env), one without PyTorch installed (venv_nightly) 14# one with your local changes (venv_yours). 15# In venv_nightly: 16# 2. First ensure that Pytorch is uninstalled, but all prereqs are installed 17# 3. Install torch nightly build with 18# `pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html` 19# 4. Generate original schemas with 20# `python test/forward_backward_compatibility/dump_all_function_schemas.py --filename nightly_schemas.txt` 21# Now in venv_yours: 22# 5. Run this test with 23# `python test/forward_backward_compatibility/check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt` 24 25# The date specifies how long the allowlist exclusion should apply to. 26# 27# - If we NEVER give BC guarantee for an operator, you can put the 28# date arbitrarily far in the future. 29# - Otherwise, pick a date that is far enough in the future that you 30# believe you can land your diff before then. 31# 32# Allowlist entries can be removed after the date listed on them passes. 33# 34# Allowlist item format: 35# [ 36# 0: function name regex 37# 1: date until which the allowlist entry is valid 38# 2: (optional) function argument regex 39# ] 40# 41# NB: function name DOES NOT include overload name! 42ALLOW_LIST = [ 43 ("c10_experimental", datetime.date(9999, 1, 1)), 44 # Internal 45 ("static", datetime.date(9999, 1, 1)), 46 ("prim::ModuleDictIndex", datetime.date(9999, 1, 1)), 47 ("prim::MKLDNNRelu6", datetime.date(9999, 1, 1)), 48 ("prim::MKLDNNRelu6_", datetime.date(9999, 1, 1)), 49 ("prim::is_ort", datetime.date(9999, 1, 1)), 50 ("prim::Concat", datetime.date(9999, 1, 1)), 51 ("aten::_NestedTensor_GeneralizedBMM", datetime.date(9999, 1, 1)), 52 # Internal, profiler-specific ops 53 ("profiler::_call_end_callbacks_on_jit_fut*", datetime.date(9999, 1, 1)), 54 ("profiler::_record_function_enter", datetime.date(9999, 1, 1)), 55 ("aten::_cholesky_helper", datetime.date(9999, 1, 1)), 56 ("aten::_lstsq_helper", datetime.date(9999, 1, 1)), 57 ("aten::_syevd_helper", datetime.date(9999, 1, 1)), 58 ("aten::_linalg_solve_out_helper_", datetime.date(9999, 1, 1)), 59 ("aten::select_backward", datetime.date(9999, 1, 1)), 60 ("aten::lstsq", datetime.date(9999, 1, 1)), 61 ("aten::lstsq.X", datetime.date(9999, 1, 1)), 62 ("aten::slice_backward", datetime.date(9999, 1, 1)), 63 ("aten::diagonal_backward", datetime.date(9999, 1, 1)), 64 ("aten::rowwise_prune", datetime.date(9999, 1, 1)), 65 ("aten::eig", datetime.date(9999, 1, 1)), 66 ("aten::eig.e", datetime.date(9999, 1, 1)), 67 ("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)), 68 ("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)), 69 ("aten::matrix_rank", datetime.date(9999, 1, 1)), 70 ("aten::matrix_rank.tol", datetime.date(9999, 1, 1)), 71 ("aten::randperm", datetime.date(9999, 1, 1)), 72 ("aten::solve", datetime.date(9999, 1, 1)), 73 ("aten::solve.solution", datetime.date(9999, 1, 1)), 74 ("aten::_solve_helper", datetime.date(9999, 1, 1)), 75 ("aten::_convolution_nogroup", datetime.date(9999, 1, 1)), 76 ("aten::miopen_convolution_backward", datetime.date(9999, 1, 1)), 77 ("aten::miopen_convolution_backward_bias", datetime.date(9999, 1, 1)), 78 ("aten::miopen_convolution_backward_input", datetime.date(9999, 1, 1)), 79 ("aten::miopen_convolution_backward_weight", datetime.date(9999, 1, 1)), 80 ("aten::miopen_convolution_transpose_backward", datetime.date(9999, 1, 1)), 81 ("aten::miopen_convolution_transpose_backward_input", datetime.date(9999, 1, 1)), 82 ("aten::miopen_convolution_transpose_backward_weight", datetime.date(9999, 1, 1)), 83 ("aten::miopen_depthwise_convolution_backward", datetime.date(9999, 1, 1)), 84 ("aten::miopen_depthwise_convolution_backward_input", datetime.date(9999, 1, 1)), 85 ("aten::miopen_depthwise_convolution_backward_weight", datetime.date(9999, 1, 1)), 86 ("aten::_nested_tensor", datetime.date(9999, 1, 1)), 87 ("prepacked::unpack_prepacked_sizes_conv2d", datetime.date(9999, 1, 1)), 88 ("prepacked::unpack_prepacked_sizes_linear", datetime.date(9999, 1, 1)), 89 ("aten::_symeig_helper", datetime.date(9999, 1, 1)), 90 ("aten::symeig", datetime.date(9999, 1, 1)), 91 ("aten::symeig.e", datetime.date(9999, 1, 1)), 92 ("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)), 93 ("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)), 94 ("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)), 95 ("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)), 96 ("prim::infer_squeeze_size.dim", datetime.date(9999, 1, 1)), 97 ("prim::infer_squeeze_size", datetime.date(9999, 1, 1)), 98 ("aten::_weight_norm_cuda_interface", datetime.date(9999, 1, 1)), 99 ("aten::_weight_norm_cuda_interface_backward", datetime.date(9999, 1, 1)), 100 ("aten::empty.SymInt", datetime.date(9999, 1, 1)), 101 # nested tensor temporary auxiliary ops 102 ("aten::_reshape_nested", datetime.date(9999, 1, 1)), 103 ("aten::_reshape_nested_backward", datetime.date(9999, 1, 1)), 104 ("aten::mps_linear", datetime.date(9999, 1, 1)), 105 ("aten::_mps_linear", datetime.date(9999, 1, 1)), 106 ("aten::_mps_max_pool2d", datetime.date(9999, 1, 1)), 107 ("aten::_mps_max_pool2d.out", datetime.date(9999, 1, 1)), 108 ("aten::mps_max_pool2d_backward", datetime.date(9999, 1, 1)), 109 ("aten::mps_max_pool2d_backward.out", datetime.date(9999, 1, 1)), 110 # TODO: FIXME: prims shouldn't be checked 111 ("prims::.*", datetime.date(9999, 1, 1)), 112 ("aten::_flash_attention_forward", datetime.date(2023, 12, 30)), 113 ("aten::_flash_attention_backward", datetime.date(2023, 12, 30)), 114 ("aten::_scaled_dot_product_cudnn_attention", datetime.date(9999, 1, 1)), 115 ("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)), 116 # BetterTransformer 1.0 internal operators 117 ("aten::_transformer_decoder_only_layer_fwd", datetime.date(9999, 1, 1)), 118 ("aten::_native_decoder_only_multi_head_attention", datetime.date(9999, 1, 1)), 119 ("c10d::_allgather_base_", datetime.date(2023, 12, 30)), 120 ("c10d::_reduce_scatter_base_", datetime.date(2023, 12, 30)), 121 ("c10d::broadcast_", datetime.date(2023, 12, 30)), 122 ("c10d::scatter_", datetime.date(2023, 12, 30)), 123 # These ops were moved to python under the c10d_functional namespace 124 ("aten::wait_tensor", datetime.date(9999, 1, 30)), 125 ("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)), 126 ("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)), 127 ("aten::all_reduce", datetime.date(9999, 1, 30)), 128 ("aten::to_sparse.out", datetime.date(2023, 12, 31)), 129 ("aten::to_sparse.sparse_dim_out", datetime.date(2023, 12, 31)), 130 ("aten::to_sparse_bsc.out", datetime.date(2023, 12, 31)), 131 ("aten::to_sparse_bsr.out", datetime.date(2023, 12, 31)), 132 ("aten::to_sparse_csc.out", datetime.date(2023, 12, 31)), 133 ("aten::to_sparse_csr.out", datetime.date(2023, 12, 31)), 134 ("aten::_structured_sparse_linear", datetime.date(2023, 12, 31)), 135 ("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)), 136 ("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)), 137 ("aten::sym_constrain_range", datetime.date(2023, 12, 31)), 138 ("aten::_efficient_attention_forward", datetime.date(2024, 7, 1)), 139 ("aten::_efficient_attention_backward", datetime.date(2024, 7, 1)), 140 ("onednn::qconv1d_pointwise", datetime.date(2024, 12, 31)), 141 ("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)), 142 ("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)), 143 ("onednn::qconv2d_pointwise.binary", datetime.date(2024, 12, 31)), 144 ("onednn::qlinear_pointwise.binary", datetime.date(2024, 12, 31)), 145 ("onednn::qlinear_pointwise.binary_tensor", datetime.date(2024, 12, 31)), 146 ("aten::_scaled_mm.out", datetime.date(2024, 12, 31)), 147 ("aten::_scaled_mm", datetime.date(2024, 12, 31)), 148 ("aten::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)), 149 ("aten::wrapped_linear_prepack", datetime.date(2024, 12, 31)), 150 ("_quantized::wrapped_linear_prepack", datetime.date(2024, 12, 31)), 151 ("_quantized::wrapped_linear_prepacked", datetime.date(2024, 12, 31)), 152 ("_quantized::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)), 153 # BC-breaking change in can_cast signature: 'from' -> 'from_' 154 ("aten::can_cast", datetime.date(2024, 5, 31)), 155] 156 157ALLOW_LIST_COMPILED = [ 158 ( 159 re.compile(item[0]), 160 item[1], 161 re.compile(item[2]) if len(item) > 2 else None, 162 ) 163 for item in ALLOW_LIST 164 if item[1] >= datetime.date.today() 165] 166 167 168def allow_listed(schema): 169 for item in ALLOW_LIST_COMPILED: 170 if item[0].search(str(schema)): 171 if len(item) > 2 and item[2] is not None: 172 # if arguments regex is present, use it 173 return bool(item[2].search(str(schema))) 174 return True 175 return False 176 177 178# The nightly will fail to parse newly added syntax to schema declarations 179# Add new schemas that will fail the nightly here 180dont_parse_list = [ 181 ("_TorchScriptTesting.*", datetime.date(2099, 9, 17)), 182 ("test_backend", datetime.date(2099, 9, 17)), 183 ("dist_c10d", datetime.date(2099, 9, 17)), 184 ("__backends__.nnc", datetime.date(2099, 9, 17)), 185] 186 187 188def has_valid_upgraders(schema, version_map): 189 # we want to parse through the map to find if 190 # the schema has valid upgraders. Since the 191 # version map has entry for each overload 192 # we need to do some ugly parsing. 193 194 # the name of the operator 195 schema_name = schema.name 196 197 if schema_name not in version_map: 198 return False 199 200 entries = version_map[schema_name] 201 202 possible_overloads = [] 203 possible_schemas = [] 204 for key, upgrader_schema_entries in entries.items(): 205 possible_overloads.append(key) 206 possible_schemas.extend(upgrader_schema_entries) 207 208 # let's make sure this existing schema is part of possible 209 # schemas 210 for old_schema in possible_schemas: 211 if old_schema == schema: 212 return True 213 214 return False 215 216 217def dont_parse(schema_line): 218 for item in dont_parse_list: 219 if item[1] < datetime.date.today(): 220 continue 221 regexp = re.compile(item[0]) 222 if regexp.search(schema_line): 223 return True 224 return False 225 226 227def load_schemas_to_dict(): 228 new_schemas = torch._C._jit_get_all_schemas() 229 new_schemas += torch._C._jit_get_custom_class_schemas() 230 new_schema_dict = defaultdict(list) 231 for s in new_schemas: 232 new_schema_dict[s.name].append(s) 233 return new_schema_dict 234 235 236def process_version_map(version_map): 237 # version map maps full schema name to 238 # list of upgraders. Since we only have 239 # the name of the schema (aka no overload) 240 # we want to first process the map to make 241 # the key lookup easier. After this it will be: 242 # Dict[schema_name, Dict[overload, List[schema]]] 243 244 output = defaultdict(dict) 245 for key, entries in version_map.items(): 246 operator_name = key.split(".")[0] 247 schema_entries = [parse_schema(entry.old_schema) for entry in entries] 248 output[operator_name][key] = schema_entries 249 return output 250 251 252def check_bc(existing_schemas): 253 new_schema_dict = load_schemas_to_dict() 254 version_map = process_version_map(torch._C._get_operator_version_map()) 255 is_bc = True 256 broken_ops = [] 257 for existing_schema in existing_schemas: 258 if allow_listed(existing_schema): 259 print("schema: ", str(existing_schema), " found on allowlist, skipping") 260 continue 261 if has_valid_upgraders(existing_schema, version_map): 262 print("schema: ", str(existing_schema), " has valid upgrader, skipping") 263 continue 264 print("processing existing schema: ", str(existing_schema)) 265 matching_new_schemas = new_schema_dict.get(existing_schema.name, []) 266 found = False 267 for matching_new_schema in matching_new_schemas: 268 if matching_new_schema.is_backward_compatible_with(existing_schema): 269 found = True 270 break 271 if not found: 272 print( 273 "Can NOT find backward compatible schemas after changes " 274 "for schema {} from the following candidates:\n[\n{}\n]".format( 275 str(existing_schema), 276 "\n\t".join(str(s) for s in matching_new_schemas), 277 ) 278 ) 279 # TODO Print out more details about why candidates don't match. 280 broken_ops.append(str(existing_schema)) 281 is_bc = False 282 if is_bc: 283 print("Found backward compatible schemas for all existing schemas") 284 else: 285 print( 286 "The PR is introducing backward incompatible changes to the " 287 "operator library. Please contact PyTorch team to confirm " 288 "whether this change is wanted or not. \n\nBroken ops: " 289 "[\n\t{}\n]".format("\n\t".join(broken_ops)) 290 ) 291 return is_bc 292 293 294def check_fc(existing_schemas): 295 new_schema_dict = load_schemas_to_dict() 296 is_fc = True 297 broken_ops = [] 298 for existing_schema in existing_schemas: 299 if allow_listed(existing_schema): 300 print("schema: ", str(existing_schema), " found on allowlist, skipping") 301 continue 302 print("processing existing schema: ", str(existing_schema)) 303 matching_new_schemas = new_schema_dict.get(existing_schema.name, []) 304 found = False 305 possible_failure_reasons = [] 306 for matching_new_schema in matching_new_schemas: 307 is_compatible, reason = matching_new_schema.check_forward_compatible_with( 308 existing_schema 309 ) 310 if is_compatible: 311 found = True 312 break 313 if reason != "": 314 possible_failure_reasons.append(reason) 315 if not found: 316 print( 317 "Can NOT find forward compatible schemas after changes " 318 "for schema {} from the following candidates:\n[\n{}\n]".format( 319 str(existing_schema), 320 "\n\t".join(str(s) for s in matching_new_schemas), 321 ) 322 ) 323 print( 324 "Refer to following reasons for failure " 325 "to find FC schema:\n[\n{}\n]".format( 326 "\n\t".join(str(r) for r in possible_failure_reasons) 327 ) 328 ) 329 broken_ops.append(str(existing_schema)) 330 is_fc = False 331 if is_fc: 332 print("Found forward compatible schemas for all existing schemas") 333 else: 334 warnings.warn( 335 "The PR is introducing a potentially forward incompatible changes to the " 336 "operator library. Please contact PyTorch team to confirm " 337 "whether this change is wanted or not. \n\nBroken ops: " 338 "[\n\t{}\n]".format("\n\t".join(broken_ops)) 339 ) 340 341 342if __name__ == "__main__": 343 parser = argparse.ArgumentParser(description="Process some integers.") 344 parser.add_argument( 345 "--existing-schemas", 346 help="filename to load existing schemas", 347 type=str, 348 default="schemas.txt", 349 ) 350 args = parser.parse_args() 351 existing_schema_dict = {} 352 slist = [] 353 with open(args.existing_schemas) as f: 354 while True: 355 line = f.readline() 356 if not line: 357 break 358 359 if dont_parse(line.strip()): 360 print("Not parsing schema line: ", line.strip()) 361 continue 362 s = parse_schema(line.strip()) 363 slist.append(s) 364 365 # TODO in case there is FC breaking changes, 366 # we just warn for now until there is a policy. 367 check_fc(slist) 368 369 if not check_bc(slist): 370 sys.exit(1) 371