1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5 6import torch 7from torch.testing._internal.common_utils import skipIfTorchDynamo 8 9 10# Make the helper files in test/ importable 11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12sys.path.append(pytorch_test_dir) 13from torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward 14 15 16if __name__ == "__main__": 17 raise RuntimeError( 18 "This test file is not meant to be run directly, use:\n\n" 19 "\tpython test/test_jit.py TESTNAME\n\n" 20 "instead." 21 ) 22 23 24@skipIfTorchDynamo() 25class TestProfiler(JitTestCase): 26 def setUp(self): 27 self.prev_exec = torch._C._jit_set_profiling_executor(True) 28 self.prev_profiling = torch._C._get_graph_executor_optimize(True) 29 self.inline_autodiff = torch._C._debug_set_autodiff_subgraph_inlining(False) 30 self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() 31 self.can_fuse_on_cpu = torch._C._jit_can_fuse_on_cpu() 32 torch._C._jit_set_texpr_fuser_enabled(True) 33 torch._C._jit_override_can_fuse_on_cpu(True) 34 self.default_dtype = torch.get_default_dtype() 35 self.old_reduction_enabled = torch._C._jit_set_texpr_reductions_enabled(True) 36 torch.set_default_dtype(torch.double) 37 self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() 38 torch._C._debug_set_fusion_group_inlining(False) 39 self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() 40 torch._C._jit_set_te_must_use_llvm_cpu(False) 41 42 def tearDown(self): 43 torch._C._jit_set_profiling_executor(self.prev_exec) 44 torch._C._get_graph_executor_optimize(self.prev_profiling) 45 torch._C._debug_set_autodiff_subgraph_inlining(self.inline_autodiff) 46 torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) 47 torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu) 48 torch.set_default_dtype(self.default_dtype) 49 torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled) 50 torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) 51 torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu) 52 53 def test_tensor_type_not_determined_by_inputs(self): 54 @torch.jit.script 55 def scalar_type_input(x, y, z): 56 return x + y + 4 + z.item() 57 58 x = torch.tensor([2, 2]) 59 scalar_type_input(x, x, torch.tensor(1)) 60 scalar_type_input(x, x, torch.tensor(1)) 61 scalar_type_input(x, x, torch.tensor(1.0)) 62 g = torch.jit.last_executed_optimized_graph() 63 64 # item & add should not get pulled into the fusion group - 65 # we expect to see Fusion Group (item / add) Fusion Group in ir dump 66 FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next( 67 "Tensor = aten::add" 68 ).check("TensorExpr").run(g) 69 70 @torch.jit.script 71 def non_const_dtype(x, y, cond: bool): 72 dtype = torch.int16 if cond else torch.int32 73 return (x + y + 3).sum(dtype=dtype) 74 75 non_const_dtype(x, x, True) 76 non_const_dtype(x, x, True) 77 g = torch.jit.last_executed_optimized_graph() 78 # because dtype is non-const, sum should not get pulled into the Fusion Group 79 FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run( 80 g 81 ) 82 83 def test_specialize_backward(self): 84 def test_fuse(a, b): 85 c = a * b 86 d = c * b 87 return d 88 89 test_fuse.__disable_jit_function_caching__ = True 90 91 scripted_f = torch.jit.script(test_fuse) 92 x = torch.ones(1, requires_grad=True) 93 y = torch.ones(1, requires_grad=True) 94 scripted_f(x, y) 95 b = scripted_f(x, y) 96 warmup_backward(b) 97 g = torch.jit.last_executed_optimized_graph() 98 # Backward has an if node guarding specializations, 99 # within the if node true block there is only one if node 100 # that guards a tensorexpr group 101 optimized_block = next(g.findNode("prim::If").blocks()) 102 if_nodes = list(optimized_block.findAllNodes("prim::If")) 103 104 self.assertEqual(len(if_nodes), 1) 105 FileCheck().check("Group[Subgraph").run(str(if_nodes[0])) 106 # no broadcasts occurred, sum_to_size have been specialized out 107 self.assertIsNone(optimized_block.findNode("aten::_grad_sum_to_size")) 108 109 broadcast_f = torch.jit.script(test_fuse) 110 x = torch.ones([2, 2], requires_grad=True) 111 y = torch.ones([1], requires_grad=True) 112 broadcast_f(x, y) 113 b = broadcast_f(x, y) 114 b.backward(torch.ones([2, 2], dtype=torch.float), retain_graph=True) 115 b.backward(torch.ones([2, 2], dtype=torch.float)) 116 # warmup_backward(b, torch.ones([2, 2], dtype=torch.float)) 117 g = torch.jit.last_executed_optimized_graph() 118 optimized_block = next(g.findNode("prim::If").blocks()) 119 # broadcasts occurred, currently expect to see aten::_grad_sum_to_size 120 self.assertIsNotNone(optimized_block.findNode("aten::_grad_sum_to_size")) 121 122 def test_specialized_types(self): 123 @torch.jit.script 124 def test_fuse(a, b): 125 c = a * b 126 d = c * b 127 return d 128 129 x = torch.tensor([0.5]) 130 for _ in range(3): 131 test_fuse(x, x) 132 133 g = torch.jit.last_executed_optimized_graph() 134 # Types should remain specialized for typecheck outputs & fusion outputs 135 FileCheck().check("Double(").check_same("prim::TypeCheck").check_same( 136 "\n" 137 ).check("Double").check_same("TensorExpr").run(g) 138 139 # other outputs should not be specialized 140 FileCheck().check("Tensor = prim::If").run(g) 141 142 def test_aliasing_merge(self): 143 @torch.jit.script 144 def foo(a, b): 145 c = a * b 146 d = c * b 147 d.add_(b) 148 e = d * b 149 return d + e 150 151 x = torch.ones(1) 152 y = torch.ones(1) 153 foo(x, y) 154 b = foo(x, y) 155 g = torch.jit.last_executed_optimized_graph() 156 self.assertEqual(len(list(g.findAllNodes("prim::TypeCheck"))), 2) 157 FileCheck().check("TensorExpr").check("aten::add_").check("TensorExpr").run(g) 158 159 def test_use_not_profiled(self): 160 def foo(t1, t2, t3, t4, t: float): 161 h = t1 + t2 + t3 + t4 162 if t > 0.5: 163 # Putting a use of t1 in a never-executed conditional prevents 164 return t1 + 1 165 return h 166 167 t = torch.rand(8, dtype=torch.float) 168 169 foo_script = torch.jit.script(foo) 170 for _ in range(torch._C._jit_get_num_profiled_runs() + 1): 171 foo_script(t, t, t, t, 0.1) 172 173 self.assertEqual(foo(t, t, t, t, 0.1), foo_script(t, t, t, t, 0.1)) 174 g = torch.jit.last_executed_optimized_graph() 175 # all adds fused 176 FileCheck().check("graph").check_not("aten::add").check("prim::If").run(g) 177 178 def test_not_fusing_scalar_ops(self): 179 @torch.jit.script 180 def foo(x: int, y: int): 181 return x + y + 2 + 4 + 5 + 6 182 183 foo(1, 2) 184 foo(2, 3) 185 g = torch.jit.last_executed_optimized_graph() 186 FileCheck().check_not("TensorExpr").run(g) 187 188 def test_not_optimizing_property(self): 189 @torch.jit.script 190 def foo(x, y): 191 return x + y + 1 + 2 + 3, x.size() 192 193 x = torch.ones(1) 194 foo(x, x) 195 foo(x, x) 196 g = torch.jit.last_executed_optimized_graph() 197 FileCheck().check("aten::size").run(g) 198 x = torch.ones([2, 3, 5]) 199 self.assertEqual(foo(x, x), (x + x + 1 + 2 + 3, x.size())) 200 201 def test_fallback_graph_not_specialized(self): 202 @torch.jit.script 203 def foo(a, b): 204 c = a * b 205 d = c * b 206 e = d * b 207 return d + e 208 209 x = torch.ones(1) 210 y = torch.ones(1) 211 foo(x, y) 212 foo(x, y) 213 g = torch.jit.last_executed_optimized_graph() 214 FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run( 215 g 216 ) 217 218 def test_autograd_fallback_graph(self): 219 @torch.jit.script 220 def foo(a, b): 221 c = a * b 222 d = c * b 223 e = d * b 224 return d + e 225 226 x = torch.ones(1, requires_grad=True) 227 y = torch.ones(1, requires_grad=True) 228 foo(x, y) 229 b = foo(x, y) 230 b.backward(torch.ones([1], dtype=torch.float), retain_graph=True) 231 b.backward(torch.ones([1], dtype=torch.float)) 232 233 g = torch.jit.last_executed_optimized_graph() 234 FileCheck().check("fallback_function").check_next("CallFunction").run(g) 235 236 def test_tensor_constant(self): 237 def foo(a, b): 238 return a + b + torch.tensor([2]) 239 240 x = torch.ones(1, requires_grad=False) 241 foo_script = torch.jit.script(foo) 242 foo_script(x, x) 243 foo_script(x, x) 244 245 self.assertEqual(foo_script(x, x), foo(x, x)) 246 g = torch.jit.last_executed_optimized_graph() 247 FileCheck().check_count("aten::add", 2, exactly=True).run(g) 248 249 def test_local_fusion_strategy(self): 250 @torch.jit.script 251 def foo(x): 252 return x + x + x 253 254 torch.jit.set_fusion_strategy([("STATIC", 1)]) 255 for _ in range(3): 256 foo(torch.rand([10])) 257 258 torch.jit.set_fusion_strategy([("STATIC", 10)]) 259 260 for i in range(10): 261 foo(torch.rand([i])) 262 foo(torch.rand([i])) 263 264 g = torch.jit.last_executed_optimized_graph() 265 FileCheck().check_count(":TensorExprGroup", 2, exactly=True).run(g) 266 267 def test_iterative_fusion(self): 268 @torch.jit.script 269 def foo(a, b, c, d): 270 a = a + b 271 b.add_(3) 272 c = c + b + d 273 a = a + 1 274 return a, c 275 276 x = torch.ones(1, requires_grad=False) 277 foo(x, x, x, x) 278 foo(x, x, x, x) 279 280 # when we iterate through the block, we will start 281 # by fusing a = a + b with a = a + 1 282 # if we were to continue iteration from that fusion point, 283 # would miss the fusion opportunity of c = c + d + b 284 285 g = torch.jit.last_executed_optimized_graph() 286 self.assertEqual(len(list(g.findAllNodes("prim::TensorExprGroup"))), 2) 287