xref: /aosp_15_r20/external/pytorch/test/test_tensorexpr.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["NNC"]
2
3import numpy as np
4import torch
5import torch.nn.functional as F
6from torch import nn
7import unittest
8import itertools
9
10from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo
11
12from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions
13
14LLVM_ENABLED = torch._C._llvm_enabled()
15
16class BaseTestClass(JitTestCase):
17    def setUp(self):
18        super().setUp()
19        self.tensorexpr_options = TensorExprTestOptions()
20        self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
21        self.dtypes = [torch.float32, torch.bfloat16] if LLVM_ENABLED else [torch.float32]
22
23    def tearDown(self):
24        self.tensorexpr_options.restore()
25        super().tearDown()
26
27    def assertLastGraphAllFused(self):
28        self.assertAllFused(torch.jit.last_executed_optimized_graph())
29
30
31def warmup_and_run_forward(f, *args):
32    for _ in range(torch._C._jit_get_num_profiled_runs() + 1):
33        results = f(*args)
34    return results
35
36
37@skipIfTorchDynamo()
38class TestTensorExprFuser(BaseTestClass):
39    def test_easy(self):
40        def easy(x, y):
41            aaa = torch.add(x, y)
42            return aaa
43
44        traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024)))
45
46        a = torch.rand(1024)
47        b = torch.rand(1024)
48        x = warmup_and_run_forward(traced, a, b)
49        self.assertLastGraphAllFused()
50        np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
51
52    def test_three_arg(self):
53        def easy(x, y, z):
54            aaa = torch.add(x, y)
55            bbb = torch.add(aaa, z)
56            return bbb
57
58        traced = torch.jit.trace(
59            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
60        )
61
62        a = torch.rand(1024)
63        b = torch.rand(1024)
64        c = torch.rand(1024)
65        x = warmup_and_run_forward(traced, a, b, c)
66        self.assertLastGraphAllFused()
67        npr = a.numpy() + b.numpy() + c.numpy()
68        np.testing.assert_allclose(npr, x.numpy())
69
70    def test_four_arg(self):
71        def run_addcmul(x, y, z, w):
72            c = torch.addcmul(torch.add(x, y), z, w)
73            return c
74
75        for dev in self.devices:
76            rand_a = torch.rand(1024, dtype=torch.float, device=dev)
77            rand_b = torch.rand(1024, dtype=torch.float, device=dev)
78            rand_c = torch.rand(1024, dtype=torch.float, device=dev)
79            rand_d = torch.rand(1024, dtype=torch.float, device=dev)
80
81            traced = torch.jit.trace(
82                run_addcmul,
83                (
84                    torch.zeros(1024, dtype=torch.float, device=dev),
85                    torch.zeros(1024, dtype=torch.float, device=dev),
86                    torch.zeros(1024, dtype=torch.float, device=dev),
87                    torch.zeros(1024, dtype=torch.float, device=dev),
88                ),
89            )
90
91            x = warmup_and_run_forward(traced, rand_a, rand_b, rand_c, rand_d)
92            self.assertLastGraphAllFused()
93            y = run_addcmul(rand_a, rand_b, rand_c, rand_d)
94            np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6)
95
96    def test_three_arg2(self):
97        for device in self.devices:
98            def test(x, y, z):
99                aaa = torch.add(x, y)
100                bbb = torch.add(aaa, z)
101                return bbb
102
103            M = 32
104            N = 32
105            traced = torch.jit.trace(
106                test,
107                (
108                    torch.rand(M, N, device=device),
109                    torch.rand(M, N, device=device),
110                    torch.rand(M, N, device=device),
111                ),
112            )
113
114            a = torch.rand(M, N, device=device)
115            b = torch.rand(M, N, device=device)
116            c = torch.rand(M, N, device=device)
117            x = traced(a, b, c)
118            x = warmup_and_run_forward(traced, a, b, c)
119            self.assertLastGraphAllFused()
120            npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
121            np.testing.assert_allclose(npr, x.cpu().numpy())
122
123    def test_broadcast3(self):
124        for device in self.devices:
125            def test_body(M, N, L, K):
126                def test(x, y, z):
127                    v1 = torch.add(x, y)
128                    v2 = torch.add(v1, z)
129                    return v2
130
131                a_shape = [M, N]
132                b_shape = [L, M, 1]
133                c_shape = [K, L, 1, 1]
134                traced = torch.jit.trace(
135                    test,
136                    (
137                        torch.rand(*a_shape, device=device),
138                        torch.rand(*b_shape, device=device),
139                        torch.rand(*c_shape, device=device),
140                    ),
141                )
142
143                a = torch.rand(*a_shape, device=device)
144                b = torch.rand(*b_shape, device=device)
145                c = torch.rand(*c_shape, device=device)
146                x = warmup_and_run_forward(traced, a, b, c)
147                self.assertLastGraphAllFused()
148                npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
149                np.testing.assert_allclose(npr, x.cpu().numpy())
150
151            test_configs = [[5, 2, 7, 3], [8, 8, 8, 8]]
152            for test_config in test_configs:
153                test_body(*test_config)
154
155    def test_all_combos(self):
156        def easy(x, y, z):
157            a = torch.add(x, y)
158            b = torch.add(a, z)
159            c = torch.add(x, b)
160            d = torch.add(c, a)
161            return d
162
163        def np_easy(x, y, z):
164            a = x + y
165            b = a + z
166            c = x + b
167            d = c + a
168            return d
169
170        traced = torch.jit.trace(
171            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
172        )
173
174        a = torch.rand(1024)
175        b = torch.rand(1024)
176        c = torch.rand(1024)
177        x = warmup_and_run_forward(traced, a, b, c)
178        self.assertLastGraphAllFused()
179        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
180        np.testing.assert_allclose(npr, x.numpy())
181
182    def test_rank_two(self):
183        def easy(x, y, z):
184            a = torch.add(x, y)
185            b = torch.add(a, z)
186            c = torch.add(x, b)
187            d = torch.add(c, a)
188            return d
189
190        def np_easy(x, y, z):
191            a = x + y
192            b = a + z
193            c = x + b
194            d = c + a
195            return d
196
197        shape = 32, 32
198        traced = torch.jit.trace(
199            easy, (torch.rand(shape), torch.rand(shape), torch.rand(shape))
200        )
201
202        a = torch.rand(shape)
203        b = torch.rand(shape)
204        c = torch.rand(shape)
205        x = warmup_and_run_forward(traced, a, b, c)
206        self.assertLastGraphAllFused()
207        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
208        np.testing.assert_allclose(npr, x.numpy())
209
210    def test_broadcast(self):
211        def easy(x, y, z):
212            a = torch.add(x, y)
213            b = torch.add(a, z)
214            return b
215
216        def np_easy(x, y, z):
217            a = x + y
218            b = a + z
219            return b
220
221        N = 32
222        traced = torch.jit.trace(easy, (torch.rand(N, N), torch.rand(N), torch.rand(N, N)))
223
224        a = torch.rand(N, N)
225        b = torch.rand(N)
226        c = torch.rand(N, N)
227        x = warmup_and_run_forward(traced, a, b, c)
228        self.assertLastGraphAllFused()
229        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
230        np.testing.assert_allclose(npr, x.numpy())
231
232    def test_broadcast_2(self):
233        zero = torch.tensor([0.0], dtype=torch.float)
234
235        def foo(x, y, z):
236            aaa = torch.add(x, y)
237            bbb = torch.add(zero, aaa)
238            return torch.add(bbb, z)
239
240        def foo_np(x, y, z):
241            a = x + y
242            b = zero.numpy() + a
243            return b + z
244
245        x = torch.rand(3, 4)
246        y = torch.ones(3, 1)
247        z = torch.rand(4)
248        traced = torch.jit.trace(foo, (x, y, z))
249
250        r = warmup_and_run_forward(traced, x, y, z)
251        self.assertLastGraphAllFused()
252
253        rnp = foo_np(x.numpy(), y.numpy(), z.numpy())
254        np.testing.assert_allclose(r, rnp)
255
256    def test_broadcast_big2(self):
257        zero = torch.tensor([0.0], dtype=torch.float)
258
259        def foo(x, y, z):
260            aaa = torch.add(x, y)
261            bbb = torch.add(zero, aaa)
262            return torch.add(bbb, z)
263
264        def foo_np(x, y, z):
265            a = x + y
266            b = zero.numpy() + a
267            return b + z
268
269        x = torch.rand(32, 1024)
270        y = torch.ones(32, 1)
271        z = torch.rand(1024)
272        traced = torch.jit.trace(foo, (x, y, z))
273
274        r = warmup_and_run_forward(traced, x, y, z)
275        self.assertLastGraphAllFused()
276        rnp = foo_np(x.numpy(), y.numpy(), z.numpy())
277        np.testing.assert_allclose(r, rnp)
278
279    def test_alpha(self):
280        def alpha(x):
281            aaa = torch.add(x, x, alpha=2.0)
282            return aaa
283
284        traced = torch.jit.trace(alpha, (torch.tensor([1.0])))
285
286        a = torch.tensor([1.0])
287        x = traced(a)
288        np.testing.assert_allclose(a.numpy() + 2.0 * a.numpy(), x.numpy())
289
290    @suppress_warnings
291    def test_constant(self):
292        def constant(x):
293            bbb = torch.tensor([1.0])
294            aaa = torch.add(x, bbb)
295            return aaa
296
297        traced = torch.jit.trace(constant, (torch.tensor([1.0])))
298
299        a = torch.tensor([1.0])
300        x = warmup_and_run_forward(traced, a)
301        self.assertLastGraphAllFused()
302        np.testing.assert_allclose(a.numpy() + 1.0, x.numpy())
303
304    def test_add_sub(self):
305        def easy(x, y, z):
306            aaa = torch.add(x, y)
307            bbb = torch.sub(aaa, z)
308            return bbb
309
310        traced = torch.jit.trace(
311            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
312        )
313
314        a = torch.rand(1024)
315        b = torch.rand(1024)
316        c = torch.rand(1024)
317        x = warmup_and_run_forward(traced, a, b, c)
318        self.assertLastGraphAllFused()
319        np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy())
320
321    def test_promotion(self):
322        def easy(x, y):
323            aaa = torch.add(x, y)
324            return aaa
325
326        traced = torch.jit.trace(
327            easy,
328            (torch.zeros(1024, dtype=torch.int32), torch.rand(1024, dtype=torch.float32)),
329        )
330
331        a = torch.zeros(1024, dtype=torch.int32)
332        b = torch.rand(1024, dtype=torch.float32)
333        x = warmup_and_run_forward(traced, a, b)
334        self.assertLastGraphAllFused()
335        np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
336
337    def test_double(self):
338        TENSOR_LEN = 8
339
340        def easy(x, y):
341            aaa = torch.add(x, y)
342            bbb = torch.mul(aaa, y)
343            return bbb
344
345        traced = torch.jit.trace(
346            easy,
347            (torch.rand(TENSOR_LEN, dtype=torch.float64), torch.full((TENSOR_LEN,), 0.5, dtype=torch.float64)),
348        )
349
350        a = torch.rand(TENSOR_LEN, dtype=torch.double)
351        b = torch.full((TENSOR_LEN,), 0.5, dtype=torch.double)
352        x = warmup_and_run_forward(traced, a, b)
353        self.assertLastGraphAllFused()
354        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
355
356    def test_short(self):
357        TENSOR_LEN = 8
358
359        def easy(x, y):
360            aaa = torch.add(x, y)
361            bbb = torch.mul(aaa, y)
362            return bbb
363
364        traced = torch.jit.trace(
365            easy,
366            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16),
367             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)),
368        )
369
370        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)
371        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)
372        x = warmup_and_run_forward(traced, a, b)
373        self.assertLastGraphAllFused()
374        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
375
376    def test_char(self):
377        TENSOR_LEN = 8
378
379        def easy(x, y):
380            aaa = torch.add(x, y)
381            bbb = torch.mul(aaa, y)
382            return bbb
383
384        traced = torch.jit.trace(
385            easy,
386            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8),
387             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)),
388        )
389
390        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
391        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
392        x = warmup_and_run_forward(traced, a, b)
393        self.assertLastGraphAllFused()
394        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
395
396    def test_int64_promotion(self):
397        TENSOR_LEN = 8
398
399        def easy(x, y):
400            aaa = torch.add(x, y)
401            bbb = torch.mul(aaa, y)
402            return bbb
403
404        traced = torch.jit.trace(
405            easy,
406            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8),
407             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)),
408        )
409
410        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
411        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)
412        x = warmup_and_run_forward(traced, a, b)
413        self.assertLastGraphAllFused()
414        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
415
416    def test_eq(self):
417        def easy(x, y):
418            c = torch.eq(x, y)
419            return c
420
421        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
422        a = torch.zeros(1024, dtype=torch.int32)
423        b = torch.zeros(1024, dtype=torch.int32)
424        x = warmup_and_run_forward(traced, a, b)
425        self.assertLastGraphAllFused()
426        np.testing.assert_allclose(np.ones(1024), x.numpy())
427
428    def test_ne(self):
429        def easy(x, y):
430            c = torch.ne(x, y)
431            return c
432
433        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
434        a = torch.zeros(1024, dtype=torch.int32)
435        b = torch.ones(1024, dtype=torch.int32)
436        x = warmup_and_run_forward(traced, a, b)
437        self.assertLastGraphAllFused()
438        np.testing.assert_allclose(np.ones(1024), x.numpy())
439
440    def test_ge(self):
441        def easy(x, y):
442            c = torch.ge(x, y)
443            return c
444
445        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
446        aa = np.empty([1024], dtype=np.int32)
447        aa.fill(5)
448        a = torch.from_numpy(aa)
449        b = torch.zeros(1024, dtype=torch.int32)
450        x = warmup_and_run_forward(traced, a, b)
451        self.assertLastGraphAllFused()
452        np.testing.assert_allclose(np.ones(1024), x.numpy())
453
454    def test_gt(self):
455        def easy(x, y):
456            c = torch.gt(x, y)
457            return c
458
459        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
460        a = torch.ones(1024, dtype=torch.int32)
461        b = torch.zeros(1024, dtype=torch.int32)
462        x = warmup_and_run_forward(traced, a, b)
463        self.assertLastGraphAllFused()
464        np.testing.assert_allclose(np.ones(1024), x.numpy())
465
466    def test_le(self):
467        def easy(x, y):
468            c = torch.le(x, y)
469            return c
470
471        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
472        aa = np.empty([1024], dtype=np.int32)
473        aa.fill(5)
474        a = torch.from_numpy(aa)
475        b = torch.zeros(1024, dtype=torch.int32)
476        x = warmup_and_run_forward(traced, a, b)
477        self.assertLastGraphAllFused()
478        np.testing.assert_allclose(np.zeros(1024), x.numpy())
479
480    def test_lt(self):
481        def easy(x, y):
482            c = torch.lt(x, y)
483            return c
484
485        for dev in self.devices:
486            traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev)))
487            a = torch.ones(1024, dtype=torch.int32, device=dev)
488            b = torch.zeros(1024, dtype=torch.int32, device=dev)
489            x = warmup_and_run_forward(traced, a, b)
490            self.assertLastGraphAllFused()
491            np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy())
492
493    @suppress_warnings
494    def test_min_max(self):
495        def test(x, y):
496            return torch.max(torch.min(x, y), torch.tensor([4.0]))
497
498        traced = torch.jit.trace(test, (torch.zeros(1024), torch.zeros(1024)))
499        a = 8.0 * torch.rand(1024)
500        b = 8.0 * torch.rand(1024)
501        np.testing.assert_allclose(
502            warmup_and_run_forward(traced, a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0])
503        )
504        self.assertLastGraphAllFused()
505
506    def test_min_max_reduction(self):
507        def test(x):
508            return torch.min(x) + torch.max(x)
509
510        traced = torch.jit.trace(test, (torch.zeros(1024)))
511        a = 8.0 * torch.rand(1024)
512        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy()))
513        self.assertLastGraphAllFused()
514
515    def test_min_max_reduction2(self):
516        def test(x):
517            return x.min() + x.max()
518
519        traced = torch.jit.trace(test, (torch.zeros(1024)))
520        a = 8.0 * torch.rand(1024)
521        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy()))
522        self.assertLastGraphAllFused()
523
524    def test_min_max_reduction_dim1(self):
525        def test(x):
526            return torch.min(x, 1)[0] + torch.max(x, 1)[0]
527
528        traced = torch.jit.trace(test, (torch.zeros(16, 16)))
529        a = 8.0 * torch.rand(16, 16)
530        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(
531            a.numpy(), axis=1) + np.amax(a.numpy(), axis=1))
532        self.assertLastGraphAllFused()
533
534    def test_min_max_reduction_dim1_2(self):
535        def test(x):
536            return torch.min(x * x, 1)
537
538        traced = torch.jit.trace(test, (torch.zeros(16, 16)))
539        a = 8.0 * torch.rand(16, 16)
540        np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin((a * a).numpy(), axis=1))
541        self.assertLastGraphAllFused()
542
543    def test_clamp(self):
544        def test(x):
545            return torch.clamp(x + 3.0, 0.0, 6.0)
546
547        for dev in self.devices:
548            traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
549            a = 20.0 * torch.rand(1024, device=dev) - 10.0
550            an = a.cpu().numpy()
551            np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip(an + 3.0, 0.0, 6.0))
552            self.assertLastGraphAllFused()
553
554    def test_relu(self):
555        def test(x):
556            return torch.clamp(F.relu(x), 0, 0.5)
557
558        for dev in self.devices:
559            traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
560            a = 20.0 * torch.rand(1024, device=dev) - 10.0
561            an = a.cpu().numpy()
562            np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5))
563            self.assertLastGraphAllFused()
564
565    def test_reps(self):
566        def easy(x, y):
567            c = torch.add(x, y)
568            return c
569
570        traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024)))
571
572        for _ in range(32):
573            a = torch.ones(1024)
574            b = torch.zeros(1024)
575            x = warmup_and_run_forward(traced, a, b)
576            np.testing.assert_allclose(np.ones(1024), x.numpy())
577
578    def test_add_const_rhs(self):
579        def test(x):
580            return x + 3.0
581
582        traced = torch.jit.trace(test, torch.rand(4))
583        x = torch.rand(4)
584        y = warmup_and_run_forward(traced, x)
585        self.assertLastGraphAllFused()
586        np.testing.assert_allclose(x.numpy() + 3.0, y.numpy())
587
588    def test_int_output(self):
589        def test(x, y, z):
590            return x * y * z
591
592        xs = [(torch.rand(4) * 3 + 1).to(torch.int32) for i in range(3)]
593        x, y, z = xs
594        xn, yn, zn = (t.numpy() for t in xs)
595        traced = torch.jit.trace(test, (x, y, z))
596        res = warmup_and_run_forward(traced, x, y, z)
597        self.assertLastGraphAllFused()
598        np.testing.assert_allclose(xn * yn * zn, res.numpy())
599
600    def test_binary_ops(self):
601        def test_atan2(x, y):
602            c = torch.atan2(torch.add(x, y), y)
603            return c
604
605        def test_gt(x, y):
606            c = torch.gt(torch.add(x, y), y)
607            return c
608
609        def test_ge(x, y):
610            c = torch.ge(torch.add(x, y), y)
611            return c
612
613        def test_lt(x, y):
614            c = torch.lt(torch.add(x, y), y)
615            return c
616
617        def test_le(x, y):
618            c = torch.le(torch.add(x, y), y)
619            return c
620
621        def test_lerp(x, y):
622            c = torch.lerp(torch.add(x, 1), x, 2.0)
623            return c
624
625        def test_mul(x, y):
626            c = torch.mul(torch.add(x, y), y)
627            return c
628
629        def test_ne(x, y):
630            c = torch.ne(torch.add(x, y), y)
631            return c
632
633        def test_div(x, y):
634            c = torch.div(torch.add(x, y), 2)
635            return c
636
637        def test_eq(x, y):
638            c = torch.eq(torch.add(x, y), y)
639            return c
640
641        def test_fmod(x, y):
642            c = torch.fmod(torch.add(x, y), 2)
643            return c
644
645        def test_sub(x, y):
646            c = torch.sub(torch.add(x, y), x)
647            return c
648
649        def test_remainder(x, y):
650            c = torch.remainder(torch.add(x, y), 3.0)
651            return c
652
653        def test_pow(x, y):
654            c = torch.pow(torch.add(x, y), 2.0)
655            return c
656
657        def test_type_as(x, y):
658            return x.type_as(torch.add(x, y))
659
660        cmp_fns = {
661            test_gt,
662            test_ge,
663            test_lt,
664            test_le,
665            test_ne,
666            test_eq
667        }
668
669        non_cmp_fns = {
670            test_atan2,
671            test_lerp,
672            test_mul,
673            test_div,
674            test_fmod,
675            test_sub,
676            test_remainder,
677            test_pow,
678            test_type_as,
679        }
680
681        all_test_fns = cmp_fns.union(non_cmp_fns)
682        fn_dev_dtype = itertools.product(all_test_fns, self.devices, self.dtypes)
683        for torch_fn, dev, data_type in fn_dev_dtype:
684            if torch_fn is test_lerp and data_type is torch.bfloat16:
685                continue
686            rand_a = torch.rand(1024, dtype=data_type, device=dev)
687            rand_b = torch.rand(1024, dtype=data_type, device=dev)
688            in1 = 20 * torch.rand(1024, dtype=data_type, device=dev)
689            in2 = 20 * torch.rand(1024, dtype=data_type, device=dev)
690            traced = torch.jit.trace(torch_fn, (in1, in2))
691            x = warmup_and_run_forward(traced, rand_a, rand_b)
692            self.assertLastGraphAllFused()
693
694            _atol = 2e-3
695            _rtol = 1e-5
696            if data_type is torch.bfloat16:
697                # Compared to aten logic, NNC coudl save addtional BF16/Fp32 conversion.
698                # Take d = a + b - c as an example, the aten logic is as follows at
699                # operator level:
700                #    tmp = to_bf16(to_fp32(a) + to_fp32(b))
701                #    d = to_bf16(to_fp32(tmp) + to_fp32(c))
702                # But NNC could fuse the compression and remove the redudant conversions.
703                # The final statement is as follows
704                #    d = to_bf16(to_fp32(a) + to_fp32(b) + to_fp32(c))
705                # Hence, we simulate NNC computation by feeding fp32 tensors and converting
706                # the result tensor back to bf16. The simulation could avoid the numeric
707                # deviation to simplify the result comprasion
708                y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float())
709                if torch_fn not in cmp_fns:
710                    y = y.bfloat16()
711                _atol = 2e-2
712            else:
713                y = torch_fn(rand_a, rand_b)
714            self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol)
715
716    def test_unary_ops(self):
717        def test_cast_float(x, y):
718            c = torch.ops.aten._cast_Float(torch.add(x, y))
719            return c
720
721        def test_round(x, y):
722            c = torch.round(torch.add(x, y))
723            return c
724
725        def test_sin(x, y):
726            c = torch.sin(torch.add(x, y))
727            return c
728
729        def test_asin(x, y):
730            c = torch.asin(torch.add(x, y))
731            return c
732
733        def test_sinh(x, y):
734            c = torch.sinh(torch.add(x, y))
735            return c
736
737        def test_cos(x, y):
738            c = torch.cos(torch.add(x, y))
739            return c
740
741        def test_acos(x, y):
742            c = torch.acos(torch.add(x, y))
743            return c
744
745        def test_cosh(x, y):
746            c = torch.cosh(torch.add(x, y))
747            return c
748
749        def test_tan(x, y):
750            c = torch.tan(torch.add(x, y))
751            return c
752
753        def test_atan(x, y):
754            c = torch.atan(torch.add(x, y))
755            return c
756
757        def test_tanh(x, y):
758            c = torch.tanh(torch.add(x, y))
759            return c
760
761        def test_sqrt(x, y):
762            c = torch.sqrt(torch.add(x, y))
763            return c
764
765        def test_rsqrt(x, y):
766            c = torch.rsqrt(torch.add(x, y))
767            return c
768
769        def test_floor(x, y):
770            c = torch.floor(torch.add(x, y))
771            return c
772
773        def test_ceil(x, y):
774            c = torch.ceil(torch.add(x, y))
775            return c
776
777        def test_trunc(x, y):
778            c = torch.trunc(torch.add(x, y))
779            return c
780
781        def test_abs(x, y):
782            c = torch.abs(torch.add(x, y))
783            return c
784
785        def test_log(x, y):
786            c = torch.log(torch.add(x, y))
787            return c
788
789        def test_log2(x, y):
790            c = torch.log2(torch.add(x, y))
791            return c
792
793        def test_log10(x, y):
794            c = torch.log10(torch.add(x, y))
795            return c
796
797        def test_log1p(x, y):
798            c = torch.log1p(torch.add(x, y))
799            return c
800
801        def test_rqrt(x, y):
802            c = torch.rsqrt(torch.add(x, y))
803            return c
804
805        def test_erf(x, y):
806            c = torch.erf(torch.add(x, y))
807            return c
808
809        def test_exp(x, y):
810            c = torch.exp(torch.add(x, y))
811            return c
812
813        def test_expm1(x, y):
814            c = torch.expm1(torch.add(x, y))
815            return c
816
817        def test_erfc(x, y):
818            c = torch.erfc(torch.add(x, y))
819            return c
820
821        def test_frac(x, y):
822            c = torch.frac(torch.add(x, y))
823            return c
824
825        def test_lgamma(x, y):
826            c = torch.lgamma(torch.add(x, y))
827            return c
828
829        def test_sigmoid(x, y):
830            c = torch.sigmoid(torch.add(x, y))
831            return c
832
833        def test_reciprocal(x, y):
834            c = torch.reciprocal(torch.add(x, y))
835            return c
836
837        def test_neg(x, y):
838            c = torch.neg(torch.add(x, y))
839            return c
840
841        def test_relu(x, y):
842            c = torch.relu(torch.add(x, y))
843            return c
844
845        def test_hardtanh(x, y):
846            c = F.hardtanh(torch.add(x, y), -1.0, 1.0)
847            return c
848
849        def test_threshold(x, y):
850            c = F.threshold(torch.add(x, y), 0.5, 10)
851            return c
852
853        gpu_only_fns = {
854            test_erf,
855            test_erfc
856        }
857        fns = {
858            test_round,
859            test_sin,
860            test_asin,
861            test_sinh,
862            test_cos,
863            test_acos,
864            test_cosh,
865            test_tan,
866            test_atan,
867            test_sqrt,
868            test_floor,
869            test_ceil,
870            test_trunc,
871            test_abs,
872            test_log,
873            test_log2,
874            test_log10,
875            test_log1p,
876            test_rsqrt,
877            test_exp,
878            test_expm1,
879            test_frac,
880            test_lgamma,
881            test_reciprocal,
882            test_neg,
883            test_threshold,
884            test_relu,
885            test_tanh,
886            test_hardtanh,
887            test_sigmoid,
888        }
889        fn_dev_dtype = itertools.product(gpu_only_fns.union(fns), self.devices, self.dtypes)
890
891        torch.manual_seed(0)
892        for torch_fn, dev, data_type in fn_dev_dtype:
893            if torch_fn == test_lgamma and dev == "cuda":
894                # lgamma_cuda does not support BF16
895                continue
896            rand_a = torch.rand(1024, dtype=data_type, device=dev)
897            rand_b = torch.rand(1024, dtype=data_type, device=dev)
898
899            ins = 20 * torch.rand(1024, dtype=data_type, device=dev)
900            cc = np.empty([1024], dtype=np.float32)
901            cc.fill(np.nan)
902            nans = torch.from_numpy(cc).to(dev)
903            traced = torch.jit.trace(torch_fn, (ins, ins))
904            x = warmup_and_run_forward(traced, rand_a, rand_b)
905            self.assertLastGraphAllFused()
906
907            _atol = 5e-3 if data_type is torch.bfloat16 else 2e-3
908            _rtol = 1e-5
909            if data_type is torch.bfloat16 and torch_fn not in gpu_only_fns:
910                y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float())
911                y = y.bfloat16()
912            else:
913                y = torch_fn(rand_a, rand_b)
914
915            self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol)
916            # nans
917            # TODO: reenable. Currently all of the tests fail
918            # traced = torch.jit.trace(torch_fn, (ins, ins))
919            # x = warmup_and_run_forward(traced, rand_a, rand_b)
920            # y = torch_fn(nans, rand_b)
921            # try:
922            #     np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
923            #     print("Succeeded on dev=", dev, "function=", torch_fn)
924            # except AssertionError:
925            #     # Print extra info before exiting:
926            #     print("Failed on dev=", dev, "function=", torch_fn)
927            #     # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
928
929
930    def test_round_2(self):
931        def round(x):
932            return torch.round(x)
933
934        for data_type in [torch.float32, torch.double]:
935            a = torch.tensor([0.2, 1.6, 2.5, 3.5]).to(data_type)
936            traced = torch.jit.trace(round, (a))
937            x = warmup_and_run_forward(traced, a)
938            self.assertLastGraphAllFused()
939            y = round(x)
940            self.assertEqual(x, y)
941
942    def test_rand_like(self):
943        N = 1 << 16
944
945        def run_rand_like(x, y):
946            return torch.rand_like(torch.add(x, y))
947
948        for device in self.devices:
949            x = torch.rand(N, device=device)
950            traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False)
951
952            for data_type in self.dtypes:
953                _x = x.to(dtype=data_type)
954                x_v = warmup_and_run_forward(traced, _x, _x)
955                self.assertLastGraphAllFused()
956
957            x_np = x.cpu().numpy()
958            x1_mean = np.mean(x_np)
959            x2_mean = np.mean(x_np ** 2)
960            x3_mean = np.mean(x_np ** 3)
961            np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2)
962            np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2)
963            np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2)
964
965    def test_nans(self):
966        def test_max(x, y):
967            return torch.max(2 * x, 2 * y)
968
969        def test_min(x, y):
970            return torch.min(2 * x, 2 * y)
971
972        tmax = torch.jit.trace(test_max, (torch.rand(1), torch.rand(1)))
973        tmin = torch.jit.trace(test_min, (torch.rand(1), torch.rand(1)))
974
975        for data_type in self.dtypes:
976            x = torch.tensor([np.nan]).to(dtype=data_type)
977            y = torch.tensor([1.0]).to(dtype=data_type)
978
979        assert np.isnan(warmup_and_run_forward(tmin, x, y).float().item())
980        assert np.isnan(warmup_and_run_forward(tmin, y, x).float().item())
981        self.assertLastGraphAllFused()
982        assert np.isnan(warmup_and_run_forward(tmax, x, y).float().item())
983        assert np.isnan(warmup_and_run_forward(tmax, y, x).float().item())
984        self.assertLastGraphAllFused()
985
986    def test_double_intrinsics(self):
987        def do_pow(x):
988            return torch.pow(x, 7)
989
990        for device in self.devices:
991            x = torch.rand(10, dtype=torch.double, device=device)
992            traced = torch.jit.trace(do_pow, (x))
993            x = warmup_and_run_forward(traced, x)
994            self.assertLastGraphAllFused()
995
996    def test_remainder(self):
997        def run_remainder(x, y):
998            c = torch.remainder(torch.add(x, y), x)
999            return c
1000
1001        for data_type in self.dtypes:
1002            a = torch.rand(1024, dtype=data_type)
1003            b = torch.rand(1024, dtype=data_type)
1004            zeros = torch.zeros(1024, dtype=data_type)
1005            cc = np.array(1024, dtype=float)
1006            cc.fill(np.nan)
1007            nans = torch.from_numpy(cc).to(dtype=data_type)
1008
1009            # random floats
1010            zeros1 = torch.zeros(1024, dtype=data_type)
1011            zeros2 = torch.zeros(1024, dtype=data_type)
1012
1013            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1014            x = warmup_and_run_forward(traced, a, b)
1015            self.assertLastGraphAllFused()
1016            y = run_remainder(a, b)
1017            if data_type is torch.bfloat16:
1018                self.assertEqual(x, y, atol=4e-3, rtol=2e-3)
1019            else:
1020                self.assertEqual(x, y)
1021
1022            # div by 0
1023            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1024            x = warmup_and_run_forward(traced, zeros, a)
1025            self.assertLastGraphAllFused()
1026            y = run_remainder(zeros, a)
1027            self.assertEqual(x, y)
1028
1029            # numerators and denominatos are nan
1030            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1031            x = warmup_and_run_forward(traced, nans, a)
1032            self.assertLastGraphAllFused()
1033            y = run_remainder(nans, a)
1034            self.assertEqual(x, y)
1035
1036    def test_multioutput(self):
1037        def easy(x):
1038            b = x + 1
1039            c = b + b
1040            return (b, c)
1041
1042        traced = torch.jit.trace(easy, (torch.zeros(1024)))
1043
1044        a = torch.zeros(1024)
1045        b, c = warmup_and_run_forward(traced, a)
1046        self.assertLastGraphAllFused()
1047        bp = a.numpy() + 1
1048        cp = bp + bp
1049        np.testing.assert_allclose(b.numpy(), bp)
1050        np.testing.assert_allclose(c.numpy(), cp)
1051
1052    def test_chunk(self):
1053        def easy(x):
1054            y = x + 1
1055            aaa, bbb = torch.chunk(y, 2)
1056            return aaa + bbb
1057
1058        for data_type in self.dtypes:
1059            trace_input = torch.zeros(1024, 1024, dtype=data_type)
1060            traced = torch.jit.trace(easy, (trace_input))
1061
1062            a = torch.zeros(32, 32, dtype=data_type)
1063            x = warmup_and_run_forward(traced, a)
1064            self.assertLastGraphAllFused()
1065            npr = a.float().numpy()
1066            npr2 = npr + 1
1067            npr_a, npr_b = np.array_split(npr2, 2)
1068            np.testing.assert_allclose(npr_a + npr_b, x.float().numpy())
1069
1070    def test_cat(self):
1071        for device in self.devices:
1072            _dim = 1
1073
1074            def foo(*args):
1075                args_2 = [v + i for i, v in enumerate(args)]
1076                v = torch.cat(args_2, dim=_dim)
1077                return v * v
1078
1079            for data_type in self.dtypes:
1080                M = 16
1081                Ns = [128, 16, 1]
1082                values = [torch.zeros(M, N, dtype=data_type, device=device) for N in Ns]
1083                traced = torch.jit.trace(foo, values)
1084
1085                x = warmup_and_run_forward(traced, *values)
1086                self.assertLastGraphAllFused()
1087                ref = foo(*values)
1088                np.testing.assert_allclose(ref.cpu().float().numpy(), x.cpu().float().numpy())
1089
1090            # Test channels-last
1091            for _cur_dim in range(4):
1092                _dim = _cur_dim
1093                values = [torch.randn((2, 3, 4, 5), device=device).to(memory_format=torch.channels_last) for _ in range(10)]
1094                traced = torch.jit.trace(foo, values)
1095
1096                x = warmup_and_run_forward(traced, *values)
1097                self.assertLastGraphAllFused()
1098                ref = foo(*values)
1099                self.assertEqual(ref, x)
1100
1101    # This test checks that we correctly handle fusion group with just aten::cat in it.
1102    # Note that the test only makes sense with min_fusion_group=1, otherwise no
1103    # fusion groups would be formed at all.
1104    # TODO: Fix and re-enable the test.
1105    @unittest.skip("cat is broken with fusion group inlining disabled")
1106    def test_cat_only(self):
1107        for device in self.devices:
1108            def foo(*args):
1109                args_2 = [v + i for i, v in enumerate(args)]
1110                v = torch.cat(args_2, dim=1)
1111                return v
1112
1113            M = 16
1114            Ns = [128, 16, 1]
1115            values = [torch.zeros(M, N, device=device) for N in Ns]
1116            traced = torch.jit.trace(foo, values)
1117
1118            x = warmup_and_run_forward(traced, *values)
1119            self.assertLastGraphAllFused()
1120            ref = foo(*values)
1121            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1122
1123    def test_cat_negative_dim(self):
1124        for device in self.devices:
1125            def foo(*args):
1126                v = torch.cat(args, dim=-1)
1127                return v * v
1128
1129            M = 16
1130            Ns = [128, 16, 1]
1131            values = [torch.randn(M, N, device=device) for N in Ns]
1132            traced = torch.jit.trace(foo, values)
1133
1134            x = warmup_and_run_forward(traced, *values)
1135            self.assertLastGraphAllFused()
1136            ref = foo(*values)
1137            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1138
1139    def test_cat_promote_inputs(self):
1140        for device in self.devices:
1141            def foo(*args):
1142                v = torch.cat(args, dim=1)
1143                return v * v
1144
1145            M = 16
1146            Ns = [128, 16, 1]
1147            dtypes = [torch.half, torch.float32, torch.double]
1148            values = [torch.randn(M, N, device=device, dtype=dt) for N, dt in zip(Ns, dtypes)]
1149            traced = torch.jit.trace(foo, values)
1150
1151            x = warmup_and_run_forward(traced, *values)
1152            self.assertLastGraphAllFused()
1153            ref = foo(*values)
1154            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1155
1156    def test_cat_empty_tensors(self):
1157        for device in self.devices:
1158            def foo(*args):
1159                v = torch.cat(args, dim=1)
1160                return v * v
1161
1162            M = 16
1163            Ns = [128, 16, 1]
1164            empty = torch.tensor([], device=device, dtype=torch.double)
1165            values = [empty] + [torch.randn(M, N, device=device) for N in Ns]
1166            traced = torch.jit.trace(foo, values)
1167
1168            x = warmup_and_run_forward(traced, *values)
1169            self.assertLastGraphAllFused()
1170            ref = foo(*values)
1171            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1172
1173            # now test with only empty tensors
1174            values = [empty for i in range(3)]
1175            traced = torch.jit.trace(foo, values)
1176            x = warmup_and_run_forward(traced, *values)
1177            self.assertLastGraphAllFused()
1178            ref = foo(*values)
1179            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1180
1181    def test_cat_with_constant_dim(self):
1182        for device in self.devices:
1183            def foo(*args):
1184                v1 = torch.cat(args, dim=1)
1185                v2 = torch.cat([v1], dim=1)
1186                return v2 * v2
1187
1188            empty = torch.tensor([], device=device, dtype=torch.float32)
1189            inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)]
1190            traced = torch.jit.trace(foo, inputs)
1191
1192            x = warmup_and_run_forward(traced, *inputs)
1193            self.assertLastGraphAllFused()
1194            ref = foo(*inputs)
1195            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1196
1197    def test_scalar(self):
1198        @torch.jit.script
1199        def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor:
1200            return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
1201
1202        @torch.jit.script
1203        def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor:
1204            return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
1205
1206        for test in (test_float, test_int):
1207            for data_type in self.dtypes:
1208                x, y, z = (torch.rand(4, dtype=data_type) for i in range(3))
1209                a, b = 1, 2
1210                test(x, y, z, a, b)
1211                r = test(x, y, z, a, b)
1212                self.assertEqual(r, x + y * a + z * b)
1213
1214    def test_loop(self):
1215        @torch.jit.script
1216        def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor:
1217            b = y
1218            for i in range(0, z):
1219                a = x + y
1220                b = b + y
1221            return b
1222
1223        x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4)
1224        test(x, y, z)
1225        r = test(x, y, z)
1226
1227    def test_slice(self):
1228        def easy(x, y):
1229            a = x[0:512:2]
1230            b = y[0:512:2]
1231            return a + b
1232
1233        traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024)))
1234
1235        a = torch.ones(1024, 1024)
1236        x = traced(a, a)
1237        npr = a[0:512:2]
1238        npr = npr + npr
1239        np.testing.assert_allclose(npr.numpy(), x.numpy())
1240
1241    def test_unsqueeze(self, N=256):
1242        def easy(x, y):
1243            a = torch.unsqueeze(x, 0)
1244            b = torch.unsqueeze(y, 0)
1245            return a + b
1246
1247        traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N)))
1248
1249        a = torch.rand(N, N)
1250        x = traced(a, a)
1251        npr = np.expand_dims(a, 0)
1252        npr = npr + npr
1253        np.testing.assert_allclose(npr, x.numpy())
1254
1255    def _test_softmax(self, device):
1256        def test_softmax(x, y):
1257            a = F.softmax(x, dim=0, dtype=torch.float32)
1258            b = F.softmax(y, dim=0, dtype=torch.float32)
1259            c = F.softmax(x, dim=1, dtype=torch.float32)
1260            d = F.softmax(y, dim=1, dtype=torch.float32)
1261            return a + b + c + d
1262
1263        def test_softmax_neg_index(x, y):
1264            a = F.softmax(x, dim=-2, dtype=torch.float32)
1265            b = F.softmax(y, dim=-2, dtype=torch.float32)
1266            c = F.softmax(x, dim=-1, dtype=torch.float32)
1267            d = F.softmax(y, dim=-1, dtype=torch.float32)
1268            return a + b + c + d
1269
1270        def test_log_softmax(x, y):
1271            a = F.log_softmax(x, dim=0, dtype=torch.float32)
1272            b = F.log_softmax(y, dim=0, dtype=torch.float32)
1273            c = F.log_softmax(x, dim=1, dtype=torch.float32)
1274            d = F.log_softmax(y, dim=1, dtype=torch.float32)
1275            return a + b + c + d
1276
1277        for test in (test_softmax, test_log_softmax, test_softmax_neg_index):
1278            for data_type in self.dtypes:
1279                old = torch._C._jit_set_texpr_reductions_enabled(True)
1280                traced_input = torch.randn(2, 3, dtype=data_type, device=device)
1281                traced = torch.jit.trace(test, (traced_input, traced_input))
1282                inp = torch.randn(2, 3, dtype=data_type, device=device)
1283                res = traced(inp, inp)
1284                # Use eager mode as reference.
1285                ref = test(inp, inp)
1286                np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06)
1287                torch._C._jit_set_texpr_reductions_enabled(old)
1288
1289    def test_softmax_cpu(self):
1290        self._test_softmax('cpu')
1291
1292    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1293    @unittest.skip("global allocs are not supported yet.")
1294    def test_softmax_cuda(self):
1295        self._test_softmax('cuda')
1296
1297    def test_half_gelu(self):
1298        devices = ["cuda"] if torch.cuda.is_available() else []
1299
1300        @torch.jit.script
1301        def bias_gelu(bias, y):
1302            x = bias + y
1303            return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
1304
1305        for device in devices:
1306            a = torch.rand(1024, dtype=torch.half, device=device)
1307            b = torch.rand(1024, dtype=torch.half, device=device)
1308            traced = torch.jit.trace(bias_gelu, (a, b))
1309            x = warmup_and_run_forward(traced, a, b)
1310            self.assertLastGraphAllFused()
1311
1312    def test_half_bn_relu(self):
1313        devices = ["cuda"] if torch.cuda.is_available() else []
1314
1315        def foo(a, b, c):
1316            y = torch.nn.functional.batch_norm(a, b, c)
1317            z = y.relu()
1318            return z
1319
1320        for device in devices:
1321            a = torch.rand(16, 16, dtype=torch.half, device=device)
1322            b = torch.rand(16, dtype=torch.half, device=device)
1323            c = torch.rand(16, dtype=torch.half, device=device)
1324            traced = torch.jit.trace(foo, (a, b, c))
1325            print(traced.graph)
1326            x = warmup_and_run_forward(traced, a, b, c)
1327            self.assertLastGraphAllFused()
1328
1329    def test_exp_pow(self):
1330        @torch.jit.script
1331        def do_exp(x, y, z):
1332            return ((x * y) * 2) * torch.pow(z, 2)
1333
1334        for device in self.devices:
1335            x = torch.rand(10, dtype=torch.double, device=device)
1336            y = torch.rand(10, dtype=torch.double, device=device)
1337            z = torch.rand(10, dtype=torch.double, device=device)
1338            traced = torch.jit.trace(do_exp, (x, y, z))
1339            x = warmup_and_run_forward(traced, x, y, z)
1340            self.assertLastGraphAllFused()
1341
1342    def test_sin_pow(self):
1343        def test(x):
1344            return torch.sin(torch.pow(x, 0))
1345
1346        for data_type, shape in itertools.product(self.dtypes, [[3], [5], [10]]):
1347            x = torch.rand(shape, dtype=data_type)
1348            scripted = torch.jit.script(test)
1349            out = warmup_and_run_forward(scripted, x)
1350            self.assertLastGraphAllFused()
1351            self.assertEqual(out, test(x))
1352
1353    def test_transpose(self):
1354        @torch.jit.script
1355        def test(x, y, z):
1356            return x.transpose(0, 1) + y + z
1357        x = torch.rand(4, 5, 2, 3)
1358        y = torch.rand(5, 4, 2, 3)
1359        z = torch.rand(5, 4, 2, 3)
1360        ref = test(x, y, z)
1361        res = test(x, y, z)
1362        np.testing.assert_allclose(ref.numpy(), res.numpy())
1363
1364    def test_sliced_stride(self):
1365        @torch.jit.script
1366        def test(x, y, z):
1367            return x + y + z
1368        x = torch.rand(16, 4, 2, 3)[::2]
1369        y = torch.rand(8, 4, 2, 3)
1370        z = torch.rand(8, 4, 2, 3)
1371        ref = test(x, y, z)
1372        res = test(x, y, z)
1373        np.testing.assert_allclose(ref.numpy(), res.numpy())
1374
1375    @unittest.skip("dynamic shapes are not quite there yet")
1376    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1377    def test_dynamic_shape(self):
1378        with num_profiled_runs(2):
1379            @torch.jit.script
1380            def test(x, y, z):
1381                return x * y * z
1382            x, y, z = (torch.rand(4, 8).cuda() for _ in range(3))
1383            ref = test(x, y, z)
1384            _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)])
1385            res = test(x, y, z)
1386            np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
1387
1388            # A wild broadcast appears.
1389            x = torch.rand(4, 8).cuda()
1390            y = torch.rand(1, 8).cuda()
1391            z = torch.rand(4, 1).cuda()
1392            res = test(x, y, z)
1393            xn, yn, zn = (t.cpu().numpy() for t in (x, y, z))
1394            np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
1395
1396            # Mismatched shapes shouldn't reach codegen.
1397            x = torch.rand(4, 8).cuda()
1398            y = torch.rand(4, 8).cuda()
1399            z = torch.rand(5, 8).cuda()
1400            try:
1401                res = test(x, y, z)
1402            except RuntimeError as e:
1403                assert "The size of tensor a (4) must match" in e.args[0]
1404
1405            # Changing a static dimension fails guards.
1406            # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)]
1407            # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
1408            # res = test(x, y, z)
1409            # print(test.graph_for(x, y, z))
1410            # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
1411
1412    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1413    def test_guard_fails(self):
1414        @torch.jit.script
1415        def test(x, y, z):
1416            return x * y * z
1417        r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
1418        r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
1419        r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
1420        r4 = test(*[torch.rand(7).cuda() for _ in range(3)])
1421
1422    def test_bitwise_ops(self):
1423        def run_and(x, y):
1424            return x & (x & y)
1425
1426        def run_or(x, y):
1427            return x & (x | y)
1428
1429        def run_xor(x, y):
1430            return x ^ (x ^ y)
1431
1432        def run_lshift(x, y):
1433            return x & (x << y)
1434
1435        def run_rshift(x, y):
1436            return x & (x >> y)
1437
1438        fns = {run_and, run_or, run_xor, run_lshift, run_rshift}
1439
1440        for device in self.devices:
1441            for fn in fns:
1442                a = torch.ones(128, dtype=torch.int32, device=device)
1443                b = torch.zeros(128, dtype=torch.int32, device=device)
1444                inp = torch.ones(128, dtype=torch.int32, device=device)
1445                traced = torch.jit.trace(fn, (inp, inp))
1446                x = warmup_and_run_forward(traced, a, b)
1447                self.assertLastGraphAllFused()
1448                y = fn(a, b)
1449                np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
1450
1451    def test_where(self):
1452        def run_where(x, y):
1453            return torch.where(torch.gt(x, y), x, y)
1454
1455        for data_type in self.dtypes:
1456            a = torch.rand(1024, dtype=data_type)
1457            b = torch.rand(1024, dtype=data_type)
1458            zeros = torch.zeros(1024, dtype=data_type)
1459            traced = torch.jit.trace(run_where, (zeros, zeros))
1460            x = warmup_and_run_forward(traced, a, b)
1461            self.assertLastGraphAllFused()
1462            y = run_where(a, b)
1463            np.testing.assert_allclose(x.float().numpy(), y.float().numpy())
1464
1465    def test_multi_rand(self):
1466        for device in self.devices:
1467            def test(x):
1468                y = torch.rand_like(x)
1469                return (x + y) - (y - x)
1470
1471            _atol = 2e-3
1472            _rtol = 1e-5
1473            for data_type in self.dtypes:
1474                if data_type is torch.bfloat16:
1475                    _atol = 2e-2
1476                a = torch.rand(4, dtype=data_type, device=device)
1477                scripted = torch.jit.script(test)
1478                out = warmup_and_run_forward(scripted, a)
1479                self.assertLastGraphAllFused()
1480                assert torch.allclose(out, 2 * a, atol=_atol, rtol=_rtol)
1481
1482    def test_mask(self):
1483        def test(x):
1484            return x.unsqueeze(1) == 0
1485
1486        for d in self.devices:
1487            for data_type in self.dtypes:
1488                x = torch.rand(4, dtype=data_type, device=d) > 0.5
1489                scripted = torch.jit.script(test)
1490                out = warmup_and_run_forward(scripted, x)
1491                self.assertLastGraphAllFused()
1492                assert torch.equal(out, test(x))
1493
1494    def test_simple_add(self):
1495        val = torch._C._jit_get_te_generate_block_code()
1496        torch._C._jit_set_te_generate_block_code(True)
1497        fall_bk = torch._C._jit_texpr_fallback_allowed()
1498        torch._C._jit_texpr_set_fallback_allowed(True)
1499
1500        def simple(a, b):
1501            return torch.add(a, b)
1502
1503        a = torch.ones(256, 256)
1504        b = torch.ones(256, 256)
1505        traced = torch.jit.trace(simple,
1506                                 (torch.ones(256, 256), torch.ones(256, 256)))
1507        f = traced(a, b)
1508        f_test = np.full((256, 256), 2, dtype=float)
1509        np.testing.assert_allclose(f.numpy(), f_test)
1510        torch._C._jit_set_te_generate_block_code(val)
1511        torch._C._jit_texpr_set_fallback_allowed(fall_bk)
1512
1513    def test_strided_output_preserved(self):
1514        def foo(a, b):
1515            return a + b - a
1516
1517        # smaller, easier to debug example
1518        x = torch.arange(6)
1519        x = torch.as_strided(x, (2, 3), (1, 2))
1520        total = 0
1521        for i in range(2):
1522            for j in range(3):
1523                x[i, j] = total
1524                total += 1
1525        foo_script = torch.jit.script(foo)
1526        foo_script(x, x)
1527        foo_script(x, x)
1528        out_s = foo_script(x, x)
1529        out_eager = foo(x, x)
1530        self.assertEqual(out_s, out_eager)
1531        self.assertEqual(out_s.stride(), out_eager.stride())
1532        self.assertLastGraphAllFused()
1533
1534        # more dims
1535        N, C, H, W, = 2, 3, 4, 5
1536        x = torch.rand(N, C, H, W).to(memory_format=torch.channels_last)
1537        foo_script = torch.jit.script(foo)
1538        foo_script(x, x)
1539        foo_script(x, x)
1540        out_s = foo_script(x, x)
1541        out_eager = foo(x, x)
1542        self.assertEqual(out_s, out_eager)
1543        self.assertEqual(out_s.stride(), out_eager.stride())
1544        self.assertLastGraphAllFused()
1545
1546    def test_alias_analysis_module(self):
1547        class AliasModule(nn.Module):
1548            def __init__(self) -> None:
1549                super().__init__()
1550                torch.manual_seed(1337)
1551                self.a = torch.randn(128, 128)
1552                self.b = torch.randn(128, 128)
1553                self.c = torch.randn(128, 128)
1554
1555            def forward(self, x, y, z):
1556                z = z + self.a
1557                self.b.add_(y)
1558                w = z + self.a
1559                z = w + x
1560                return z
1561        x = torch.randn(128, 128)
1562
1563        def getModule(script):
1564            am = AliasModule()
1565            if script:
1566                return torch.jit.script(am)
1567            return am
1568
1569        am = getModule(False)
1570        am_s = getModule(True)
1571        ref = am(x, x, x)
1572        test = am_s(x, x, x)
1573        torch.testing.assert_close(ref, test)
1574
1575        # Now do the aliasing
1576        am.a = am.b
1577        ref = am(x, x, x)
1578
1579        am_s.a = am_s.b
1580        test = am_s(x, x, x)
1581
1582        torch.testing.assert_close(ref, test)
1583
1584    def test_alias_analysis_inputs(self):
1585        class AliasModule(nn.Module):
1586            def __init__(self) -> None:
1587                super().__init__()
1588                torch.manual_seed(1337)
1589                self.a = torch.randn(128, 128)
1590                self.b = torch.randn(128, 128)
1591                self.c = torch.randn(128, 128)
1592
1593            def forward(self, x, y, z):
1594                x.add_(y)
1595                w = z + self.a
1596                z = w + x
1597                return z
1598
1599        def getModule(script):
1600            am = AliasModule()
1601            if script:
1602                return torch.jit.script(am)
1603            return am
1604        am = getModule(False)
1605        am_s = getModule(True)
1606
1607        torch.manual_seed(1337)
1608        x = torch.randn(128, 128)
1609        ref = am(x, x, x)
1610
1611        torch.manual_seed(1337)
1612        x = torch.randn(128, 128)
1613        test = am_s(x, x, x)
1614
1615        torch.testing.assert_close(ref, test)
1616
1617    def test_alias_analysis_input_and_module(self):
1618        class AliasModule(nn.Module):
1619            def __init__(self) -> None:
1620                super().__init__()
1621                torch.manual_seed(1337)
1622                self.a = torch.randn(128, 128)
1623                self.b = torch.randn(128, 128)
1624                self.c = torch.randn(128, 128)
1625
1626            def forward(self, x, y, z):
1627                x.add_(y)
1628                w = z + self.b
1629                z = w + x
1630                return z
1631
1632        def getModule(script):
1633            am = AliasModule()
1634            if script:
1635                return torch.jit.script(am)
1636            return am
1637        am = getModule(False)
1638        am_s = getModule(True)
1639
1640        torch.manual_seed(1337)
1641        x = torch.randn(128, 128)
1642        am.b = x
1643        ref = am(x, x, x)
1644
1645        torch.manual_seed(1337)
1646        x = torch.randn(128, 128)
1647        am_s.b = x
1648        test = am_s(x, x, x)
1649
1650        torch.testing.assert_close(ref, test)
1651
1652    def test_multiple_outputs(self):
1653        for device in self.devices:
1654            # A bug reported internally similar to the one reported in #48533
1655            def foo(a, b, c):
1656                t_next = c + 1
1657                t5 = t_next * b
1658                t6 = torch.unsqueeze(t_next, 1)
1659                t7 = a * t6
1660                return (t7, t5, t_next)
1661
1662            for data_type in self.dtypes:
1663                a = torch.rand(20, 20, dtype=data_type, device=device)
1664                b = torch.rand(20 * 29, dtype=data_type, device=device).as_strided([20], [29])
1665                c = torch.ones(20, dtype=torch.int64, device=device)
1666                traced = torch.jit.trace(foo, (a, b, c))
1667                ref = foo(a, b, c)
1668                exp = traced(a, b, c)
1669                exp = traced(a, b, c)
1670                self.assertEqual(ref, exp)
1671
1672    def test_propagated_mem_layout(self):
1673        def foo(a, b, c):
1674            t_next = c + 1
1675            t5 = t_next * b
1676            t7 = a * t5
1677            return t7
1678
1679        def foo_multi_outputs(a, b, c):
1680            t_next = c + 1
1681            t5 = b * t_next
1682            t7 = a * t5
1683            return (t7, t5, t_next)
1684
1685        def foo_multi_outputs_i_nhwc_o_nchw(a, b, c):
1686            t_next = c + 1
1687            t5 = b * t_next
1688            t7 = a * t5
1689            t8 = t7.to(memory_format=torch.contiguous_format)
1690            return (t8, t7, t5, t_next)
1691
1692        def run_foo_case(foo, a, b, c):
1693            traced_contiguous = torch.jit.trace(foo, (a, b, c))
1694            ref = foo(a, b, c)
1695            exp = traced_contiguous(a, b, c)
1696            exp = traced_contiguous(a, b, c)
1697            self.assertEqual(ref, exp)
1698
1699        mem_layouts = list(itertools.product([torch.contiguous_format, torch.channels_last], repeat=3))
1700        shapes = [(2, 3, 4, 5), (2, 1, 1, 5), (1, 1, 1, 1)]
1701        permutes = [(0, 3, 2, 1), (0, 3, 1, 2)]
1702        funcs = [foo, foo_multi_outputs, foo_multi_outputs_i_nhwc_o_nchw]
1703        configs = itertools.product(funcs, shapes, mem_layouts, permutes)
1704        for strategy in ["STATIC", "DYNAMIC"]:
1705            old_strategy = torch.jit.set_fusion_strategy([(strategy, 10)])
1706            for _func, _shape, _mem_layouts, _permute in configs:
1707                a = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[0])
1708                b = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[1])
1709                c = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[2])
1710                run_foo_case(_func, a, b, c)
1711
1712                a = a.permute(dims=_permute)
1713                b = b.permute(dims=_permute)
1714                c = c.permute(dims=_permute)
1715                run_foo_case(_func, a, b, c)
1716
1717            torch.jit.set_fusion_strategy(old_strategy)
1718
1719if __name__ == '__main__':
1720    run_tests()
1721