1# Owner(s): ["NNC"] 2 3import torch 4import numpy as np 5import torch._C._te as te 6 7from torch.testing._internal.common_utils import run_tests 8from torch.testing._internal.jit_utils import JitTestCase 9import unittest 10 11LLVM_ENABLED = torch._C._llvm_enabled() 12 13 14def construct_adder(n: int, dtype=torch.float32): 15 A = te.BufHandle("A", [n], dtype) 16 B = te.BufHandle("B", [n], dtype) 17 18 def compute(i): 19 return A.load([i]) + B.load([i]) 20 21 C = te.Compute("C", [n], compute) 22 23 loopnest = te.LoopNest([C]) 24 loopnest.prepare_for_codegen() 25 stmt = te.simplify(loopnest.root_stmt()) 26 27 return te.construct_codegen("ir_eval", stmt, [A, B, C]) 28 29 30class TestTensorExprPyBind(JitTestCase): 31 def test_simple_sum(self): 32 n = 32 33 cg = construct_adder(n) 34 35 tA = torch.randn(n) 36 tB = torch.randn(n) 37 tC = torch.empty(n) 38 cg.call([tA, tB, tC]) 39 torch.testing.assert_close(tA + tB, tC) 40 41 def test_call_raw(self): 42 n = 16 43 cg = construct_adder(n, dtype=torch.float64) 44 45 tA = torch.randn(n, dtype=torch.float64) 46 tB = torch.randn(n, dtype=torch.float64) 47 tC = torch.empty(n, dtype=torch.float64) 48 cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()]) 49 torch.testing.assert_close(tA + tB, tC) 50 51 def test_external_calls(self): 52 dtype = torch.float32 53 54 A = te.BufHandle("A", [1, 4], dtype) 55 B = te.BufHandle("B", [4, 1], dtype) 56 C = te.BufHandle("C", [1, 1], dtype) 57 58 s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) 59 60 loopnest = te.LoopNest(s, [C]) 61 loopnest.prepare_for_codegen() 62 codegen = te.construct_codegen("ir_eval", s, [A, B, C]) 63 64 tA = torch.ones(1, 4) 65 tB = torch.ones(4, 1) 66 tC = torch.empty(1, 1) 67 codegen.call([tA, tB, tC]) 68 torch.testing.assert_close(torch.matmul(tA, tB), tC) 69 70 def test_dynamic_shape(self): 71 dN = te.VarHandle(torch.int32) 72 A = te.BufHandle([dN], torch.float64) 73 B = te.BufHandle([dN], torch.float64) 74 75 def compute(i): 76 return A.load(i) - B.load(i) 77 78 C = te.Compute("C", [dN], compute) 79 80 loopnest = te.LoopNest([C]) 81 loopnest.prepare_for_codegen() 82 83 cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN]) 84 85 def test_with_shape(n): 86 tA = torch.randn(n, dtype=torch.double) 87 tB = torch.randn(n, dtype=torch.double) 88 tC = torch.empty(n, dtype=torch.double) 89 cg.call([tA, tB, tC, n]) 90 torch.testing.assert_close(tA - tB, tC) 91 92 test_with_shape(8) 93 test_with_shape(31) 94 95 def test_dynamic_shape_2d(self): 96 dN = te.VarHandle(torch.int32) 97 dM = te.VarHandle(torch.int32) 98 A = te.BufHandle([dN, dM], torch.float64) 99 B = te.BufHandle([dN, dM], torch.float64) 100 101 def compute(i, j): 102 return A.load([i, j]) - B.load([i, j]) 103 104 C = te.Compute("C", [dN, dM], compute) 105 106 loopnest = te.LoopNest([C]) 107 loopnest.prepare_for_codegen() 108 109 cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM]) 110 111 def test_with_shape(n, m): 112 tA = torch.randn(n, m, dtype=torch.double) 113 tB = torch.randn(n, m, dtype=torch.double) 114 tC = torch.empty(n, m, dtype=torch.double) 115 cg.call([tA, tB, tC, n, m]) 116 torch.testing.assert_close(tA - tB, tC) 117 118 test_with_shape(2, 4) 119 test_with_shape(5, 3) 120 121 def test_dtype_error(self): 122 te.BufHandle("a", [1], torch.float32) # ok 123 self.assertRaises(TypeError, lambda: te.BufHandle("a", [1], "float55")) 124 125 @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") 126 def test_kernel_with_tensor_inputs(self): 127 def f(a, b, c): 128 return a + b + c 129 130 device, size = "cpu", (4, 4) 131 x = torch.rand(size, device=device) 132 y = torch.rand(size, device=device) 133 z = torch.rand(size, device=device) 134 135 graph_str = """ 136graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), 137 %b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), 138 %c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)): 139 %6 : int = prim::Constant[value=1]() 140 %7 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %6) 141 %3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%7, %c.1, %6) 142 return (%3) 143 """ 144 graph = torch._C.parse_ir(graph_str) 145 146 kernel = te.TensorExprKernel(graph) 147 res1 = kernel.run((x, y, z)) 148 res2 = kernel.fallback((x, y, z)) 149 correct = f(x, y, z) 150 np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) 151 np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) 152 153 @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") 154 def test_kernel_with_scalar_inputs(self): 155 def f(a, b, c): 156 return a + b + c 157 158 x = torch.tensor(0.1, dtype=torch.float, device="cpu") 159 y = torch.tensor(0.6, dtype=torch.float, device="cpu") 160 z = torch.tensor(0.7, dtype=torch.float, device="cpu") 161 162 graph_str = """ 163graph(%a.1 : Float(requires_grad=0, device=cpu), 164 %b.1 : Float(requires_grad=0, device=cpu), 165 %c.1 : Float(requires_grad=0, device=cpu)): 166 %3 : int = prim::Constant[value=1]() 167 %6 : Float(requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %3) 168 %9 : Float(requires_grad=0, device=cpu) = aten::add(%6, %c.1, %3) 169 return (%9) 170 """ 171 graph = torch._C.parse_ir(graph_str) 172 173 kernel = te.TensorExprKernel(graph) 174 res1 = kernel.run((x, y, z)) 175 res2 = kernel.fallback((x, y, z)) 176 correct = f(x, y, z) 177 np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) 178 np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) 179 180 @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") 181 def test_kernel_shape_prop(self): 182 device, size = "cpu", (4, 4) 183 x = torch.rand(size, device=device) 184 y = torch.rand(size, device=device) 185 186 graph_str = """ 187graph(%a : Tensor, %b : Tensor): 188 %c : Tensor = aten::mul(%a, %b) 189 return (%c) 190 """ 191 graph = torch._C.parse_ir(graph_str) 192 193 exception_thrown = False 194 try: 195 kernel = te.TensorExprKernel(graph) 196 except RuntimeError: 197 # Graph doesn't have shape info for inputs => compilation should 198 # fail 199 exception_thrown = True 200 assert exception_thrown 201 202 # Inject shape info and try compiling again 203 example_inputs = [torch.rand(4, 4), torch.rand(4, 4)] 204 torch._C._te.annotate_input_shapes(graph, example_inputs) 205 torch._C._jit_pass_propagate_shapes_on_graph(graph) 206 207 # Now compilation should pass 208 kernel = te.TensorExprKernel(graph) 209 210 res = kernel.run((x, y)) 211 correct = torch.mul(x, y) 212 np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5) 213 214 @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") 215 def test_kernel_shape_prop_module(self): 216 class TestModule(torch.nn.Module): 217 def forward(self, x, y): 218 return x * x + y 219 220 graph = torch.jit.script(TestModule()).graph 221 222 # Try compiling the graph as-is. It should fail because it doesn't have 223 # shape info. 224 exception_thrown = False 225 try: 226 kernel = te.TensorExprKernel(graph) 227 except RuntimeError: 228 exception_thrown = True 229 assert exception_thrown 230 231 # Try injecting shape info for graph inputs 232 example_inputs = [torch.rand(4, 4), torch.rand(4, 4)] 233 234 exception_thrown = False 235 try: 236 torch._C._te.annotate_input_shapes(graph, example_inputs) 237 except RuntimeError: 238 # Graph has a 'self' argument for which we can't set shapes 239 exception_thrown = True 240 assert exception_thrown 241 242 # Remove 'self' argument and try annotating shapes one more time 243 torch._C._te.remove_unused_self_argument(graph) 244 245 # Inject shape info and try compiling again 246 torch._C._te.annotate_input_shapes(graph, example_inputs) 247 torch._C._jit_pass_propagate_shapes_on_graph(graph) 248 249 # Now compilation should pass 250 kernel = te.TensorExprKernel(graph) 251 252 device, size = "cpu", (4, 4) 253 x = torch.rand(size, device=device) 254 y = torch.rand(size, device=device) 255 256 res = kernel.run((x, y)) 257 correct = TestModule().forward(x, y) 258 np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5) 259 260 @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") 261 def test_kernel_with_t(self): 262 def f(a): 263 return a.t() 264 265 device, size = "cpu", (3, 4) 266 x = torch.rand(size, device=device) 267 268 graph_str = """ 269graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): 270 %3 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::t(%a.1) 271 return (%3) 272 """ 273 graph = torch._C.parse_ir(graph_str) 274 275 kernel = te.TensorExprKernel(graph) 276 res1 = kernel.run((x,)) 277 res2 = kernel.fallback((x,)) 278 correct = f(x) 279 np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) 280 np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) 281 282 @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") 283 def test_kernel_with_transpose(self): 284 def f(a): 285 return a.transpose(-1, -2) 286 287 device, size = "cpu", (3, 4) 288 x = torch.rand(size, device=device) 289 290 graph_str = """ 291graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): 292 %2 : int = prim::Constant[value=-1]() 293 %3 : int = prim::Constant[value=-2]() 294 %4 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::transpose(%a.1, %2, %3) 295 return (%4) 296 """ 297 graph = torch._C.parse_ir(graph_str) 298 299 kernel = te.TensorExprKernel(graph) 300 res1 = kernel.run((x,)) 301 res2 = kernel.fallback((x,)) 302 correct = f(x) 303 np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) 304 np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) 305 306 @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") 307 def test_kernel_with_permute(self): 308 def f(a): 309 return a.permute([2, 1, 0]) 310 311 device, size = "cpu", (3, 4, 5) 312 x = torch.rand(size, device=device) 313 314 graph_str = """ 315graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)): 316 %1 : int = prim::Constant[value=2]() 317 %2 : int = prim::Constant[value=1]() 318 %3 : int = prim::Constant[value=0]() 319 %4 : int[] = prim::ListConstruct(%1, %2, %3) 320 %5 : Float(5, 4, 3, strides=[12, 3, 1], requires_grad=0, device=cpu) = aten::permute(%a.1, %4) 321 return (%5) 322 """ 323 graph = torch._C.parse_ir(graph_str) 324 325 kernel = te.TensorExprKernel(graph) 326 res1 = kernel.run((x,)) 327 res2 = kernel.fallback((x,)) 328 correct = f(x) 329 np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) 330 np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) 331 332 @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") 333 def test_kernel_with_custom_lowering(self): 334 def f(a): 335 return a.nan_to_num() 336 337 device = "cpu" 338 x = torch.ones((2, 2), device=device) 339 x[0, 0] = x[1, 1] = torch.nan 340 graph_str = """ 341graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): 342 %none : NoneType = prim::Constant() 343 %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none) 344 return (%y) 345 """ 346 graph = torch._C.parse_ir(graph_str) 347 348 def my_custom_lowering(inputs, out_shape, out_stride, out_type, device): 349 def compute(idxs): 350 load = inputs[0].as_buf().load(idxs) 351 return te.ifThenElse( 352 te.ExprHandle.isnan(load), te.ExprHandle.float(0.0), load 353 ) 354 355 return te.Compute2("custom_nan_to_num", out_shape, compute) 356 357 kernel = te.TensorExprKernel(graph, {"aten::nan_to_num": my_custom_lowering}) 358 res1 = kernel.run((x,)) 359 res2 = kernel.fallback((x,)) 360 correct = f(x) 361 np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) 362 np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) 363 364 @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") 365 def test_kernel_with_expand(self): 366 def f(a): 367 return a.expand((2, 3, 4)) 368 369 device = "cpu" 370 x = torch.rand((1, 3, 1), device=device) 371 graph_str = """ 372graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)): 373 %1 : int = prim::Constant[value=2]() 374 %2 : int = prim::Constant[value=3]() 375 %3 : int = prim::Constant[value=4]() 376 %4 : int[] = prim::ListConstruct(%1, %2, %3) 377 %5 : bool = prim::Constant[value=0]() 378 %6 : Float(2, 3, 4, strides=[12, 4, 0], requires_grad=0, device=cpu) = aten::expand(%a, %4, %5) 379 return (%6) 380 """ 381 graph = torch._C.parse_ir(graph_str) 382 383 kernel = te.TensorExprKernel(graph) 384 res1 = kernel.run((x,)) 385 res2 = kernel.fallback((x,)) 386 correct = f(x) 387 np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) 388 np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) 389 390 @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") 391 def test_alloc_in_loop(self): 392 a, tmp, b = ( 393 te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"] 394 ) 395 body = te.Block([tmp.store([0], a.load([0])), b.store([0], tmp.load([0]))]) 396 for _ in range(4): 397 i = te.VarHandle("i", torch.int32) 398 body = te.For.make(i, 0, 100, body) 399 nest = te.LoopNest(body, [b]) 400 nest.prepare_for_codegen() 401 f = te.construct_codegen("llvm", nest.simplify(), [a, b]) 402 ta, tb = (torch.ones(1) for _ in range(2)) 403 f.call([ta.data_ptr(), tb.data_ptr()]) 404 405 406class TestExprHandlePyBind(JitTestCase): 407 def test_unary_ops(self): 408 unary_operators = { 409 torch.sin: torch._C._te.sin, 410 torch.cos: torch._C._te.cos, 411 torch.tan: torch._C._te.tan, 412 torch.asin: torch._C._te.asin, 413 torch.acos: torch._C._te.acos, 414 torch.atan: torch._C._te.atan, 415 torch.sinh: torch._C._te.sinh, 416 torch.cosh: torch._C._te.cosh, 417 torch.tanh: torch._C._te.tanh, 418 torch.sigmoid: torch._C._te.sigmoid, 419 torch.exp: torch._C._te.exp, 420 torch.expm1: torch._C._te.expm1, 421 torch.abs: torch._C._te.abs, 422 torch.log: torch._C._te.log, 423 torch.log2: torch._C._te.log2, 424 torch.log10: torch._C._te.log10, 425 torch.log1p: torch._C._te.log1p, 426 torch.erf: torch._C._te.erf, 427 torch.erfc: torch._C._te.erfc, 428 torch.sqrt: torch._C._te.sqrt, 429 torch.rsqrt: torch._C._te.rsqrt, 430 torch.ceil: torch._C._te.ceil, 431 torch.floor: torch._C._te.floor, 432 torch.round: torch._C._te.round, 433 torch.trunc: torch._C._te.trunc, 434 torch.lgamma: torch._C._te.lgamma, 435 torch.frac: torch._C._te.frac, 436 } 437 438 def construct_te_fn(op, n: int, dtype=torch.float32): 439 A = torch._C._te.BufHandle("A", [n], dtype) 440 441 def compute(i): 442 return op(A.load([i])) 443 444 C = te.Compute("C", [n], compute) 445 446 loopnest = te.LoopNest([C]) 447 loopnest.prepare_for_codegen() 448 stmt = te.simplify(loopnest.root_stmt()) 449 450 return te.construct_codegen("ir_eval", stmt, [A, C]) 451 452 n = 10 453 a = torch.rand(n) 454 for torch_op, te_op in unary_operators.items(): 455 ref = torch_op(a) 456 457 te_fn = construct_te_fn(te_op, n, torch.float32) 458 res = torch.empty(n) 459 te_fn.call([a, res]) 460 assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3) 461 462 463if __name__ == "__main__": 464 run_tests() 465