xref: /aosp_15_r20/external/pytorch/test/test_tensorexpr_pybind.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["NNC"]
2
3import torch
4import numpy as np
5import torch._C._te as te
6
7from torch.testing._internal.common_utils import run_tests
8from torch.testing._internal.jit_utils import JitTestCase
9import unittest
10
11LLVM_ENABLED = torch._C._llvm_enabled()
12
13
14def construct_adder(n: int, dtype=torch.float32):
15    A = te.BufHandle("A", [n], dtype)
16    B = te.BufHandle("B", [n], dtype)
17
18    def compute(i):
19        return A.load([i]) + B.load([i])
20
21    C = te.Compute("C", [n], compute)
22
23    loopnest = te.LoopNest([C])
24    loopnest.prepare_for_codegen()
25    stmt = te.simplify(loopnest.root_stmt())
26
27    return te.construct_codegen("ir_eval", stmt, [A, B, C])
28
29
30class TestTensorExprPyBind(JitTestCase):
31    def test_simple_sum(self):
32        n = 32
33        cg = construct_adder(n)
34
35        tA = torch.randn(n)
36        tB = torch.randn(n)
37        tC = torch.empty(n)
38        cg.call([tA, tB, tC])
39        torch.testing.assert_close(tA + tB, tC)
40
41    def test_call_raw(self):
42        n = 16
43        cg = construct_adder(n, dtype=torch.float64)
44
45        tA = torch.randn(n, dtype=torch.float64)
46        tB = torch.randn(n, dtype=torch.float64)
47        tC = torch.empty(n, dtype=torch.float64)
48        cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()])
49        torch.testing.assert_close(tA + tB, tC)
50
51    def test_external_calls(self):
52        dtype = torch.float32
53
54        A = te.BufHandle("A", [1, 4], dtype)
55        B = te.BufHandle("B", [4, 1], dtype)
56        C = te.BufHandle("C", [1, 1], dtype)
57
58        s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
59
60        loopnest = te.LoopNest(s, [C])
61        loopnest.prepare_for_codegen()
62        codegen = te.construct_codegen("ir_eval", s, [A, B, C])
63
64        tA = torch.ones(1, 4)
65        tB = torch.ones(4, 1)
66        tC = torch.empty(1, 1)
67        codegen.call([tA, tB, tC])
68        torch.testing.assert_close(torch.matmul(tA, tB), tC)
69
70    def test_dynamic_shape(self):
71        dN = te.VarHandle(torch.int32)
72        A = te.BufHandle([dN], torch.float64)
73        B = te.BufHandle([dN], torch.float64)
74
75        def compute(i):
76            return A.load(i) - B.load(i)
77
78        C = te.Compute("C", [dN], compute)
79
80        loopnest = te.LoopNest([C])
81        loopnest.prepare_for_codegen()
82
83        cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN])
84
85        def test_with_shape(n):
86            tA = torch.randn(n, dtype=torch.double)
87            tB = torch.randn(n, dtype=torch.double)
88            tC = torch.empty(n, dtype=torch.double)
89            cg.call([tA, tB, tC, n])
90            torch.testing.assert_close(tA - tB, tC)
91
92        test_with_shape(8)
93        test_with_shape(31)
94
95    def test_dynamic_shape_2d(self):
96        dN = te.VarHandle(torch.int32)
97        dM = te.VarHandle(torch.int32)
98        A = te.BufHandle([dN, dM], torch.float64)
99        B = te.BufHandle([dN, dM], torch.float64)
100
101        def compute(i, j):
102            return A.load([i, j]) - B.load([i, j])
103
104        C = te.Compute("C", [dN, dM], compute)
105
106        loopnest = te.LoopNest([C])
107        loopnest.prepare_for_codegen()
108
109        cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM])
110
111        def test_with_shape(n, m):
112            tA = torch.randn(n, m, dtype=torch.double)
113            tB = torch.randn(n, m, dtype=torch.double)
114            tC = torch.empty(n, m, dtype=torch.double)
115            cg.call([tA, tB, tC, n, m])
116            torch.testing.assert_close(tA - tB, tC)
117
118        test_with_shape(2, 4)
119        test_with_shape(5, 3)
120
121    def test_dtype_error(self):
122        te.BufHandle("a", [1], torch.float32)  # ok
123        self.assertRaises(TypeError, lambda: te.BufHandle("a", [1], "float55"))
124
125    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
126    def test_kernel_with_tensor_inputs(self):
127        def f(a, b, c):
128            return a + b + c
129
130        device, size = "cpu", (4, 4)
131        x = torch.rand(size, device=device)
132        y = torch.rand(size, device=device)
133        z = torch.rand(size, device=device)
134
135        graph_str = """
136graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
137      %b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
138      %c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)):
139  %6 : int = prim::Constant[value=1]()
140  %7 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %6)
141  %3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%7, %c.1, %6)
142  return (%3)
143        """
144        graph = torch._C.parse_ir(graph_str)
145
146        kernel = te.TensorExprKernel(graph)
147        res1 = kernel.run((x, y, z))
148        res2 = kernel.fallback((x, y, z))
149        correct = f(x, y, z)
150        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
151        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
152
153    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
154    def test_kernel_with_scalar_inputs(self):
155        def f(a, b, c):
156            return a + b + c
157
158        x = torch.tensor(0.1, dtype=torch.float, device="cpu")
159        y = torch.tensor(0.6, dtype=torch.float, device="cpu")
160        z = torch.tensor(0.7, dtype=torch.float, device="cpu")
161
162        graph_str = """
163graph(%a.1 : Float(requires_grad=0, device=cpu),
164      %b.1 : Float(requires_grad=0, device=cpu),
165      %c.1 : Float(requires_grad=0, device=cpu)):
166  %3 : int = prim::Constant[value=1]()
167  %6 : Float(requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %3)
168  %9 : Float(requires_grad=0, device=cpu) = aten::add(%6, %c.1, %3)
169  return (%9)
170        """
171        graph = torch._C.parse_ir(graph_str)
172
173        kernel = te.TensorExprKernel(graph)
174        res1 = kernel.run((x, y, z))
175        res2 = kernel.fallback((x, y, z))
176        correct = f(x, y, z)
177        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
178        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
179
180    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
181    def test_kernel_shape_prop(self):
182        device, size = "cpu", (4, 4)
183        x = torch.rand(size, device=device)
184        y = torch.rand(size, device=device)
185
186        graph_str = """
187graph(%a : Tensor, %b : Tensor):
188  %c : Tensor = aten::mul(%a, %b)
189  return (%c)
190        """
191        graph = torch._C.parse_ir(graph_str)
192
193        exception_thrown = False
194        try:
195            kernel = te.TensorExprKernel(graph)
196        except RuntimeError:
197            # Graph doesn't have shape info for inputs => compilation should
198            # fail
199            exception_thrown = True
200        assert exception_thrown
201
202        # Inject shape info and try compiling again
203        example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
204        torch._C._te.annotate_input_shapes(graph, example_inputs)
205        torch._C._jit_pass_propagate_shapes_on_graph(graph)
206
207        # Now compilation should pass
208        kernel = te.TensorExprKernel(graph)
209
210        res = kernel.run((x, y))
211        correct = torch.mul(x, y)
212        np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
213
214    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
215    def test_kernel_shape_prop_module(self):
216        class TestModule(torch.nn.Module):
217            def forward(self, x, y):
218                return x * x + y
219
220        graph = torch.jit.script(TestModule()).graph
221
222        # Try compiling the graph as-is. It should fail because it doesn't have
223        # shape info.
224        exception_thrown = False
225        try:
226            kernel = te.TensorExprKernel(graph)
227        except RuntimeError:
228            exception_thrown = True
229        assert exception_thrown
230
231        # Try injecting shape info for graph inputs
232        example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
233
234        exception_thrown = False
235        try:
236            torch._C._te.annotate_input_shapes(graph, example_inputs)
237        except RuntimeError:
238            # Graph has a 'self' argument for which we can't set shapes
239            exception_thrown = True
240        assert exception_thrown
241
242        # Remove 'self' argument and try annotating shapes one more time
243        torch._C._te.remove_unused_self_argument(graph)
244
245        # Inject shape info and try compiling again
246        torch._C._te.annotate_input_shapes(graph, example_inputs)
247        torch._C._jit_pass_propagate_shapes_on_graph(graph)
248
249        # Now compilation should pass
250        kernel = te.TensorExprKernel(graph)
251
252        device, size = "cpu", (4, 4)
253        x = torch.rand(size, device=device)
254        y = torch.rand(size, device=device)
255
256        res = kernel.run((x, y))
257        correct = TestModule().forward(x, y)
258        np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
259
260    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
261    def test_kernel_with_t(self):
262        def f(a):
263            return a.t()
264
265        device, size = "cpu", (3, 4)
266        x = torch.rand(size, device=device)
267
268        graph_str = """
269graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
270  %3 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::t(%a.1)
271  return (%3)
272        """
273        graph = torch._C.parse_ir(graph_str)
274
275        kernel = te.TensorExprKernel(graph)
276        res1 = kernel.run((x,))
277        res2 = kernel.fallback((x,))
278        correct = f(x)
279        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
280        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
281
282    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
283    def test_kernel_with_transpose(self):
284        def f(a):
285            return a.transpose(-1, -2)
286
287        device, size = "cpu", (3, 4)
288        x = torch.rand(size, device=device)
289
290        graph_str = """
291graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
292  %2 : int = prim::Constant[value=-1]()
293  %3 : int = prim::Constant[value=-2]()
294  %4 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::transpose(%a.1, %2, %3)
295  return (%4)
296        """
297        graph = torch._C.parse_ir(graph_str)
298
299        kernel = te.TensorExprKernel(graph)
300        res1 = kernel.run((x,))
301        res2 = kernel.fallback((x,))
302        correct = f(x)
303        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
304        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
305
306    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
307    def test_kernel_with_permute(self):
308        def f(a):
309            return a.permute([2, 1, 0])
310
311        device, size = "cpu", (3, 4, 5)
312        x = torch.rand(size, device=device)
313
314        graph_str = """
315graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)):
316  %1 : int = prim::Constant[value=2]()
317  %2 : int = prim::Constant[value=1]()
318  %3 : int = prim::Constant[value=0]()
319  %4 : int[] = prim::ListConstruct(%1, %2, %3)
320  %5 : Float(5, 4, 3, strides=[12, 3, 1], requires_grad=0, device=cpu) = aten::permute(%a.1, %4)
321  return (%5)
322        """
323        graph = torch._C.parse_ir(graph_str)
324
325        kernel = te.TensorExprKernel(graph)
326        res1 = kernel.run((x,))
327        res2 = kernel.fallback((x,))
328        correct = f(x)
329        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
330        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
331
332    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
333    def test_kernel_with_custom_lowering(self):
334        def f(a):
335            return a.nan_to_num()
336
337        device = "cpu"
338        x = torch.ones((2, 2), device=device)
339        x[0, 0] = x[1, 1] = torch.nan
340        graph_str = """
341graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
342    %none : NoneType = prim::Constant()
343    %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none)
344    return (%y)
345        """
346        graph = torch._C.parse_ir(graph_str)
347
348        def my_custom_lowering(inputs, out_shape, out_stride, out_type, device):
349            def compute(idxs):
350                load = inputs[0].as_buf().load(idxs)
351                return te.ifThenElse(
352                    te.ExprHandle.isnan(load), te.ExprHandle.float(0.0), load
353                )
354
355            return te.Compute2("custom_nan_to_num", out_shape, compute)
356
357        kernel = te.TensorExprKernel(graph, {"aten::nan_to_num": my_custom_lowering})
358        res1 = kernel.run((x,))
359        res2 = kernel.fallback((x,))
360        correct = f(x)
361        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
362        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
363
364    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
365    def test_kernel_with_expand(self):
366        def f(a):
367            return a.expand((2, 3, 4))
368
369        device = "cpu"
370        x = torch.rand((1, 3, 1), device=device)
371        graph_str = """
372graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
373  %1 : int = prim::Constant[value=2]()
374  %2 : int = prim::Constant[value=3]()
375  %3 : int = prim::Constant[value=4]()
376  %4 : int[] = prim::ListConstruct(%1, %2, %3)
377  %5 : bool = prim::Constant[value=0]()
378  %6 : Float(2, 3, 4, strides=[12, 4, 0], requires_grad=0, device=cpu) = aten::expand(%a, %4, %5)
379  return (%6)
380        """
381        graph = torch._C.parse_ir(graph_str)
382
383        kernel = te.TensorExprKernel(graph)
384        res1 = kernel.run((x,))
385        res2 = kernel.fallback((x,))
386        correct = f(x)
387        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
388        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
389
390    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
391    def test_alloc_in_loop(self):
392        a, tmp, b = (
393            te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"]
394        )
395        body = te.Block([tmp.store([0], a.load([0])), b.store([0], tmp.load([0]))])
396        for _ in range(4):
397            i = te.VarHandle("i", torch.int32)
398            body = te.For.make(i, 0, 100, body)
399        nest = te.LoopNest(body, [b])
400        nest.prepare_for_codegen()
401        f = te.construct_codegen("llvm", nest.simplify(), [a, b])
402        ta, tb = (torch.ones(1) for _ in range(2))
403        f.call([ta.data_ptr(), tb.data_ptr()])
404
405
406class TestExprHandlePyBind(JitTestCase):
407    def test_unary_ops(self):
408        unary_operators = {
409            torch.sin: torch._C._te.sin,
410            torch.cos: torch._C._te.cos,
411            torch.tan: torch._C._te.tan,
412            torch.asin: torch._C._te.asin,
413            torch.acos: torch._C._te.acos,
414            torch.atan: torch._C._te.atan,
415            torch.sinh: torch._C._te.sinh,
416            torch.cosh: torch._C._te.cosh,
417            torch.tanh: torch._C._te.tanh,
418            torch.sigmoid: torch._C._te.sigmoid,
419            torch.exp: torch._C._te.exp,
420            torch.expm1: torch._C._te.expm1,
421            torch.abs: torch._C._te.abs,
422            torch.log: torch._C._te.log,
423            torch.log2: torch._C._te.log2,
424            torch.log10: torch._C._te.log10,
425            torch.log1p: torch._C._te.log1p,
426            torch.erf: torch._C._te.erf,
427            torch.erfc: torch._C._te.erfc,
428            torch.sqrt: torch._C._te.sqrt,
429            torch.rsqrt: torch._C._te.rsqrt,
430            torch.ceil: torch._C._te.ceil,
431            torch.floor: torch._C._te.floor,
432            torch.round: torch._C._te.round,
433            torch.trunc: torch._C._te.trunc,
434            torch.lgamma: torch._C._te.lgamma,
435            torch.frac: torch._C._te.frac,
436        }
437
438        def construct_te_fn(op, n: int, dtype=torch.float32):
439            A = torch._C._te.BufHandle("A", [n], dtype)
440
441            def compute(i):
442                return op(A.load([i]))
443
444            C = te.Compute("C", [n], compute)
445
446            loopnest = te.LoopNest([C])
447            loopnest.prepare_for_codegen()
448            stmt = te.simplify(loopnest.root_stmt())
449
450            return te.construct_codegen("ir_eval", stmt, [A, C])
451
452        n = 10
453        a = torch.rand(n)
454        for torch_op, te_op in unary_operators.items():
455            ref = torch_op(a)
456
457            te_fn = construct_te_fn(te_op, n, torch.float32)
458            res = torch.empty(n)
459            te_fn.call([a, res])
460            assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
461
462
463if __name__ == "__main__":
464    run_tests()
465