xref: /aosp_15_r20/external/pytorch/test/test_jit_fuser.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
8*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
10*da0073e9SAndroid Build Coastguard Workerfrom unittest import skipIf
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
13*da0073e9SAndroid Build Coastguard Worker    enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \
15*da0073e9SAndroid Build Coastguard Worker    RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward
16*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent
17*da0073e9SAndroid Build Coastguard Workerfrom itertools import product, permutations
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import with_tf32_off
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerfrom test_jit import backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, \
21*da0073e9SAndroid Build Coastguard Worker    LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerif GRAPH_EXECUTOR == ProfilingMode.PROFILING:
24*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_profiling_executor(True)
25*da0073e9SAndroid Build Coastguard Worker    torch._C._jit_set_profiling_mode(True)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerdef strip_profiling_nodes(nodes):
29*da0073e9SAndroid Build Coastguard Worker    profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'}
30*da0073e9SAndroid Build Coastguard Worker    return [n for n in nodes if n.kind() not in profiling_opcodes]
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workerdef warmup_forward(f, *args):
34*da0073e9SAndroid Build Coastguard Worker    profiling_count = 2
35*da0073e9SAndroid Build Coastguard Worker    for i in range(profiling_count):
36*da0073e9SAndroid Build Coastguard Worker        results = f(*args)
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker    return results
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker@skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "skip due to SIGIOT failures, #67646")
42*da0073e9SAndroid Build Coastguard Workerclass TestFuser(JitTestCase):
43*da0073e9SAndroid Build Coastguard Worker    def assertAllFused(self, graph, except_for=()):
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        diff_graphs = [n for n in graph.nodes() if n.kind() == 'prim::DifferentiableGraph']
46*da0073e9SAndroid Build Coastguard Worker        if len(diff_graphs) > 0:
47*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(diff_graphs), 1)
48*da0073e9SAndroid Build Coastguard Worker            graph = diff_graphs[0].g('Subgraph')
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        allowed_nodes = {'prim::Constant', 'prim::FusionGroup', 'prim::BailoutTemplate',
51*da0073e9SAndroid Build Coastguard Worker                         'prim::BailOut', 'prim::TupleConstruct'} | set(except_for)
52*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
53*da0073e9SAndroid Build Coastguard Worker                        f'got {graph}')
54*da0073e9SAndroid Build Coastguard Worker        self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    def _test_fused_abs(self, device='cpu'):
57*da0073e9SAndroid Build Coastguard Worker        def func(x):
58*da0073e9SAndroid Build Coastguard Worker            return x.abs() * 2
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, device=device)
61*da0073e9SAndroid Build Coastguard Worker        scripted = self.checkScript(func, (a,))
62*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(scripted.graph_for(a))
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
65*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
66*da0073e9SAndroid Build Coastguard Worker    def test_abs_cpu(self):
67*da0073e9SAndroid Build Coastguard Worker        self._test_fused_abs()
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not IS_WINDOWS, "This is meant to be Windows-specific")
70*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
71*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
72*da0073e9SAndroid Build Coastguard Worker    def test_abs_cpu_unicode_temp_dir(self):
73*da0073e9SAndroid Build Coastguard Worker        with TemporaryDirectoryName(suffix='\u4e2d\u6587') as dname:
74*da0073e9SAndroid Build Coastguard Worker            shell_env = os.environ.copy()
75*da0073e9SAndroid Build Coastguard Worker            shell_env['TMP'] = dname
76*da0073e9SAndroid Build Coastguard Worker            cmd = [sys.executable, os.path.basename(__file__), type(self).__name__ + '.test_abs_cpu']
77*da0073e9SAndroid Build Coastguard Worker            legacy_jit_flag = '--jit-executor=legacy'
78*da0073e9SAndroid Build Coastguard Worker            for v in sys.argv:
79*da0073e9SAndroid Build Coastguard Worker                if v == legacy_jit_flag:
80*da0073e9SAndroid Build Coastguard Worker                    cmd.append(legacy_jit_flag)
81*da0073e9SAndroid Build Coastguard Worker            return_code = shell(cmd, cwd=os.path.dirname(__file__), env=shell_env)
82*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(return_code, 0)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
85*da0073e9SAndroid Build Coastguard Worker    def test_abs_cuda(self):
86*da0073e9SAndroid Build Coastguard Worker        self._test_fused_abs(device="cuda")
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
89*da0073e9SAndroid Build Coastguard Worker    def test_zero_element_tensors(self):
90*da0073e9SAndroid Build Coastguard Worker        def decode(sin_t, cos_t):
91*da0073e9SAndroid Build Coastguard Worker            theta = torch.atan2(sin_t.float(), cos_t.float())
92*da0073e9SAndroid Build Coastguard Worker            return theta
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker        sin = torch.zeros(0, device="cuda")
95*da0073e9SAndroid Build Coastguard Worker        cos = torch.zeros(0, device="cuda")
96*da0073e9SAndroid Build Coastguard Worker        inputs = [sin, cos]
97*da0073e9SAndroid Build Coastguard Worker        ge = self.checkScript(decode, inputs)
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
100*da0073e9SAndroid Build Coastguard Worker    def test_arg_configurations_smoke_cuda(self):
101*da0073e9SAndroid Build Coastguard Worker        # A smoke test to make sure we won't use the same kernel for contiguous
102*da0073e9SAndroid Build Coastguard Worker        # and non-contiguous arguments.
103*da0073e9SAndroid Build Coastguard Worker        # TODO: add optionally enabled debug counters to the fuser to verify
104*da0073e9SAndroid Build Coastguard Worker        #       that we really can tell the difference between configurations
105*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
106*da0073e9SAndroid Build Coastguard Worker            z1, z2 = (x + y).chunk(2, dim=1)
107*da0073e9SAndroid Build Coastguard Worker            return z1 * z2
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
110*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
111*da0073e9SAndroid Build Coastguard Worker        traced_f = torch.jit.trace(f, (x, y,))
112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
115*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_cuda(self):
116*da0073e9SAndroid Build Coastguard Worker        def scaleshift(x, scale, shift):
117*da0073e9SAndroid Build Coastguard Worker            return x * scale + shift
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker        inputs = [
120*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 4, dtype=torch.float, device='cuda'),
121*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, dtype=torch.float, device='cuda'),
122*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, dtype=torch.float, device='cuda'),
123*da0073e9SAndroid Build Coastguard Worker        ]
124*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(scaleshift, inputs)
125*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(*inputs))
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
128*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no bfloat support with profiling on")
129*da0073e9SAndroid Build Coastguard Worker    def test_cuda_bfloat16(self):
130*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
131*da0073e9SAndroid Build Coastguard Worker            return (x + y).relu()
132*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(foo)
133*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(65536).cuda().bfloat16()
134*da0073e9SAndroid Build Coastguard Worker        y = torch.randn_like(x)
135*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(m.graph_for(x, y))
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
138*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
139*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
140*da0073e9SAndroid Build Coastguard Worker    def test_cuda_half(self):
141*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.half, device='cuda')
142*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, dtype=torch.half, device='cuda')
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        funcs = [
145*da0073e9SAndroid Build Coastguard Worker            self.fn_test_comparison_gt_lt,
146*da0073e9SAndroid Build Coastguard Worker            self.fn_test_relu,
147*da0073e9SAndroid Build Coastguard Worker            self.fn_test_exp
148*da0073e9SAndroid Build Coastguard Worker        ]
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        # Note: Non fused inputs must be float to prevent loss of precision
151*da0073e9SAndroid Build Coastguard Worker        inputs = (x.float(), y.float())
152*da0073e9SAndroid Build Coastguard Worker        fusion_inputs = (x, y)
153*da0073e9SAndroid Build Coastguard Worker        for fn in funcs:
154*da0073e9SAndroid Build Coastguard Worker            local_inputs = [t.clone().requires_grad_() for t in inputs]
155*da0073e9SAndroid Build Coastguard Worker            local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker            # Verifies outputs
158*da0073e9SAndroid Build Coastguard Worker            fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False)
159*da0073e9SAndroid Build Coastguard Worker            outputs = fn(*local_inputs)
160*da0073e9SAndroid Build Coastguard Worker            fusion_outputs = fusion(*local_fusion_inputs)
161*da0073e9SAndroid Build Coastguard Worker            outputs_half = [t.half() for t in outputs]
162*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(outputs_half, fusion_outputs)
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker            # Verifies gradients
165*da0073e9SAndroid Build Coastguard Worker            for output, fusion_output in zip(outputs_half, fusion_outputs):
166*da0073e9SAndroid Build Coastguard Worker                grads = torch.autograd.grad(
167*da0073e9SAndroid Build Coastguard Worker                    output.float().sum(), local_inputs, allow_unused=True, retain_graph=True)
168*da0073e9SAndroid Build Coastguard Worker                fusion_grads = torch.autograd.grad(
169*da0073e9SAndroid Build Coastguard Worker                    fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True)
170*da0073e9SAndroid Build Coastguard Worker                grads_half = [t.half() for t in grads]
171*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grads_half, fusion_grads)
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
174*da0073e9SAndroid Build Coastguard Worker    def test_checks_cat_inputs(self):
175*da0073e9SAndroid Build Coastguard Worker        # We shouldn't treat cat nodes as broadcasting. All their inputs
176*da0073e9SAndroid Build Coastguard Worker        # need to be checked for having the same map size, before we can
177*da0073e9SAndroid Build Coastguard Worker        # run the kernel.
178*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
179*da0073e9SAndroid Build Coastguard Worker            return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker        # NOTE: y is broadcastable to x, but output of f(x, y) should have
182*da0073e9SAndroid Build Coastguard Worker        # shape 3x4, and not 4x4.
183*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 4, dtype=torch.float, device='cuda')
184*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(1, 4, dtype=torch.float, device='cuda')
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker        scripted = self.checkScript(f, (x, y))
187*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(scripted.graph_for(x, y))
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "No CUDA")
190*da0073e9SAndroid Build Coastguard Worker    def test_remainder_cuda(self):
191*da0073e9SAndroid Build Coastguard Worker        def cuda_rem(x, y):
192*da0073e9SAndroid Build Coastguard Worker            return 1 + torch.remainder(x, y) - 1
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker        a = torch.rand([512], dtype=torch.float).cuda()
195*da0073e9SAndroid Build Coastguard Worker        b = torch.rand([512], dtype=torch.float).cuda()
196*da0073e9SAndroid Build Coastguard Worker        inputs = [a, b]
197*da0073e9SAndroid Build Coastguard Worker        ge = self.checkScript(cuda_rem, inputs)
198*da0073e9SAndroid Build Coastguard Worker        graph = ge.graph_for(*inputs)
199*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(graph)
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "No CUDA")
202*da0073e9SAndroid Build Coastguard Worker    def test_chunk_cuda(self):
203*da0073e9SAndroid Build Coastguard Worker        def fn(x):
204*da0073e9SAndroid Build Coastguard Worker            a, b, c = x.chunk(3, 1)
205*da0073e9SAndroid Build Coastguard Worker            return a * b + c
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        ge = self.checkScript(fn, inputs)
210*da0073e9SAndroid Build Coastguard Worker        graph = ge.graph_for(*inputs)
211*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(graph)
212*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::ConstantChunk[chunks=3, dim=1]").run(str(graph))
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    @staticmethod
215*da0073e9SAndroid Build Coastguard Worker    def _test_chunk_correctness(self, device='cpu'):
216*da0073e9SAndroid Build Coastguard Worker        def chunk_4_0(x):
217*da0073e9SAndroid Build Coastguard Worker            x0, x1, x2, x3 = x.chunk(4, 0)
218*da0073e9SAndroid Build Coastguard Worker            return x0 + x1 + x2 + x3
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker        def chunk_4_1(x):
221*da0073e9SAndroid Build Coastguard Worker            x0, x1, x2, x3 = x.chunk(4, 1)
222*da0073e9SAndroid Build Coastguard Worker            return x0 + x1 + x2 + x3
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker        def chunk_4_last(x):
225*da0073e9SAndroid Build Coastguard Worker            x0, x1, x2, x3 = x.chunk(4, 2)
226*da0073e9SAndroid Build Coastguard Worker            return x0 + x1 + x2 + x3
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker        fns = [chunk_4_0, chunk_4_1, chunk_4_last]
229*da0073e9SAndroid Build Coastguard Worker        tensors = [
230*da0073e9SAndroid Build Coastguard Worker            # splitSize = 1
231*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 4, 4, dtype=torch.float, device=device),
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker            # contiguous case
234*da0073e9SAndroid Build Coastguard Worker            torch.randn(12, 8, 16, dtype=torch.float, device=device),
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker            # non-contiguous case
237*da0073e9SAndroid Build Coastguard Worker            torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2),
238*da0073e9SAndroid Build Coastguard Worker        ]
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker        for tensor in tensors:
241*da0073e9SAndroid Build Coastguard Worker            for fn in fns:
242*da0073e9SAndroid Build Coastguard Worker                self.checkScript(fn, [tensor])
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
245*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
246*da0073e9SAndroid Build Coastguard Worker    def test_chunk_correctness(self):
247*da0073e9SAndroid Build Coastguard Worker        return self._test_chunk_correctness(self, 'cpu')
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "No CUDA")
250*da0073e9SAndroid Build Coastguard Worker    def test_chunk_correctness_cuda(self):
251*da0073e9SAndroid Build Coastguard Worker        return self._test_chunk_correctness(self, 'cuda')
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
254*da0073e9SAndroid Build Coastguard Worker    def test_chunk_distributes_cuda(self):
255*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
256*da0073e9SAndroid Build Coastguard Worker            z1, z2 = (x + y).chunk(2, dim=1)
257*da0073e9SAndroid Build Coastguard Worker            return z1 * z2
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
260*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(f, (x, y))
263*da0073e9SAndroid Build Coastguard Worker        graph = ge.graph_for(x, y)
264*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("broadcast_tensors").check('with prim::FusionGroup_') \
265*da0073e9SAndroid Build Coastguard Worker            .check_count('ConstantChunk', 2, exactly=True).run(str(graph))
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
268*da0073e9SAndroid Build Coastguard Worker    def test_chunk_motion_deduplicates_inputs(self):
269*da0073e9SAndroid Build Coastguard Worker        def func1(x):
270*da0073e9SAndroid Build Coastguard Worker            z = x * x
271*da0073e9SAndroid Build Coastguard Worker            z0, z1 = z.chunk(2)
272*da0073e9SAndroid Build Coastguard Worker            return z0 * z1
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        def func2(x):
275*da0073e9SAndroid Build Coastguard Worker            z = x * x * x
276*da0073e9SAndroid Build Coastguard Worker            z0, z1 = z.chunk(2)
277*da0073e9SAndroid Build Coastguard Worker            return z0 * z1
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker        inputs = [
280*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
281*da0073e9SAndroid Build Coastguard Worker        ]
282*da0073e9SAndroid Build Coastguard Worker        for func in [func1, func2]:
283*da0073e9SAndroid Build Coastguard Worker            module = self.checkScript(func, inputs)
284*da0073e9SAndroid Build Coastguard Worker            forward_graph = module.graph_for(*inputs)
285*da0073e9SAndroid Build Coastguard Worker            self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
286*da0073e9SAndroid Build Coastguard Worker            fusion_group = list(forward_graph.nodes())[-1]
287*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(list(fusion_group.inputs())), 1)
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "No CUDA")
290*da0073e9SAndroid Build Coastguard Worker    def test_chunk_multiple_cuda(self):
291*da0073e9SAndroid Build Coastguard Worker        # The arguments are intentionally used out of order as a test to see
292*da0073e9SAndroid Build Coastguard Worker        # if the fusion compiler adds extra args in the correct order
293*da0073e9SAndroid Build Coastguard Worker        def fn(s, x, y, z):
294*da0073e9SAndroid Build Coastguard Worker            z1, z2 = z.chunk(2, 2)
295*da0073e9SAndroid Build Coastguard Worker            x1, x2, x3 = x.chunk(3, 1)
296*da0073e9SAndroid Build Coastguard Worker            y1, y2 = y.chunk(2, 0)
297*da0073e9SAndroid Build Coastguard Worker            return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        inputs = [
300*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
301*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
302*da0073e9SAndroid Build Coastguard Worker            torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
303*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
304*da0073e9SAndroid Build Coastguard Worker        ]
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        ge = self.checkScript(fn, inputs)
307*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(*inputs))
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
310*da0073e9SAndroid Build Coastguard Worker    def test_minmax(self):
311*da0073e9SAndroid Build Coastguard Worker        def tmax(a, b):
312*da0073e9SAndroid Build Coastguard Worker            return torch.max(2 * a, b)
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker        def tmin(a, b):
315*da0073e9SAndroid Build Coastguard Worker            return torch.min(2 * a, b)
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, 4, dtype=torch.float, device="cuda")
318*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(4, 4, dtype=torch.float, device="cuda")
319*da0073e9SAndroid Build Coastguard Worker        nan = torch.tensor(float('nan'), dtype=torch.float, device="cuda")
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker        for f, inputs in product(
322*da0073e9SAndroid Build Coastguard Worker                (tmax, tmin),
323*da0073e9SAndroid Build Coastguard Worker                ([a, b], [a, nan], [b, nan])):
324*da0073e9SAndroid Build Coastguard Worker            s = self.checkScript(f, inputs)
325*da0073e9SAndroid Build Coastguard Worker            self.assertAllFused(s.graph_for(*inputs))
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
328*da0073e9SAndroid Build Coastguard Worker    def test_clamp(self):
329*da0073e9SAndroid Build Coastguard Worker        def func2(a, b):
330*da0073e9SAndroid Build Coastguard Worker            return torch.clamp(a + b, min=0, max=2)
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker        def funcInf(a, b):
333*da0073e9SAndroid Build Coastguard Worker            return torch.clamp(a + b, min=0, max=float('inf'))
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        def funcOptMin(a, b):
336*da0073e9SAndroid Build Coastguard Worker            return torch.clamp(a + b, max=2)
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker        def funcOptMax(a, b):
339*da0073e9SAndroid Build Coastguard Worker            return torch.clamp(a + b, min=0)
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
342*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(4, 4, dtype=torch.float, device='cuda')
343*da0073e9SAndroid Build Coastguard Worker        nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda')
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker        funcs = (func2, funcInf, funcOptMin, funcOptMax)
346*da0073e9SAndroid Build Coastguard Worker        for f, inputs in product(funcs, [[a, b], [a, nan]]):
347*da0073e9SAndroid Build Coastguard Worker            f.__disable_jit_function_caching__ = True
348*da0073e9SAndroid Build Coastguard Worker            inp1, inp2 = inputs
349*da0073e9SAndroid Build Coastguard Worker            s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
350*da0073e9SAndroid Build Coastguard Worker            self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
351*da0073e9SAndroid Build Coastguard Worker            c = s(inp1, inp2)
352*da0073e9SAndroid Build Coastguard Worker            with enable_profiling_mode_for_profiling_tests():
353*da0073e9SAndroid Build Coastguard Worker                warmup_backward(c.sum())
354*da0073e9SAndroid Build Coastguard Worker            graph = backward_graph(s)
355*da0073e9SAndroid Build Coastguard Worker            self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
358*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
359*da0073e9SAndroid Build Coastguard Worker    def test_dropout(self):
360*da0073e9SAndroid Build Coastguard Worker        def func(x):
361*da0073e9SAndroid Build Coastguard Worker            x = torch.nn.functional.dropout(x)
362*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.relu(x)
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
365*da0073e9SAndroid Build Coastguard Worker        s = torch.jit.script(func)
366*da0073e9SAndroid Build Coastguard Worker        c = s(a)
367*da0073e9SAndroid Build Coastguard Worker        c = s(a)
368*da0073e9SAndroid Build Coastguard Worker        warmup_backward(c.sum())
369*da0073e9SAndroid Build Coastguard Worker        # skip_check to skip extra bailout nodes in between
370*da0073e9SAndroid Build Coastguard Worker        graph = backward_graph(s, skip_check=True)
371*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(graph, except_for={'aten::div', 'prim::Constant'})
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
374*da0073e9SAndroid Build Coastguard Worker    def test_comparison_eq_ne(self):
375*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
376*da0073e9SAndroid Build Coastguard Worker            mask = (x == 0).type_as(x)
377*da0073e9SAndroid Build Coastguard Worker            z = x * mask + y
378*da0073e9SAndroid Build Coastguard Worker            mask = (x != 0).type_as(x)
379*da0073e9SAndroid Build Coastguard Worker            z = z * mask + y
380*da0073e9SAndroid Build Coastguard Worker            return z
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
383*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(f, (x, y))
386*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(x, y))
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker    @staticmethod
389*da0073e9SAndroid Build Coastguard Worker    def fn_test_comparison_gt_lt(x, y):
390*da0073e9SAndroid Build Coastguard Worker        mask = (x > 0).type_as(x)
391*da0073e9SAndroid Build Coastguard Worker        z = x * mask + y
392*da0073e9SAndroid Build Coastguard Worker        mask = (x < 0).type_as(x)
393*da0073e9SAndroid Build Coastguard Worker        z = z * mask + y
394*da0073e9SAndroid Build Coastguard Worker        return z
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
397*da0073e9SAndroid Build Coastguard Worker    def test_comparison_gt_lt_cuda(self):
398*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
399*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
402*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(x, y))
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
405*da0073e9SAndroid Build Coastguard Worker    def test_comparison_ge_le_cuda(self):
406*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
407*da0073e9SAndroid Build Coastguard Worker            mask = (x >= 0).type_as(x)
408*da0073e9SAndroid Build Coastguard Worker            z = x * mask + y
409*da0073e9SAndroid Build Coastguard Worker            mask = (x <= 0).type_as(x)
410*da0073e9SAndroid Build Coastguard Worker            z = z * mask + y
411*da0073e9SAndroid Build Coastguard Worker            return z
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
414*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(f, (x, y))
417*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(x, y))
418*da0073e9SAndroid Build Coastguard Worker        x.requires_grad_(True)
419*da0073e9SAndroid Build Coastguard Worker        y.requires_grad_(True)
420*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
421*da0073e9SAndroid Build Coastguard Worker                                                            "aten::_size_if_not_equal"))
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
424*da0073e9SAndroid Build Coastguard Worker    def test_addcmul_cuda(self):
425*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(1, 4, dtype=torch.float, device='cuda')
426*da0073e9SAndroid Build Coastguard Worker        t1 = torch.randn(4, 1, dtype=torch.float, device='cuda')
427*da0073e9SAndroid Build Coastguard Worker        t2 = torch.randn(1, 4, dtype=torch.float, device='cuda')
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker        def foo(t, t1, t2):
430*da0073e9SAndroid Build Coastguard Worker            return t.addcmul(t + 1, t2, value=0.1)
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
433*da0073e9SAndroid Build Coastguard Worker        graph = ge.graph_for(t, t1, t2)
434*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(graph)
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker    # TODO: We leak CUDA memory here because the traced graph holds onto a
437*da0073e9SAndroid Build Coastguard Worker    # constant-ified tensor. Since the Python-global CompilationUnit is alive
438*da0073e9SAndroid Build Coastguard Worker    # until the end of the process, the memory is effectively leaked.
439*da0073e9SAndroid Build Coastguard Worker    # Removed `_cuda` suffix from this test which disables leak-checking.
440*da0073e9SAndroid Build Coastguard Worker    # If this is a real problem, we'll need to revisit Torchscript Function
441*da0073e9SAndroid Build Coastguard Worker    # lifetimes in Python.
442*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
443*da0073e9SAndroid Build Coastguard Worker    def test_lerp(self):
444*da0073e9SAndroid Build Coastguard Worker        start = torch.randn(4, 1, dtype=torch.float, device='cuda')
445*da0073e9SAndroid Build Coastguard Worker        end = torch.randn(1, 4, dtype=torch.float, device='cuda')
446*da0073e9SAndroid Build Coastguard Worker        weight = torch.tensor(0.5, dtype=torch.float, device='cuda')
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker        # scalar weight overload
449*da0073e9SAndroid Build Coastguard Worker        def foo_weight_scalar(start, end):
450*da0073e9SAndroid Build Coastguard Worker            return torch.lerp(start + 1, end, 0.5)
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker        # tensor weight overload
453*da0073e9SAndroid Build Coastguard Worker        def foo_weight_tensor(start, end):
454*da0073e9SAndroid Build Coastguard Worker            return torch.lerp(start + 1, end, weight)
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end))
457*da0073e9SAndroid Build Coastguard Worker        graph = ge_weight_scalar.graph_for(start, end)
458*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(graph)
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker        ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end))
461*da0073e9SAndroid Build Coastguard Worker        graph = ge_weight_tensor.graph_for(start, end)
462*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(graph)
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
465*da0073e9SAndroid Build Coastguard Worker    def test_concat_cuda(self):
466*da0073e9SAndroid Build Coastguard Worker        hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
467*da0073e9SAndroid Build Coastguard Worker        cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker        def foo(hx, cx):
470*da0073e9SAndroid Build Coastguard Worker            return torch.cat((hx + cx, hx * cx))
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(foo, (hx, cx))
473*da0073e9SAndroid Build Coastguard Worker        graph = ge.graph_for(hx, cx)
474*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(graph)
475*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("FusedConcat").check_next("return").run(str(graph))
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
478*da0073e9SAndroid Build Coastguard Worker    def test_concat_invariant_cuda(self):
479*da0073e9SAndroid Build Coastguard Worker        # Invariant: the output of prim::FusedConcat may
480*da0073e9SAndroid Build Coastguard Worker        # not be an input to any node inside the FusionGroup.
481*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
482*da0073e9SAndroid Build Coastguard Worker            x1 = x + y
483*da0073e9SAndroid Build Coastguard Worker            y1 = x - y
484*da0073e9SAndroid Build Coastguard Worker            w = torch.cat([x1, y1])
485*da0073e9SAndroid Build Coastguard Worker            return w + z
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, dtype=torch.float, device='cuda')
488*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2, dtype=torch.float, device='cuda')
489*da0073e9SAndroid Build Coastguard Worker        z = torch.randn(4, 2, dtype=torch.float, device='cuda')
490*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(fn, (x, y, z))
491*da0073e9SAndroid Build Coastguard Worker        graph = ge.graph_for(x, y, z)
492*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(graph, except_for={'aten::add'})
493*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("FusedConcat").check_next("return").run(str(graph))
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker    @staticmethod
496*da0073e9SAndroid Build Coastguard Worker    def fn_test_exp(x, y):
497*da0073e9SAndroid Build Coastguard Worker        return (x + .5 * y).exp()
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
500*da0073e9SAndroid Build Coastguard Worker    def test_exp_cuda(self):
501*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
502*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(self.fn_test_exp, (x, y))
505*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(x, y))
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
508*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
509*da0073e9SAndroid Build Coastguard Worker    @torch._jit_internal._disable_emit_hooks_decorator
510*da0073e9SAndroid Build Coastguard Worker    @_inline_everything
511*da0073e9SAndroid Build Coastguard Worker    def test_fuse_decompose_normalization(self):
512*da0073e9SAndroid Build Coastguard Worker        class ResLike(torch.jit.ScriptModule):
513*da0073e9SAndroid Build Coastguard Worker            def __init__(self, norm_module):
514*da0073e9SAndroid Build Coastguard Worker                super().__init__()
515*da0073e9SAndroid Build Coastguard Worker                self.nm = norm_module
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
518*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
519*da0073e9SAndroid Build Coastguard Worker                return y + torch.relu(self.nm(x))
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker        def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph):
522*da0073e9SAndroid Build Coastguard Worker            model = ResLike(nm).cuda()
523*da0073e9SAndroid Build Coastguard Worker            model_noopt = ResLike(nm).cuda()
524*da0073e9SAndroid Build Coastguard Worker            model_noopt.load_state_dict(model.state_dict())
525*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 16, 8, 8, device='cuda')
526*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(2, 16, 8, 8, device='cuda')
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker            # FIXME: We need differentiation for CNNs for this optimization to trigger
529*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
530*da0073e9SAndroid Build Coastguard Worker                out = model(x, y)
531*da0073e9SAndroid Build Coastguard Worker                graph = model.graph_for(x, y)
532*da0073e9SAndroid Build Coastguard Worker                rep = str(graph)
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker                with torch.jit.optimized_execution(False):
535*da0073e9SAndroid Build Coastguard Worker                    out_noopt = model_noopt(x, y)
536*da0073e9SAndroid Build Coastguard Worker                    rep_noopt = str(model_noopt.graph_for(x, y))
537*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, out_noopt, atol=3e-5)
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker            # Check that normalization op has really been decomposed
540*da0073e9SAndroid Build Coastguard Worker            for node_in_graph in in_opt_graph:
541*da0073e9SAndroid Build Coastguard Worker                self.assertIn(node_in_graph, rep)
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker            for node_not_in_graph in not_in_opt_graph:
544*da0073e9SAndroid Build Coastguard Worker                self.assertNotIn(node_not_in_graph, rep)
545*da0073e9SAndroid Build Coastguard Worker                self.assertIn(node_not_in_graph, rep_noopt)
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker            fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
548*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(fusion_groups), 1)
549*da0073e9SAndroid Build Coastguard Worker            fused_graph = str(fusion_groups[0].g('Subgraph'))
550*da0073e9SAndroid Build Coastguard Worker            for node_in_fusegraph in in_fusegraph:
551*da0073e9SAndroid Build Coastguard Worker                self.assertIn(node_in_fusegraph, fused_graph)
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker        # test for batchnorm decompose
554*da0073e9SAndroid Build Coastguard Worker        bm = nn.BatchNorm2d(16)
555*da0073e9SAndroid Build Coastguard Worker        test_norm_decompose(bm, ['aten::batch_norm_update_stats'],
556*da0073e9SAndroid Build Coastguard Worker                            ['aten::batch_norm('], ['aten::sqrt'])
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker        # test for layernorm decompose
559*da0073e9SAndroid Build Coastguard Worker        lm = nn.LayerNorm(8)
560*da0073e9SAndroid Build Coastguard Worker        test_norm_decompose(lm, ['aten::batch_norm_stats'],
561*da0073e9SAndroid Build Coastguard Worker                            ['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::add'])
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
564*da0073e9SAndroid Build Coastguard Worker    def test_threshold(self):
565*da0073e9SAndroid Build Coastguard Worker        def f(x):
566*da0073e9SAndroid Build Coastguard Worker            return torch.threshold(x, 0, -10) + x + x + x
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda')
569*da0073e9SAndroid Build Coastguard Worker        scripted = self.checkScript(f, (x,))
570*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(scripted.graph_for(x))
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
573*da0073e9SAndroid Build Coastguard Worker    def test_scalar_arg_cuda(self):
574*da0073e9SAndroid Build Coastguard Worker        def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
575*da0073e9SAndroid Build Coastguard Worker            return p * (x * x + x)
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
578*da0073e9SAndroid Build Coastguard Worker        p = 3
579*da0073e9SAndroid Build Coastguard Worker        scripted = self.checkScript(fn_test_scalar_arg, (x, p))
580*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(scripted.graph_for(x, p))
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker        x.requires_grad_(True)
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker        # use another function otherwise we will bailout
585*da0073e9SAndroid Build Coastguard Worker        # and won't be able to do fused checks
586*da0073e9SAndroid Build Coastguard Worker        def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor:
587*da0073e9SAndroid Build Coastguard Worker            return p * (x * x + x)
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)
590*da0073e9SAndroid Build Coastguard Worker        out = scripted(x, p)
591*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes",
592*da0073e9SAndroid Build Coastguard Worker                                                                  "aten::_size_if_not_equal"))
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
595*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("deduplicating introduces aliasing in backward graph's outputs")
596*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
597*da0073e9SAndroid Build Coastguard Worker    def test_fuser_deduplication(self):
598*da0073e9SAndroid Build Coastguard Worker        # See that fusion kernel outputs are deduplicated when removing  _grad_sum_to_size in the fuser's compilation
599*da0073e9SAndroid Build Coastguard Worker        # see the discussion in PR #14957.
600*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
601*da0073e9SAndroid Build Coastguard Worker            return torch.sigmoid(x + y)
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 5, requires_grad=True)
604*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 5, requires_grad=True)
605*da0073e9SAndroid Build Coastguard Worker        s = self.checkScript(f, (a, b))
606*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(s.graph_for(a, b), except_for={
607*da0073e9SAndroid Build Coastguard Worker                            'aten::size', 'aten::_size_if_not_equal', 'prim::BroadcastSizes'})
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        c = s(a, b)
610*da0073e9SAndroid Build Coastguard Worker        results = warmup_backward(c.sum(), [a, b])
611*da0073e9SAndroid Build Coastguard Worker        ga2, gb2 = results.pop()
612*da0073e9SAndroid Build Coastguard Worker        graph = backward_graph(s)
613*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(graph)
614*da0073e9SAndroid Build Coastguard Worker        # check that a, b share storage, i.e. were generated as a single output in the fuser
615*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ga2.data_ptr(), gb2.data_ptr())
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
618*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
619*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("temporarily disabled because fusion was restricted in fixing #22833")
620*da0073e9SAndroid Build Coastguard Worker    def test_fuser_iou(self):
621*da0073e9SAndroid Build Coastguard Worker        # This checks if most of Intersection over Union is fused.
622*da0073e9SAndroid Build Coastguard Worker        # In particular, the backward contains many _grad_sum_to_size.
623*da0073e9SAndroid Build Coastguard Worker        def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
624*da0073e9SAndroid Build Coastguard Worker            ltx = torch.max(b1x1, b2x1)  # [N,M]
625*da0073e9SAndroid Build Coastguard Worker            lty = torch.max(b1y1, b2y1)
626*da0073e9SAndroid Build Coastguard Worker            rbx = torch.min(b1x2, b2x2)
627*da0073e9SAndroid Build Coastguard Worker            rby = torch.min(b1y2, b2y2)
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker            w = (rbx - ltx).clamp(min=0, max=float('inf'))  # [N,M]
630*da0073e9SAndroid Build Coastguard Worker            h = (rby - lty).clamp(min=0, max=float('inf'))  # [N,M]
631*da0073e9SAndroid Build Coastguard Worker            inter = w * h  # [N,M]
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker            area1 = (b1x2 - b1x1) * (b1y2 - b1y2)  # [N,1]
634*da0073e9SAndroid Build Coastguard Worker            area2 = (b2x2 - b2x1) * (b2y2 - b2y2)  # [1,M]
635*da0073e9SAndroid Build Coastguard Worker            iou = inter / (area1 + area2 - inter)
636*da0073e9SAndroid Build Coastguard Worker            return iou
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker        box1 = torch.randn(5, 4, requires_grad=True)
639*da0073e9SAndroid Build Coastguard Worker        box2 = torch.randn(5, 4, requires_grad=True)
640*da0073e9SAndroid Build Coastguard Worker        # unsqueezing can currently not be fused
641*da0073e9SAndroid Build Coastguard Worker        b1x1 = box1[:, 0].unsqueeze(1)  # [N,1]
642*da0073e9SAndroid Build Coastguard Worker        b1y1 = box1[:, 1].unsqueeze(1)
643*da0073e9SAndroid Build Coastguard Worker        b1x2 = box1[:, 2].unsqueeze(1)
644*da0073e9SAndroid Build Coastguard Worker        b1y2 = box1[:, 3].unsqueeze(1)
645*da0073e9SAndroid Build Coastguard Worker        b2x1 = box2[:, 0].unsqueeze(0)  # [1,N]
646*da0073e9SAndroid Build Coastguard Worker        b2y1 = box2[:, 1].unsqueeze(0)
647*da0073e9SAndroid Build Coastguard Worker        b2x2 = box2[:, 2].unsqueeze(0)
648*da0073e9SAndroid Build Coastguard Worker        b2y2 = box2[:, 3].unsqueeze(0)
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker        s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
651*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2),
652*da0073e9SAndroid Build Coastguard Worker                            except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests(True):
655*da0073e9SAndroid Build Coastguard Worker            c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
656*da0073e9SAndroid Build Coastguard Worker            warmup_backward(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2])
657*da0073e9SAndroid Build Coastguard Worker            graph = backward_graph(s)
658*da0073e9SAndroid Build Coastguard Worker            self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
661*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
662*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
663*da0073e9SAndroid Build Coastguard Worker    def test_fusion_reuse_multi_gpu(self):
664*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
665*da0073e9SAndroid Build Coastguard Worker            return x * y * x * y
666*da0073e9SAndroid Build Coastguard Worker
667*da0073e9SAndroid Build Coastguard Worker        inputs_cpu = [
668*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 4, dtype=torch.float),
669*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 4, dtype=torch.float),
670*da0073e9SAndroid Build Coastguard Worker        ]
671*da0073e9SAndroid Build Coastguard Worker        inputs_cuda0 = [x.cuda(0) for x in inputs_cpu]
672*da0073e9SAndroid Build Coastguard Worker        inputs_cuda1 = [y.cuda(1) for y in inputs_cpu]
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker        # Should not crash; these should compile different kernels.
675*da0073e9SAndroid Build Coastguard Worker        ge = self.checkScript(fn, inputs_cpu)
676*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(*inputs_cpu))
677*da0073e9SAndroid Build Coastguard Worker        ge(*inputs_cuda0)
678*da0073e9SAndroid Build Coastguard Worker        ge(*inputs_cuda1)
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
681*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
682*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
683*da0073e9SAndroid Build Coastguard Worker    def test_kernel_cache_multi_gpu(self):
684*da0073e9SAndroid Build Coastguard Worker        def not_fusible(x):
685*da0073e9SAndroid Build Coastguard Worker            return x
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
688*da0073e9SAndroid Build Coastguard Worker            x_out = x * x * x * x * x  # fusion: lambda x. x * x * x * x * x
689*da0073e9SAndroid Build Coastguard Worker            y_out = y * y * y * y * y
690*da0073e9SAndroid Build Coastguard Worker            z_out = z * z * z * z * z
691*da0073e9SAndroid Build Coastguard Worker            return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        inputs = [
694*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 4, dtype=torch.float),
695*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 4, dtype=torch.float, device='cuda:0'),
696*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 4, dtype=torch.float, device='cuda:1'),
697*da0073e9SAndroid Build Coastguard Worker        ]
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker        prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker        # There are 3 FusionGroups. Because they have the same graph, they
702*da0073e9SAndroid Build Coastguard Worker        # should reuse the same KernelSpec in the KernelSpec cache.
703*da0073e9SAndroid Build Coastguard Worker        ge = self.checkScript(fn, inputs)
704*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
705*da0073e9SAndroid Build Coastguard Worker            ge.graph_for(*inputs), 'prim::FusionGroup', 3, True)
706*da0073e9SAndroid Build Coastguard Worker        new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
707*da0073e9SAndroid Build Coastguard Worker        # XXX: This assumes that the same kernel isn't already used by another test
708*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(new_cache_size - prev_cache_size, 1)
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
711*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_device_cuda(self):
712*da0073e9SAndroid Build Coastguard Worker        device = 'cuda:' + str(1)
713*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.4], dtype=torch.float, device=device)
714*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([0.7], dtype=torch.float, device=device)
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker        def doit(x, y):
717*da0073e9SAndroid Build Coastguard Worker            return torch.sigmoid(torch.tanh(x * (x + y) + x))
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(doit, (x, y))
720*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(x, y))
721*da0073e9SAndroid Build Coastguard Worker
722*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
723*da0073e9SAndroid Build Coastguard Worker    def test_lstm_cuda(self):
724*da0073e9SAndroid Build Coastguard Worker        inputs = get_lstm_inputs('cuda', training=True)
725*da0073e9SAndroid Build Coastguard Worker        module = self.checkScript(LSTMCellS, inputs)
726*da0073e9SAndroid Build Coastguard Worker        return
727*da0073e9SAndroid Build Coastguard Worker        forward_graph = module.graph_for(*inputs)
728*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
729*da0073e9SAndroid Build Coastguard Worker            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
730*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
731*da0073e9SAndroid Build Coastguard Worker        # Everything is differentiable but TupleConstruct return
732*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
733*da0073e9SAndroid Build Coastguard Worker            .check_next("return").run(str(forward_graph))
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests(True):
736*da0073e9SAndroid Build Coastguard Worker            hy, cy = module(*inputs)
737*da0073e9SAndroid Build Coastguard Worker            warmup_backward((hy + cy).sum())
738*da0073e9SAndroid Build Coastguard Worker            backward = backward_graph(module)
739*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
740*da0073e9SAndroid Build Coastguard Worker                                                  "aten::_grad_sum_to_size"))
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
743*da0073e9SAndroid Build Coastguard Worker    # By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision.
744*da0073e9SAndroid Build Coastguard Worker    # We want float tensors to be computed at full precision in order to use the default precision
745*da0073e9SAndroid Build Coastguard Worker    @with_tf32_off
746*da0073e9SAndroid Build Coastguard Worker    def test_lstm_concat_cuda(self):
747*da0073e9SAndroid Build Coastguard Worker        inputs = get_lstm_inputs('cuda')
748*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(LSTMCellC, inputs)
749*da0073e9SAndroid Build Coastguard Worker        graph = ge.graph_for(*inputs)
750*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("FusedConcat").check_next("return").run(str(graph))
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
753*da0073e9SAndroid Build Coastguard Worker    def test_lstm_gates_permutations_cuda(self):
754*da0073e9SAndroid Build Coastguard Worker        # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
755*da0073e9SAndroid Build Coastguard Worker        # Test that any permutation of this will still result in one FusionGroup.
756*da0073e9SAndroid Build Coastguard Worker        choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh']
757*da0073e9SAndroid Build Coastguard Worker        template = dedent('''
758*da0073e9SAndroid Build Coastguard Worker        def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
759*da0073e9SAndroid Build Coastguard Worker            gates = {} + {} + {} + {}
760*da0073e9SAndroid Build Coastguard Worker            ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
761*da0073e9SAndroid Build Coastguard Worker            return ingate * forgetgate * cellgate * outgate
762*da0073e9SAndroid Build Coastguard Worker        ''')
763*da0073e9SAndroid Build Coastguard Worker        for permutation in permutations(choices, len(choices)):
764*da0073e9SAndroid Build Coastguard Worker            code = template.format(*permutation)
765*da0073e9SAndroid Build Coastguard Worker            scope = {}
766*da0073e9SAndroid Build Coastguard Worker            exec(code, globals(), scope)
767*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code)
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker            inputs = get_lstm_inputs('cuda', training=False)
770*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
771*da0073e9SAndroid Build Coastguard Worker            forward_graph = cu.cell.graph_for(*inputs)
772*da0073e9SAndroid Build Coastguard Worker            self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker    # TODO: Fuser doesn't work at all when inputs require grad. Fix that
775*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
776*da0073e9SAndroid Build Coastguard Worker    # By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision.
777*da0073e9SAndroid Build Coastguard Worker    # We want float tensors to be computed at full precision in order to use the default precision
778*da0073e9SAndroid Build Coastguard Worker    @with_tf32_off
779*da0073e9SAndroid Build Coastguard Worker    def test_lstm_traced_cuda(self):
780*da0073e9SAndroid Build Coastguard Worker        inputs = get_lstm_inputs('cuda')
781*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(LSTMCellF, inputs)
782*da0073e9SAndroid Build Coastguard Worker        graph = ge.graph_for(*inputs)
783*da0073e9SAndroid Build Coastguard Worker        # .check_not("aten::add") don't get pulled into FusionGroup because of BailOuts
784*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("Chunk").check_not("aten::sigmoid") \
785*da0073e9SAndroid Build Coastguard Worker            .check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \
786*da0073e9SAndroid Build Coastguard Worker            .check_next("return").check_not("FusionGroup_2").run(str(graph))
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
789*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
790*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
791*da0073e9SAndroid Build Coastguard Worker    def test_lstm_traced_cpu(self):
792*da0073e9SAndroid Build Coastguard Worker        inputs = get_lstm_inputs('cpu')
793*da0073e9SAndroid Build Coastguard Worker        try:
794*da0073e9SAndroid Build Coastguard Worker            ge = self.checkTrace(LSTMCellF, inputs)
795*da0073e9SAndroid Build Coastguard Worker            graph = ge.graph_for(*inputs)
796*da0073e9SAndroid Build Coastguard Worker            FileCheck.check("FusionGroup").run(str(graph))
797*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
798*da0073e9SAndroid Build Coastguard Worker            if 'Failed to compile' in e.args[0]:
799*da0073e9SAndroid Build Coastguard Worker                warnings.warn('CPU fuser test has failed! This is not a hard failure, '  # noqa: F821
800*da0073e9SAndroid Build Coastguard Worker                              'because the kernels sometimes trigger bugs in compilers '
801*da0073e9SAndroid Build Coastguard Worker                              '(most notably GCC 7.2).')
802*da0073e9SAndroid Build Coastguard Worker                raise unittest.SkipTest('Failed to compile') from e
803*da0073e9SAndroid Build Coastguard Worker            else:
804*da0073e9SAndroid Build Coastguard Worker                raise
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
807*da0073e9SAndroid Build Coastguard Worker    def test_milstm_cuda(self):
808*da0073e9SAndroid Build Coastguard Worker        inputs = get_milstm_inputs('cuda', training=True)
809*da0073e9SAndroid Build Coastguard Worker        module = self.checkScript(MiLSTMCell, inputs)
810*da0073e9SAndroid Build Coastguard Worker        forward_graph = module.graph_for(*inputs)
811*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
812*da0073e9SAndroid Build Coastguard Worker            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
813*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
814*da0073e9SAndroid Build Coastguard Worker            .check_next("return").check("FusionGroup").run(str(forward_graph))
815*da0073e9SAndroid Build Coastguard Worker        hy, cy = module(*inputs)
816*da0073e9SAndroid Build Coastguard Worker        warmup_backward((hy + cy).sum())
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
819*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor")
820*da0073e9SAndroid Build Coastguard Worker    def test_rand_cuda(self):
821*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
822*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['d']
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
825*da0073e9SAndroid Build Coastguard Worker                super().__init__()
826*da0073e9SAndroid Build Coastguard Worker                self.d = torch.device('cuda')
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
829*da0073e9SAndroid Build Coastguard Worker            def create(self, x):
830*da0073e9SAndroid Build Coastguard Worker                return x * x + x + torch.rand_like(x)
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
833*da0073e9SAndroid Build Coastguard Worker        m = M()
834*da0073e9SAndroid Build Coastguard Worker        out1 = m.create(x)
835*da0073e9SAndroid Build Coastguard Worker        out2 = m.create(x)
836*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(out1, out2)
837*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(out1 >= 0))
838*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(out1 < 1))
839*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(out2 >= 0))
840*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(out2 < 1))
841*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(m.create.graph_for(x))
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker    @staticmethod
844*da0073e9SAndroid Build Coastguard Worker    def fn_test_relu(x, y):
845*da0073e9SAndroid Build Coastguard Worker        return F.relu(x + .5 * y)
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
848*da0073e9SAndroid Build Coastguard Worker    def test_relu_cuda(self):
849*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
850*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
851*da0073e9SAndroid Build Coastguard Worker
852*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(self.fn_test_relu, (x, y))
853*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(x, y))
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
856*da0073e9SAndroid Build Coastguard Worker    def test_erf_cuda(self):
857*da0073e9SAndroid Build Coastguard Worker        def fn_test_erf(x):
858*da0073e9SAndroid Build Coastguard Worker            return F.relu(torch.erf(x) - torch.erfc(x))
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
861*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(fn_test_erf, (x,))
862*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(x))
863*da0073e9SAndroid Build Coastguard Worker        x.requires_grad_(True)
864*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(fn_test_erf, (x,))
865*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes",
866*da0073e9SAndroid Build Coastguard Worker                                                         "aten::_size_if_not_equal"))
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
869*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor")
870*da0073e9SAndroid Build Coastguard Worker    def test_rand_broadcast_cuda(self):
871*da0073e9SAndroid Build Coastguard Worker        def fn_test_rand(x, y):
872*da0073e9SAndroid Build Coastguard Worker            r = torch.rand_like(y)
873*da0073e9SAndroid Build Coastguard Worker            return r * x + x
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
876*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
877*da0073e9SAndroid Build Coastguard Worker        script_f = torch.jit.script(fn_test_rand)
878*da0073e9SAndroid Build Coastguard Worker        out = script_f(x, y)
879*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(script_f.graph_for(x, y))
880*da0073e9SAndroid Build Coastguard Worker        x.requires_grad_(True)
881*da0073e9SAndroid Build Coastguard Worker        out = script_f(x, y)
882*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
883*da0073e9SAndroid Build Coastguard Worker                                                                  "aten::_size_if_not_equal"))
884*da0073e9SAndroid Build Coastguard Worker        # test that broadcasting random produces correct results
885*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(4, 4, dtype=torch.float, device='cuda')
886*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(4, dtype=torch.float, device='cuda')
887*da0073e9SAndroid Build Coastguard Worker        out = script_f(x, y)
888*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out[0], out[1])
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
891*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
892*da0073e9SAndroid Build Coastguard Worker    def test_scalar(self):
893*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
894*da0073e9SAndroid Build Coastguard Worker            return 2 * x + y
895*da0073e9SAndroid Build Coastguard Worker
896*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(0.1, dtype=torch.float, device='cpu')
897*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(1, dtype=torch.float, device='cpu')
898*da0073e9SAndroid Build Coastguard Worker        ge = self.checkScript(fn, (x, y))
899*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(x, y))
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
902*da0073e9SAndroid Build Coastguard Worker    def test_small_constant_cuda(self):
903*da0073e9SAndroid Build Coastguard Worker        def fn_test_small_constant(x, y):
904*da0073e9SAndroid Build Coastguard Worker            return (1e-8 * x + 5e-9 * y) * 1e8
905*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
906*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(fn_test_small_constant, (x, y))
909*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(x, y))
910*da0073e9SAndroid Build Coastguard Worker
911*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
912*da0073e9SAndroid Build Coastguard Worker    def test_tensor_scalar_ops_cuda(self):
913*da0073e9SAndroid Build Coastguard Worker        def should_fuse(x):
914*da0073e9SAndroid Build Coastguard Worker            z = 3.
915*da0073e9SAndroid Build Coastguard Worker            y = x + z
916*da0073e9SAndroid Build Coastguard Worker            return x * y
917*da0073e9SAndroid Build Coastguard Worker
918*da0073e9SAndroid Build Coastguard Worker        # XXX: right now we only support fusing scalars if
919*da0073e9SAndroid Build Coastguard Worker        # they're constant (#9940)
920*da0073e9SAndroid Build Coastguard Worker        def should_not_fuse(x, z):
921*da0073e9SAndroid Build Coastguard Worker            y = x + int(z)
922*da0073e9SAndroid Build Coastguard Worker            return x * y
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
925*da0073e9SAndroid Build Coastguard Worker        ge = self.checkScript(should_fuse, inputs)
926*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(ge.graph_for(*inputs))
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker        inputs = [
929*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, 2, dtype=torch.float, device='cuda'),
930*da0073e9SAndroid Build Coastguard Worker            torch.tensor(3., dtype=torch.float, device='cuda'),
931*da0073e9SAndroid Build Coastguard Worker        ]
932*da0073e9SAndroid Build Coastguard Worker        ge = self.checkScript(should_not_fuse, inputs)
933*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
934*da0073e9SAndroid Build Coastguard Worker            ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
937*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
938*da0073e9SAndroid Build Coastguard Worker    def test_where_and_typing(self):
939*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
940*da0073e9SAndroid Build Coastguard Worker            mask = x > y
941*da0073e9SAndroid Build Coastguard Worker            res = torch.where(mask, x, y)
942*da0073e9SAndroid Build Coastguard Worker            return mask, res
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, dtype=torch.double)
945*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, dtype=torch.double)
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker        script_f = self.checkScript(f, (x, y))
948*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
951*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
952*da0073e9SAndroid Build Coastguard Worker    def test_grad_sum_to_size_elimination(self):
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Worker        def my_broadcasted_cell(a, b, c):
955*da0073e9SAndroid Build Coastguard Worker            return (a + b) + c
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker        s1 = torch.randn(5, 1, requires_grad=True, device='cuda')
958*da0073e9SAndroid Build Coastguard Worker        s2 = torch.randn(5, 5, requires_grad=True, device='cuda')
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker        module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING)
961*da0073e9SAndroid Build Coastguard Worker        forward_graph = module.graph_for(s1, s1, s1)
962*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes",
963*da0073e9SAndroid Build Coastguard Worker                                                       "aten::_size_if_not_equal"))
964*da0073e9SAndroid Build Coastguard Worker
965*da0073e9SAndroid Build Coastguard Worker        old_plans = set()
966*da0073e9SAndroid Build Coastguard Worker        for i in range(3):
967*da0073e9SAndroid Build Coastguard Worker            # if we have s2, then the s1 are _grad_sum_to_size'd
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker            args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2
970*da0073e9SAndroid Build Coastguard Worker            args = [a.detach_().requires_grad_() for a in args]
971*da0073e9SAndroid Build Coastguard Worker            # recompile, so we don't trigger bailouts
972*da0073e9SAndroid Build Coastguard Worker            module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING)
973*da0073e9SAndroid Build Coastguard Worker            res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2)
974*da0073e9SAndroid Build Coastguard Worker            warmup_backward(res.sum(), args)
975*da0073e9SAndroid Build Coastguard Worker            grads = torch.autograd.grad(res.sum(), args)
976*da0073e9SAndroid Build Coastguard Worker            for inp, gr in zip(args, grads):
977*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(inp.shape, gr.shape)
978*da0073e9SAndroid Build Coastguard Worker            backward = None
979*da0073e9SAndroid Build Coastguard Worker            # this is a workaround for the backward graphs not being
980*da0073e9SAndroid Build Coastguard Worker            # in order for Python 2
981*da0073e9SAndroid Build Coastguard Worker            for g in all_backward_graphs(module):
982*da0073e9SAndroid Build Coastguard Worker                if str(g) not in old_plans:
983*da0073e9SAndroid Build Coastguard Worker                    assert backward is None
984*da0073e9SAndroid Build Coastguard Worker                    backward = g
985*da0073e9SAndroid Build Coastguard Worker                    old_plans.add(str(backward))
986*da0073e9SAndroid Build Coastguard Worker            num_grads = 1 if i > 0 else 0
987*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len([n for n in backward.nodes() if n.kind() == 'aten::_grad_sum_to_size']), num_grads)
988*da0073e9SAndroid Build Coastguard Worker
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
991*da0073e9SAndroid Build Coastguard Worker    run_tests()
992