1# Owner(s): ["module: inductor"] 2import functools 3from unittest.mock import patch 4 5import torch 6import torch._dynamo.config as dynamo_config 7import torch._inductor.config as inductor_config 8import torch._inductor.select_algorithm as select_algorithm 9import torch.nn.functional as F 10from torch._dynamo.testing import expectedFailureDynamicWrapper 11from torch._dynamo.utils import counters 12from torch._inductor.autotune_process import TritonBenchmarkRequest 13from torch._inductor.test_case import run_tests, TestCase 14from torch._inductor.utils import is_big_gpu 15from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm 16from torch.testing._internal.inductor_utils import HAS_CUDA 17 18 19aten = torch.ops.aten 20 21 22def patches(fn): 23 def skip_cache(self, choices, name, key, benchmark): 24 if benchmark is None: 25 return {} 26 return benchmark(choices) 27 28 for patcher in [ 29 dynamo_config.patch(verbose=True), 30 inductor_config.patch(debug=True, max_autotune=True, epilogue_fusion=True), 31 patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)), 32 patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache), 33 torch.backends.cudnn.flags(allow_tf32=False), 34 ]: 35 fn = patcher(fn) 36 37 @functools.wraps(fn) 38 def wrapped(*args, **kwargs): 39 counters.clear() 40 torch.manual_seed(12345) 41 assert ( 42 not torch.backends.cuda.matmul.allow_tf32 43 ), "correctness testing is allergic to tf32" 44 return fn(*args, **kwargs) 45 46 return wrapped 47 48 49class TestSelectAlgorithm(TestCase): 50 def setUp(self): 51 super().setUp() 52 if not is_big_gpu(0): 53 return self.skipTest("Need a big GPU to run max_autotune=True") 54 55 @patches 56 def test_linear_relu_cuda(self): 57 @torch.compile 58 def foo(input, weight, bias): 59 return F.relu(F.linear(input, weight, bias)) 60 61 foo( 62 torch.randn(64, 32, device="cuda"), 63 torch.randn(16, 32, device="cuda"), 64 torch.randn(1, 16, device="cuda"), 65 ) 66 # Autotuning checks correctness of each version 67 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 68 # It would be nice to assert this got fused into a single kernel, but that 69 # only happens if we select a triton template (and not aten). 70 71 @patches 72 def test_addmm_cuda(self): 73 @torch.compile 74 def foo(input, weight, bias): 75 return torch.addmm(bias, input, weight) 76 77 inps = ( 78 torch.randn(20, 33, device="cuda"), 79 torch.randn(33, 16, device="cuda"), 80 torch.randn(20, 16, device="cuda"), 81 ) 82 83 foo(*inps) 84 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 85 86 @patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2)) 87 @patches 88 def test_addmm_fp16(self): 89 @torch.compile 90 def foo(input, weight, bias): 91 return torch.addmm(bias, input, weight) 92 93 inps = ( 94 torch.randn(2, 320, device="cuda", dtype=torch.half), 95 torch.randn(320, 320, device="cuda", dtype=torch.half).t(), 96 torch.empty(320, device="cuda", dtype=torch.half), 97 ) 98 99 foo(*inps) 100 # Autotuning checks correctness of each version 101 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 102 103 @patches 104 def test_mm(self): 105 @torch.compile 106 def foo(a, b): 107 return torch.mm(a, b) 108 109 foo( 110 torch.randn(8, 32, device="cuda"), 111 torch.randn(32, 8, device="cuda"), 112 ) 113 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 114 115 # FIXME: Investigate why _int_mm_out_cuda is not compiled on ROCm 116 @skipIfRocm 117 @patches 118 def test__int_mm(self): 119 @torch.compile 120 def foo(a, b): 121 return torch._int_mm(a, b) 122 123 foo( 124 torch.randint(-10, 10, (64, 32), device="cuda", dtype=torch.int8), 125 torch.randint(-10, 10, (32, 64), device="cuda", dtype=torch.int8), 126 ) 127 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 128 129 @patches 130 def test_mm_skip(self): 131 @torch.compile 132 def foo(a, b): 133 return torch.mm(a, b) 134 135 foo( 136 torch.randn(8, 32, device="cuda", dtype=torch.float64), 137 torch.randn(32, 8, device="cuda", dtype=torch.float64), 138 ) 139 # float64 not supported by tl.dot() 140 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0) 141 142 @patches 143 def test_bmm(self): 144 @torch.compile 145 def foo(a, b): 146 return torch.bmm(a, b) 147 148 foo( 149 torch.randn(2, 8, 32, device="cuda"), 150 torch.randn(2, 32, 8, device="cuda"), 151 ) 152 # Autotuning checks correctness of each version 153 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 154 155 @patches 156 def test_mm_not_even_k(self): 157 @torch.compile 158 def foo(a, b): 159 return torch.mm(a, b) 160 161 foo( 162 torch.randn(11, 22, device="cuda"), 163 torch.randn(22, 33, device="cuda"), 164 ) 165 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 166 167 @patches 168 def test_baddbmm(self): 169 @torch.compile 170 def foo(a, b, c): 171 return torch.baddbmm(c, a, b) 172 173 foo( 174 torch.randn(2, 8, 32, device="cuda"), 175 torch.randn(2, 32, 8, device="cuda"), 176 torch.randn(2, 1, 8, device="cuda"), 177 ) 178 # Autotuning checks correctness of each version 179 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 180 181 @patches 182 def test_mm_plus_mm(self): 183 @torch.compile 184 def foo(a, b, c, d): 185 return (a @ b) + (c @ d) 186 187 foo( 188 torch.randn(32, 32, device="cuda"), 189 torch.randn(32, 32, device="cuda"), 190 torch.randn(32, 32, device="cuda"), 191 torch.randn(32, 32, device="cuda"), 192 ) 193 # Autotuning checks correctness of each version 194 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 195 196 @patches 197 def test_mm_plus_mm2_cuda(self): 198 @torch.compile 199 def foo(a, b, c, d): 200 return (a @ b) + (c @ d) 201 202 foo( 203 torch.randn(512, 512, device="cuda"), 204 torch.randn(512, 512, device="cuda"), 205 torch.randn(512, 512, device="cuda"), 206 torch.randn(512, 512, device="cuda"), 207 ) 208 # Autotuning checks correctness of each version 209 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 210 211 @expectedFailureDynamicWrapper 212 @patches 213 def test_mm_plus_mm3_cuda(self): 214 @torch.compile 215 def foo(a, b, c, d): 216 return (a @ b) + (c @ d) 217 218 foo( 219 torch.randn(512, 32, device="cuda"), 220 torch.randn(32, 8, device="cuda"), 221 torch.randn(512, 32, device="cuda"), 222 torch.randn(32, 8, device="cuda"), 223 ) 224 # Autotuning checks correctness of each version 225 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 226 227 @patches 228 def test_mm_dup_args(self): 229 @torch.compile 230 def foo(a): 231 return torch.mm(a, a) 232 233 foo(torch.randn(32, 32, device="cuda")) 234 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 235 236 @patches 237 def test_mm_dup_args_view(self): 238 @torch.compile 239 def foo(a): 240 q = a[:32, :] 241 k = a[32:, :] 242 return torch.mm(q, k.transpose(0, 1)) 243 244 foo(torch.randn(64, 64, device="cuda")) 245 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 246 247 @expectedFailureDynamicWrapper 248 @patches 249 def test_convolution1(self): 250 @torch.compile 251 def foo(x, w, b): 252 return aten.convolution( 253 x + 1, 254 w, 255 b, 256 stride=(2, 3), 257 padding=(4, 5), 258 dilation=(1, 1), 259 transposed=False, 260 output_padding=(0, 0), 261 groups=1, 262 ) 263 264 foo( 265 torch.randn(2, 33, 34, 41, device="cuda"), 266 torch.randn(34, 33, 3, 3, device="cuda"), 267 torch.randn(34, device="cuda"), 268 ) 269 # Autotuning checks correctness of each version 270 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 271 272 @skipIfRocm 273 @patches 274 def test_mm_dropout(self): 275 @torch.compile 276 def fn(x1, x2, seed): 277 mm_4 = torch.ops.aten.mm.default(x2, x1) 278 rnd = torch.ops.prims.inductor_random.default(mm_4.shape, seed, "rand") 279 return mm_4 * rnd 280 281 # sizes picked so triton autotuning wins 282 fn( 283 torch.randn(512, 1024, dtype=torch.float16, device="cuda"), 284 torch.randn(384, 512, dtype=torch.float16, device="cuda"), 285 torch.tensor(12345, device="cuda"), 286 ) 287 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 288 289 @skipIfRocm 290 @patches 291 @torch._inductor.config.patch(conv_1x1_as_mm=False) 292 def test_convolution2(self): 293 @torch.compile 294 def foo(x, w, b): 295 return aten.convolution( 296 x, 297 w, 298 b, 299 stride=(1, 1), 300 padding=(0, 0), 301 dilation=(1, 1), 302 transposed=False, 303 output_padding=(0, 0), 304 groups=1, 305 ) 306 307 foo( 308 torch.randn(1, 33, 16, 16, device="cuda"), 309 torch.randn(34, 33, 1, 1, device="cuda"), 310 torch.randn(34, device="cuda"), 311 ) 312 # Autotuning checks correctness of each version 313 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 314 315 @patches 316 @torch._inductor.config.patch(conv_1x1_as_mm=True) 317 def test_convolution_as_mm(self): 318 @torch.compile 319 def foo(x, w, b): 320 return aten.convolution( 321 x + 1, 322 w, 323 b, 324 stride=(1, 1), 325 padding=(0, 0), 326 dilation=(1, 1), 327 transposed=False, 328 output_padding=(0, 0), 329 groups=1, 330 ) 331 332 foo( 333 torch.randn(2, 33, 16, 16, device="cuda"), 334 torch.randn(34, 33, 1, 1, device="cuda"), 335 torch.randn(34, device="cuda"), 336 ) 337 # Autotuning checks correctness of each version 338 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 339 340 def test_TritonTemplateCaller_str(self): 341 """ 342 Make sure str(TritonTemplateCaller) does not raise exceptions. 343 """ 344 module_path = "abc.py" 345 bmreq = TritonBenchmarkRequest( 346 module_path=module_path, 347 module_cache_key=None, 348 kernel_name=None, 349 grid=None, 350 extra_args=None, 351 num_stages=None, 352 num_warps=None, 353 input_tensor_meta=None, 354 output_tensor_meta=None, 355 ) 356 caller = select_algorithm.TritonTemplateCaller( 357 None, None, None, None, "extra", bmreq 358 ) 359 caller_str = str(caller) 360 self.assertEqual(caller_str, f"TritonTemplateCaller({module_path}, extra)") 361 362 363if __name__ == "__main__": 364 if IS_LINUX and HAS_CUDA and is_big_gpu(0): 365 run_tests() 366