xref: /aosp_15_r20/external/pytorch/test/test_tensorexpr.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["NNC"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport numpy as np
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
6*da0073e9SAndroid Build Coastguard Workerfrom torch import nn
7*da0073e9SAndroid Build Coastguard Workerimport unittest
8*da0073e9SAndroid Build Coastguard Workerimport itertools
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard WorkerLLVM_ENABLED = torch._C._llvm_enabled()
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerclass BaseTestClass(JitTestCase):
17*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
18*da0073e9SAndroid Build Coastguard Worker        super().setUp()
19*da0073e9SAndroid Build Coastguard Worker        self.tensorexpr_options = TensorExprTestOptions()
20*da0073e9SAndroid Build Coastguard Worker        self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
21*da0073e9SAndroid Build Coastguard Worker        self.dtypes = [torch.float32, torch.bfloat16] if LLVM_ENABLED else [torch.float32]
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
24*da0073e9SAndroid Build Coastguard Worker        self.tensorexpr_options.restore()
25*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    def assertLastGraphAllFused(self):
28*da0073e9SAndroid Build Coastguard Worker        self.assertAllFused(torch.jit.last_executed_optimized_graph())
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Workerdef warmup_and_run_forward(f, *args):
32*da0073e9SAndroid Build Coastguard Worker    for _ in range(torch._C._jit_get_num_profiled_runs() + 1):
33*da0073e9SAndroid Build Coastguard Worker        results = f(*args)
34*da0073e9SAndroid Build Coastguard Worker    return results
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo()
38*da0073e9SAndroid Build Coastguard Workerclass TestTensorExprFuser(BaseTestClass):
39*da0073e9SAndroid Build Coastguard Worker    def test_easy(self):
40*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
41*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, y)
42*da0073e9SAndroid Build Coastguard Worker            return aaa
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024)))
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1024)
47*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1024)
48*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b)
49*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
50*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    def test_three_arg(self):
53*da0073e9SAndroid Build Coastguard Worker        def easy(x, y, z):
54*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, y)
55*da0073e9SAndroid Build Coastguard Worker            bbb = torch.add(aaa, z)
56*da0073e9SAndroid Build Coastguard Worker            return bbb
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(
59*da0073e9SAndroid Build Coastguard Worker            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
60*da0073e9SAndroid Build Coastguard Worker        )
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1024)
63*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1024)
64*da0073e9SAndroid Build Coastguard Worker        c = torch.rand(1024)
65*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b, c)
66*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
67*da0073e9SAndroid Build Coastguard Worker        npr = a.numpy() + b.numpy() + c.numpy()
68*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(npr, x.numpy())
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    def test_four_arg(self):
71*da0073e9SAndroid Build Coastguard Worker        def run_addcmul(x, y, z, w):
72*da0073e9SAndroid Build Coastguard Worker            c = torch.addcmul(torch.add(x, y), z, w)
73*da0073e9SAndroid Build Coastguard Worker            return c
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        for dev in self.devices:
76*da0073e9SAndroid Build Coastguard Worker            rand_a = torch.rand(1024, dtype=torch.float, device=dev)
77*da0073e9SAndroid Build Coastguard Worker            rand_b = torch.rand(1024, dtype=torch.float, device=dev)
78*da0073e9SAndroid Build Coastguard Worker            rand_c = torch.rand(1024, dtype=torch.float, device=dev)
79*da0073e9SAndroid Build Coastguard Worker            rand_d = torch.rand(1024, dtype=torch.float, device=dev)
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(
82*da0073e9SAndroid Build Coastguard Worker                run_addcmul,
83*da0073e9SAndroid Build Coastguard Worker                (
84*da0073e9SAndroid Build Coastguard Worker                    torch.zeros(1024, dtype=torch.float, device=dev),
85*da0073e9SAndroid Build Coastguard Worker                    torch.zeros(1024, dtype=torch.float, device=dev),
86*da0073e9SAndroid Build Coastguard Worker                    torch.zeros(1024, dtype=torch.float, device=dev),
87*da0073e9SAndroid Build Coastguard Worker                    torch.zeros(1024, dtype=torch.float, device=dev),
88*da0073e9SAndroid Build Coastguard Worker                ),
89*da0073e9SAndroid Build Coastguard Worker            )
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, rand_a, rand_b, rand_c, rand_d)
92*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
93*da0073e9SAndroid Build Coastguard Worker            y = run_addcmul(rand_a, rand_b, rand_c, rand_d)
94*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    def test_three_arg2(self):
97*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
98*da0073e9SAndroid Build Coastguard Worker            def test(x, y, z):
99*da0073e9SAndroid Build Coastguard Worker                aaa = torch.add(x, y)
100*da0073e9SAndroid Build Coastguard Worker                bbb = torch.add(aaa, z)
101*da0073e9SAndroid Build Coastguard Worker                return bbb
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker            M = 32
104*da0073e9SAndroid Build Coastguard Worker            N = 32
105*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(
106*da0073e9SAndroid Build Coastguard Worker                test,
107*da0073e9SAndroid Build Coastguard Worker                (
108*da0073e9SAndroid Build Coastguard Worker                    torch.rand(M, N, device=device),
109*da0073e9SAndroid Build Coastguard Worker                    torch.rand(M, N, device=device),
110*da0073e9SAndroid Build Coastguard Worker                    torch.rand(M, N, device=device),
111*da0073e9SAndroid Build Coastguard Worker                ),
112*da0073e9SAndroid Build Coastguard Worker            )
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(M, N, device=device)
115*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(M, N, device=device)
116*da0073e9SAndroid Build Coastguard Worker            c = torch.rand(M, N, device=device)
117*da0073e9SAndroid Build Coastguard Worker            x = traced(a, b, c)
118*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, a, b, c)
119*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
120*da0073e9SAndroid Build Coastguard Worker            npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
121*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(npr, x.cpu().numpy())
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker    def test_broadcast3(self):
124*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
125*da0073e9SAndroid Build Coastguard Worker            def test_body(M, N, L, K):
126*da0073e9SAndroid Build Coastguard Worker                def test(x, y, z):
127*da0073e9SAndroid Build Coastguard Worker                    v1 = torch.add(x, y)
128*da0073e9SAndroid Build Coastguard Worker                    v2 = torch.add(v1, z)
129*da0073e9SAndroid Build Coastguard Worker                    return v2
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker                a_shape = [M, N]
132*da0073e9SAndroid Build Coastguard Worker                b_shape = [L, M, 1]
133*da0073e9SAndroid Build Coastguard Worker                c_shape = [K, L, 1, 1]
134*da0073e9SAndroid Build Coastguard Worker                traced = torch.jit.trace(
135*da0073e9SAndroid Build Coastguard Worker                    test,
136*da0073e9SAndroid Build Coastguard Worker                    (
137*da0073e9SAndroid Build Coastguard Worker                        torch.rand(*a_shape, device=device),
138*da0073e9SAndroid Build Coastguard Worker                        torch.rand(*b_shape, device=device),
139*da0073e9SAndroid Build Coastguard Worker                        torch.rand(*c_shape, device=device),
140*da0073e9SAndroid Build Coastguard Worker                    ),
141*da0073e9SAndroid Build Coastguard Worker                )
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker                a = torch.rand(*a_shape, device=device)
144*da0073e9SAndroid Build Coastguard Worker                b = torch.rand(*b_shape, device=device)
145*da0073e9SAndroid Build Coastguard Worker                c = torch.rand(*c_shape, device=device)
146*da0073e9SAndroid Build Coastguard Worker                x = warmup_and_run_forward(traced, a, b, c)
147*da0073e9SAndroid Build Coastguard Worker                self.assertLastGraphAllFused()
148*da0073e9SAndroid Build Coastguard Worker                npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
149*da0073e9SAndroid Build Coastguard Worker                np.testing.assert_allclose(npr, x.cpu().numpy())
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker            test_configs = [[5, 2, 7, 3], [8, 8, 8, 8]]
152*da0073e9SAndroid Build Coastguard Worker            for test_config in test_configs:
153*da0073e9SAndroid Build Coastguard Worker                test_body(*test_config)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker    def test_all_combos(self):
156*da0073e9SAndroid Build Coastguard Worker        def easy(x, y, z):
157*da0073e9SAndroid Build Coastguard Worker            a = torch.add(x, y)
158*da0073e9SAndroid Build Coastguard Worker            b = torch.add(a, z)
159*da0073e9SAndroid Build Coastguard Worker            c = torch.add(x, b)
160*da0073e9SAndroid Build Coastguard Worker            d = torch.add(c, a)
161*da0073e9SAndroid Build Coastguard Worker            return d
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        def np_easy(x, y, z):
164*da0073e9SAndroid Build Coastguard Worker            a = x + y
165*da0073e9SAndroid Build Coastguard Worker            b = a + z
166*da0073e9SAndroid Build Coastguard Worker            c = x + b
167*da0073e9SAndroid Build Coastguard Worker            d = c + a
168*da0073e9SAndroid Build Coastguard Worker            return d
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(
171*da0073e9SAndroid Build Coastguard Worker            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
172*da0073e9SAndroid Build Coastguard Worker        )
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1024)
175*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1024)
176*da0073e9SAndroid Build Coastguard Worker        c = torch.rand(1024)
177*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b, c)
178*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
179*da0073e9SAndroid Build Coastguard Worker        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
180*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(npr, x.numpy())
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker    def test_rank_two(self):
183*da0073e9SAndroid Build Coastguard Worker        def easy(x, y, z):
184*da0073e9SAndroid Build Coastguard Worker            a = torch.add(x, y)
185*da0073e9SAndroid Build Coastguard Worker            b = torch.add(a, z)
186*da0073e9SAndroid Build Coastguard Worker            c = torch.add(x, b)
187*da0073e9SAndroid Build Coastguard Worker            d = torch.add(c, a)
188*da0073e9SAndroid Build Coastguard Worker            return d
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        def np_easy(x, y, z):
191*da0073e9SAndroid Build Coastguard Worker            a = x + y
192*da0073e9SAndroid Build Coastguard Worker            b = a + z
193*da0073e9SAndroid Build Coastguard Worker            c = x + b
194*da0073e9SAndroid Build Coastguard Worker            d = c + a
195*da0073e9SAndroid Build Coastguard Worker            return d
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        shape = 32, 32
198*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(
199*da0073e9SAndroid Build Coastguard Worker            easy, (torch.rand(shape), torch.rand(shape), torch.rand(shape))
200*da0073e9SAndroid Build Coastguard Worker        )
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(shape)
203*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(shape)
204*da0073e9SAndroid Build Coastguard Worker        c = torch.rand(shape)
205*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b, c)
206*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
207*da0073e9SAndroid Build Coastguard Worker        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
208*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(npr, x.numpy())
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker    def test_broadcast(self):
211*da0073e9SAndroid Build Coastguard Worker        def easy(x, y, z):
212*da0073e9SAndroid Build Coastguard Worker            a = torch.add(x, y)
213*da0073e9SAndroid Build Coastguard Worker            b = torch.add(a, z)
214*da0073e9SAndroid Build Coastguard Worker            return b
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker        def np_easy(x, y, z):
217*da0073e9SAndroid Build Coastguard Worker            a = x + y
218*da0073e9SAndroid Build Coastguard Worker            b = a + z
219*da0073e9SAndroid Build Coastguard Worker            return b
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker        N = 32
222*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(easy, (torch.rand(N, N), torch.rand(N), torch.rand(N, N)))
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(N, N)
225*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(N)
226*da0073e9SAndroid Build Coastguard Worker        c = torch.rand(N, N)
227*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b, c)
228*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
229*da0073e9SAndroid Build Coastguard Worker        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
230*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(npr, x.numpy())
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_2(self):
233*da0073e9SAndroid Build Coastguard Worker        zero = torch.tensor([0.0], dtype=torch.float)
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        def foo(x, y, z):
236*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, y)
237*da0073e9SAndroid Build Coastguard Worker            bbb = torch.add(zero, aaa)
238*da0073e9SAndroid Build Coastguard Worker            return torch.add(bbb, z)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker        def foo_np(x, y, z):
241*da0073e9SAndroid Build Coastguard Worker            a = x + y
242*da0073e9SAndroid Build Coastguard Worker            b = zero.numpy() + a
243*da0073e9SAndroid Build Coastguard Worker            return b + z
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
246*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(3, 1)
247*da0073e9SAndroid Build Coastguard Worker        z = torch.rand(4)
248*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(foo, (x, y, z))
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        r = warmup_and_run_forward(traced, x, y, z)
251*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker        rnp = foo_np(x.numpy(), y.numpy(), z.numpy())
254*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(r, rnp)
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_big2(self):
257*da0073e9SAndroid Build Coastguard Worker        zero = torch.tensor([0.0], dtype=torch.float)
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        def foo(x, y, z):
260*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, y)
261*da0073e9SAndroid Build Coastguard Worker            bbb = torch.add(zero, aaa)
262*da0073e9SAndroid Build Coastguard Worker            return torch.add(bbb, z)
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker        def foo_np(x, y, z):
265*da0073e9SAndroid Build Coastguard Worker            a = x + y
266*da0073e9SAndroid Build Coastguard Worker            b = zero.numpy() + a
267*da0073e9SAndroid Build Coastguard Worker            return b + z
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(32, 1024)
270*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(32, 1)
271*da0073e9SAndroid Build Coastguard Worker        z = torch.rand(1024)
272*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(foo, (x, y, z))
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        r = warmup_and_run_forward(traced, x, y, z)
275*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
276*da0073e9SAndroid Build Coastguard Worker        rnp = foo_np(x.numpy(), y.numpy(), z.numpy())
277*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(r, rnp)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    def test_alpha(self):
280*da0073e9SAndroid Build Coastguard Worker        def alpha(x):
281*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, x, alpha=2.0)
282*da0073e9SAndroid Build Coastguard Worker            return aaa
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(alpha, (torch.tensor([1.0])))
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.0])
287*da0073e9SAndroid Build Coastguard Worker        x = traced(a)
288*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(a.numpy() + 2.0 * a.numpy(), x.numpy())
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
291*da0073e9SAndroid Build Coastguard Worker    def test_constant(self):
292*da0073e9SAndroid Build Coastguard Worker        def constant(x):
293*da0073e9SAndroid Build Coastguard Worker            bbb = torch.tensor([1.0])
294*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, bbb)
295*da0073e9SAndroid Build Coastguard Worker            return aaa
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(constant, (torch.tensor([1.0])))
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.0])
300*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a)
301*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
302*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(a.numpy() + 1.0, x.numpy())
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker    def test_add_sub(self):
305*da0073e9SAndroid Build Coastguard Worker        def easy(x, y, z):
306*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, y)
307*da0073e9SAndroid Build Coastguard Worker            bbb = torch.sub(aaa, z)
308*da0073e9SAndroid Build Coastguard Worker            return bbb
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(
311*da0073e9SAndroid Build Coastguard Worker            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
312*da0073e9SAndroid Build Coastguard Worker        )
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1024)
315*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1024)
316*da0073e9SAndroid Build Coastguard Worker        c = torch.rand(1024)
317*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b, c)
318*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
319*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy())
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker    def test_promotion(self):
322*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
323*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, y)
324*da0073e9SAndroid Build Coastguard Worker            return aaa
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(
327*da0073e9SAndroid Build Coastguard Worker            easy,
328*da0073e9SAndroid Build Coastguard Worker            (torch.zeros(1024, dtype=torch.int32), torch.rand(1024, dtype=torch.float32)),
329*da0073e9SAndroid Build Coastguard Worker        )
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(1024, dtype=torch.int32)
332*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1024, dtype=torch.float32)
333*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b)
334*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
335*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker    def test_double(self):
338*da0073e9SAndroid Build Coastguard Worker        TENSOR_LEN = 8
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
341*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, y)
342*da0073e9SAndroid Build Coastguard Worker            bbb = torch.mul(aaa, y)
343*da0073e9SAndroid Build Coastguard Worker            return bbb
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(
346*da0073e9SAndroid Build Coastguard Worker            easy,
347*da0073e9SAndroid Build Coastguard Worker            (torch.rand(TENSOR_LEN, dtype=torch.float64), torch.full((TENSOR_LEN,), 0.5, dtype=torch.float64)),
348*da0073e9SAndroid Build Coastguard Worker        )
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(TENSOR_LEN, dtype=torch.double)
351*da0073e9SAndroid Build Coastguard Worker        b = torch.full((TENSOR_LEN,), 0.5, dtype=torch.double)
352*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b)
353*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
354*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker    def test_short(self):
357*da0073e9SAndroid Build Coastguard Worker        TENSOR_LEN = 8
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
360*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, y)
361*da0073e9SAndroid Build Coastguard Worker            bbb = torch.mul(aaa, y)
362*da0073e9SAndroid Build Coastguard Worker            return bbb
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(
365*da0073e9SAndroid Build Coastguard Worker            easy,
366*da0073e9SAndroid Build Coastguard Worker            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16),
367*da0073e9SAndroid Build Coastguard Worker             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)),
368*da0073e9SAndroid Build Coastguard Worker        )
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)
371*da0073e9SAndroid Build Coastguard Worker        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)
372*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b)
373*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
374*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker    def test_char(self):
377*da0073e9SAndroid Build Coastguard Worker        TENSOR_LEN = 8
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
380*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, y)
381*da0073e9SAndroid Build Coastguard Worker            bbb = torch.mul(aaa, y)
382*da0073e9SAndroid Build Coastguard Worker            return bbb
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(
385*da0073e9SAndroid Build Coastguard Worker            easy,
386*da0073e9SAndroid Build Coastguard Worker            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8),
387*da0073e9SAndroid Build Coastguard Worker             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)),
388*da0073e9SAndroid Build Coastguard Worker        )
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
391*da0073e9SAndroid Build Coastguard Worker        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
392*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b)
393*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
394*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    def test_int64_promotion(self):
397*da0073e9SAndroid Build Coastguard Worker        TENSOR_LEN = 8
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
400*da0073e9SAndroid Build Coastguard Worker            aaa = torch.add(x, y)
401*da0073e9SAndroid Build Coastguard Worker            bbb = torch.mul(aaa, y)
402*da0073e9SAndroid Build Coastguard Worker            return bbb
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(
405*da0073e9SAndroid Build Coastguard Worker            easy,
406*da0073e9SAndroid Build Coastguard Worker            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8),
407*da0073e9SAndroid Build Coastguard Worker             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)),
408*da0073e9SAndroid Build Coastguard Worker        )
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
411*da0073e9SAndroid Build Coastguard Worker        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)
412*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b)
413*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
414*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker    def test_eq(self):
417*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
418*da0073e9SAndroid Build Coastguard Worker            c = torch.eq(x, y)
419*da0073e9SAndroid Build Coastguard Worker            return c
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
422*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(1024, dtype=torch.int32)
423*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(1024, dtype=torch.int32)
424*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b)
425*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
426*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(np.ones(1024), x.numpy())
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker    def test_ne(self):
429*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
430*da0073e9SAndroid Build Coastguard Worker            c = torch.ne(x, y)
431*da0073e9SAndroid Build Coastguard Worker            return c
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
434*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(1024, dtype=torch.int32)
435*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(1024, dtype=torch.int32)
436*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b)
437*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
438*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(np.ones(1024), x.numpy())
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker    def test_ge(self):
441*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
442*da0073e9SAndroid Build Coastguard Worker            c = torch.ge(x, y)
443*da0073e9SAndroid Build Coastguard Worker            return c
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
446*da0073e9SAndroid Build Coastguard Worker        aa = np.empty([1024], dtype=np.int32)
447*da0073e9SAndroid Build Coastguard Worker        aa.fill(5)
448*da0073e9SAndroid Build Coastguard Worker        a = torch.from_numpy(aa)
449*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(1024, dtype=torch.int32)
450*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b)
451*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
452*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(np.ones(1024), x.numpy())
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker    def test_gt(self):
455*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
456*da0073e9SAndroid Build Coastguard Worker            c = torch.gt(x, y)
457*da0073e9SAndroid Build Coastguard Worker            return c
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
460*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1024, dtype=torch.int32)
461*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(1024, dtype=torch.int32)
462*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b)
463*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
464*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(np.ones(1024), x.numpy())
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker    def test_le(self):
467*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
468*da0073e9SAndroid Build Coastguard Worker            c = torch.le(x, y)
469*da0073e9SAndroid Build Coastguard Worker            return c
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
472*da0073e9SAndroid Build Coastguard Worker        aa = np.empty([1024], dtype=np.int32)
473*da0073e9SAndroid Build Coastguard Worker        aa.fill(5)
474*da0073e9SAndroid Build Coastguard Worker        a = torch.from_numpy(aa)
475*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(1024, dtype=torch.int32)
476*da0073e9SAndroid Build Coastguard Worker        x = warmup_and_run_forward(traced, a, b)
477*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
478*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(np.zeros(1024), x.numpy())
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    def test_lt(self):
481*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
482*da0073e9SAndroid Build Coastguard Worker            c = torch.lt(x, y)
483*da0073e9SAndroid Build Coastguard Worker            return c
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker        for dev in self.devices:
486*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev)))
487*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(1024, dtype=torch.int32, device=dev)
488*da0073e9SAndroid Build Coastguard Worker            b = torch.zeros(1024, dtype=torch.int32, device=dev)
489*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, a, b)
490*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
491*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy())
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
494*da0073e9SAndroid Build Coastguard Worker    def test_min_max(self):
495*da0073e9SAndroid Build Coastguard Worker        def test(x, y):
496*da0073e9SAndroid Build Coastguard Worker            return torch.max(torch.min(x, y), torch.tensor([4.0]))
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(test, (torch.zeros(1024), torch.zeros(1024)))
499*da0073e9SAndroid Build Coastguard Worker        a = 8.0 * torch.rand(1024)
500*da0073e9SAndroid Build Coastguard Worker        b = 8.0 * torch.rand(1024)
501*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(
502*da0073e9SAndroid Build Coastguard Worker            warmup_and_run_forward(traced, a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0])
503*da0073e9SAndroid Build Coastguard Worker        )
504*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker    def test_min_max_reduction(self):
507*da0073e9SAndroid Build Coastguard Worker        def test(x):
508*da0073e9SAndroid Build Coastguard Worker            return torch.min(x) + torch.max(x)
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(test, (torch.zeros(1024)))
511*da0073e9SAndroid Build Coastguard Worker        a = 8.0 * torch.rand(1024)
512*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy()))
513*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker    def test_min_max_reduction2(self):
516*da0073e9SAndroid Build Coastguard Worker        def test(x):
517*da0073e9SAndroid Build Coastguard Worker            return x.min() + x.max()
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(test, (torch.zeros(1024)))
520*da0073e9SAndroid Build Coastguard Worker        a = 8.0 * torch.rand(1024)
521*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy()))
522*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker    def test_min_max_reduction_dim1(self):
525*da0073e9SAndroid Build Coastguard Worker        def test(x):
526*da0073e9SAndroid Build Coastguard Worker            return torch.min(x, 1)[0] + torch.max(x, 1)[0]
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(test, (torch.zeros(16, 16)))
529*da0073e9SAndroid Build Coastguard Worker        a = 8.0 * torch.rand(16, 16)
530*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(
531*da0073e9SAndroid Build Coastguard Worker            a.numpy(), axis=1) + np.amax(a.numpy(), axis=1))
532*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker    def test_min_max_reduction_dim1_2(self):
535*da0073e9SAndroid Build Coastguard Worker        def test(x):
536*da0073e9SAndroid Build Coastguard Worker            return torch.min(x * x, 1)
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(test, (torch.zeros(16, 16)))
539*da0073e9SAndroid Build Coastguard Worker        a = 8.0 * torch.rand(16, 16)
540*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin((a * a).numpy(), axis=1))
541*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker    def test_clamp(self):
544*da0073e9SAndroid Build Coastguard Worker        def test(x):
545*da0073e9SAndroid Build Coastguard Worker            return torch.clamp(x + 3.0, 0.0, 6.0)
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        for dev in self.devices:
548*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
549*da0073e9SAndroid Build Coastguard Worker            a = 20.0 * torch.rand(1024, device=dev) - 10.0
550*da0073e9SAndroid Build Coastguard Worker            an = a.cpu().numpy()
551*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip(an + 3.0, 0.0, 6.0))
552*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker    def test_relu(self):
555*da0073e9SAndroid Build Coastguard Worker        def test(x):
556*da0073e9SAndroid Build Coastguard Worker            return torch.clamp(F.relu(x), 0, 0.5)
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker        for dev in self.devices:
559*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
560*da0073e9SAndroid Build Coastguard Worker            a = 20.0 * torch.rand(1024, device=dev) - 10.0
561*da0073e9SAndroid Build Coastguard Worker            an = a.cpu().numpy()
562*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5))
563*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker    def test_reps(self):
566*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
567*da0073e9SAndroid Build Coastguard Worker            c = torch.add(x, y)
568*da0073e9SAndroid Build Coastguard Worker            return c
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024)))
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker        for _ in range(32):
573*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(1024)
574*da0073e9SAndroid Build Coastguard Worker            b = torch.zeros(1024)
575*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, a, b)
576*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(np.ones(1024), x.numpy())
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker    def test_add_const_rhs(self):
579*da0073e9SAndroid Build Coastguard Worker        def test(x):
580*da0073e9SAndroid Build Coastguard Worker            return x + 3.0
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(test, torch.rand(4))
583*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
584*da0073e9SAndroid Build Coastguard Worker        y = warmup_and_run_forward(traced, x)
585*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
586*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(x.numpy() + 3.0, y.numpy())
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker    def test_int_output(self):
589*da0073e9SAndroid Build Coastguard Worker        def test(x, y, z):
590*da0073e9SAndroid Build Coastguard Worker            return x * y * z
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker        xs = [(torch.rand(4) * 3 + 1).to(torch.int32) for i in range(3)]
593*da0073e9SAndroid Build Coastguard Worker        x, y, z = xs
594*da0073e9SAndroid Build Coastguard Worker        xn, yn, zn = (t.numpy() for t in xs)
595*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(test, (x, y, z))
596*da0073e9SAndroid Build Coastguard Worker        res = warmup_and_run_forward(traced, x, y, z)
597*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
598*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(xn * yn * zn, res.numpy())
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker    def test_binary_ops(self):
601*da0073e9SAndroid Build Coastguard Worker        def test_atan2(x, y):
602*da0073e9SAndroid Build Coastguard Worker            c = torch.atan2(torch.add(x, y), y)
603*da0073e9SAndroid Build Coastguard Worker            return c
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker        def test_gt(x, y):
606*da0073e9SAndroid Build Coastguard Worker            c = torch.gt(torch.add(x, y), y)
607*da0073e9SAndroid Build Coastguard Worker            return c
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        def test_ge(x, y):
610*da0073e9SAndroid Build Coastguard Worker            c = torch.ge(torch.add(x, y), y)
611*da0073e9SAndroid Build Coastguard Worker            return c
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker        def test_lt(x, y):
614*da0073e9SAndroid Build Coastguard Worker            c = torch.lt(torch.add(x, y), y)
615*da0073e9SAndroid Build Coastguard Worker            return c
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker        def test_le(x, y):
618*da0073e9SAndroid Build Coastguard Worker            c = torch.le(torch.add(x, y), y)
619*da0073e9SAndroid Build Coastguard Worker            return c
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker        def test_lerp(x, y):
622*da0073e9SAndroid Build Coastguard Worker            c = torch.lerp(torch.add(x, 1), x, 2.0)
623*da0073e9SAndroid Build Coastguard Worker            return c
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker        def test_mul(x, y):
626*da0073e9SAndroid Build Coastguard Worker            c = torch.mul(torch.add(x, y), y)
627*da0073e9SAndroid Build Coastguard Worker            return c
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker        def test_ne(x, y):
630*da0073e9SAndroid Build Coastguard Worker            c = torch.ne(torch.add(x, y), y)
631*da0073e9SAndroid Build Coastguard Worker            return c
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker        def test_div(x, y):
634*da0073e9SAndroid Build Coastguard Worker            c = torch.div(torch.add(x, y), 2)
635*da0073e9SAndroid Build Coastguard Worker            return c
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker        def test_eq(x, y):
638*da0073e9SAndroid Build Coastguard Worker            c = torch.eq(torch.add(x, y), y)
639*da0073e9SAndroid Build Coastguard Worker            return c
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker        def test_fmod(x, y):
642*da0073e9SAndroid Build Coastguard Worker            c = torch.fmod(torch.add(x, y), 2)
643*da0073e9SAndroid Build Coastguard Worker            return c
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker        def test_sub(x, y):
646*da0073e9SAndroid Build Coastguard Worker            c = torch.sub(torch.add(x, y), x)
647*da0073e9SAndroid Build Coastguard Worker            return c
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker        def test_remainder(x, y):
650*da0073e9SAndroid Build Coastguard Worker            c = torch.remainder(torch.add(x, y), 3.0)
651*da0073e9SAndroid Build Coastguard Worker            return c
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker        def test_pow(x, y):
654*da0073e9SAndroid Build Coastguard Worker            c = torch.pow(torch.add(x, y), 2.0)
655*da0073e9SAndroid Build Coastguard Worker            return c
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker        def test_type_as(x, y):
658*da0073e9SAndroid Build Coastguard Worker            return x.type_as(torch.add(x, y))
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker        cmp_fns = {
661*da0073e9SAndroid Build Coastguard Worker            test_gt,
662*da0073e9SAndroid Build Coastguard Worker            test_ge,
663*da0073e9SAndroid Build Coastguard Worker            test_lt,
664*da0073e9SAndroid Build Coastguard Worker            test_le,
665*da0073e9SAndroid Build Coastguard Worker            test_ne,
666*da0073e9SAndroid Build Coastguard Worker            test_eq
667*da0073e9SAndroid Build Coastguard Worker        }
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker        non_cmp_fns = {
670*da0073e9SAndroid Build Coastguard Worker            test_atan2,
671*da0073e9SAndroid Build Coastguard Worker            test_lerp,
672*da0073e9SAndroid Build Coastguard Worker            test_mul,
673*da0073e9SAndroid Build Coastguard Worker            test_div,
674*da0073e9SAndroid Build Coastguard Worker            test_fmod,
675*da0073e9SAndroid Build Coastguard Worker            test_sub,
676*da0073e9SAndroid Build Coastguard Worker            test_remainder,
677*da0073e9SAndroid Build Coastguard Worker            test_pow,
678*da0073e9SAndroid Build Coastguard Worker            test_type_as,
679*da0073e9SAndroid Build Coastguard Worker        }
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker        all_test_fns = cmp_fns.union(non_cmp_fns)
682*da0073e9SAndroid Build Coastguard Worker        fn_dev_dtype = itertools.product(all_test_fns, self.devices, self.dtypes)
683*da0073e9SAndroid Build Coastguard Worker        for torch_fn, dev, data_type in fn_dev_dtype:
684*da0073e9SAndroid Build Coastguard Worker            if torch_fn is test_lerp and data_type is torch.bfloat16:
685*da0073e9SAndroid Build Coastguard Worker                continue
686*da0073e9SAndroid Build Coastguard Worker            rand_a = torch.rand(1024, dtype=data_type, device=dev)
687*da0073e9SAndroid Build Coastguard Worker            rand_b = torch.rand(1024, dtype=data_type, device=dev)
688*da0073e9SAndroid Build Coastguard Worker            in1 = 20 * torch.rand(1024, dtype=data_type, device=dev)
689*da0073e9SAndroid Build Coastguard Worker            in2 = 20 * torch.rand(1024, dtype=data_type, device=dev)
690*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(torch_fn, (in1, in2))
691*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, rand_a, rand_b)
692*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker            _atol = 2e-3
695*da0073e9SAndroid Build Coastguard Worker            _rtol = 1e-5
696*da0073e9SAndroid Build Coastguard Worker            if data_type is torch.bfloat16:
697*da0073e9SAndroid Build Coastguard Worker                # Compared to aten logic, NNC coudl save addtional BF16/Fp32 conversion.
698*da0073e9SAndroid Build Coastguard Worker                # Take d = a + b - c as an example, the aten logic is as follows at
699*da0073e9SAndroid Build Coastguard Worker                # operator level:
700*da0073e9SAndroid Build Coastguard Worker                #    tmp = to_bf16(to_fp32(a) + to_fp32(b))
701*da0073e9SAndroid Build Coastguard Worker                #    d = to_bf16(to_fp32(tmp) + to_fp32(c))
702*da0073e9SAndroid Build Coastguard Worker                # But NNC could fuse the compression and remove the redudant conversions.
703*da0073e9SAndroid Build Coastguard Worker                # The final statement is as follows
704*da0073e9SAndroid Build Coastguard Worker                #    d = to_bf16(to_fp32(a) + to_fp32(b) + to_fp32(c))
705*da0073e9SAndroid Build Coastguard Worker                # Hence, we simulate NNC computation by feeding fp32 tensors and converting
706*da0073e9SAndroid Build Coastguard Worker                # the result tensor back to bf16. The simulation could avoid the numeric
707*da0073e9SAndroid Build Coastguard Worker                # deviation to simplify the result comprasion
708*da0073e9SAndroid Build Coastguard Worker                y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float())
709*da0073e9SAndroid Build Coastguard Worker                if torch_fn not in cmp_fns:
710*da0073e9SAndroid Build Coastguard Worker                    y = y.bfloat16()
711*da0073e9SAndroid Build Coastguard Worker                _atol = 2e-2
712*da0073e9SAndroid Build Coastguard Worker            else:
713*da0073e9SAndroid Build Coastguard Worker                y = torch_fn(rand_a, rand_b)
714*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol)
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker    def test_unary_ops(self):
717*da0073e9SAndroid Build Coastguard Worker        def test_cast_float(x, y):
718*da0073e9SAndroid Build Coastguard Worker            c = torch.ops.aten._cast_Float(torch.add(x, y))
719*da0073e9SAndroid Build Coastguard Worker            return c
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker        def test_round(x, y):
722*da0073e9SAndroid Build Coastguard Worker            c = torch.round(torch.add(x, y))
723*da0073e9SAndroid Build Coastguard Worker            return c
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker        def test_sin(x, y):
726*da0073e9SAndroid Build Coastguard Worker            c = torch.sin(torch.add(x, y))
727*da0073e9SAndroid Build Coastguard Worker            return c
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker        def test_asin(x, y):
730*da0073e9SAndroid Build Coastguard Worker            c = torch.asin(torch.add(x, y))
731*da0073e9SAndroid Build Coastguard Worker            return c
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker        def test_sinh(x, y):
734*da0073e9SAndroid Build Coastguard Worker            c = torch.sinh(torch.add(x, y))
735*da0073e9SAndroid Build Coastguard Worker            return c
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker        def test_cos(x, y):
738*da0073e9SAndroid Build Coastguard Worker            c = torch.cos(torch.add(x, y))
739*da0073e9SAndroid Build Coastguard Worker            return c
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker        def test_acos(x, y):
742*da0073e9SAndroid Build Coastguard Worker            c = torch.acos(torch.add(x, y))
743*da0073e9SAndroid Build Coastguard Worker            return c
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker        def test_cosh(x, y):
746*da0073e9SAndroid Build Coastguard Worker            c = torch.cosh(torch.add(x, y))
747*da0073e9SAndroid Build Coastguard Worker            return c
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker        def test_tan(x, y):
750*da0073e9SAndroid Build Coastguard Worker            c = torch.tan(torch.add(x, y))
751*da0073e9SAndroid Build Coastguard Worker            return c
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker        def test_atan(x, y):
754*da0073e9SAndroid Build Coastguard Worker            c = torch.atan(torch.add(x, y))
755*da0073e9SAndroid Build Coastguard Worker            return c
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker        def test_tanh(x, y):
758*da0073e9SAndroid Build Coastguard Worker            c = torch.tanh(torch.add(x, y))
759*da0073e9SAndroid Build Coastguard Worker            return c
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker        def test_sqrt(x, y):
762*da0073e9SAndroid Build Coastguard Worker            c = torch.sqrt(torch.add(x, y))
763*da0073e9SAndroid Build Coastguard Worker            return c
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker        def test_rsqrt(x, y):
766*da0073e9SAndroid Build Coastguard Worker            c = torch.rsqrt(torch.add(x, y))
767*da0073e9SAndroid Build Coastguard Worker            return c
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker        def test_floor(x, y):
770*da0073e9SAndroid Build Coastguard Worker            c = torch.floor(torch.add(x, y))
771*da0073e9SAndroid Build Coastguard Worker            return c
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker        def test_ceil(x, y):
774*da0073e9SAndroid Build Coastguard Worker            c = torch.ceil(torch.add(x, y))
775*da0073e9SAndroid Build Coastguard Worker            return c
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker        def test_trunc(x, y):
778*da0073e9SAndroid Build Coastguard Worker            c = torch.trunc(torch.add(x, y))
779*da0073e9SAndroid Build Coastguard Worker            return c
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker        def test_abs(x, y):
782*da0073e9SAndroid Build Coastguard Worker            c = torch.abs(torch.add(x, y))
783*da0073e9SAndroid Build Coastguard Worker            return c
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker        def test_log(x, y):
786*da0073e9SAndroid Build Coastguard Worker            c = torch.log(torch.add(x, y))
787*da0073e9SAndroid Build Coastguard Worker            return c
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker        def test_log2(x, y):
790*da0073e9SAndroid Build Coastguard Worker            c = torch.log2(torch.add(x, y))
791*da0073e9SAndroid Build Coastguard Worker            return c
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker        def test_log10(x, y):
794*da0073e9SAndroid Build Coastguard Worker            c = torch.log10(torch.add(x, y))
795*da0073e9SAndroid Build Coastguard Worker            return c
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Worker        def test_log1p(x, y):
798*da0073e9SAndroid Build Coastguard Worker            c = torch.log1p(torch.add(x, y))
799*da0073e9SAndroid Build Coastguard Worker            return c
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker        def test_rqrt(x, y):
802*da0073e9SAndroid Build Coastguard Worker            c = torch.rsqrt(torch.add(x, y))
803*da0073e9SAndroid Build Coastguard Worker            return c
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker        def test_erf(x, y):
806*da0073e9SAndroid Build Coastguard Worker            c = torch.erf(torch.add(x, y))
807*da0073e9SAndroid Build Coastguard Worker            return c
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker        def test_exp(x, y):
810*da0073e9SAndroid Build Coastguard Worker            c = torch.exp(torch.add(x, y))
811*da0073e9SAndroid Build Coastguard Worker            return c
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker        def test_expm1(x, y):
814*da0073e9SAndroid Build Coastguard Worker            c = torch.expm1(torch.add(x, y))
815*da0073e9SAndroid Build Coastguard Worker            return c
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker        def test_erfc(x, y):
818*da0073e9SAndroid Build Coastguard Worker            c = torch.erfc(torch.add(x, y))
819*da0073e9SAndroid Build Coastguard Worker            return c
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker        def test_frac(x, y):
822*da0073e9SAndroid Build Coastguard Worker            c = torch.frac(torch.add(x, y))
823*da0073e9SAndroid Build Coastguard Worker            return c
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker        def test_lgamma(x, y):
826*da0073e9SAndroid Build Coastguard Worker            c = torch.lgamma(torch.add(x, y))
827*da0073e9SAndroid Build Coastguard Worker            return c
828*da0073e9SAndroid Build Coastguard Worker
829*da0073e9SAndroid Build Coastguard Worker        def test_sigmoid(x, y):
830*da0073e9SAndroid Build Coastguard Worker            c = torch.sigmoid(torch.add(x, y))
831*da0073e9SAndroid Build Coastguard Worker            return c
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker        def test_reciprocal(x, y):
834*da0073e9SAndroid Build Coastguard Worker            c = torch.reciprocal(torch.add(x, y))
835*da0073e9SAndroid Build Coastguard Worker            return c
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker        def test_neg(x, y):
838*da0073e9SAndroid Build Coastguard Worker            c = torch.neg(torch.add(x, y))
839*da0073e9SAndroid Build Coastguard Worker            return c
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker        def test_relu(x, y):
842*da0073e9SAndroid Build Coastguard Worker            c = torch.relu(torch.add(x, y))
843*da0073e9SAndroid Build Coastguard Worker            return c
844*da0073e9SAndroid Build Coastguard Worker
845*da0073e9SAndroid Build Coastguard Worker        def test_hardtanh(x, y):
846*da0073e9SAndroid Build Coastguard Worker            c = F.hardtanh(torch.add(x, y), -1.0, 1.0)
847*da0073e9SAndroid Build Coastguard Worker            return c
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker        def test_threshold(x, y):
850*da0073e9SAndroid Build Coastguard Worker            c = F.threshold(torch.add(x, y), 0.5, 10)
851*da0073e9SAndroid Build Coastguard Worker            return c
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker        gpu_only_fns = {
854*da0073e9SAndroid Build Coastguard Worker            test_erf,
855*da0073e9SAndroid Build Coastguard Worker            test_erfc
856*da0073e9SAndroid Build Coastguard Worker        }
857*da0073e9SAndroid Build Coastguard Worker        fns = {
858*da0073e9SAndroid Build Coastguard Worker            test_round,
859*da0073e9SAndroid Build Coastguard Worker            test_sin,
860*da0073e9SAndroid Build Coastguard Worker            test_asin,
861*da0073e9SAndroid Build Coastguard Worker            test_sinh,
862*da0073e9SAndroid Build Coastguard Worker            test_cos,
863*da0073e9SAndroid Build Coastguard Worker            test_acos,
864*da0073e9SAndroid Build Coastguard Worker            test_cosh,
865*da0073e9SAndroid Build Coastguard Worker            test_tan,
866*da0073e9SAndroid Build Coastguard Worker            test_atan,
867*da0073e9SAndroid Build Coastguard Worker            test_sqrt,
868*da0073e9SAndroid Build Coastguard Worker            test_floor,
869*da0073e9SAndroid Build Coastguard Worker            test_ceil,
870*da0073e9SAndroid Build Coastguard Worker            test_trunc,
871*da0073e9SAndroid Build Coastguard Worker            test_abs,
872*da0073e9SAndroid Build Coastguard Worker            test_log,
873*da0073e9SAndroid Build Coastguard Worker            test_log2,
874*da0073e9SAndroid Build Coastguard Worker            test_log10,
875*da0073e9SAndroid Build Coastguard Worker            test_log1p,
876*da0073e9SAndroid Build Coastguard Worker            test_rsqrt,
877*da0073e9SAndroid Build Coastguard Worker            test_exp,
878*da0073e9SAndroid Build Coastguard Worker            test_expm1,
879*da0073e9SAndroid Build Coastguard Worker            test_frac,
880*da0073e9SAndroid Build Coastguard Worker            test_lgamma,
881*da0073e9SAndroid Build Coastguard Worker            test_reciprocal,
882*da0073e9SAndroid Build Coastguard Worker            test_neg,
883*da0073e9SAndroid Build Coastguard Worker            test_threshold,
884*da0073e9SAndroid Build Coastguard Worker            test_relu,
885*da0073e9SAndroid Build Coastguard Worker            test_tanh,
886*da0073e9SAndroid Build Coastguard Worker            test_hardtanh,
887*da0073e9SAndroid Build Coastguard Worker            test_sigmoid,
888*da0073e9SAndroid Build Coastguard Worker        }
889*da0073e9SAndroid Build Coastguard Worker        fn_dev_dtype = itertools.product(gpu_only_fns.union(fns), self.devices, self.dtypes)
890*da0073e9SAndroid Build Coastguard Worker
891*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(0)
892*da0073e9SAndroid Build Coastguard Worker        for torch_fn, dev, data_type in fn_dev_dtype:
893*da0073e9SAndroid Build Coastguard Worker            if torch_fn == test_lgamma and dev == "cuda":
894*da0073e9SAndroid Build Coastguard Worker                # lgamma_cuda does not support BF16
895*da0073e9SAndroid Build Coastguard Worker                continue
896*da0073e9SAndroid Build Coastguard Worker            rand_a = torch.rand(1024, dtype=data_type, device=dev)
897*da0073e9SAndroid Build Coastguard Worker            rand_b = torch.rand(1024, dtype=data_type, device=dev)
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker            ins = 20 * torch.rand(1024, dtype=data_type, device=dev)
900*da0073e9SAndroid Build Coastguard Worker            cc = np.empty([1024], dtype=np.float32)
901*da0073e9SAndroid Build Coastguard Worker            cc.fill(np.nan)
902*da0073e9SAndroid Build Coastguard Worker            nans = torch.from_numpy(cc).to(dev)
903*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(torch_fn, (ins, ins))
904*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, rand_a, rand_b)
905*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker            _atol = 5e-3 if data_type is torch.bfloat16 else 2e-3
908*da0073e9SAndroid Build Coastguard Worker            _rtol = 1e-5
909*da0073e9SAndroid Build Coastguard Worker            if data_type is torch.bfloat16 and torch_fn not in gpu_only_fns:
910*da0073e9SAndroid Build Coastguard Worker                y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float())
911*da0073e9SAndroid Build Coastguard Worker                y = y.bfloat16()
912*da0073e9SAndroid Build Coastguard Worker            else:
913*da0073e9SAndroid Build Coastguard Worker                y = torch_fn(rand_a, rand_b)
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol)
916*da0073e9SAndroid Build Coastguard Worker            # nans
917*da0073e9SAndroid Build Coastguard Worker            # TODO: reenable. Currently all of the tests fail
918*da0073e9SAndroid Build Coastguard Worker            # traced = torch.jit.trace(torch_fn, (ins, ins))
919*da0073e9SAndroid Build Coastguard Worker            # x = warmup_and_run_forward(traced, rand_a, rand_b)
920*da0073e9SAndroid Build Coastguard Worker            # y = torch_fn(nans, rand_b)
921*da0073e9SAndroid Build Coastguard Worker            # try:
922*da0073e9SAndroid Build Coastguard Worker            #     np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
923*da0073e9SAndroid Build Coastguard Worker            #     print("Succeeded on dev=", dev, "function=", torch_fn)
924*da0073e9SAndroid Build Coastguard Worker            # except AssertionError:
925*da0073e9SAndroid Build Coastguard Worker            #     # Print extra info before exiting:
926*da0073e9SAndroid Build Coastguard Worker            #     print("Failed on dev=", dev, "function=", torch_fn)
927*da0073e9SAndroid Build Coastguard Worker            #     # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
928*da0073e9SAndroid Build Coastguard Worker
929*da0073e9SAndroid Build Coastguard Worker
930*da0073e9SAndroid Build Coastguard Worker    def test_round_2(self):
931*da0073e9SAndroid Build Coastguard Worker        def round(x):
932*da0073e9SAndroid Build Coastguard Worker            return torch.round(x)
933*da0073e9SAndroid Build Coastguard Worker
934*da0073e9SAndroid Build Coastguard Worker        for data_type in [torch.float32, torch.double]:
935*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([0.2, 1.6, 2.5, 3.5]).to(data_type)
936*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(round, (a))
937*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, a)
938*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
939*da0073e9SAndroid Build Coastguard Worker            y = round(x)
940*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, y)
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker    def test_rand_like(self):
943*da0073e9SAndroid Build Coastguard Worker        N = 1 << 16
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker        def run_rand_like(x, y):
946*da0073e9SAndroid Build Coastguard Worker            return torch.rand_like(torch.add(x, y))
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
949*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(N, device=device)
950*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False)
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker            for data_type in self.dtypes:
953*da0073e9SAndroid Build Coastguard Worker                _x = x.to(dtype=data_type)
954*da0073e9SAndroid Build Coastguard Worker                x_v = warmup_and_run_forward(traced, _x, _x)
955*da0073e9SAndroid Build Coastguard Worker                self.assertLastGraphAllFused()
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker            x_np = x.cpu().numpy()
958*da0073e9SAndroid Build Coastguard Worker            x1_mean = np.mean(x_np)
959*da0073e9SAndroid Build Coastguard Worker            x2_mean = np.mean(x_np ** 2)
960*da0073e9SAndroid Build Coastguard Worker            x3_mean = np.mean(x_np ** 3)
961*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2)
962*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2)
963*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2)
964*da0073e9SAndroid Build Coastguard Worker
965*da0073e9SAndroid Build Coastguard Worker    def test_nans(self):
966*da0073e9SAndroid Build Coastguard Worker        def test_max(x, y):
967*da0073e9SAndroid Build Coastguard Worker            return torch.max(2 * x, 2 * y)
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker        def test_min(x, y):
970*da0073e9SAndroid Build Coastguard Worker            return torch.min(2 * x, 2 * y)
971*da0073e9SAndroid Build Coastguard Worker
972*da0073e9SAndroid Build Coastguard Worker        tmax = torch.jit.trace(test_max, (torch.rand(1), torch.rand(1)))
973*da0073e9SAndroid Build Coastguard Worker        tmin = torch.jit.trace(test_min, (torch.rand(1), torch.rand(1)))
974*da0073e9SAndroid Build Coastguard Worker
975*da0073e9SAndroid Build Coastguard Worker        for data_type in self.dtypes:
976*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([np.nan]).to(dtype=data_type)
977*da0073e9SAndroid Build Coastguard Worker            y = torch.tensor([1.0]).to(dtype=data_type)
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        assert np.isnan(warmup_and_run_forward(tmin, x, y).float().item())
980*da0073e9SAndroid Build Coastguard Worker        assert np.isnan(warmup_and_run_forward(tmin, y, x).float().item())
981*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
982*da0073e9SAndroid Build Coastguard Worker        assert np.isnan(warmup_and_run_forward(tmax, x, y).float().item())
983*da0073e9SAndroid Build Coastguard Worker        assert np.isnan(warmup_and_run_forward(tmax, y, x).float().item())
984*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker    def test_double_intrinsics(self):
987*da0073e9SAndroid Build Coastguard Worker        def do_pow(x):
988*da0073e9SAndroid Build Coastguard Worker            return torch.pow(x, 7)
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
991*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(10, dtype=torch.double, device=device)
992*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(do_pow, (x))
993*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, x)
994*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker    def test_remainder(self):
997*da0073e9SAndroid Build Coastguard Worker        def run_remainder(x, y):
998*da0073e9SAndroid Build Coastguard Worker            c = torch.remainder(torch.add(x, y), x)
999*da0073e9SAndroid Build Coastguard Worker            return c
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker        for data_type in self.dtypes:
1002*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(1024, dtype=data_type)
1003*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(1024, dtype=data_type)
1004*da0073e9SAndroid Build Coastguard Worker            zeros = torch.zeros(1024, dtype=data_type)
1005*da0073e9SAndroid Build Coastguard Worker            cc = np.array(1024, dtype=float)
1006*da0073e9SAndroid Build Coastguard Worker            cc.fill(np.nan)
1007*da0073e9SAndroid Build Coastguard Worker            nans = torch.from_numpy(cc).to(dtype=data_type)
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker            # random floats
1010*da0073e9SAndroid Build Coastguard Worker            zeros1 = torch.zeros(1024, dtype=data_type)
1011*da0073e9SAndroid Build Coastguard Worker            zeros2 = torch.zeros(1024, dtype=data_type)
1012*da0073e9SAndroid Build Coastguard Worker
1013*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1014*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, a, b)
1015*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1016*da0073e9SAndroid Build Coastguard Worker            y = run_remainder(a, b)
1017*da0073e9SAndroid Build Coastguard Worker            if data_type is torch.bfloat16:
1018*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x, y, atol=4e-3, rtol=2e-3)
1019*da0073e9SAndroid Build Coastguard Worker            else:
1020*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x, y)
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker            # div by 0
1023*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1024*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, zeros, a)
1025*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1026*da0073e9SAndroid Build Coastguard Worker            y = run_remainder(zeros, a)
1027*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, y)
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker            # numerators and denominatos are nan
1030*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1031*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, nans, a)
1032*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1033*da0073e9SAndroid Build Coastguard Worker            y = run_remainder(nans, a)
1034*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, y)
1035*da0073e9SAndroid Build Coastguard Worker
1036*da0073e9SAndroid Build Coastguard Worker    def test_multioutput(self):
1037*da0073e9SAndroid Build Coastguard Worker        def easy(x):
1038*da0073e9SAndroid Build Coastguard Worker            b = x + 1
1039*da0073e9SAndroid Build Coastguard Worker            c = b + b
1040*da0073e9SAndroid Build Coastguard Worker            return (b, c)
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(easy, (torch.zeros(1024)))
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(1024)
1045*da0073e9SAndroid Build Coastguard Worker        b, c = warmup_and_run_forward(traced, a)
1046*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
1047*da0073e9SAndroid Build Coastguard Worker        bp = a.numpy() + 1
1048*da0073e9SAndroid Build Coastguard Worker        cp = bp + bp
1049*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(b.numpy(), bp)
1050*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(c.numpy(), cp)
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker    def test_chunk(self):
1053*da0073e9SAndroid Build Coastguard Worker        def easy(x):
1054*da0073e9SAndroid Build Coastguard Worker            y = x + 1
1055*da0073e9SAndroid Build Coastguard Worker            aaa, bbb = torch.chunk(y, 2)
1056*da0073e9SAndroid Build Coastguard Worker            return aaa + bbb
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Worker        for data_type in self.dtypes:
1059*da0073e9SAndroid Build Coastguard Worker            trace_input = torch.zeros(1024, 1024, dtype=data_type)
1060*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(easy, (trace_input))
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker            a = torch.zeros(32, 32, dtype=data_type)
1063*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, a)
1064*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1065*da0073e9SAndroid Build Coastguard Worker            npr = a.float().numpy()
1066*da0073e9SAndroid Build Coastguard Worker            npr2 = npr + 1
1067*da0073e9SAndroid Build Coastguard Worker            npr_a, npr_b = np.array_split(npr2, 2)
1068*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(npr_a + npr_b, x.float().numpy())
1069*da0073e9SAndroid Build Coastguard Worker
1070*da0073e9SAndroid Build Coastguard Worker    def test_cat(self):
1071*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
1072*da0073e9SAndroid Build Coastguard Worker            _dim = 1
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker            def foo(*args):
1075*da0073e9SAndroid Build Coastguard Worker                args_2 = [v + i for i, v in enumerate(args)]
1076*da0073e9SAndroid Build Coastguard Worker                v = torch.cat(args_2, dim=_dim)
1077*da0073e9SAndroid Build Coastguard Worker                return v * v
1078*da0073e9SAndroid Build Coastguard Worker
1079*da0073e9SAndroid Build Coastguard Worker            for data_type in self.dtypes:
1080*da0073e9SAndroid Build Coastguard Worker                M = 16
1081*da0073e9SAndroid Build Coastguard Worker                Ns = [128, 16, 1]
1082*da0073e9SAndroid Build Coastguard Worker                values = [torch.zeros(M, N, dtype=data_type, device=device) for N in Ns]
1083*da0073e9SAndroid Build Coastguard Worker                traced = torch.jit.trace(foo, values)
1084*da0073e9SAndroid Build Coastguard Worker
1085*da0073e9SAndroid Build Coastguard Worker                x = warmup_and_run_forward(traced, *values)
1086*da0073e9SAndroid Build Coastguard Worker                self.assertLastGraphAllFused()
1087*da0073e9SAndroid Build Coastguard Worker                ref = foo(*values)
1088*da0073e9SAndroid Build Coastguard Worker                np.testing.assert_allclose(ref.cpu().float().numpy(), x.cpu().float().numpy())
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker            # Test channels-last
1091*da0073e9SAndroid Build Coastguard Worker            for _cur_dim in range(4):
1092*da0073e9SAndroid Build Coastguard Worker                _dim = _cur_dim
1093*da0073e9SAndroid Build Coastguard Worker                values = [torch.randn((2, 3, 4, 5), device=device).to(memory_format=torch.channels_last) for _ in range(10)]
1094*da0073e9SAndroid Build Coastguard Worker                traced = torch.jit.trace(foo, values)
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker                x = warmup_and_run_forward(traced, *values)
1097*da0073e9SAndroid Build Coastguard Worker                self.assertLastGraphAllFused()
1098*da0073e9SAndroid Build Coastguard Worker                ref = foo(*values)
1099*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ref, x)
1100*da0073e9SAndroid Build Coastguard Worker
1101*da0073e9SAndroid Build Coastguard Worker    # This test checks that we correctly handle fusion group with just aten::cat in it.
1102*da0073e9SAndroid Build Coastguard Worker    # Note that the test only makes sense with min_fusion_group=1, otherwise no
1103*da0073e9SAndroid Build Coastguard Worker    # fusion groups would be formed at all.
1104*da0073e9SAndroid Build Coastguard Worker    # TODO: Fix and re-enable the test.
1105*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("cat is broken with fusion group inlining disabled")
1106*da0073e9SAndroid Build Coastguard Worker    def test_cat_only(self):
1107*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
1108*da0073e9SAndroid Build Coastguard Worker            def foo(*args):
1109*da0073e9SAndroid Build Coastguard Worker                args_2 = [v + i for i, v in enumerate(args)]
1110*da0073e9SAndroid Build Coastguard Worker                v = torch.cat(args_2, dim=1)
1111*da0073e9SAndroid Build Coastguard Worker                return v
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker            M = 16
1114*da0073e9SAndroid Build Coastguard Worker            Ns = [128, 16, 1]
1115*da0073e9SAndroid Build Coastguard Worker            values = [torch.zeros(M, N, device=device) for N in Ns]
1116*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(foo, values)
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, *values)
1119*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1120*da0073e9SAndroid Build Coastguard Worker            ref = foo(*values)
1121*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker    def test_cat_negative_dim(self):
1124*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
1125*da0073e9SAndroid Build Coastguard Worker            def foo(*args):
1126*da0073e9SAndroid Build Coastguard Worker                v = torch.cat(args, dim=-1)
1127*da0073e9SAndroid Build Coastguard Worker                return v * v
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker            M = 16
1130*da0073e9SAndroid Build Coastguard Worker            Ns = [128, 16, 1]
1131*da0073e9SAndroid Build Coastguard Worker            values = [torch.randn(M, N, device=device) for N in Ns]
1132*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(foo, values)
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, *values)
1135*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1136*da0073e9SAndroid Build Coastguard Worker            ref = foo(*values)
1137*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1138*da0073e9SAndroid Build Coastguard Worker
1139*da0073e9SAndroid Build Coastguard Worker    def test_cat_promote_inputs(self):
1140*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
1141*da0073e9SAndroid Build Coastguard Worker            def foo(*args):
1142*da0073e9SAndroid Build Coastguard Worker                v = torch.cat(args, dim=1)
1143*da0073e9SAndroid Build Coastguard Worker                return v * v
1144*da0073e9SAndroid Build Coastguard Worker
1145*da0073e9SAndroid Build Coastguard Worker            M = 16
1146*da0073e9SAndroid Build Coastguard Worker            Ns = [128, 16, 1]
1147*da0073e9SAndroid Build Coastguard Worker            dtypes = [torch.half, torch.float32, torch.double]
1148*da0073e9SAndroid Build Coastguard Worker            values = [torch.randn(M, N, device=device, dtype=dt) for N, dt in zip(Ns, dtypes)]
1149*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(foo, values)
1150*da0073e9SAndroid Build Coastguard Worker
1151*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, *values)
1152*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1153*da0073e9SAndroid Build Coastguard Worker            ref = foo(*values)
1154*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1155*da0073e9SAndroid Build Coastguard Worker
1156*da0073e9SAndroid Build Coastguard Worker    def test_cat_empty_tensors(self):
1157*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
1158*da0073e9SAndroid Build Coastguard Worker            def foo(*args):
1159*da0073e9SAndroid Build Coastguard Worker                v = torch.cat(args, dim=1)
1160*da0073e9SAndroid Build Coastguard Worker                return v * v
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker            M = 16
1163*da0073e9SAndroid Build Coastguard Worker            Ns = [128, 16, 1]
1164*da0073e9SAndroid Build Coastguard Worker            empty = torch.tensor([], device=device, dtype=torch.double)
1165*da0073e9SAndroid Build Coastguard Worker            values = [empty] + [torch.randn(M, N, device=device) for N in Ns]
1166*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(foo, values)
1167*da0073e9SAndroid Build Coastguard Worker
1168*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, *values)
1169*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1170*da0073e9SAndroid Build Coastguard Worker            ref = foo(*values)
1171*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1172*da0073e9SAndroid Build Coastguard Worker
1173*da0073e9SAndroid Build Coastguard Worker            # now test with only empty tensors
1174*da0073e9SAndroid Build Coastguard Worker            values = [empty for i in range(3)]
1175*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(foo, values)
1176*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, *values)
1177*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1178*da0073e9SAndroid Build Coastguard Worker            ref = foo(*values)
1179*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1180*da0073e9SAndroid Build Coastguard Worker
1181*da0073e9SAndroid Build Coastguard Worker    def test_cat_with_constant_dim(self):
1182*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
1183*da0073e9SAndroid Build Coastguard Worker            def foo(*args):
1184*da0073e9SAndroid Build Coastguard Worker                v1 = torch.cat(args, dim=1)
1185*da0073e9SAndroid Build Coastguard Worker                v2 = torch.cat([v1], dim=1)
1186*da0073e9SAndroid Build Coastguard Worker                return v2 * v2
1187*da0073e9SAndroid Build Coastguard Worker
1188*da0073e9SAndroid Build Coastguard Worker            empty = torch.tensor([], device=device, dtype=torch.float32)
1189*da0073e9SAndroid Build Coastguard Worker            inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)]
1190*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(foo, inputs)
1191*da0073e9SAndroid Build Coastguard Worker
1192*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, *inputs)
1193*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1194*da0073e9SAndroid Build Coastguard Worker            ref = foo(*inputs)
1195*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1196*da0073e9SAndroid Build Coastguard Worker
1197*da0073e9SAndroid Build Coastguard Worker    def test_scalar(self):
1198*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1199*da0073e9SAndroid Build Coastguard Worker        def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor:
1200*da0073e9SAndroid Build Coastguard Worker            return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
1201*da0073e9SAndroid Build Coastguard Worker
1202*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1203*da0073e9SAndroid Build Coastguard Worker        def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor:
1204*da0073e9SAndroid Build Coastguard Worker            return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Worker        for test in (test_float, test_int):
1207*da0073e9SAndroid Build Coastguard Worker            for data_type in self.dtypes:
1208*da0073e9SAndroid Build Coastguard Worker                x, y, z = (torch.rand(4, dtype=data_type) for i in range(3))
1209*da0073e9SAndroid Build Coastguard Worker                a, b = 1, 2
1210*da0073e9SAndroid Build Coastguard Worker                test(x, y, z, a, b)
1211*da0073e9SAndroid Build Coastguard Worker                r = test(x, y, z, a, b)
1212*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r, x + y * a + z * b)
1213*da0073e9SAndroid Build Coastguard Worker
1214*da0073e9SAndroid Build Coastguard Worker    def test_loop(self):
1215*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1216*da0073e9SAndroid Build Coastguard Worker        def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor:
1217*da0073e9SAndroid Build Coastguard Worker            b = y
1218*da0073e9SAndroid Build Coastguard Worker            for i in range(0, z):
1219*da0073e9SAndroid Build Coastguard Worker                a = x + y
1220*da0073e9SAndroid Build Coastguard Worker                b = b + y
1221*da0073e9SAndroid Build Coastguard Worker            return b
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker        x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4)
1224*da0073e9SAndroid Build Coastguard Worker        test(x, y, z)
1225*da0073e9SAndroid Build Coastguard Worker        r = test(x, y, z)
1226*da0073e9SAndroid Build Coastguard Worker
1227*da0073e9SAndroid Build Coastguard Worker    def test_slice(self):
1228*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
1229*da0073e9SAndroid Build Coastguard Worker            a = x[0:512:2]
1230*da0073e9SAndroid Build Coastguard Worker            b = y[0:512:2]
1231*da0073e9SAndroid Build Coastguard Worker            return a + b
1232*da0073e9SAndroid Build Coastguard Worker
1233*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024)))
1234*da0073e9SAndroid Build Coastguard Worker
1235*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1024, 1024)
1236*da0073e9SAndroid Build Coastguard Worker        x = traced(a, a)
1237*da0073e9SAndroid Build Coastguard Worker        npr = a[0:512:2]
1238*da0073e9SAndroid Build Coastguard Worker        npr = npr + npr
1239*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(npr.numpy(), x.numpy())
1240*da0073e9SAndroid Build Coastguard Worker
1241*da0073e9SAndroid Build Coastguard Worker    def test_unsqueeze(self, N=256):
1242*da0073e9SAndroid Build Coastguard Worker        def easy(x, y):
1243*da0073e9SAndroid Build Coastguard Worker            a = torch.unsqueeze(x, 0)
1244*da0073e9SAndroid Build Coastguard Worker            b = torch.unsqueeze(y, 0)
1245*da0073e9SAndroid Build Coastguard Worker            return a + b
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N)))
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(N, N)
1250*da0073e9SAndroid Build Coastguard Worker        x = traced(a, a)
1251*da0073e9SAndroid Build Coastguard Worker        npr = np.expand_dims(a, 0)
1252*da0073e9SAndroid Build Coastguard Worker        npr = npr + npr
1253*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(npr, x.numpy())
1254*da0073e9SAndroid Build Coastguard Worker
1255*da0073e9SAndroid Build Coastguard Worker    def _test_softmax(self, device):
1256*da0073e9SAndroid Build Coastguard Worker        def test_softmax(x, y):
1257*da0073e9SAndroid Build Coastguard Worker            a = F.softmax(x, dim=0, dtype=torch.float32)
1258*da0073e9SAndroid Build Coastguard Worker            b = F.softmax(y, dim=0, dtype=torch.float32)
1259*da0073e9SAndroid Build Coastguard Worker            c = F.softmax(x, dim=1, dtype=torch.float32)
1260*da0073e9SAndroid Build Coastguard Worker            d = F.softmax(y, dim=1, dtype=torch.float32)
1261*da0073e9SAndroid Build Coastguard Worker            return a + b + c + d
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker        def test_softmax_neg_index(x, y):
1264*da0073e9SAndroid Build Coastguard Worker            a = F.softmax(x, dim=-2, dtype=torch.float32)
1265*da0073e9SAndroid Build Coastguard Worker            b = F.softmax(y, dim=-2, dtype=torch.float32)
1266*da0073e9SAndroid Build Coastguard Worker            c = F.softmax(x, dim=-1, dtype=torch.float32)
1267*da0073e9SAndroid Build Coastguard Worker            d = F.softmax(y, dim=-1, dtype=torch.float32)
1268*da0073e9SAndroid Build Coastguard Worker            return a + b + c + d
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker        def test_log_softmax(x, y):
1271*da0073e9SAndroid Build Coastguard Worker            a = F.log_softmax(x, dim=0, dtype=torch.float32)
1272*da0073e9SAndroid Build Coastguard Worker            b = F.log_softmax(y, dim=0, dtype=torch.float32)
1273*da0073e9SAndroid Build Coastguard Worker            c = F.log_softmax(x, dim=1, dtype=torch.float32)
1274*da0073e9SAndroid Build Coastguard Worker            d = F.log_softmax(y, dim=1, dtype=torch.float32)
1275*da0073e9SAndroid Build Coastguard Worker            return a + b + c + d
1276*da0073e9SAndroid Build Coastguard Worker
1277*da0073e9SAndroid Build Coastguard Worker        for test in (test_softmax, test_log_softmax, test_softmax_neg_index):
1278*da0073e9SAndroid Build Coastguard Worker            for data_type in self.dtypes:
1279*da0073e9SAndroid Build Coastguard Worker                old = torch._C._jit_set_texpr_reductions_enabled(True)
1280*da0073e9SAndroid Build Coastguard Worker                traced_input = torch.randn(2, 3, dtype=data_type, device=device)
1281*da0073e9SAndroid Build Coastguard Worker                traced = torch.jit.trace(test, (traced_input, traced_input))
1282*da0073e9SAndroid Build Coastguard Worker                inp = torch.randn(2, 3, dtype=data_type, device=device)
1283*da0073e9SAndroid Build Coastguard Worker                res = traced(inp, inp)
1284*da0073e9SAndroid Build Coastguard Worker                # Use eager mode as reference.
1285*da0073e9SAndroid Build Coastguard Worker                ref = test(inp, inp)
1286*da0073e9SAndroid Build Coastguard Worker                np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06)
1287*da0073e9SAndroid Build Coastguard Worker                torch._C._jit_set_texpr_reductions_enabled(old)
1288*da0073e9SAndroid Build Coastguard Worker
1289*da0073e9SAndroid Build Coastguard Worker    def test_softmax_cpu(self):
1290*da0073e9SAndroid Build Coastguard Worker        self._test_softmax('cpu')
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1293*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("global allocs are not supported yet.")
1294*da0073e9SAndroid Build Coastguard Worker    def test_softmax_cuda(self):
1295*da0073e9SAndroid Build Coastguard Worker        self._test_softmax('cuda')
1296*da0073e9SAndroid Build Coastguard Worker
1297*da0073e9SAndroid Build Coastguard Worker    def test_half_gelu(self):
1298*da0073e9SAndroid Build Coastguard Worker        devices = ["cuda"] if torch.cuda.is_available() else []
1299*da0073e9SAndroid Build Coastguard Worker
1300*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1301*da0073e9SAndroid Build Coastguard Worker        def bias_gelu(bias, y):
1302*da0073e9SAndroid Build Coastguard Worker            x = bias + y
1303*da0073e9SAndroid Build Coastguard Worker            return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
1304*da0073e9SAndroid Build Coastguard Worker
1305*da0073e9SAndroid Build Coastguard Worker        for device in devices:
1306*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(1024, dtype=torch.half, device=device)
1307*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(1024, dtype=torch.half, device=device)
1308*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(bias_gelu, (a, b))
1309*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, a, b)
1310*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker    def test_half_bn_relu(self):
1313*da0073e9SAndroid Build Coastguard Worker        devices = ["cuda"] if torch.cuda.is_available() else []
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Worker        def foo(a, b, c):
1316*da0073e9SAndroid Build Coastguard Worker            y = torch.nn.functional.batch_norm(a, b, c)
1317*da0073e9SAndroid Build Coastguard Worker            z = y.relu()
1318*da0073e9SAndroid Build Coastguard Worker            return z
1319*da0073e9SAndroid Build Coastguard Worker
1320*da0073e9SAndroid Build Coastguard Worker        for device in devices:
1321*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(16, 16, dtype=torch.half, device=device)
1322*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(16, dtype=torch.half, device=device)
1323*da0073e9SAndroid Build Coastguard Worker            c = torch.rand(16, dtype=torch.half, device=device)
1324*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(foo, (a, b, c))
1325*da0073e9SAndroid Build Coastguard Worker            print(traced.graph)
1326*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, a, b, c)
1327*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker    def test_exp_pow(self):
1330*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1331*da0073e9SAndroid Build Coastguard Worker        def do_exp(x, y, z):
1332*da0073e9SAndroid Build Coastguard Worker            return ((x * y) * 2) * torch.pow(z, 2)
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
1335*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(10, dtype=torch.double, device=device)
1336*da0073e9SAndroid Build Coastguard Worker            y = torch.rand(10, dtype=torch.double, device=device)
1337*da0073e9SAndroid Build Coastguard Worker            z = torch.rand(10, dtype=torch.double, device=device)
1338*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(do_exp, (x, y, z))
1339*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, x, y, z)
1340*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1341*da0073e9SAndroid Build Coastguard Worker
1342*da0073e9SAndroid Build Coastguard Worker    def test_sin_pow(self):
1343*da0073e9SAndroid Build Coastguard Worker        def test(x):
1344*da0073e9SAndroid Build Coastguard Worker            return torch.sin(torch.pow(x, 0))
1345*da0073e9SAndroid Build Coastguard Worker
1346*da0073e9SAndroid Build Coastguard Worker        for data_type, shape in itertools.product(self.dtypes, [[3], [5], [10]]):
1347*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(shape, dtype=data_type)
1348*da0073e9SAndroid Build Coastguard Worker            scripted = torch.jit.script(test)
1349*da0073e9SAndroid Build Coastguard Worker            out = warmup_and_run_forward(scripted, x)
1350*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1351*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, test(x))
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker    def test_transpose(self):
1354*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1355*da0073e9SAndroid Build Coastguard Worker        def test(x, y, z):
1356*da0073e9SAndroid Build Coastguard Worker            return x.transpose(0, 1) + y + z
1357*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4, 5, 2, 3)
1358*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(5, 4, 2, 3)
1359*da0073e9SAndroid Build Coastguard Worker        z = torch.rand(5, 4, 2, 3)
1360*da0073e9SAndroid Build Coastguard Worker        ref = test(x, y, z)
1361*da0073e9SAndroid Build Coastguard Worker        res = test(x, y, z)
1362*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(ref.numpy(), res.numpy())
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard Worker    def test_sliced_stride(self):
1365*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1366*da0073e9SAndroid Build Coastguard Worker        def test(x, y, z):
1367*da0073e9SAndroid Build Coastguard Worker            return x + y + z
1368*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(16, 4, 2, 3)[::2]
1369*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(8, 4, 2, 3)
1370*da0073e9SAndroid Build Coastguard Worker        z = torch.rand(8, 4, 2, 3)
1371*da0073e9SAndroid Build Coastguard Worker        ref = test(x, y, z)
1372*da0073e9SAndroid Build Coastguard Worker        res = test(x, y, z)
1373*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(ref.numpy(), res.numpy())
1374*da0073e9SAndroid Build Coastguard Worker
1375*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("dynamic shapes are not quite there yet")
1376*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1377*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_shape(self):
1378*da0073e9SAndroid Build Coastguard Worker        with num_profiled_runs(2):
1379*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
1380*da0073e9SAndroid Build Coastguard Worker            def test(x, y, z):
1381*da0073e9SAndroid Build Coastguard Worker                return x * y * z
1382*da0073e9SAndroid Build Coastguard Worker            x, y, z = (torch.rand(4, 8).cuda() for _ in range(3))
1383*da0073e9SAndroid Build Coastguard Worker            ref = test(x, y, z)
1384*da0073e9SAndroid Build Coastguard Worker            _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)])
1385*da0073e9SAndroid Build Coastguard Worker            res = test(x, y, z)
1386*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
1387*da0073e9SAndroid Build Coastguard Worker
1388*da0073e9SAndroid Build Coastguard Worker            # A wild broadcast appears.
1389*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(4, 8).cuda()
1390*da0073e9SAndroid Build Coastguard Worker            y = torch.rand(1, 8).cuda()
1391*da0073e9SAndroid Build Coastguard Worker            z = torch.rand(4, 1).cuda()
1392*da0073e9SAndroid Build Coastguard Worker            res = test(x, y, z)
1393*da0073e9SAndroid Build Coastguard Worker            xn, yn, zn = (t.cpu().numpy() for t in (x, y, z))
1394*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
1395*da0073e9SAndroid Build Coastguard Worker
1396*da0073e9SAndroid Build Coastguard Worker            # Mismatched shapes shouldn't reach codegen.
1397*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(4, 8).cuda()
1398*da0073e9SAndroid Build Coastguard Worker            y = torch.rand(4, 8).cuda()
1399*da0073e9SAndroid Build Coastguard Worker            z = torch.rand(5, 8).cuda()
1400*da0073e9SAndroid Build Coastguard Worker            try:
1401*da0073e9SAndroid Build Coastguard Worker                res = test(x, y, z)
1402*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as e:
1403*da0073e9SAndroid Build Coastguard Worker                assert "The size of tensor a (4) must match" in e.args[0]
1404*da0073e9SAndroid Build Coastguard Worker
1405*da0073e9SAndroid Build Coastguard Worker            # Changing a static dimension fails guards.
1406*da0073e9SAndroid Build Coastguard Worker            # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)]
1407*da0073e9SAndroid Build Coastguard Worker            # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
1408*da0073e9SAndroid Build Coastguard Worker            # res = test(x, y, z)
1409*da0073e9SAndroid Build Coastguard Worker            # print(test.graph_for(x, y, z))
1410*da0073e9SAndroid Build Coastguard Worker            # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
1411*da0073e9SAndroid Build Coastguard Worker
1412*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1413*da0073e9SAndroid Build Coastguard Worker    def test_guard_fails(self):
1414*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1415*da0073e9SAndroid Build Coastguard Worker        def test(x, y, z):
1416*da0073e9SAndroid Build Coastguard Worker            return x * y * z
1417*da0073e9SAndroid Build Coastguard Worker        r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
1418*da0073e9SAndroid Build Coastguard Worker        r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
1419*da0073e9SAndroid Build Coastguard Worker        r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
1420*da0073e9SAndroid Build Coastguard Worker        r4 = test(*[torch.rand(7).cuda() for _ in range(3)])
1421*da0073e9SAndroid Build Coastguard Worker
1422*da0073e9SAndroid Build Coastguard Worker    def test_bitwise_ops(self):
1423*da0073e9SAndroid Build Coastguard Worker        def run_and(x, y):
1424*da0073e9SAndroid Build Coastguard Worker            return x & (x & y)
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker        def run_or(x, y):
1427*da0073e9SAndroid Build Coastguard Worker            return x & (x | y)
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker        def run_xor(x, y):
1430*da0073e9SAndroid Build Coastguard Worker            return x ^ (x ^ y)
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker        def run_lshift(x, y):
1433*da0073e9SAndroid Build Coastguard Worker            return x & (x << y)
1434*da0073e9SAndroid Build Coastguard Worker
1435*da0073e9SAndroid Build Coastguard Worker        def run_rshift(x, y):
1436*da0073e9SAndroid Build Coastguard Worker            return x & (x >> y)
1437*da0073e9SAndroid Build Coastguard Worker
1438*da0073e9SAndroid Build Coastguard Worker        fns = {run_and, run_or, run_xor, run_lshift, run_rshift}
1439*da0073e9SAndroid Build Coastguard Worker
1440*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
1441*da0073e9SAndroid Build Coastguard Worker            for fn in fns:
1442*da0073e9SAndroid Build Coastguard Worker                a = torch.ones(128, dtype=torch.int32, device=device)
1443*da0073e9SAndroid Build Coastguard Worker                b = torch.zeros(128, dtype=torch.int32, device=device)
1444*da0073e9SAndroid Build Coastguard Worker                inp = torch.ones(128, dtype=torch.int32, device=device)
1445*da0073e9SAndroid Build Coastguard Worker                traced = torch.jit.trace(fn, (inp, inp))
1446*da0073e9SAndroid Build Coastguard Worker                x = warmup_and_run_forward(traced, a, b)
1447*da0073e9SAndroid Build Coastguard Worker                self.assertLastGraphAllFused()
1448*da0073e9SAndroid Build Coastguard Worker                y = fn(a, b)
1449*da0073e9SAndroid Build Coastguard Worker                np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
1450*da0073e9SAndroid Build Coastguard Worker
1451*da0073e9SAndroid Build Coastguard Worker    def test_where(self):
1452*da0073e9SAndroid Build Coastguard Worker        def run_where(x, y):
1453*da0073e9SAndroid Build Coastguard Worker            return torch.where(torch.gt(x, y), x, y)
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker        for data_type in self.dtypes:
1456*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(1024, dtype=data_type)
1457*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(1024, dtype=data_type)
1458*da0073e9SAndroid Build Coastguard Worker            zeros = torch.zeros(1024, dtype=data_type)
1459*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(run_where, (zeros, zeros))
1460*da0073e9SAndroid Build Coastguard Worker            x = warmup_and_run_forward(traced, a, b)
1461*da0073e9SAndroid Build Coastguard Worker            self.assertLastGraphAllFused()
1462*da0073e9SAndroid Build Coastguard Worker            y = run_where(a, b)
1463*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(x.float().numpy(), y.float().numpy())
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker    def test_multi_rand(self):
1466*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
1467*da0073e9SAndroid Build Coastguard Worker            def test(x):
1468*da0073e9SAndroid Build Coastguard Worker                y = torch.rand_like(x)
1469*da0073e9SAndroid Build Coastguard Worker                return (x + y) - (y - x)
1470*da0073e9SAndroid Build Coastguard Worker
1471*da0073e9SAndroid Build Coastguard Worker            _atol = 2e-3
1472*da0073e9SAndroid Build Coastguard Worker            _rtol = 1e-5
1473*da0073e9SAndroid Build Coastguard Worker            for data_type in self.dtypes:
1474*da0073e9SAndroid Build Coastguard Worker                if data_type is torch.bfloat16:
1475*da0073e9SAndroid Build Coastguard Worker                    _atol = 2e-2
1476*da0073e9SAndroid Build Coastguard Worker                a = torch.rand(4, dtype=data_type, device=device)
1477*da0073e9SAndroid Build Coastguard Worker                scripted = torch.jit.script(test)
1478*da0073e9SAndroid Build Coastguard Worker                out = warmup_and_run_forward(scripted, a)
1479*da0073e9SAndroid Build Coastguard Worker                self.assertLastGraphAllFused()
1480*da0073e9SAndroid Build Coastguard Worker                assert torch.allclose(out, 2 * a, atol=_atol, rtol=_rtol)
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker    def test_mask(self):
1483*da0073e9SAndroid Build Coastguard Worker        def test(x):
1484*da0073e9SAndroid Build Coastguard Worker            return x.unsqueeze(1) == 0
1485*da0073e9SAndroid Build Coastguard Worker
1486*da0073e9SAndroid Build Coastguard Worker        for d in self.devices:
1487*da0073e9SAndroid Build Coastguard Worker            for data_type in self.dtypes:
1488*da0073e9SAndroid Build Coastguard Worker                x = torch.rand(4, dtype=data_type, device=d) > 0.5
1489*da0073e9SAndroid Build Coastguard Worker                scripted = torch.jit.script(test)
1490*da0073e9SAndroid Build Coastguard Worker                out = warmup_and_run_forward(scripted, x)
1491*da0073e9SAndroid Build Coastguard Worker                self.assertLastGraphAllFused()
1492*da0073e9SAndroid Build Coastguard Worker                assert torch.equal(out, test(x))
1493*da0073e9SAndroid Build Coastguard Worker
1494*da0073e9SAndroid Build Coastguard Worker    def test_simple_add(self):
1495*da0073e9SAndroid Build Coastguard Worker        val = torch._C._jit_get_te_generate_block_code()
1496*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_te_generate_block_code(True)
1497*da0073e9SAndroid Build Coastguard Worker        fall_bk = torch._C._jit_texpr_fallback_allowed()
1498*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_texpr_set_fallback_allowed(True)
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker        def simple(a, b):
1501*da0073e9SAndroid Build Coastguard Worker            return torch.add(a, b)
1502*da0073e9SAndroid Build Coastguard Worker
1503*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(256, 256)
1504*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(256, 256)
1505*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(simple,
1506*da0073e9SAndroid Build Coastguard Worker                                 (torch.ones(256, 256), torch.ones(256, 256)))
1507*da0073e9SAndroid Build Coastguard Worker        f = traced(a, b)
1508*da0073e9SAndroid Build Coastguard Worker        f_test = np.full((256, 256), 2, dtype=float)
1509*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(f.numpy(), f_test)
1510*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_te_generate_block_code(val)
1511*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_texpr_set_fallback_allowed(fall_bk)
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Worker    def test_strided_output_preserved(self):
1514*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
1515*da0073e9SAndroid Build Coastguard Worker            return a + b - a
1516*da0073e9SAndroid Build Coastguard Worker
1517*da0073e9SAndroid Build Coastguard Worker        # smaller, easier to debug example
1518*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(6)
1519*da0073e9SAndroid Build Coastguard Worker        x = torch.as_strided(x, (2, 3), (1, 2))
1520*da0073e9SAndroid Build Coastguard Worker        total = 0
1521*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
1522*da0073e9SAndroid Build Coastguard Worker            for j in range(3):
1523*da0073e9SAndroid Build Coastguard Worker                x[i, j] = total
1524*da0073e9SAndroid Build Coastguard Worker                total += 1
1525*da0073e9SAndroid Build Coastguard Worker        foo_script = torch.jit.script(foo)
1526*da0073e9SAndroid Build Coastguard Worker        foo_script(x, x)
1527*da0073e9SAndroid Build Coastguard Worker        foo_script(x, x)
1528*da0073e9SAndroid Build Coastguard Worker        out_s = foo_script(x, x)
1529*da0073e9SAndroid Build Coastguard Worker        out_eager = foo(x, x)
1530*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_s, out_eager)
1531*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_s.stride(), out_eager.stride())
1532*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
1533*da0073e9SAndroid Build Coastguard Worker
1534*da0073e9SAndroid Build Coastguard Worker        # more dims
1535*da0073e9SAndroid Build Coastguard Worker        N, C, H, W, = 2, 3, 4, 5
1536*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(N, C, H, W).to(memory_format=torch.channels_last)
1537*da0073e9SAndroid Build Coastguard Worker        foo_script = torch.jit.script(foo)
1538*da0073e9SAndroid Build Coastguard Worker        foo_script(x, x)
1539*da0073e9SAndroid Build Coastguard Worker        foo_script(x, x)
1540*da0073e9SAndroid Build Coastguard Worker        out_s = foo_script(x, x)
1541*da0073e9SAndroid Build Coastguard Worker        out_eager = foo(x, x)
1542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_s, out_eager)
1543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_s.stride(), out_eager.stride())
1544*da0073e9SAndroid Build Coastguard Worker        self.assertLastGraphAllFused()
1545*da0073e9SAndroid Build Coastguard Worker
1546*da0073e9SAndroid Build Coastguard Worker    def test_alias_analysis_module(self):
1547*da0073e9SAndroid Build Coastguard Worker        class AliasModule(nn.Module):
1548*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1549*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1550*da0073e9SAndroid Build Coastguard Worker                torch.manual_seed(1337)
1551*da0073e9SAndroid Build Coastguard Worker                self.a = torch.randn(128, 128)
1552*da0073e9SAndroid Build Coastguard Worker                self.b = torch.randn(128, 128)
1553*da0073e9SAndroid Build Coastguard Worker                self.c = torch.randn(128, 128)
1554*da0073e9SAndroid Build Coastguard Worker
1555*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y, z):
1556*da0073e9SAndroid Build Coastguard Worker                z = z + self.a
1557*da0073e9SAndroid Build Coastguard Worker                self.b.add_(y)
1558*da0073e9SAndroid Build Coastguard Worker                w = z + self.a
1559*da0073e9SAndroid Build Coastguard Worker                z = w + x
1560*da0073e9SAndroid Build Coastguard Worker                return z
1561*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(128, 128)
1562*da0073e9SAndroid Build Coastguard Worker
1563*da0073e9SAndroid Build Coastguard Worker        def getModule(script):
1564*da0073e9SAndroid Build Coastguard Worker            am = AliasModule()
1565*da0073e9SAndroid Build Coastguard Worker            if script:
1566*da0073e9SAndroid Build Coastguard Worker                return torch.jit.script(am)
1567*da0073e9SAndroid Build Coastguard Worker            return am
1568*da0073e9SAndroid Build Coastguard Worker
1569*da0073e9SAndroid Build Coastguard Worker        am = getModule(False)
1570*da0073e9SAndroid Build Coastguard Worker        am_s = getModule(True)
1571*da0073e9SAndroid Build Coastguard Worker        ref = am(x, x, x)
1572*da0073e9SAndroid Build Coastguard Worker        test = am_s(x, x, x)
1573*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref, test)
1574*da0073e9SAndroid Build Coastguard Worker
1575*da0073e9SAndroid Build Coastguard Worker        # Now do the aliasing
1576*da0073e9SAndroid Build Coastguard Worker        am.a = am.b
1577*da0073e9SAndroid Build Coastguard Worker        ref = am(x, x, x)
1578*da0073e9SAndroid Build Coastguard Worker
1579*da0073e9SAndroid Build Coastguard Worker        am_s.a = am_s.b
1580*da0073e9SAndroid Build Coastguard Worker        test = am_s(x, x, x)
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref, test)
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker    def test_alias_analysis_inputs(self):
1585*da0073e9SAndroid Build Coastguard Worker        class AliasModule(nn.Module):
1586*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1587*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1588*da0073e9SAndroid Build Coastguard Worker                torch.manual_seed(1337)
1589*da0073e9SAndroid Build Coastguard Worker                self.a = torch.randn(128, 128)
1590*da0073e9SAndroid Build Coastguard Worker                self.b = torch.randn(128, 128)
1591*da0073e9SAndroid Build Coastguard Worker                self.c = torch.randn(128, 128)
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y, z):
1594*da0073e9SAndroid Build Coastguard Worker                x.add_(y)
1595*da0073e9SAndroid Build Coastguard Worker                w = z + self.a
1596*da0073e9SAndroid Build Coastguard Worker                z = w + x
1597*da0073e9SAndroid Build Coastguard Worker                return z
1598*da0073e9SAndroid Build Coastguard Worker
1599*da0073e9SAndroid Build Coastguard Worker        def getModule(script):
1600*da0073e9SAndroid Build Coastguard Worker            am = AliasModule()
1601*da0073e9SAndroid Build Coastguard Worker            if script:
1602*da0073e9SAndroid Build Coastguard Worker                return torch.jit.script(am)
1603*da0073e9SAndroid Build Coastguard Worker            return am
1604*da0073e9SAndroid Build Coastguard Worker        am = getModule(False)
1605*da0073e9SAndroid Build Coastguard Worker        am_s = getModule(True)
1606*da0073e9SAndroid Build Coastguard Worker
1607*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1337)
1608*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(128, 128)
1609*da0073e9SAndroid Build Coastguard Worker        ref = am(x, x, x)
1610*da0073e9SAndroid Build Coastguard Worker
1611*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1337)
1612*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(128, 128)
1613*da0073e9SAndroid Build Coastguard Worker        test = am_s(x, x, x)
1614*da0073e9SAndroid Build Coastguard Worker
1615*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref, test)
1616*da0073e9SAndroid Build Coastguard Worker
1617*da0073e9SAndroid Build Coastguard Worker    def test_alias_analysis_input_and_module(self):
1618*da0073e9SAndroid Build Coastguard Worker        class AliasModule(nn.Module):
1619*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1620*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1621*da0073e9SAndroid Build Coastguard Worker                torch.manual_seed(1337)
1622*da0073e9SAndroid Build Coastguard Worker                self.a = torch.randn(128, 128)
1623*da0073e9SAndroid Build Coastguard Worker                self.b = torch.randn(128, 128)
1624*da0073e9SAndroid Build Coastguard Worker                self.c = torch.randn(128, 128)
1625*da0073e9SAndroid Build Coastguard Worker
1626*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y, z):
1627*da0073e9SAndroid Build Coastguard Worker                x.add_(y)
1628*da0073e9SAndroid Build Coastguard Worker                w = z + self.b
1629*da0073e9SAndroid Build Coastguard Worker                z = w + x
1630*da0073e9SAndroid Build Coastguard Worker                return z
1631*da0073e9SAndroid Build Coastguard Worker
1632*da0073e9SAndroid Build Coastguard Worker        def getModule(script):
1633*da0073e9SAndroid Build Coastguard Worker            am = AliasModule()
1634*da0073e9SAndroid Build Coastguard Worker            if script:
1635*da0073e9SAndroid Build Coastguard Worker                return torch.jit.script(am)
1636*da0073e9SAndroid Build Coastguard Worker            return am
1637*da0073e9SAndroid Build Coastguard Worker        am = getModule(False)
1638*da0073e9SAndroid Build Coastguard Worker        am_s = getModule(True)
1639*da0073e9SAndroid Build Coastguard Worker
1640*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1337)
1641*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(128, 128)
1642*da0073e9SAndroid Build Coastguard Worker        am.b = x
1643*da0073e9SAndroid Build Coastguard Worker        ref = am(x, x, x)
1644*da0073e9SAndroid Build Coastguard Worker
1645*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1337)
1646*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(128, 128)
1647*da0073e9SAndroid Build Coastguard Worker        am_s.b = x
1648*da0073e9SAndroid Build Coastguard Worker        test = am_s(x, x, x)
1649*da0073e9SAndroid Build Coastguard Worker
1650*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref, test)
1651*da0073e9SAndroid Build Coastguard Worker
1652*da0073e9SAndroid Build Coastguard Worker    def test_multiple_outputs(self):
1653*da0073e9SAndroid Build Coastguard Worker        for device in self.devices:
1654*da0073e9SAndroid Build Coastguard Worker            # A bug reported internally similar to the one reported in #48533
1655*da0073e9SAndroid Build Coastguard Worker            def foo(a, b, c):
1656*da0073e9SAndroid Build Coastguard Worker                t_next = c + 1
1657*da0073e9SAndroid Build Coastguard Worker                t5 = t_next * b
1658*da0073e9SAndroid Build Coastguard Worker                t6 = torch.unsqueeze(t_next, 1)
1659*da0073e9SAndroid Build Coastguard Worker                t7 = a * t6
1660*da0073e9SAndroid Build Coastguard Worker                return (t7, t5, t_next)
1661*da0073e9SAndroid Build Coastguard Worker
1662*da0073e9SAndroid Build Coastguard Worker            for data_type in self.dtypes:
1663*da0073e9SAndroid Build Coastguard Worker                a = torch.rand(20, 20, dtype=data_type, device=device)
1664*da0073e9SAndroid Build Coastguard Worker                b = torch.rand(20 * 29, dtype=data_type, device=device).as_strided([20], [29])
1665*da0073e9SAndroid Build Coastguard Worker                c = torch.ones(20, dtype=torch.int64, device=device)
1666*da0073e9SAndroid Build Coastguard Worker                traced = torch.jit.trace(foo, (a, b, c))
1667*da0073e9SAndroid Build Coastguard Worker                ref = foo(a, b, c)
1668*da0073e9SAndroid Build Coastguard Worker                exp = traced(a, b, c)
1669*da0073e9SAndroid Build Coastguard Worker                exp = traced(a, b, c)
1670*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ref, exp)
1671*da0073e9SAndroid Build Coastguard Worker
1672*da0073e9SAndroid Build Coastguard Worker    def test_propagated_mem_layout(self):
1673*da0073e9SAndroid Build Coastguard Worker        def foo(a, b, c):
1674*da0073e9SAndroid Build Coastguard Worker            t_next = c + 1
1675*da0073e9SAndroid Build Coastguard Worker            t5 = t_next * b
1676*da0073e9SAndroid Build Coastguard Worker            t7 = a * t5
1677*da0073e9SAndroid Build Coastguard Worker            return t7
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker        def foo_multi_outputs(a, b, c):
1680*da0073e9SAndroid Build Coastguard Worker            t_next = c + 1
1681*da0073e9SAndroid Build Coastguard Worker            t5 = b * t_next
1682*da0073e9SAndroid Build Coastguard Worker            t7 = a * t5
1683*da0073e9SAndroid Build Coastguard Worker            return (t7, t5, t_next)
1684*da0073e9SAndroid Build Coastguard Worker
1685*da0073e9SAndroid Build Coastguard Worker        def foo_multi_outputs_i_nhwc_o_nchw(a, b, c):
1686*da0073e9SAndroid Build Coastguard Worker            t_next = c + 1
1687*da0073e9SAndroid Build Coastguard Worker            t5 = b * t_next
1688*da0073e9SAndroid Build Coastguard Worker            t7 = a * t5
1689*da0073e9SAndroid Build Coastguard Worker            t8 = t7.to(memory_format=torch.contiguous_format)
1690*da0073e9SAndroid Build Coastguard Worker            return (t8, t7, t5, t_next)
1691*da0073e9SAndroid Build Coastguard Worker
1692*da0073e9SAndroid Build Coastguard Worker        def run_foo_case(foo, a, b, c):
1693*da0073e9SAndroid Build Coastguard Worker            traced_contiguous = torch.jit.trace(foo, (a, b, c))
1694*da0073e9SAndroid Build Coastguard Worker            ref = foo(a, b, c)
1695*da0073e9SAndroid Build Coastguard Worker            exp = traced_contiguous(a, b, c)
1696*da0073e9SAndroid Build Coastguard Worker            exp = traced_contiguous(a, b, c)
1697*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, exp)
1698*da0073e9SAndroid Build Coastguard Worker
1699*da0073e9SAndroid Build Coastguard Worker        mem_layouts = list(itertools.product([torch.contiguous_format, torch.channels_last], repeat=3))
1700*da0073e9SAndroid Build Coastguard Worker        shapes = [(2, 3, 4, 5), (2, 1, 1, 5), (1, 1, 1, 1)]
1701*da0073e9SAndroid Build Coastguard Worker        permutes = [(0, 3, 2, 1), (0, 3, 1, 2)]
1702*da0073e9SAndroid Build Coastguard Worker        funcs = [foo, foo_multi_outputs, foo_multi_outputs_i_nhwc_o_nchw]
1703*da0073e9SAndroid Build Coastguard Worker        configs = itertools.product(funcs, shapes, mem_layouts, permutes)
1704*da0073e9SAndroid Build Coastguard Worker        for strategy in ["STATIC", "DYNAMIC"]:
1705*da0073e9SAndroid Build Coastguard Worker            old_strategy = torch.jit.set_fusion_strategy([(strategy, 10)])
1706*da0073e9SAndroid Build Coastguard Worker            for _func, _shape, _mem_layouts, _permute in configs:
1707*da0073e9SAndroid Build Coastguard Worker                a = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[0])
1708*da0073e9SAndroid Build Coastguard Worker                b = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[1])
1709*da0073e9SAndroid Build Coastguard Worker                c = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[2])
1710*da0073e9SAndroid Build Coastguard Worker                run_foo_case(_func, a, b, c)
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker                a = a.permute(dims=_permute)
1713*da0073e9SAndroid Build Coastguard Worker                b = b.permute(dims=_permute)
1714*da0073e9SAndroid Build Coastguard Worker                c = c.permute(dims=_permute)
1715*da0073e9SAndroid Build Coastguard Worker                run_foo_case(_func, a, b, c)
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker            torch.jit.set_fusion_strategy(old_strategy)
1718*da0073e9SAndroid Build Coastguard Worker
1719*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
1720*da0073e9SAndroid Build Coastguard Worker    run_tests()
1721