1# Owner(s): ["module: inductor"] 2import copy 3import os 4import random 5 6import torch 7from torch import nn 8from torch._dynamo.utils import same 9from torch._inductor import config 10from torch._inductor.test_case import run_tests, TestCase 11from torch.testing._internal.common_cuda import tf32_off 12from torch.testing._internal.inductor_utils import HAS_CUDA 13 14 15USE_DDP_WRAPPER = os.environ.get("USE_DDP_WRAPPER", "1") == "1" 16 17 18class Model2Conv(nn.Module): 19 def __init__(self, dim=512, manual_graph_break=False): 20 super().__init__() 21 self.conv1 = nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False) 22 self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False) 23 self.manual_graph_break = manual_graph_break 24 25 def forward(self, x): 26 x = self.conv1(x) 27 if self.manual_graph_break: 28 torch._dynamo.graph_break() 29 x = self.conv2(x) 30 return x 31 32 def get_example_inputs(self): 33 return (torch.rand(2, 3, 16, 16),) 34 35 36class TestLayoutOptim(TestCase): 37 @classmethod 38 def setUpClass(cls): 39 super().setUpClass() 40 41 import torch.distributed as dist 42 43 # not use a fixed port for stress test 44 tot_retry = 5 45 for retry_no in range(tot_retry): 46 try: 47 port = random.randint(10000, 60000) 48 dist.init_process_group( 49 backend="nccl", 50 init_method=f"tcp://localhost:{port}", 51 world_size=1, 52 rank=0, 53 ) 54 break 55 except RuntimeError: 56 if retry_no == tot_retry - 1: 57 raise 58 else: 59 continue 60 61 def verify_accuracy( 62 self, model_class, use_ddp_wrapper=USE_DDP_WRAPPER, is_train=False 63 ): 64 # there are 2 potential ways to introduce graph breaks 65 # 1. manually 66 # 2. using DDP 67 # if we are not using DDP to introduce graph breaks, do that manually 68 def wrap_mod(m): 69 if is_train: 70 71 def f(*inp): 72 x = m(*inp) 73 x.sum().backward() 74 75 grads = [] 76 for name, param in m.named_parameters(): 77 grad = param.grad 78 if param.grad is None: 79 grad = torch.zeros_like(param) 80 grads.append(grad) 81 return grads 82 83 return f 84 else: 85 return m 86 87 manual_graph_break = not use_ddp_wrapper 88 mod = model_class(manual_graph_break=manual_graph_break).cuda() 89 inp = [t.cuda() for t in mod.get_example_inputs()] 90 expected_out = wrap_mod(mod)(*inp) 91 92 fp64_mod = copy.deepcopy(mod).to(torch.float64) 93 fp64_inp = [t.to(torch.float64) for t in copy.deepcopy(inp)] 94 fp64_out = wrap_mod(fp64_mod)(*fp64_inp) 95 96 if use_ddp_wrapper: 97 from torch.nn.parallel import DistributedDataParallel as DDP 98 99 ddp_wrapped_mod = DDP(mod) 100 opt_mod = torch.compile(wrap_mod(ddp_wrapped_mod)) 101 else: 102 opt_mod = torch.compile(wrap_mod(mod)) 103 actual_out = opt_mod(*inp) 104 105 if is_train: 106 self.assertTrue(same(expected_out, actual_out, fp64_ref=fp64_out)) 107 else: 108 expected_sum = expected_out.sum() 109 actual_sum = actual_out.sum() 110 print(f"Expected sum {expected_sum}, actual sum {actual_sum}") 111 self.assertTrue(same(expected_out, actual_out, fp64_ref=fp64_out)) 112 113 def verify_accuracy_for_infer(self, *args, **kwargs): 114 self.verify_accuracy(*args, **kwargs, is_train=False) 115 116 def verify_accuracy_for_train(self, *args, **kwargs): 117 self.verify_accuracy(*args, **kwargs, is_train=True) 118 119 def test_2conv_with_graph_break(self): 120 """ 121 Make sure graph break does not cause any accuracy issue. 122 """ 123 self.verify_accuracy_for_infer(Model2Conv) 124 125 def test_3conv_with_graph_break(self): 126 class Model(nn.Module): 127 def __init__( 128 self, dim=512, patch_size=7, kernel_size=7, manual_graph_break=False 129 ): 130 super().__init__() 131 self.seq = nn.Sequential( 132 nn.Conv2d( 133 3, dim, kernel_size=patch_size, stride=patch_size, bias=False 134 ), 135 nn.Conv2d( 136 dim, dim, kernel_size, groups=dim, padding="same", bias=False 137 ), 138 ) 139 self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=False) 140 self.manual_graph_break = manual_graph_break 141 142 def forward(self, x): 143 x = self.seq(x) 144 if self.manual_graph_break: 145 torch._dynamo.graph_break() 146 x = self.conv(x) 147 return x 148 149 def get_example_inputs(self): 150 return (torch.randn(2, 3, 16, 16),) 151 152 self.verify_accuracy_for_infer(Model) 153 154 @torch.no_grad() 155 def test_keep_output_layout_infer(self): 156 class Model(nn.Module): 157 def __init__(self) -> None: 158 super().__init__() 159 self.conv = nn.Conv2d( 160 3, 128, kernel_size=3, padding=1, stride=1, bias=False 161 ) 162 163 def forward(self, x): 164 x = self.conv(x) 165 return x 166 167 def get_example_inputs(self): 168 return (torch.randn(2, 3, 5, 5),) 169 170 mod = Model().cuda() 171 inp = [t.cuda() for t in mod.get_example_inputs()] 172 out = mod(*inp) 173 174 opt_mod = torch.compile(mod) 175 opt_out = opt_mod(*inp) 176 177 # We should be able to do view on eager output 178 out.view(5, -1) 179 180 # We should be able to do view on the output of the optimized module 181 # Note that if the output is channels last, the view op will fail. 182 opt_out.view(5, -1) 183 184 def test_keep_output_layout_with_freezing(self): 185 with config.patch( 186 { 187 "freezing": True, 188 } 189 ): 190 self.test_keep_output_layout_infer() 191 192 def test_training_acc(self): 193 self.verify_accuracy_for_train(Model2Conv) 194 195 def test_mutate_view(self): 196 """ 197 The GraphModule passed to GraphLowering init method is like: 198 https://gist.github.com/shunting314/07228313fd017e2267101ff32edc6d64 199 200 It shows that we will call copy_ to update the argument in the end. This 201 guarantees the correctnesss. 202 """ 203 204 @torch.compile 205 def f(x): 206 y = x.view(3, 2) 207 y.mul_(2) 208 209 x = torch.ones(2, 3).cuda() 210 f(x) 211 self.assertTrue(torch.equal(x, torch.ones(2, 3).cuda() * 2)) 212 213 def test_mutate_base(self): 214 """ 215 The GraphModule passed to GraphLowering init method is like: 216 https://gist.github.com/shunting314/fd60fe11d1f844c6db76aba7b06811bc 217 218 It shows that the output of the graph is the mul node which contains 219 the update we applied to the base tensor. 220 """ 221 222 @torch.compile 223 def f(x): 224 y = x.view(3, 2) 225 x.mul_(2) 226 return y 227 228 x = torch.ones(2, 3).cuda() 229 y = f(x) 230 self.assertTrue(torch.equal(y, torch.ones(3, 2).cuda() * 2)) 231 232 @tf32_off() 233 def test_mutate_base_for_conv_output(self): 234 class Model(nn.Module): 235 def __init__(self, manual_graph_break=False): 236 super().__init__() 237 self.conv = nn.Conv2d(3, 512, kernel_size=3, stride=2, bias=False) 238 239 def forward(self, x): 240 x = self.conv(x) 241 y = x.view(-1) 242 x.mul_(2) 243 return y 244 245 def get_example_inputs(self): 246 return (torch.rand(2, 3, 16, 16),) 247 248 self.verify_accuracy_for_infer(Model) 249 250 @tf32_off() 251 def test_mutate_view_for_conv_output(self): 252 class Model(nn.Module): 253 def __init__(self, manual_graph_break=False): 254 super().__init__() 255 self.conv = nn.Conv2d(3, 512, kernel_size=3, stride=2, bias=False) 256 257 def forward(self, x): 258 x = self.conv(x) 259 y = x.view(-1) 260 y.mul_(2) 261 return x 262 263 def get_example_inputs(self): 264 return (torch.rand(2, 3, 16, 16),) 265 266 self.verify_accuracy_for_infer(Model) 267 268 def test_dynamic_shape_specialization(self): 269 """ 270 Previously in aot_autograd.py we compare strides of FakeTensor 271 with real tensor. That cause dynamic dimensions of the FakeTensor 272 being specialized to static shapes. This test protects against that. 273 """ 274 275 def f(a, b): 276 x = a.sin() 277 y = b.cos() 278 z = x + y 279 return z 280 281 for size in [4, 8, 16]: 282 a = torch.randn(2, size, requires_grad=True).cuda() 283 b = torch.randn(2, size).cuda() 284 actual = torch.compile(f, dynamic=True)(a, b) 285 self.assertTrue(torch.allclose(f(a, b), actual)) 286 287 # Trigger the compiling of the backward graph 288 actual.sum().backward() 289 290 def test_nll_loss_backward(self): 291 """ 292 Repro for issue https://github.com/pytorch/pytorch/issues/120759 293 294 The CUDA implementation of aten.nll_loss2d_backward.default requires 295 the self tensor (whose layout will be used to create grad_input) 296 to be contiguous. Layout optimization may change the self tensor's layout 297 and cause failure. We fix that by adding layout constaints to the 298 fallback of aten.nll_loss2d_backward.default . 299 """ 300 301 class MyModel(torch.nn.Module): 302 def __init__(self, input_dim, num_classes): 303 super().__init__() 304 self.conv = torch.nn.Conv2d(1, num_classes, 3, 1, padding="same") 305 self.out = torch.nn.Linear(input_dim * num_classes, num_classes) 306 307 def forward(self, x: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 308 x = self.conv(x) 309 b, c, t, f = x.size() 310 x = self.out(x.reshape(b, t, c * f)) 311 logits = x.reshape(x.size(0), x.size(2), x.size(1)) 312 loss = torch.nn.functional.cross_entropy(logits, targets) 313 return loss 314 315 device = "cuda" 316 batch_size = 48 317 seq_len = 144 318 input_dim = 39 319 num_classes = 111 320 321 model = MyModel(input_dim, num_classes) 322 model.to(device) 323 324 opt_model = torch.compile(model) 325 326 x = torch.ones((batch_size, 1, seq_len, input_dim), device=device) 327 targets = torch.randint( 328 0, num_classes - 1, (batch_size, seq_len), device=device, dtype=torch.int64 329 ) 330 331 loss = model(x, targets) 332 loss.backward() 333 334 ref = model(x, targets) 335 self.assertTrue(torch.allclose(ref, loss)) 336 337 338if __name__ == "__main__": 339 if HAS_CUDA: 340 run_tests() 341