1# Owner(s): ["module: inductor"] 2import functools 3import importlib 4import itertools 5import os 6import sys 7 8import torch 9from torch import nn 10from torch._inductor import config as inductor_config 11from torch.testing._internal.common_cuda import TEST_CUDNN 12 13 14# Make the helper files in test/ importable 15pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 16sys.path.append(pytorch_test_dir) 17 18from inductor.test_inductor_freezing import TestCase 19from inductor.test_torchinductor import check_model, check_model_gpu, copy_tests 20from torch.testing._internal.common_utils import TEST_WITH_ASAN 21from torch.testing._internal.inductor_utils import skipCUDAIf 22 23 24importlib.import_module("functorch") 25importlib.import_module("filelock") 26 27from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU 28 29 30aten = torch.ops.aten 31 32 33class BinaryFoldingTemplate(TestCase): 34 @skipCUDAIf(TEST_CUDNN, "CUDNN has accuracy issues for this test") 35 def test_conv_binary_folding(self): 36 @torch.no_grad() 37 def test_conv_fusion(use_bias, module, op, scalar, add_tensor, expect_success): 38 class ConvOp(nn.Module): 39 __constants__ = ["use_scalar"] 40 41 def __init__(self, in_channels, out_channels, device, **kwargs): 42 super().__init__() 43 self.conv = module( 44 in_channels, out_channels, bias=use_bias, **kwargs 45 ).to(device) 46 self.conv2 = module( 47 in_channels, out_channels, bias=use_bias, **kwargs 48 ).to(device) 49 self.use_scalar = scalar 50 tensor_size = [1 for _ in range(self.conv.weight.ndim)] 51 tensor_size[1] = self.conv.weight.size(0) 52 self.tensor = torch.nn.Parameter( 53 add_tensor 54 if add_tensor is not None 55 else torch.rand(tensor_size).to(device) 56 ) 57 self.op = op 58 59 def forward(self, x): 60 x = self.conv(x) 61 if self.use_scalar: 62 return self.op(x, 2.0) 63 else: 64 return self.op(x, self.tensor) 65 66 from torch._inductor.compile_fx import compile_fx, compile_fx_inner 67 68 aten_binary = { 69 torch.add: aten.add.Tensor, 70 torch.sub: aten.sub.Tensor, 71 torch.mul: aten.mul.Tensor, 72 torch.div: aten.div.Tensor, 73 } 74 n_binary_ops = 0 75 76 def my_inner_compile(gm, example_inputs, *args, **kwargs): 77 out = compile_fx_inner(gm, example_inputs, *args, **kwargs) 78 nonlocal n_binary_ops 79 binarry_ops = [n for n in gm.graph.nodes if n.target == aten_binary[op]] 80 n_binary_ops += len(binarry_ops) 81 return out 82 83 torch._dynamo.reset() 84 mod_eager = ConvOp(3, 32, self.device, kernel_size=3, stride=2).eval() 85 out_optimized = torch.compile( 86 mod_eager, 87 backend=functools.partial(compile_fx, inner_compile=my_inner_compile), 88 ) 89 90 inps = [4, 3, 4] 91 if module == nn.Conv2d: 92 inps.append(inps[-1]) 93 if module == nn.Conv3d: 94 inps.append(inps[-1]) 95 inps.append(inps[-1]) 96 97 torch.manual_seed(1234) 98 inp = torch.rand(inps).to(self.device) 99 out_eager = mod_eager(inp) 100 out_optimized = out_optimized(inp) 101 self.assertEqual(out_optimized, out_eager) 102 if expect_success: 103 self.assertTrue(n_binary_ops == 0) 104 else: 105 self.assertTrue(n_binary_ops == 1) 106 107 conv_bias = [True, False] 108 modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d] 109 use_scalar = [True, False] 110 ops = [torch.add, torch.sub, torch.mul, torch.div] 111 for use_bias, module, pytorch_op, scalar in itertools.product( 112 conv_bias, modules, ops, use_scalar 113 ): 114 # TODO: support scalar case 115 expect_success = not scalar 116 test_conv_fusion( 117 use_bias, 118 module, 119 pytorch_op, 120 scalar, 121 add_tensor=None, 122 expect_success=expect_success, 123 ) 124 125 for use_bias, pytorch_op in itertools.product(conv_bias, ops): 126 # broadcasting add 127 test_conv_fusion( 128 use_bias, 129 nn.Conv2d, 130 pytorch_op, 131 False, 132 add_tensor=torch.rand( 133 32, 134 1, 135 32, 136 ).to(self.device), 137 expect_success=False, 138 ) 139 140 # broadcasting add 141 test_conv_fusion( 142 use_bias, 143 nn.Conv2d, 144 pytorch_op, 145 False, 146 add_tensor=torch.rand(1, 1).to(self.device), 147 expect_success=True, 148 ) 149 150 # add with different dtype 151 test_conv_fusion( 152 use_bias, 153 nn.Conv2d, 154 pytorch_op, 155 False, 156 add_tensor=torch.tensor([2]).to(torch.float64).to(self.device), 157 expect_success=False, 158 ) 159 160 @inductor_config.patch({"freezing": True}) 161 def test_conv_bn_folding(self): 162 @torch.no_grad() 163 def test_conv_fusion(use_bias, module, expect_success): 164 class ConvOp(nn.Module): 165 def __init__(self, in_channels, out_channels, device, **kwargs): 166 super().__init__() 167 self.conv = module[0]( 168 in_channels, out_channels, bias=use_bias, **kwargs 169 ).to(device) 170 self.bn = module[1](out_channels).to(device) 171 172 def forward(self, x): 173 x = self.conv(x) 174 return self.bn(x) 175 176 from torch._inductor.compile_fx import compile_fx, compile_fx_inner 177 178 aten_binary = [ 179 aten.add.Tensor, 180 aten.sub.Tensor, 181 aten.mul.Tensor, 182 aten.div.Tensor, 183 ] 184 n_binary_ops = 0 185 186 def my_inner_compile(gm, example_inputs, *args, **kwargs): 187 out = compile_fx_inner(gm, example_inputs, *args, **kwargs) 188 nonlocal n_binary_ops 189 binarry_ops = [n for n in gm.graph.nodes if n.target in aten_binary] 190 n_binary_ops += len(binarry_ops) 191 return out 192 193 torch._dynamo.reset() 194 mod_eager = ConvOp(3, 32, self.device, kernel_size=3, stride=2).eval() 195 out_optimized = torch.compile( 196 mod_eager, 197 backend=functools.partial(compile_fx, inner_compile=my_inner_compile), 198 ) 199 200 inps = [4, 3, 4] 201 if module[0] == nn.Conv2d: 202 inps.append(inps[-1]) 203 if module[0] == nn.Conv3d: 204 inps.append(inps[-1]) 205 inps.append(inps[-1]) 206 207 inp = torch.rand(inps).to(self.device) 208 out_eager = mod_eager(inp) 209 out_optimized = out_optimized(inp) 210 self.assertEqual(out_optimized, out_eager, atol=2e-04, rtol=1e-5) 211 if expect_success: 212 self.assertTrue(n_binary_ops == 0) 213 else: 214 self.assertTrue(n_binary_ops > 1) 215 216 conv_bias = [True, False] 217 modules = [ 218 (nn.Conv1d, nn.BatchNorm1d), 219 (nn.Conv2d, nn.BatchNorm2d), 220 (nn.Conv3d, nn.BatchNorm3d), 221 ] 222 for use_bias, module in itertools.product(conv_bias, modules): 223 test_conv_fusion( 224 use_bias, 225 module, 226 expect_success=True, 227 ) 228 229 230if HAS_CPU and not torch.backends.mps.is_available(): 231 232 class FreezingCpuTests(TestCase): 233 common = check_model 234 device = "cpu" 235 autocast = torch.cpu.amp.autocast 236 237 copy_tests(BinaryFoldingTemplate, FreezingCpuTests, "cpu") 238 239if HAS_GPU and not TEST_WITH_ASAN: 240 241 class FreezingGpuTests(TestCase): 242 common = check_model_gpu 243 device = GPU_TYPE 244 autocast = torch.amp.autocast(device_type=GPU_TYPE) 245 246 copy_tests(BinaryFoldingTemplate, FreezingGpuTests, GPU_TYPE) 247 248 249del BinaryFoldingTemplate 250 251if __name__ == "__main__": 252 from torch._inductor.test_case import run_tests 253 254 if HAS_CPU or HAS_GPU: 255 run_tests(needs="filelock") 256