xref: /aosp_15_r20/external/pytorch/test/inductor/test_torchinductor_opinfo.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import atexit
3import contextlib
4import functools
5import os
6import sys
7import unittest
8from collections import defaultdict
9from enum import Enum
10from functools import partial
11from unittest.mock import patch
12
13import torch
14from torch._dispatch.python import enable_python_dispatcher
15from torch._inductor.test_case import run_tests, TestCase
16from torch._subclasses.fake_tensor import (
17    DataDependentOutputException,
18    DynamicOutputShapeException,
19    FakeTensorMode,
20)
21from torch.testing._internal.common_cuda import SM80OrLater
22from torch.testing._internal.common_device_type import (
23    instantiate_device_type_tests,
24    onlyNativeDeviceTypes,
25    OpDTypes,
26    ops,
27    skipCPUIf,
28    skipCUDAIf,
29)
30from torch.testing._internal.common_methods_invocations import op_db, skipOps
31from torch.testing._internal.common_utils import (
32    dtype_abbrs,
33    IS_MACOS,
34    IS_X86,
35    skipCUDAMemoryLeakCheckIf,
36    skipIfCrossRef,
37    skipIfTorchDynamo,
38    suppress_warnings,
39    TEST_MKL,
40    TEST_WITH_ASAN,
41    TEST_WITH_ROCM,
42)
43from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA
44from torch.utils._python_dispatch import TorchDispatchMode
45from torch.utils._pytree import tree_map
46
47
48try:
49    try:
50        from .test_torchinductor import check_model, check_model_gpu
51    except ImportError:
52        from test_torchinductor import check_model, check_model_gpu
53except (unittest.SkipTest, ImportError) as e:
54    sys.stderr.write(f"{type(e)}: {e}\n")
55    if __name__ == "__main__":
56        sys.exit(0)
57    raise
58
59bf16 = torch.bfloat16  # not tested
60f64 = torch.float64
61f32 = torch.float32
62f16 = torch.float16
63i8 = torch.int8  # not tested
64i16 = torch.int16  # not tested
65i32 = torch.int32
66i64 = torch.int64
67b8 = torch.bool
68u8 = torch.uint8  # not tested except upsampling and interpolate ops
69u16 = torch.uint16  # not tested
70u32 = torch.uint32  # not tested
71u64 = torch.uint64  # not tested
72
73_ops = partial(
74    ops,
75    dtypes=OpDTypes.supported,
76    allowed_dtypes=[f16, f32, f64, i32, i64, b8, u8, u16, u32, u64],
77)
78
79# Success forces pass; failure forces fail; skip unconditionally skips testing
80ExpectedTestResult = Enum("ExpectedTestResult", ("SUCCESS", "XFAILURE", "SKIP"))
81
82COLLECT_EXPECT = os.getenv("PYTORCH_COLLECT_EXPECT", "0") == "1"
83ALL_SAMPLES = os.getenv("PYTORCH_ALL_SAMPLES", "0") == "1"
84START = os.getenv("PYTORCH_TEST_RANGE_START", None)
85END = os.getenv("PYTORCH_TEST_RANGE_END", None)
86
87if START is not None or END is not None:
88    assert END is not None
89    assert START is not None
90    START = int(START)
91    END = int(END)
92    assert START < END
93else:
94    START = 0
95    END = len(op_db)
96
97seen_failed = defaultdict(set)
98failed_reasons = defaultdict(set)
99
100
101def print_seen():
102    expected_failures = defaultdict(list)
103
104    def fmt_dtypes(dtypes):
105        r = ", ".join(sorted(dtype_abbrs[d] for d in dtypes))
106        return "{" + r + "}"
107
108    def sort_key(kv):
109        k, v = kv
110        device_type, op = k
111        if isinstance(op, tuple):
112            return op
113        else:
114            return op, ""
115
116    for (device_type, op), failed_dtypes in sorted(seen_failed.items(), key=sort_key):
117        key = device_type, op
118        reasons = ""
119        if failed_reasons[key]:
120
121            def maybe_truncate(x, length=80):
122                x = str(x).replace("\n", " ")
123
124                idx = x.find("\\n")
125                if idx >= 0:
126                    x = f"{x[:idx]}..."
127                if len(x) > length:
128                    return f"{x[:length - 3]}..."
129                return x
130
131            reasons = sorted(set(map(maybe_truncate, failed_reasons[key])))
132            reasons = "  # " + ", ".join(reasons)
133
134        if failed_dtypes:
135
136            def format_op(op):
137                if isinstance(op, tuple):
138                    return f'("{op[0]}", "{op[1]}")'
139                else:
140                    return f'"{op}"'
141
142            expected_failures[device_type].append(
143                f"    {format_op(op)}: {fmt_dtypes(failed_dtypes)},{reasons}"
144            )
145
146    for device_type in ("cpu", GPU_TYPE):
147        expected_failures[device_type]
148        nl = "\n"
149        print(
150            f"""
151inductor_expected_failures_single_sample[\"{device_type}\"] = {{
152{nl.join(expected_failures[device_type])}
153}}
154"""
155        )
156
157
158if COLLECT_EXPECT:
159    atexit.register(print_seen)
160
161# Note, in these skip/xfail dictionaries use a string as the key
162# for the default test, and a tuple of two strings for variants
163
164inductor_skips = defaultdict(dict)
165
166
167inductor_skips["cpu"] = {
168    "linalg.ldl_factor": {f32, f64},  # flaky
169    "nn.functional.cosine_embedding_loss": {b8},  # flaky
170    ("index_reduce", "prod"): {f16},  # flaky
171    ("index_reduce", "mean"): {f16},  # flaky
172}
173
174if IS_MACOS and IS_X86:
175    inductor_skips["cpu"]["rsqrt"] = {b8, i32}
176    inductor_skips["cpu"]["nn.functional.multi_margin_loss"] = {
177        b8,
178        f16,
179        f32,
180        f64,
181        i32,
182        i64,
183    }
184
185inductor_skips["cuda"] = {
186    # Jiterator kernel is not expected to work with inductor
187    "jiterator_2inputs_2outputs": {b8, f16, f32, f64, i32, i64},
188    "jiterator_4inputs_with_extra_args": {b8, f16, f32, f64, i32, i64},
189    "jiterator_binary": {b8, f16, f32, f64, i32, i64},
190    "jiterator_binary_return_by_ref": {b8, f16, f32, f64, i32, i64},
191    "jiterator_unary": {b8, f16, f32, f64, i32, i64},
192    # flaky
193    "nn.functional.cosine_embedding_loss": {b8},
194    "native_batch_norm": {f16, f32, f64},
195    "_native_batch_norm_legit": {f16, f32, f64},
196    "_batch_norm_with_update": {f16, f32, f64},
197}
198
199if not SM80OrLater:
200    inductor_skips["cuda"]["bfloat16"] = {b8, f16, f32, f64, i32, i64}
201
202if TEST_WITH_ROCM:
203    # Tensors are not alike
204    inductor_skips["cuda"]["logcumsumexp"] = {f32}
205    inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
206
207inductor_expected_failures_single_sample = defaultdict(dict)
208
209inductor_expected_failures_single_sample["cpu"] = {
210    "_softmax_backward_data": {
211        f16
212    },  # half_to_float is only valid for the CUDA implementation
213    "_upsample_bilinear2d_aa": {f32, f64},
214    "cholesky": {f32, f64},
215    "complex": {f16},
216    "resize_": {b8, f16, f32, f64, i32, i64},
217    "resize_as_": {b8, f16, f32, f64, i32, i64},
218    "histc": {f16},
219    "multinomial": {f16, f32, f64},
220    "nn.functional.avg_pool1d": {i64},
221    "nn.functional.avg_pool2d": {i64},
222    "nn.functional.avg_pool3d": {i64},
223    "nn.functional.local_response_norm": {i64},
224    "nn.functional.rrelu": {f32, f64},
225    "nonzero_static": {b8, f16, f32, f64, i32, i64},
226    ("normal", "in_place"): {f16, f32, f64},
227    ("normal", "number_mean"): {f16, f32, f64},
228    "normal": {f16, f32, f64},
229    ("sparse.mm", "reduce"): {f32, f64, f16},
230    "sparse.sampled_addmm": {f32, f64},
231    "to_sparse": {
232        f32,
233        f64,
234    },  # NYI: could not find kernel for aten.view.default at dispatch key DispatchKey.SparseCPU
235    "view_as_complex": {f16},
236}
237
238
239inductor_expected_failures_single_sample["cuda"] = {
240    "_upsample_bilinear2d_aa": {f16, f32, f64},
241    "cholesky": {f32, f64},
242    "multinomial": {f16, f32, f64},
243    ("normal", "in_place"): {f16, f32, f64},
244    ("normal", "number_mean"): {f16, f32, f64},
245    "normal": {f16, f32, f64},
246    "sparse.sampled_addmm": {f32, f64},
247    "torch.ops.aten._flash_attention_forward": {f16},
248    "torch.ops.aten._efficient_attention_forward": {f16, f32},
249    "to_sparse": {
250        f16,
251        f32,
252        f64,
253    },  # NYI: could not find kernel for aten.view.default at dispatch key DispatchKey.SparseCUDA
254}
255
256
257# intentionally not handled
258intentionally_not_handled = {
259    "resize_": {b8, f16, f32, f64, i32, i64},
260    "resize_as_": {b8, f16, f32, f64, i32, i64},
261}
262# This is only fixed when this config is set
263# We should eventually always turn it on
264import torch._functorch.config as functorch_config
265
266
267if not functorch_config.view_replay_for_aliased_outputs:
268    intentionally_not_handled['("as_strided", "partial_views")'] = {
269        b8,
270        f16,
271        f32,
272        f64,
273        i32,
274        i64,
275    }
276
277inductor_expected_failures_single_sample["cuda"].update(intentionally_not_handled)
278
279
280inductor_gradient_expected_failures_single_sample = defaultdict(dict)
281
282inductor_gradient_expected_failures_single_sample["cuda"] = {}
283
284if not TEST_MKL:
285    inductor_expected_failures_single_sample["cpu"].update({})
286
287inductor_should_fail_with_exception = defaultdict(dict)
288inductor_should_fail_with_exception["cpu"] = {}
289inductor_should_fail_with_exception["cuda"] = {}
290
291
292def get_skips_and_xfails(from_dict, xfails=True):
293    retval = set()
294    for device, d in from_dict.items():
295        for op, dtypes in d.items():
296            if type(op) is tuple:
297                op, variant_name = op
298            else:
299                variant_name = ""
300            retval.add((op, variant_name, device, tuple(dtypes), xfails))
301    return retval
302
303
304# Note: if you get a "AssertionError: Couldn't find OpInfo for ..." error for an OpInfo you are sure
305# exists, you might be trying to use a test variant and you need to replace, for example,
306# "max.reduction_no_dim" with ("max", "reduction_no_dim") as the key of one of these dictionaries
307test_skips_or_fails = (
308    get_skips_and_xfails(inductor_skips, xfails=False)
309    | get_skips_and_xfails(inductor_expected_failures_single_sample, xfails=True)
310    | get_skips_and_xfails(
311        inductor_gradient_expected_failures_single_sample, xfails=True
312    )
313)
314
315
316def wrapper_noop_set_seed(op, *args, **kwargs):
317    return op(*args, **kwargs)
318
319
320torch.testing._internal.common_methods_invocations.wrapper_set_seed = (
321    wrapper_noop_set_seed
322)
323
324
325# key can be either op_name, or (op_name, deivce_type), or (op_name, device_type, dtype)
326inductor_override_kwargs = {
327    # the return value of empty is undefined
328    "empty": {"assert_equal": False},
329    "empty_permuted": {"assert_equal": False},
330    "empty_like": {"assert_equal": False},
331    "new_empty": {"assert_equal": False},
332    "empty_strided": {"assert_equal": False},
333    "new_empty_strided": {"assert_equal": False},
334    "randn": {"assert_equal": False},
335    ("cross", "cuda", f16): {"reference_in_float": True},
336    ("linalg.cross", "cuda", f16): {"reference_in_float": True},
337    ("addr", "cuda", f16): {"reference_in_float": True},
338    ("baddbmm", "cuda", f16): {"atol": 2e-3, "rtol": 0.002},  # decomp affects accuracy
339    ("angle", "cuda", f64): {"reference_in_float": True},
340    ("asin", "cuda", f16): {"reference_in_float": True},
341    ("atanh", "cuda", f16): {"reference_in_float": True},
342    ("cauchy", "cuda"): {"reference_in_float": True},
343    ("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002},
344    ("cumsum", "cuda", f16): {"reference_in_float": True},
345    ("cumprod", "cuda"): {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002},
346    ("logcumsumexp", "cuda"): {"grad_atol": 8e-4, "grad_rtol": 0.001},
347    ("exponential", "cuda"): {"reference_in_float": True},
348    ("geometric", "cuda"): {"reference_in_float": True},
349    ("kron", "cuda", f16): {"reference_in_float": True},
350    ("log_normal", "cuda"): {"reference_in_float": True},
351    ("masked.softmin", "cuda", f16): {"atol": 1e-4, "rtol": 0.01},
352    ("nn.functional.batch_norm", "cuda", f16): {"reference_in_float": True},
353    ("nn.functional.batch_norm.without_cudnn", "cuda", f16): {
354        "reference_in_float": True
355    },
356    ("nn.functional.cosine_similarity", "cuda", f16): {"reference_in_float": True},
357    ("nn.functional.instance_norm", "cuda", f16): {"reference_in_float": True},
358    ("nn.functional.local_response_norm", "cuda", f16): {"reference_in_float": True},
359    ("nn.functional.normalize", "cuda", f16): {"atol": 1e-3, "rtol": 0.05},
360    ("nn.functional.rms_norm", "cuda", f16): {"reference_in_float": True},
361    ("nn.functional.soft_margin_loss", "cuda", f16): {"reference_in_float": True},
362    ("nn.functional.softmin", "cuda", f16): {"atol": 1e-4, "rtol": 0.01},
363    ("nn.functional.softsign", "cuda", f16): {"reference_in_float": True},
364    ("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001},
365    ("nn.functional.multilabel_soft_margin_loss", "cpu", f16): {
366        "atol": 3e-4,
367        "rtol": 0.002,
368    },
369    ("outer", "cuda", f16): {"reference_in_float": True},
370    ("round.decimals_3", "cuda", f16): {"reference_in_float": True},
371    ("nn.functional.triplet_margin_loss", "cuda", f16): {"atol": 1e-4, "rtol": 0.02},
372    ("nn.functional.triplet_margin_with_distance_loss", "cuda", f16): {
373        "atol": 1e-4,
374        "rtol": 0.02,
375    },
376    ("sinc", "cuda", f16): {"atol": 0.008, "rtol": 0.002},
377    ("torch.ops.aten._safe_softmax.default", "cuda", f16): {"atol": 5e-4, "rtol": 0.02},
378    ("softmax", "cpu", f16): {"atol": 1e-4, "rtol": 0.02},
379    ("softmax", "cuda", f16): {"atol": 1e-4, "rtol": 0.02},
380    ("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002},
381    ("special.log_ndtr", "cuda", f64): {"atol": 1e-6, "rtol": 1e-5},
382    ("polygamma.polygamma_n_0", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
383    ("polygamma.polygamma_n_1", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
384    ("polygamma.polygamma_n_2", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
385    ("polygamma.polygamma_n_3", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
386    ("polygamma.polygamma_n_4", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
387    ("special.polygamma.special_polygamma_n_0", "cpu", f32): {
388        "atol": 1e-3,
389        "rtol": 1e-4,
390    },
391    ("std_mean.unbiased", "cuda", f16): {"reference_in_float": True},
392    ("uniform", "cuda"): {"reference_in_float": True},
393    ("_unsafe_masked_index_put_accumulate", "cuda", f16): {"atol": 1e-4, "rtol": 0.01},
394    ("_unsafe_masked_index_put_accumulate", "cpu", f16): {"atol": 1e-4, "rtol": 0.01},
395    # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors
396    ("nn.functional.interpolate.bilinear", "cpu", u8): {"atol": 1, "rtol": 0},
397    ("nn.functional.upsample_bilinear", "cpu", u8): {"atol": 1, "rtol": 0},
398    ("nn.functional.interpolate.bicubic", "cpu", u8): {"atol": 1, "rtol": 0},
399    # High atol due to precision loss
400    ("nn.functional.interpolate.bilinear", "cuda", f64): {"atol": 5e-4, "rtol": 0},
401    ("nn.functional.upsample_bilinear", "cuda", f64): {"atol": 5e-4, "rtol": 0},
402    ("nn.functional.interpolate.bicubic", "cpu", f32): {"atol": 5e-3, "rtol": 0},
403    ("nn.functional.interpolate.bicubic", "cuda", f64): {"atol": 1e-3, "rtol": 0},
404    # Unreasonably high atol requirement:
405    ("index_reduce.mean", "cuda", f16): {"check_gradient": False},
406    ("index_reduce.mean", "cuda", f32): {"check_gradient": False},
407    ("index_reduce.mean", "cuda", f64): {"check_gradient": False},
408    # Gradient contains non-finite entries:
409    ("index_reduce.amin", "cuda", f64): {"check_gradient": False},
410    ("index_reduce.amin", "cuda", f32): {"check_gradient": False},
411    ("index_reduce.amin", "cuda", f16): {"check_gradient": False},
412    ("index_reduce.amax", "cuda", f64): {"check_gradient": False},
413    ("index_reduce.amax", "cuda", f32): {"check_gradient": False},
414    ("index_reduce.amax", "cuda", f16): {"check_gradient": False},
415    ("tanh", "cuda", f16): {"atol": 1e-4, "rtol": 1e-2},
416}
417
418
419# Test with one sample only for following ops
420inductor_one_sample = {
421    "_segment_reduce.lengths": {f16},
422    "_segment_reduce.offsets": {f16},
423    "addmv": {f16},
424    "as_strided.partial_views": {f16},
425    "corrcoef": {f16},
426    "diff": {f16},
427    "einsum": {f16, i32},
428    "gradient": {f16},
429    "histogram": {f32, f64},
430    "histogramdd": {f32, f64},
431    "index_put": {f16, f32, f64},
432    "linalg.eig": {f32, f64},
433    "linspace": {f16, i32, i64},
434    "linspace.tensor_overload": {f16, f32, f64, i32, i64},
435    "logspace": {f16},
436    "logspace.tensor_overload": {f16, f32, f64, i32, i64},
437    "masked_logsumexp": {i64},
438    "max_pool2d_with_indices_backward": {f16, f32, f64},
439    "new_empty_strided": {f16},
440    "nn.functional.adaptive_avg_pool3d": {f16},
441    "nn.functional.adaptive_max_pool1d": {f16, f32},
442    "nn.functional.adaptive_max_pool2d": {f16, f32},
443    "nn.functional.bilinear": {f16},
444    "nn.functional.conv_transpose1d": {f16},
445    "nn.functional.conv_transpose2d": {f16},
446    "nn.functional.conv_transpose3d": {f16},
447    "nn.functional.cosine_similarity": {f16},
448    "nn.functional.cross_entropy": {f16, f32, f64},
449    "nn.functional.gaussian_nll_loss": {f16},
450    "nn.functional.grid_sample": {f32, f64},
451    "nn.functional.interpolate.area": {f16},
452    "nn.functional.nll_loss": {f16, f32, f64},
453    "normal": {f16, f32, f64},
454    "put": {f16, f32, f64},
455    "take": {b8, f16, f32, f64, i32, i64},
456    ("__rdiv__", "cuda"): {f16},
457    ("__rmod__", "cuda"): {f16, i64},
458    ("__rmul__", "cuda"): {f16},
459    ("__rpow__", "cuda"): {f16},
460    ("_unsafe_masked_index", "cuda"): {f16},
461    ("_unsafe_masked_index_put_accumulate", "cuda"): {f16},
462    ("addcdiv", "cuda"): {f16},
463    ("addcmul", "cuda"): {f16},
464    ("atan2", "cuda"): {f16},
465    ("cumsum", "cuda"): {f16},
466    ("cumulative_trapezoid", "cuda"): {f16},
467    ("dist", "cuda"): {f16},
468    ("div.no_rounding_mode", "cuda"): {f16},
469    ("fmod", "cuda"): {f16},
470    ("grid_sampler_2d", "cuda"): {f16},
471    ("index_fill", "cuda"): {f16, f32, f64},
472    ("ldexp", "cuda"): {f16},
473    ("lerp", "cuda"): {f16},
474    ("linalg.householder_product", "cuda"): {f32},
475    ("linalg.matrix_norm", "cuda"): {f16},
476    ("linalg.vector_norm", "cuda"): {f16},
477    ("logspace", "cuda"): {i32, i64},
478    ("masked.cumsum", "cuda"): {f16},
479    ("masked.logsumexp", "cuda"): {f16},
480    ("masked.mean", "cuda"): {b8},
481    ("masked.normalize", "cuda"): {f16},
482    ("masked.prod", "cuda"): {f16},
483    ("masked.std", "cuda"): {f16},
484    ("masked.var", "cuda"): {f16},
485    ("mul", "cuda"): {f16},
486    ("nn.functional.alpha_dropout", "cuda"): {f16, f32, f64},
487    ("nn.functional.avg_pool1d", "cuda"): {f16, f32, f64},
488    ("nn.functional.avg_pool2d", "cuda"): {f16, f32, f64},
489    ("nn.functional.avg_pool3d", "cuda"): {f16, f32, f64},
490    ("nn.functional.binary_cross_entropy", "cuda"): {f16},
491    ("nn.functional.binary_cross_entropy_with_logits", "cuda"): {f16},
492    ("nn.functional.conv2d", "cuda"): {f16},
493    ("nn.functional.cosine_embedding_loss", "cuda"): {f16},
494    ("nn.functional.dropout2d", "cuda"): {f16, f32, f64},
495    ("nn.functional.dropout3d", "cuda"): {f16, f32, f64},
496    ("nn.functional.dropout", "cuda"): {f16, f32, f64},
497    ("nn.functional.feature_alpha_dropout.with_train", "cuda"): {f16, f32, f64},
498    ("nn.functional.fractional_max_pool2d", "cuda"): {f16, f32, f64},
499    ("nn.functional.fractional_max_pool3d", "cuda"): {f16, f32, f64},
500    ("nn.functional.grid_sample", "cuda"): {f16},
501    ("nn.functional.group_norm", "cuda"): {f16},
502    ("nn.functional.hinge_embedding_loss", "cuda"): {f16},
503    # Enabling all tests for this test fails randomly
504    # See https://github.com/pytorch/pytorch/issues/129238
505    ("nn.functional.huber_loss", "cuda"): {f16},
506    ("nn.functional.interpolate.bicubic", "cuda"): {f16},
507    ("nn.functional.interpolate.bilinear", "cuda"): {f16},
508    ("nn.functional.interpolate.trilinear", "cuda"): {f16},
509    ("nn.functional.kl_div", "cuda"): {f16},
510    ("nn.functional.margin_ranking_loss", "cuda"): {f16},
511    ("nn.functional.max_pool1d", "cuda"): {f16, f32, f64},
512    ("nn.functional.max_pool3d", "cuda"): {f16},
513    ("nn.functional.mse_loss", "cuda"): {f16},
514    ("nn.functional.multi_margin_loss", "cuda"): {f16},
515    ("nn.functional.multilabel_margin_loss", "cuda"): {f16},
516    ("nn.functional.multilabel_soft_margin_loss", "cuda"): {f16},
517    ("nn.functional.normalize", "cuda"): {f16},
518    ("nn.functional.pad.replicate", "cuda"): {f16, f32, f64},
519    ("nn.functional.pad.reflect", "cuda"): {f16},
520    ("nn.functional.pairwise_distance", "cuda"): {f16},
521    ("nn.functional.poisson_nll_loss", "cuda"): {f16},
522    ("nn.functional.rms_norm", "cuda"): {f16},
523    ("norm", "cuda"): {f16},
524    ("pow", "cuda"): {f16},
525    ("prod", "cuda"): {f16},
526    ("scatter_reduce.amax", "cuda"): {f16, f32, f64},
527    ("scatter_reduce.amin", "cuda"): {f16, f32, f64},
528    ("scatter_reduce.mean", "cuda"): {f16, f32, f64},
529    ("special.xlog1py", "cuda"): {f16},
530    ("std", "cuda"): {f16},
531    ("std_mean", "cuda"): {f16},
532    ("svd_lowrank", "cuda"): {f32, f64},
533    ("trapezoid", "cuda"): {f16},
534    ("trapz", "cuda"): {f16},
535    ("true_divide", "cuda"): {f16},
536    ("var", "cuda"): {f16},
537    ("var_mean", "cuda"): {f16},
538    ("xlogy", "cuda"): {f16},
539}
540
541
542def collection_decorator(fn):
543    @functools.wraps(fn)
544    def inner(self, device, dtype, op):
545        try:
546            fn(self, device, dtype, op)
547        except Exception as e:
548            if COLLECT_EXPECT:
549                variant = op.variant_test_name
550                op_key = op.name if not variant else (op.name, variant)
551                device_type = torch.device(device).type
552                # failed_reasons[device_type, op_key].add(repr(e))
553                seen_failed[device_type, op_key].add(dtype)
554            raise e
555
556    return inner
557
558
559class TestInductorOpInfo(TestCase):
560    def tearDown(self):
561        torch._dynamo.reset()
562
563    check_model = check_model
564    check_model_gpu = check_model_gpu
565
566    @onlyNativeDeviceTypes
567    @suppress_warnings
568    @skipCUDAMemoryLeakCheckIf(
569        True
570    )  # inductor kernels failing this test intermittently
571    @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found")
572    @skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found")
573    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
574    @skipIfTorchDynamo("Test uses dynamo already")
575    @skipIfCrossRef
576    @_ops(op_db[START:END])
577    @skipOps("TestInductorOpInfo", "test_comprehensive", test_skips_or_fails)
578    @patch("torch._dynamo.config.raise_on_unsafe_aot_autograd", True)
579    @torch._inductor.config.patch(
580        {"implicit_fallbacks": False, "triton.autotune_pointwise": False}
581    )
582    @collection_decorator
583    def test_comprehensive(self, device, dtype, op):
584        device_type = torch.device(device).type
585
586        assert device_type in (GPU_TYPE, "cpu")
587
588        torch._dynamo.reset()
589        with torch.no_grad():
590            # TODO: should we move empty_cache to the common device interface
591            if device_type == "cuda":
592                torch.cuda.empty_cache()
593        op_name = op.name
594        if op.variant_test_name:
595            op_name += f".{op.variant_test_name}"
596
597        # Skip dtype=torch.uint8 for all ops except upsample and interpolate:
598        allowed_dtypes = [f16, f32, f64, i32, i64, b8]
599        if op_name not in (
600            "nn.functional.interpolate.bilinear",
601            "nn.functional.interpolate.bicubic",
602            "nn.functional.upsample_bilinear",
603            "nn.functional.upsample_nearest",
604        ):
605            if dtype not in allowed_dtypes:
606                raise unittest.SkipTest("Skipped!")
607
608        # with open("test_output.txt", "a") as f:
609        #     print(f"CONSIDERING OP {op_name} on {device_type} with {dtype} |
610        # {inductor_skips[device_type].get(op_name, set())}", flush=True, file=f)
611        #     print(f"CONSIDERING OP {op_name} on {device_type} with {dtype} |
612        # {inductor_skips[device_type].get(op_name, set())}", flush=True)
613        if dtype in inductor_skips[device_type].get(op_name, set()):
614            test_expect = ExpectedTestResult.SKIP
615            # with open("test_output.txt", "a") as f:
616            #     print(f"SKIPPING OP {op_name} on {device_type}", flush=True, file=f)
617            #     print(f"SKIPPING OP {op_name} on {device_type}", flush=True)
618        elif dtype in inductor_expected_failures_single_sample[device_type].get(
619            op_name, set()
620        ) or dtype in inductor_gradient_expected_failures_single_sample[
621            device_type
622        ].get(
623            op_name, set()
624        ):
625            test_expect = ExpectedTestResult.XFAILURE
626        else:
627            test_expect = ExpectedTestResult.SUCCESS
628
629        overridden_kwargs = {}
630        if op_name in inductor_override_kwargs:
631            overridden_kwargs = inductor_override_kwargs[op_name]
632        elif (op_name, device_type) in inductor_override_kwargs:
633            overridden_kwargs = inductor_override_kwargs[(op_name, device_type)]
634        elif (op_name, device_type, dtype) in inductor_override_kwargs:
635            overridden_kwargs = inductor_override_kwargs[(op_name, device_type, dtype)]
636        func = op.get_op()
637
638        def fn(*args, **kwargs):
639            return func(*args, **kwargs)
640
641        requires_grad = (
642            op.supports_autograd
643            and dtype in op.supported_backward_dtypes(device_type)
644            # TODO: OpInfo really ought to error out for this case, but it's
645            # not exercised in test_ops_gradients atm.  The problem is not
646            # complex32 per-se (which is supported by data movement only ops)
647            # but that when we do backwards we expect other ops like add to work
648            and not dtype == torch.complex32
649        )
650        samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
651
652        if (
653            dtype in inductor_one_sample.get(op_name, {})
654            or dtype in inductor_one_sample.get((op_name, device_type), {})
655        ) and not ALL_SAMPLES:
656            if isinstance(samples, (list, tuple)):
657                samples = [samples[0]]
658            else:
659                samples = [next(samples)]
660
661        class HasRngOp(TorchDispatchMode):
662            def __init__(self) -> None:
663                super().__init__()
664                self.has_rng_op = False
665
666            def __torch_dispatch__(self, func, types, args, kwargs=None):
667                kwargs = kwargs if kwargs else {}
668                if torch.Tag.nondeterministic_seeded in func.tags:
669                    self.has_rng_op = True
670
671                return func(*args, **kwargs)
672
673        def do_nopython_and_has_rng(fn, args, kwargs):
674            try:
675                mode = FakeTensorMode()
676
677                def map_to_fake(e):
678                    if isinstance(e, torch.Tensor):
679                        return mode.from_tensor(e)
680                    else:
681                        return e
682
683                args, kwargs = tree_map(map_to_fake, (args, kwargs))
684                with HasRngOp() as rng_mode, mode:
685                    with enable_python_dispatcher():
686                        fn(*args, **kwargs)
687
688            except (DataDependentOutputException, DynamicOutputShapeException):
689                return False, rng_mode.has_rng_op
690
691            return True, rng_mode.has_rng_op
692
693        def get_contexts(has_rng_op):
694            if has_rng_op:
695                # TODO - enable this, running into errors
696                return (
697                    # (
698                    #     lambda: torch._inductor.config.patch(
699                    #         {"fallback_random": True, "implicit_fallbacks": True}
700                    #     ),
701                    #     {"assert_equal": True},
702                    # ),
703                    (
704                        contextlib.nullcontext,
705                        {"assert_equal": False},
706                    ),
707                )
708            return ((contextlib.nullcontext, {}),)
709
710        try:
711
712            def _get_tolerances(dtype):
713                _custom_tolerances = {
714                    torch.float32: (1.3e-5, 1.5e-5),
715                }
716                if dtype in _custom_tolerances:
717                    return _custom_tolerances[dtype]
718                else:
719                    return None, None
720
721            for sample_input in samples:
722                args = [sample_input.input] + list(sample_input.args)
723                kwargs = sample_input.kwargs
724                # UNCOMMENT TO DEBUG SEGFAULTS
725
726                # with open("test_output.txt", "a") as f:
727                #     print(f"RUNNING OP {op_name} on {device_type} with {dtype}", flush=True, file=f)
728                #     print(f"RUNNING OP {op_name} on {device_type} with {dtype}", flush=True)
729                rtol, atol = _get_tolerances(dtype)
730                if device_type == GPU_TYPE:
731                    # opinfo test case have already place the input on the correct device
732                    # so we don't need do additional copy by setting copy_to_gpu=False
733
734                    no_python, has_rng_op = do_nopython_and_has_rng(fn, args, kwargs)
735                    for context_fn, kwarg_overrides in get_contexts(has_rng_op):
736                        with context_fn():
737                            adjusted_kwargs = {
738                                "check_lowp": False,
739                                "nopython": no_python,
740                                "copy_to_gpu": False,
741                                "reference_in_float": False,
742                                "check_gradient": requires_grad,
743                                "check_has_compiled": no_python,
744                                "output_process_fn_grad": sample_input.output_process_fn_grad,
745                                "atol": atol,
746                                "rtol": rtol,
747                            }
748                            adjusted_kwargs.update(overridden_kwargs)
749                            adjusted_kwargs.update(kwarg_overrides)
750                            self.check_model_gpu(
751                                fn,
752                                args,
753                                kwargs,
754                                **adjusted_kwargs,
755                            )
756                elif device_type == "cpu":
757                    no_python, has_rng_op = do_nopython_and_has_rng(fn, args, kwargs)
758                    for context_fn, kwarg_overrides in get_contexts(has_rng_op):
759                        with context_fn():
760                            adjusted_kwargs = {
761                                "check_lowp": False,
762                                "nopython": no_python,
763                                "check_has_compiled": no_python,
764                                # skip checking gradient on CPU for now
765                                "check_gradient": False,
766                                "atol": atol,
767                                "rtol": rtol,
768                            }
769                            adjusted_kwargs.update(overridden_kwargs)
770                            adjusted_kwargs.update(kwarg_overrides)
771
772                            self.check_model(
773                                fn,
774                                args,
775                                kwargs,
776                                **adjusted_kwargs,
777                            )
778
779        except Exception as e:
780            known_failure = False
781            if dtype in inductor_should_fail_with_exception[device_type].get(
782                op_name, set()
783            ):
784                failure = inductor_should_fail_with_exception[device_type][op_name][
785                    dtype
786                ]
787                if failure in str(e):
788                    known_failure = True
789            if not known_failure:
790                raise e
791
792        # with open("test_output.txt", "a") as f:
793        #     print(f"SUCCEEDED OP {op_name} on {device_type} with {dtype}", flush=True, file=f)
794
795
796instantiate_device_type_tests(TestInductorOpInfo, globals())
797
798if __name__ == "__main__":
799    run_tests()
800