xref: /aosp_15_r20/external/pytorch/test/test_jit_fuser_te.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["NNC"]
2
3import contextlib
4import math
5import operator
6import os
7import unittest
8import warnings
9from typing import List
10
11import torch
12import torch.nn.functional as F
13from torch.testing import FileCheck
14
15
16# these needs to be set before `common_utils`
17# infers `GRAPH_EXECUTOR`.
18# this file **requires** these settings
19# and setting them after `GRAPH_EXECUTOR` is
20# inferred erroneously runs or skips
21# some tests
22torch._C._jit_set_profiling_executor(True)
23torch._C._get_graph_executor_optimize(True)
24
25from itertools import combinations, permutations, product
26from textwrap import dedent
27
28from jit.test_fuser_common import TestFuserCommon  # noqa: F401
29from test_jit import (
30    backward_graph,
31    get_lstm_inputs,
32    get_milstm_inputs,
33    LSTMCellC,
34    LSTMCellF,
35    LSTMCellS,
36    MiLSTMCell,
37)
38
39from torch.testing._internal.common_device_type import (
40    instantiate_device_type_tests,
41    onlyCPU,
42    OpDTypes,
43    ops,
44)
45from torch.testing._internal.common_jit import JitCommonTestCase
46from torch.testing._internal.common_methods_invocations import op_db
47from torch.testing._internal.common_utils import (
48    enable_profiling_mode_for_profiling_tests,
49    GRAPH_EXECUTOR,
50    IS_FBCODE,
51    ProfilingMode,
52    run_tests,
53    skipIfTorchDynamo,
54    slowTest,
55    TEST_WITH_ASAN,
56    TEST_WITH_ROCM,
57)
58from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
59from torch.testing._internal.jit_utils import (
60    clone_inputs,
61    get_traced_sample_variant_pairs,
62    JitTestCase,
63    NoTracerWarnContextManager,
64    RUN_CUDA,
65    RUN_CUDA_HALF,
66    RUN_CUDA_MULTI_GPU,
67    set_fusion_group_inlining,
68    TensorExprTestOptions,
69    warmup_backward,
70)
71
72
73FUSION_GROUP = "prim::TensorExprGroup"
74LLVM_ENABLED = torch._C._llvm_enabled()
75
76autograd_check_set = {
77    "aten::__is__",
78    "prim::AutogradAllNonZero",
79    "prim::AutogradAllZero",
80    "prim::ListConstruct",
81}
82
83
84def strip_profiling_nodes(nodes):
85    profiling_opcodes = {"prim::BailoutTemplate", "prim::BailOut"}
86    return [n for n in nodes if n.kind() not in profiling_opcodes]
87
88
89def warmup_forward(f, *args, profiling_count=2):
90    for i in range(profiling_count):
91        results = f(*args)
92
93    return results
94
95
96@contextlib.contextmanager
97def texpr_reductions_enabled():
98    old = torch._C._jit_set_texpr_reductions_enabled(True)
99    try:
100        yield
101    finally:
102        torch._C._jit_set_texpr_reductions_enabled(old)
103
104
105@contextlib.contextmanager
106def texpr_enable_strategy(strategy):
107    old = torch._C._jit_set_fusion_strategy(strategy)
108    try:
109        yield
110    finally:
111        torch._C._jit_set_fusion_strategy(old)
112
113
114@contextlib.contextmanager
115def inline_fusion_groups():
116    old_inlining = torch._C._debug_get_fusion_group_inlining()
117    torch._C._debug_set_fusion_group_inlining(True)
118    try:
119        yield
120    finally:
121        torch._C._debug_set_fusion_group_inlining(old_inlining)
122
123
124class TestTEFuser(JitTestCase):
125    def setUp(self):
126        super().setUp()
127        self.tensorexpr_options = TensorExprTestOptions()
128
129        # note: `self.dynamic_shapes` instatiated in specialization of class
130        # defined below
131
132        fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)]
133        self.old_fusion_strategy = torch._C._jit_set_fusion_strategy(fusion_strategy)
134
135        self.devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
136        self.int_dtypes = [
137            torch.int8,
138            torch.int16,
139            torch.int32,
140            torch.int64,
141            torch.bool,
142        ]
143        self.fp_dtypes = [
144            torch.float16,
145            torch.float32,
146            torch.float64,
147            torch.bfloat16,
148        ]
149        self.dtypes = self.int_dtypes + self.fp_dtypes
150
151    def tearDown(self):
152        self.tensorexpr_options.restore()
153        torch._C._jit_set_fusion_strategy(self.old_fusion_strategy)
154        super().tearDown()
155
156    def assertAllFused(self, graph, except_for=None):
157        except_for = except_for if except_for is not None else set()
158        # TODO - upstream
159        guards = (
160            "prim::TypeCheck",
161            "prim::RequiresGradCheck",
162            "prim::TensorExprDynamicGuard",
163        )
164        guard_found = False
165
166        def autodiff_guard(node):
167            if node.kind() != "aten::all":
168                return False
169            inps = list(node.inputs())
170            if len(inps) != 1 or inps[0].node().kind() != "prim::ListConstruct":
171                return False
172            li_inps = list(inps[0].node().inputs())
173            for li_inp in li_inps:
174                if li_inp.node().kind() in (
175                    "prim::AutogradAllNonZero",
176                    "prim::AutogradAllZero",
177                ):
178                    return True
179            return False
180
181        def is_guard(node):
182            return node.kind() in guards or autodiff_guard(node)
183
184        for node in graph.block().nodes():
185            if node.kind() == "prim::Constant":
186                continue
187            if is_guard(node):
188                self.assertFalse(guard_found)
189                guard_found = True
190                continue
191            if node.kind() in except_for:
192                continue
193            if node.kind() == "prim::If":
194                self.assertTrue(is_guard(node.prev()))
195                continue
196            self.assertTrue(False, "Found unexpected node:" + node.kind())
197
198        self.assertTrue(guard_found)
199
200    def assertLastGraphAllFused(self):
201        self.assertAllFused(torch.jit.last_executed_optimized_graph())
202
203    def findFusionGroups(self, graph):
204        result = []
205        for n in graph.nodes():
206            if n.kind() == FUSION_GROUP:
207                result.append(n.g("Subgraph"))
208                continue
209            for block in n.blocks():
210                result += self.findFusionGroups(block)
211        return result
212
213    def test_typecheck(self):
214        a = torch.ones(1)
215
216        def fused_kernel(a, b):
217            return (a + b) * 2.0
218
219        scripted = self.checkScript(fused_kernel, (a, a))
220        graph = scripted.graph_for(a, a)
221        # double check we fused
222        fusion_groups = self.findFusionGroups(graph)
223        self.assertEqual(len(fusion_groups), 1)
224        # we use a bigger tensor now (size 2)
225        # if we won't trigger a recompilation
226        # we will still create a tensor up to (size 1)
227        # if the type check fails
228        a = torch.ones(2)
229        # shape changed if we don't trigger recompilation
230        # we would compute the wrong result silently
231        self.assertEqual(scripted(a, a), fused_kernel(a, a))
232
233    def test_sum_simple(self):
234        def func(x):
235            x2 = x * x
236            return x2.sum()
237
238        with texpr_reductions_enabled():
239            a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
240            a = a.reshape(5, 3)
241            scripted = self.checkScript(func, (a,))
242            self.assertLastGraphAllFused()
243
244    def test_nop(self):
245        pass
246
247    def test_sum_dim(self):
248        def func(x):
249            return x.sum((0,)) * 2
250
251        def func_neg(x):
252            return x.sum((-2,)) * 2
253
254        with texpr_reductions_enabled():
255            a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
256            a = a.reshape(5, 3)
257            scripted = self.checkScript(func, (a,))
258            self.assertLastGraphAllFused()
259            scripted = self.checkScript(func_neg, (a,))
260            self.assertLastGraphAllFused()
261
262    def test_sum_keepdim_cast(self):
263        def func(x):
264            return x.sum((0,), keepdim=True, dtype=torch.double) * 2
265
266        with texpr_reductions_enabled():
267            a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
268            a = a.reshape(5, 3)
269
270            self.checkScript(func, (a,))
271            self.assertLastGraphAllFused()
272
273    def test_abs(self):
274        for device in self.devices:
275
276            def func(x):
277                return x.abs() * 2
278
279            a = torch.randn(5, device=device)
280            scripted = self.checkScript(func, (a,))
281            self.assertLastGraphAllFused()
282
283    def test_unsqueeze_size_calculation(self):
284        for device in self.devices:
285
286            def foo(b, d):
287                x = d.unsqueeze(1)
288                y = x * 42.0
289                z = b + y
290                r = z / 42.0
291                return r
292
293            inputs = (
294                torch.rand(20, 28, device=device, requires_grad=True),
295                torch.rand(20, device=device),
296            )
297            scripted = self.checkScript(foo, inputs)
298            self.assertAllFused(scripted.graph_for(*inputs))
299
300    def test_zero_element_tensors(self):
301        for device in self.devices:
302
303            def decode(sin_t, cos_t):
304                theta = torch.atan2(sin_t.float(), cos_t.float())
305                return theta
306
307            sin = torch.zeros(0, device=device)
308            cos = torch.zeros(0, device=device)
309            inputs = [sin, cos]
310            ge = self.checkScript(decode, inputs)
311
312    def test_arg_configurations_smoke(self):
313        if self.dynamic_shapes:
314            self.skipTest("TODO: chunk dynamic shapes")
315
316        # A smoke test to make sure we won't use the same kernel for contiguous
317        # and non-contiguous arguments.
318        # TODO: add optionally enabled debug counters to the fuser to verify
319        #       that we really can tell the difference between configurations
320        for device in self.devices:
321
322            def f(x, y):
323                z1, z2 = (x + y).chunk(2, dim=1)
324                return z1 * z2
325
326            x = torch.randn(4, 4, dtype=torch.float, device=device)
327            y = torch.randn(4, 4, dtype=torch.float, device=device)
328            traced_f = torch.jit.trace(f, (x, y))
329            self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
330
331    def test_broadcast(self):
332        for device in self.devices:
333
334            def scaleshift(x, scale, shift):
335                return x * scale + shift
336
337            inputs = [
338                torch.randn(4, 4, dtype=torch.float, device=device),
339                torch.randn(4, dtype=torch.float, device=device),
340                torch.randn(4, dtype=torch.float, device=device),
341            ]
342            self.checkScript(scaleshift, inputs)
343
344    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
345    @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
346    @unittest.skipIf(
347        GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on"
348    )
349    def test_cuda_half(self):
350        x = torch.randn(4, 4, dtype=torch.half, device="cuda")
351        y = torch.randn(4, 4, dtype=torch.half, device="cuda")
352
353        funcs = [self.fn_test_comparison_gt_lt, self.fn_test_relu, self.fn_test_exp]
354
355        # Note: Non fused inputs must be float to prevent loss of precision
356        inputs = (x.float(), y.float())
357        fusion_inputs = (x, y)
358        for fn in funcs:
359            local_inputs = [t.clone().requires_grad_() for t in inputs]
360            local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
361
362            # Verifies outputs
363            fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False)
364            outputs = fn(*local_inputs)
365            fusion_outputs = fusion(*local_fusion_inputs)
366            outputs_half = [t.half() for t in outputs]
367            self.assertEqual(outputs_half, fusion_outputs)
368
369            # Verifies gradients
370            for output, fusion_output in zip(outputs_half, fusion_outputs):
371                grads = torch.autograd.grad(
372                    output.float().sum(),
373                    local_inputs,
374                    allow_unused=True,
375                    retain_graph=True,
376                )
377                fusion_grads = torch.autograd.grad(
378                    fusion_output.sum(),
379                    local_fusion_inputs,
380                    allow_unused=True,
381                    retain_graph=True,
382                )
383                grads_half = [t.half() for t in grads]
384                self.assertEqual(grads_half, fusion_grads)
385
386    def test_checks_cat_inputs(self):
387        # single fusion node causes error
388        with set_fusion_group_inlining(True):
389            for device in self.devices:
390                # We shouldn't treat cat nodes as broadcasting. All their inputs
391                # need to be checked for having the same map size, before we can
392                # run the kernel.
393                def f(x, y):
394                    return torch.cat([x + 2 * x + x**2, y + 4 * y + y**3], dim=0)
395
396                # NOTE: y is broadcastable to x, but output of f(x, y) should have
397                # shape 3x4, and not 4x4.
398                x = torch.randn(2, 4, dtype=torch.float, device=device)
399                y = torch.randn(1, 4, dtype=torch.float, device=device)
400
401                scripted = self.checkScript(f, (x, y))
402                self.assertEqual(scripted(x, y).shape, (3, 4))
403                self.assertAllFused(scripted.graph_for(x, y))
404
405    def test_chunk(self):
406        if self.dynamic_shapes:
407            self.skipTest("TODO: chunk dynamic shapes")
408
409        for device in self.devices:
410
411            def fn(x):
412                a, b, c = x.chunk(3, 1)
413                return a * b + c
414
415            inputs = [torch.randn(10, 6, dtype=torch.float, device=device)]
416
417            self.checkScript(fn, inputs)
418            self.assertLastGraphAllFused()
419
420    def test_chunk_correctness(self):
421        if self.dynamic_shapes:
422            self.skipTest("TODO: chunk dynamic shapes")
423
424        for device in self.devices:
425
426            def chunk_4_0(x):
427                x0, x1, x2, x3 = x.chunk(4, 0)
428                return x0 + x1 + x2 + x3
429
430            def chunk_4_1(x):
431                x0, x1, x2, x3 = x.chunk(4, 1)
432                return x0 + x1 + x2 + x3
433
434            def chunk_4_last(x):
435                x0, x1, x2, x3 = x.chunk(4, 2)
436                return x0 + x1 + x2 + x3
437
438            fns = [chunk_4_0, chunk_4_1, chunk_4_last]
439            tensors = [
440                # splitSize = 1
441                torch.randn(4, 4, 4, dtype=torch.float, device=device),
442                # contiguous case
443                torch.randn(12, 8, 16, dtype=torch.float, device=device),
444                # non-contiguous case
445                torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(
446                    1, 2
447                ),
448            ]
449
450            for tensor in tensors:
451                for fn in fns:
452                    self.checkScript(fn, [tensor])
453                    self.assertLastGraphAllFused()
454
455    def test_chunk_distributes(self):
456        if self.dynamic_shapes:
457            self.skipTest("TODO: chunk dynamic shapes")
458
459        if self.dynamic_shapes:
460            self.skipTest("TODO: chunk dynamic shapes")
461
462        for device in self.devices:
463
464            def f(x, y):
465                z1, z2 = (x + y).chunk(2, dim=1)
466                return z1 * z2
467
468            x = torch.randn(4, 4, dtype=torch.float, device=device)
469            y = torch.randn(4, 4, dtype=torch.float, device=device)
470
471            ge = self.checkTrace(f, (x, y))
472            graph = ge.graph_for(x, y)
473            # XXX: The old fuser does broadcast_tensors but the new fuser doesn't.
474            # FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \
475            #     .check_count('ConstantChunk', 2, exactly=True).run(str(graph))
476            FileCheck().check("with " + FUSION_GROUP + "_").check_count(
477                "ConstantChunk", 1, exactly=True
478            ).run(str(graph))
479
480    def test_chunk_motion_deduplicates_inputs(self):
481        if self.dynamic_shapes:
482            self.skipTest("TODO: chunk dynamic shapes")
483
484        for device in self.devices:
485
486            def func1(x):
487                z = x * x
488                z0, z1 = z.chunk(2)
489                return z0 * z1
490
491            def func2(x):
492                z = x * x * x
493                z0, z1 = z.chunk(2)
494                return z0 * z1
495
496            inputs = [torch.tensor([1.1, 1.2], device=device, dtype=torch.float)]
497            for func in [func1, func2]:
498                self.checkScript(func, inputs)
499                self.assertLastGraphAllFused()
500
501    def test_chunk_multiple(self):
502        if self.dynamic_shapes:
503            self.skipTest("TODO: chunk dynamic shapes")
504
505        for device in self.devices:
506            # The arguments are intentionally used out of order as a test to see
507            # if the fusion compiler adds extra args in the correct order
508            def fn(s, x, y, z):
509                z1, z2 = z.chunk(2, 2)
510                x1, x2, x3 = x.chunk(3, 1)
511                y1, y2 = y.chunk(2, 0)
512                return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
513
514            inputs = [
515                torch.randn(5, 2, 3, dtype=torch.float, device=device),
516                torch.randn(5, 6, 3, dtype=torch.float, device=device),
517                torch.randn(10, 2, 3, dtype=torch.float, device=device),
518                torch.randn(5, 2, 6, dtype=torch.float, device=device),
519            ]
520
521            ge = self.checkScript(fn, inputs)
522            self.assertAllFused(ge.graph_for(*inputs))
523
524    def test_minmax(self):
525        for device in self.devices:
526
527            def tmax(a, b):
528                return torch.max(2 * a, b)
529
530            def tmin(a, b):
531                return torch.min(2 * a, b)
532
533            a = torch.randn(4, 4, dtype=torch.float)
534            b = torch.randn(4, 4, dtype=torch.float)
535            nan = torch.tensor(float("nan"), dtype=torch.float)
536
537            for f, inputs, device in product(
538                (tmax, tmin), ([a, b], [a, nan], [b, nan]), self.devices
539            ):
540                inputs = [t.to(device) for t in inputs]
541                s = self.checkScript(f, inputs)
542                self.assertAllFused(s.graph_for(*inputs))
543
544    def test_clamp(self):
545        for device in self.devices:
546
547            def func2(a, b):
548                return torch.clamp(a + b, min=0, max=2)
549
550            def funcInf(a, b):
551                return torch.clamp(a + b, min=0, max=float("inf"))
552
553            def funcNegInf(a, b):
554                return torch.clamp(a + b, min=float("-inf"), max=0)
555
556            def funcOptMin(a, b):
557                return torch.clamp(a + b, max=2)
558
559            def funcOptMax(a, b):
560                return torch.clamp(a + b, min=0)
561
562            a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True)
563            b = torch.randn(4, 4, dtype=torch.float, device=device)
564            nan = torch.tensor(float("nan"), dtype=torch.float, device=device)
565
566            funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax)
567            for f, inputs in product(funcs, [[a, b], [a, nan]]):
568                inp1, inp2 = inputs
569                s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
570                self.assertAllFused(
571                    s.graph_for(inp1, inp2),
572                    except_for={"aten::size", "aten::_size_if_not_equal"},
573                )
574                c = s(inp1, inp2)
575                with enable_profiling_mode_for_profiling_tests():
576                    warmup_backward(c.sum())
577                graph = backward_graph(s)
578                self.assertAllFused(
579                    graph,
580                    except_for={"aten::Float", "aten::_grad_sum_to_size"}.union(
581                        autograd_check_set
582                    ),
583                )
584
585    def test_clamp_double(self):
586        for device in self.devices:
587
588            def clamp_double(x, eta: float):
589                return 1 - x.clamp(eta, 1 - eta)
590
591            x = torch.tensor([1.0, 1.0], dtype=torch.double, device=device)
592            eta = 1e-9
593            s = self.checkScript(
594                clamp_double,
595                (x, eta),
596                profiling=ProfilingMode.PROFILING,
597                atol=1e-10,
598                rtol=1e-5,
599            )
600            self.assertAllFused(s.graph_for(x, eta), except_for={"aten::sub"})
601
602    def test_clamp_int(self):
603        for device in self.devices:
604
605            def clamp_int(x, eta: int):
606                return x.clamp(0, eta)
607
608            x = torch.tensor([1, 1], device=device)
609            eta = 1 << 32
610            s = self.checkScript(clamp_int, (x, eta), profiling=ProfilingMode.PROFILING)
611            self.assertAllFused(s.graph_for(x, eta))
612
613    def test_add_bool(self):
614        sizes = [(1,), (2,), (4, 4)]
615        for device, size in product(self.devices, sizes):
616
617            def f(x, y, z):
618                return x + y + z
619
620            x = torch.randint(0, 2, size, dtype=torch.bool, device=device)
621            y = torch.randint(0, 2, size, dtype=torch.bool, device=device)
622            z = torch.randint(0, 2, size, dtype=torch.bool, device=device)
623            ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
624            self.assertAllFused(ge.graph_for(x, y, z))
625
626    def test_mul_bool(self):
627        for device in self.devices:
628
629            def f(x, y, z):
630                return x * y * z
631
632            x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
633            y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
634            z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
635
636            ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
637            self.assertAllFused(ge.graph_for(x, y, z))
638
639    def test_div_bool(self):
640        for device in self.devices:
641
642            def f(x, y, z):
643                return (x + y) / z
644
645            x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
646            y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
647            z = torch.ones_like(x, dtype=torch.bool, device=device)
648
649            ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
650            self.assertAllFused(ge.graph_for(x, y, z))
651
652    def test_bitwise_ops(self):
653        def apply(fn):
654            return lambda x, y, z: fn(fn(x, y), z)
655
656        binary_ops = [
657            operator.__and__,
658            operator.__or__,
659            operator.__xor__,
660            operator.__lshift__,
661            operator.__rshift__,
662        ]
663        devices = self.devices
664        for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
665            try:
666                x = self.data_for(dtype, device)
667                y = self.data_for(dtype, device)
668                z = self.data_for(dtype, device)
669                fn = apply(op)
670                ref = fn(x, y, z)
671            except Exception:
672                # If eager mode doesn't support a dtype/op/device combo,
673                # neither does the fuser.  Catch everything to avoid needing to
674                # guess what errors might be thrown by eager.
675                continue
676            try:
677                t = torch.jit.trace(fn, (x, y, z))
678                self.assertEqual(ref, t(x, y, z))
679                self.assertAllFused(t.graph_for(x, y, z))
680            except Exception as e:
681                raise RuntimeError(
682                    " ".join(["Failed:", str(dtype), op.__name__, device])
683                ) from e
684
685    def test_minmax_int_ops(self):
686        def apply(fn):
687            return lambda x, y, z: fn(fn(x, y), z)
688
689        binary_ops = [torch.min, torch.max]
690        devices = self.devices
691        for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
692            try:
693                x = self.data_for(dtype, device)
694                y = self.data_for(dtype, device)
695                z = self.data_for(dtype, device)
696                fn = apply(op)
697                ref = fn(x, y, z)
698            except Exception:
699                # If eager mode doesn't support a dtype/op/device combo,
700                # neither does the fuser.  Catch everything to avoid needing to
701                # guess what errors might be thrown by eager.
702                continue
703            try:
704                t = torch.jit.trace(fn, (x, y, z))
705                self.assertEqual(ref, t(x, y, z))
706                self.assertAllFused(t.graph_for(x, y, z))
707            except Exception as e:
708                raise RuntimeError(
709                    " ".join(["Failed:", str(dtype), op.__name__, device])
710                ) from e
711
712    def test_comparison_eq_ne(self):
713        for device in self.devices:
714
715            def f(x, y):
716                mask = (x == 0).type_as(x)
717                z = x * mask + y
718                mask = (x != 0).type_as(x)
719                z = z * mask + y
720                return z
721
722            x = torch.randn(4, 4, dtype=torch.float, device=device)
723            y = torch.randn(4, 4, dtype=torch.float, device=device)
724
725            ge = self.checkTrace(f, (x, y))
726            self.assertAllFused(ge.graph_for(x, y))
727
728    @staticmethod
729    def fn_test_comparison_gt_lt(x, y):
730        mask = (x > 0).type_as(x)
731        z = x * mask + y
732        mask = (x < 0).type_as(x)
733        z = z * mask + y
734        return z
735
736    def test_comparison_gt_lt(self):
737        for device in self.devices:
738            x = torch.randn(4, 4, dtype=torch.float, device=device)
739            y = torch.randn(4, 4, dtype=torch.float, device=device)
740
741            ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
742            self.assertAllFused(ge.graph_for(x, y))
743
744    def test_comparison_ge_le(self):
745        for device in self.devices:
746
747            def f(x, y):
748                mask = (x >= 0).type_as(x)
749                z = x * mask + y
750                mask = (x <= 0).type_as(x)
751                z = z * mask + y
752                return z
753
754            x = torch.randn(4, 4, dtype=torch.float, device=device)
755            y = torch.randn(4, 4, dtype=torch.float, device=device)
756
757            ge = self.checkTrace(f, (x, y))
758            self.assertAllFused(ge.graph_for(x, y))
759            x.requires_grad_(True)
760            y.requires_grad_(True)
761            self.assertAllFused(
762                ge.graph_for(x, y),
763                except_for=(
764                    "aten::size",
765                    "prim::BroadcastSizes",
766                    "aten::_size_if_not_equal",
767                ),
768            )
769
770    def test_addcmul(self):
771        for device in self.devices:
772            t = torch.randn(1, 4, dtype=torch.float, device=device)
773            t1 = torch.randn(4, 1, dtype=torch.float, device=device)
774            t2 = torch.randn(1, 4, dtype=torch.float, device=device)
775
776            def foo(t, t1, t2):
777                return t.addcmul(t + 1, t2, value=0.1)
778
779            ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
780            graph = ge.graph_for(t, t1, t2)
781            fusion_groups = self.findFusionGroups(graph)
782            self.assertEqual(len(fusion_groups), 1)
783            FileCheck().check("aten::add(").check("aten::addcmul(").run(
784                str(fusion_groups[0])
785            )
786
787    # TODO: We leak CUDA memory here because the traced graph holds onto a
788    # constant-ified tensor. Since the Python-global CompilationUnit is alive
789    # until the end of the process, the memory is effectively leaked.
790    # Removed `_cuda` suffix from this test which disables leak-checking.
791    # If this is a real problem, we'll need to revisit Torchscript Function
792    # lifetimes in Python.
793    def test_lerp(self):
794        for device in self.devices:
795            start = torch.randn(4, 1, dtype=torch.float, device=device)
796            end = torch.randn(1, 4, dtype=torch.float, device=device)
797            weight = torch.tensor(0.5, dtype=torch.float, device=device)
798
799            # scalar weight overload
800            def foo_weight_scalar(start, end):
801                return torch.lerp(start + 1, end, 0.5)
802
803            # tensor weight overload
804            def foo_weight_tensor(start, end):
805                return torch.lerp(start + 1, end, weight)
806
807            ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end))
808            graph = ge_weight_scalar.graph_for(start, end)
809            self.assertAllFused(graph)
810
811            # TODO: uncomment when TE enables support for scalar tensors
812            # ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end))
813            # graph = ge_weight_tensor.graph_for(start, end)
814            # self.assertAllFused(graph)
815
816    def test_concat(self):
817        # disabling concat causes error with single concat node
818        with set_fusion_group_inlining(True):
819            for device in self.devices:
820                hx = torch.randn(3, 20, dtype=torch.float, device=device)
821                cx = torch.randn(3, 20, dtype=torch.float, device=device)
822
823                def foo(hx, cx):
824                    return torch.cat((hx + cx, hx * cx))
825
826                ge = self.checkTrace(foo, (hx, cx))
827                graph = ge.graph_for(hx, cx)
828                self.assertAllFused(graph)
829                # XXX: TE fuser can handle concats in a fusion group.
830                # FileCheck().check("FusedConcat").check_next("return").run(str(graph))
831
832    def test_remove_output_used_only_in_size(self):
833        for device in self.devices:
834
835            def test_fuse(a, b):
836                c = a + b
837                d = c + b
838                return d
839
840            scripted_f = torch.jit.script(test_fuse)
841            x = torch.ones(1, requires_grad=True, device=device)
842            y = torch.ones(1, requires_grad=True, device=device)
843            warmup_forward(scripted_f, x, y, profiling_count=3)
844            g = scripted_f.graph_for(x, y)
845            diff_nodes = g.findAllNodes("prim::DifferentiableGraph")
846            self.assertEqual(len(diff_nodes), 1)
847            g = diff_nodes[0].g("Subgraph")
848            if_nodes = [n for n in g.nodes() if n.kind() == "prim::If"]
849            self.assertEqual(len(if_nodes), 1)
850
851            # the if node and the fusion group inside it should only have one output
852            self.assertEqual(len(list(if_nodes[0].outputs())), 1)
853
854    def test_concat_invariant(self):
855        for device in self.devices:
856            # Invariant: the output of prim::FusedConcat may
857            # not be an input to any node inside the FusionGroup.
858            def fn(x, y, z):
859                x1 = x + y
860                y1 = x - y
861                w = torch.cat([x1, y1])
862                return w + z
863
864            x = torch.randn(2, 2, dtype=torch.float, device=device)
865            y = torch.randn(2, 2, dtype=torch.float, device=device)
866            z = torch.randn(4, 2, dtype=torch.float, device=device)
867            ge = self.checkTrace(fn, (x, y, z))
868            graph = ge.graph_for(x, y, z)
869            self.assertAllFused(graph, except_for={"aten::add"})
870            # XXX: TE fuser can handle concats inside a fusion group.
871            # FileCheck().check("FusedConcat").check_next("return").run(str(graph))
872
873    @staticmethod
874    def fn_test_exp(x, y):
875        return (x + 0.5 * y).exp()
876
877    def test_exp(self):
878        for device in self.devices:
879            x = torch.randn(4, 4, dtype=torch.float, device=device)
880            y = torch.randn(4, 4, dtype=torch.float, device=device)
881
882            ge = self.checkTrace(self.fn_test_exp, (x, y))
883            self.assertAllFused(ge.graph_for(x, y))
884
885    def test_threshold(self):
886        for device in self.devices:
887
888            def f(x):
889                return torch.threshold(x, 0, -10) + x + x + x
890
891            x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device=device)
892            scripted = self.checkScript(f, (x,))
893            self.assertAllFused(scripted.graph_for(x))
894
895    def test_scalar_arg(self):
896        for device in self.devices:
897
898            def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
899                return p * (x * x + x)
900
901            x = torch.randn(4, 4, dtype=torch.float, device=device)
902            p = 3
903            scripted = self.checkScript(fn_test_scalar_arg, (x, p))
904            self.assertAllFused(scripted.graph_for(x, p))
905
906            x.requires_grad_(True)
907
908            # use another function otherwise we will bailout
909            # and won't be able to do fused checks
910            def fn_test_scalar_arg_requires_grad(
911                x: torch.Tensor, p: float
912            ) -> torch.Tensor:
913                return p * (x * x + x)
914
915            scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)
916            out = scripted(x, p)
917            out = scripted(x, p)
918            out = scripted(x, p)
919            self.assertAllFused(
920                scripted.graph_for(x, p),
921                except_for=(
922                    "aten::size",
923                    "prim::BroadcastSizes",
924                    "aten::_size_if_not_equal",
925                ),
926            )
927
928    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
929    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
930    def test_fusion_reuse_multi_gpu(self):
931        def fn(x, y):
932            return x * y * x * y
933
934        inputs_cpu = [
935            torch.randn(4, 4, dtype=torch.float),
936            torch.randn(4, 4, dtype=torch.float),
937        ]
938        inputs_cuda0 = [x.cuda(0) for x in inputs_cpu]
939        inputs_cuda1 = [y.cuda(1) for y in inputs_cpu]
940
941        # Should not crash; these should compile different kernels.
942        ge = self.checkScript(fn, inputs_cpu)
943        self.assertAllFused(ge.graph_for(*inputs_cpu))
944        ge(*inputs_cuda0)
945        ge(*inputs_cuda1)
946
947    # TODO: we're currently not checking 'device' in the type info when pulling
948    # nodes into a fusion group. We should fix that and re-enable this test.
949    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
950    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
951    def test_kernel_cache_multi_gpu(self):
952        def not_fusible(x):
953            return x
954
955        def fn(x, y, z):
956            x_out = x * x * x * x * x  # fusion: lambda x. x * x * x * x * x
957            y_out = y * y * y * y * y
958            z_out = z * z * z * z * z
959            return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
960
961        inputs = [
962            torch.randn(4, 4, dtype=torch.float),
963            torch.randn(4, 4, dtype=torch.float, device="cuda:0"),
964            torch.randn(4, 4, dtype=torch.float, device="cuda:1"),
965        ]
966
967        prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
968
969        # There are 3 FusionGroups. Because they have the same graph, they
970        # should reuse the same KernelSpec in the KernelSpec cache.
971        ge = self.checkScript(fn, inputs)
972        self.assertGraphContainsExactly(ge.graph_for(*inputs), FUSION_GROUP, 3, True)
973        new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
974        # XXX: This assumes that the same kernel isn't already used by another test
975        # FIXME: Use the TE fuser's way of querying the cache.
976        # self.assertEqual(new_cache_size - prev_cache_size, 1)
977
978    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
979    def test_nonzero_device_cuda(self):
980        device = "cuda:" + str(1)
981        x = torch.tensor([0.4], dtype=torch.float, device=device)
982        y = torch.tensor([0.7], dtype=torch.float, device=device)
983
984        def doit(x, y):
985            return torch.sigmoid(torch.tanh(x * (x + y) + x))
986
987        ge = self.checkTrace(doit, (x, y))
988        self.assertAllFused(ge.graph_for(x, y))
989
990    def test_lstm(self):
991        for device in self.devices:
992            inputs = get_lstm_inputs(device, training=True)
993            module = self.checkScript(LSTMCellS, inputs)
994            self.assertAllFused(
995                module.graph_for(inputs), except_for={"prim::TupleConstruct"}
996            )
997
998    def test_lstm_concat(self):
999        # single fusion node causes error
1000        with set_fusion_group_inlining(True):
1001            for device in self.devices:
1002                inputs = get_lstm_inputs(device)
1003                ge = self.checkTrace(LSTMCellC, inputs)
1004                graph = ge.graph_for(*inputs)
1005                except_nodes = {"prim::TupleConstruct", "aten::linear"}
1006                # TODO... Chunk
1007                if self.dynamic_shapes:
1008                    except_nodes = except_nodes.union(
1009                        {"aten::add", "prim::ConstantChunk"}
1010                    )
1011                self.assertAllFused(ge.graph_for(*inputs), except_for=except_nodes)
1012                # XXX: TE fuser can handle concats inside a fusion group.
1013                # FileCheck().check("FusedConcat").check_next("return").run(str(graph))
1014
1015    def test_lstm_gates_permutations(self):
1016        for device in self.devices:
1017            # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
1018            # Test that any permutation of this will still result in one FusionGroup.
1019            choices = ["x.mm(w_ih.t())", "hx.mm(w_hh.t())", "b_ih", "b_hh"]
1020            template = dedent(
1021                """
1022            def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
1023                gates = {} + {} + {} + {}
1024                ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
1025                return ingate * forgetgate * cellgate * outgate
1026            """
1027            )
1028            for permutation in permutations(choices, len(choices)):
1029                code = template.format(*permutation)
1030                scope = {}
1031                exec(code, globals(), scope)
1032                cu = torch.jit.CompilationUnit(code)
1033                fusion_group_len = 2 if self.dynamic_shapes else 1
1034                inputs = get_lstm_inputs(device, training=False)
1035                self.assertEqual(cu.cell(*inputs), scope["cell"](*inputs))
1036                forward_graph = cu.cell.graph_for(*inputs)
1037                self.assertGraphContainsExactly(
1038                    forward_graph, FUSION_GROUP, fusion_group_len
1039                )
1040
1041    # TODO: Fuser doesn't work at all when inputs require grad. Fix that
1042    def test_lstm_traced(self):
1043        for device in self.devices:
1044            inputs = get_lstm_inputs(device)
1045            ge = self.checkTrace(LSTMCellF, inputs)
1046            graph = ge.graph_for(*inputs)
1047            fusion_groups = self.findFusionGroups(graph)
1048            # TODO: chunk
1049            fusion_group_len = 2 if self.dynamic_shapes else 1
1050            self.assertEqual(len(fusion_groups), fusion_group_len)
1051            f = FileCheck()
1052            if not self.dynamic_shapes:
1053                f.check("Chunk")
1054            f.check("aten::sigmoid").check("aten::tanh").run(
1055                str(fusion_groups[0 if not self.dynamic_shapes else 1])
1056            )
1057
1058    def test_milstm(self):
1059        if self.dynamic_shapes:
1060            self.skipTest("don't run conv with dynamic shapes")
1061
1062        for device in self.devices:
1063            inputs = get_milstm_inputs(device, training=True)
1064            module = self.checkScript(MiLSTMCell, inputs)
1065            forward_graph = module.graph_for(*inputs)
1066            # TODO: chunk
1067            fusion_group_len = 2 if self.dynamic_shapes else 1
1068            self.assertGraphContainsExactly(
1069                forward_graph, FUSION_GROUP, fusion_group_len, consider_subgraphs=True
1070            )
1071            FileCheck().check("DifferentiableGraph").check("TupleConstruct").check_next(
1072                "return"
1073            ).check(FUSION_GROUP).run(str(forward_graph))
1074            hy, cy = module(*inputs)
1075            warmup_backward((hy + cy).sum())
1076
1077    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
1078    @unittest.skip("rand_like is not supported yet")
1079    def test_rand_cuda(self):
1080        class M(torch.jit.ScriptModule):
1081            __constants__ = ["d"]
1082
1083            def __init__(self) -> None:
1084                super().__init__()
1085                self.d = torch.device("cuda")
1086
1087            @torch.jit.script_method
1088            def create(self, x):
1089                return x * x + x + torch.rand_like(x)
1090
1091        x = torch.zeros([3, 4, 5], dtype=torch.float, device="cuda")
1092        m = M()
1093        out1 = m.create(x)
1094        out2 = m.create(x)
1095        self.assertNotEqual(out1, out2)
1096        self.assertTrue(torch.all(out1 >= 0))
1097        self.assertTrue(torch.all(out1 < 1))
1098        self.assertTrue(torch.all(out2 >= 0))
1099        self.assertTrue(torch.all(out2 < 1))
1100        self.assertAllFused(m.create.graph_for(x))
1101
1102    @staticmethod
1103    def fn_test_relu(x, y):
1104        return F.relu(x + 0.5 * y)
1105
1106    def test_relu(self):
1107        for device in self.devices:
1108            x = torch.randn(4, 4, dtype=torch.float, device=device)
1109            y = torch.randn(4, 4, dtype=torch.float, device=device)
1110
1111            ge = self.checkTrace(self.fn_test_relu, (x, y))
1112            self.assertAllFused(ge.graph_for(x, y))
1113
1114    def test_erf(self):
1115        for device in self.devices:
1116            # only enabled on gpu
1117            if device == "cpu":
1118                continue
1119
1120            def fn_test_erf(x):
1121                return F.relu(torch.erf(x) - torch.erfc(x))
1122
1123            x = torch.randn(4, 4, dtype=torch.float, device=device)
1124            ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING)
1125            self.assertAllFused(ge.graph_for(x))
1126            x.requires_grad_(True)
1127            ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING)
1128            self.assertAllFused(
1129                ge.graph_for(x),
1130                except_for=(
1131                    "aten::size",
1132                    "prim::BroadcastSizes",
1133                    "aten::_size_if_not_equal",
1134                ),
1135            )
1136
1137    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
1138    @unittest.skip("rand_like is not supported yet")
1139    def test_rand_broadcast_cuda(self):
1140        def fn_test_rand(x, y):
1141            r = torch.rand_like(y)
1142            return r * x + x
1143
1144        # If using profiling, a different function is needed to test different
1145        # shapes, or we'll use a cached script.
1146        def fn_test_rand2(x, y):
1147            r = torch.rand_like(y)
1148            return r * x * x
1149
1150        x = torch.randn(4, 4, dtype=torch.float, device="cuda")
1151        y = torch.randn(4, 4, dtype=torch.float, device="cuda")
1152        script_f = torch.jit.script(fn_test_rand)
1153        warmup_forward(script_f, x, y)
1154        out = script_f(x, y)
1155        self.assertAllFused(script_f.graph_for(x, y))
1156        x.requires_grad_(True)
1157        out = script_f(x, y)
1158        self.assertAllFused(
1159            script_f.graph_for(x, y),
1160            except_for=(
1161                "aten::size",
1162                "prim::BroadcastSizes",
1163                "aten::_size_if_not_equal",
1164            ),
1165        )
1166
1167        # test that broadcasting random produces correct results
1168        x = torch.ones(4, 4, dtype=torch.float, device="cuda")
1169        y = torch.ones(4, dtype=torch.float, device="cuda")
1170        script_f = torch.jit.script(fn_test_rand2)
1171        warmup_forward(script_f, x, y)
1172        out = script_f(x, y)
1173        self.assertEqual(out[0, :] + torch.zeros(4, 4, device="cuda"), out)
1174
1175    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
1176    @unittest.skip("rand_like is not supported yet")
1177    def test_rand_diamond(self):
1178        def fn_test_diamond(x, y):
1179            r = torch.rand_like(y)
1180            a = x + r
1181            b = y - r
1182            return a + b
1183
1184        x = torch.randn(4, 4, dtype=torch.float, device="cuda")
1185        y = torch.randn(4, 4, dtype=torch.float, device="cuda")
1186        script_f = torch.jit.script(fn_test_diamond)
1187        warmup_forward(script_f, x, y)
1188        out = script_f(x, y)
1189        self.assertEqual(out, x + y)
1190
1191    def test_scalar(self):
1192        def fn(x, y):
1193            return 2 * x + y
1194
1195        x = torch.tensor(0.1, dtype=torch.float, device="cpu")
1196        y = torch.tensor(1, dtype=torch.float, device="cpu")
1197        ge = self.checkScript(fn, (x, y))
1198        self.assertAllFused(ge.graph_for(x, y))
1199
1200    def test_inlined_optimized_graph(self):
1201        @torch.jit.script
1202        def foo(x):
1203            return torch.relu(x + x)
1204
1205        for _ in range(3):
1206            foo(torch.rand([4, 4]))
1207
1208        for _ in range(3):
1209            foo(torch.rand([10]))
1210
1211        for _ in range(3):
1212            foo(torch.rand([2, 2, 2]))
1213
1214        g = torch.jit.last_executed_optimized_graph()
1215
1216        FileCheck().check_count("prim::If", 1, exactly=True).check(
1217            "prim::TensorExpr"
1218        ).run(g)
1219        torch._C._jit_pass_inline(g)
1220        f = FileCheck()
1221        for _ in range(3):
1222            f.check("prim::If").check("prim::TensorExpr")
1223        f.run(g)
1224
1225    def test_small_constant(self):
1226        for device in self.devices:
1227
1228            def fn_test_small_constant(x, y):
1229                return (1e-8 * x + 5e-9 * y) * 1e8
1230
1231            x = torch.randn(4, 4, dtype=torch.float, device=device)
1232            y = torch.randn(4, 4, dtype=torch.float, device=device)
1233
1234            ge = self.checkTrace(fn_test_small_constant, (x, y))
1235            self.assertAllFused(ge.graph_for(x, y))
1236
1237    # Currently we don't pull constants into fusion groups, because in some
1238    # cases it could remove the constant from the original graph and now our
1239    # fusion group needs to return that constant for its other users.
1240    # Instead of never pulling constants into the fusion group, we should just
1241    # be more careful at how we rewrite its users.
1242    # TODO: fix that and reenable the test.
1243    def test_tensor_scalar_ops(self):
1244        for device in self.devices:
1245
1246            def should_fuse(x):
1247                z = 3.0
1248                y = x + z
1249                return x * y
1250
1251            def should_fuse_scalar(x, z):
1252                y = x + int(z)
1253                return x * y
1254
1255            inputs = [torch.randn(2, 2, dtype=torch.float, device=device)]
1256            ge = self.checkScript(should_fuse, inputs)
1257            graph = ge.graph_for(*inputs)
1258            fusion_groups = self.findFusionGroups(graph)
1259            self.assertEqual(len(fusion_groups), 1)
1260            FileCheck().check("aten::add").check("aten::mul").run(str(fusion_groups[0]))
1261
1262            inputs = [
1263                torch.randn(2, 2, dtype=torch.float, device=device),
1264                torch.tensor(3.0, dtype=torch.float, device=device),
1265            ]
1266            ge = self.checkScript(should_fuse_scalar, inputs)
1267            # Check that the fused graph computes correct results when the scalar
1268            # input changes.
1269            inputs = [
1270                torch.randn(2, 2, dtype=torch.float, device=device),
1271                torch.tensor(7.0, dtype=torch.float, device=device),
1272            ]
1273            self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs))
1274            # The TE fuser supports fusion of non-constant scalars
1275            self.assertGraphContainsExactly(
1276                ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True
1277            )
1278
1279    def test_where_and_typing(self):
1280        for device in self.devices:
1281
1282            def f(x, y):
1283                mask = x > y
1284                res = torch.where(mask, x, y)
1285                return mask, res
1286
1287            x = torch.randn(4, 4, dtype=torch.double, device=device)
1288            y = torch.randn(4, 4, dtype=torch.double, device=device)
1289
1290            script_f = self.checkScript(f, (x, y))
1291            self.assertAllFused(
1292                script_f.graph_for(x, y), except_for={"prim::TupleConstruct"}
1293            )
1294
1295    def test_disabled(self):
1296        old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
1297        torch._C._jit_override_can_fuse_on_cpu(False)
1298
1299        def fn(a):
1300            return a**2 + a
1301
1302        x = torch.randn(4, dtype=torch.float, device="cpu")
1303        s = self.checkScript(fn, (x,))
1304        g = s.graph_for(x)
1305        self.assertEqual(len(self.findFusionGroups(g)), 0)
1306
1307        torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state)
1308
1309    def data_for(self, dtype, device="cuda", size=None):
1310        if size is None:
1311            v = torch.arange(1, 3, dtype=torch.float, device=device)
1312        else:
1313            v = torch.rand(*size, device=device)
1314        if dtype == torch.bool:
1315            return v > 2
1316        elif dtype in [torch.qint8, torch.quint8, torch.qint32]:
1317            return torch.quantize_per_tensor(v, 0.1, 1, dtype=dtype)
1318        else:
1319            return v.to(dtype)
1320
1321    def test_torch_to(self):
1322        # test no op
1323        @torch.jit.script
1324        def foo(x):
1325            return x.to(torch.float)
1326
1327        foo(torch.tensor([3.0], dtype=torch.float))
1328        foo(torch.tensor([3.0], dtype=torch.float))
1329        FileCheck().check_not("TensorExpr").run(
1330            torch.jit.last_executed_optimized_graph()
1331        )
1332
1333        # test not fusing non-const inputs
1334        @torch.jit.script
1335        def foo(x, dtype: int):
1336            return x.to(dtype)
1337
1338        foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1339        foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1340        FileCheck().check_not("TensorExpr").run(
1341            torch.jit.last_executed_optimized_graph()
1342        )
1343
1344        # test not fusing to_pinned inputs
1345        @torch.jit.script
1346        def foo(x, dtype: int):
1347            return x.to(pin_memory=True)
1348
1349        foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1350        foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1351        FileCheck().check_not("TensorExpr").run(
1352            torch.jit.last_executed_optimized_graph()
1353        )
1354
1355        # test across-device not supported
1356        if torch.cuda.is_available():
1357
1358            @torch.jit.script
1359            def foo(x):
1360                return x.to(device="cuda")
1361
1362            foo(torch.tensor([3.0], dtype=torch.float))
1363            foo(torch.tensor([3.0], dtype=torch.float))
1364            FileCheck().check_not("TensorExpr").run(
1365                torch.jit.last_executed_optimized_graph()
1366            )
1367
1368        sizes = [(1, 4), (4, 4)]
1369        # reuses cast impl, smaller dtype set for faster test
1370        dtypes = [
1371            torch.bool,
1372            torch.int,
1373            torch.float16,
1374            torch.float32,
1375            torch.float64,
1376        ]
1377
1378        class MyMod(torch.nn.Module):
1379            def __init__(self, dtype):
1380                super().__init__()
1381                self.dtype = dtype
1382
1383            def forward(self, x):
1384                return x.to(self.dtype)
1385
1386        bad_dtypes = []
1387        for dtype, output_dtype, device, size in product(
1388            dtypes, dtypes, self.devices, sizes
1389        ):
1390            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
1391            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1392                continue
1393            if dtype == output_dtype:
1394                continue
1395
1396            x = self.data_for(dtype, device, size=size)
1397            mod = MyMod(output_dtype)
1398            ref = mod.forward(x)
1399            # use freezing to make non-Tensor args to `to` constant
1400            mod = torch.jit.freeze(torch.jit.script(mod.eval()))
1401            warmup_forward(mod.forward, x)
1402            self.assertEqual(ref, mod.forward(x))
1403            self.assertLastGraphAllFused()
1404
1405    @unittest.skip("Temporarily disabled")
1406    def test_masked_fill(self):
1407        dtypes = [
1408            torch.int8,
1409            torch.int16,
1410            torch.int32,
1411            torch.int64,
1412            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
1413            # torch.float16,
1414            torch.float32,
1415            torch.float64,
1416            torch.bool,
1417        ]
1418        sizes = [(2,), (4, 4)]
1419        for self_dtype, device, scalar_val, size in product(
1420            dtypes, self.devices, [0.4, 3], sizes
1421        ):
1422            input_v = self.data_for(self_dtype, device, size=size)
1423            mask = self.data_for(torch.bool, device, size=size)
1424
1425            def fn(input_v, mask):
1426                return torch.masked_fill(input_v, mask, scalar_val)
1427
1428            ref = fn(input_v, mask)
1429            try:
1430                t = torch.jit.trace(fn, (input_v, mask))
1431                torch.testing.assert_close(ref, t(input_v, mask))
1432                self.assertLastGraphAllFused()
1433            except Exception as e:
1434                raise RuntimeError(
1435                    " ".join(
1436                        [
1437                            "Failed:",
1438                            str(self_dtype),
1439                            op.__name__,  # noqa: F821
1440                            device,
1441                            str(size),
1442                        ]
1443                    )
1444                ) from e
1445
1446    def test_isnan(self):
1447        x = torch.rand([4])
1448        x[0] = float("nan")
1449        inputs = [x, torch.tensor([float("nan"), 0.5])]
1450        dtypes = [
1451            torch.int8,
1452            torch.int16,
1453            torch.int32,
1454            torch.int64,
1455            torch.float16,
1456            torch.float32,
1457            torch.float64,
1458            torch.bool,
1459        ]
1460
1461        for inp, device, dtype in product(inputs, self.devices, dtypes):
1462            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
1463            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1464                continue
1465            inp = inp.to(device=device, dtype=dtype)
1466            try:
1467                f = torch.jit.trace(lambda x: x.isnan(), (inp,))
1468                warmup_forward(f, inp)
1469                self.assertEqual(f(inp), inp.isnan())
1470                self.assertLastGraphAllFused()
1471            except Exception as e:
1472                raise RuntimeError(
1473                    " ".join(["Failed:", str(dtype), "isnan", device])
1474                ) from e
1475
1476    def test_gelu(self):
1477        def apply(fn):
1478            return lambda x, approximate: fn(x, approximate)
1479
1480        unary_ops = [
1481            F.gelu,
1482        ]
1483        sizes = [(1,), (2,), (4, 4)]
1484        for dtype, op, device, size in product(
1485            self.dtypes, unary_ops, self.devices, sizes
1486        ):
1487            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
1488            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1489                continue
1490            try:
1491                x = self.data_for(dtype, device, size=size)
1492                cond = self.data_for(torch.bool, device)
1493                fn = apply(op)
1494                ref = fn(x, cond)
1495            except Exception:
1496                # If eager mode doesn't support a dtype/op/device combo,
1497                # neither does the fuser.  Catch everything to avoid needing to
1498                # guess what errors might be thrown by eager.
1499                continue
1500            try:
1501                t = torch.jit.trace(fn, (x, cond))
1502                torch.testing.assert_close(ref, t(x, cond))
1503                self.assertAllFused(t.graph_for(x, cond))
1504            except Exception as e:
1505                raise RuntimeError(
1506                    " ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
1507                ) from e
1508
1509    def test_unary_ops(self):
1510        with torch._jit_internal._disable_emit_hooks():
1511
1512            def apply(fn):
1513                return lambda x: fn(x)
1514
1515            unary_ops = [
1516                torch.lgamma,
1517                torch.sigmoid,
1518                torch.reciprocal,
1519                torch.neg,
1520                torch.relu,
1521                F.relu6,
1522                torch.log,
1523                torch.log10,
1524                torch.log1p,
1525                torch.log2,
1526                torch.exp,
1527                torch.expm1,
1528                torch.erf,
1529                torch.erfc,
1530                torch.cos,
1531                torch.sin,
1532                torch.tan,
1533                torch.acos,
1534                torch.asin,
1535                torch.cosh,
1536                torch.sinh,
1537                torch.atan,
1538                torch.tanh,
1539                F.hardtanh,
1540                F.hardsigmoid,
1541                F.hardswish,
1542                F.softplus,
1543                F.silu,
1544                F.mish,
1545                F.elu,
1546                torch.sqrt,
1547                torch.rsqrt,
1548                torch.abs,
1549                # TODO broken on int8 since
1550                # https://github.com/pytorch/pytorch/pull/85144
1551                # RuntimeError: Invalid integral op_type: 23
1552                # torch.ceil,
1553                # torch.floor,
1554                # torch.round,
1555                # torch.trunc,
1556                torch.frac,
1557                # TODO: broken on ROCm?
1558                # F.hardshrink,
1559                F.leaky_relu,
1560                lambda x: torch.threshold(x, 0, -10),
1561                # TODO: broken since type promotion was added
1562                # lambda x: torch.clamp(x, -10, 10),
1563            ]
1564            gpu_only = {torch.erf, torch.erfc}
1565            sizes = [(1,), (2,), (4, 4)]
1566            for dtype, op, device, size in product(
1567                self.dtypes, unary_ops, self.devices, sizes
1568            ):
1569                # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
1570                if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1571                    continue
1572                # todo - re-enable. fails with .500
1573                if dtype == torch.bfloat16 and op == torch.round:
1574                    continue
1575                if op in gpu_only and device == "cpu":
1576                    continue
1577                try:
1578                    x = self.data_for(dtype, device, size=size)
1579                    fn = apply(op)
1580                    ref = fn(x)
1581                except Exception:
1582                    # If eager mode doesn't support a dtype/op/device combo,
1583                    # neither does the fuser.  Catch everything to avoid needing to
1584                    # guess what errors might be thrown by eager.
1585                    continue
1586                try:
1587                    t = torch.jit.trace(fn, (x,))
1588                    torch.testing.assert_close(ref, t(x))
1589                    self.assertAllFused(t.graph_for(x))
1590                except Exception as e:
1591                    raise RuntimeError(
1592                        " ".join(
1593                            ["Failed:", str(dtype), op.__name__, device, str(size)]
1594                        )
1595                    ) from e
1596
1597    def test_binary_ops(self):
1598        def apply(fn):
1599            return lambda x, y: fn(x, y)
1600
1601        binary_ops = [
1602            operator.__and__,
1603            operator.__or__,
1604            operator.__xor__,
1605            torch.add,
1606            torch.sub,
1607            torch.mul,
1608            torch.min,
1609            torch.max,
1610            lambda x, y: torch.lerp(x, y, 0.5),
1611            torch.atan2,
1612            torch.div,
1613            torch.eq,
1614            torch.ne,
1615            torch.ge,
1616            torch.gt,
1617            torch.lt,
1618            torch.fmod,
1619            torch.remainder,
1620            lambda x, y: y.type_as(x),
1621        ]
1622        fp_only = [
1623            torch.fmod,
1624            torch.remainder,
1625        ]
1626        devices = self.devices
1627        for dtype, op, device in product(self.dtypes, binary_ops, devices):
1628            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1629                continue
1630            try:
1631                x = self.data_for(dtype, device)
1632                y = self.data_for(dtype, device)
1633                fn = apply(op)
1634                ref = fn(x, y)
1635            except Exception:
1636                # If eager mode doesn't support a dtype/op/device combo,
1637                # neither does the fuser.  Catch everything to avoid needing to
1638                # guess what errors might be thrown by eager.
1639                continue
1640            try:
1641                t = torch.jit.trace(fn, (x, y))
1642                self.assertEqual(ref, t(x, y))
1643                if op not in fp_only or dtype.is_floating_point:
1644                    self.assertAllFused(t.graph_for(x, y))
1645            except Exception as e:
1646                raise RuntimeError(
1647                    " ".join(["Failed:", str(dtype), op.__name__, device])
1648                ) from e
1649
1650    def test_binary_scalar_ops(self):
1651        def apply(fn):
1652            return lambda x, y: fn(x, y)
1653
1654        ir_template = """
1655        graph(%x : {dtype_x}, %y : {dtype_y}):
1656          %z = {op}(%x, %y)
1657          return (%z)"""
1658
1659        binary_ops = [
1660            "aten::mul",
1661            "aten::add",
1662            "aten::sub",
1663            "aten::div",
1664            "aten::lt",
1665            "aten::le",
1666            "aten::eq",
1667            "aten::ne",
1668            "aten::gt",
1669            "aten::ge",
1670            "aten::__or__",
1671            "aten::__xor__",
1672            "aten::__and__",
1673            "aten::__lshift__",
1674            "aten::__rshift__",
1675        ]
1676        dtypes = ["int", "float", "bool"]
1677        values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]}
1678        devices = self.devices
1679        for dtype_x, dtype_y, op, device in product(
1680            dtypes, dtypes, binary_ops, devices
1681        ):
1682            code = ir_template.format(**locals())
1683
1684            # Interpret the graph
1685            try:
1686                graph = torch._C.parse_ir(code)
1687                for x, y in product(values[dtype_x], values[dtype_y]):
1688                    ref = torch._C._jit_interpret_graph(graph, (x, y))
1689            except Exception:
1690                # If we can't interpret this IR, don't bother checking NNC.
1691                continue
1692
1693            # Compile the graph
1694            try:
1695                k = torch._C._te.TensorExprKernel(graph)
1696            except Exception as e:
1697                raise RuntimeError(
1698                    " ".join(["Compilation failed:", device, str(code)])
1699                ) from e
1700
1701            # Run the graph
1702            for x, y in product(values[dtype_x], values[dtype_y]):
1703                ref = torch._C._jit_interpret_graph(graph, (x, y))
1704                try:
1705                    res = k.run((x, y))
1706                    self.assertEqual(ref, res)
1707                except Exception as e:
1708                    raise RuntimeError(
1709                        " ".join(
1710                            ["Failed at runtime:", device, str(x), str(y), str(code)]
1711                        )
1712                    ) from e
1713
1714    def test_matmul(self):
1715        if self.dynamic_shapes:
1716            self.skipTest("don't run conv with dynamic shapes")
1717
1718        def fn(x, y):
1719            return torch.matmul(x, y)
1720
1721        devices = ["cpu"]  # No cuda support for ext calls yet
1722        sizes = [
1723            [[128, 128], [128, 128]],
1724            [[10, 10], [10, 10]],
1725            [[1, 16], [16, 128]],
1726            [[128], [128]],
1727            [[128], [128, 128]],
1728            [[3], [3]],
1729            [[3, 4], [4]],
1730            [[10, 3, 4], [4]],
1731            [[10, 3, 4], [10, 4, 5]],
1732            [[10, 3, 4], [4, 5]],
1733        ]
1734
1735        # Only 2D x 2D matrix multiply is supported. For non-supported sizes we
1736        # still want to run results verification to test that we didn't
1737        # accidentally fuse it, but we skip the 'is-fused' check.
1738        # TODO: add support for other shape combinations and make this set empty:
1739        skip_is_fused_check_sizes = [
1740            "[[128], [128]]",
1741            "[[128], [128, 128]]",
1742            "[[3], [3]]",
1743            "[[3, 4], [4]]",
1744            "[[10, 3, 4], [4]]",
1745            "[[10, 3, 4], [10, 4, 5]]",
1746            "[[10, 3, 4], [4, 5]]",
1747        ]
1748        for dtype, size, device in product(self.dtypes, sizes, devices):
1749            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1750                continue
1751            try:
1752                size_x, size_y = size
1753                x = self.data_for(dtype, device, size=size_x)
1754                y = self.data_for(dtype, device, size=size_y)
1755                ref = fn(x, y)
1756            except Exception as e:
1757                # If eager mode doesn't support a dtype/op/device combo,
1758                # neither does the fuser.  Catch everything to avoid needing to
1759                # guess what errors might be thrown by eager.
1760                continue
1761            try:
1762                t = torch.jit.trace(fn, (x, y))
1763                t(x, y)
1764                self.assertEqual(ref, t(x, y))
1765                if str(size) not in skip_is_fused_check_sizes:
1766                    self.assertAllFused(t.graph_for(x, y))
1767            except Exception as e:
1768                raise RuntimeError(" ".join(["Failed:", str(dtype), device])) from e
1769
1770    def test_binary_tensor_scalar_ops(self):
1771        with torch._jit_internal._disable_emit_hooks():
1772
1773            def apply_with_scalar(fn, scalar):
1774                return lambda x: fn(x, scalar)
1775
1776            # FIXME: Fails in IR Eval: torch.int64 and_ cpu
1777            binary_ops = [
1778                operator.__and__,
1779                operator.__or__,
1780                operator.__xor__,
1781                torch.add,
1782                torch.sub,
1783                torch.mul,
1784                torch.eq,
1785                torch.ne,
1786                torch.ge,
1787                torch.lt,
1788                torch.gt,
1789            ]
1790            devices = self.devices
1791            # Maybe we should split this into separate tests to speed it up by
1792            # only using  scalar values relevant to particular ops
1793            scalars = [1.5, 3, 0, -2.0, -1]
1794            for dtype, op, device, scalar in product(
1795                self.dtypes, binary_ops, devices, scalars
1796            ):
1797                if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1798                    continue
1799                try:
1800                    x = self.data_for(dtype, device)
1801                    fn = apply_with_scalar(op, scalar)
1802                    ref = fn(x)
1803                except Exception:
1804                    # If eager mode doesn't support a dtype/op/device combo,
1805                    # neither does the fuser.  Catch everything to avoid needing to
1806                    # guess what errors might be thrown by eager.
1807                    continue
1808                try:
1809                    t = torch.jit.trace(fn, (x))
1810                    self.assertEqual(ref, t(x))
1811                    self.assertAllFused(t.graph_for(x))
1812                except Exception as e:
1813                    raise RuntimeError(
1814                        " ".join(["Failed:", str(dtype), op.__name__, device])
1815                    ) from e
1816
1817    def test_binary_div_ops(self):
1818        def apply_with_scalar(fn, scalar):
1819            return lambda x: fn(x, scalar)
1820
1821        binary_ops = [
1822            torch.div,
1823            torch.remainder,
1824            torch.fmod,
1825        ]
1826        devices = self.devices
1827        # Maybe we should split this into separate tests to speed it up by
1828        # only using  scalar values relevant to particular ops
1829        scalars = [1.5, 3, -2.0, -1]  # skip 0
1830        for dtype, op, device, scalar in product(
1831            self.dtypes, binary_ops, devices, scalars
1832        ):
1833            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1834                continue
1835            try:
1836                x = self.data_for(dtype, device)
1837                fn = apply_with_scalar(op, scalar)
1838                ref = fn(x)
1839            except Exception:
1840                # If eager mode doesn't support a dtype/op/device combo,
1841                # neither does the fuser.  Catch everything to avoid needing to
1842                # guess what errors might be thrown by eager.
1843                continue
1844            try:
1845                t = torch.jit.trace(fn, (x))
1846                self.assertEqual(ref, t(x))
1847            except Exception as e:
1848                raise RuntimeError(
1849                    f"Failed: {dtype} {op.__name__} {device} {scalar}"
1850                ) from e
1851
1852    def test_binary_pow(self):
1853        def apply_with_scalar(fn, scalar):
1854            return lambda x: fn(x, scalar)
1855
1856        dtypes = [
1857            # FIXME: 'pow' fails with dtype=torch.float16/device=cuda/scalar=0
1858            # torch.float16,
1859            torch.float32,
1860            torch.float64,
1861            # torch.bool intentionally not included
1862        ]
1863        binary_ops = [
1864            torch.pow,
1865        ]
1866        # Maybe we should split this into separate tests to speed it up by
1867        # only using  scalar values relevant to particular ops
1868        scalars = [1.5, 3, 0, -2.0, -1]
1869        for dtype, op, device, scalar in product(
1870            dtypes, binary_ops, self.devices, scalars
1871        ):
1872            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1873                continue
1874            try:
1875                x = self.data_for(dtype, device)
1876                fn = apply_with_scalar(op, scalar)
1877                ref = fn(x)
1878            except Exception:
1879                # If eager mode doesn't support a dtype/op/device combo,
1880                # neither does the fuser.  Catch everything to avoid needing to
1881                # guess what errors might be thrown by eager.
1882                continue
1883            try:
1884                t = torch.jit.trace(fn, (x))
1885                self.assertEqual(ref, t(x))
1886                self.assertAllFused(t.graph_for(x))
1887            except Exception as e:
1888                raise RuntimeError(
1889                    " ".join(["Failed:", str(dtype), op.__name__, device])
1890                ) from e
1891
1892    def test_ternary_ops(self):
1893        def apply(fn):
1894            return lambda x, y, z: fn(x, y, z)
1895
1896        ternary_ops = [
1897            torch.lerp,
1898            torch.addcmul,
1899        ]
1900        devices = self.devices
1901        for dtype, op, device in product(self.dtypes, ternary_ops, devices):
1902            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1903                continue
1904            try:
1905                x = self.data_for(dtype, device)
1906                y = self.data_for(dtype, device)
1907                z = self.data_for(dtype, device)
1908                fn = apply(op)
1909                ref = fn(x, y, z)
1910            except Exception:
1911                # If eager mode doesn't support a dtype/op/device combo,
1912                # neither does the fuser.  Catch everything to avoid needing to
1913                # guess what errors might be thrown by eager.
1914                continue
1915            try:
1916                t = torch.jit.trace(fn, (x, y, z))
1917                self.assertEqual(ref, t(x, y, z))
1918                self.assertAllFused(t.graph_for(x, y, z))
1919            except Exception as e:
1920                raise RuntimeError(
1921                    " ".join(["Failed:", str(dtype), op.__name__, device])
1922                ) from e
1923
1924    def test_ternary_norm_ops(self):
1925        def apply(fn):
1926            return lambda x, y, z: fn(x, y, z)
1927
1928        ternary_ops = [
1929            F.batch_norm,
1930        ]
1931        devices = self.devices
1932        for dtype, op, device in product(self.dtypes, ternary_ops, devices):
1933            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1934                continue
1935            try:
1936                x = self.data_for(dtype, device, size=[5, 3, 128, 128])
1937                y = self.data_for(dtype, device, size=[3])
1938                z = self.data_for(dtype, device, size=[3])
1939                fn = apply(op)
1940                ref = fn(x, y, z)
1941            except Exception:
1942                # If eager mode doesn't support a dtype/op/device combo,
1943                # neither does the fuser.  Catch everything to avoid needing to
1944                # guess what errors might be thrown by eager.
1945                continue
1946            try:
1947                t = torch.jit.trace(fn, (x, y, z))
1948                self.assertEqual(ref, t(x, y, z))
1949                self.assertAllFused(t.graph_for(x, y, z))
1950            except Exception as e:
1951                raise RuntimeError(
1952                    " ".join(["Failed:", str(dtype), op.__name__, device])
1953                ) from e
1954
1955    @unittest.skip(
1956        "FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure"
1957    )
1958    def test_list_ops(self):
1959        def apply(fn):
1960            return lambda x, y, z: fn([x * x, y * y, z * z])
1961
1962        devices = self.devices
1963        list_ops = [
1964            torch.cat,
1965        ]
1966        for dtype, op, device in product(self.dtypes, list_ops, devices):
1967            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1968                continue
1969            try:
1970                x = self.data_for(dtype, device, size=[5, 4, 1, 7])
1971                y = self.data_for(dtype, device, size=[5, 4, 1, 7])
1972                z = self.data_for(dtype, device, size=[5, 4, 1, 7])
1973                fn = apply(op)
1974                ref = fn(x, y, z)
1975            except Exception:
1976                # If eager mode doesn't support a dtype/op/device combo,
1977                # neither does the fuser.  Catch everything to avoid needing to
1978                # guess what errors might be thrown by eager.
1979                continue
1980            try:
1981                t = torch.jit.trace(fn, (x, y, z))
1982                self.assertEqual(ref, t(x, y, z))
1983                self.assertAllFused(t.graph_for(x, y, z))
1984            except Exception as e:
1985                raise RuntimeError(
1986                    " ".join(["Failed:", str(dtype), op.__name__, device])
1987                ) from e
1988
1989    def test_where_ops(self):
1990        def apply(fn):
1991            return lambda cond, x, y: fn(cond, x, y)
1992
1993        ops = [
1994            torch.where,
1995            lambda cond, x, y: torch.where(cond, x, 3.1415),
1996            lambda cond, x, y: torch.where(cond, 42, y),
1997        ]
1998        devices = self.devices
1999        for dtype, op, device in product(self.dtypes, ops, devices):
2000            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
2001                continue
2002            try:
2003                cond = self.data_for(torch.bool, device)
2004                x = self.data_for(dtype, device)
2005                y = self.data_for(dtype, device)
2006                fn = apply(op)
2007                ref = fn(cond, x, y)
2008            except Exception:
2009                # If eager mode doesn't support a dtype/op/device combo,
2010                # neither does the fuser.  Catch everything to avoid needing to
2011                # guess what errors might be thrown by eager.
2012                continue
2013            try:
2014                t = torch.jit.trace(fn, (cond, x, y))
2015                self.assertEqual(ref, t(cond, x, y))
2016                self.assertAllFused(t.graph_for(cond, x, y))
2017            except Exception as e:
2018                raise RuntimeError(
2019                    " ".join(["Failed:", str(dtype), op.__name__, device])
2020                ) from e
2021
2022    def test_unsupported_dtypes(self):
2023        for device in self.devices:
2024
2025            def fn(x):
2026                return x * x + x
2027
2028            unsupported_dtypes = [
2029                torch.uint8,
2030                torch.complex32,
2031                torch.complex64,
2032                torch.complex128,
2033                torch.qint8,
2034                torch.quint8,
2035                torch.qint32,
2036            ]
2037            for dtype in unsupported_dtypes:
2038                try:
2039                    x = self.data_for(dtype, device)
2040                    ref = fn(x)
2041                except Exception:
2042                    # If eager mode doesn't support a dtype/op/device combo,
2043                    # neither does the fuser.  Catch everything to avoid needing to
2044                    # guess what errors might be thrown by eager.
2045                    continue
2046                t = torch.jit.trace(fn, (x,))
2047                self.assertEqual(ref, t(x))
2048                self.assertEqual(len(self.findFusionGroups(t.graph_for(x))), 0)
2049
2050    def test_superslomo(self):
2051        devices = self.devices.copy()
2052        if not LLVM_ENABLED:
2053            devices.remove("cpu")
2054        for device in devices:
2055            # Test extracted from Super-SloMo: https://github.com/avinashpaliwal/Super-SloMo
2056            # A few interesting things happen here: strided inputs of mixed size,
2057            # plus outputs of mixed shapes.  The latter characteristic happened to
2058            # expose a memory corruption bug due to not properly guarding the
2059            # outputs.
2060            def eager(t0, t1, t2, t3, t4):
2061                t5 = torch.mul(t0, t4)
2062                t6 = torch.mul(t2, t3)
2063                t7 = torch.mul(t6, t1)
2064                t9 = torch.add(t5, t7)
2065                t11 = torch.add(t0, t6)
2066                ft_p = torch.div(t9, t11)
2067                return (ft_p, t11, t9, t6)
2068
2069            t0 = torch.rand(1, 6, 352, 352, device=device).transpose(0, 1)
2070            t1 = torch.rand(6, 3, 352, 352, device=device)
2071            t2 = torch.rand(6, device=device)[None, None, None, :].permute(3, 0, 1, 2)
2072            t3 = torch.rand(6, 1, 352, 352, device=device)
2073            t4 = torch.rand(6, 3, 352, 352, device=device)
2074            inputs = [t0, t1, t2, t3, t4]
2075
2076            script = torch.jit.script(eager)
2077            for _ in range(4):
2078                for pair in zip(script(*inputs), eager(*inputs)):
2079                    test, ref = pair
2080                    torch.testing.assert_close(test, ref)
2081                    self.assertAllFused(
2082                        script.graph_for(*inputs), except_for={"prim::TupleConstruct"}
2083                    )
2084
2085    def test_sub_gt_and(self):
2086        for device in self.devices:
2087
2088            def eager(t1, t2, t3, t4, t: float):
2089                w = t1 - t2
2090                h = t3 - t4
2091                k = (w > t) & (h > t)
2092                assert k.dtype == torch.bool
2093                if t > 0.5:
2094                    # Putting a use of k in a never-executed conditional prevents
2095                    # profiling its type, which leaves it as "Tensor".  If we
2096                    # propagate Tensor back to the definition of k, we have to be
2097                    # careful not to create a fusion group containing it.
2098                    return k + 1
2099                return w
2100
2101            t = torch.rand(8, dtype=torch.float, device=device)
2102            scripted = self.checkScript(eager, (t, t, t, t, 0.1))
2103
2104    @skipIfTorchDynamo("too slow")
2105    def test_chunk_mul_one(self):
2106        if self.dynamic_shapes:
2107            self.skipTest("TODO: chunk dynamic shapes")
2108
2109        for device in self.devices:
2110
2111            def eager(x):
2112                z, y, w = torch.chunk(x, 3, -1)
2113                return z * 3, y, w
2114
2115            x = torch.rand(64, 1, 3072, dtype=torch.float, device=device)
2116            z, y, w = eager(x)
2117            script = self.checkScript(eager, (x,))
2118
2119    def test_eq_unsqueeze_type_as(self):
2120        for device in self.devices:
2121
2122            def eager(a, b):
2123                mask = b == 1
2124                mask = torch.unsqueeze(mask, -1)
2125                x = mask.type_as(a)
2126                return x, mask
2127
2128            a = torch.rand(1, 64, 1024, device=device, dtype=torch.float)
2129            b = torch.randint(-2, 2, (1, 64), device=device, dtype=torch.long)
2130            script = self.checkScript(eager, (a, b))
2131
2132    def test_neg_pow(self):
2133        def eager_tt(a: torch.Tensor, b: torch.Tensor):
2134            return torch.neg(torch.pow(a, b))
2135
2136        def eager_ts(a: torch.Tensor, b: float):
2137            return torch.neg(torch.pow(a, b))
2138
2139        def eager_st(a: float, b: torch.Tensor):
2140            return torch.neg(torch.pow(a, b))
2141
2142        a = torch.rand(1, dtype=torch.float)
2143        b = torch.rand(1, dtype=torch.float)
2144        s = b.item()
2145        script = self.checkScript(eager_tt, (a, b))
2146        # TODO: re-enable fusion, which doesn't work right now. just test correctness for now
2147        # self.assertAllFused(script.graph_for(a, b))
2148        script = self.checkScript(eager_ts, (a, s))
2149        # self.assertAllFused(script.graph_for(a, s))
2150        script = self.checkScript(eager_st, (s, b))
2151        # self.assertAllFused(script.graph_for(s, b))
2152
2153    @unittest.skipIf(not LLVM_ENABLED, "Too slow to run with the TE interpreter")
2154    def test_conv2d_depthwise(self):
2155        if self.dynamic_shapes:
2156            self.skipTest("don't run conv with dynamic shapes")
2157
2158        def eager(input, weight, bias):
2159            return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=72)
2160
2161        input = torch.rand((1, 72, 56, 56), dtype=torch.float)
2162        weight = torch.rand((72, 1, 3, 3), dtype=torch.float)
2163        bias = torch.rand((72), dtype=torch.float)
2164
2165        script = self.checkScript(eager, (input, weight, bias))
2166        self.assertAllFused(script.graph_for(input, weight, bias))
2167
2168    def test_conv2d(self):
2169        if self.dynamic_shapes:
2170            self.skipTest("don't run conv with dynamic shapes")
2171
2172        def eager(input, weight, bias):
2173            return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=1)
2174
2175        input = torch.rand((1, 64, 56, 56), dtype=torch.float)
2176        weight = torch.rand((64, 64, 3, 3), dtype=torch.float)
2177        bias = torch.rand((64), dtype=torch.float)
2178
2179        script = self.checkScript(eager, (input, weight, bias))
2180        FileCheck().check_not("TensorExpr").run(
2181            torch.jit.last_executed_optimized_graph()
2182        )
2183
2184    def test_type_as_cat(self):
2185        with inline_fusion_groups():
2186
2187            def eager(x, y):
2188                return torch.cat((x, y.type_as(x)), dim=1)
2189
2190            dtypes = self.dtypes.copy()
2191            # CPU fuser doesn't support float16.
2192            dtypes.remove(torch.float16)
2193            dtypes.remove(torch.bfloat16)
2194            for dtype1, dtype2 in product(dtypes, dtypes):
2195                x = torch.randint(2, (1, 13)).to(dtype1)
2196                zero = torch.tensor([[0]]).to(dtype2)
2197                one = torch.tensor([[1]]).to(dtype2)
2198                script = torch.jit.trace(eager, (x, zero))
2199                for _ in range(3):
2200                    torch.testing.assert_close(script(x, zero), eager(x, zero))
2201                    torch.testing.assert_close(script(x, one), eager(x, one))
2202                self.assertAllFused(script.graph_for(x, one))
2203
2204    def test_to_device(self):
2205        def eager(x):
2206            return x.to(device="cpu").relu()
2207
2208        x = torch.rand(8)
2209        script = self.checkScript(eager, (x,))
2210        self.assertAllFused(script.graph_for(x))
2211
2212    def test_dims(self):
2213        def eager(x, y):
2214            return x / (y + 0.0001)
2215
2216        x = torch.linspace(-1, 1, 768, dtype=torch.float32).as_strided(
2217            (1, 1, 768), (768, 1, 1)
2218        )
2219        y = torch.tensor([[[2.0]]], dtype=torch.float32)
2220        script = self.checkScript(eager, (x, y))
2221        self.assertAllFused(script.graph_for(x, y))
2222
2223    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
2224    def test_channels_last_dims_dynamic(self):
2225        def eager(x, y):
2226            return x + (y + 0.0001)
2227
2228        indices = [0, 1, 2, 3]
2229        sets = []
2230        for i in range(0, len(indices) + 1):
2231            for subset in combinations(indices, i):
2232                sets.append(subset)  # noqa: PERF402
2233
2234        for set in sets:
2235            size = [2, 3, 4, 5]
2236            for index in set:
2237                size[index] = 1
2238            inp = torch.rand(size).to(memory_format=torch.channels_last).cuda()
2239            with texpr_enable_strategy([("DYNAMIC", 20)]):
2240                foo_s = torch.jit.trace(eager, (inp, inp))
2241                for _ in range(3):
2242                    out = foo_s(inp, inp)
2243                out_eager = eager(inp, inp)
2244                self.assertEqual(out_eager, out)
2245                self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
2246                g = torch.jit.last_executed_optimized_graph()
2247                FileCheck().check("TensorExpr").run(g)
2248
2249    def test_exhaust_specializations(self):
2250        with texpr_enable_strategy([("STATIC", 1)]):
2251
2252            @torch.jit.script
2253            def foo(x):
2254                return x + x + x
2255
2256            for _ in range(3):
2257                foo(torch.rand([2, 2]))
2258
2259            for _ in range(3):
2260                foo(torch.rand([4, 4, 4]))
2261
2262            g = torch.jit.last_executed_optimized_graph()
2263            torch._C._jit_pass_inline(g)
2264
2265            FileCheck().check_count("TensorExpr", 2, exactly=True).run(g)
2266
2267    def test_unsqueeze_var_dim(self):
2268        def eager(x, y, z: int):
2269            return x * torch.unsqueeze(y, dim=z)
2270
2271        x = torch.rand(4, 4, 64).permute(1, 0, 2)
2272        y = torch.rand(4, 4)
2273        z = 2
2274        script = self.checkScript(eager, (x, y, z))
2275
2276    def _test_fwd_bwd(self, fn):
2277        x = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True)
2278        xs = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True)
2279        script = torch.jit.script(fn)
2280        for i in range(11):
2281            y = fn(x)
2282            g0 = torch.rand_like(y)
2283            y.backward(g0)
2284
2285            ys = script(xs)
2286            ys.backward(g0)
2287
2288            with torch.no_grad():
2289                x -= 0.1 * x.grad
2290                xs -= 0.1 * xs.grad
2291                x.grad = None
2292                xs.grad = None
2293        torch.testing.assert_close(y, ys)
2294
2295    def test_relu_fwd_bwd(self):
2296        def eager(x):
2297            return torch.relu(x * 1.01)
2298
2299        self._test_fwd_bwd(eager)
2300
2301    def test_hardswish_fwd_bwd(self):
2302        def eager(x):
2303            return F.hardswish(x) * 1.01
2304
2305        self._test_fwd_bwd(eager)
2306
2307    def test_hardsigmoid_fwd_bwd(self):
2308        def eager(x):
2309            return F.hardsigmoid(x) * 1.01
2310
2311        self._test_fwd_bwd(eager)
2312
2313    def test_cat_graph_opt(self):
2314        def foo(x, y, z):
2315            return torch.log(torch.cat([x, y, z]))
2316
2317        self.checkScript(
2318            foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5]))
2319        )
2320        # TODO: not sure why not updated graph isn't reflected in last_optimized_graph
2321        self.assertLastGraphAllFused()
2322
2323    def test_dynamic_cat(self):
2324        with inline_fusion_groups():
2325
2326            @torch.jit.script
2327            def repro(
2328                xs: List[torch.Tensor], ys: List[torch.Tensor], zs: List[torch.Tensor]
2329            ):
2330                return [
2331                    torch.cat([x, torch.cat([y, z], dim=-1)], dim=-1)
2332                    for x, y, z in zip(xs, ys, zs)
2333                ]
2334
2335            for _ in range(3):
2336                N = 3
2337                xs = [torch.ones(21) for _ in range(N)]
2338                # Note: concat of ys and zs will have the same size for each
2339                # pair, even though the individual ys and zs do not.
2340                ys = [torch.ones(N - i) for i in range(N)]
2341                zs = [torch.ones(i) for i in range(N)]
2342                repro(xs, ys, zs)
2343
2344    def test_scalar_only_inputs(self):
2345        def eager(b: float):
2346            a = torch.ones(1)
2347            return a * b
2348
2349        script = self.checkScript(eager, (1.0,))
2350
2351    def test_cat_2k_args(self):
2352        with inline_fusion_groups():
2353
2354            def eager(x):
2355                return torch.relu(torch.cat([x for _ in range(2000)]))
2356
2357            x = torch.randn(1)
2358            trace = self.checkTrace(eager, (x,))
2359            fusion_groups = self.findFusionGroups(trace.graph_for(x))
2360            self.assertEqual(len(fusion_groups), 0)
2361
2362    def test_adaptive_avg_pool2d(self):
2363        # TODO: once the adaptive_avg_pool2d is available in OpInfo DB, this
2364        # test should be moved there
2365        with inline_fusion_groups():
2366
2367            def foo1(x):
2368                return torch.nn.functional.adaptive_avg_pool2d(x, (2, 2))
2369
2370            def foo2(x):
2371                return torch.nn.functional.adaptive_avg_pool2d(x, (2))
2372
2373            x = torch.randn(4, 4, 4)
2374            for foo in [foo1, foo2]:
2375                f = torch.jit.trace(foo, (x,))
2376                kernel = torch._C._te.TensorExprKernel(f.graph)
2377                correct_val = f(x)
2378                self.assertEqual(kernel.run((x,)), correct_val)
2379
2380    def test_unrolled_cat(self):
2381        with inline_fusion_groups():
2382
2383            def eager(x):
2384                ret = torch.empty(0)
2385                for i in range(x.shape[0]):
2386                    ret = torch.cat([ret, x[i].relu()])
2387                return ret
2388
2389            script = torch.jit.script(eager)
2390
2391            # Warm up with size=1 tensor; since the loop iterates once the
2392            # profile data will be "burned in" assuming size=1, and then
2393            # unrolled.
2394            x = torch.ones(1, 1)
2395            for _ in range(3):
2396                script(x)
2397
2398            torch.testing.assert_close(eager(x), script(x))
2399
2400            # Now when an input hits the unrolled path, it will produce an
2401            # incorrectly-sized tensor, since size=1 has been burned in.
2402            x = torch.ones((8, 1))
2403            torch.testing.assert_close(eager(x), script(x))
2404
2405    @skipIfTorchDynamo("too slow")
2406    @unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan")
2407    @unittest.skipIf(TEST_WITH_ROCM, "Tensor-likes are not close for nans")
2408    def test_batch_norm(self):
2409        def test(fn, args):
2410            trace = torch.jit.trace(fn, args)
2411            self.assertAllFused(trace.graph_for(*args))
2412            # TODO: Are `NaN`'s actually ok here or did this pass silently before, because `equal_nan=True` was the
2413            #  default?
2414            torch.testing.assert_close(fn(*args), trace(*args), equal_nan=True)
2415
2416        def bn(i, x):
2417            return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu()
2418
2419        def bn_no_weight(i, x):
2420            return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu()
2421
2422        def bn_no_bias(i, x):
2423            return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu()
2424
2425        def bn_neither(i, x):
2426            return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu()
2427
2428        for device in self.devices:
2429            i = torch.randn(4, 16, 32, 40, device=device)
2430            x = torch.randn(16, device=device)
2431            for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]:
2432                test(fn, (i, x))
2433
2434    def test_profiler(self):
2435        @torch.jit.script
2436        def test(x, y, z):
2437            return x * y + z
2438
2439        args = [torch.randn(4) for _ in range(3)]
2440        with torch.autograd.profiler.profile() as prof:
2441            for _ in range(3):
2442                test(*args)
2443        self.assertIn("fused_mul_add", prof.table())
2444
2445    def test_skip_grad_in_check(self):
2446        @torch.jit.script
2447        def foo(x):
2448            return (x + 2) / 2
2449
2450        inp = torch.rand([4, 4])
2451        for _ in range(3):
2452            foo(inp)
2453
2454        inp.requires_grad_(True)
2455        with torch.inference_mode():
2456            for _ in range(3):
2457                foo(inp)
2458        g = torch.jit.last_executed_optimized_graph()
2459        torch._C._jit_pass_inline(g)
2460        torch._C._jit_pass_inline(g)
2461        FileCheck().check_count("prim::If", 1, exactly=True).run(g)
2462
2463    def test_dynamic_shapes(self):
2464        from functools import partial
2465
2466        n = 10
2467
2468        gen_tensor = (
2469            lambda n: R(1, n),
2470            lambda n: R(n, n),
2471            lambda n: R(n, n).transpose(0, 1),
2472            lambda n: R(n + 1, n + 1, 2)[:n, n, 0],
2473            lambda n: R(n, n, 2)[:, :, 0],
2474            lambda n: R(n, n + 1, n + 2, n + 3).to(memory_format=torch.channels_last),
2475        )
2476
2477        with texpr_enable_strategy([("DYNAMIC", 20)]):
2478
2479            def foo(x, y, z):
2480                return torch.sigmoid(torch.tanh(x))
2481
2482            foo.__disable_jit_function_caching__ = True
2483
2484            def fi(x, y, z):
2485                return torch.tanh(x + y)
2486
2487            fi.__disable_jit_function_caching__ = True
2488
2489            def fum(x, y, z):
2490                return torch.tanh(x + y) + z
2491
2492            fum.__disable_jit_function_caching__ = True
2493
2494            funcs = [foo, fi, fum]
2495            with inline_fusion_groups():
2496                for device in self.devices:
2497                    I = partial(torch.randint, 0, 100, device=device)
2498                    R = partial(torch.randn, device=device)
2499
2500                    for i, func in enumerate(funcs):
2501                        num_args = i + 1
2502                        for j, gen in enumerate(gen_tensor):
2503                            inps = (gen(n), gen(n), gen(n))
2504                            func_s = torch.jit.trace(func, inps, check_trace=False)
2505                            torch._C._jit_pass_erase_shape_information(func_s.graph)
2506                            for _ in range(2):
2507                                x, y, z = gen(n), gen(n), gen(n)
2508                                func_s(x, y, z)
2509
2510                            for incr in range(3):
2511                                func_s(*[gen(n + 1) for _ in range(3)])
2512
2513                            g = torch.jit.last_executed_optimized_graph()
2514                            torch._C._jit_pass_inline(g)
2515                            torch._C._jit_pass_dce(g)
2516
2517                            # We should see only one optimized kernel
2518                            FileCheck().check_count(
2519                                "TensorExprDynamicGuard", 1, exactly=True
2520                            ).run(g)
2521                            self.assertEqual(func(*inps), func_s(*inps))
2522
2523                    gen = gen_tensor[0]
2524                    inps = (gen(n), gen(n), gen(n))
2525                    foo_s = torch.jit.trace(foo, inps)
2526                    torch._C._jit_pass_erase_shape_information(foo_s.graph)
2527                    g_prev = None
2528                    for gen in gen_tensor:
2529                        for i in range(3):
2530                            foo_s(*[gen(n + i) for _ in range(3)])
2531                            inps = (gen(n), gen(n), gen(n))
2532                            self.assertEqual(foo_s(*inps), foo(*inps))
2533                    g = torch.jit.last_executed_optimized_graph()
2534                    torch._C._jit_pass_inline(g)
2535                    torch._C._jit_pass_dce(g)
2536                    FileCheck().check_count(
2537                        "TensorExprDynamicGuard", len(gen_tensor), exactly=True
2538                    ).run(g)
2539
2540    @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA")
2541    def test_autocast_up(self):
2542        def f(x):
2543            y = x._autocast_to_full_precision(True, True)
2544            z = torch.exp(y)
2545            return z
2546
2547        x = torch.rand((2, 2), dtype=torch.half, device="cuda")
2548        scr = torch.jit.script(f)
2549        scr(x)
2550        scr(x)
2551        self.assertLastGraphAllFused()
2552
2553    @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA")
2554    def test_autocast_down(self):
2555        def f(x):
2556            y = torch.sigmoid(x)
2557            z = y._autocast_to_reduced_precision(True, True, torch.half, torch.half)
2558            return z
2559
2560        x = torch.rand((2, 2), dtype=torch.float, device="cuda")
2561        scr = torch.jit.script(f)
2562        scr(x)
2563        scr(x)
2564        self.assertLastGraphAllFused()
2565
2566    @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2567    def test_to_dtype(self):
2568        def f(x):
2569            y = torch.sigmoid(x)
2570            z = y._autocast_to_reduced_precision(True, True, torch.half, torch.bfloat16)
2571            h = z._autocast_to_full_precision(True, True)
2572            i = h.to(dtype=torch.bfloat16)
2573            j = i.to(dtype=torch.float32)
2574            return j
2575
2576        x = torch.rand((2, 2), dtype=torch.float32)
2577        scr = torch.jit.trace(f, x)
2578        scr(x)
2579        scr(x)
2580        self.assertLastGraphAllFused()
2581        self.assertEqual(f(x), scr(x), atol=4e-3, rtol=4e-3)
2582
2583        bf_x = torch.rand((2, 2), dtype=torch.bfloat16)
2584        bf_scr = torch.jit.trace(f, bf_x)
2585        bf_scr(bf_x)
2586        bf_scr(bf_x)
2587        graph = bf_scr.graph_for(bf_x)
2588        fusion_groups = self.findFusionGroups(graph)
2589        self.assertEqual(len(fusion_groups), 2)
2590        self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3)
2591
2592    def test_with_strict_fusion(self):
2593        def success(x):
2594            with torch.jit.strict_fusion():
2595                return x + x + x
2596
2597        scripted = self.checkScript(success, (torch.rand([4]),))
2598        g = torch.jit.last_executed_optimized_graph()
2599        FileCheck().check_not("aten::add").check("prim::TensorExprGroup").run(g)
2600
2601        def foo(x):
2602            with torch.jit.strict_fusion():
2603                return x + x + torch.rand([4]) + 3
2604
2605        with self.assertRaises(Exception) as error_out:
2606            foo_s = torch.jit.script(foo)
2607            foo_s(torch.rand([4]))
2608            foo_s(torch.rand([4]))
2609            print(torch.jit.last_executed_optimized_graph())
2610        fc = FileCheck().check("Found unfused operators")
2611        fc.check("aten::rand(SymInt[] size")
2612        fc.check("torch.rand([4]").run(str(error_out.exception))
2613
2614        with warnings.catch_warnings(record=True) as warns:
2615            foo(torch.rand([4]))
2616
2617        FileCheck().check("Only works in script mode").run(str(warns[0]))
2618
2619        def test_autodiff(x):
2620            with torch.jit.strict_fusion():
2621                return torch.rand([4]) + x + x + x
2622
2623        foo_s = torch.jit.script(test_autodiff)
2624        inp = torch.rand([4], requires_grad=True)
2625        with self.assertRaises(Exception) as error_out:
2626            for _ in range(3):
2627                foo_s(inp)
2628        f = FileCheck().check("unfused operators").check("aten::rand")
2629        f.run(str(error_out.exception))
2630
2631        def test_separate_fusions(x, y):
2632            with torch.jit.strict_fusion():
2633                return x + x + x, y + y + y
2634
2635        inp = torch.rand([4], requires_grad=True)
2636        with self.assertRaises(Exception) as error_out:
2637            for _ in range(3):
2638                foo_s = torch.jit.script(test_separate_fusions)
2639                foo_s(inp, inp)
2640
2641        f = FileCheck().check("Found multiple fusions")
2642        f.run(str(error_out.exception))
2643
2644    def test_constant_chunk_shapes(self):
2645        # We had an issue where buildShapeExpressions would fail as show below:
2646        #
2647        # %1 : Tensor = Constant[..]  # not supported, we don't build this shape
2648        # %2 : Tensor = Constant[..]  # not supported
2649        # %3 : Tensor = aten::add(%1, %2)  # inputs not supported, we don't build shape
2650        # ... = prim::ConstantChunk[..](%3)  # it forgets to check whether input shapes exist, and fails
2651        if self.dynamic_shapes:
2652            self.skipTest("TODO: chunk dynamic shapes")
2653
2654        for device in self.devices:
2655
2656            def f(x, y):
2657                r = torch.tensor(4)
2658                z1, z2 = (x + y + r).chunk(2, dim=1)
2659                return z1 * z2
2660
2661            x = torch.randn(4, 4, dtype=torch.float, device=device)
2662            y = torch.randn(4, 4, dtype=torch.float, device=device)
2663
2664            ge = self.checkTrace(f, (x, y))
2665            graph = ge.graph_for(x, y)
2666
2667            # make sure that we are actually testing the right scenario
2668            FileCheck().check("with " + FUSION_GROUP + "_").check_count(
2669                "ConstantChunk", 1, exactly=True
2670            ).run(str(graph))
2671
2672            f_traced = torch.jit.trace(f, (x, y))
2673
2674            for i in range(4):
2675                # make sure this doesn't error out
2676                res = f_traced(x, y)
2677
2678            self.assertEqual(res, f(x, y))
2679
2680    @unittest.skipIf(not RUN_CUDA_HALF, "half-precision NNC fusion requires CUDA")
2681    def test_pow_multiple_dtype(self):
2682        # https://github.com/pytorch/pytorch/issues/75476
2683        def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
2684            p = torch.sigmoid(p)
2685            result = p**gamma
2686            return result
2687
2688        x = torch.rand((2, 2), dtype=torch.half, device="cuda")
2689
2690        ref = fn(x)
2691
2692        script_fn = torch.jit.script(fn)
2693        for i in range(4):
2694            res = script_fn(x)
2695
2696        self.assertEqual(ref, res)
2697
2698
2699class TestTEFuserStatic(TestTEFuser):
2700    dynamic_shapes = False
2701
2702
2703class TestTEFuserDynamic(TestTEFuser):
2704    dynamic_shapes = True
2705
2706
2707del TestTEFuser
2708
2709works_list = [
2710    "__radd__",
2711    "__rdiv__",
2712    "__rmul__",
2713    "__rmod__",
2714    "abs",
2715    "acos",
2716    "add",
2717    "addcmul",
2718    "addmm.decomposed",
2719    "asin",
2720    "atan",
2721    "atan2",
2722    "ceil",
2723    "clamp",
2724    "clamp.scalar",
2725    "contiguous",
2726    "cos",
2727    "cosh",
2728    "div.no_rounding_mode",
2729    "div.true_rounding",
2730    "div.floor_rounding",
2731    "div.trunc_rounding",
2732    "eq",
2733    "erf",
2734    "erfc",
2735    "exp",
2736    "expand",
2737    "expand_as",
2738    "expm1",
2739    "floor",
2740    "fmod",
2741    "fmod.autodiffed",
2742    "ge",
2743    "gt",
2744    "isnan",
2745    "le",
2746    "lerp",
2747    "lgamma",
2748    "log",
2749    "log10",
2750    "log1p",
2751    "log2",
2752    "lt",
2753    "masked_fill",
2754    "max.binary",
2755    "mean",
2756    "min.binary",
2757    "mm",
2758    "mul",
2759    "ne",
2760    "neg",
2761    "nn.functional.hardshrink",
2762    "nn.functional.hardsigmoid",
2763    "nn.functional.hardswish",
2764    "nn.functional.softplus",
2765    "nn.functional.hardtanh",
2766    "nn.functional.leaky_relu",
2767    "nn.functional.relu",
2768    "nn.functional.relu6",
2769    "nn.functional.softsign",
2770    "nn.functional.tanhshrink",
2771    "nn.functional.threshold",
2772    "permute",
2773    "pow",
2774    "reciprocal",
2775    "remainder",
2776    "remainder.autodiffed",
2777    "reshape",
2778    "reshape_as",
2779    "round",
2780    "rsub",
2781    "rsub.rsub_tensor",
2782    "rsqrt",
2783    "sigmoid",
2784    "sign",
2785    "sin",
2786    "sinh",
2787    "sqrt",
2788    "sub",
2789    "sum",
2790    "t",
2791    "tan",
2792    "tanh",
2793    "transpose",
2794    "true_divide",
2795    "trunc",
2796    "unsqueeze",
2797    "view",
2798    "view_as",
2799    "where",
2800    "bool",
2801    "byte",
2802    "char",
2803    "double",
2804    "float",
2805    "half",
2806    "int",
2807    "long",
2808    "short",
2809    "bool.channels_last",
2810    "byte.channels_last",
2811    "char.channels_last",
2812    "double.channels_last",
2813    "float.channels_last",
2814    "half.channels_last",
2815    "int.channels_last",
2816    "long.channels_last",
2817    "short.channels_last",
2818]
2819
2820known_failures = [
2821    "__rmatmul__",
2822    "frac",
2823    "matmul",
2824]
2825
2826# If your OpInfo test causes this test to fail, add it here
2827skip_ops = ["conj"]
2828
2829
2830def get_name(op):
2831    l = [op.name]
2832    if op.variant_test_name != "":
2833        l.append(op.variant_test_name)
2834    return ".".join(l)
2835
2836
2837# Purpose of this class is to allow super() calls.
2838# super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works.
2839# super(TestNNCOpInfo, self) fails because TestNNCOpInfo gets deleted from global scope.
2840# super(JitCommonTestCase, self).fn() would skip JitCommonTestCase.fn() implementation
2841class TestNNCOpInfoParent(JitCommonTestCase):
2842    pass
2843
2844
2845class TestNNCOpInfo(TestNNCOpInfoParent):
2846    def setUp(self):
2847        super(TestNNCOpInfoParent, self).setUp()
2848        self.tensorexpr_options = TensorExprTestOptions()
2849
2850    def tearDown(self):
2851        self.tensorexpr_options.restore()
2852        super(TestNNCOpInfoParent, self).tearDown()
2853
2854    def te_compile(self, device, dtype, op):
2855        if op.name in skip_ops:
2856            return
2857        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
2858        for sample_input in sample_inputs_itr:
2859            arg_values = [sample_input.input] + list(sample_input.args)
2860            kwarg_values = sample_input.kwargs
2861            param_names = []
2862            param_values = []
2863            fx_args = []
2864            for idx, v in enumerate(arg_values):
2865                if isinstance(v, torch.Tensor):
2866                    param_names.append(f"arg_{idx}")
2867                    param_values.append(v)
2868                    fx_args.append(param_names[-1])
2869                else:
2870                    fx_args.append(f"{repr(v)}")
2871
2872            for k, v in kwarg_values.items():
2873                if isinstance(v, torch.Tensor):
2874                    param_names.append(k)
2875                    param_values.append(v)
2876                    fx_args.append(f"{k} = {k}")
2877                else:
2878                    fx_args.append(f"{k} = {repr(v)}")
2879
2880            code = f"""
2881def f({', '.join(param_names)}):
2882    return op.op({', '.join(fx_args)})"""
2883            g = {"torch": torch, "inf": math.inf, "op": op}
2884            exec(code, g)
2885            f = g["f"]
2886            f.__module__ = "test"
2887            out = f(*param_values)
2888
2889            ts_g = torch.jit.trace(f, param_values)
2890            kernel = torch._C._te.TensorExprKernel(ts_g.graph)
2891            correct_val = f(*param_values)
2892            self.assertEqual(kernel.run(tuple(param_values)), correct_val)
2893            self.assertEqual(kernel.fallback(tuple(param_values)), correct_val)
2894
2895    @onlyCPU
2896    @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2897    @ops(
2898        [op for op in op_db if get_name(op) in works_list],
2899        allowed_dtypes=(torch.float,),
2900    )
2901    def test_working(self, device, dtype, op):
2902        self.te_compile(device, dtype, op)
2903
2904    @onlyCPU
2905    @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2906    @ops(
2907        [op for op in op_db if get_name(op) in known_failures],
2908        allowed_dtypes=(torch.float,),
2909    )
2910    def test_failures(self, device, dtype, op):
2911        try:
2912            self.te_compile(device, dtype, op)
2913        except Exception as e:
2914            pass
2915        else:
2916            raise RuntimeError(
2917                "Expected test to fail. If it now works, move op into works_list"
2918            )
2919
2920    @onlyCPU
2921    @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2922    @ops(
2923        [op for op in op_db if get_name(op) not in works_list + known_failures],
2924        allowed_dtypes=(torch.float,),
2925    )
2926    def test_unsupported(self, device, dtype, op):
2927        if get_name(op) in skip_ops:
2928            return
2929        try:
2930            with warnings.catch_warnings():
2931                warnings.simplefilter("ignore", TracerWarning)  # noqa: F821
2932                self.te_compile(device, dtype, op)
2933        except Exception as e:
2934            pass
2935        else:
2936            raise RuntimeError(
2937                "Expected test to fail. If it now works, move op into works_list"
2938            )
2939
2940    @slowTest
2941    @onlyCPU
2942    @ops(op_db, dtypes=OpDTypes.supported)
2943    def test_nnc_correctness(self, device, dtype, op):
2944        if not op.supports_tracing:
2945            self.skipTest("Requires tracing support")
2946
2947        with NoTracerWarnContextManager() as no_warn:
2948            variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
2949
2950            for variant, sample in variant_sample_pairs:
2951                trace = create_traced_fn(self, variant, cache_traced_fn=True)
2952                ref = variant(
2953                    *clone_inputs((sample.input, *sample.args)), **sample.kwargs
2954                )
2955
2956                trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
2957                val = trace(
2958                    *clone_inputs((sample.input, *sample.args)), **sample.kwargs
2959                )
2960
2961                atol = 2e-1 if dtype == torch.bfloat16 else 1e-5
2962                rtol = 2e-1 if dtype == torch.bfloat16 else 1e-5
2963                self.assertEqual(ref, val, atol=atol, rtol=rtol)
2964
2965            # https://github.com/pytorch/pytorch/issues/35600
2966            # each torch.jit.trace adds state to the _python_cu compilation unit
2967            # since this test traces a lot of functions, out-of-memory can occur
2968            # if the CU is not cleared.
2969            torch.jit._state._python_cu.drop_all_functions()
2970
2971
2972# CPU fuser not currently used in fbcode
2973only_for = ("cuda") if IS_FBCODE else ("cpu", "cuda")
2974instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
2975
2976
2977# Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent)
2978class TestLoopnestRandomizationParent(JitTestCase):
2979    pass
2980
2981
2982class TestLoopnestRandomization(TestLoopnestRandomizationParent):
2983    def setUp(self):
2984        super(TestLoopnestRandomizationParent, self).setUp()
2985        self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
2986        self.old_must_use_cpu_state = torch._C._jit_get_te_must_use_llvm_cpu()
2987        self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
2988
2989        torch._C._jit_override_can_fuse_on_cpu(True)
2990        # TODO: force LLVM. need to add it to asan, mac, windows builds + sandcastle
2991        # torch._C._jit_set_te_must_use_llvm_cpu(True)
2992        torch._C._jit_override_can_fuse_on_gpu(True)
2993
2994        self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
2995        self.old_profiling_mode = torch._C._get_graph_executor_optimize(True)
2996
2997        self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
2998        torch._C._debug_set_fusion_group_inlining(False)
2999
3000        self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
3001        torch._C._jit_set_texpr_fuser_enabled(True)
3002
3003        self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
3004        torch._C._jit_set_te_must_use_llvm_cpu(False)
3005
3006        # Set the seed to 1. This tests the codepath through random
3007        # transformation.
3008        os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "1"
3009
3010    def tearDown(self):
3011        torch._C._jit_set_profiling_executor(self.old_profiling_executor)
3012        torch._C._get_graph_executor_optimize(self.old_profiling_mode)
3013
3014        torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
3015        torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
3016        torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state)
3017        torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
3018
3019        torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
3020        torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
3021
3022        # Set it back to 0.
3023        os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "0"
3024        super(TestLoopnestRandomizationParent, self).tearDown()
3025
3026    @onlyCPU
3027    @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
3028    def test_relu(self, device):
3029        def fn_test_relu(x, y):
3030            return F.relu(x + 0.5 * y)
3031
3032        x = torch.randn(4, 4, dtype=torch.float, device=device)
3033        y = torch.randn(4, 4, dtype=torch.float, device=device)
3034
3035        fn = fn_test_relu
3036        traced_fn = torch.jit.trace(fn, (x, y))
3037
3038        ref = fn(x, y)
3039        res = traced_fn(x, y)
3040        assert torch.allclose(ref, res)
3041
3042
3043instantiate_device_type_tests(TestLoopnestRandomization, globals(), only_for=("cpu"))
3044
3045
3046if __name__ == "__main__":
3047    run_tests()
3048