1# Owner(s): ["module: inductor"] 2 3import sys 4import unittest 5 6import torch 7import torch._inductor 8from torch.testing._internal.common_utils import ( 9 instantiate_parametrized_tests, 10 TestCase, 11) 12from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA 13from torch.testing._internal.triton_utils import requires_cuda 14 15 16aten = torch.ops.aten 17 18try: 19 try: 20 from .test_torchinductor import check_model, check_model_cuda 21 except ImportError: 22 from test_torchinductor import check_model, check_model_cuda 23except (unittest.SkipTest, ImportError) as e: 24 sys.stderr.write(f"{type(e)}: {e}\n") 25 if __name__ == "__main__": 26 sys.exit(0) 27 raise 28 29 30@instantiate_parametrized_tests 31class ComboKernelTests(TestCase): 32 check_model_cuda = check_model_cuda 33 check_model_cpu = check_model 34 check_kernel_count = True 35 36 def setUp(self): 37 super().setUp() 38 torch._inductor.metrics.reset() 39 torch._inductor.config.combo_kernels = True 40 torch._inductor.config.benchmark_combo_kernel = False 41 42 def tearDown(self): 43 super().tearDown() 44 torch._inductor.metrics.reset() 45 46 @requires_cuda 47 def test_activation_functions(self): 48 def test_activations(a, b, c): 49 a1 = torch.nn.functional.relu(a) 50 b1 = torch.nn.functional.sigmoid(b) 51 c1 = torch.nn.functional.tanh(c) 52 return a1, b1, c1 53 54 inps = [ 55 torch.rand(10, 10, device="cuda"), 56 torch.rand(20, 20, device="cuda"), 57 torch.rand(10, 10, device="cuda"), 58 ] 59 60 out_eager = test_activations(*inps) 61 out_compiled = torch.compile(test_activations)(*inps) 62 63 self.assertEqual(out_eager, out_compiled) 64 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 65 66 @requires_cuda 67 def test_reduce_functions(self): 68 def test_reduce(a, b, c, d): 69 a1 = torch.sum(a, dim=0) 70 b1 = torch.max(b, dim=0) 71 c1 = torch.min(c, dim=0) 72 d1 = torch.nn.functional.tanh(d) 73 74 return a1, b1, c1, d1 75 76 inps = [ 77 torch.rand(10, 10, device="cuda"), 78 torch.rand(20, 20, device="cuda"), 79 torch.rand(10, 10, device="cuda"), 80 torch.rand(30, 8, device="cuda"), 81 ] 82 83 out_eager = test_reduce(*inps) 84 out_compiled = torch.compile(test_reduce)(*inps) 85 86 self.assertEqual(out_eager, out_compiled) 87 self.assertTrue(torch._inductor.metrics.generated_kernel_count <= 2) 88 89 @requires_cuda 90 def test_mutated_args(self): 91 def test_mutated(a, b, c, d): 92 a.add_(1) 93 b.sigmoid_() 94 c = torch.add(c, 5) 95 d.tanh_() 96 97 return a, b, c, d 98 99 inps = [ 100 torch.rand(10, 10, device="cuda"), 101 torch.rand(20, 20, device="cuda"), 102 torch.rand(10, 10, device="cuda"), 103 torch.rand(30, 8, device="cuda"), 104 ] 105 106 out_eager = test_mutated(*inps) 107 out_compiled = torch.compile(test_mutated)(*inps) 108 109 self.assertEqual(out_eager, out_compiled) 110 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) 111 112 @requires_cuda 113 def test_reduce_split(self): 114 def fn(a, b): 115 a1 = torch.linalg.vector_norm(a) 116 b1 = torch.sum(b, dim=0) 117 return a1, b1 118 119 inps = [ 120 torch.rand(2048, 512, device="cuda"), 121 torch.rand(20, 20, device="cuda"), 122 ] 123 out_eager = fn(*inps) 124 out_compiled = torch.compile(fn)(*inps) 125 126 self.assertEqual(out_eager, out_compiled) 127 128 @requires_cuda 129 def test_2d_blocking_partitioning(self): 130 def fn(a0, a1, a2, b0, b1, b2): 131 c0 = torch.add(a0, b0) 132 c1 = torch.add(a1, b1) 133 c2 = torch.add(a2, b2) 134 return c0, c1, c2 135 136 self.check_model_cuda( 137 fn, 138 ( 139 torch.rand(30, 20, device="cuda"), 140 torch.rand(40, 30, device="cuda"), 141 torch.rand(36, 40, device="cuda"), 142 torch.rand(30, 20, device="cuda"), 143 torch.rand(30, 40, device="cuda").t(), 144 torch.rand(40, 36, device="cuda").t(), 145 ), 146 ) 147 148 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) 149 150 151@instantiate_parametrized_tests 152class ComboKernelBenchmarkTests(TestCase): 153 check_model_cuda = check_model_cuda 154 check_model_cpu = check_model 155 check_kernel_count = True 156 157 def setUp(self): 158 super().setUp() 159 torch._inductor.metrics.reset() 160 torch._inductor.config.combo_kernels = True 161 torch._inductor.config.benchmark_combo_kernel = True 162 163 def tearDown(self): 164 super().tearDown() 165 torch._inductor.metrics.reset() 166 167 @requires_cuda 168 def test_activation_benchmark(self): 169 def test_activations(a, b, c): 170 a1 = torch.nn.functional.relu(a) 171 b1 = torch.nn.functional.sigmoid(b) 172 c1 = torch.nn.functional.tanh(c) 173 return a1, b1, c1 174 175 inps = [ 176 torch.rand(10, 10, device="cuda"), 177 torch.rand(20, 20, device="cuda"), 178 torch.rand(10, 10, device="cuda"), 179 ] 180 181 out_eager = test_activations(*inps) 182 out_compiled = torch.compile(test_activations)(*inps) 183 184 self.assertEqual(out_eager, out_compiled) 185 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) 186 187 @requires_cuda 188 def test_reduce_benchmark(self): 189 def test_reduce(a, b, c, d): 190 a1 = torch.sum(a, dim=0) 191 b1 = torch.max(b, dim=0) 192 c1 = torch.min(c, dim=0) 193 d1 = torch.nn.functional.tanh(d) 194 195 return a1, b1, c1, d1 196 197 inps = [ 198 torch.rand(10, 10, device="cuda"), 199 torch.rand(20, 20, device="cuda"), 200 torch.rand(10, 10, device="cuda"), 201 torch.rand(30, 8, device="cuda"), 202 ] 203 204 out_eager = test_reduce(*inps) 205 out_compiled = torch.compile(test_reduce)(*inps) 206 207 self.assertEqual(out_eager, out_compiled) 208 self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10) 209 210 @requires_cuda 211 def test_mutated_benchmark(self): 212 def test_mutated(a, b, c, d): 213 a.add_(1) 214 b.sigmoid_() 215 c = torch.add(c, 5) 216 d.tanh_() 217 218 return a, b, c, d 219 220 inps = [ 221 torch.rand(10, 10, device="cuda"), 222 torch.rand(20, 20, device="cuda"), 223 torch.rand(10, 10, device="cuda"), 224 torch.rand(30, 8, device="cuda"), 225 ] 226 227 out_eager = test_mutated(*inps) 228 out_compiled = torch.compile(test_mutated)(*inps) 229 230 self.assertEqual(out_eager, out_compiled) 231 self.assertTrue(torch._inductor.metrics.generated_kernel_count in [6, 9]) 232 233 @requires_cuda 234 def test_round_robin_dispatch(self): 235 # combo kernel dispatch strategy: round robin 236 def test_mutated(a, b, c, d): 237 a.add_(1) 238 b.sigmoid_() 239 c = torch.add(c, 5) 240 d.tanh_() 241 242 return a, b, c, d 243 244 inps = [ 245 torch.rand(10, 10, device="cuda"), 246 torch.rand(20, 5, device="cuda"), 247 torch.rand(10, 10, device="cuda"), 248 torch.rand(5, 18, device="cuda"), 249 ] 250 251 out_eager = test_mutated(*inps) 252 out_compiled = torch.compile(test_mutated)(*inps) 253 254 self.assertEqual(out_eager, out_compiled) 255 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) 256 257 @requires_cuda 258 def test_2d_blocking_benchmark(self): 259 def fn(a0, a1, a2, b0, b1, b2): 260 c0 = torch.add(a0, b0) 261 c1 = torch.add(a1, b1) 262 c2 = torch.add(a2, b2) 263 return c0, c1, c2 264 265 self.check_model_cuda( 266 fn, 267 ( 268 torch.rand(30, 20, device="cuda"), 269 torch.rand(40, 30, device="cuda"), 270 torch.rand(36, 40, device="cuda"), 271 torch.rand(30, 20, device="cuda"), 272 torch.rand(30, 40, device="cuda").t(), 273 torch.rand(40, 36, device="cuda").t(), 274 ), 275 ) 276 277 self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) 278 279 @requires_cuda 280 def test_persistent_reduction_no_x_dim(self): 281 def fn(x, y): 282 return x.sum(1), y.sum(1) 283 284 inps = ( 285 torch.rand(16, 256, device="cuda"), 286 torch.rand(32, 256, device="cuda"), 287 ) 288 torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) 289 torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) 290 out_eager = fn(*inps) 291 out_compiled = torch.compile(fn)(*inps) 292 293 self.assertEqual(out_eager, out_compiled) 294 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) 295 296 297@instantiate_parametrized_tests 298class ComboKernelDynamicShapesTests(TestCase): 299 check_model_cuda = check_model_cuda 300 check_model_cpu = check_model 301 check_kernel_count = True 302 303 def setUp(self): 304 super().setUp() 305 torch._inductor.metrics.reset() 306 torch._inductor.config.combo_kernels = True 307 torch._inductor.config.benchmark_combo_kernel = True 308 torch._dynamo.config.automatic_dynamic_shapes = False 309 torch._dynamo.config.assume_static_by_default = False 310 311 def tearDown(self): 312 super().tearDown() 313 torch._inductor.metrics.reset() 314 315 @requires_cuda 316 def test_dynamic_shapes_activations(self): 317 def test_activations(a, b, c): 318 a1 = torch.nn.functional.relu(a) 319 b1 = torch.nn.functional.sigmoid(b) 320 c1 = torch.nn.functional.tanh(c) 321 return a1, b1, c1 322 323 inps = [ 324 torch.rand(10, 10, device="cuda"), 325 torch.rand(20, 20, device="cuda"), 326 torch.rand(10, 10, device="cuda"), 327 ] 328 329 out_eager = test_activations(*inps) 330 out_compiled = torch.compile(test_activations)(*inps) 331 332 self.assertEqual(out_eager, out_compiled) 333 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) 334 335 @requires_cuda 336 def test_dynamic_shapes_2d_blocking(self): 337 def fn(a0, a1, a2, b0, b1, b2): 338 c0 = torch.add(a0, b0) 339 c1 = torch.add(a1, b1) 340 c2 = torch.add(a2, b2) 341 return c0, c1, c2 342 343 self.check_model_cuda( 344 fn, 345 ( 346 torch.rand(30, 20, device="cuda"), 347 torch.rand(40, 30, device="cuda"), 348 torch.rand(36, 40, device="cuda"), 349 torch.rand(30, 20, device="cuda"), 350 torch.rand(30, 40, device="cuda").t(), 351 torch.rand(40, 36, device="cuda").t(), 352 ), 353 ) 354 355 self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) 356 357 @requires_cuda 358 def test_dynamic_shapes_reduce(self): 359 def test_reduce(a, b, c, d): 360 a1 = torch.sum(a, dim=0) 361 b1 = torch.max(b, dim=0) 362 c1 = torch.min(c, dim=0) 363 d1 = torch.nn.functional.tanh(d) 364 365 return a1, b1, c1, d1 366 367 inps = [ 368 torch.rand(10, 10, device="cuda"), 369 torch.rand(20, 20, device="cuda"), 370 torch.rand(10, 10, device="cuda"), 371 torch.rand(30, 8, device="cuda"), 372 ] 373 374 out_eager = test_reduce(*inps) 375 out_compiled = torch.compile(test_reduce)(*inps) 376 377 self.assertEqual(out_eager, out_compiled) 378 self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10) 379 380 @requires_cuda 381 def test_dynamic_shapes_mutated(self): 382 # combo kernel dispatch strategy: round robin 383 def test_mutated(a, b, c, d): 384 a.add_(1) 385 b.sigmoid_() 386 c = torch.add(c, 5) 387 d.tanh_() 388 389 return a, b, c, d 390 391 inps = [ 392 torch.rand(10, 10, device="cuda"), 393 torch.rand(20, 5, device="cuda"), 394 torch.rand(10, 10, device="cuda"), 395 torch.rand(5, 18, device="cuda"), 396 ] 397 398 out_eager = test_mutated(*inps) 399 out_compiled = torch.compile(test_mutated)(*inps) 400 401 self.assertEqual(out_eager, out_compiled) 402 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) 403 404 @requires_cuda 405 @torch._inductor.config.patch("combo_kernels_autotune", 0) 406 def test_dynamic_shapes_activations_no_autotune(self): 407 def test_activations(a, b, c): 408 a1 = torch.nn.functional.relu(a) 409 b1 = torch.nn.functional.sigmoid(b) 410 c1 = torch.nn.functional.tanh(c) 411 return a1, b1, c1 412 413 inps = [ 414 torch.rand(10, 10, device="cuda"), 415 torch.rand(20, 20, device="cuda"), 416 torch.rand(10, 10, device="cuda"), 417 ] 418 419 out_eager = test_activations(*inps) 420 out_compiled = torch.compile(test_activations)(*inps) 421 422 self.assertEqual(out_eager, out_compiled) 423 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) 424 425 @requires_cuda 426 @torch._dynamo.config.patch("automatic_dynamic_shapes", True) 427 @torch._dynamo.config.patch("assume_static_by_default", True) 428 def test_dynamic_shapes_persistent_reduction_no_x_dim(self): 429 def fn(x, y): 430 return x.sum(1), y.sum(1) 431 432 inps = ( 433 torch.rand(16, 256, device="cuda"), 434 torch.rand(32, 256, device="cuda"), 435 ) 436 torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) 437 torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) 438 out_eager = fn(*inps) 439 out_compiled = torch.compile(fn)(*inps) 440 441 self.assertEqual(out_eager, out_compiled) 442 self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) 443 444 @requires_cuda 445 @torch._dynamo.config.patch("automatic_dynamic_shapes", True) 446 @torch._dynamo.config.patch("assume_static_by_default", True) 447 def test_dynamic_shapes_2d_blocking_round_robin(self): 448 def fn(a0, a1, a2, b0, b1, b2): 449 c0 = torch.add(a0, b0) 450 c1 = torch.add(a1, b1) 451 c2 = torch.add(a2, b2) 452 return c0, c1, c2 453 454 inps = ( 455 torch.rand(20, 30, device="cuda"), 456 torch.rand(30, 30, device="cuda"), 457 torch.rand(40, 32, device="cuda"), 458 torch.rand(30, 20, device="cuda").t(), 459 torch.rand(30, 30, device="cuda").t(), 460 torch.rand(32, 40, device="cuda").t(), 461 ) 462 463 out_eager = fn(*inps) 464 compiled = torch.compile(fn) 465 out_compiled = compiled(*inps) 466 self.assertEqual(out_eager, out_compiled) 467 self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6) 468 torch._inductor.metrics.reset() 469 470 inps = ( 471 torch.rand(24, 30, device="cuda"), 472 torch.rand(32, 30, device="cuda"), 473 torch.rand(48, 32, device="cuda"), 474 torch.rand(30, 24, device="cuda").t(), 475 torch.rand(30, 32, device="cuda").t(), 476 torch.rand(32, 48, device="cuda").t(), 477 ) 478 out_compiled = compiled(*inps) 479 out_eager = fn(*inps) 480 self.assertEqual(out_eager, out_compiled) 481 self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6) 482 483 @requires_cuda 484 @torch._dynamo.config.patch("automatic_dynamic_shapes", True) 485 @torch._dynamo.config.patch("assume_static_by_default", True) 486 @torch._inductor.config.patch("triton.autotune_at_compile_time", True) 487 def test_dynamic_shapes_persistent_reduction_mixed_x_dim_cuda(self): 488 def fn(x, y, z): 489 return x.sum(1), y.mean(1), z.max(1) 490 491 inps = ( 492 torch.rand(16, 128, device="cuda"), 493 torch.rand(32, 128, device="cuda"), 494 torch.rand(32, 256, device="cuda"), 495 ) 496 torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) 497 torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) 498 torch._dynamo.mark_dynamic(inps[2], 0, min=1, max=256) 499 out_eager = fn(*inps) 500 out_compiled = torch.compile(fn)(*inps) 501 502 self.assertEqual(out_eager, out_compiled) 503 504 505if __name__ == "__main__": 506 from torch._dynamo.test_case import run_tests 507 508 if HAS_CPU or HAS_CUDA: 509 run_tests(needs="filelock") 510