1# Owner(s): ["NNC"] 2 3import contextlib 4import math 5import operator 6import os 7import unittest 8import warnings 9from typing import List 10 11import torch 12import torch.nn.functional as F 13from torch.testing import FileCheck 14 15 16# these needs to be set before `common_utils` 17# infers `GRAPH_EXECUTOR`. 18# this file **requires** these settings 19# and setting them after `GRAPH_EXECUTOR` is 20# inferred erroneously runs or skips 21# some tests 22torch._C._jit_set_profiling_executor(True) 23torch._C._get_graph_executor_optimize(True) 24 25from itertools import combinations, permutations, product 26from textwrap import dedent 27 28from jit.test_fuser_common import TestFuserCommon # noqa: F401 29from test_jit import ( 30 backward_graph, 31 get_lstm_inputs, 32 get_milstm_inputs, 33 LSTMCellC, 34 LSTMCellF, 35 LSTMCellS, 36 MiLSTMCell, 37) 38 39from torch.testing._internal.common_device_type import ( 40 instantiate_device_type_tests, 41 onlyCPU, 42 OpDTypes, 43 ops, 44) 45from torch.testing._internal.common_jit import JitCommonTestCase 46from torch.testing._internal.common_methods_invocations import op_db 47from torch.testing._internal.common_utils import ( 48 enable_profiling_mode_for_profiling_tests, 49 GRAPH_EXECUTOR, 50 IS_FBCODE, 51 ProfilingMode, 52 run_tests, 53 skipIfTorchDynamo, 54 slowTest, 55 TEST_WITH_ASAN, 56 TEST_WITH_ROCM, 57) 58from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn 59from torch.testing._internal.jit_utils import ( 60 clone_inputs, 61 get_traced_sample_variant_pairs, 62 JitTestCase, 63 NoTracerWarnContextManager, 64 RUN_CUDA, 65 RUN_CUDA_HALF, 66 RUN_CUDA_MULTI_GPU, 67 set_fusion_group_inlining, 68 TensorExprTestOptions, 69 warmup_backward, 70) 71 72 73FUSION_GROUP = "prim::TensorExprGroup" 74LLVM_ENABLED = torch._C._llvm_enabled() 75 76autograd_check_set = { 77 "aten::__is__", 78 "prim::AutogradAllNonZero", 79 "prim::AutogradAllZero", 80 "prim::ListConstruct", 81} 82 83 84def strip_profiling_nodes(nodes): 85 profiling_opcodes = {"prim::BailoutTemplate", "prim::BailOut"} 86 return [n for n in nodes if n.kind() not in profiling_opcodes] 87 88 89def warmup_forward(f, *args, profiling_count=2): 90 for i in range(profiling_count): 91 results = f(*args) 92 93 return results 94 95 96@contextlib.contextmanager 97def texpr_reductions_enabled(): 98 old = torch._C._jit_set_texpr_reductions_enabled(True) 99 try: 100 yield 101 finally: 102 torch._C._jit_set_texpr_reductions_enabled(old) 103 104 105@contextlib.contextmanager 106def texpr_enable_strategy(strategy): 107 old = torch._C._jit_set_fusion_strategy(strategy) 108 try: 109 yield 110 finally: 111 torch._C._jit_set_fusion_strategy(old) 112 113 114@contextlib.contextmanager 115def inline_fusion_groups(): 116 old_inlining = torch._C._debug_get_fusion_group_inlining() 117 torch._C._debug_set_fusion_group_inlining(True) 118 try: 119 yield 120 finally: 121 torch._C._debug_set_fusion_group_inlining(old_inlining) 122 123 124class TestTEFuser(JitTestCase): 125 def setUp(self): 126 super().setUp() 127 self.tensorexpr_options = TensorExprTestOptions() 128 129 # note: `self.dynamic_shapes` instatiated in specialization of class 130 # defined below 131 132 fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)] 133 self.old_fusion_strategy = torch._C._jit_set_fusion_strategy(fusion_strategy) 134 135 self.devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"] 136 self.int_dtypes = [ 137 torch.int8, 138 torch.int16, 139 torch.int32, 140 torch.int64, 141 torch.bool, 142 ] 143 self.fp_dtypes = [ 144 torch.float16, 145 torch.float32, 146 torch.float64, 147 torch.bfloat16, 148 ] 149 self.dtypes = self.int_dtypes + self.fp_dtypes 150 151 def tearDown(self): 152 self.tensorexpr_options.restore() 153 torch._C._jit_set_fusion_strategy(self.old_fusion_strategy) 154 super().tearDown() 155 156 def assertAllFused(self, graph, except_for=None): 157 except_for = except_for if except_for is not None else set() 158 # TODO - upstream 159 guards = ( 160 "prim::TypeCheck", 161 "prim::RequiresGradCheck", 162 "prim::TensorExprDynamicGuard", 163 ) 164 guard_found = False 165 166 def autodiff_guard(node): 167 if node.kind() != "aten::all": 168 return False 169 inps = list(node.inputs()) 170 if len(inps) != 1 or inps[0].node().kind() != "prim::ListConstruct": 171 return False 172 li_inps = list(inps[0].node().inputs()) 173 for li_inp in li_inps: 174 if li_inp.node().kind() in ( 175 "prim::AutogradAllNonZero", 176 "prim::AutogradAllZero", 177 ): 178 return True 179 return False 180 181 def is_guard(node): 182 return node.kind() in guards or autodiff_guard(node) 183 184 for node in graph.block().nodes(): 185 if node.kind() == "prim::Constant": 186 continue 187 if is_guard(node): 188 self.assertFalse(guard_found) 189 guard_found = True 190 continue 191 if node.kind() in except_for: 192 continue 193 if node.kind() == "prim::If": 194 self.assertTrue(is_guard(node.prev())) 195 continue 196 self.assertTrue(False, "Found unexpected node:" + node.kind()) 197 198 self.assertTrue(guard_found) 199 200 def assertLastGraphAllFused(self): 201 self.assertAllFused(torch.jit.last_executed_optimized_graph()) 202 203 def findFusionGroups(self, graph): 204 result = [] 205 for n in graph.nodes(): 206 if n.kind() == FUSION_GROUP: 207 result.append(n.g("Subgraph")) 208 continue 209 for block in n.blocks(): 210 result += self.findFusionGroups(block) 211 return result 212 213 def test_typecheck(self): 214 a = torch.ones(1) 215 216 def fused_kernel(a, b): 217 return (a + b) * 2.0 218 219 scripted = self.checkScript(fused_kernel, (a, a)) 220 graph = scripted.graph_for(a, a) 221 # double check we fused 222 fusion_groups = self.findFusionGroups(graph) 223 self.assertEqual(len(fusion_groups), 1) 224 # we use a bigger tensor now (size 2) 225 # if we won't trigger a recompilation 226 # we will still create a tensor up to (size 1) 227 # if the type check fails 228 a = torch.ones(2) 229 # shape changed if we don't trigger recompilation 230 # we would compute the wrong result silently 231 self.assertEqual(scripted(a, a), fused_kernel(a, a)) 232 233 def test_sum_simple(self): 234 def func(x): 235 x2 = x * x 236 return x2.sum() 237 238 with texpr_reductions_enabled(): 239 a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") 240 a = a.reshape(5, 3) 241 scripted = self.checkScript(func, (a,)) 242 self.assertLastGraphAllFused() 243 244 def test_nop(self): 245 pass 246 247 def test_sum_dim(self): 248 def func(x): 249 return x.sum((0,)) * 2 250 251 def func_neg(x): 252 return x.sum((-2,)) * 2 253 254 with texpr_reductions_enabled(): 255 a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") 256 a = a.reshape(5, 3) 257 scripted = self.checkScript(func, (a,)) 258 self.assertLastGraphAllFused() 259 scripted = self.checkScript(func_neg, (a,)) 260 self.assertLastGraphAllFused() 261 262 def test_sum_keepdim_cast(self): 263 def func(x): 264 return x.sum((0,), keepdim=True, dtype=torch.double) * 2 265 266 with texpr_reductions_enabled(): 267 a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") 268 a = a.reshape(5, 3) 269 270 self.checkScript(func, (a,)) 271 self.assertLastGraphAllFused() 272 273 def test_abs(self): 274 for device in self.devices: 275 276 def func(x): 277 return x.abs() * 2 278 279 a = torch.randn(5, device=device) 280 scripted = self.checkScript(func, (a,)) 281 self.assertLastGraphAllFused() 282 283 def test_unsqueeze_size_calculation(self): 284 for device in self.devices: 285 286 def foo(b, d): 287 x = d.unsqueeze(1) 288 y = x * 42.0 289 z = b + y 290 r = z / 42.0 291 return r 292 293 inputs = ( 294 torch.rand(20, 28, device=device, requires_grad=True), 295 torch.rand(20, device=device), 296 ) 297 scripted = self.checkScript(foo, inputs) 298 self.assertAllFused(scripted.graph_for(*inputs)) 299 300 def test_zero_element_tensors(self): 301 for device in self.devices: 302 303 def decode(sin_t, cos_t): 304 theta = torch.atan2(sin_t.float(), cos_t.float()) 305 return theta 306 307 sin = torch.zeros(0, device=device) 308 cos = torch.zeros(0, device=device) 309 inputs = [sin, cos] 310 ge = self.checkScript(decode, inputs) 311 312 def test_arg_configurations_smoke(self): 313 if self.dynamic_shapes: 314 self.skipTest("TODO: chunk dynamic shapes") 315 316 # A smoke test to make sure we won't use the same kernel for contiguous 317 # and non-contiguous arguments. 318 # TODO: add optionally enabled debug counters to the fuser to verify 319 # that we really can tell the difference between configurations 320 for device in self.devices: 321 322 def f(x, y): 323 z1, z2 = (x + y).chunk(2, dim=1) 324 return z1 * z2 325 326 x = torch.randn(4, 4, dtype=torch.float, device=device) 327 y = torch.randn(4, 4, dtype=torch.float, device=device) 328 traced_f = torch.jit.trace(f, (x, y)) 329 self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) 330 331 def test_broadcast(self): 332 for device in self.devices: 333 334 def scaleshift(x, scale, shift): 335 return x * scale + shift 336 337 inputs = [ 338 torch.randn(4, 4, dtype=torch.float, device=device), 339 torch.randn(4, dtype=torch.float, device=device), 340 torch.randn(4, dtype=torch.float, device=device), 341 ] 342 self.checkScript(scaleshift, inputs) 343 344 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 345 @unittest.skipIf(not RUN_CUDA_HALF, "no half support") 346 @unittest.skipIf( 347 GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on" 348 ) 349 def test_cuda_half(self): 350 x = torch.randn(4, 4, dtype=torch.half, device="cuda") 351 y = torch.randn(4, 4, dtype=torch.half, device="cuda") 352 353 funcs = [self.fn_test_comparison_gt_lt, self.fn_test_relu, self.fn_test_exp] 354 355 # Note: Non fused inputs must be float to prevent loss of precision 356 inputs = (x.float(), y.float()) 357 fusion_inputs = (x, y) 358 for fn in funcs: 359 local_inputs = [t.clone().requires_grad_() for t in inputs] 360 local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs] 361 362 # Verifies outputs 363 fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False) 364 outputs = fn(*local_inputs) 365 fusion_outputs = fusion(*local_fusion_inputs) 366 outputs_half = [t.half() for t in outputs] 367 self.assertEqual(outputs_half, fusion_outputs) 368 369 # Verifies gradients 370 for output, fusion_output in zip(outputs_half, fusion_outputs): 371 grads = torch.autograd.grad( 372 output.float().sum(), 373 local_inputs, 374 allow_unused=True, 375 retain_graph=True, 376 ) 377 fusion_grads = torch.autograd.grad( 378 fusion_output.sum(), 379 local_fusion_inputs, 380 allow_unused=True, 381 retain_graph=True, 382 ) 383 grads_half = [t.half() for t in grads] 384 self.assertEqual(grads_half, fusion_grads) 385 386 def test_checks_cat_inputs(self): 387 # single fusion node causes error 388 with set_fusion_group_inlining(True): 389 for device in self.devices: 390 # We shouldn't treat cat nodes as broadcasting. All their inputs 391 # need to be checked for having the same map size, before we can 392 # run the kernel. 393 def f(x, y): 394 return torch.cat([x + 2 * x + x**2, y + 4 * y + y**3], dim=0) 395 396 # NOTE: y is broadcastable to x, but output of f(x, y) should have 397 # shape 3x4, and not 4x4. 398 x = torch.randn(2, 4, dtype=torch.float, device=device) 399 y = torch.randn(1, 4, dtype=torch.float, device=device) 400 401 scripted = self.checkScript(f, (x, y)) 402 self.assertEqual(scripted(x, y).shape, (3, 4)) 403 self.assertAllFused(scripted.graph_for(x, y)) 404 405 def test_chunk(self): 406 if self.dynamic_shapes: 407 self.skipTest("TODO: chunk dynamic shapes") 408 409 for device in self.devices: 410 411 def fn(x): 412 a, b, c = x.chunk(3, 1) 413 return a * b + c 414 415 inputs = [torch.randn(10, 6, dtype=torch.float, device=device)] 416 417 self.checkScript(fn, inputs) 418 self.assertLastGraphAllFused() 419 420 def test_chunk_correctness(self): 421 if self.dynamic_shapes: 422 self.skipTest("TODO: chunk dynamic shapes") 423 424 for device in self.devices: 425 426 def chunk_4_0(x): 427 x0, x1, x2, x3 = x.chunk(4, 0) 428 return x0 + x1 + x2 + x3 429 430 def chunk_4_1(x): 431 x0, x1, x2, x3 = x.chunk(4, 1) 432 return x0 + x1 + x2 + x3 433 434 def chunk_4_last(x): 435 x0, x1, x2, x3 = x.chunk(4, 2) 436 return x0 + x1 + x2 + x3 437 438 fns = [chunk_4_0, chunk_4_1, chunk_4_last] 439 tensors = [ 440 # splitSize = 1 441 torch.randn(4, 4, 4, dtype=torch.float, device=device), 442 # contiguous case 443 torch.randn(12, 8, 16, dtype=torch.float, device=device), 444 # non-contiguous case 445 torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose( 446 1, 2 447 ), 448 ] 449 450 for tensor in tensors: 451 for fn in fns: 452 self.checkScript(fn, [tensor]) 453 self.assertLastGraphAllFused() 454 455 def test_chunk_distributes(self): 456 if self.dynamic_shapes: 457 self.skipTest("TODO: chunk dynamic shapes") 458 459 if self.dynamic_shapes: 460 self.skipTest("TODO: chunk dynamic shapes") 461 462 for device in self.devices: 463 464 def f(x, y): 465 z1, z2 = (x + y).chunk(2, dim=1) 466 return z1 * z2 467 468 x = torch.randn(4, 4, dtype=torch.float, device=device) 469 y = torch.randn(4, 4, dtype=torch.float, device=device) 470 471 ge = self.checkTrace(f, (x, y)) 472 graph = ge.graph_for(x, y) 473 # XXX: The old fuser does broadcast_tensors but the new fuser doesn't. 474 # FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \ 475 # .check_count('ConstantChunk', 2, exactly=True).run(str(graph)) 476 FileCheck().check("with " + FUSION_GROUP + "_").check_count( 477 "ConstantChunk", 1, exactly=True 478 ).run(str(graph)) 479 480 def test_chunk_motion_deduplicates_inputs(self): 481 if self.dynamic_shapes: 482 self.skipTest("TODO: chunk dynamic shapes") 483 484 for device in self.devices: 485 486 def func1(x): 487 z = x * x 488 z0, z1 = z.chunk(2) 489 return z0 * z1 490 491 def func2(x): 492 z = x * x * x 493 z0, z1 = z.chunk(2) 494 return z0 * z1 495 496 inputs = [torch.tensor([1.1, 1.2], device=device, dtype=torch.float)] 497 for func in [func1, func2]: 498 self.checkScript(func, inputs) 499 self.assertLastGraphAllFused() 500 501 def test_chunk_multiple(self): 502 if self.dynamic_shapes: 503 self.skipTest("TODO: chunk dynamic shapes") 504 505 for device in self.devices: 506 # The arguments are intentionally used out of order as a test to see 507 # if the fusion compiler adds extra args in the correct order 508 def fn(s, x, y, z): 509 z1, z2 = z.chunk(2, 2) 510 x1, x2, x3 = x.chunk(3, 1) 511 y1, y2 = y.chunk(2, 0) 512 return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 513 514 inputs = [ 515 torch.randn(5, 2, 3, dtype=torch.float, device=device), 516 torch.randn(5, 6, 3, dtype=torch.float, device=device), 517 torch.randn(10, 2, 3, dtype=torch.float, device=device), 518 torch.randn(5, 2, 6, dtype=torch.float, device=device), 519 ] 520 521 ge = self.checkScript(fn, inputs) 522 self.assertAllFused(ge.graph_for(*inputs)) 523 524 def test_minmax(self): 525 for device in self.devices: 526 527 def tmax(a, b): 528 return torch.max(2 * a, b) 529 530 def tmin(a, b): 531 return torch.min(2 * a, b) 532 533 a = torch.randn(4, 4, dtype=torch.float) 534 b = torch.randn(4, 4, dtype=torch.float) 535 nan = torch.tensor(float("nan"), dtype=torch.float) 536 537 for f, inputs, device in product( 538 (tmax, tmin), ([a, b], [a, nan], [b, nan]), self.devices 539 ): 540 inputs = [t.to(device) for t in inputs] 541 s = self.checkScript(f, inputs) 542 self.assertAllFused(s.graph_for(*inputs)) 543 544 def test_clamp(self): 545 for device in self.devices: 546 547 def func2(a, b): 548 return torch.clamp(a + b, min=0, max=2) 549 550 def funcInf(a, b): 551 return torch.clamp(a + b, min=0, max=float("inf")) 552 553 def funcNegInf(a, b): 554 return torch.clamp(a + b, min=float("-inf"), max=0) 555 556 def funcOptMin(a, b): 557 return torch.clamp(a + b, max=2) 558 559 def funcOptMax(a, b): 560 return torch.clamp(a + b, min=0) 561 562 a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True) 563 b = torch.randn(4, 4, dtype=torch.float, device=device) 564 nan = torch.tensor(float("nan"), dtype=torch.float, device=device) 565 566 funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) 567 for f, inputs in product(funcs, [[a, b], [a, nan]]): 568 inp1, inp2 = inputs 569 s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) 570 self.assertAllFused( 571 s.graph_for(inp1, inp2), 572 except_for={"aten::size", "aten::_size_if_not_equal"}, 573 ) 574 c = s(inp1, inp2) 575 with enable_profiling_mode_for_profiling_tests(): 576 warmup_backward(c.sum()) 577 graph = backward_graph(s) 578 self.assertAllFused( 579 graph, 580 except_for={"aten::Float", "aten::_grad_sum_to_size"}.union( 581 autograd_check_set 582 ), 583 ) 584 585 def test_clamp_double(self): 586 for device in self.devices: 587 588 def clamp_double(x, eta: float): 589 return 1 - x.clamp(eta, 1 - eta) 590 591 x = torch.tensor([1.0, 1.0], dtype=torch.double, device=device) 592 eta = 1e-9 593 s = self.checkScript( 594 clamp_double, 595 (x, eta), 596 profiling=ProfilingMode.PROFILING, 597 atol=1e-10, 598 rtol=1e-5, 599 ) 600 self.assertAllFused(s.graph_for(x, eta), except_for={"aten::sub"}) 601 602 def test_clamp_int(self): 603 for device in self.devices: 604 605 def clamp_int(x, eta: int): 606 return x.clamp(0, eta) 607 608 x = torch.tensor([1, 1], device=device) 609 eta = 1 << 32 610 s = self.checkScript(clamp_int, (x, eta), profiling=ProfilingMode.PROFILING) 611 self.assertAllFused(s.graph_for(x, eta)) 612 613 def test_add_bool(self): 614 sizes = [(1,), (2,), (4, 4)] 615 for device, size in product(self.devices, sizes): 616 617 def f(x, y, z): 618 return x + y + z 619 620 x = torch.randint(0, 2, size, dtype=torch.bool, device=device) 621 y = torch.randint(0, 2, size, dtype=torch.bool, device=device) 622 z = torch.randint(0, 2, size, dtype=torch.bool, device=device) 623 ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) 624 self.assertAllFused(ge.graph_for(x, y, z)) 625 626 def test_mul_bool(self): 627 for device in self.devices: 628 629 def f(x, y, z): 630 return x * y * z 631 632 x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) 633 y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) 634 z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) 635 636 ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) 637 self.assertAllFused(ge.graph_for(x, y, z)) 638 639 def test_div_bool(self): 640 for device in self.devices: 641 642 def f(x, y, z): 643 return (x + y) / z 644 645 x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) 646 y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) 647 z = torch.ones_like(x, dtype=torch.bool, device=device) 648 649 ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) 650 self.assertAllFused(ge.graph_for(x, y, z)) 651 652 def test_bitwise_ops(self): 653 def apply(fn): 654 return lambda x, y, z: fn(fn(x, y), z) 655 656 binary_ops = [ 657 operator.__and__, 658 operator.__or__, 659 operator.__xor__, 660 operator.__lshift__, 661 operator.__rshift__, 662 ] 663 devices = self.devices 664 for dtype, op, device in product(self.int_dtypes, binary_ops, devices): 665 try: 666 x = self.data_for(dtype, device) 667 y = self.data_for(dtype, device) 668 z = self.data_for(dtype, device) 669 fn = apply(op) 670 ref = fn(x, y, z) 671 except Exception: 672 # If eager mode doesn't support a dtype/op/device combo, 673 # neither does the fuser. Catch everything to avoid needing to 674 # guess what errors might be thrown by eager. 675 continue 676 try: 677 t = torch.jit.trace(fn, (x, y, z)) 678 self.assertEqual(ref, t(x, y, z)) 679 self.assertAllFused(t.graph_for(x, y, z)) 680 except Exception as e: 681 raise RuntimeError( 682 " ".join(["Failed:", str(dtype), op.__name__, device]) 683 ) from e 684 685 def test_minmax_int_ops(self): 686 def apply(fn): 687 return lambda x, y, z: fn(fn(x, y), z) 688 689 binary_ops = [torch.min, torch.max] 690 devices = self.devices 691 for dtype, op, device in product(self.int_dtypes, binary_ops, devices): 692 try: 693 x = self.data_for(dtype, device) 694 y = self.data_for(dtype, device) 695 z = self.data_for(dtype, device) 696 fn = apply(op) 697 ref = fn(x, y, z) 698 except Exception: 699 # If eager mode doesn't support a dtype/op/device combo, 700 # neither does the fuser. Catch everything to avoid needing to 701 # guess what errors might be thrown by eager. 702 continue 703 try: 704 t = torch.jit.trace(fn, (x, y, z)) 705 self.assertEqual(ref, t(x, y, z)) 706 self.assertAllFused(t.graph_for(x, y, z)) 707 except Exception as e: 708 raise RuntimeError( 709 " ".join(["Failed:", str(dtype), op.__name__, device]) 710 ) from e 711 712 def test_comparison_eq_ne(self): 713 for device in self.devices: 714 715 def f(x, y): 716 mask = (x == 0).type_as(x) 717 z = x * mask + y 718 mask = (x != 0).type_as(x) 719 z = z * mask + y 720 return z 721 722 x = torch.randn(4, 4, dtype=torch.float, device=device) 723 y = torch.randn(4, 4, dtype=torch.float, device=device) 724 725 ge = self.checkTrace(f, (x, y)) 726 self.assertAllFused(ge.graph_for(x, y)) 727 728 @staticmethod 729 def fn_test_comparison_gt_lt(x, y): 730 mask = (x > 0).type_as(x) 731 z = x * mask + y 732 mask = (x < 0).type_as(x) 733 z = z * mask + y 734 return z 735 736 def test_comparison_gt_lt(self): 737 for device in self.devices: 738 x = torch.randn(4, 4, dtype=torch.float, device=device) 739 y = torch.randn(4, 4, dtype=torch.float, device=device) 740 741 ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) 742 self.assertAllFused(ge.graph_for(x, y)) 743 744 def test_comparison_ge_le(self): 745 for device in self.devices: 746 747 def f(x, y): 748 mask = (x >= 0).type_as(x) 749 z = x * mask + y 750 mask = (x <= 0).type_as(x) 751 z = z * mask + y 752 return z 753 754 x = torch.randn(4, 4, dtype=torch.float, device=device) 755 y = torch.randn(4, 4, dtype=torch.float, device=device) 756 757 ge = self.checkTrace(f, (x, y)) 758 self.assertAllFused(ge.graph_for(x, y)) 759 x.requires_grad_(True) 760 y.requires_grad_(True) 761 self.assertAllFused( 762 ge.graph_for(x, y), 763 except_for=( 764 "aten::size", 765 "prim::BroadcastSizes", 766 "aten::_size_if_not_equal", 767 ), 768 ) 769 770 def test_addcmul(self): 771 for device in self.devices: 772 t = torch.randn(1, 4, dtype=torch.float, device=device) 773 t1 = torch.randn(4, 1, dtype=torch.float, device=device) 774 t2 = torch.randn(1, 4, dtype=torch.float, device=device) 775 776 def foo(t, t1, t2): 777 return t.addcmul(t + 1, t2, value=0.1) 778 779 ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) 780 graph = ge.graph_for(t, t1, t2) 781 fusion_groups = self.findFusionGroups(graph) 782 self.assertEqual(len(fusion_groups), 1) 783 FileCheck().check("aten::add(").check("aten::addcmul(").run( 784 str(fusion_groups[0]) 785 ) 786 787 # TODO: We leak CUDA memory here because the traced graph holds onto a 788 # constant-ified tensor. Since the Python-global CompilationUnit is alive 789 # until the end of the process, the memory is effectively leaked. 790 # Removed `_cuda` suffix from this test which disables leak-checking. 791 # If this is a real problem, we'll need to revisit Torchscript Function 792 # lifetimes in Python. 793 def test_lerp(self): 794 for device in self.devices: 795 start = torch.randn(4, 1, dtype=torch.float, device=device) 796 end = torch.randn(1, 4, dtype=torch.float, device=device) 797 weight = torch.tensor(0.5, dtype=torch.float, device=device) 798 799 # scalar weight overload 800 def foo_weight_scalar(start, end): 801 return torch.lerp(start + 1, end, 0.5) 802 803 # tensor weight overload 804 def foo_weight_tensor(start, end): 805 return torch.lerp(start + 1, end, weight) 806 807 ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end)) 808 graph = ge_weight_scalar.graph_for(start, end) 809 self.assertAllFused(graph) 810 811 # TODO: uncomment when TE enables support for scalar tensors 812 # ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end)) 813 # graph = ge_weight_tensor.graph_for(start, end) 814 # self.assertAllFused(graph) 815 816 def test_concat(self): 817 # disabling concat causes error with single concat node 818 with set_fusion_group_inlining(True): 819 for device in self.devices: 820 hx = torch.randn(3, 20, dtype=torch.float, device=device) 821 cx = torch.randn(3, 20, dtype=torch.float, device=device) 822 823 def foo(hx, cx): 824 return torch.cat((hx + cx, hx * cx)) 825 826 ge = self.checkTrace(foo, (hx, cx)) 827 graph = ge.graph_for(hx, cx) 828 self.assertAllFused(graph) 829 # XXX: TE fuser can handle concats in a fusion group. 830 # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) 831 832 def test_remove_output_used_only_in_size(self): 833 for device in self.devices: 834 835 def test_fuse(a, b): 836 c = a + b 837 d = c + b 838 return d 839 840 scripted_f = torch.jit.script(test_fuse) 841 x = torch.ones(1, requires_grad=True, device=device) 842 y = torch.ones(1, requires_grad=True, device=device) 843 warmup_forward(scripted_f, x, y, profiling_count=3) 844 g = scripted_f.graph_for(x, y) 845 diff_nodes = g.findAllNodes("prim::DifferentiableGraph") 846 self.assertEqual(len(diff_nodes), 1) 847 g = diff_nodes[0].g("Subgraph") 848 if_nodes = [n for n in g.nodes() if n.kind() == "prim::If"] 849 self.assertEqual(len(if_nodes), 1) 850 851 # the if node and the fusion group inside it should only have one output 852 self.assertEqual(len(list(if_nodes[0].outputs())), 1) 853 854 def test_concat_invariant(self): 855 for device in self.devices: 856 # Invariant: the output of prim::FusedConcat may 857 # not be an input to any node inside the FusionGroup. 858 def fn(x, y, z): 859 x1 = x + y 860 y1 = x - y 861 w = torch.cat([x1, y1]) 862 return w + z 863 864 x = torch.randn(2, 2, dtype=torch.float, device=device) 865 y = torch.randn(2, 2, dtype=torch.float, device=device) 866 z = torch.randn(4, 2, dtype=torch.float, device=device) 867 ge = self.checkTrace(fn, (x, y, z)) 868 graph = ge.graph_for(x, y, z) 869 self.assertAllFused(graph, except_for={"aten::add"}) 870 # XXX: TE fuser can handle concats inside a fusion group. 871 # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) 872 873 @staticmethod 874 def fn_test_exp(x, y): 875 return (x + 0.5 * y).exp() 876 877 def test_exp(self): 878 for device in self.devices: 879 x = torch.randn(4, 4, dtype=torch.float, device=device) 880 y = torch.randn(4, 4, dtype=torch.float, device=device) 881 882 ge = self.checkTrace(self.fn_test_exp, (x, y)) 883 self.assertAllFused(ge.graph_for(x, y)) 884 885 def test_threshold(self): 886 for device in self.devices: 887 888 def f(x): 889 return torch.threshold(x, 0, -10) + x + x + x 890 891 x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device=device) 892 scripted = self.checkScript(f, (x,)) 893 self.assertAllFused(scripted.graph_for(x)) 894 895 def test_scalar_arg(self): 896 for device in self.devices: 897 898 def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor: 899 return p * (x * x + x) 900 901 x = torch.randn(4, 4, dtype=torch.float, device=device) 902 p = 3 903 scripted = self.checkScript(fn_test_scalar_arg, (x, p)) 904 self.assertAllFused(scripted.graph_for(x, p)) 905 906 x.requires_grad_(True) 907 908 # use another function otherwise we will bailout 909 # and won't be able to do fused checks 910 def fn_test_scalar_arg_requires_grad( 911 x: torch.Tensor, p: float 912 ) -> torch.Tensor: 913 return p * (x * x + x) 914 915 scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) 916 out = scripted(x, p) 917 out = scripted(x, p) 918 out = scripted(x, p) 919 self.assertAllFused( 920 scripted.graph_for(x, p), 921 except_for=( 922 "aten::size", 923 "prim::BroadcastSizes", 924 "aten::_size_if_not_equal", 925 ), 926 ) 927 928 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 929 @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") 930 def test_fusion_reuse_multi_gpu(self): 931 def fn(x, y): 932 return x * y * x * y 933 934 inputs_cpu = [ 935 torch.randn(4, 4, dtype=torch.float), 936 torch.randn(4, 4, dtype=torch.float), 937 ] 938 inputs_cuda0 = [x.cuda(0) for x in inputs_cpu] 939 inputs_cuda1 = [y.cuda(1) for y in inputs_cpu] 940 941 # Should not crash; these should compile different kernels. 942 ge = self.checkScript(fn, inputs_cpu) 943 self.assertAllFused(ge.graph_for(*inputs_cpu)) 944 ge(*inputs_cuda0) 945 ge(*inputs_cuda1) 946 947 # TODO: we're currently not checking 'device' in the type info when pulling 948 # nodes into a fusion group. We should fix that and re-enable this test. 949 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 950 @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") 951 def test_kernel_cache_multi_gpu(self): 952 def not_fusible(x): 953 return x 954 955 def fn(x, y, z): 956 x_out = x * x * x * x * x # fusion: lambda x. x * x * x * x * x 957 y_out = y * y * y * y * y 958 z_out = z * z * z * z * z 959 return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out) 960 961 inputs = [ 962 torch.randn(4, 4, dtype=torch.float), 963 torch.randn(4, 4, dtype=torch.float, device="cuda:0"), 964 torch.randn(4, 4, dtype=torch.float, device="cuda:1"), 965 ] 966 967 prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() 968 969 # There are 3 FusionGroups. Because they have the same graph, they 970 # should reuse the same KernelSpec in the KernelSpec cache. 971 ge = self.checkScript(fn, inputs) 972 self.assertGraphContainsExactly(ge.graph_for(*inputs), FUSION_GROUP, 3, True) 973 new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() 974 # XXX: This assumes that the same kernel isn't already used by another test 975 # FIXME: Use the TE fuser's way of querying the cache. 976 # self.assertEqual(new_cache_size - prev_cache_size, 1) 977 978 @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") 979 def test_nonzero_device_cuda(self): 980 device = "cuda:" + str(1) 981 x = torch.tensor([0.4], dtype=torch.float, device=device) 982 y = torch.tensor([0.7], dtype=torch.float, device=device) 983 984 def doit(x, y): 985 return torch.sigmoid(torch.tanh(x * (x + y) + x)) 986 987 ge = self.checkTrace(doit, (x, y)) 988 self.assertAllFused(ge.graph_for(x, y)) 989 990 def test_lstm(self): 991 for device in self.devices: 992 inputs = get_lstm_inputs(device, training=True) 993 module = self.checkScript(LSTMCellS, inputs) 994 self.assertAllFused( 995 module.graph_for(inputs), except_for={"prim::TupleConstruct"} 996 ) 997 998 def test_lstm_concat(self): 999 # single fusion node causes error 1000 with set_fusion_group_inlining(True): 1001 for device in self.devices: 1002 inputs = get_lstm_inputs(device) 1003 ge = self.checkTrace(LSTMCellC, inputs) 1004 graph = ge.graph_for(*inputs) 1005 except_nodes = {"prim::TupleConstruct", "aten::linear"} 1006 # TODO... Chunk 1007 if self.dynamic_shapes: 1008 except_nodes = except_nodes.union( 1009 {"aten::add", "prim::ConstantChunk"} 1010 ) 1011 self.assertAllFused(ge.graph_for(*inputs), except_for=except_nodes) 1012 # XXX: TE fuser can handle concats inside a fusion group. 1013 # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) 1014 1015 def test_lstm_gates_permutations(self): 1016 for device in self.devices: 1017 # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. 1018 # Test that any permutation of this will still result in one FusionGroup. 1019 choices = ["x.mm(w_ih.t())", "hx.mm(w_hh.t())", "b_ih", "b_hh"] 1020 template = dedent( 1021 """ 1022 def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): 1023 gates = {} + {} + {} + {} 1024 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 1025 return ingate * forgetgate * cellgate * outgate 1026 """ 1027 ) 1028 for permutation in permutations(choices, len(choices)): 1029 code = template.format(*permutation) 1030 scope = {} 1031 exec(code, globals(), scope) 1032 cu = torch.jit.CompilationUnit(code) 1033 fusion_group_len = 2 if self.dynamic_shapes else 1 1034 inputs = get_lstm_inputs(device, training=False) 1035 self.assertEqual(cu.cell(*inputs), scope["cell"](*inputs)) 1036 forward_graph = cu.cell.graph_for(*inputs) 1037 self.assertGraphContainsExactly( 1038 forward_graph, FUSION_GROUP, fusion_group_len 1039 ) 1040 1041 # TODO: Fuser doesn't work at all when inputs require grad. Fix that 1042 def test_lstm_traced(self): 1043 for device in self.devices: 1044 inputs = get_lstm_inputs(device) 1045 ge = self.checkTrace(LSTMCellF, inputs) 1046 graph = ge.graph_for(*inputs) 1047 fusion_groups = self.findFusionGroups(graph) 1048 # TODO: chunk 1049 fusion_group_len = 2 if self.dynamic_shapes else 1 1050 self.assertEqual(len(fusion_groups), fusion_group_len) 1051 f = FileCheck() 1052 if not self.dynamic_shapes: 1053 f.check("Chunk") 1054 f.check("aten::sigmoid").check("aten::tanh").run( 1055 str(fusion_groups[0 if not self.dynamic_shapes else 1]) 1056 ) 1057 1058 def test_milstm(self): 1059 if self.dynamic_shapes: 1060 self.skipTest("don't run conv with dynamic shapes") 1061 1062 for device in self.devices: 1063 inputs = get_milstm_inputs(device, training=True) 1064 module = self.checkScript(MiLSTMCell, inputs) 1065 forward_graph = module.graph_for(*inputs) 1066 # TODO: chunk 1067 fusion_group_len = 2 if self.dynamic_shapes else 1 1068 self.assertGraphContainsExactly( 1069 forward_graph, FUSION_GROUP, fusion_group_len, consider_subgraphs=True 1070 ) 1071 FileCheck().check("DifferentiableGraph").check("TupleConstruct").check_next( 1072 "return" 1073 ).check(FUSION_GROUP).run(str(forward_graph)) 1074 hy, cy = module(*inputs) 1075 warmup_backward((hy + cy).sum()) 1076 1077 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 1078 @unittest.skip("rand_like is not supported yet") 1079 def test_rand_cuda(self): 1080 class M(torch.jit.ScriptModule): 1081 __constants__ = ["d"] 1082 1083 def __init__(self) -> None: 1084 super().__init__() 1085 self.d = torch.device("cuda") 1086 1087 @torch.jit.script_method 1088 def create(self, x): 1089 return x * x + x + torch.rand_like(x) 1090 1091 x = torch.zeros([3, 4, 5], dtype=torch.float, device="cuda") 1092 m = M() 1093 out1 = m.create(x) 1094 out2 = m.create(x) 1095 self.assertNotEqual(out1, out2) 1096 self.assertTrue(torch.all(out1 >= 0)) 1097 self.assertTrue(torch.all(out1 < 1)) 1098 self.assertTrue(torch.all(out2 >= 0)) 1099 self.assertTrue(torch.all(out2 < 1)) 1100 self.assertAllFused(m.create.graph_for(x)) 1101 1102 @staticmethod 1103 def fn_test_relu(x, y): 1104 return F.relu(x + 0.5 * y) 1105 1106 def test_relu(self): 1107 for device in self.devices: 1108 x = torch.randn(4, 4, dtype=torch.float, device=device) 1109 y = torch.randn(4, 4, dtype=torch.float, device=device) 1110 1111 ge = self.checkTrace(self.fn_test_relu, (x, y)) 1112 self.assertAllFused(ge.graph_for(x, y)) 1113 1114 def test_erf(self): 1115 for device in self.devices: 1116 # only enabled on gpu 1117 if device == "cpu": 1118 continue 1119 1120 def fn_test_erf(x): 1121 return F.relu(torch.erf(x) - torch.erfc(x)) 1122 1123 x = torch.randn(4, 4, dtype=torch.float, device=device) 1124 ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING) 1125 self.assertAllFused(ge.graph_for(x)) 1126 x.requires_grad_(True) 1127 ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING) 1128 self.assertAllFused( 1129 ge.graph_for(x), 1130 except_for=( 1131 "aten::size", 1132 "prim::BroadcastSizes", 1133 "aten::_size_if_not_equal", 1134 ), 1135 ) 1136 1137 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 1138 @unittest.skip("rand_like is not supported yet") 1139 def test_rand_broadcast_cuda(self): 1140 def fn_test_rand(x, y): 1141 r = torch.rand_like(y) 1142 return r * x + x 1143 1144 # If using profiling, a different function is needed to test different 1145 # shapes, or we'll use a cached script. 1146 def fn_test_rand2(x, y): 1147 r = torch.rand_like(y) 1148 return r * x * x 1149 1150 x = torch.randn(4, 4, dtype=torch.float, device="cuda") 1151 y = torch.randn(4, 4, dtype=torch.float, device="cuda") 1152 script_f = torch.jit.script(fn_test_rand) 1153 warmup_forward(script_f, x, y) 1154 out = script_f(x, y) 1155 self.assertAllFused(script_f.graph_for(x, y)) 1156 x.requires_grad_(True) 1157 out = script_f(x, y) 1158 self.assertAllFused( 1159 script_f.graph_for(x, y), 1160 except_for=( 1161 "aten::size", 1162 "prim::BroadcastSizes", 1163 "aten::_size_if_not_equal", 1164 ), 1165 ) 1166 1167 # test that broadcasting random produces correct results 1168 x = torch.ones(4, 4, dtype=torch.float, device="cuda") 1169 y = torch.ones(4, dtype=torch.float, device="cuda") 1170 script_f = torch.jit.script(fn_test_rand2) 1171 warmup_forward(script_f, x, y) 1172 out = script_f(x, y) 1173 self.assertEqual(out[0, :] + torch.zeros(4, 4, device="cuda"), out) 1174 1175 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 1176 @unittest.skip("rand_like is not supported yet") 1177 def test_rand_diamond(self): 1178 def fn_test_diamond(x, y): 1179 r = torch.rand_like(y) 1180 a = x + r 1181 b = y - r 1182 return a + b 1183 1184 x = torch.randn(4, 4, dtype=torch.float, device="cuda") 1185 y = torch.randn(4, 4, dtype=torch.float, device="cuda") 1186 script_f = torch.jit.script(fn_test_diamond) 1187 warmup_forward(script_f, x, y) 1188 out = script_f(x, y) 1189 self.assertEqual(out, x + y) 1190 1191 def test_scalar(self): 1192 def fn(x, y): 1193 return 2 * x + y 1194 1195 x = torch.tensor(0.1, dtype=torch.float, device="cpu") 1196 y = torch.tensor(1, dtype=torch.float, device="cpu") 1197 ge = self.checkScript(fn, (x, y)) 1198 self.assertAllFused(ge.graph_for(x, y)) 1199 1200 def test_inlined_optimized_graph(self): 1201 @torch.jit.script 1202 def foo(x): 1203 return torch.relu(x + x) 1204 1205 for _ in range(3): 1206 foo(torch.rand([4, 4])) 1207 1208 for _ in range(3): 1209 foo(torch.rand([10])) 1210 1211 for _ in range(3): 1212 foo(torch.rand([2, 2, 2])) 1213 1214 g = torch.jit.last_executed_optimized_graph() 1215 1216 FileCheck().check_count("prim::If", 1, exactly=True).check( 1217 "prim::TensorExpr" 1218 ).run(g) 1219 torch._C._jit_pass_inline(g) 1220 f = FileCheck() 1221 for _ in range(3): 1222 f.check("prim::If").check("prim::TensorExpr") 1223 f.run(g) 1224 1225 def test_small_constant(self): 1226 for device in self.devices: 1227 1228 def fn_test_small_constant(x, y): 1229 return (1e-8 * x + 5e-9 * y) * 1e8 1230 1231 x = torch.randn(4, 4, dtype=torch.float, device=device) 1232 y = torch.randn(4, 4, dtype=torch.float, device=device) 1233 1234 ge = self.checkTrace(fn_test_small_constant, (x, y)) 1235 self.assertAllFused(ge.graph_for(x, y)) 1236 1237 # Currently we don't pull constants into fusion groups, because in some 1238 # cases it could remove the constant from the original graph and now our 1239 # fusion group needs to return that constant for its other users. 1240 # Instead of never pulling constants into the fusion group, we should just 1241 # be more careful at how we rewrite its users. 1242 # TODO: fix that and reenable the test. 1243 def test_tensor_scalar_ops(self): 1244 for device in self.devices: 1245 1246 def should_fuse(x): 1247 z = 3.0 1248 y = x + z 1249 return x * y 1250 1251 def should_fuse_scalar(x, z): 1252 y = x + int(z) 1253 return x * y 1254 1255 inputs = [torch.randn(2, 2, dtype=torch.float, device=device)] 1256 ge = self.checkScript(should_fuse, inputs) 1257 graph = ge.graph_for(*inputs) 1258 fusion_groups = self.findFusionGroups(graph) 1259 self.assertEqual(len(fusion_groups), 1) 1260 FileCheck().check("aten::add").check("aten::mul").run(str(fusion_groups[0])) 1261 1262 inputs = [ 1263 torch.randn(2, 2, dtype=torch.float, device=device), 1264 torch.tensor(3.0, dtype=torch.float, device=device), 1265 ] 1266 ge = self.checkScript(should_fuse_scalar, inputs) 1267 # Check that the fused graph computes correct results when the scalar 1268 # input changes. 1269 inputs = [ 1270 torch.randn(2, 2, dtype=torch.float, device=device), 1271 torch.tensor(7.0, dtype=torch.float, device=device), 1272 ] 1273 self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs)) 1274 # The TE fuser supports fusion of non-constant scalars 1275 self.assertGraphContainsExactly( 1276 ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True 1277 ) 1278 1279 def test_where_and_typing(self): 1280 for device in self.devices: 1281 1282 def f(x, y): 1283 mask = x > y 1284 res = torch.where(mask, x, y) 1285 return mask, res 1286 1287 x = torch.randn(4, 4, dtype=torch.double, device=device) 1288 y = torch.randn(4, 4, dtype=torch.double, device=device) 1289 1290 script_f = self.checkScript(f, (x, y)) 1291 self.assertAllFused( 1292 script_f.graph_for(x, y), except_for={"prim::TupleConstruct"} 1293 ) 1294 1295 def test_disabled(self): 1296 old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() 1297 torch._C._jit_override_can_fuse_on_cpu(False) 1298 1299 def fn(a): 1300 return a**2 + a 1301 1302 x = torch.randn(4, dtype=torch.float, device="cpu") 1303 s = self.checkScript(fn, (x,)) 1304 g = s.graph_for(x) 1305 self.assertEqual(len(self.findFusionGroups(g)), 0) 1306 1307 torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state) 1308 1309 def data_for(self, dtype, device="cuda", size=None): 1310 if size is None: 1311 v = torch.arange(1, 3, dtype=torch.float, device=device) 1312 else: 1313 v = torch.rand(*size, device=device) 1314 if dtype == torch.bool: 1315 return v > 2 1316 elif dtype in [torch.qint8, torch.quint8, torch.qint32]: 1317 return torch.quantize_per_tensor(v, 0.1, 1, dtype=dtype) 1318 else: 1319 return v.to(dtype) 1320 1321 def test_torch_to(self): 1322 # test no op 1323 @torch.jit.script 1324 def foo(x): 1325 return x.to(torch.float) 1326 1327 foo(torch.tensor([3.0], dtype=torch.float)) 1328 foo(torch.tensor([3.0], dtype=torch.float)) 1329 FileCheck().check_not("TensorExpr").run( 1330 torch.jit.last_executed_optimized_graph() 1331 ) 1332 1333 # test not fusing non-const inputs 1334 @torch.jit.script 1335 def foo(x, dtype: int): 1336 return x.to(dtype) 1337 1338 foo(torch.tensor([3.0], dtype=torch.float), torch.int) 1339 foo(torch.tensor([3.0], dtype=torch.float), torch.int) 1340 FileCheck().check_not("TensorExpr").run( 1341 torch.jit.last_executed_optimized_graph() 1342 ) 1343 1344 # test not fusing to_pinned inputs 1345 @torch.jit.script 1346 def foo(x, dtype: int): 1347 return x.to(pin_memory=True) 1348 1349 foo(torch.tensor([3.0], dtype=torch.float), torch.int) 1350 foo(torch.tensor([3.0], dtype=torch.float), torch.int) 1351 FileCheck().check_not("TensorExpr").run( 1352 torch.jit.last_executed_optimized_graph() 1353 ) 1354 1355 # test across-device not supported 1356 if torch.cuda.is_available(): 1357 1358 @torch.jit.script 1359 def foo(x): 1360 return x.to(device="cuda") 1361 1362 foo(torch.tensor([3.0], dtype=torch.float)) 1363 foo(torch.tensor([3.0], dtype=torch.float)) 1364 FileCheck().check_not("TensorExpr").run( 1365 torch.jit.last_executed_optimized_graph() 1366 ) 1367 1368 sizes = [(1, 4), (4, 4)] 1369 # reuses cast impl, smaller dtype set for faster test 1370 dtypes = [ 1371 torch.bool, 1372 torch.int, 1373 torch.float16, 1374 torch.float32, 1375 torch.float64, 1376 ] 1377 1378 class MyMod(torch.nn.Module): 1379 def __init__(self, dtype): 1380 super().__init__() 1381 self.dtype = dtype 1382 1383 def forward(self, x): 1384 return x.to(self.dtype) 1385 1386 bad_dtypes = [] 1387 for dtype, output_dtype, device, size in product( 1388 dtypes, dtypes, self.devices, sizes 1389 ): 1390 # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed 1391 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1392 continue 1393 if dtype == output_dtype: 1394 continue 1395 1396 x = self.data_for(dtype, device, size=size) 1397 mod = MyMod(output_dtype) 1398 ref = mod.forward(x) 1399 # use freezing to make non-Tensor args to `to` constant 1400 mod = torch.jit.freeze(torch.jit.script(mod.eval())) 1401 warmup_forward(mod.forward, x) 1402 self.assertEqual(ref, mod.forward(x)) 1403 self.assertLastGraphAllFused() 1404 1405 @unittest.skip("Temporarily disabled") 1406 def test_masked_fill(self): 1407 dtypes = [ 1408 torch.int8, 1409 torch.int16, 1410 torch.int32, 1411 torch.int64, 1412 # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed 1413 # torch.float16, 1414 torch.float32, 1415 torch.float64, 1416 torch.bool, 1417 ] 1418 sizes = [(2,), (4, 4)] 1419 for self_dtype, device, scalar_val, size in product( 1420 dtypes, self.devices, [0.4, 3], sizes 1421 ): 1422 input_v = self.data_for(self_dtype, device, size=size) 1423 mask = self.data_for(torch.bool, device, size=size) 1424 1425 def fn(input_v, mask): 1426 return torch.masked_fill(input_v, mask, scalar_val) 1427 1428 ref = fn(input_v, mask) 1429 try: 1430 t = torch.jit.trace(fn, (input_v, mask)) 1431 torch.testing.assert_close(ref, t(input_v, mask)) 1432 self.assertLastGraphAllFused() 1433 except Exception as e: 1434 raise RuntimeError( 1435 " ".join( 1436 [ 1437 "Failed:", 1438 str(self_dtype), 1439 op.__name__, # noqa: F821 1440 device, 1441 str(size), 1442 ] 1443 ) 1444 ) from e 1445 1446 def test_isnan(self): 1447 x = torch.rand([4]) 1448 x[0] = float("nan") 1449 inputs = [x, torch.tensor([float("nan"), 0.5])] 1450 dtypes = [ 1451 torch.int8, 1452 torch.int16, 1453 torch.int32, 1454 torch.int64, 1455 torch.float16, 1456 torch.float32, 1457 torch.float64, 1458 torch.bool, 1459 ] 1460 1461 for inp, device, dtype in product(inputs, self.devices, dtypes): 1462 # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed 1463 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1464 continue 1465 inp = inp.to(device=device, dtype=dtype) 1466 try: 1467 f = torch.jit.trace(lambda x: x.isnan(), (inp,)) 1468 warmup_forward(f, inp) 1469 self.assertEqual(f(inp), inp.isnan()) 1470 self.assertLastGraphAllFused() 1471 except Exception as e: 1472 raise RuntimeError( 1473 " ".join(["Failed:", str(dtype), "isnan", device]) 1474 ) from e 1475 1476 def test_gelu(self): 1477 def apply(fn): 1478 return lambda x, approximate: fn(x, approximate) 1479 1480 unary_ops = [ 1481 F.gelu, 1482 ] 1483 sizes = [(1,), (2,), (4, 4)] 1484 for dtype, op, device, size in product( 1485 self.dtypes, unary_ops, self.devices, sizes 1486 ): 1487 # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed 1488 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1489 continue 1490 try: 1491 x = self.data_for(dtype, device, size=size) 1492 cond = self.data_for(torch.bool, device) 1493 fn = apply(op) 1494 ref = fn(x, cond) 1495 except Exception: 1496 # If eager mode doesn't support a dtype/op/device combo, 1497 # neither does the fuser. Catch everything to avoid needing to 1498 # guess what errors might be thrown by eager. 1499 continue 1500 try: 1501 t = torch.jit.trace(fn, (x, cond)) 1502 torch.testing.assert_close(ref, t(x, cond)) 1503 self.assertAllFused(t.graph_for(x, cond)) 1504 except Exception as e: 1505 raise RuntimeError( 1506 " ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) 1507 ) from e 1508 1509 def test_unary_ops(self): 1510 with torch._jit_internal._disable_emit_hooks(): 1511 1512 def apply(fn): 1513 return lambda x: fn(x) 1514 1515 unary_ops = [ 1516 torch.lgamma, 1517 torch.sigmoid, 1518 torch.reciprocal, 1519 torch.neg, 1520 torch.relu, 1521 F.relu6, 1522 torch.log, 1523 torch.log10, 1524 torch.log1p, 1525 torch.log2, 1526 torch.exp, 1527 torch.expm1, 1528 torch.erf, 1529 torch.erfc, 1530 torch.cos, 1531 torch.sin, 1532 torch.tan, 1533 torch.acos, 1534 torch.asin, 1535 torch.cosh, 1536 torch.sinh, 1537 torch.atan, 1538 torch.tanh, 1539 F.hardtanh, 1540 F.hardsigmoid, 1541 F.hardswish, 1542 F.softplus, 1543 F.silu, 1544 F.mish, 1545 F.elu, 1546 torch.sqrt, 1547 torch.rsqrt, 1548 torch.abs, 1549 # TODO broken on int8 since 1550 # https://github.com/pytorch/pytorch/pull/85144 1551 # RuntimeError: Invalid integral op_type: 23 1552 # torch.ceil, 1553 # torch.floor, 1554 # torch.round, 1555 # torch.trunc, 1556 torch.frac, 1557 # TODO: broken on ROCm? 1558 # F.hardshrink, 1559 F.leaky_relu, 1560 lambda x: torch.threshold(x, 0, -10), 1561 # TODO: broken since type promotion was added 1562 # lambda x: torch.clamp(x, -10, 10), 1563 ] 1564 gpu_only = {torch.erf, torch.erfc} 1565 sizes = [(1,), (2,), (4, 4)] 1566 for dtype, op, device, size in product( 1567 self.dtypes, unary_ops, self.devices, sizes 1568 ): 1569 # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed 1570 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1571 continue 1572 # todo - re-enable. fails with .500 1573 if dtype == torch.bfloat16 and op == torch.round: 1574 continue 1575 if op in gpu_only and device == "cpu": 1576 continue 1577 try: 1578 x = self.data_for(dtype, device, size=size) 1579 fn = apply(op) 1580 ref = fn(x) 1581 except Exception: 1582 # If eager mode doesn't support a dtype/op/device combo, 1583 # neither does the fuser. Catch everything to avoid needing to 1584 # guess what errors might be thrown by eager. 1585 continue 1586 try: 1587 t = torch.jit.trace(fn, (x,)) 1588 torch.testing.assert_close(ref, t(x)) 1589 self.assertAllFused(t.graph_for(x)) 1590 except Exception as e: 1591 raise RuntimeError( 1592 " ".join( 1593 ["Failed:", str(dtype), op.__name__, device, str(size)] 1594 ) 1595 ) from e 1596 1597 def test_binary_ops(self): 1598 def apply(fn): 1599 return lambda x, y: fn(x, y) 1600 1601 binary_ops = [ 1602 operator.__and__, 1603 operator.__or__, 1604 operator.__xor__, 1605 torch.add, 1606 torch.sub, 1607 torch.mul, 1608 torch.min, 1609 torch.max, 1610 lambda x, y: torch.lerp(x, y, 0.5), 1611 torch.atan2, 1612 torch.div, 1613 torch.eq, 1614 torch.ne, 1615 torch.ge, 1616 torch.gt, 1617 torch.lt, 1618 torch.fmod, 1619 torch.remainder, 1620 lambda x, y: y.type_as(x), 1621 ] 1622 fp_only = [ 1623 torch.fmod, 1624 torch.remainder, 1625 ] 1626 devices = self.devices 1627 for dtype, op, device in product(self.dtypes, binary_ops, devices): 1628 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1629 continue 1630 try: 1631 x = self.data_for(dtype, device) 1632 y = self.data_for(dtype, device) 1633 fn = apply(op) 1634 ref = fn(x, y) 1635 except Exception: 1636 # If eager mode doesn't support a dtype/op/device combo, 1637 # neither does the fuser. Catch everything to avoid needing to 1638 # guess what errors might be thrown by eager. 1639 continue 1640 try: 1641 t = torch.jit.trace(fn, (x, y)) 1642 self.assertEqual(ref, t(x, y)) 1643 if op not in fp_only or dtype.is_floating_point: 1644 self.assertAllFused(t.graph_for(x, y)) 1645 except Exception as e: 1646 raise RuntimeError( 1647 " ".join(["Failed:", str(dtype), op.__name__, device]) 1648 ) from e 1649 1650 def test_binary_scalar_ops(self): 1651 def apply(fn): 1652 return lambda x, y: fn(x, y) 1653 1654 ir_template = """ 1655 graph(%x : {dtype_x}, %y : {dtype_y}): 1656 %z = {op}(%x, %y) 1657 return (%z)""" 1658 1659 binary_ops = [ 1660 "aten::mul", 1661 "aten::add", 1662 "aten::sub", 1663 "aten::div", 1664 "aten::lt", 1665 "aten::le", 1666 "aten::eq", 1667 "aten::ne", 1668 "aten::gt", 1669 "aten::ge", 1670 "aten::__or__", 1671 "aten::__xor__", 1672 "aten::__and__", 1673 "aten::__lshift__", 1674 "aten::__rshift__", 1675 ] 1676 dtypes = ["int", "float", "bool"] 1677 values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]} 1678 devices = self.devices 1679 for dtype_x, dtype_y, op, device in product( 1680 dtypes, dtypes, binary_ops, devices 1681 ): 1682 code = ir_template.format(**locals()) 1683 1684 # Interpret the graph 1685 try: 1686 graph = torch._C.parse_ir(code) 1687 for x, y in product(values[dtype_x], values[dtype_y]): 1688 ref = torch._C._jit_interpret_graph(graph, (x, y)) 1689 except Exception: 1690 # If we can't interpret this IR, don't bother checking NNC. 1691 continue 1692 1693 # Compile the graph 1694 try: 1695 k = torch._C._te.TensorExprKernel(graph) 1696 except Exception as e: 1697 raise RuntimeError( 1698 " ".join(["Compilation failed:", device, str(code)]) 1699 ) from e 1700 1701 # Run the graph 1702 for x, y in product(values[dtype_x], values[dtype_y]): 1703 ref = torch._C._jit_interpret_graph(graph, (x, y)) 1704 try: 1705 res = k.run((x, y)) 1706 self.assertEqual(ref, res) 1707 except Exception as e: 1708 raise RuntimeError( 1709 " ".join( 1710 ["Failed at runtime:", device, str(x), str(y), str(code)] 1711 ) 1712 ) from e 1713 1714 def test_matmul(self): 1715 if self.dynamic_shapes: 1716 self.skipTest("don't run conv with dynamic shapes") 1717 1718 def fn(x, y): 1719 return torch.matmul(x, y) 1720 1721 devices = ["cpu"] # No cuda support for ext calls yet 1722 sizes = [ 1723 [[128, 128], [128, 128]], 1724 [[10, 10], [10, 10]], 1725 [[1, 16], [16, 128]], 1726 [[128], [128]], 1727 [[128], [128, 128]], 1728 [[3], [3]], 1729 [[3, 4], [4]], 1730 [[10, 3, 4], [4]], 1731 [[10, 3, 4], [10, 4, 5]], 1732 [[10, 3, 4], [4, 5]], 1733 ] 1734 1735 # Only 2D x 2D matrix multiply is supported. For non-supported sizes we 1736 # still want to run results verification to test that we didn't 1737 # accidentally fuse it, but we skip the 'is-fused' check. 1738 # TODO: add support for other shape combinations and make this set empty: 1739 skip_is_fused_check_sizes = [ 1740 "[[128], [128]]", 1741 "[[128], [128, 128]]", 1742 "[[3], [3]]", 1743 "[[3, 4], [4]]", 1744 "[[10, 3, 4], [4]]", 1745 "[[10, 3, 4], [10, 4, 5]]", 1746 "[[10, 3, 4], [4, 5]]", 1747 ] 1748 for dtype, size, device in product(self.dtypes, sizes, devices): 1749 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1750 continue 1751 try: 1752 size_x, size_y = size 1753 x = self.data_for(dtype, device, size=size_x) 1754 y = self.data_for(dtype, device, size=size_y) 1755 ref = fn(x, y) 1756 except Exception as e: 1757 # If eager mode doesn't support a dtype/op/device combo, 1758 # neither does the fuser. Catch everything to avoid needing to 1759 # guess what errors might be thrown by eager. 1760 continue 1761 try: 1762 t = torch.jit.trace(fn, (x, y)) 1763 t(x, y) 1764 self.assertEqual(ref, t(x, y)) 1765 if str(size) not in skip_is_fused_check_sizes: 1766 self.assertAllFused(t.graph_for(x, y)) 1767 except Exception as e: 1768 raise RuntimeError(" ".join(["Failed:", str(dtype), device])) from e 1769 1770 def test_binary_tensor_scalar_ops(self): 1771 with torch._jit_internal._disable_emit_hooks(): 1772 1773 def apply_with_scalar(fn, scalar): 1774 return lambda x: fn(x, scalar) 1775 1776 # FIXME: Fails in IR Eval: torch.int64 and_ cpu 1777 binary_ops = [ 1778 operator.__and__, 1779 operator.__or__, 1780 operator.__xor__, 1781 torch.add, 1782 torch.sub, 1783 torch.mul, 1784 torch.eq, 1785 torch.ne, 1786 torch.ge, 1787 torch.lt, 1788 torch.gt, 1789 ] 1790 devices = self.devices 1791 # Maybe we should split this into separate tests to speed it up by 1792 # only using scalar values relevant to particular ops 1793 scalars = [1.5, 3, 0, -2.0, -1] 1794 for dtype, op, device, scalar in product( 1795 self.dtypes, binary_ops, devices, scalars 1796 ): 1797 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1798 continue 1799 try: 1800 x = self.data_for(dtype, device) 1801 fn = apply_with_scalar(op, scalar) 1802 ref = fn(x) 1803 except Exception: 1804 # If eager mode doesn't support a dtype/op/device combo, 1805 # neither does the fuser. Catch everything to avoid needing to 1806 # guess what errors might be thrown by eager. 1807 continue 1808 try: 1809 t = torch.jit.trace(fn, (x)) 1810 self.assertEqual(ref, t(x)) 1811 self.assertAllFused(t.graph_for(x)) 1812 except Exception as e: 1813 raise RuntimeError( 1814 " ".join(["Failed:", str(dtype), op.__name__, device]) 1815 ) from e 1816 1817 def test_binary_div_ops(self): 1818 def apply_with_scalar(fn, scalar): 1819 return lambda x: fn(x, scalar) 1820 1821 binary_ops = [ 1822 torch.div, 1823 torch.remainder, 1824 torch.fmod, 1825 ] 1826 devices = self.devices 1827 # Maybe we should split this into separate tests to speed it up by 1828 # only using scalar values relevant to particular ops 1829 scalars = [1.5, 3, -2.0, -1] # skip 0 1830 for dtype, op, device, scalar in product( 1831 self.dtypes, binary_ops, devices, scalars 1832 ): 1833 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1834 continue 1835 try: 1836 x = self.data_for(dtype, device) 1837 fn = apply_with_scalar(op, scalar) 1838 ref = fn(x) 1839 except Exception: 1840 # If eager mode doesn't support a dtype/op/device combo, 1841 # neither does the fuser. Catch everything to avoid needing to 1842 # guess what errors might be thrown by eager. 1843 continue 1844 try: 1845 t = torch.jit.trace(fn, (x)) 1846 self.assertEqual(ref, t(x)) 1847 except Exception as e: 1848 raise RuntimeError( 1849 f"Failed: {dtype} {op.__name__} {device} {scalar}" 1850 ) from e 1851 1852 def test_binary_pow(self): 1853 def apply_with_scalar(fn, scalar): 1854 return lambda x: fn(x, scalar) 1855 1856 dtypes = [ 1857 # FIXME: 'pow' fails with dtype=torch.float16/device=cuda/scalar=0 1858 # torch.float16, 1859 torch.float32, 1860 torch.float64, 1861 # torch.bool intentionally not included 1862 ] 1863 binary_ops = [ 1864 torch.pow, 1865 ] 1866 # Maybe we should split this into separate tests to speed it up by 1867 # only using scalar values relevant to particular ops 1868 scalars = [1.5, 3, 0, -2.0, -1] 1869 for dtype, op, device, scalar in product( 1870 dtypes, binary_ops, self.devices, scalars 1871 ): 1872 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1873 continue 1874 try: 1875 x = self.data_for(dtype, device) 1876 fn = apply_with_scalar(op, scalar) 1877 ref = fn(x) 1878 except Exception: 1879 # If eager mode doesn't support a dtype/op/device combo, 1880 # neither does the fuser. Catch everything to avoid needing to 1881 # guess what errors might be thrown by eager. 1882 continue 1883 try: 1884 t = torch.jit.trace(fn, (x)) 1885 self.assertEqual(ref, t(x)) 1886 self.assertAllFused(t.graph_for(x)) 1887 except Exception as e: 1888 raise RuntimeError( 1889 " ".join(["Failed:", str(dtype), op.__name__, device]) 1890 ) from e 1891 1892 def test_ternary_ops(self): 1893 def apply(fn): 1894 return lambda x, y, z: fn(x, y, z) 1895 1896 ternary_ops = [ 1897 torch.lerp, 1898 torch.addcmul, 1899 ] 1900 devices = self.devices 1901 for dtype, op, device in product(self.dtypes, ternary_ops, devices): 1902 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1903 continue 1904 try: 1905 x = self.data_for(dtype, device) 1906 y = self.data_for(dtype, device) 1907 z = self.data_for(dtype, device) 1908 fn = apply(op) 1909 ref = fn(x, y, z) 1910 except Exception: 1911 # If eager mode doesn't support a dtype/op/device combo, 1912 # neither does the fuser. Catch everything to avoid needing to 1913 # guess what errors might be thrown by eager. 1914 continue 1915 try: 1916 t = torch.jit.trace(fn, (x, y, z)) 1917 self.assertEqual(ref, t(x, y, z)) 1918 self.assertAllFused(t.graph_for(x, y, z)) 1919 except Exception as e: 1920 raise RuntimeError( 1921 " ".join(["Failed:", str(dtype), op.__name__, device]) 1922 ) from e 1923 1924 def test_ternary_norm_ops(self): 1925 def apply(fn): 1926 return lambda x, y, z: fn(x, y, z) 1927 1928 ternary_ops = [ 1929 F.batch_norm, 1930 ] 1931 devices = self.devices 1932 for dtype, op, device in product(self.dtypes, ternary_ops, devices): 1933 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1934 continue 1935 try: 1936 x = self.data_for(dtype, device, size=[5, 3, 128, 128]) 1937 y = self.data_for(dtype, device, size=[3]) 1938 z = self.data_for(dtype, device, size=[3]) 1939 fn = apply(op) 1940 ref = fn(x, y, z) 1941 except Exception: 1942 # If eager mode doesn't support a dtype/op/device combo, 1943 # neither does the fuser. Catch everything to avoid needing to 1944 # guess what errors might be thrown by eager. 1945 continue 1946 try: 1947 t = torch.jit.trace(fn, (x, y, z)) 1948 self.assertEqual(ref, t(x, y, z)) 1949 self.assertAllFused(t.graph_for(x, y, z)) 1950 except Exception as e: 1951 raise RuntimeError( 1952 " ".join(["Failed:", str(dtype), op.__name__, device]) 1953 ) from e 1954 1955 @unittest.skip( 1956 "FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure" 1957 ) 1958 def test_list_ops(self): 1959 def apply(fn): 1960 return lambda x, y, z: fn([x * x, y * y, z * z]) 1961 1962 devices = self.devices 1963 list_ops = [ 1964 torch.cat, 1965 ] 1966 for dtype, op, device in product(self.dtypes, list_ops, devices): 1967 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 1968 continue 1969 try: 1970 x = self.data_for(dtype, device, size=[5, 4, 1, 7]) 1971 y = self.data_for(dtype, device, size=[5, 4, 1, 7]) 1972 z = self.data_for(dtype, device, size=[5, 4, 1, 7]) 1973 fn = apply(op) 1974 ref = fn(x, y, z) 1975 except Exception: 1976 # If eager mode doesn't support a dtype/op/device combo, 1977 # neither does the fuser. Catch everything to avoid needing to 1978 # guess what errors might be thrown by eager. 1979 continue 1980 try: 1981 t = torch.jit.trace(fn, (x, y, z)) 1982 self.assertEqual(ref, t(x, y, z)) 1983 self.assertAllFused(t.graph_for(x, y, z)) 1984 except Exception as e: 1985 raise RuntimeError( 1986 " ".join(["Failed:", str(dtype), op.__name__, device]) 1987 ) from e 1988 1989 def test_where_ops(self): 1990 def apply(fn): 1991 return lambda cond, x, y: fn(cond, x, y) 1992 1993 ops = [ 1994 torch.where, 1995 lambda cond, x, y: torch.where(cond, x, 3.1415), 1996 lambda cond, x, y: torch.where(cond, 42, y), 1997 ] 1998 devices = self.devices 1999 for dtype, op, device in product(self.dtypes, ops, devices): 2000 if dtype in [torch.float16, torch.bfloat16] and device == "cpu": 2001 continue 2002 try: 2003 cond = self.data_for(torch.bool, device) 2004 x = self.data_for(dtype, device) 2005 y = self.data_for(dtype, device) 2006 fn = apply(op) 2007 ref = fn(cond, x, y) 2008 except Exception: 2009 # If eager mode doesn't support a dtype/op/device combo, 2010 # neither does the fuser. Catch everything to avoid needing to 2011 # guess what errors might be thrown by eager. 2012 continue 2013 try: 2014 t = torch.jit.trace(fn, (cond, x, y)) 2015 self.assertEqual(ref, t(cond, x, y)) 2016 self.assertAllFused(t.graph_for(cond, x, y)) 2017 except Exception as e: 2018 raise RuntimeError( 2019 " ".join(["Failed:", str(dtype), op.__name__, device]) 2020 ) from e 2021 2022 def test_unsupported_dtypes(self): 2023 for device in self.devices: 2024 2025 def fn(x): 2026 return x * x + x 2027 2028 unsupported_dtypes = [ 2029 torch.uint8, 2030 torch.complex32, 2031 torch.complex64, 2032 torch.complex128, 2033 torch.qint8, 2034 torch.quint8, 2035 torch.qint32, 2036 ] 2037 for dtype in unsupported_dtypes: 2038 try: 2039 x = self.data_for(dtype, device) 2040 ref = fn(x) 2041 except Exception: 2042 # If eager mode doesn't support a dtype/op/device combo, 2043 # neither does the fuser. Catch everything to avoid needing to 2044 # guess what errors might be thrown by eager. 2045 continue 2046 t = torch.jit.trace(fn, (x,)) 2047 self.assertEqual(ref, t(x)) 2048 self.assertEqual(len(self.findFusionGroups(t.graph_for(x))), 0) 2049 2050 def test_superslomo(self): 2051 devices = self.devices.copy() 2052 if not LLVM_ENABLED: 2053 devices.remove("cpu") 2054 for device in devices: 2055 # Test extracted from Super-SloMo: https://github.com/avinashpaliwal/Super-SloMo 2056 # A few interesting things happen here: strided inputs of mixed size, 2057 # plus outputs of mixed shapes. The latter characteristic happened to 2058 # expose a memory corruption bug due to not properly guarding the 2059 # outputs. 2060 def eager(t0, t1, t2, t3, t4): 2061 t5 = torch.mul(t0, t4) 2062 t6 = torch.mul(t2, t3) 2063 t7 = torch.mul(t6, t1) 2064 t9 = torch.add(t5, t7) 2065 t11 = torch.add(t0, t6) 2066 ft_p = torch.div(t9, t11) 2067 return (ft_p, t11, t9, t6) 2068 2069 t0 = torch.rand(1, 6, 352, 352, device=device).transpose(0, 1) 2070 t1 = torch.rand(6, 3, 352, 352, device=device) 2071 t2 = torch.rand(6, device=device)[None, None, None, :].permute(3, 0, 1, 2) 2072 t3 = torch.rand(6, 1, 352, 352, device=device) 2073 t4 = torch.rand(6, 3, 352, 352, device=device) 2074 inputs = [t0, t1, t2, t3, t4] 2075 2076 script = torch.jit.script(eager) 2077 for _ in range(4): 2078 for pair in zip(script(*inputs), eager(*inputs)): 2079 test, ref = pair 2080 torch.testing.assert_close(test, ref) 2081 self.assertAllFused( 2082 script.graph_for(*inputs), except_for={"prim::TupleConstruct"} 2083 ) 2084 2085 def test_sub_gt_and(self): 2086 for device in self.devices: 2087 2088 def eager(t1, t2, t3, t4, t: float): 2089 w = t1 - t2 2090 h = t3 - t4 2091 k = (w > t) & (h > t) 2092 assert k.dtype == torch.bool 2093 if t > 0.5: 2094 # Putting a use of k in a never-executed conditional prevents 2095 # profiling its type, which leaves it as "Tensor". If we 2096 # propagate Tensor back to the definition of k, we have to be 2097 # careful not to create a fusion group containing it. 2098 return k + 1 2099 return w 2100 2101 t = torch.rand(8, dtype=torch.float, device=device) 2102 scripted = self.checkScript(eager, (t, t, t, t, 0.1)) 2103 2104 @skipIfTorchDynamo("too slow") 2105 def test_chunk_mul_one(self): 2106 if self.dynamic_shapes: 2107 self.skipTest("TODO: chunk dynamic shapes") 2108 2109 for device in self.devices: 2110 2111 def eager(x): 2112 z, y, w = torch.chunk(x, 3, -1) 2113 return z * 3, y, w 2114 2115 x = torch.rand(64, 1, 3072, dtype=torch.float, device=device) 2116 z, y, w = eager(x) 2117 script = self.checkScript(eager, (x,)) 2118 2119 def test_eq_unsqueeze_type_as(self): 2120 for device in self.devices: 2121 2122 def eager(a, b): 2123 mask = b == 1 2124 mask = torch.unsqueeze(mask, -1) 2125 x = mask.type_as(a) 2126 return x, mask 2127 2128 a = torch.rand(1, 64, 1024, device=device, dtype=torch.float) 2129 b = torch.randint(-2, 2, (1, 64), device=device, dtype=torch.long) 2130 script = self.checkScript(eager, (a, b)) 2131 2132 def test_neg_pow(self): 2133 def eager_tt(a: torch.Tensor, b: torch.Tensor): 2134 return torch.neg(torch.pow(a, b)) 2135 2136 def eager_ts(a: torch.Tensor, b: float): 2137 return torch.neg(torch.pow(a, b)) 2138 2139 def eager_st(a: float, b: torch.Tensor): 2140 return torch.neg(torch.pow(a, b)) 2141 2142 a = torch.rand(1, dtype=torch.float) 2143 b = torch.rand(1, dtype=torch.float) 2144 s = b.item() 2145 script = self.checkScript(eager_tt, (a, b)) 2146 # TODO: re-enable fusion, which doesn't work right now. just test correctness for now 2147 # self.assertAllFused(script.graph_for(a, b)) 2148 script = self.checkScript(eager_ts, (a, s)) 2149 # self.assertAllFused(script.graph_for(a, s)) 2150 script = self.checkScript(eager_st, (s, b)) 2151 # self.assertAllFused(script.graph_for(s, b)) 2152 2153 @unittest.skipIf(not LLVM_ENABLED, "Too slow to run with the TE interpreter") 2154 def test_conv2d_depthwise(self): 2155 if self.dynamic_shapes: 2156 self.skipTest("don't run conv with dynamic shapes") 2157 2158 def eager(input, weight, bias): 2159 return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=72) 2160 2161 input = torch.rand((1, 72, 56, 56), dtype=torch.float) 2162 weight = torch.rand((72, 1, 3, 3), dtype=torch.float) 2163 bias = torch.rand((72), dtype=torch.float) 2164 2165 script = self.checkScript(eager, (input, weight, bias)) 2166 self.assertAllFused(script.graph_for(input, weight, bias)) 2167 2168 def test_conv2d(self): 2169 if self.dynamic_shapes: 2170 self.skipTest("don't run conv with dynamic shapes") 2171 2172 def eager(input, weight, bias): 2173 return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=1) 2174 2175 input = torch.rand((1, 64, 56, 56), dtype=torch.float) 2176 weight = torch.rand((64, 64, 3, 3), dtype=torch.float) 2177 bias = torch.rand((64), dtype=torch.float) 2178 2179 script = self.checkScript(eager, (input, weight, bias)) 2180 FileCheck().check_not("TensorExpr").run( 2181 torch.jit.last_executed_optimized_graph() 2182 ) 2183 2184 def test_type_as_cat(self): 2185 with inline_fusion_groups(): 2186 2187 def eager(x, y): 2188 return torch.cat((x, y.type_as(x)), dim=1) 2189 2190 dtypes = self.dtypes.copy() 2191 # CPU fuser doesn't support float16. 2192 dtypes.remove(torch.float16) 2193 dtypes.remove(torch.bfloat16) 2194 for dtype1, dtype2 in product(dtypes, dtypes): 2195 x = torch.randint(2, (1, 13)).to(dtype1) 2196 zero = torch.tensor([[0]]).to(dtype2) 2197 one = torch.tensor([[1]]).to(dtype2) 2198 script = torch.jit.trace(eager, (x, zero)) 2199 for _ in range(3): 2200 torch.testing.assert_close(script(x, zero), eager(x, zero)) 2201 torch.testing.assert_close(script(x, one), eager(x, one)) 2202 self.assertAllFused(script.graph_for(x, one)) 2203 2204 def test_to_device(self): 2205 def eager(x): 2206 return x.to(device="cpu").relu() 2207 2208 x = torch.rand(8) 2209 script = self.checkScript(eager, (x,)) 2210 self.assertAllFused(script.graph_for(x)) 2211 2212 def test_dims(self): 2213 def eager(x, y): 2214 return x / (y + 0.0001) 2215 2216 x = torch.linspace(-1, 1, 768, dtype=torch.float32).as_strided( 2217 (1, 1, 768), (768, 1, 1) 2218 ) 2219 y = torch.tensor([[[2.0]]], dtype=torch.float32) 2220 script = self.checkScript(eager, (x, y)) 2221 self.assertAllFused(script.graph_for(x, y)) 2222 2223 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 2224 def test_channels_last_dims_dynamic(self): 2225 def eager(x, y): 2226 return x + (y + 0.0001) 2227 2228 indices = [0, 1, 2, 3] 2229 sets = [] 2230 for i in range(0, len(indices) + 1): 2231 for subset in combinations(indices, i): 2232 sets.append(subset) # noqa: PERF402 2233 2234 for set in sets: 2235 size = [2, 3, 4, 5] 2236 for index in set: 2237 size[index] = 1 2238 inp = torch.rand(size).to(memory_format=torch.channels_last).cuda() 2239 with texpr_enable_strategy([("DYNAMIC", 20)]): 2240 foo_s = torch.jit.trace(eager, (inp, inp)) 2241 for _ in range(3): 2242 out = foo_s(inp, inp) 2243 out_eager = eager(inp, inp) 2244 self.assertEqual(out_eager, out) 2245 self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 2246 g = torch.jit.last_executed_optimized_graph() 2247 FileCheck().check("TensorExpr").run(g) 2248 2249 def test_exhaust_specializations(self): 2250 with texpr_enable_strategy([("STATIC", 1)]): 2251 2252 @torch.jit.script 2253 def foo(x): 2254 return x + x + x 2255 2256 for _ in range(3): 2257 foo(torch.rand([2, 2])) 2258 2259 for _ in range(3): 2260 foo(torch.rand([4, 4, 4])) 2261 2262 g = torch.jit.last_executed_optimized_graph() 2263 torch._C._jit_pass_inline(g) 2264 2265 FileCheck().check_count("TensorExpr", 2, exactly=True).run(g) 2266 2267 def test_unsqueeze_var_dim(self): 2268 def eager(x, y, z: int): 2269 return x * torch.unsqueeze(y, dim=z) 2270 2271 x = torch.rand(4, 4, 64).permute(1, 0, 2) 2272 y = torch.rand(4, 4) 2273 z = 2 2274 script = self.checkScript(eager, (x, y, z)) 2275 2276 def _test_fwd_bwd(self, fn): 2277 x = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True) 2278 xs = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True) 2279 script = torch.jit.script(fn) 2280 for i in range(11): 2281 y = fn(x) 2282 g0 = torch.rand_like(y) 2283 y.backward(g0) 2284 2285 ys = script(xs) 2286 ys.backward(g0) 2287 2288 with torch.no_grad(): 2289 x -= 0.1 * x.grad 2290 xs -= 0.1 * xs.grad 2291 x.grad = None 2292 xs.grad = None 2293 torch.testing.assert_close(y, ys) 2294 2295 def test_relu_fwd_bwd(self): 2296 def eager(x): 2297 return torch.relu(x * 1.01) 2298 2299 self._test_fwd_bwd(eager) 2300 2301 def test_hardswish_fwd_bwd(self): 2302 def eager(x): 2303 return F.hardswish(x) * 1.01 2304 2305 self._test_fwd_bwd(eager) 2306 2307 def test_hardsigmoid_fwd_bwd(self): 2308 def eager(x): 2309 return F.hardsigmoid(x) * 1.01 2310 2311 self._test_fwd_bwd(eager) 2312 2313 def test_cat_graph_opt(self): 2314 def foo(x, y, z): 2315 return torch.log(torch.cat([x, y, z])) 2316 2317 self.checkScript( 2318 foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5])) 2319 ) 2320 # TODO: not sure why not updated graph isn't reflected in last_optimized_graph 2321 self.assertLastGraphAllFused() 2322 2323 def test_dynamic_cat(self): 2324 with inline_fusion_groups(): 2325 2326 @torch.jit.script 2327 def repro( 2328 xs: List[torch.Tensor], ys: List[torch.Tensor], zs: List[torch.Tensor] 2329 ): 2330 return [ 2331 torch.cat([x, torch.cat([y, z], dim=-1)], dim=-1) 2332 for x, y, z in zip(xs, ys, zs) 2333 ] 2334 2335 for _ in range(3): 2336 N = 3 2337 xs = [torch.ones(21) for _ in range(N)] 2338 # Note: concat of ys and zs will have the same size for each 2339 # pair, even though the individual ys and zs do not. 2340 ys = [torch.ones(N - i) for i in range(N)] 2341 zs = [torch.ones(i) for i in range(N)] 2342 repro(xs, ys, zs) 2343 2344 def test_scalar_only_inputs(self): 2345 def eager(b: float): 2346 a = torch.ones(1) 2347 return a * b 2348 2349 script = self.checkScript(eager, (1.0,)) 2350 2351 def test_cat_2k_args(self): 2352 with inline_fusion_groups(): 2353 2354 def eager(x): 2355 return torch.relu(torch.cat([x for _ in range(2000)])) 2356 2357 x = torch.randn(1) 2358 trace = self.checkTrace(eager, (x,)) 2359 fusion_groups = self.findFusionGroups(trace.graph_for(x)) 2360 self.assertEqual(len(fusion_groups), 0) 2361 2362 def test_adaptive_avg_pool2d(self): 2363 # TODO: once the adaptive_avg_pool2d is available in OpInfo DB, this 2364 # test should be moved there 2365 with inline_fusion_groups(): 2366 2367 def foo1(x): 2368 return torch.nn.functional.adaptive_avg_pool2d(x, (2, 2)) 2369 2370 def foo2(x): 2371 return torch.nn.functional.adaptive_avg_pool2d(x, (2)) 2372 2373 x = torch.randn(4, 4, 4) 2374 for foo in [foo1, foo2]: 2375 f = torch.jit.trace(foo, (x,)) 2376 kernel = torch._C._te.TensorExprKernel(f.graph) 2377 correct_val = f(x) 2378 self.assertEqual(kernel.run((x,)), correct_val) 2379 2380 def test_unrolled_cat(self): 2381 with inline_fusion_groups(): 2382 2383 def eager(x): 2384 ret = torch.empty(0) 2385 for i in range(x.shape[0]): 2386 ret = torch.cat([ret, x[i].relu()]) 2387 return ret 2388 2389 script = torch.jit.script(eager) 2390 2391 # Warm up with size=1 tensor; since the loop iterates once the 2392 # profile data will be "burned in" assuming size=1, and then 2393 # unrolled. 2394 x = torch.ones(1, 1) 2395 for _ in range(3): 2396 script(x) 2397 2398 torch.testing.assert_close(eager(x), script(x)) 2399 2400 # Now when an input hits the unrolled path, it will produce an 2401 # incorrectly-sized tensor, since size=1 has been burned in. 2402 x = torch.ones((8, 1)) 2403 torch.testing.assert_close(eager(x), script(x)) 2404 2405 @skipIfTorchDynamo("too slow") 2406 @unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan") 2407 @unittest.skipIf(TEST_WITH_ROCM, "Tensor-likes are not close for nans") 2408 def test_batch_norm(self): 2409 def test(fn, args): 2410 trace = torch.jit.trace(fn, args) 2411 self.assertAllFused(trace.graph_for(*args)) 2412 # TODO: Are `NaN`'s actually ok here or did this pass silently before, because `equal_nan=True` was the 2413 # default? 2414 torch.testing.assert_close(fn(*args), trace(*args), equal_nan=True) 2415 2416 def bn(i, x): 2417 return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu() 2418 2419 def bn_no_weight(i, x): 2420 return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu() 2421 2422 def bn_no_bias(i, x): 2423 return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu() 2424 2425 def bn_neither(i, x): 2426 return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu() 2427 2428 for device in self.devices: 2429 i = torch.randn(4, 16, 32, 40, device=device) 2430 x = torch.randn(16, device=device) 2431 for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]: 2432 test(fn, (i, x)) 2433 2434 def test_profiler(self): 2435 @torch.jit.script 2436 def test(x, y, z): 2437 return x * y + z 2438 2439 args = [torch.randn(4) for _ in range(3)] 2440 with torch.autograd.profiler.profile() as prof: 2441 for _ in range(3): 2442 test(*args) 2443 self.assertIn("fused_mul_add", prof.table()) 2444 2445 def test_skip_grad_in_check(self): 2446 @torch.jit.script 2447 def foo(x): 2448 return (x + 2) / 2 2449 2450 inp = torch.rand([4, 4]) 2451 for _ in range(3): 2452 foo(inp) 2453 2454 inp.requires_grad_(True) 2455 with torch.inference_mode(): 2456 for _ in range(3): 2457 foo(inp) 2458 g = torch.jit.last_executed_optimized_graph() 2459 torch._C._jit_pass_inline(g) 2460 torch._C._jit_pass_inline(g) 2461 FileCheck().check_count("prim::If", 1, exactly=True).run(g) 2462 2463 def test_dynamic_shapes(self): 2464 from functools import partial 2465 2466 n = 10 2467 2468 gen_tensor = ( 2469 lambda n: R(1, n), 2470 lambda n: R(n, n), 2471 lambda n: R(n, n).transpose(0, 1), 2472 lambda n: R(n + 1, n + 1, 2)[:n, n, 0], 2473 lambda n: R(n, n, 2)[:, :, 0], 2474 lambda n: R(n, n + 1, n + 2, n + 3).to(memory_format=torch.channels_last), 2475 ) 2476 2477 with texpr_enable_strategy([("DYNAMIC", 20)]): 2478 2479 def foo(x, y, z): 2480 return torch.sigmoid(torch.tanh(x)) 2481 2482 foo.__disable_jit_function_caching__ = True 2483 2484 def fi(x, y, z): 2485 return torch.tanh(x + y) 2486 2487 fi.__disable_jit_function_caching__ = True 2488 2489 def fum(x, y, z): 2490 return torch.tanh(x + y) + z 2491 2492 fum.__disable_jit_function_caching__ = True 2493 2494 funcs = [foo, fi, fum] 2495 with inline_fusion_groups(): 2496 for device in self.devices: 2497 I = partial(torch.randint, 0, 100, device=device) 2498 R = partial(torch.randn, device=device) 2499 2500 for i, func in enumerate(funcs): 2501 num_args = i + 1 2502 for j, gen in enumerate(gen_tensor): 2503 inps = (gen(n), gen(n), gen(n)) 2504 func_s = torch.jit.trace(func, inps, check_trace=False) 2505 torch._C._jit_pass_erase_shape_information(func_s.graph) 2506 for _ in range(2): 2507 x, y, z = gen(n), gen(n), gen(n) 2508 func_s(x, y, z) 2509 2510 for incr in range(3): 2511 func_s(*[gen(n + 1) for _ in range(3)]) 2512 2513 g = torch.jit.last_executed_optimized_graph() 2514 torch._C._jit_pass_inline(g) 2515 torch._C._jit_pass_dce(g) 2516 2517 # We should see only one optimized kernel 2518 FileCheck().check_count( 2519 "TensorExprDynamicGuard", 1, exactly=True 2520 ).run(g) 2521 self.assertEqual(func(*inps), func_s(*inps)) 2522 2523 gen = gen_tensor[0] 2524 inps = (gen(n), gen(n), gen(n)) 2525 foo_s = torch.jit.trace(foo, inps) 2526 torch._C._jit_pass_erase_shape_information(foo_s.graph) 2527 g_prev = None 2528 for gen in gen_tensor: 2529 for i in range(3): 2530 foo_s(*[gen(n + i) for _ in range(3)]) 2531 inps = (gen(n), gen(n), gen(n)) 2532 self.assertEqual(foo_s(*inps), foo(*inps)) 2533 g = torch.jit.last_executed_optimized_graph() 2534 torch._C._jit_pass_inline(g) 2535 torch._C._jit_pass_dce(g) 2536 FileCheck().check_count( 2537 "TensorExprDynamicGuard", len(gen_tensor), exactly=True 2538 ).run(g) 2539 2540 @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA") 2541 def test_autocast_up(self): 2542 def f(x): 2543 y = x._autocast_to_full_precision(True, True) 2544 z = torch.exp(y) 2545 return z 2546 2547 x = torch.rand((2, 2), dtype=torch.half, device="cuda") 2548 scr = torch.jit.script(f) 2549 scr(x) 2550 scr(x) 2551 self.assertLastGraphAllFused() 2552 2553 @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA") 2554 def test_autocast_down(self): 2555 def f(x): 2556 y = torch.sigmoid(x) 2557 z = y._autocast_to_reduced_precision(True, True, torch.half, torch.half) 2558 return z 2559 2560 x = torch.rand((2, 2), dtype=torch.float, device="cuda") 2561 scr = torch.jit.script(f) 2562 scr(x) 2563 scr(x) 2564 self.assertLastGraphAllFused() 2565 2566 @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") 2567 def test_to_dtype(self): 2568 def f(x): 2569 y = torch.sigmoid(x) 2570 z = y._autocast_to_reduced_precision(True, True, torch.half, torch.bfloat16) 2571 h = z._autocast_to_full_precision(True, True) 2572 i = h.to(dtype=torch.bfloat16) 2573 j = i.to(dtype=torch.float32) 2574 return j 2575 2576 x = torch.rand((2, 2), dtype=torch.float32) 2577 scr = torch.jit.trace(f, x) 2578 scr(x) 2579 scr(x) 2580 self.assertLastGraphAllFused() 2581 self.assertEqual(f(x), scr(x), atol=4e-3, rtol=4e-3) 2582 2583 bf_x = torch.rand((2, 2), dtype=torch.bfloat16) 2584 bf_scr = torch.jit.trace(f, bf_x) 2585 bf_scr(bf_x) 2586 bf_scr(bf_x) 2587 graph = bf_scr.graph_for(bf_x) 2588 fusion_groups = self.findFusionGroups(graph) 2589 self.assertEqual(len(fusion_groups), 2) 2590 self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3) 2591 2592 def test_with_strict_fusion(self): 2593 def success(x): 2594 with torch.jit.strict_fusion(): 2595 return x + x + x 2596 2597 scripted = self.checkScript(success, (torch.rand([4]),)) 2598 g = torch.jit.last_executed_optimized_graph() 2599 FileCheck().check_not("aten::add").check("prim::TensorExprGroup").run(g) 2600 2601 def foo(x): 2602 with torch.jit.strict_fusion(): 2603 return x + x + torch.rand([4]) + 3 2604 2605 with self.assertRaises(Exception) as error_out: 2606 foo_s = torch.jit.script(foo) 2607 foo_s(torch.rand([4])) 2608 foo_s(torch.rand([4])) 2609 print(torch.jit.last_executed_optimized_graph()) 2610 fc = FileCheck().check("Found unfused operators") 2611 fc.check("aten::rand(SymInt[] size") 2612 fc.check("torch.rand([4]").run(str(error_out.exception)) 2613 2614 with warnings.catch_warnings(record=True) as warns: 2615 foo(torch.rand([4])) 2616 2617 FileCheck().check("Only works in script mode").run(str(warns[0])) 2618 2619 def test_autodiff(x): 2620 with torch.jit.strict_fusion(): 2621 return torch.rand([4]) + x + x + x 2622 2623 foo_s = torch.jit.script(test_autodiff) 2624 inp = torch.rand([4], requires_grad=True) 2625 with self.assertRaises(Exception) as error_out: 2626 for _ in range(3): 2627 foo_s(inp) 2628 f = FileCheck().check("unfused operators").check("aten::rand") 2629 f.run(str(error_out.exception)) 2630 2631 def test_separate_fusions(x, y): 2632 with torch.jit.strict_fusion(): 2633 return x + x + x, y + y + y 2634 2635 inp = torch.rand([4], requires_grad=True) 2636 with self.assertRaises(Exception) as error_out: 2637 for _ in range(3): 2638 foo_s = torch.jit.script(test_separate_fusions) 2639 foo_s(inp, inp) 2640 2641 f = FileCheck().check("Found multiple fusions") 2642 f.run(str(error_out.exception)) 2643 2644 def test_constant_chunk_shapes(self): 2645 # We had an issue where buildShapeExpressions would fail as show below: 2646 # 2647 # %1 : Tensor = Constant[..] # not supported, we don't build this shape 2648 # %2 : Tensor = Constant[..] # not supported 2649 # %3 : Tensor = aten::add(%1, %2) # inputs not supported, we don't build shape 2650 # ... = prim::ConstantChunk[..](%3) # it forgets to check whether input shapes exist, and fails 2651 if self.dynamic_shapes: 2652 self.skipTest("TODO: chunk dynamic shapes") 2653 2654 for device in self.devices: 2655 2656 def f(x, y): 2657 r = torch.tensor(4) 2658 z1, z2 = (x + y + r).chunk(2, dim=1) 2659 return z1 * z2 2660 2661 x = torch.randn(4, 4, dtype=torch.float, device=device) 2662 y = torch.randn(4, 4, dtype=torch.float, device=device) 2663 2664 ge = self.checkTrace(f, (x, y)) 2665 graph = ge.graph_for(x, y) 2666 2667 # make sure that we are actually testing the right scenario 2668 FileCheck().check("with " + FUSION_GROUP + "_").check_count( 2669 "ConstantChunk", 1, exactly=True 2670 ).run(str(graph)) 2671 2672 f_traced = torch.jit.trace(f, (x, y)) 2673 2674 for i in range(4): 2675 # make sure this doesn't error out 2676 res = f_traced(x, y) 2677 2678 self.assertEqual(res, f(x, y)) 2679 2680 @unittest.skipIf(not RUN_CUDA_HALF, "half-precision NNC fusion requires CUDA") 2681 def test_pow_multiple_dtype(self): 2682 # https://github.com/pytorch/pytorch/issues/75476 2683 def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: 2684 p = torch.sigmoid(p) 2685 result = p**gamma 2686 return result 2687 2688 x = torch.rand((2, 2), dtype=torch.half, device="cuda") 2689 2690 ref = fn(x) 2691 2692 script_fn = torch.jit.script(fn) 2693 for i in range(4): 2694 res = script_fn(x) 2695 2696 self.assertEqual(ref, res) 2697 2698 2699class TestTEFuserStatic(TestTEFuser): 2700 dynamic_shapes = False 2701 2702 2703class TestTEFuserDynamic(TestTEFuser): 2704 dynamic_shapes = True 2705 2706 2707del TestTEFuser 2708 2709works_list = [ 2710 "__radd__", 2711 "__rdiv__", 2712 "__rmul__", 2713 "__rmod__", 2714 "abs", 2715 "acos", 2716 "add", 2717 "addcmul", 2718 "addmm.decomposed", 2719 "asin", 2720 "atan", 2721 "atan2", 2722 "ceil", 2723 "clamp", 2724 "clamp.scalar", 2725 "contiguous", 2726 "cos", 2727 "cosh", 2728 "div.no_rounding_mode", 2729 "div.true_rounding", 2730 "div.floor_rounding", 2731 "div.trunc_rounding", 2732 "eq", 2733 "erf", 2734 "erfc", 2735 "exp", 2736 "expand", 2737 "expand_as", 2738 "expm1", 2739 "floor", 2740 "fmod", 2741 "fmod.autodiffed", 2742 "ge", 2743 "gt", 2744 "isnan", 2745 "le", 2746 "lerp", 2747 "lgamma", 2748 "log", 2749 "log10", 2750 "log1p", 2751 "log2", 2752 "lt", 2753 "masked_fill", 2754 "max.binary", 2755 "mean", 2756 "min.binary", 2757 "mm", 2758 "mul", 2759 "ne", 2760 "neg", 2761 "nn.functional.hardshrink", 2762 "nn.functional.hardsigmoid", 2763 "nn.functional.hardswish", 2764 "nn.functional.softplus", 2765 "nn.functional.hardtanh", 2766 "nn.functional.leaky_relu", 2767 "nn.functional.relu", 2768 "nn.functional.relu6", 2769 "nn.functional.softsign", 2770 "nn.functional.tanhshrink", 2771 "nn.functional.threshold", 2772 "permute", 2773 "pow", 2774 "reciprocal", 2775 "remainder", 2776 "remainder.autodiffed", 2777 "reshape", 2778 "reshape_as", 2779 "round", 2780 "rsub", 2781 "rsub.rsub_tensor", 2782 "rsqrt", 2783 "sigmoid", 2784 "sign", 2785 "sin", 2786 "sinh", 2787 "sqrt", 2788 "sub", 2789 "sum", 2790 "t", 2791 "tan", 2792 "tanh", 2793 "transpose", 2794 "true_divide", 2795 "trunc", 2796 "unsqueeze", 2797 "view", 2798 "view_as", 2799 "where", 2800 "bool", 2801 "byte", 2802 "char", 2803 "double", 2804 "float", 2805 "half", 2806 "int", 2807 "long", 2808 "short", 2809 "bool.channels_last", 2810 "byte.channels_last", 2811 "char.channels_last", 2812 "double.channels_last", 2813 "float.channels_last", 2814 "half.channels_last", 2815 "int.channels_last", 2816 "long.channels_last", 2817 "short.channels_last", 2818] 2819 2820known_failures = [ 2821 "__rmatmul__", 2822 "frac", 2823 "matmul", 2824] 2825 2826# If your OpInfo test causes this test to fail, add it here 2827skip_ops = ["conj"] 2828 2829 2830def get_name(op): 2831 l = [op.name] 2832 if op.variant_test_name != "": 2833 l.append(op.variant_test_name) 2834 return ".".join(l) 2835 2836 2837# Purpose of this class is to allow super() calls. 2838# super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works. 2839# super(TestNNCOpInfo, self) fails because TestNNCOpInfo gets deleted from global scope. 2840# super(JitCommonTestCase, self).fn() would skip JitCommonTestCase.fn() implementation 2841class TestNNCOpInfoParent(JitCommonTestCase): 2842 pass 2843 2844 2845class TestNNCOpInfo(TestNNCOpInfoParent): 2846 def setUp(self): 2847 super(TestNNCOpInfoParent, self).setUp() 2848 self.tensorexpr_options = TensorExprTestOptions() 2849 2850 def tearDown(self): 2851 self.tensorexpr_options.restore() 2852 super(TestNNCOpInfoParent, self).tearDown() 2853 2854 def te_compile(self, device, dtype, op): 2855 if op.name in skip_ops: 2856 return 2857 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) 2858 for sample_input in sample_inputs_itr: 2859 arg_values = [sample_input.input] + list(sample_input.args) 2860 kwarg_values = sample_input.kwargs 2861 param_names = [] 2862 param_values = [] 2863 fx_args = [] 2864 for idx, v in enumerate(arg_values): 2865 if isinstance(v, torch.Tensor): 2866 param_names.append(f"arg_{idx}") 2867 param_values.append(v) 2868 fx_args.append(param_names[-1]) 2869 else: 2870 fx_args.append(f"{repr(v)}") 2871 2872 for k, v in kwarg_values.items(): 2873 if isinstance(v, torch.Tensor): 2874 param_names.append(k) 2875 param_values.append(v) 2876 fx_args.append(f"{k} = {k}") 2877 else: 2878 fx_args.append(f"{k} = {repr(v)}") 2879 2880 code = f""" 2881def f({', '.join(param_names)}): 2882 return op.op({', '.join(fx_args)})""" 2883 g = {"torch": torch, "inf": math.inf, "op": op} 2884 exec(code, g) 2885 f = g["f"] 2886 f.__module__ = "test" 2887 out = f(*param_values) 2888 2889 ts_g = torch.jit.trace(f, param_values) 2890 kernel = torch._C._te.TensorExprKernel(ts_g.graph) 2891 correct_val = f(*param_values) 2892 self.assertEqual(kernel.run(tuple(param_values)), correct_val) 2893 self.assertEqual(kernel.fallback(tuple(param_values)), correct_val) 2894 2895 @onlyCPU 2896 @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") 2897 @ops( 2898 [op for op in op_db if get_name(op) in works_list], 2899 allowed_dtypes=(torch.float,), 2900 ) 2901 def test_working(self, device, dtype, op): 2902 self.te_compile(device, dtype, op) 2903 2904 @onlyCPU 2905 @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") 2906 @ops( 2907 [op for op in op_db if get_name(op) in known_failures], 2908 allowed_dtypes=(torch.float,), 2909 ) 2910 def test_failures(self, device, dtype, op): 2911 try: 2912 self.te_compile(device, dtype, op) 2913 except Exception as e: 2914 pass 2915 else: 2916 raise RuntimeError( 2917 "Expected test to fail. If it now works, move op into works_list" 2918 ) 2919 2920 @onlyCPU 2921 @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") 2922 @ops( 2923 [op for op in op_db if get_name(op) not in works_list + known_failures], 2924 allowed_dtypes=(torch.float,), 2925 ) 2926 def test_unsupported(self, device, dtype, op): 2927 if get_name(op) in skip_ops: 2928 return 2929 try: 2930 with warnings.catch_warnings(): 2931 warnings.simplefilter("ignore", TracerWarning) # noqa: F821 2932 self.te_compile(device, dtype, op) 2933 except Exception as e: 2934 pass 2935 else: 2936 raise RuntimeError( 2937 "Expected test to fail. If it now works, move op into works_list" 2938 ) 2939 2940 @slowTest 2941 @onlyCPU 2942 @ops(op_db, dtypes=OpDTypes.supported) 2943 def test_nnc_correctness(self, device, dtype, op): 2944 if not op.supports_tracing: 2945 self.skipTest("Requires tracing support") 2946 2947 with NoTracerWarnContextManager() as no_warn: 2948 variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op) 2949 2950 for variant, sample in variant_sample_pairs: 2951 trace = create_traced_fn(self, variant, cache_traced_fn=True) 2952 ref = variant( 2953 *clone_inputs((sample.input, *sample.args)), **sample.kwargs 2954 ) 2955 2956 trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) 2957 val = trace( 2958 *clone_inputs((sample.input, *sample.args)), **sample.kwargs 2959 ) 2960 2961 atol = 2e-1 if dtype == torch.bfloat16 else 1e-5 2962 rtol = 2e-1 if dtype == torch.bfloat16 else 1e-5 2963 self.assertEqual(ref, val, atol=atol, rtol=rtol) 2964 2965 # https://github.com/pytorch/pytorch/issues/35600 2966 # each torch.jit.trace adds state to the _python_cu compilation unit 2967 # since this test traces a lot of functions, out-of-memory can occur 2968 # if the CU is not cleared. 2969 torch.jit._state._python_cu.drop_all_functions() 2970 2971 2972# CPU fuser not currently used in fbcode 2973only_for = ("cuda") if IS_FBCODE else ("cpu", "cuda") 2974instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for) 2975 2976 2977# Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent) 2978class TestLoopnestRandomizationParent(JitTestCase): 2979 pass 2980 2981 2982class TestLoopnestRandomization(TestLoopnestRandomizationParent): 2983 def setUp(self): 2984 super(TestLoopnestRandomizationParent, self).setUp() 2985 self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() 2986 self.old_must_use_cpu_state = torch._C._jit_get_te_must_use_llvm_cpu() 2987 self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu() 2988 2989 torch._C._jit_override_can_fuse_on_cpu(True) 2990 # TODO: force LLVM. need to add it to asan, mac, windows builds + sandcastle 2991 # torch._C._jit_set_te_must_use_llvm_cpu(True) 2992 torch._C._jit_override_can_fuse_on_gpu(True) 2993 2994 self.old_profiling_executor = torch._C._jit_set_profiling_executor(True) 2995 self.old_profiling_mode = torch._C._get_graph_executor_optimize(True) 2996 2997 self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() 2998 torch._C._debug_set_fusion_group_inlining(False) 2999 3000 self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() 3001 torch._C._jit_set_texpr_fuser_enabled(True) 3002 3003 self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() 3004 torch._C._jit_set_te_must_use_llvm_cpu(False) 3005 3006 # Set the seed to 1. This tests the codepath through random 3007 # transformation. 3008 os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "1" 3009 3010 def tearDown(self): 3011 torch._C._jit_set_profiling_executor(self.old_profiling_executor) 3012 torch._C._get_graph_executor_optimize(self.old_profiling_mode) 3013 3014 torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state) 3015 torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) 3016 torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state) 3017 torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) 3018 3019 torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) 3020 torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu) 3021 3022 # Set it back to 0. 3023 os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "0" 3024 super(TestLoopnestRandomizationParent, self).tearDown() 3025 3026 @onlyCPU 3027 @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") 3028 def test_relu(self, device): 3029 def fn_test_relu(x, y): 3030 return F.relu(x + 0.5 * y) 3031 3032 x = torch.randn(4, 4, dtype=torch.float, device=device) 3033 y = torch.randn(4, 4, dtype=torch.float, device=device) 3034 3035 fn = fn_test_relu 3036 traced_fn = torch.jit.trace(fn, (x, y)) 3037 3038 ref = fn(x, y) 3039 res = traced_fn(x, y) 3040 assert torch.allclose(ref, res) 3041 3042 3043instantiate_device_type_tests(TestLoopnestRandomization, globals(), only_for=("cpu")) 3044 3045 3046if __name__ == "__main__": 3047 run_tests() 3048