1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import unittest 6 7import torch 8from torch.testing._internal.common_jit import check_against_reference 9from torch.testing._internal.common_utils import ( 10 enable_profiling_mode_for_profiling_tests, 11 GRAPH_EXECUTOR, 12 num_profiled_runs, 13 ProfilingMode, 14) 15 16 17# Make the helper files in test/ importable 18pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 19sys.path.append(pytorch_test_dir) 20from typing import List, Optional, Tuple 21 22from torch.testing import FileCheck 23from torch.testing._internal.jit_utils import ( 24 disable_autodiff_subgraph_inlining, 25 JitTestCase, 26) 27 28 29if __name__ == "__main__": 30 raise RuntimeError( 31 "This test file is not meant to be run directly, use:\n\n" 32 "\tpython test/test_jit.py TESTNAME\n\n" 33 "instead." 34 ) 35 36 37@unittest.skipIf( 38 GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients" 39) 40class TestAutodiffSubgraphSlicing(JitTestCase): 41 # TODO: It is better if we can test directly on graphs instead of the current 42 # end-to-end fashion. 43 def _perform_ad_subgraph_slicing(self, fn, *input_sizes): 44 with disable_autodiff_subgraph_inlining(): 45 with enable_profiling_mode_for_profiling_tests(): 46 ge = torch.jit.script(fn) 47 inputs = [torch.randn(size, requires_grad=True) for size in input_sizes] 48 ge(*inputs, profile_and_replay=True) 49 return ge.graph_for(*inputs) 50 51 def assertGraphSize(self, graph, size): 52 nodes = list( 53 filter( 54 lambda n: ( 55 n.kind() != "prim::BailOut" 56 and n.kind() != "prim::BailoutTemplate" 57 and n.kind() != "prim::TypeCheck" 58 and n.kind() != "prim::RequiresGradCheck" 59 ), 60 graph.nodes(), 61 ) 62 ) 63 self.assertEqual(len(list(nodes)), size) 64 65 def test_chunk_constant_script_ad(self): 66 @torch.jit.script 67 def func(x): 68 x1, x2 = torch.chunk(x, 2) 69 return (x1, x2) 70 71 input = torch.rand(6, 10).requires_grad_() 72 with disable_autodiff_subgraph_inlining(): 73 with enable_profiling_mode_for_profiling_tests(): 74 output = func(input, profile_and_replay=True) 75 FileCheck().check_not("prim::DifferentiableGraph").run( 76 func.graph_for(input) 77 ) 78 79 @unittest.skipIf( 80 GRAPH_EXECUTOR != ProfilingMode.PROFILING, 81 "This threshold is only valid for Profiling Executor", 82 ) 83 def test_diff_graph_inline_threshold(self): 84 with enable_profiling_mode_for_profiling_tests(): 85 NUM_RUNS = 1 86 with num_profiled_runs(NUM_RUNS): 87 88 @torch.jit.script 89 def foo(x): 90 # two nodes should be fused 91 # see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49 92 return torch.sigmoid(torch.sigmoid(x)) 93 94 @torch.jit.script 95 def bar(x): 96 # two nodes should NOT be fused 97 return torch.sigmoid(x) 98 99 input = torch.rand([4, 4], requires_grad=True) 100 foo(input) 101 foo(input) 102 103 bar(input) 104 bar(input) 105 106 self.assertGraphContainsExactly( 107 foo.graph_for(input), "prim::DifferentiableGraph", 1 108 ) 109 self.assertGraphContainsExactly( 110 bar.graph_for(input), "prim::DifferentiableGraph", 0 111 ) 112 113 def test_bias_as_module_attr(self): 114 with enable_profiling_mode_for_profiling_tests(): 115 116 class M(torch.nn.Module): 117 def __init__(self, has_bias): 118 super().__init__() 119 self.ll = torch.nn.Linear(10, 10, has_bias) 120 121 def forward(self, x, y): 122 return self.ll(x + y) * x + y 123 124 x = torch.rand(10, 10, requires_grad=True) 125 no_bias = M(False) 126 scripted_no_bias = torch.jit.script(no_bias) 127 scripted_no_bias(x, x) 128 scripted_no_bias(x, x) 129 scripted_no_bias(x, x) 130 has_bias = M(True) 131 check_against_reference( 132 self, 133 scripted_no_bias, 134 no_bias, 135 lambda x: x, 136 ( 137 x, 138 x, 139 ), 140 check_types=False, 141 ) 142 scripted_has_bias = torch.jit.script(has_bias) 143 scripted_has_bias(x, x) 144 scripted_has_bias(x, x) 145 scripted_has_bias(x, x) 146 check_against_reference( 147 self, 148 scripted_has_bias, 149 has_bias, 150 lambda x: x, 151 ( 152 x, 153 x, 154 ), 155 check_types=False, 156 ) 157 158 def test_constructed_bias(self): 159 with enable_profiling_mode_for_profiling_tests(): 160 161 def method1(x, weight, b1, b2): 162 bias = b1 * b2 163 return torch.nn.functional.linear(x, weight, bias) 164 165 N = 10 166 x = torch.rand(N, N, requires_grad=True) 167 weight = torch.rand(N, N, requires_grad=True) 168 b1 = torch.rand(N, N, requires_grad=True) 169 b2 = torch.rand(N, N, requires_grad=True) 170 scripted = self.checkScript(method1, (x, weight, b1, b2)) 171 # check_types requires last_graph on scripted to be set, so we just skip it 172 check_against_reference( 173 self, 174 scripted, 175 method1, 176 lambda x: x, 177 (x, weight, b1, b2), 178 check_types=False, 179 ) 180 181 def test_bias_as_arg(self): 182 with enable_profiling_mode_for_profiling_tests(): 183 184 def method1(x, weight, bias: Optional[torch.Tensor]): 185 return torch.nn.functional.linear(x, weight, bias).relu() + 2 186 187 N = 10 188 x = torch.rand(N, N, requires_grad=True) 189 weight = torch.rand(N, N, requires_grad=True) 190 bias = None 191 scripted = self.checkScript(method1, (x, weight, bias)) 192 # check_types requires last_graph on scripted to be set, so we just skip it 193 check_against_reference( 194 self, 195 scripted, 196 method1, 197 lambda x: x, 198 (x, weight, bias), 199 check_types=False, 200 ) 201 bias = torch.rand(N, N, requires_grad=True) 202 scripted = self.checkScript(method1, (x, weight, bias)) 203 # check_types requires last_graph on scripted to be set, so we just skip it 204 check_against_reference( 205 self, 206 scripted, 207 method1, 208 lambda x: x, 209 (x, weight, bias), 210 check_types=False, 211 ) 212 213 def test_requires_grad_for_tensor_list(self): 214 with enable_profiling_mode_for_profiling_tests(): 215 # output & var_list[0] should have requires_grad set to True 216 def func( 217 input0: torch.Tensor, input1: torch.Tensor 218 ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 219 var_list = [input0, input1] 220 var = torch.cat(var_list) 221 output = var + 1.0 222 return output, var_list 223 224 jit_f = torch.jit.script(func) 225 input0 = torch.randn((2,), requires_grad=True) 226 input1 = torch.randn((2,)) 227 output_ref = func(input0, input1) 228 for i in range(2): 229 output = jit_f(input0, input1) 230 assert output_ref[0].requires_grad == output[0].requires_grad 231 assert output_ref[1][0].requires_grad == output[1][0].requires_grad 232 assert output_ref[1][1].requires_grad == output[1][1].requires_grad 233 234 @unittest.skip( 235 "disable until we property handle tensor lists with undefined gradients" 236 ) 237 def test_differentiable_graph_ops_requires_grad(self): 238 x = torch.randn(8, 2, dtype=torch.float).requires_grad_() 239 y = torch.randn(8, 2, dtype=torch.float) 240 241 def t(x: torch.Tensor, y: torch.Tensor, flag: bool): 242 o = x + 1.0 243 o1 = torch.relu(o) 244 o = y + 1.5 245 o2 = torch.relu(o) 246 o3 = o1 + o2 247 248 if flag: 249 o = o1 + 1.0 250 oo1 = torch.relu(o) 251 o = o2 + 2.5 252 oo2 = torch.relu(o) 253 oo3 = oo1 + oo2 254 else: 255 o = o1 * 1.0 256 oo1 = torch.relu(o) 257 o = o2 * 2.0 258 oo2 = torch.relu(o) 259 oo3 = oo1 + oo2 260 261 return o1, o2, o3, oo1, oo2, oo3 262 263 with enable_profiling_mode_for_profiling_tests(): 264 t_jit = torch.jit.script(t) 265 jit_o = t_jit(x, y, False) 266 jit_o = t_jit(x, y, False) 267 o = t(x, y, False) 268 269 FileCheck().check("prim::DifferentiableGraph").run( 270 t_jit.graph_for(x, y, False) 271 ) 272 # validate the differentiableGraphOps are marking proper requires_grad 273 for oo, jit_oo in zip(o, jit_o): 274 self.assertEqual(oo.requires_grad, jit_oo.requires_grad) 275 self.assertEqual(oo, jit_oo) 276 # one more runs to trigger fusion 277 jit_o = t_jit(x, y, False) 278 for oo, jit_oo in zip(o, jit_o): 279 self.assertEqual(oo.dtype, jit_oo.dtype) 280 self.assertEqual(oo.requires_grad, jit_oo.requires_grad) 281 self.assertEqual(oo, jit_oo) 282 283 @unittest.skipIf( 284 GRAPH_EXECUTOR == ProfilingMode.PROFILING, 285 "Simple Executor doesn't support gradients", 286 ) 287 def test_prune_grad(self): 288 @torch.jit.script 289 def t(input, bias): 290 return torch.nn.functional.relu(input + bias) 291 292 input = torch.randn(2, 8, requires_grad=True) 293 bias = torch.randn(8, requires_grad=False) # bias does NOT require grad 294 NUM_PROFILED_RUNS = 1 295 with num_profiled_runs(NUM_PROFILED_RUNS): 296 WARMUP = 3 # 2 runs to reach backward + 1 to optimize it 297 for x in range(WARMUP): 298 o = t(input, bias) 299 o.sum().backward() 300 301 fwd_plan = list(t.get_debug_state().execution_plans.values())[0] 302 bwd_graph = list( 303 fwd_plan.code.grad_executor_states()[0].execution_plans.values() 304 )[0].graph 305 tup = next(bwd_graph.outputs()) 306 self.assertEqual(len(list(tup.node().inputs())), 1) 307 308 def test_simple_merge(self): 309 # o --> o 310 def fn(x, y, z): 311 a = x * y 312 b = a * z 313 return b 314 315 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1) 316 317 self.assertGraphSize(graph, 1) 318 self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1) 319 320 def test_simple_no_merge(self): 321 # o: autodiff supported. x: not autodiff supported. 322 # o --> x 323 def fn(x, y, z): 324 a = x * y 325 b = torch.zeros([abs(int(y))]) 326 return a, b 327 328 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1) 329 g_str = str(graph) 330 FileCheck().check("aten::Int").check("aten::zeros").check_not("aten::mul").run( 331 g_str[0 : g_str.find("return")] 332 ) 333 self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1) 334 335 def test_does_not_merge_unrelated(self): 336 # o o 337 def fn(w, x, y, z): 338 a = x * y 339 b = w * z 340 return a, b 341 342 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1) 343 344 self.assertGraphSize(graph, 3) 345 self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2) 346 347 def test_merges_without_cycles(self): 348 # o --> o --> o 349 # | ^ 350 # \_________/ 351 def fn(w, x, y): 352 a = w * x 353 b = a * y 354 c = a * b 355 return c 356 357 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1) 358 359 self.assertGraphSize(graph, 1) 360 self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1) 361 362 def test_merges_dense(self): 363 # o o 364 # |\ /| 365 # | \ / | 366 # | /\ | 367 # vv vv 368 # o o 369 def fn(x, y): 370 a, b = x.chunk(2) 371 c, d = y.chunk(2) 372 return a + c, b + d 373 374 graph = self._perform_ad_subgraph_slicing(fn, 2, 2) 375 376 self.assertGraphSize(graph, 2) 377 self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1) 378 379 def test_does_not_create_cycles(self): 380 # o --> x --> o 381 # | ^ 382 # \_________/ 383 def fn(w, x, y): 384 a = w * x 385 b = torch.zeros(abs(int(a))) 386 c = a * b 387 return c 388 389 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1) 390 self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2) 391 392 def test_merges_up(self): 393 # o --> x o 394 # | ^ 395 # \_________/ 396 def fn(w, x, y, z): 397 a = w * x 398 b = torch.zeros(abs(int(y))) 399 c = a * z 400 return b, c 401 402 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1) 403 g_str = str(graph) 404 FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")]) 405 self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1) 406 407 def test_merges_down(self): 408 # o x --> o 409 # | ^ 410 # \_________/ 411 def fn(v, w, x, y): 412 a = v * w 413 b = torch.ones(int(y)) 414 c = b * a 415 return a, c 416 417 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1) 418 419 num_nodes = 4 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 3 420 # add moved down 421 g_str = str(graph) 422 FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")]) 423 self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1) 424 425 def test_respects_lexical_scoping(self): 426 def fn(x, k): 427 y = x * 1.1 428 if bool(k): 429 k = k + y 430 z = y * k 431 return z, k 432 433 graph = self._perform_ad_subgraph_slicing(fn, 1, 1) 434 # We should not have combined the two multiplications into 435 # the same group; they should each be a separate DiffGraph 436 self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 3) 437 438 def test_merge_respects_aliasing(self): 439 def fn(x, k, cond): 440 y = x * 1.1 441 y = y * k 442 y = y * 2.2 443 if bool(cond): 444 z1 = y[0] 445 z2 = y[1] 446 z1.add_(3) 447 out = z2 + k + 3.3 448 out = out * out 449 return out 450 451 graph = self._perform_ad_subgraph_slicing(fn, [2, 2], [2, 2], 1) 452 # z2 did did not get merged into the subgraph 453 FileCheck().check("prim::If").check("aten::select").check_next( 454 "aten::select" 455 ).check_next("aten::add_").check("Differentiable").run(graph) 456 self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2) 457 458 def test_aliased_outputs(self): 459 with enable_profiling_mode_for_profiling_tests(): 460 # Case 1: aliasing between relu and t 461 # is within a DifferentiableGraph. It should be valid 462 # to merge both split_with_sizes in relu in one graph 463 input_str = """ 464 graph(%a : Tensor): 465 %b : Tensor = aten::relu(%a) 466 %2 : Tensor = aten::t(%b) 467 return (%2) 468 """ 469 470 graph = torch._C.parse_ir(input_str) 471 torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) 472 FileCheck().check("with prim::DifferentiableGraph").check( 473 "aten::relu" 474 ).check("aten::t").run(graph) 475 476 # Case 2: aliasing between relu and split_with_sizes 477 # are both outputs of a Diff graph. It should be invalid 478 # to merge both split_with_sizes in relu in one graph 479 # i.e. relu and split_with_sizes should be in different 480 # differentiable graphs 481 input_str = """ 482 graph(%a : Tensor): 483 %b : Tensor = aten::relu(%a) 484 %0 : int[] = prim::Constant[value=[2, 2, 1]]() 485 %1 : int = prim::Constant[value=0]() 486 %2 : Tensor[] = aten::split_with_sizes(%b, %0, %1) 487 %3 : (Tensor[], Tensor[]) = prim::TupleConstruct(%b, %2) 488 return (%3) 489""" 490 491 graph = torch._C.parse_ir(input_str) 492 torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) 493 FileCheck().check("Tensor = prim::DifferentiableGraph").check( 494 "with prim::DifferentiableGraph" 495 ).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run( 496 graph 497 ) 498 499 # Case 3: two aliased nodes in a graph. 500 # Both `split_with_sizes` should be unfused 501 input_str = """ 502 graph(%a : Tensor): 503 %b : Tensor = aten::relu(%a) 504 %s1 : int[] = prim::Constant[value=[2, 2, 1]]() 505 %s2 : int[] = prim::Constant[value=[3, 1]]() 506 %1 : int = prim::Constant[value=0]() 507 %2 : Tensor[] = aten::split_with_sizes(%b, %s1, %1) 508 %3 : Tensor[] = aten::split_with_sizes(%b, %s2, %1) 509 %4 : (Tensor, Tensor[]) = prim::TupleConstruct(%b, %2, %3) 510 return (%4) 511""" 512 513 graph = torch._C.parse_ir(input_str) 514 torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) 515 FileCheck().check("Tensor = prim::DifferentiableGraph").check( 516 "with prim::DifferentiableGraph" 517 ).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run( 518 graph 519 ) 520 521 # Case 4: the aliased output has a descendant 522 # Both should be unfused. Note, %3 comes before %2 523 # to test that we unfuse in the reverse topo order 524 input_str = """ 525 graph(%a : Tensor): 526 %b : Tensor = aten::relu(%a) 527 %0 : int[] = prim::Constant[value=[2, 2, 1]]() 528 %1 : int = prim::Constant[value=0]() 529 %2 : Tensor = aten::t(%b) 530 %3 : Tensor = aten::relu(%2) 531 %4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%b, %3, %2) 532 return (%4) 533""" 534 535 graph = torch._C.parse_ir(input_str) 536 torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) 537 FileCheck().check("Tensor = prim::DifferentiableGraph").check( 538 "with prim::DifferentiableGraph" 539 ).check("Tensor = aten::relu").check_not("aten::t").run(graph) 540 541 # Case 5: multiple aliased groups 542 # Both should be unfused. Note, %3 comes before %2 543 # to test that we unfuse in the reverse topo order 544 input_str = """ 545 graph(%a : Tensor): 546 %b : Tensor = aten::relu(%a) 547 %c : Tensor = aten::abs(%a) 548 %0 : int[] = prim::Constant[value=[2, 2, 1]]() 549 %1 : int = prim::Constant[value=0]() 550 %d : Tensor = aten::t(%c) 551 %2 : Tensor = aten::t(%b) 552 %3 : Tensor = aten::relu(%2) 553 %4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%3, %2, %d, %b, %c, %b) 554 return (%4) 555""" 556 557 graph = torch._C.parse_ir(input_str) 558 torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) 559 FileCheck().check("Tensor = prim::DifferentiableGraph").check( 560 "with prim::DifferentiableGraph" 561 ).check("Tensor = aten::relu").check_not("aten::t").run(graph) 562 563 def test_has_profiled_info_aliasing_outputs(self): 564 # The expectation is that CallFunction will prevent the final profile node from 565 # getting merged into the DifferentiableGraph, and that create_autodiff_subgraphs 566 # will instead add this to the type for %4. 567 ir = """ 568 graph(%a : Tensor): 569 %1 : Tensor = prim::profile[profiled_type=Float(requires_grad=0)](%a) 570 %2 : Tensor = aten::relu(%1) 571 %3 : Tensor = prim::profile[profiled_type=Float(requires_grad=0)](%2) 572 %4 : Tensor = aten::relu(%3) 573 %5 : Tensor = prim::CallFunction(%4) 574 %6 : Tensor = prim::profile[profiled_type=Float(requires_grad=0)](%4) 575 return (%6) 576 """ 577 578 graph = torch._C.parse_ir(ir) 579 torch._C._jit_pass_create_autodiff_subgraphs(graph) 580 581 for n in graph.nodes(): 582 if n.kind() == "prim::DifferentiableGraph": 583 diff_graph = n.g("Subgraph") 584 585 outputs = list(diff_graph.outputs()) 586 self.assertEqual(1, len(outputs)) 587 output = outputs[0] 588 self.assertEqual(False, output.requiresGrad()) 589 590 FileCheck().check("= prim::DifferentiableGraph").check( 591 "with prim::DifferentiableGraph" 592 ).check(" = aten::relu").check("requires_grad=0").check("aten::relu").run(graph) 593