1import argparse
2import datetime
3import re
4import sys
5import warnings
6from collections import defaultdict
7
8import torch
9from torch._C import parse_schema
10
11
12# How to run this test locally:
13# 1 Have two virtual environments (eg conda env), one without PyTorch installed (venv_nightly)
14#   one with your local changes (venv_yours).
15# In venv_nightly:
16# 2. First ensure that Pytorch is uninstalled, but all prereqs are installed
17# 3. Install torch nightly build with
18#    `pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html`
19# 4. Generate original schemas with
20#    `python test/forward_backward_compatibility/dump_all_function_schemas.py --filename nightly_schemas.txt`
21# Now in venv_yours:
22# 5. Run this test with
23#    `python test/forward_backward_compatibility/check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt`
24
25# The date specifies how long the allowlist exclusion should apply to.
26#
27#   - If we NEVER give BC guarantee for an operator, you can put the
28#     date arbitrarily far in the future.
29#   - Otherwise, pick a date that is far enough in the future that you
30#     believe you can land your diff before then.
31#
32# Allowlist entries can be removed after the date listed on them passes.
33#
34# Allowlist item format:
35# [
36#   0: function name regex
37#   1: date until which the allowlist entry is valid
38#   2: (optional) function argument regex
39# ]
40#
41# NB: function name DOES NOT include overload name!
42ALLOW_LIST = [
43    ("c10_experimental", datetime.date(9999, 1, 1)),
44    # Internal
45    ("static", datetime.date(9999, 1, 1)),
46    ("prim::ModuleDictIndex", datetime.date(9999, 1, 1)),
47    ("prim::MKLDNNRelu6", datetime.date(9999, 1, 1)),
48    ("prim::MKLDNNRelu6_", datetime.date(9999, 1, 1)),
49    ("prim::is_ort", datetime.date(9999, 1, 1)),
50    ("prim::Concat", datetime.date(9999, 1, 1)),
51    ("aten::_NestedTensor_GeneralizedBMM", datetime.date(9999, 1, 1)),
52    # Internal, profiler-specific ops
53    ("profiler::_call_end_callbacks_on_jit_fut*", datetime.date(9999, 1, 1)),
54    ("profiler::_record_function_enter", datetime.date(9999, 1, 1)),
55    ("aten::_cholesky_helper", datetime.date(9999, 1, 1)),
56    ("aten::_lstsq_helper", datetime.date(9999, 1, 1)),
57    ("aten::_syevd_helper", datetime.date(9999, 1, 1)),
58    ("aten::_linalg_solve_out_helper_", datetime.date(9999, 1, 1)),
59    ("aten::select_backward", datetime.date(9999, 1, 1)),
60    ("aten::lstsq", datetime.date(9999, 1, 1)),
61    ("aten::lstsq.X", datetime.date(9999, 1, 1)),
62    ("aten::slice_backward", datetime.date(9999, 1, 1)),
63    ("aten::diagonal_backward", datetime.date(9999, 1, 1)),
64    ("aten::rowwise_prune", datetime.date(9999, 1, 1)),
65    ("aten::eig", datetime.date(9999, 1, 1)),
66    ("aten::eig.e", datetime.date(9999, 1, 1)),
67    ("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)),
68    ("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
69    ("aten::matrix_rank", datetime.date(9999, 1, 1)),
70    ("aten::matrix_rank.tol", datetime.date(9999, 1, 1)),
71    ("aten::randperm", datetime.date(9999, 1, 1)),
72    ("aten::solve", datetime.date(9999, 1, 1)),
73    ("aten::solve.solution", datetime.date(9999, 1, 1)),
74    ("aten::_solve_helper", datetime.date(9999, 1, 1)),
75    ("aten::_convolution_nogroup", datetime.date(9999, 1, 1)),
76    ("aten::miopen_convolution_backward", datetime.date(9999, 1, 1)),
77    ("aten::miopen_convolution_backward_bias", datetime.date(9999, 1, 1)),
78    ("aten::miopen_convolution_backward_input", datetime.date(9999, 1, 1)),
79    ("aten::miopen_convolution_backward_weight", datetime.date(9999, 1, 1)),
80    ("aten::miopen_convolution_transpose_backward", datetime.date(9999, 1, 1)),
81    ("aten::miopen_convolution_transpose_backward_input", datetime.date(9999, 1, 1)),
82    ("aten::miopen_convolution_transpose_backward_weight", datetime.date(9999, 1, 1)),
83    ("aten::miopen_depthwise_convolution_backward", datetime.date(9999, 1, 1)),
84    ("aten::miopen_depthwise_convolution_backward_input", datetime.date(9999, 1, 1)),
85    ("aten::miopen_depthwise_convolution_backward_weight", datetime.date(9999, 1, 1)),
86    ("aten::_nested_tensor", datetime.date(9999, 1, 1)),
87    ("prepacked::unpack_prepacked_sizes_conv2d", datetime.date(9999, 1, 1)),
88    ("prepacked::unpack_prepacked_sizes_linear", datetime.date(9999, 1, 1)),
89    ("aten::_symeig_helper", datetime.date(9999, 1, 1)),
90    ("aten::symeig", datetime.date(9999, 1, 1)),
91    ("aten::symeig.e", datetime.date(9999, 1, 1)),
92    ("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)),
93    ("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)),
94    ("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)),
95    ("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)),
96    ("prim::infer_squeeze_size.dim", datetime.date(9999, 1, 1)),
97    ("prim::infer_squeeze_size", datetime.date(9999, 1, 1)),
98    ("aten::_weight_norm_cuda_interface", datetime.date(9999, 1, 1)),
99    ("aten::_weight_norm_cuda_interface_backward", datetime.date(9999, 1, 1)),
100    ("aten::empty.SymInt", datetime.date(9999, 1, 1)),
101    # nested tensor temporary auxiliary ops
102    ("aten::_reshape_nested", datetime.date(9999, 1, 1)),
103    ("aten::_reshape_nested_backward", datetime.date(9999, 1, 1)),
104    ("aten::mps_linear", datetime.date(9999, 1, 1)),
105    ("aten::_mps_linear", datetime.date(9999, 1, 1)),
106    ("aten::_mps_max_pool2d", datetime.date(9999, 1, 1)),
107    ("aten::_mps_max_pool2d.out", datetime.date(9999, 1, 1)),
108    ("aten::mps_max_pool2d_backward", datetime.date(9999, 1, 1)),
109    ("aten::mps_max_pool2d_backward.out", datetime.date(9999, 1, 1)),
110    # TODO: FIXME: prims shouldn't be checked
111    ("prims::.*", datetime.date(9999, 1, 1)),
112    ("aten::_flash_attention_forward", datetime.date(2023, 12, 30)),
113    ("aten::_flash_attention_backward", datetime.date(2023, 12, 30)),
114    ("aten::_scaled_dot_product_cudnn_attention", datetime.date(9999, 1, 1)),
115    ("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
116    # BetterTransformer 1.0 internal operators
117    ("aten::_transformer_decoder_only_layer_fwd", datetime.date(9999, 1, 1)),
118    ("aten::_native_decoder_only_multi_head_attention", datetime.date(9999, 1, 1)),
119    ("c10d::_allgather_base_", datetime.date(2023, 12, 30)),
120    ("c10d::_reduce_scatter_base_", datetime.date(2023, 12, 30)),
121    ("c10d::broadcast_", datetime.date(2023, 12, 30)),
122    ("c10d::scatter_", datetime.date(2023, 12, 30)),
123    # These ops were moved to python under the c10d_functional namespace
124    ("aten::wait_tensor", datetime.date(9999, 1, 30)),
125    ("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
126    ("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)),
127    ("aten::all_reduce", datetime.date(9999, 1, 30)),
128    ("aten::to_sparse.out", datetime.date(2023, 12, 31)),
129    ("aten::to_sparse.sparse_dim_out", datetime.date(2023, 12, 31)),
130    ("aten::to_sparse_bsc.out", datetime.date(2023, 12, 31)),
131    ("aten::to_sparse_bsr.out", datetime.date(2023, 12, 31)),
132    ("aten::to_sparse_csc.out", datetime.date(2023, 12, 31)),
133    ("aten::to_sparse_csr.out", datetime.date(2023, 12, 31)),
134    ("aten::_structured_sparse_linear", datetime.date(2023, 12, 31)),
135    ("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)),
136    ("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)),
137    ("aten::sym_constrain_range", datetime.date(2023, 12, 31)),
138    ("aten::_efficient_attention_forward", datetime.date(2024, 7, 1)),
139    ("aten::_efficient_attention_backward", datetime.date(2024, 7, 1)),
140    ("onednn::qconv1d_pointwise", datetime.date(2024, 12, 31)),
141    ("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)),
142    ("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)),
143    ("onednn::qconv2d_pointwise.binary", datetime.date(2024, 12, 31)),
144    ("onednn::qlinear_pointwise.binary", datetime.date(2024, 12, 31)),
145    ("onednn::qlinear_pointwise.binary_tensor", datetime.date(2024, 12, 31)),
146    ("aten::_scaled_mm.out", datetime.date(2024, 12, 31)),
147    ("aten::_scaled_mm", datetime.date(2024, 12, 31)),
148    ("aten::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)),
149    ("aten::wrapped_linear_prepack", datetime.date(2024, 12, 31)),
150    ("_quantized::wrapped_linear_prepack", datetime.date(2024, 12, 31)),
151    ("_quantized::wrapped_linear_prepacked", datetime.date(2024, 12, 31)),
152    ("_quantized::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)),
153    # BC-breaking change in can_cast signature: 'from' -> 'from_'
154    ("aten::can_cast", datetime.date(2024, 5, 31)),
155]
156
157ALLOW_LIST_COMPILED = [
158    (
159        re.compile(item[0]),
160        item[1],
161        re.compile(item[2]) if len(item) > 2 else None,
162    )
163    for item in ALLOW_LIST
164    if item[1] >= datetime.date.today()
165]
166
167
168def allow_listed(schema):
169    for item in ALLOW_LIST_COMPILED:
170        if item[0].search(str(schema)):
171            if len(item) > 2 and item[2] is not None:
172                # if arguments regex is present, use it
173                return bool(item[2].search(str(schema)))
174            return True
175    return False
176
177
178# The nightly will fail to parse newly added syntax to schema declarations
179# Add new schemas that will fail the nightly here
180dont_parse_list = [
181    ("_TorchScriptTesting.*", datetime.date(2099, 9, 17)),
182    ("test_backend", datetime.date(2099, 9, 17)),
183    ("dist_c10d", datetime.date(2099, 9, 17)),
184    ("__backends__.nnc", datetime.date(2099, 9, 17)),
185]
186
187
188def has_valid_upgraders(schema, version_map):
189    # we want to parse through the map to find if
190    # the schema has valid upgraders. Since the
191    # version map has entry for each overload
192    # we need to do some ugly parsing.
193
194    # the name of the operator
195    schema_name = schema.name
196
197    if schema_name not in version_map:
198        return False
199
200    entries = version_map[schema_name]
201
202    possible_overloads = []
203    possible_schemas = []
204    for key, upgrader_schema_entries in entries.items():
205        possible_overloads.append(key)
206        possible_schemas.extend(upgrader_schema_entries)
207
208    # let's make sure this existing schema is part of possible
209    # schemas
210    for old_schema in possible_schemas:
211        if old_schema == schema:
212            return True
213
214    return False
215
216
217def dont_parse(schema_line):
218    for item in dont_parse_list:
219        if item[1] < datetime.date.today():
220            continue
221        regexp = re.compile(item[0])
222        if regexp.search(schema_line):
223            return True
224    return False
225
226
227def load_schemas_to_dict():
228    new_schemas = torch._C._jit_get_all_schemas()
229    new_schemas += torch._C._jit_get_custom_class_schemas()
230    new_schema_dict = defaultdict(list)
231    for s in new_schemas:
232        new_schema_dict[s.name].append(s)
233    return new_schema_dict
234
235
236def process_version_map(version_map):
237    # version map maps full schema name to
238    # list of upgraders. Since we only have
239    # the name of the schema (aka no overload)
240    # we want to first process the map to make
241    # the key lookup easier. After this it will be:
242    # Dict[schema_name, Dict[overload, List[schema]]]
243
244    output = defaultdict(dict)
245    for key, entries in version_map.items():
246        operator_name = key.split(".")[0]
247        schema_entries = [parse_schema(entry.old_schema) for entry in entries]
248        output[operator_name][key] = schema_entries
249    return output
250
251
252def check_bc(existing_schemas):
253    new_schema_dict = load_schemas_to_dict()
254    version_map = process_version_map(torch._C._get_operator_version_map())
255    is_bc = True
256    broken_ops = []
257    for existing_schema in existing_schemas:
258        if allow_listed(existing_schema):
259            print("schema: ", str(existing_schema), " found on allowlist, skipping")
260            continue
261        if has_valid_upgraders(existing_schema, version_map):
262            print("schema: ", str(existing_schema), " has valid upgrader, skipping")
263            continue
264        print("processing existing schema: ", str(existing_schema))
265        matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
266        found = False
267        for matching_new_schema in matching_new_schemas:
268            if matching_new_schema.is_backward_compatible_with(existing_schema):
269                found = True
270                break
271        if not found:
272            print(
273                "Can NOT find backward compatible schemas after changes "
274                "for schema {} from the following candidates:\n[\n{}\n]".format(
275                    str(existing_schema),
276                    "\n\t".join(str(s) for s in matching_new_schemas),
277                )
278            )
279            # TODO Print out more details about why candidates don't match.
280            broken_ops.append(str(existing_schema))
281            is_bc = False
282    if is_bc:
283        print("Found backward compatible schemas for all existing schemas")
284    else:
285        print(
286            "The PR is introducing backward incompatible changes to the "
287            "operator library. Please contact PyTorch team to confirm "
288            "whether this change is wanted or not. \n\nBroken ops: "
289            "[\n\t{}\n]".format("\n\t".join(broken_ops))
290        )
291    return is_bc
292
293
294def check_fc(existing_schemas):
295    new_schema_dict = load_schemas_to_dict()
296    is_fc = True
297    broken_ops = []
298    for existing_schema in existing_schemas:
299        if allow_listed(existing_schema):
300            print("schema: ", str(existing_schema), " found on allowlist, skipping")
301            continue
302        print("processing existing schema: ", str(existing_schema))
303        matching_new_schemas = new_schema_dict.get(existing_schema.name, [])
304        found = False
305        possible_failure_reasons = []
306        for matching_new_schema in matching_new_schemas:
307            is_compatible, reason = matching_new_schema.check_forward_compatible_with(
308                existing_schema
309            )
310            if is_compatible:
311                found = True
312                break
313            if reason != "":
314                possible_failure_reasons.append(reason)
315        if not found:
316            print(
317                "Can NOT find forward compatible schemas after changes "
318                "for schema {} from the following candidates:\n[\n{}\n]".format(
319                    str(existing_schema),
320                    "\n\t".join(str(s) for s in matching_new_schemas),
321                )
322            )
323            print(
324                "Refer to following reasons for failure "
325                "to find FC schema:\n[\n{}\n]".format(
326                    "\n\t".join(str(r) for r in possible_failure_reasons)
327                )
328            )
329            broken_ops.append(str(existing_schema))
330            is_fc = False
331    if is_fc:
332        print("Found forward compatible schemas for all existing schemas")
333    else:
334        warnings.warn(
335            "The PR is introducing a potentially forward incompatible changes to the "
336            "operator library. Please contact PyTorch team to confirm "
337            "whether this change is wanted or not. \n\nBroken ops: "
338            "[\n\t{}\n]".format("\n\t".join(broken_ops))
339        )
340
341
342if __name__ == "__main__":
343    parser = argparse.ArgumentParser(description="Process some integers.")
344    parser.add_argument(
345        "--existing-schemas",
346        help="filename to load existing schemas",
347        type=str,
348        default="schemas.txt",
349    )
350    args = parser.parse_args()
351    existing_schema_dict = {}
352    slist = []
353    with open(args.existing_schemas) as f:
354        while True:
355            line = f.readline()
356            if not line:
357                break
358
359            if dont_parse(line.strip()):
360                print("Not parsing schema line: ", line.strip())
361                continue
362            s = parse_schema(line.strip())
363            slist.append(s)
364
365    # TODO in case there is FC breaking changes,
366    # we just warn for now until there is a policy.
367    check_fc(slist)
368
369    if not check_bc(slist):
370        sys.exit(1)
371