1# Owner(s): ["oncall: jit"] 2 3import unittest 4import os 5import sys 6import torch 7import torch.nn as nn 8import torch.nn.functional as F 9from torch.testing import FileCheck 10from unittest import skipIf 11 12from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ 13 enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell 14from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \ 15 RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward 16from textwrap import dedent 17from itertools import product, permutations 18from torch.testing._internal.common_cuda import with_tf32_off 19 20from test_jit import backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, \ 21 LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell 22 23if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 24 torch._C._jit_set_profiling_executor(True) 25 torch._C._jit_set_profiling_mode(True) 26 27 28def strip_profiling_nodes(nodes): 29 profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'} 30 return [n for n in nodes if n.kind() not in profiling_opcodes] 31 32 33def warmup_forward(f, *args): 34 profiling_count = 2 35 for i in range(profiling_count): 36 results = f(*args) 37 38 return results 39 40 41@skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "skip due to SIGIOT failures, #67646") 42class TestFuser(JitTestCase): 43 def assertAllFused(self, graph, except_for=()): 44 45 diff_graphs = [n for n in graph.nodes() if n.kind() == 'prim::DifferentiableGraph'] 46 if len(diff_graphs) > 0: 47 self.assertEqual(len(diff_graphs), 1) 48 graph = diff_graphs[0].g('Subgraph') 49 50 allowed_nodes = {'prim::Constant', 'prim::FusionGroup', 'prim::BailoutTemplate', 51 'prim::BailOut', 'prim::TupleConstruct'} | set(except_for) 52 self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()), 53 f'got {graph}') 54 self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1) 55 56 def _test_fused_abs(self, device='cpu'): 57 def func(x): 58 return x.abs() * 2 59 60 a = torch.randn(5, device=device) 61 scripted = self.checkScript(func, (a,)) 62 self.assertAllFused(scripted.graph_for(a)) 63 64 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 65 @enable_cpu_fuser 66 def test_abs_cpu(self): 67 self._test_fused_abs() 68 69 @unittest.skipIf(not IS_WINDOWS, "This is meant to be Windows-specific") 70 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 71 @enable_cpu_fuser 72 def test_abs_cpu_unicode_temp_dir(self): 73 with TemporaryDirectoryName(suffix='\u4e2d\u6587') as dname: 74 shell_env = os.environ.copy() 75 shell_env['TMP'] = dname 76 cmd = [sys.executable, os.path.basename(__file__), type(self).__name__ + '.test_abs_cpu'] 77 legacy_jit_flag = '--jit-executor=legacy' 78 for v in sys.argv: 79 if v == legacy_jit_flag: 80 cmd.append(legacy_jit_flag) 81 return_code = shell(cmd, cwd=os.path.dirname(__file__), env=shell_env) 82 self.assertEqual(return_code, 0) 83 84 @unittest.skipIf(not RUN_CUDA, "requires CUDA") 85 def test_abs_cuda(self): 86 self._test_fused_abs(device="cuda") 87 88 @unittest.skipIf(not RUN_CUDA, "requires CUDA") 89 def test_zero_element_tensors(self): 90 def decode(sin_t, cos_t): 91 theta = torch.atan2(sin_t.float(), cos_t.float()) 92 return theta 93 94 sin = torch.zeros(0, device="cuda") 95 cos = torch.zeros(0, device="cuda") 96 inputs = [sin, cos] 97 ge = self.checkScript(decode, inputs) 98 99 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 100 def test_arg_configurations_smoke_cuda(self): 101 # A smoke test to make sure we won't use the same kernel for contiguous 102 # and non-contiguous arguments. 103 # TODO: add optionally enabled debug counters to the fuser to verify 104 # that we really can tell the difference between configurations 105 def f(x, y): 106 z1, z2 = (x + y).chunk(2, dim=1) 107 return z1 * z2 108 109 x = torch.randn(4, 4, dtype=torch.float, device='cuda') 110 y = torch.randn(4, 4, dtype=torch.float, device='cuda') 111 traced_f = torch.jit.trace(f, (x, y,)) 112 self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) 113 114 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 115 def test_broadcast_cuda(self): 116 def scaleshift(x, scale, shift): 117 return x * scale + shift 118 119 inputs = [ 120 torch.randn(4, 4, dtype=torch.float, device='cuda'), 121 torch.randn(4, dtype=torch.float, device='cuda'), 122 torch.randn(4, dtype=torch.float, device='cuda'), 123 ] 124 ge = self.checkTrace(scaleshift, inputs) 125 self.assertAllFused(ge.graph_for(*inputs)) 126 127 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 128 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no bfloat support with profiling on") 129 def test_cuda_bfloat16(self): 130 def foo(x, y): 131 return (x + y).relu() 132 m = torch.jit.script(foo) 133 x = torch.randn(65536).cuda().bfloat16() 134 y = torch.randn_like(x) 135 self.assertAllFused(m.graph_for(x, y)) 136 137 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 138 @unittest.skipIf(not RUN_CUDA_HALF, "no half support") 139 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") 140 def test_cuda_half(self): 141 x = torch.randn(4, 4, dtype=torch.half, device='cuda') 142 y = torch.randn(4, 4, dtype=torch.half, device='cuda') 143 144 funcs = [ 145 self.fn_test_comparison_gt_lt, 146 self.fn_test_relu, 147 self.fn_test_exp 148 ] 149 150 # Note: Non fused inputs must be float to prevent loss of precision 151 inputs = (x.float(), y.float()) 152 fusion_inputs = (x, y) 153 for fn in funcs: 154 local_inputs = [t.clone().requires_grad_() for t in inputs] 155 local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs] 156 157 # Verifies outputs 158 fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False) 159 outputs = fn(*local_inputs) 160 fusion_outputs = fusion(*local_fusion_inputs) 161 outputs_half = [t.half() for t in outputs] 162 self.assertEqual(outputs_half, fusion_outputs) 163 164 # Verifies gradients 165 for output, fusion_output in zip(outputs_half, fusion_outputs): 166 grads = torch.autograd.grad( 167 output.float().sum(), local_inputs, allow_unused=True, retain_graph=True) 168 fusion_grads = torch.autograd.grad( 169 fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True) 170 grads_half = [t.half() for t in grads] 171 self.assertEqual(grads_half, fusion_grads) 172 173 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 174 def test_checks_cat_inputs(self): 175 # We shouldn't treat cat nodes as broadcasting. All their inputs 176 # need to be checked for having the same map size, before we can 177 # run the kernel. 178 def f(x, y): 179 return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) 180 181 # NOTE: y is broadcastable to x, but output of f(x, y) should have 182 # shape 3x4, and not 4x4. 183 x = torch.randn(2, 4, dtype=torch.float, device='cuda') 184 y = torch.randn(1, 4, dtype=torch.float, device='cuda') 185 186 scripted = self.checkScript(f, (x, y)) 187 self.assertAllFused(scripted.graph_for(x, y)) 188 189 @unittest.skipIf(not RUN_CUDA, "No CUDA") 190 def test_remainder_cuda(self): 191 def cuda_rem(x, y): 192 return 1 + torch.remainder(x, y) - 1 193 194 a = torch.rand([512], dtype=torch.float).cuda() 195 b = torch.rand([512], dtype=torch.float).cuda() 196 inputs = [a, b] 197 ge = self.checkScript(cuda_rem, inputs) 198 graph = ge.graph_for(*inputs) 199 self.assertAllFused(graph) 200 201 @unittest.skipIf(not RUN_CUDA, "No CUDA") 202 def test_chunk_cuda(self): 203 def fn(x): 204 a, b, c = x.chunk(3, 1) 205 return a * b + c 206 207 inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')] 208 209 ge = self.checkScript(fn, inputs) 210 graph = ge.graph_for(*inputs) 211 self.assertAllFused(graph) 212 FileCheck().check("prim::ConstantChunk[chunks=3, dim=1]").run(str(graph)) 213 214 @staticmethod 215 def _test_chunk_correctness(self, device='cpu'): 216 def chunk_4_0(x): 217 x0, x1, x2, x3 = x.chunk(4, 0) 218 return x0 + x1 + x2 + x3 219 220 def chunk_4_1(x): 221 x0, x1, x2, x3 = x.chunk(4, 1) 222 return x0 + x1 + x2 + x3 223 224 def chunk_4_last(x): 225 x0, x1, x2, x3 = x.chunk(4, 2) 226 return x0 + x1 + x2 + x3 227 228 fns = [chunk_4_0, chunk_4_1, chunk_4_last] 229 tensors = [ 230 # splitSize = 1 231 torch.randn(4, 4, 4, dtype=torch.float, device=device), 232 233 # contiguous case 234 torch.randn(12, 8, 16, dtype=torch.float, device=device), 235 236 # non-contiguous case 237 torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2), 238 ] 239 240 for tensor in tensors: 241 for fn in fns: 242 self.checkScript(fn, [tensor]) 243 244 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 245 @enable_cpu_fuser 246 def test_chunk_correctness(self): 247 return self._test_chunk_correctness(self, 'cpu') 248 249 @unittest.skipIf(not RUN_CUDA, "No CUDA") 250 def test_chunk_correctness_cuda(self): 251 return self._test_chunk_correctness(self, 'cuda') 252 253 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 254 def test_chunk_distributes_cuda(self): 255 def f(x, y): 256 z1, z2 = (x + y).chunk(2, dim=1) 257 return z1 * z2 258 259 x = torch.randn(4, 4, dtype=torch.float, device='cuda') 260 y = torch.randn(4, 4, dtype=torch.float, device='cuda') 261 262 ge = self.checkTrace(f, (x, y)) 263 graph = ge.graph_for(x, y) 264 FileCheck().check("broadcast_tensors").check('with prim::FusionGroup_') \ 265 .check_count('ConstantChunk', 2, exactly=True).run(str(graph)) 266 267 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 268 def test_chunk_motion_deduplicates_inputs(self): 269 def func1(x): 270 z = x * x 271 z0, z1 = z.chunk(2) 272 return z0 * z1 273 274 def func2(x): 275 z = x * x * x 276 z0, z1 = z.chunk(2) 277 return z0 * z1 278 279 inputs = [ 280 torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float), 281 ] 282 for func in [func1, func2]: 283 module = self.checkScript(func, inputs) 284 forward_graph = module.graph_for(*inputs) 285 self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1) 286 fusion_group = list(forward_graph.nodes())[-1] 287 self.assertEqual(len(list(fusion_group.inputs())), 1) 288 289 @unittest.skipIf(not RUN_CUDA, "No CUDA") 290 def test_chunk_multiple_cuda(self): 291 # The arguments are intentionally used out of order as a test to see 292 # if the fusion compiler adds extra args in the correct order 293 def fn(s, x, y, z): 294 z1, z2 = z.chunk(2, 2) 295 x1, x2, x3 = x.chunk(3, 1) 296 y1, y2 = y.chunk(2, 0) 297 return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 298 299 inputs = [ 300 torch.randn(5, 2, 3, dtype=torch.float, device='cuda'), 301 torch.randn(5, 6, 3, dtype=torch.float, device='cuda'), 302 torch.randn(10, 2, 3, dtype=torch.float, device='cuda'), 303 torch.randn(5, 2, 6, dtype=torch.float, device='cuda'), 304 ] 305 306 ge = self.checkScript(fn, inputs) 307 self.assertAllFused(ge.graph_for(*inputs)) 308 309 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 310 def test_minmax(self): 311 def tmax(a, b): 312 return torch.max(2 * a, b) 313 314 def tmin(a, b): 315 return torch.min(2 * a, b) 316 317 a = torch.randn(4, 4, dtype=torch.float, device="cuda") 318 b = torch.randn(4, 4, dtype=torch.float, device="cuda") 319 nan = torch.tensor(float('nan'), dtype=torch.float, device="cuda") 320 321 for f, inputs in product( 322 (tmax, tmin), 323 ([a, b], [a, nan], [b, nan])): 324 s = self.checkScript(f, inputs) 325 self.assertAllFused(s.graph_for(*inputs)) 326 327 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 328 def test_clamp(self): 329 def func2(a, b): 330 return torch.clamp(a + b, min=0, max=2) 331 332 def funcInf(a, b): 333 return torch.clamp(a + b, min=0, max=float('inf')) 334 335 def funcOptMin(a, b): 336 return torch.clamp(a + b, max=2) 337 338 def funcOptMax(a, b): 339 return torch.clamp(a + b, min=0) 340 341 a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) 342 b = torch.randn(4, 4, dtype=torch.float, device='cuda') 343 nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') 344 345 funcs = (func2, funcInf, funcOptMin, funcOptMax) 346 for f, inputs in product(funcs, [[a, b], [a, nan]]): 347 f.__disable_jit_function_caching__ = True 348 inp1, inp2 = inputs 349 s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) 350 self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) 351 c = s(inp1, inp2) 352 with enable_profiling_mode_for_profiling_tests(): 353 warmup_backward(c.sum()) 354 graph = backward_graph(s) 355 self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'}) 356 357 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 358 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") 359 def test_dropout(self): 360 def func(x): 361 x = torch.nn.functional.dropout(x) 362 return torch.nn.functional.relu(x) 363 364 a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) 365 s = torch.jit.script(func) 366 c = s(a) 367 c = s(a) 368 warmup_backward(c.sum()) 369 # skip_check to skip extra bailout nodes in between 370 graph = backward_graph(s, skip_check=True) 371 self.assertAllFused(graph, except_for={'aten::div', 'prim::Constant'}) 372 373 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 374 def test_comparison_eq_ne(self): 375 def f(x, y): 376 mask = (x == 0).type_as(x) 377 z = x * mask + y 378 mask = (x != 0).type_as(x) 379 z = z * mask + y 380 return z 381 382 x = torch.randn(4, 4, dtype=torch.float, device='cuda') 383 y = torch.randn(4, 4, dtype=torch.float, device='cuda') 384 385 ge = self.checkTrace(f, (x, y)) 386 self.assertAllFused(ge.graph_for(x, y)) 387 388 @staticmethod 389 def fn_test_comparison_gt_lt(x, y): 390 mask = (x > 0).type_as(x) 391 z = x * mask + y 392 mask = (x < 0).type_as(x) 393 z = z * mask + y 394 return z 395 396 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 397 def test_comparison_gt_lt_cuda(self): 398 x = torch.randn(4, 4, dtype=torch.float, device='cuda') 399 y = torch.randn(4, 4, dtype=torch.float, device='cuda') 400 401 ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) 402 self.assertAllFused(ge.graph_for(x, y)) 403 404 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 405 def test_comparison_ge_le_cuda(self): 406 def f(x, y): 407 mask = (x >= 0).type_as(x) 408 z = x * mask + y 409 mask = (x <= 0).type_as(x) 410 z = z * mask + y 411 return z 412 413 x = torch.randn(4, 4, dtype=torch.float, device='cuda') 414 y = torch.randn(4, 4, dtype=torch.float, device='cuda') 415 416 ge = self.checkTrace(f, (x, y)) 417 self.assertAllFused(ge.graph_for(x, y)) 418 x.requires_grad_(True) 419 y.requires_grad_(True) 420 self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", 421 "aten::_size_if_not_equal")) 422 423 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 424 def test_addcmul_cuda(self): 425 t = torch.randn(1, 4, dtype=torch.float, device='cuda') 426 t1 = torch.randn(4, 1, dtype=torch.float, device='cuda') 427 t2 = torch.randn(1, 4, dtype=torch.float, device='cuda') 428 429 def foo(t, t1, t2): 430 return t.addcmul(t + 1, t2, value=0.1) 431 432 ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) 433 graph = ge.graph_for(t, t1, t2) 434 self.assertAllFused(graph) 435 436 # TODO: We leak CUDA memory here because the traced graph holds onto a 437 # constant-ified tensor. Since the Python-global CompilationUnit is alive 438 # until the end of the process, the memory is effectively leaked. 439 # Removed `_cuda` suffix from this test which disables leak-checking. 440 # If this is a real problem, we'll need to revisit Torchscript Function 441 # lifetimes in Python. 442 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 443 def test_lerp(self): 444 start = torch.randn(4, 1, dtype=torch.float, device='cuda') 445 end = torch.randn(1, 4, dtype=torch.float, device='cuda') 446 weight = torch.tensor(0.5, dtype=torch.float, device='cuda') 447 448 # scalar weight overload 449 def foo_weight_scalar(start, end): 450 return torch.lerp(start + 1, end, 0.5) 451 452 # tensor weight overload 453 def foo_weight_tensor(start, end): 454 return torch.lerp(start + 1, end, weight) 455 456 ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end)) 457 graph = ge_weight_scalar.graph_for(start, end) 458 self.assertAllFused(graph) 459 460 ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end)) 461 graph = ge_weight_tensor.graph_for(start, end) 462 self.assertAllFused(graph) 463 464 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 465 def test_concat_cuda(self): 466 hx = torch.randn(3, 20, dtype=torch.float, device='cuda') 467 cx = torch.randn(3, 20, dtype=torch.float, device='cuda') 468 469 def foo(hx, cx): 470 return torch.cat((hx + cx, hx * cx)) 471 472 ge = self.checkTrace(foo, (hx, cx)) 473 graph = ge.graph_for(hx, cx) 474 self.assertAllFused(graph) 475 FileCheck().check("FusedConcat").check_next("return").run(str(graph)) 476 477 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 478 def test_concat_invariant_cuda(self): 479 # Invariant: the output of prim::FusedConcat may 480 # not be an input to any node inside the FusionGroup. 481 def fn(x, y, z): 482 x1 = x + y 483 y1 = x - y 484 w = torch.cat([x1, y1]) 485 return w + z 486 487 x = torch.randn(2, 2, dtype=torch.float, device='cuda') 488 y = torch.randn(2, 2, dtype=torch.float, device='cuda') 489 z = torch.randn(4, 2, dtype=torch.float, device='cuda') 490 ge = self.checkTrace(fn, (x, y, z)) 491 graph = ge.graph_for(x, y, z) 492 self.assertAllFused(graph, except_for={'aten::add'}) 493 FileCheck().check("FusedConcat").check_next("return").run(str(graph)) 494 495 @staticmethod 496 def fn_test_exp(x, y): 497 return (x + .5 * y).exp() 498 499 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 500 def test_exp_cuda(self): 501 x = torch.randn(4, 4, dtype=torch.float, device='cuda') 502 y = torch.randn(4, 4, dtype=torch.float, device='cuda') 503 504 ge = self.checkTrace(self.fn_test_exp, (x, y)) 505 self.assertAllFused(ge.graph_for(x, y)) 506 507 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 508 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on") 509 @torch._jit_internal._disable_emit_hooks_decorator 510 @_inline_everything 511 def test_fuse_decompose_normalization(self): 512 class ResLike(torch.jit.ScriptModule): 513 def __init__(self, norm_module): 514 super().__init__() 515 self.nm = norm_module 516 517 @torch.jit.script_method 518 def forward(self, x, y): 519 return y + torch.relu(self.nm(x)) 520 521 def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph): 522 model = ResLike(nm).cuda() 523 model_noopt = ResLike(nm).cuda() 524 model_noopt.load_state_dict(model.state_dict()) 525 x = torch.randn(2, 16, 8, 8, device='cuda') 526 y = torch.randn(2, 16, 8, 8, device='cuda') 527 528 # FIXME: We need differentiation for CNNs for this optimization to trigger 529 with torch.no_grad(): 530 out = model(x, y) 531 graph = model.graph_for(x, y) 532 rep = str(graph) 533 534 with torch.jit.optimized_execution(False): 535 out_noopt = model_noopt(x, y) 536 rep_noopt = str(model_noopt.graph_for(x, y)) 537 self.assertEqual(out, out_noopt, atol=3e-5) 538 539 # Check that normalization op has really been decomposed 540 for node_in_graph in in_opt_graph: 541 self.assertIn(node_in_graph, rep) 542 543 for node_not_in_graph in not_in_opt_graph: 544 self.assertNotIn(node_not_in_graph, rep) 545 self.assertIn(node_not_in_graph, rep_noopt) 546 547 fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup'] 548 self.assertEqual(len(fusion_groups), 1) 549 fused_graph = str(fusion_groups[0].g('Subgraph')) 550 for node_in_fusegraph in in_fusegraph: 551 self.assertIn(node_in_fusegraph, fused_graph) 552 553 # test for batchnorm decompose 554 bm = nn.BatchNorm2d(16) 555 test_norm_decompose(bm, ['aten::batch_norm_update_stats'], 556 ['aten::batch_norm('], ['aten::sqrt']) 557 558 # test for layernorm decompose 559 lm = nn.LayerNorm(8) 560 test_norm_decompose(lm, ['aten::batch_norm_stats'], 561 ['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::add']) 562 563 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 564 def test_threshold(self): 565 def f(x): 566 return torch.threshold(x, 0, -10) + x + x + x 567 568 x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda') 569 scripted = self.checkScript(f, (x,)) 570 self.assertAllFused(scripted.graph_for(x)) 571 572 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 573 def test_scalar_arg_cuda(self): 574 def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor: 575 return p * (x * x + x) 576 577 x = torch.randn(4, 4, dtype=torch.float, device='cuda') 578 p = 3 579 scripted = self.checkScript(fn_test_scalar_arg, (x, p)) 580 self.assertAllFused(scripted.graph_for(x, p)) 581 582 x.requires_grad_(True) 583 584 # use another function otherwise we will bailout 585 # and won't be able to do fused checks 586 def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor: 587 return p * (x * x + x) 588 589 scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) 590 out = scripted(x, p) 591 self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes", 592 "aten::_size_if_not_equal")) 593 594 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 595 @unittest.skip("deduplicating introduces aliasing in backward graph's outputs") 596 @enable_cpu_fuser 597 def test_fuser_deduplication(self): 598 # See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation 599 # see the discussion in PR #14957. 600 def f(x, y): 601 return torch.sigmoid(x + y) 602 603 b = torch.randn(5, 5, requires_grad=True) 604 a = torch.randn(5, 5, requires_grad=True) 605 s = self.checkScript(f, (a, b)) 606 self.assertAllFused(s.graph_for(a, b), except_for={ 607 'aten::size', 'aten::_size_if_not_equal', 'prim::BroadcastSizes'}) 608 609 c = s(a, b) 610 results = warmup_backward(c.sum(), [a, b]) 611 ga2, gb2 = results.pop() 612 graph = backward_graph(s) 613 self.assertAllFused(graph) 614 # check that a, b share storage, i.e. were generated as a single output in the fuser 615 self.assertEqual(ga2.data_ptr(), gb2.data_ptr()) 616 617 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 618 @enable_cpu_fuser 619 @unittest.skip("temporarily disabled because fusion was restricted in fixing #22833") 620 def test_fuser_iou(self): 621 # This checks if most of Intersection over Union is fused. 622 # In particular, the backward contains many _grad_sum_to_size. 623 def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2): 624 ltx = torch.max(b1x1, b2x1) # [N,M] 625 lty = torch.max(b1y1, b2y1) 626 rbx = torch.min(b1x2, b2x2) 627 rby = torch.min(b1y2, b2y2) 628 629 w = (rbx - ltx).clamp(min=0, max=float('inf')) # [N,M] 630 h = (rby - lty).clamp(min=0, max=float('inf')) # [N,M] 631 inter = w * h # [N,M] 632 633 area1 = (b1x2 - b1x1) * (b1y2 - b1y2) # [N,1] 634 area2 = (b2x2 - b2x1) * (b2y2 - b2y2) # [1,M] 635 iou = inter / (area1 + area2 - inter) 636 return iou 637 638 box1 = torch.randn(5, 4, requires_grad=True) 639 box2 = torch.randn(5, 4, requires_grad=True) 640 # unsqueezing can currently not be fused 641 b1x1 = box1[:, 0].unsqueeze(1) # [N,1] 642 b1y1 = box1[:, 1].unsqueeze(1) 643 b1x2 = box1[:, 2].unsqueeze(1) 644 b1y2 = box1[:, 3].unsqueeze(1) 645 b2x1 = box2[:, 0].unsqueeze(0) # [1,N] 646 b2y1 = box2[:, 1].unsqueeze(0) 647 b2x2 = box2[:, 2].unsqueeze(0) 648 b2y2 = box2[:, 3].unsqueeze(0) 649 650 s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)) 651 self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2), 652 except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'}) 653 654 with enable_profiling_mode_for_profiling_tests(True): 655 c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2) 656 warmup_backward(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2]) 657 graph = backward_graph(s) 658 self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'}) 659 660 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 661 @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") 662 @enable_cpu_fuser 663 def test_fusion_reuse_multi_gpu(self): 664 def fn(x, y): 665 return x * y * x * y 666 667 inputs_cpu = [ 668 torch.randn(4, 4, dtype=torch.float), 669 torch.randn(4, 4, dtype=torch.float), 670 ] 671 inputs_cuda0 = [x.cuda(0) for x in inputs_cpu] 672 inputs_cuda1 = [y.cuda(1) for y in inputs_cpu] 673 674 # Should not crash; these should compile different kernels. 675 ge = self.checkScript(fn, inputs_cpu) 676 self.assertAllFused(ge.graph_for(*inputs_cpu)) 677 ge(*inputs_cuda0) 678 ge(*inputs_cuda1) 679 680 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 681 @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") 682 @enable_cpu_fuser 683 def test_kernel_cache_multi_gpu(self): 684 def not_fusible(x): 685 return x 686 687 def fn(x, y, z): 688 x_out = x * x * x * x * x # fusion: lambda x. x * x * x * x * x 689 y_out = y * y * y * y * y 690 z_out = z * z * z * z * z 691 return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out) 692 693 inputs = [ 694 torch.randn(4, 4, dtype=torch.float), 695 torch.randn(4, 4, dtype=torch.float, device='cuda:0'), 696 torch.randn(4, 4, dtype=torch.float, device='cuda:1'), 697 ] 698 699 prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() 700 701 # There are 3 FusionGroups. Because they have the same graph, they 702 # should reuse the same KernelSpec in the KernelSpec cache. 703 ge = self.checkScript(fn, inputs) 704 self.assertGraphContainsExactly( 705 ge.graph_for(*inputs), 'prim::FusionGroup', 3, True) 706 new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() 707 # XXX: This assumes that the same kernel isn't already used by another test 708 self.assertEqual(new_cache_size - prev_cache_size, 1) 709 710 @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") 711 def test_nonzero_device_cuda(self): 712 device = 'cuda:' + str(1) 713 x = torch.tensor([0.4], dtype=torch.float, device=device) 714 y = torch.tensor([0.7], dtype=torch.float, device=device) 715 716 def doit(x, y): 717 return torch.sigmoid(torch.tanh(x * (x + y) + x)) 718 719 ge = self.checkTrace(doit, (x, y)) 720 self.assertAllFused(ge.graph_for(x, y)) 721 722 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 723 def test_lstm_cuda(self): 724 inputs = get_lstm_inputs('cuda', training=True) 725 module = self.checkScript(LSTMCellS, inputs) 726 return 727 forward_graph = module.graph_for(*inputs) 728 self.assertGraphContainsExactly( 729 forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) 730 self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2) 731 # Everything is differentiable but TupleConstruct return 732 FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ 733 .check_next("return").run(str(forward_graph)) 734 735 with enable_profiling_mode_for_profiling_tests(True): 736 hy, cy = module(*inputs) 737 warmup_backward((hy + cy).sum()) 738 backward = backward_graph(module) 739 self.assertAllFused(backward, except_for=("aten::t", "aten::mm", 740 "aten::_grad_sum_to_size")) 741 742 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 743 # By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision. 744 # We want float tensors to be computed at full precision in order to use the default precision 745 @with_tf32_off 746 def test_lstm_concat_cuda(self): 747 inputs = get_lstm_inputs('cuda') 748 ge = self.checkTrace(LSTMCellC, inputs) 749 graph = ge.graph_for(*inputs) 750 FileCheck().check("FusedConcat").check_next("return").run(str(graph)) 751 752 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 753 def test_lstm_gates_permutations_cuda(self): 754 # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. 755 # Test that any permutation of this will still result in one FusionGroup. 756 choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh'] 757 template = dedent(''' 758 def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): 759 gates = {} + {} + {} + {} 760 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 761 return ingate * forgetgate * cellgate * outgate 762 ''') 763 for permutation in permutations(choices, len(choices)): 764 code = template.format(*permutation) 765 scope = {} 766 exec(code, globals(), scope) 767 cu = torch.jit.CompilationUnit(code) 768 769 inputs = get_lstm_inputs('cuda', training=False) 770 self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs)) 771 forward_graph = cu.cell.graph_for(*inputs) 772 self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1) 773 774 # TODO: Fuser doesn't work at all when inputs require grad. Fix that 775 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 776 # By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision. 777 # We want float tensors to be computed at full precision in order to use the default precision 778 @with_tf32_off 779 def test_lstm_traced_cuda(self): 780 inputs = get_lstm_inputs('cuda') 781 ge = self.checkTrace(LSTMCellF, inputs) 782 graph = ge.graph_for(*inputs) 783 # .check_not("aten::add") don't get pulled into FusionGroup because of BailOuts 784 FileCheck().check_not("Chunk").check_not("aten::sigmoid") \ 785 .check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \ 786 .check_next("return").check_not("FusionGroup_2").run(str(graph)) 787 788 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 789 @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746") 790 @enable_cpu_fuser 791 def test_lstm_traced_cpu(self): 792 inputs = get_lstm_inputs('cpu') 793 try: 794 ge = self.checkTrace(LSTMCellF, inputs) 795 graph = ge.graph_for(*inputs) 796 FileCheck.check("FusionGroup").run(str(graph)) 797 except RuntimeError as e: 798 if 'Failed to compile' in e.args[0]: 799 warnings.warn('CPU fuser test has failed! This is not a hard failure, ' # noqa: F821 800 'because the kernels sometimes trigger bugs in compilers ' 801 '(most notably GCC 7.2).') 802 raise unittest.SkipTest('Failed to compile') from e 803 else: 804 raise 805 806 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 807 def test_milstm_cuda(self): 808 inputs = get_milstm_inputs('cuda', training=True) 809 module = self.checkScript(MiLSTMCell, inputs) 810 forward_graph = module.graph_for(*inputs) 811 self.assertGraphContainsExactly( 812 forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) 813 FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ 814 .check_next("return").check("FusionGroup").run(str(forward_graph)) 815 hy, cy = module(*inputs) 816 warmup_backward((hy + cy).sum()) 817 818 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 819 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor") 820 def test_rand_cuda(self): 821 class M(torch.jit.ScriptModule): 822 __constants__ = ['d'] 823 824 def __init__(self) -> None: 825 super().__init__() 826 self.d = torch.device('cuda') 827 828 @torch.jit.script_method 829 def create(self, x): 830 return x * x + x + torch.rand_like(x) 831 832 x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda') 833 m = M() 834 out1 = m.create(x) 835 out2 = m.create(x) 836 self.assertNotEqual(out1, out2) 837 self.assertTrue(torch.all(out1 >= 0)) 838 self.assertTrue(torch.all(out1 < 1)) 839 self.assertTrue(torch.all(out2 >= 0)) 840 self.assertTrue(torch.all(out2 < 1)) 841 self.assertAllFused(m.create.graph_for(x)) 842 843 @staticmethod 844 def fn_test_relu(x, y): 845 return F.relu(x + .5 * y) 846 847 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 848 def test_relu_cuda(self): 849 x = torch.randn(4, 4, dtype=torch.float, device='cuda') 850 y = torch.randn(4, 4, dtype=torch.float, device='cuda') 851 852 ge = self.checkTrace(self.fn_test_relu, (x, y)) 853 self.assertAllFused(ge.graph_for(x, y)) 854 855 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 856 def test_erf_cuda(self): 857 def fn_test_erf(x): 858 return F.relu(torch.erf(x) - torch.erfc(x)) 859 860 x = torch.randn(4, 4, dtype=torch.float, device='cuda') 861 ge = self.checkTrace(fn_test_erf, (x,)) 862 self.assertAllFused(ge.graph_for(x)) 863 x.requires_grad_(True) 864 ge = self.checkTrace(fn_test_erf, (x,)) 865 self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes", 866 "aten::_size_if_not_equal")) 867 868 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 869 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor") 870 def test_rand_broadcast_cuda(self): 871 def fn_test_rand(x, y): 872 r = torch.rand_like(y) 873 return r * x + x 874 875 x = torch.randn(4, 4, dtype=torch.float, device='cuda') 876 y = torch.randn(4, 4, dtype=torch.float, device='cuda') 877 script_f = torch.jit.script(fn_test_rand) 878 out = script_f(x, y) 879 self.assertAllFused(script_f.graph_for(x, y)) 880 x.requires_grad_(True) 881 out = script_f(x, y) 882 self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", 883 "aten::_size_if_not_equal")) 884 # test that broadcasting random produces correct results 885 x = torch.ones(4, 4, dtype=torch.float, device='cuda') 886 y = torch.ones(4, dtype=torch.float, device='cuda') 887 out = script_f(x, y) 888 self.assertEqual(out[0], out[1]) 889 890 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 891 @enable_cpu_fuser 892 def test_scalar(self): 893 def fn(x, y): 894 return 2 * x + y 895 896 x = torch.tensor(0.1, dtype=torch.float, device='cpu') 897 y = torch.tensor(1, dtype=torch.float, device='cpu') 898 ge = self.checkScript(fn, (x, y)) 899 self.assertAllFused(ge.graph_for(x, y)) 900 901 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 902 def test_small_constant_cuda(self): 903 def fn_test_small_constant(x, y): 904 return (1e-8 * x + 5e-9 * y) * 1e8 905 x = torch.randn(4, 4, dtype=torch.float, device='cuda') 906 y = torch.randn(4, 4, dtype=torch.float, device='cuda') 907 908 ge = self.checkTrace(fn_test_small_constant, (x, y)) 909 self.assertAllFused(ge.graph_for(x, y)) 910 911 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 912 def test_tensor_scalar_ops_cuda(self): 913 def should_fuse(x): 914 z = 3. 915 y = x + z 916 return x * y 917 918 # XXX: right now we only support fusing scalars if 919 # they're constant (#9940) 920 def should_not_fuse(x, z): 921 y = x + int(z) 922 return x * y 923 924 inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')] 925 ge = self.checkScript(should_fuse, inputs) 926 self.assertAllFused(ge.graph_for(*inputs)) 927 928 inputs = [ 929 torch.randn(2, 2, dtype=torch.float, device='cuda'), 930 torch.tensor(3., dtype=torch.float, device='cuda'), 931 ] 932 ge = self.checkScript(should_not_fuse, inputs) 933 self.assertGraphContainsExactly( 934 ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True) 935 936 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 937 @enable_cpu_fuser 938 def test_where_and_typing(self): 939 def f(x, y): 940 mask = x > y 941 res = torch.where(mask, x, y) 942 return mask, res 943 944 x = torch.randn(4, 4, dtype=torch.double) 945 y = torch.randn(4, 4, dtype=torch.double) 946 947 script_f = self.checkScript(f, (x, y)) 948 self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'}) 949 950 @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 951 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") 952 def test_grad_sum_to_size_elimination(self): 953 954 def my_broadcasted_cell(a, b, c): 955 return (a + b) + c 956 957 s1 = torch.randn(5, 1, requires_grad=True, device='cuda') 958 s2 = torch.randn(5, 5, requires_grad=True, device='cuda') 959 960 module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING) 961 forward_graph = module.graph_for(s1, s1, s1) 962 self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes", 963 "aten::_size_if_not_equal")) 964 965 old_plans = set() 966 for i in range(3): 967 # if we have s2, then the s1 are _grad_sum_to_size'd 968 969 args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2 970 args = [a.detach_().requires_grad_() for a in args] 971 # recompile, so we don't trigger bailouts 972 module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING) 973 res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2) 974 warmup_backward(res.sum(), args) 975 grads = torch.autograd.grad(res.sum(), args) 976 for inp, gr in zip(args, grads): 977 self.assertEqual(inp.shape, gr.shape) 978 backward = None 979 # this is a workaround for the backward graphs not being 980 # in order for Python 2 981 for g in all_backward_graphs(module): 982 if str(g) not in old_plans: 983 assert backward is None 984 backward = g 985 old_plans.add(str(backward)) 986 num_grads = 1 if i > 0 else 0 987 self.assertEqual(len([n for n in backward.nodes() if n.kind() == 'aten::_grad_sum_to_size']), num_grads) 988 989 990if __name__ == '__main__': 991 run_tests() 992