1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 8*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 10*da0073e9SAndroid Build Coastguard Workerfrom unittest import skipIf 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ 13*da0073e9SAndroid Build Coastguard Worker enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \ 15*da0073e9SAndroid Build Coastguard Worker RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward 16*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent 17*da0073e9SAndroid Build Coastguard Workerfrom itertools import product, permutations 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import with_tf32_off 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerfrom test_jit import backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, \ 21*da0073e9SAndroid Build Coastguard Worker LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerif GRAPH_EXECUTOR == ProfilingMode.PROFILING: 24*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_profiling_executor(True) 25*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_profiling_mode(True) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerdef strip_profiling_nodes(nodes): 29*da0073e9SAndroid Build Coastguard Worker profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'} 30*da0073e9SAndroid Build Coastguard Worker return [n for n in nodes if n.kind() not in profiling_opcodes] 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Workerdef warmup_forward(f, *args): 34*da0073e9SAndroid Build Coastguard Worker profiling_count = 2 35*da0073e9SAndroid Build Coastguard Worker for i in range(profiling_count): 36*da0073e9SAndroid Build Coastguard Worker results = f(*args) 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker return results 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker@skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "skip due to SIGIOT failures, #67646") 42*da0073e9SAndroid Build Coastguard Workerclass TestFuser(JitTestCase): 43*da0073e9SAndroid Build Coastguard Worker def assertAllFused(self, graph, except_for=()): 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker diff_graphs = [n for n in graph.nodes() if n.kind() == 'prim::DifferentiableGraph'] 46*da0073e9SAndroid Build Coastguard Worker if len(diff_graphs) > 0: 47*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(diff_graphs), 1) 48*da0073e9SAndroid Build Coastguard Worker graph = diff_graphs[0].g('Subgraph') 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker allowed_nodes = {'prim::Constant', 'prim::FusionGroup', 'prim::BailoutTemplate', 51*da0073e9SAndroid Build Coastguard Worker 'prim::BailOut', 'prim::TupleConstruct'} | set(except_for) 52*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()), 53*da0073e9SAndroid Build Coastguard Worker f'got {graph}') 54*da0073e9SAndroid Build Coastguard Worker self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker def _test_fused_abs(self, device='cpu'): 57*da0073e9SAndroid Build Coastguard Worker def func(x): 58*da0073e9SAndroid Build Coastguard Worker return x.abs() * 2 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, device=device) 61*da0073e9SAndroid Build Coastguard Worker scripted = self.checkScript(func, (a,)) 62*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(scripted.graph_for(a)) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 65*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 66*da0073e9SAndroid Build Coastguard Worker def test_abs_cpu(self): 67*da0073e9SAndroid Build Coastguard Worker self._test_fused_abs() 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not IS_WINDOWS, "This is meant to be Windows-specific") 70*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 71*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 72*da0073e9SAndroid Build Coastguard Worker def test_abs_cpu_unicode_temp_dir(self): 73*da0073e9SAndroid Build Coastguard Worker with TemporaryDirectoryName(suffix='\u4e2d\u6587') as dname: 74*da0073e9SAndroid Build Coastguard Worker shell_env = os.environ.copy() 75*da0073e9SAndroid Build Coastguard Worker shell_env['TMP'] = dname 76*da0073e9SAndroid Build Coastguard Worker cmd = [sys.executable, os.path.basename(__file__), type(self).__name__ + '.test_abs_cpu'] 77*da0073e9SAndroid Build Coastguard Worker legacy_jit_flag = '--jit-executor=legacy' 78*da0073e9SAndroid Build Coastguard Worker for v in sys.argv: 79*da0073e9SAndroid Build Coastguard Worker if v == legacy_jit_flag: 80*da0073e9SAndroid Build Coastguard Worker cmd.append(legacy_jit_flag) 81*da0073e9SAndroid Build Coastguard Worker return_code = shell(cmd, cwd=os.path.dirname(__file__), env=shell_env) 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual(return_code, 0) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "requires CUDA") 85*da0073e9SAndroid Build Coastguard Worker def test_abs_cuda(self): 86*da0073e9SAndroid Build Coastguard Worker self._test_fused_abs(device="cuda") 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "requires CUDA") 89*da0073e9SAndroid Build Coastguard Worker def test_zero_element_tensors(self): 90*da0073e9SAndroid Build Coastguard Worker def decode(sin_t, cos_t): 91*da0073e9SAndroid Build Coastguard Worker theta = torch.atan2(sin_t.float(), cos_t.float()) 92*da0073e9SAndroid Build Coastguard Worker return theta 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker sin = torch.zeros(0, device="cuda") 95*da0073e9SAndroid Build Coastguard Worker cos = torch.zeros(0, device="cuda") 96*da0073e9SAndroid Build Coastguard Worker inputs = [sin, cos] 97*da0073e9SAndroid Build Coastguard Worker ge = self.checkScript(decode, inputs) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 100*da0073e9SAndroid Build Coastguard Worker def test_arg_configurations_smoke_cuda(self): 101*da0073e9SAndroid Build Coastguard Worker # A smoke test to make sure we won't use the same kernel for contiguous 102*da0073e9SAndroid Build Coastguard Worker # and non-contiguous arguments. 103*da0073e9SAndroid Build Coastguard Worker # TODO: add optionally enabled debug counters to the fuser to verify 104*da0073e9SAndroid Build Coastguard Worker # that we really can tell the difference between configurations 105*da0073e9SAndroid Build Coastguard Worker def f(x, y): 106*da0073e9SAndroid Build Coastguard Worker z1, z2 = (x + y).chunk(2, dim=1) 107*da0073e9SAndroid Build Coastguard Worker return z1 * z2 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.float, device='cuda') 110*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, dtype=torch.float, device='cuda') 111*da0073e9SAndroid Build Coastguard Worker traced_f = torch.jit.trace(f, (x, y,)) 112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 115*da0073e9SAndroid Build Coastguard Worker def test_broadcast_cuda(self): 116*da0073e9SAndroid Build Coastguard Worker def scaleshift(x, scale, shift): 117*da0073e9SAndroid Build Coastguard Worker return x * scale + shift 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker inputs = [ 120*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, dtype=torch.float, device='cuda'), 121*da0073e9SAndroid Build Coastguard Worker torch.randn(4, dtype=torch.float, device='cuda'), 122*da0073e9SAndroid Build Coastguard Worker torch.randn(4, dtype=torch.float, device='cuda'), 123*da0073e9SAndroid Build Coastguard Worker ] 124*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(scaleshift, inputs) 125*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(*inputs)) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 128*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no bfloat support with profiling on") 129*da0073e9SAndroid Build Coastguard Worker def test_cuda_bfloat16(self): 130*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 131*da0073e9SAndroid Build Coastguard Worker return (x + y).relu() 132*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(foo) 133*da0073e9SAndroid Build Coastguard Worker x = torch.randn(65536).cuda().bfloat16() 134*da0073e9SAndroid Build Coastguard Worker y = torch.randn_like(x) 135*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(m.graph_for(x, y)) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 138*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA_HALF, "no half support") 139*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") 140*da0073e9SAndroid Build Coastguard Worker def test_cuda_half(self): 141*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.half, device='cuda') 142*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, dtype=torch.half, device='cuda') 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker funcs = [ 145*da0073e9SAndroid Build Coastguard Worker self.fn_test_comparison_gt_lt, 146*da0073e9SAndroid Build Coastguard Worker self.fn_test_relu, 147*da0073e9SAndroid Build Coastguard Worker self.fn_test_exp 148*da0073e9SAndroid Build Coastguard Worker ] 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker # Note: Non fused inputs must be float to prevent loss of precision 151*da0073e9SAndroid Build Coastguard Worker inputs = (x.float(), y.float()) 152*da0073e9SAndroid Build Coastguard Worker fusion_inputs = (x, y) 153*da0073e9SAndroid Build Coastguard Worker for fn in funcs: 154*da0073e9SAndroid Build Coastguard Worker local_inputs = [t.clone().requires_grad_() for t in inputs] 155*da0073e9SAndroid Build Coastguard Worker local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs] 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker # Verifies outputs 158*da0073e9SAndroid Build Coastguard Worker fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False) 159*da0073e9SAndroid Build Coastguard Worker outputs = fn(*local_inputs) 160*da0073e9SAndroid Build Coastguard Worker fusion_outputs = fusion(*local_fusion_inputs) 161*da0073e9SAndroid Build Coastguard Worker outputs_half = [t.half() for t in outputs] 162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs_half, fusion_outputs) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker # Verifies gradients 165*da0073e9SAndroid Build Coastguard Worker for output, fusion_output in zip(outputs_half, fusion_outputs): 166*da0073e9SAndroid Build Coastguard Worker grads = torch.autograd.grad( 167*da0073e9SAndroid Build Coastguard Worker output.float().sum(), local_inputs, allow_unused=True, retain_graph=True) 168*da0073e9SAndroid Build Coastguard Worker fusion_grads = torch.autograd.grad( 169*da0073e9SAndroid Build Coastguard Worker fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True) 170*da0073e9SAndroid Build Coastguard Worker grads_half = [t.half() for t in grads] 171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads_half, fusion_grads) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 174*da0073e9SAndroid Build Coastguard Worker def test_checks_cat_inputs(self): 175*da0073e9SAndroid Build Coastguard Worker # We shouldn't treat cat nodes as broadcasting. All their inputs 176*da0073e9SAndroid Build Coastguard Worker # need to be checked for having the same map size, before we can 177*da0073e9SAndroid Build Coastguard Worker # run the kernel. 178*da0073e9SAndroid Build Coastguard Worker def f(x, y): 179*da0073e9SAndroid Build Coastguard Worker return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker # NOTE: y is broadcastable to x, but output of f(x, y) should have 182*da0073e9SAndroid Build Coastguard Worker # shape 3x4, and not 4x4. 183*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 4, dtype=torch.float, device='cuda') 184*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1, 4, dtype=torch.float, device='cuda') 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker scripted = self.checkScript(f, (x, y)) 187*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(scripted.graph_for(x, y)) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "No CUDA") 190*da0073e9SAndroid Build Coastguard Worker def test_remainder_cuda(self): 191*da0073e9SAndroid Build Coastguard Worker def cuda_rem(x, y): 192*da0073e9SAndroid Build Coastguard Worker return 1 + torch.remainder(x, y) - 1 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker a = torch.rand([512], dtype=torch.float).cuda() 195*da0073e9SAndroid Build Coastguard Worker b = torch.rand([512], dtype=torch.float).cuda() 196*da0073e9SAndroid Build Coastguard Worker inputs = [a, b] 197*da0073e9SAndroid Build Coastguard Worker ge = self.checkScript(cuda_rem, inputs) 198*da0073e9SAndroid Build Coastguard Worker graph = ge.graph_for(*inputs) 199*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(graph) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "No CUDA") 202*da0073e9SAndroid Build Coastguard Worker def test_chunk_cuda(self): 203*da0073e9SAndroid Build Coastguard Worker def fn(x): 204*da0073e9SAndroid Build Coastguard Worker a, b, c = x.chunk(3, 1) 205*da0073e9SAndroid Build Coastguard Worker return a * b + c 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')] 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker ge = self.checkScript(fn, inputs) 210*da0073e9SAndroid Build Coastguard Worker graph = ge.graph_for(*inputs) 211*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(graph) 212*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::ConstantChunk[chunks=3, dim=1]").run(str(graph)) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker @staticmethod 215*da0073e9SAndroid Build Coastguard Worker def _test_chunk_correctness(self, device='cpu'): 216*da0073e9SAndroid Build Coastguard Worker def chunk_4_0(x): 217*da0073e9SAndroid Build Coastguard Worker x0, x1, x2, x3 = x.chunk(4, 0) 218*da0073e9SAndroid Build Coastguard Worker return x0 + x1 + x2 + x3 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker def chunk_4_1(x): 221*da0073e9SAndroid Build Coastguard Worker x0, x1, x2, x3 = x.chunk(4, 1) 222*da0073e9SAndroid Build Coastguard Worker return x0 + x1 + x2 + x3 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker def chunk_4_last(x): 225*da0073e9SAndroid Build Coastguard Worker x0, x1, x2, x3 = x.chunk(4, 2) 226*da0073e9SAndroid Build Coastguard Worker return x0 + x1 + x2 + x3 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker fns = [chunk_4_0, chunk_4_1, chunk_4_last] 229*da0073e9SAndroid Build Coastguard Worker tensors = [ 230*da0073e9SAndroid Build Coastguard Worker # splitSize = 1 231*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, 4, dtype=torch.float, device=device), 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker # contiguous case 234*da0073e9SAndroid Build Coastguard Worker torch.randn(12, 8, 16, dtype=torch.float, device=device), 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker # non-contiguous case 237*da0073e9SAndroid Build Coastguard Worker torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2), 238*da0073e9SAndroid Build Coastguard Worker ] 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 241*da0073e9SAndroid Build Coastguard Worker for fn in fns: 242*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, [tensor]) 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 245*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 246*da0073e9SAndroid Build Coastguard Worker def test_chunk_correctness(self): 247*da0073e9SAndroid Build Coastguard Worker return self._test_chunk_correctness(self, 'cpu') 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "No CUDA") 250*da0073e9SAndroid Build Coastguard Worker def test_chunk_correctness_cuda(self): 251*da0073e9SAndroid Build Coastguard Worker return self._test_chunk_correctness(self, 'cuda') 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 254*da0073e9SAndroid Build Coastguard Worker def test_chunk_distributes_cuda(self): 255*da0073e9SAndroid Build Coastguard Worker def f(x, y): 256*da0073e9SAndroid Build Coastguard Worker z1, z2 = (x + y).chunk(2, dim=1) 257*da0073e9SAndroid Build Coastguard Worker return z1 * z2 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.float, device='cuda') 260*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, dtype=torch.float, device='cuda') 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(f, (x, y)) 263*da0073e9SAndroid Build Coastguard Worker graph = ge.graph_for(x, y) 264*da0073e9SAndroid Build Coastguard Worker FileCheck().check("broadcast_tensors").check('with prim::FusionGroup_') \ 265*da0073e9SAndroid Build Coastguard Worker .check_count('ConstantChunk', 2, exactly=True).run(str(graph)) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 268*da0073e9SAndroid Build Coastguard Worker def test_chunk_motion_deduplicates_inputs(self): 269*da0073e9SAndroid Build Coastguard Worker def func1(x): 270*da0073e9SAndroid Build Coastguard Worker z = x * x 271*da0073e9SAndroid Build Coastguard Worker z0, z1 = z.chunk(2) 272*da0073e9SAndroid Build Coastguard Worker return z0 * z1 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker def func2(x): 275*da0073e9SAndroid Build Coastguard Worker z = x * x * x 276*da0073e9SAndroid Build Coastguard Worker z0, z1 = z.chunk(2) 277*da0073e9SAndroid Build Coastguard Worker return z0 * z1 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker inputs = [ 280*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float), 281*da0073e9SAndroid Build Coastguard Worker ] 282*da0073e9SAndroid Build Coastguard Worker for func in [func1, func2]: 283*da0073e9SAndroid Build Coastguard Worker module = self.checkScript(func, inputs) 284*da0073e9SAndroid Build Coastguard Worker forward_graph = module.graph_for(*inputs) 285*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1) 286*da0073e9SAndroid Build Coastguard Worker fusion_group = list(forward_graph.nodes())[-1] 287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(fusion_group.inputs())), 1) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "No CUDA") 290*da0073e9SAndroid Build Coastguard Worker def test_chunk_multiple_cuda(self): 291*da0073e9SAndroid Build Coastguard Worker # The arguments are intentionally used out of order as a test to see 292*da0073e9SAndroid Build Coastguard Worker # if the fusion compiler adds extra args in the correct order 293*da0073e9SAndroid Build Coastguard Worker def fn(s, x, y, z): 294*da0073e9SAndroid Build Coastguard Worker z1, z2 = z.chunk(2, 2) 295*da0073e9SAndroid Build Coastguard Worker x1, x2, x3 = x.chunk(3, 1) 296*da0073e9SAndroid Build Coastguard Worker y1, y2 = y.chunk(2, 0) 297*da0073e9SAndroid Build Coastguard Worker return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker inputs = [ 300*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 2, 3, dtype=torch.float, device='cuda'), 301*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 6, 3, dtype=torch.float, device='cuda'), 302*da0073e9SAndroid Build Coastguard Worker torch.randn(10, 2, 3, dtype=torch.float, device='cuda'), 303*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 2, 6, dtype=torch.float, device='cuda'), 304*da0073e9SAndroid Build Coastguard Worker ] 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker ge = self.checkScript(fn, inputs) 307*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(*inputs)) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 310*da0073e9SAndroid Build Coastguard Worker def test_minmax(self): 311*da0073e9SAndroid Build Coastguard Worker def tmax(a, b): 312*da0073e9SAndroid Build Coastguard Worker return torch.max(2 * a, b) 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker def tmin(a, b): 315*da0073e9SAndroid Build Coastguard Worker return torch.min(2 * a, b) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, 4, dtype=torch.float, device="cuda") 318*da0073e9SAndroid Build Coastguard Worker b = torch.randn(4, 4, dtype=torch.float, device="cuda") 319*da0073e9SAndroid Build Coastguard Worker nan = torch.tensor(float('nan'), dtype=torch.float, device="cuda") 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker for f, inputs in product( 322*da0073e9SAndroid Build Coastguard Worker (tmax, tmin), 323*da0073e9SAndroid Build Coastguard Worker ([a, b], [a, nan], [b, nan])): 324*da0073e9SAndroid Build Coastguard Worker s = self.checkScript(f, inputs) 325*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(s.graph_for(*inputs)) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 328*da0073e9SAndroid Build Coastguard Worker def test_clamp(self): 329*da0073e9SAndroid Build Coastguard Worker def func2(a, b): 330*da0073e9SAndroid Build Coastguard Worker return torch.clamp(a + b, min=0, max=2) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker def funcInf(a, b): 333*da0073e9SAndroid Build Coastguard Worker return torch.clamp(a + b, min=0, max=float('inf')) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker def funcOptMin(a, b): 336*da0073e9SAndroid Build Coastguard Worker return torch.clamp(a + b, max=2) 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker def funcOptMax(a, b): 339*da0073e9SAndroid Build Coastguard Worker return torch.clamp(a + b, min=0) 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) 342*da0073e9SAndroid Build Coastguard Worker b = torch.randn(4, 4, dtype=torch.float, device='cuda') 343*da0073e9SAndroid Build Coastguard Worker nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker funcs = (func2, funcInf, funcOptMin, funcOptMax) 346*da0073e9SAndroid Build Coastguard Worker for f, inputs in product(funcs, [[a, b], [a, nan]]): 347*da0073e9SAndroid Build Coastguard Worker f.__disable_jit_function_caching__ = True 348*da0073e9SAndroid Build Coastguard Worker inp1, inp2 = inputs 349*da0073e9SAndroid Build Coastguard Worker s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) 350*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) 351*da0073e9SAndroid Build Coastguard Worker c = s(inp1, inp2) 352*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 353*da0073e9SAndroid Build Coastguard Worker warmup_backward(c.sum()) 354*da0073e9SAndroid Build Coastguard Worker graph = backward_graph(s) 355*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'}) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 358*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") 359*da0073e9SAndroid Build Coastguard Worker def test_dropout(self): 360*da0073e9SAndroid Build Coastguard Worker def func(x): 361*da0073e9SAndroid Build Coastguard Worker x = torch.nn.functional.dropout(x) 362*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.relu(x) 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) 365*da0073e9SAndroid Build Coastguard Worker s = torch.jit.script(func) 366*da0073e9SAndroid Build Coastguard Worker c = s(a) 367*da0073e9SAndroid Build Coastguard Worker c = s(a) 368*da0073e9SAndroid Build Coastguard Worker warmup_backward(c.sum()) 369*da0073e9SAndroid Build Coastguard Worker # skip_check to skip extra bailout nodes in between 370*da0073e9SAndroid Build Coastguard Worker graph = backward_graph(s, skip_check=True) 371*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(graph, except_for={'aten::div', 'prim::Constant'}) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 374*da0073e9SAndroid Build Coastguard Worker def test_comparison_eq_ne(self): 375*da0073e9SAndroid Build Coastguard Worker def f(x, y): 376*da0073e9SAndroid Build Coastguard Worker mask = (x == 0).type_as(x) 377*da0073e9SAndroid Build Coastguard Worker z = x * mask + y 378*da0073e9SAndroid Build Coastguard Worker mask = (x != 0).type_as(x) 379*da0073e9SAndroid Build Coastguard Worker z = z * mask + y 380*da0073e9SAndroid Build Coastguard Worker return z 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.float, device='cuda') 383*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, dtype=torch.float, device='cuda') 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(f, (x, y)) 386*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(x, y)) 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker @staticmethod 389*da0073e9SAndroid Build Coastguard Worker def fn_test_comparison_gt_lt(x, y): 390*da0073e9SAndroid Build Coastguard Worker mask = (x > 0).type_as(x) 391*da0073e9SAndroid Build Coastguard Worker z = x * mask + y 392*da0073e9SAndroid Build Coastguard Worker mask = (x < 0).type_as(x) 393*da0073e9SAndroid Build Coastguard Worker z = z * mask + y 394*da0073e9SAndroid Build Coastguard Worker return z 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 397*da0073e9SAndroid Build Coastguard Worker def test_comparison_gt_lt_cuda(self): 398*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.float, device='cuda') 399*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, dtype=torch.float, device='cuda') 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) 402*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(x, y)) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 405*da0073e9SAndroid Build Coastguard Worker def test_comparison_ge_le_cuda(self): 406*da0073e9SAndroid Build Coastguard Worker def f(x, y): 407*da0073e9SAndroid Build Coastguard Worker mask = (x >= 0).type_as(x) 408*da0073e9SAndroid Build Coastguard Worker z = x * mask + y 409*da0073e9SAndroid Build Coastguard Worker mask = (x <= 0).type_as(x) 410*da0073e9SAndroid Build Coastguard Worker z = z * mask + y 411*da0073e9SAndroid Build Coastguard Worker return z 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.float, device='cuda') 414*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, dtype=torch.float, device='cuda') 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(f, (x, y)) 417*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(x, y)) 418*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(True) 419*da0073e9SAndroid Build Coastguard Worker y.requires_grad_(True) 420*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", 421*da0073e9SAndroid Build Coastguard Worker "aten::_size_if_not_equal")) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 424*da0073e9SAndroid Build Coastguard Worker def test_addcmul_cuda(self): 425*da0073e9SAndroid Build Coastguard Worker t = torch.randn(1, 4, dtype=torch.float, device='cuda') 426*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(4, 1, dtype=torch.float, device='cuda') 427*da0073e9SAndroid Build Coastguard Worker t2 = torch.randn(1, 4, dtype=torch.float, device='cuda') 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker def foo(t, t1, t2): 430*da0073e9SAndroid Build Coastguard Worker return t.addcmul(t + 1, t2, value=0.1) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) 433*da0073e9SAndroid Build Coastguard Worker graph = ge.graph_for(t, t1, t2) 434*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(graph) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker # TODO: We leak CUDA memory here because the traced graph holds onto a 437*da0073e9SAndroid Build Coastguard Worker # constant-ified tensor. Since the Python-global CompilationUnit is alive 438*da0073e9SAndroid Build Coastguard Worker # until the end of the process, the memory is effectively leaked. 439*da0073e9SAndroid Build Coastguard Worker # Removed `_cuda` suffix from this test which disables leak-checking. 440*da0073e9SAndroid Build Coastguard Worker # If this is a real problem, we'll need to revisit Torchscript Function 441*da0073e9SAndroid Build Coastguard Worker # lifetimes in Python. 442*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 443*da0073e9SAndroid Build Coastguard Worker def test_lerp(self): 444*da0073e9SAndroid Build Coastguard Worker start = torch.randn(4, 1, dtype=torch.float, device='cuda') 445*da0073e9SAndroid Build Coastguard Worker end = torch.randn(1, 4, dtype=torch.float, device='cuda') 446*da0073e9SAndroid Build Coastguard Worker weight = torch.tensor(0.5, dtype=torch.float, device='cuda') 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker # scalar weight overload 449*da0073e9SAndroid Build Coastguard Worker def foo_weight_scalar(start, end): 450*da0073e9SAndroid Build Coastguard Worker return torch.lerp(start + 1, end, 0.5) 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker # tensor weight overload 453*da0073e9SAndroid Build Coastguard Worker def foo_weight_tensor(start, end): 454*da0073e9SAndroid Build Coastguard Worker return torch.lerp(start + 1, end, weight) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end)) 457*da0073e9SAndroid Build Coastguard Worker graph = ge_weight_scalar.graph_for(start, end) 458*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(graph) 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end)) 461*da0073e9SAndroid Build Coastguard Worker graph = ge_weight_tensor.graph_for(start, end) 462*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(graph) 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 465*da0073e9SAndroid Build Coastguard Worker def test_concat_cuda(self): 466*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(3, 20, dtype=torch.float, device='cuda') 467*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(3, 20, dtype=torch.float, device='cuda') 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker def foo(hx, cx): 470*da0073e9SAndroid Build Coastguard Worker return torch.cat((hx + cx, hx * cx)) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(foo, (hx, cx)) 473*da0073e9SAndroid Build Coastguard Worker graph = ge.graph_for(hx, cx) 474*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(graph) 475*da0073e9SAndroid Build Coastguard Worker FileCheck().check("FusedConcat").check_next("return").run(str(graph)) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 478*da0073e9SAndroid Build Coastguard Worker def test_concat_invariant_cuda(self): 479*da0073e9SAndroid Build Coastguard Worker # Invariant: the output of prim::FusedConcat may 480*da0073e9SAndroid Build Coastguard Worker # not be an input to any node inside the FusionGroup. 481*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 482*da0073e9SAndroid Build Coastguard Worker x1 = x + y 483*da0073e9SAndroid Build Coastguard Worker y1 = x - y 484*da0073e9SAndroid Build Coastguard Worker w = torch.cat([x1, y1]) 485*da0073e9SAndroid Build Coastguard Worker return w + z 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, dtype=torch.float, device='cuda') 488*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2, dtype=torch.float, device='cuda') 489*da0073e9SAndroid Build Coastguard Worker z = torch.randn(4, 2, dtype=torch.float, device='cuda') 490*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(fn, (x, y, z)) 491*da0073e9SAndroid Build Coastguard Worker graph = ge.graph_for(x, y, z) 492*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(graph, except_for={'aten::add'}) 493*da0073e9SAndroid Build Coastguard Worker FileCheck().check("FusedConcat").check_next("return").run(str(graph)) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker @staticmethod 496*da0073e9SAndroid Build Coastguard Worker def fn_test_exp(x, y): 497*da0073e9SAndroid Build Coastguard Worker return (x + .5 * y).exp() 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 500*da0073e9SAndroid Build Coastguard Worker def test_exp_cuda(self): 501*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.float, device='cuda') 502*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, dtype=torch.float, device='cuda') 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(self.fn_test_exp, (x, y)) 505*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(x, y)) 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 508*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on") 509*da0073e9SAndroid Build Coastguard Worker @torch._jit_internal._disable_emit_hooks_decorator 510*da0073e9SAndroid Build Coastguard Worker @_inline_everything 511*da0073e9SAndroid Build Coastguard Worker def test_fuse_decompose_normalization(self): 512*da0073e9SAndroid Build Coastguard Worker class ResLike(torch.jit.ScriptModule): 513*da0073e9SAndroid Build Coastguard Worker def __init__(self, norm_module): 514*da0073e9SAndroid Build Coastguard Worker super().__init__() 515*da0073e9SAndroid Build Coastguard Worker self.nm = norm_module 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 518*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 519*da0073e9SAndroid Build Coastguard Worker return y + torch.relu(self.nm(x)) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph): 522*da0073e9SAndroid Build Coastguard Worker model = ResLike(nm).cuda() 523*da0073e9SAndroid Build Coastguard Worker model_noopt = ResLike(nm).cuda() 524*da0073e9SAndroid Build Coastguard Worker model_noopt.load_state_dict(model.state_dict()) 525*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 16, 8, 8, device='cuda') 526*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 16, 8, 8, device='cuda') 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker # FIXME: We need differentiation for CNNs for this optimization to trigger 529*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 530*da0073e9SAndroid Build Coastguard Worker out = model(x, y) 531*da0073e9SAndroid Build Coastguard Worker graph = model.graph_for(x, y) 532*da0073e9SAndroid Build Coastguard Worker rep = str(graph) 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 535*da0073e9SAndroid Build Coastguard Worker out_noopt = model_noopt(x, y) 536*da0073e9SAndroid Build Coastguard Worker rep_noopt = str(model_noopt.graph_for(x, y)) 537*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_noopt, atol=3e-5) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker # Check that normalization op has really been decomposed 540*da0073e9SAndroid Build Coastguard Worker for node_in_graph in in_opt_graph: 541*da0073e9SAndroid Build Coastguard Worker self.assertIn(node_in_graph, rep) 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker for node_not_in_graph in not_in_opt_graph: 544*da0073e9SAndroid Build Coastguard Worker self.assertNotIn(node_not_in_graph, rep) 545*da0073e9SAndroid Build Coastguard Worker self.assertIn(node_not_in_graph, rep_noopt) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup'] 548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(fusion_groups), 1) 549*da0073e9SAndroid Build Coastguard Worker fused_graph = str(fusion_groups[0].g('Subgraph')) 550*da0073e9SAndroid Build Coastguard Worker for node_in_fusegraph in in_fusegraph: 551*da0073e9SAndroid Build Coastguard Worker self.assertIn(node_in_fusegraph, fused_graph) 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker # test for batchnorm decompose 554*da0073e9SAndroid Build Coastguard Worker bm = nn.BatchNorm2d(16) 555*da0073e9SAndroid Build Coastguard Worker test_norm_decompose(bm, ['aten::batch_norm_update_stats'], 556*da0073e9SAndroid Build Coastguard Worker ['aten::batch_norm('], ['aten::sqrt']) 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker # test for layernorm decompose 559*da0073e9SAndroid Build Coastguard Worker lm = nn.LayerNorm(8) 560*da0073e9SAndroid Build Coastguard Worker test_norm_decompose(lm, ['aten::batch_norm_stats'], 561*da0073e9SAndroid Build Coastguard Worker ['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::add']) 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 564*da0073e9SAndroid Build Coastguard Worker def test_threshold(self): 565*da0073e9SAndroid Build Coastguard Worker def f(x): 566*da0073e9SAndroid Build Coastguard Worker return torch.threshold(x, 0, -10) + x + x + x 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda') 569*da0073e9SAndroid Build Coastguard Worker scripted = self.checkScript(f, (x,)) 570*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(scripted.graph_for(x)) 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 573*da0073e9SAndroid Build Coastguard Worker def test_scalar_arg_cuda(self): 574*da0073e9SAndroid Build Coastguard Worker def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor: 575*da0073e9SAndroid Build Coastguard Worker return p * (x * x + x) 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.float, device='cuda') 578*da0073e9SAndroid Build Coastguard Worker p = 3 579*da0073e9SAndroid Build Coastguard Worker scripted = self.checkScript(fn_test_scalar_arg, (x, p)) 580*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(scripted.graph_for(x, p)) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(True) 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker # use another function otherwise we will bailout 585*da0073e9SAndroid Build Coastguard Worker # and won't be able to do fused checks 586*da0073e9SAndroid Build Coastguard Worker def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor: 587*da0073e9SAndroid Build Coastguard Worker return p * (x * x + x) 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) 590*da0073e9SAndroid Build Coastguard Worker out = scripted(x, p) 591*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes", 592*da0073e9SAndroid Build Coastguard Worker "aten::_size_if_not_equal")) 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 595*da0073e9SAndroid Build Coastguard Worker @unittest.skip("deduplicating introduces aliasing in backward graph's outputs") 596*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 597*da0073e9SAndroid Build Coastguard Worker def test_fuser_deduplication(self): 598*da0073e9SAndroid Build Coastguard Worker # See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation 599*da0073e9SAndroid Build Coastguard Worker # see the discussion in PR #14957. 600*da0073e9SAndroid Build Coastguard Worker def f(x, y): 601*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(x + y) 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 5, requires_grad=True) 604*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 5, requires_grad=True) 605*da0073e9SAndroid Build Coastguard Worker s = self.checkScript(f, (a, b)) 606*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(s.graph_for(a, b), except_for={ 607*da0073e9SAndroid Build Coastguard Worker 'aten::size', 'aten::_size_if_not_equal', 'prim::BroadcastSizes'}) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker c = s(a, b) 610*da0073e9SAndroid Build Coastguard Worker results = warmup_backward(c.sum(), [a, b]) 611*da0073e9SAndroid Build Coastguard Worker ga2, gb2 = results.pop() 612*da0073e9SAndroid Build Coastguard Worker graph = backward_graph(s) 613*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(graph) 614*da0073e9SAndroid Build Coastguard Worker # check that a, b share storage, i.e. were generated as a single output in the fuser 615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ga2.data_ptr(), gb2.data_ptr()) 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 618*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 619*da0073e9SAndroid Build Coastguard Worker @unittest.skip("temporarily disabled because fusion was restricted in fixing #22833") 620*da0073e9SAndroid Build Coastguard Worker def test_fuser_iou(self): 621*da0073e9SAndroid Build Coastguard Worker # This checks if most of Intersection over Union is fused. 622*da0073e9SAndroid Build Coastguard Worker # In particular, the backward contains many _grad_sum_to_size. 623*da0073e9SAndroid Build Coastguard Worker def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2): 624*da0073e9SAndroid Build Coastguard Worker ltx = torch.max(b1x1, b2x1) # [N,M] 625*da0073e9SAndroid Build Coastguard Worker lty = torch.max(b1y1, b2y1) 626*da0073e9SAndroid Build Coastguard Worker rbx = torch.min(b1x2, b2x2) 627*da0073e9SAndroid Build Coastguard Worker rby = torch.min(b1y2, b2y2) 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker w = (rbx - ltx).clamp(min=0, max=float('inf')) # [N,M] 630*da0073e9SAndroid Build Coastguard Worker h = (rby - lty).clamp(min=0, max=float('inf')) # [N,M] 631*da0073e9SAndroid Build Coastguard Worker inter = w * h # [N,M] 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker area1 = (b1x2 - b1x1) * (b1y2 - b1y2) # [N,1] 634*da0073e9SAndroid Build Coastguard Worker area2 = (b2x2 - b2x1) * (b2y2 - b2y2) # [1,M] 635*da0073e9SAndroid Build Coastguard Worker iou = inter / (area1 + area2 - inter) 636*da0073e9SAndroid Build Coastguard Worker return iou 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker box1 = torch.randn(5, 4, requires_grad=True) 639*da0073e9SAndroid Build Coastguard Worker box2 = torch.randn(5, 4, requires_grad=True) 640*da0073e9SAndroid Build Coastguard Worker # unsqueezing can currently not be fused 641*da0073e9SAndroid Build Coastguard Worker b1x1 = box1[:, 0].unsqueeze(1) # [N,1] 642*da0073e9SAndroid Build Coastguard Worker b1y1 = box1[:, 1].unsqueeze(1) 643*da0073e9SAndroid Build Coastguard Worker b1x2 = box1[:, 2].unsqueeze(1) 644*da0073e9SAndroid Build Coastguard Worker b1y2 = box1[:, 3].unsqueeze(1) 645*da0073e9SAndroid Build Coastguard Worker b2x1 = box2[:, 0].unsqueeze(0) # [1,N] 646*da0073e9SAndroid Build Coastguard Worker b2y1 = box2[:, 1].unsqueeze(0) 647*da0073e9SAndroid Build Coastguard Worker b2x2 = box2[:, 2].unsqueeze(0) 648*da0073e9SAndroid Build Coastguard Worker b2y2 = box2[:, 3].unsqueeze(0) 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)) 651*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2), 652*da0073e9SAndroid Build Coastguard Worker except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'}) 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(True): 655*da0073e9SAndroid Build Coastguard Worker c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2) 656*da0073e9SAndroid Build Coastguard Worker warmup_backward(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2]) 657*da0073e9SAndroid Build Coastguard Worker graph = backward_graph(s) 658*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'}) 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 661*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") 662*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 663*da0073e9SAndroid Build Coastguard Worker def test_fusion_reuse_multi_gpu(self): 664*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 665*da0073e9SAndroid Build Coastguard Worker return x * y * x * y 666*da0073e9SAndroid Build Coastguard Worker 667*da0073e9SAndroid Build Coastguard Worker inputs_cpu = [ 668*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, dtype=torch.float), 669*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, dtype=torch.float), 670*da0073e9SAndroid Build Coastguard Worker ] 671*da0073e9SAndroid Build Coastguard Worker inputs_cuda0 = [x.cuda(0) for x in inputs_cpu] 672*da0073e9SAndroid Build Coastguard Worker inputs_cuda1 = [y.cuda(1) for y in inputs_cpu] 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker # Should not crash; these should compile different kernels. 675*da0073e9SAndroid Build Coastguard Worker ge = self.checkScript(fn, inputs_cpu) 676*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(*inputs_cpu)) 677*da0073e9SAndroid Build Coastguard Worker ge(*inputs_cuda0) 678*da0073e9SAndroid Build Coastguard Worker ge(*inputs_cuda1) 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 681*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") 682*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 683*da0073e9SAndroid Build Coastguard Worker def test_kernel_cache_multi_gpu(self): 684*da0073e9SAndroid Build Coastguard Worker def not_fusible(x): 685*da0073e9SAndroid Build Coastguard Worker return x 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 688*da0073e9SAndroid Build Coastguard Worker x_out = x * x * x * x * x # fusion: lambda x. x * x * x * x * x 689*da0073e9SAndroid Build Coastguard Worker y_out = y * y * y * y * y 690*da0073e9SAndroid Build Coastguard Worker z_out = z * z * z * z * z 691*da0073e9SAndroid Build Coastguard Worker return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker inputs = [ 694*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, dtype=torch.float), 695*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, dtype=torch.float, device='cuda:0'), 696*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, dtype=torch.float, device='cuda:1'), 697*da0073e9SAndroid Build Coastguard Worker ] 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker # There are 3 FusionGroups. Because they have the same graph, they 702*da0073e9SAndroid Build Coastguard Worker # should reuse the same KernelSpec in the KernelSpec cache. 703*da0073e9SAndroid Build Coastguard Worker ge = self.checkScript(fn, inputs) 704*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 705*da0073e9SAndroid Build Coastguard Worker ge.graph_for(*inputs), 'prim::FusionGroup', 3, True) 706*da0073e9SAndroid Build Coastguard Worker new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() 707*da0073e9SAndroid Build Coastguard Worker # XXX: This assumes that the same kernel isn't already used by another test 708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_cache_size - prev_cache_size, 1) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") 711*da0073e9SAndroid Build Coastguard Worker def test_nonzero_device_cuda(self): 712*da0073e9SAndroid Build Coastguard Worker device = 'cuda:' + str(1) 713*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.4], dtype=torch.float, device=device) 714*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0.7], dtype=torch.float, device=device) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker def doit(x, y): 717*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(torch.tanh(x * (x + y) + x)) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(doit, (x, y)) 720*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(x, y)) 721*da0073e9SAndroid Build Coastguard Worker 722*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 723*da0073e9SAndroid Build Coastguard Worker def test_lstm_cuda(self): 724*da0073e9SAndroid Build Coastguard Worker inputs = get_lstm_inputs('cuda', training=True) 725*da0073e9SAndroid Build Coastguard Worker module = self.checkScript(LSTMCellS, inputs) 726*da0073e9SAndroid Build Coastguard Worker return 727*da0073e9SAndroid Build Coastguard Worker forward_graph = module.graph_for(*inputs) 728*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 729*da0073e9SAndroid Build Coastguard Worker forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) 730*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2) 731*da0073e9SAndroid Build Coastguard Worker # Everything is differentiable but TupleConstruct return 732*da0073e9SAndroid Build Coastguard Worker FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ 733*da0073e9SAndroid Build Coastguard Worker .check_next("return").run(str(forward_graph)) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(True): 736*da0073e9SAndroid Build Coastguard Worker hy, cy = module(*inputs) 737*da0073e9SAndroid Build Coastguard Worker warmup_backward((hy + cy).sum()) 738*da0073e9SAndroid Build Coastguard Worker backward = backward_graph(module) 739*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(backward, except_for=("aten::t", "aten::mm", 740*da0073e9SAndroid Build Coastguard Worker "aten::_grad_sum_to_size")) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 743*da0073e9SAndroid Build Coastguard Worker # By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision. 744*da0073e9SAndroid Build Coastguard Worker # We want float tensors to be computed at full precision in order to use the default precision 745*da0073e9SAndroid Build Coastguard Worker @with_tf32_off 746*da0073e9SAndroid Build Coastguard Worker def test_lstm_concat_cuda(self): 747*da0073e9SAndroid Build Coastguard Worker inputs = get_lstm_inputs('cuda') 748*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(LSTMCellC, inputs) 749*da0073e9SAndroid Build Coastguard Worker graph = ge.graph_for(*inputs) 750*da0073e9SAndroid Build Coastguard Worker FileCheck().check("FusedConcat").check_next("return").run(str(graph)) 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 753*da0073e9SAndroid Build Coastguard Worker def test_lstm_gates_permutations_cuda(self): 754*da0073e9SAndroid Build Coastguard Worker # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. 755*da0073e9SAndroid Build Coastguard Worker # Test that any permutation of this will still result in one FusionGroup. 756*da0073e9SAndroid Build Coastguard Worker choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh'] 757*da0073e9SAndroid Build Coastguard Worker template = dedent(''' 758*da0073e9SAndroid Build Coastguard Worker def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): 759*da0073e9SAndroid Build Coastguard Worker gates = {} + {} + {} + {} 760*da0073e9SAndroid Build Coastguard Worker ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 761*da0073e9SAndroid Build Coastguard Worker return ingate * forgetgate * cellgate * outgate 762*da0073e9SAndroid Build Coastguard Worker ''') 763*da0073e9SAndroid Build Coastguard Worker for permutation in permutations(choices, len(choices)): 764*da0073e9SAndroid Build Coastguard Worker code = template.format(*permutation) 765*da0073e9SAndroid Build Coastguard Worker scope = {} 766*da0073e9SAndroid Build Coastguard Worker exec(code, globals(), scope) 767*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker inputs = get_lstm_inputs('cuda', training=False) 770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs)) 771*da0073e9SAndroid Build Coastguard Worker forward_graph = cu.cell.graph_for(*inputs) 772*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1) 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker # TODO: Fuser doesn't work at all when inputs require grad. Fix that 775*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 776*da0073e9SAndroid Build Coastguard Worker # By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision. 777*da0073e9SAndroid Build Coastguard Worker # We want float tensors to be computed at full precision in order to use the default precision 778*da0073e9SAndroid Build Coastguard Worker @with_tf32_off 779*da0073e9SAndroid Build Coastguard Worker def test_lstm_traced_cuda(self): 780*da0073e9SAndroid Build Coastguard Worker inputs = get_lstm_inputs('cuda') 781*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(LSTMCellF, inputs) 782*da0073e9SAndroid Build Coastguard Worker graph = ge.graph_for(*inputs) 783*da0073e9SAndroid Build Coastguard Worker # .check_not("aten::add") don't get pulled into FusionGroup because of BailOuts 784*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("Chunk").check_not("aten::sigmoid") \ 785*da0073e9SAndroid Build Coastguard Worker .check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \ 786*da0073e9SAndroid Build Coastguard Worker .check_next("return").check_not("FusionGroup_2").run(str(graph)) 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 789*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746") 790*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 791*da0073e9SAndroid Build Coastguard Worker def test_lstm_traced_cpu(self): 792*da0073e9SAndroid Build Coastguard Worker inputs = get_lstm_inputs('cpu') 793*da0073e9SAndroid Build Coastguard Worker try: 794*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(LSTMCellF, inputs) 795*da0073e9SAndroid Build Coastguard Worker graph = ge.graph_for(*inputs) 796*da0073e9SAndroid Build Coastguard Worker FileCheck.check("FusionGroup").run(str(graph)) 797*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 798*da0073e9SAndroid Build Coastguard Worker if 'Failed to compile' in e.args[0]: 799*da0073e9SAndroid Build Coastguard Worker warnings.warn('CPU fuser test has failed! This is not a hard failure, ' # noqa: F821 800*da0073e9SAndroid Build Coastguard Worker 'because the kernels sometimes trigger bugs in compilers ' 801*da0073e9SAndroid Build Coastguard Worker '(most notably GCC 7.2).') 802*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest('Failed to compile') from e 803*da0073e9SAndroid Build Coastguard Worker else: 804*da0073e9SAndroid Build Coastguard Worker raise 805*da0073e9SAndroid Build Coastguard Worker 806*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 807*da0073e9SAndroid Build Coastguard Worker def test_milstm_cuda(self): 808*da0073e9SAndroid Build Coastguard Worker inputs = get_milstm_inputs('cuda', training=True) 809*da0073e9SAndroid Build Coastguard Worker module = self.checkScript(MiLSTMCell, inputs) 810*da0073e9SAndroid Build Coastguard Worker forward_graph = module.graph_for(*inputs) 811*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 812*da0073e9SAndroid Build Coastguard Worker forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) 813*da0073e9SAndroid Build Coastguard Worker FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ 814*da0073e9SAndroid Build Coastguard Worker .check_next("return").check("FusionGroup").run(str(forward_graph)) 815*da0073e9SAndroid Build Coastguard Worker hy, cy = module(*inputs) 816*da0073e9SAndroid Build Coastguard Worker warmup_backward((hy + cy).sum()) 817*da0073e9SAndroid Build Coastguard Worker 818*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 819*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor") 820*da0073e9SAndroid Build Coastguard Worker def test_rand_cuda(self): 821*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 822*da0073e9SAndroid Build Coastguard Worker __constants__ = ['d'] 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 825*da0073e9SAndroid Build Coastguard Worker super().__init__() 826*da0073e9SAndroid Build Coastguard Worker self.d = torch.device('cuda') 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 829*da0073e9SAndroid Build Coastguard Worker def create(self, x): 830*da0073e9SAndroid Build Coastguard Worker return x * x + x + torch.rand_like(x) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda') 833*da0073e9SAndroid Build Coastguard Worker m = M() 834*da0073e9SAndroid Build Coastguard Worker out1 = m.create(x) 835*da0073e9SAndroid Build Coastguard Worker out2 = m.create(x) 836*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(out1, out2) 837*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(out1 >= 0)) 838*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(out1 < 1)) 839*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(out2 >= 0)) 840*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(out2 < 1)) 841*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(m.create.graph_for(x)) 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker @staticmethod 844*da0073e9SAndroid Build Coastguard Worker def fn_test_relu(x, y): 845*da0073e9SAndroid Build Coastguard Worker return F.relu(x + .5 * y) 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 848*da0073e9SAndroid Build Coastguard Worker def test_relu_cuda(self): 849*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.float, device='cuda') 850*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, dtype=torch.float, device='cuda') 851*da0073e9SAndroid Build Coastguard Worker 852*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(self.fn_test_relu, (x, y)) 853*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(x, y)) 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 856*da0073e9SAndroid Build Coastguard Worker def test_erf_cuda(self): 857*da0073e9SAndroid Build Coastguard Worker def fn_test_erf(x): 858*da0073e9SAndroid Build Coastguard Worker return F.relu(torch.erf(x) - torch.erfc(x)) 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.float, device='cuda') 861*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(fn_test_erf, (x,)) 862*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(x)) 863*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(True) 864*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(fn_test_erf, (x,)) 865*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes", 866*da0073e9SAndroid Build Coastguard Worker "aten::_size_if_not_equal")) 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 869*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor") 870*da0073e9SAndroid Build Coastguard Worker def test_rand_broadcast_cuda(self): 871*da0073e9SAndroid Build Coastguard Worker def fn_test_rand(x, y): 872*da0073e9SAndroid Build Coastguard Worker r = torch.rand_like(y) 873*da0073e9SAndroid Build Coastguard Worker return r * x + x 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.float, device='cuda') 876*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, dtype=torch.float, device='cuda') 877*da0073e9SAndroid Build Coastguard Worker script_f = torch.jit.script(fn_test_rand) 878*da0073e9SAndroid Build Coastguard Worker out = script_f(x, y) 879*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(script_f.graph_for(x, y)) 880*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(True) 881*da0073e9SAndroid Build Coastguard Worker out = script_f(x, y) 882*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes", 883*da0073e9SAndroid Build Coastguard Worker "aten::_size_if_not_equal")) 884*da0073e9SAndroid Build Coastguard Worker # test that broadcasting random produces correct results 885*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4, 4, dtype=torch.float, device='cuda') 886*da0073e9SAndroid Build Coastguard Worker y = torch.ones(4, dtype=torch.float, device='cuda') 887*da0073e9SAndroid Build Coastguard Worker out = script_f(x, y) 888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[0], out[1]) 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 891*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 892*da0073e9SAndroid Build Coastguard Worker def test_scalar(self): 893*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 894*da0073e9SAndroid Build Coastguard Worker return 2 * x + y 895*da0073e9SAndroid Build Coastguard Worker 896*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(0.1, dtype=torch.float, device='cpu') 897*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(1, dtype=torch.float, device='cpu') 898*da0073e9SAndroid Build Coastguard Worker ge = self.checkScript(fn, (x, y)) 899*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(x, y)) 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 902*da0073e9SAndroid Build Coastguard Worker def test_small_constant_cuda(self): 903*da0073e9SAndroid Build Coastguard Worker def fn_test_small_constant(x, y): 904*da0073e9SAndroid Build Coastguard Worker return (1e-8 * x + 5e-9 * y) * 1e8 905*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.float, device='cuda') 906*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, dtype=torch.float, device='cuda') 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Worker ge = self.checkTrace(fn_test_small_constant, (x, y)) 909*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(x, y)) 910*da0073e9SAndroid Build Coastguard Worker 911*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 912*da0073e9SAndroid Build Coastguard Worker def test_tensor_scalar_ops_cuda(self): 913*da0073e9SAndroid Build Coastguard Worker def should_fuse(x): 914*da0073e9SAndroid Build Coastguard Worker z = 3. 915*da0073e9SAndroid Build Coastguard Worker y = x + z 916*da0073e9SAndroid Build Coastguard Worker return x * y 917*da0073e9SAndroid Build Coastguard Worker 918*da0073e9SAndroid Build Coastguard Worker # XXX: right now we only support fusing scalars if 919*da0073e9SAndroid Build Coastguard Worker # they're constant (#9940) 920*da0073e9SAndroid Build Coastguard Worker def should_not_fuse(x, z): 921*da0073e9SAndroid Build Coastguard Worker y = x + int(z) 922*da0073e9SAndroid Build Coastguard Worker return x * y 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')] 925*da0073e9SAndroid Build Coastguard Worker ge = self.checkScript(should_fuse, inputs) 926*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(ge.graph_for(*inputs)) 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker inputs = [ 929*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 2, dtype=torch.float, device='cuda'), 930*da0073e9SAndroid Build Coastguard Worker torch.tensor(3., dtype=torch.float, device='cuda'), 931*da0073e9SAndroid Build Coastguard Worker ] 932*da0073e9SAndroid Build Coastguard Worker ge = self.checkScript(should_not_fuse, inputs) 933*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 934*da0073e9SAndroid Build Coastguard Worker ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True) 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") 937*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 938*da0073e9SAndroid Build Coastguard Worker def test_where_and_typing(self): 939*da0073e9SAndroid Build Coastguard Worker def f(x, y): 940*da0073e9SAndroid Build Coastguard Worker mask = x > y 941*da0073e9SAndroid Build Coastguard Worker res = torch.where(mask, x, y) 942*da0073e9SAndroid Build Coastguard Worker return mask, res 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, dtype=torch.double) 945*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, dtype=torch.double) 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker script_f = self.checkScript(f, (x, y)) 948*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'}) 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") 951*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on") 952*da0073e9SAndroid Build Coastguard Worker def test_grad_sum_to_size_elimination(self): 953*da0073e9SAndroid Build Coastguard Worker 954*da0073e9SAndroid Build Coastguard Worker def my_broadcasted_cell(a, b, c): 955*da0073e9SAndroid Build Coastguard Worker return (a + b) + c 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker s1 = torch.randn(5, 1, requires_grad=True, device='cuda') 958*da0073e9SAndroid Build Coastguard Worker s2 = torch.randn(5, 5, requires_grad=True, device='cuda') 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING) 961*da0073e9SAndroid Build Coastguard Worker forward_graph = module.graph_for(s1, s1, s1) 962*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes", 963*da0073e9SAndroid Build Coastguard Worker "aten::_size_if_not_equal")) 964*da0073e9SAndroid Build Coastguard Worker 965*da0073e9SAndroid Build Coastguard Worker old_plans = set() 966*da0073e9SAndroid Build Coastguard Worker for i in range(3): 967*da0073e9SAndroid Build Coastguard Worker # if we have s2, then the s1 are _grad_sum_to_size'd 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Worker args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2 970*da0073e9SAndroid Build Coastguard Worker args = [a.detach_().requires_grad_() for a in args] 971*da0073e9SAndroid Build Coastguard Worker # recompile, so we don't trigger bailouts 972*da0073e9SAndroid Build Coastguard Worker module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING) 973*da0073e9SAndroid Build Coastguard Worker res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2) 974*da0073e9SAndroid Build Coastguard Worker warmup_backward(res.sum(), args) 975*da0073e9SAndroid Build Coastguard Worker grads = torch.autograd.grad(res.sum(), args) 976*da0073e9SAndroid Build Coastguard Worker for inp, gr in zip(args, grads): 977*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp.shape, gr.shape) 978*da0073e9SAndroid Build Coastguard Worker backward = None 979*da0073e9SAndroid Build Coastguard Worker # this is a workaround for the backward graphs not being 980*da0073e9SAndroid Build Coastguard Worker # in order for Python 2 981*da0073e9SAndroid Build Coastguard Worker for g in all_backward_graphs(module): 982*da0073e9SAndroid Build Coastguard Worker if str(g) not in old_plans: 983*da0073e9SAndroid Build Coastguard Worker assert backward is None 984*da0073e9SAndroid Build Coastguard Worker backward = g 985*da0073e9SAndroid Build Coastguard Worker old_plans.add(str(backward)) 986*da0073e9SAndroid Build Coastguard Worker num_grads = 1 if i > 0 else 0 987*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len([n for n in backward.nodes() if n.kind() == 'aten::_grad_sum_to_size']), num_grads) 988*da0073e9SAndroid Build Coastguard Worker 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 991*da0073e9SAndroid Build Coastguard Worker run_tests() 992