1# Owner(s): ["NNC"] 2 3import numpy as np 4import torch 5import torch.nn.functional as F 6from torch import nn 7import unittest 8import itertools 9 10from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo 11 12from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions 13 14LLVM_ENABLED = torch._C._llvm_enabled() 15 16class BaseTestClass(JitTestCase): 17 def setUp(self): 18 super().setUp() 19 self.tensorexpr_options = TensorExprTestOptions() 20 self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] 21 self.dtypes = [torch.float32, torch.bfloat16] if LLVM_ENABLED else [torch.float32] 22 23 def tearDown(self): 24 self.tensorexpr_options.restore() 25 super().tearDown() 26 27 def assertLastGraphAllFused(self): 28 self.assertAllFused(torch.jit.last_executed_optimized_graph()) 29 30 31def warmup_and_run_forward(f, *args): 32 for _ in range(torch._C._jit_get_num_profiled_runs() + 1): 33 results = f(*args) 34 return results 35 36 37@skipIfTorchDynamo() 38class TestTensorExprFuser(BaseTestClass): 39 def test_easy(self): 40 def easy(x, y): 41 aaa = torch.add(x, y) 42 return aaa 43 44 traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024))) 45 46 a = torch.rand(1024) 47 b = torch.rand(1024) 48 x = warmup_and_run_forward(traced, a, b) 49 self.assertLastGraphAllFused() 50 np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) 51 52 def test_three_arg(self): 53 def easy(x, y, z): 54 aaa = torch.add(x, y) 55 bbb = torch.add(aaa, z) 56 return bbb 57 58 traced = torch.jit.trace( 59 easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) 60 ) 61 62 a = torch.rand(1024) 63 b = torch.rand(1024) 64 c = torch.rand(1024) 65 x = warmup_and_run_forward(traced, a, b, c) 66 self.assertLastGraphAllFused() 67 npr = a.numpy() + b.numpy() + c.numpy() 68 np.testing.assert_allclose(npr, x.numpy()) 69 70 def test_four_arg(self): 71 def run_addcmul(x, y, z, w): 72 c = torch.addcmul(torch.add(x, y), z, w) 73 return c 74 75 for dev in self.devices: 76 rand_a = torch.rand(1024, dtype=torch.float, device=dev) 77 rand_b = torch.rand(1024, dtype=torch.float, device=dev) 78 rand_c = torch.rand(1024, dtype=torch.float, device=dev) 79 rand_d = torch.rand(1024, dtype=torch.float, device=dev) 80 81 traced = torch.jit.trace( 82 run_addcmul, 83 ( 84 torch.zeros(1024, dtype=torch.float, device=dev), 85 torch.zeros(1024, dtype=torch.float, device=dev), 86 torch.zeros(1024, dtype=torch.float, device=dev), 87 torch.zeros(1024, dtype=torch.float, device=dev), 88 ), 89 ) 90 91 x = warmup_and_run_forward(traced, rand_a, rand_b, rand_c, rand_d) 92 self.assertLastGraphAllFused() 93 y = run_addcmul(rand_a, rand_b, rand_c, rand_d) 94 np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6) 95 96 def test_three_arg2(self): 97 for device in self.devices: 98 def test(x, y, z): 99 aaa = torch.add(x, y) 100 bbb = torch.add(aaa, z) 101 return bbb 102 103 M = 32 104 N = 32 105 traced = torch.jit.trace( 106 test, 107 ( 108 torch.rand(M, N, device=device), 109 torch.rand(M, N, device=device), 110 torch.rand(M, N, device=device), 111 ), 112 ) 113 114 a = torch.rand(M, N, device=device) 115 b = torch.rand(M, N, device=device) 116 c = torch.rand(M, N, device=device) 117 x = traced(a, b, c) 118 x = warmup_and_run_forward(traced, a, b, c) 119 self.assertLastGraphAllFused() 120 npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() 121 np.testing.assert_allclose(npr, x.cpu().numpy()) 122 123 def test_broadcast3(self): 124 for device in self.devices: 125 def test_body(M, N, L, K): 126 def test(x, y, z): 127 v1 = torch.add(x, y) 128 v2 = torch.add(v1, z) 129 return v2 130 131 a_shape = [M, N] 132 b_shape = [L, M, 1] 133 c_shape = [K, L, 1, 1] 134 traced = torch.jit.trace( 135 test, 136 ( 137 torch.rand(*a_shape, device=device), 138 torch.rand(*b_shape, device=device), 139 torch.rand(*c_shape, device=device), 140 ), 141 ) 142 143 a = torch.rand(*a_shape, device=device) 144 b = torch.rand(*b_shape, device=device) 145 c = torch.rand(*c_shape, device=device) 146 x = warmup_and_run_forward(traced, a, b, c) 147 self.assertLastGraphAllFused() 148 npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() 149 np.testing.assert_allclose(npr, x.cpu().numpy()) 150 151 test_configs = [[5, 2, 7, 3], [8, 8, 8, 8]] 152 for test_config in test_configs: 153 test_body(*test_config) 154 155 def test_all_combos(self): 156 def easy(x, y, z): 157 a = torch.add(x, y) 158 b = torch.add(a, z) 159 c = torch.add(x, b) 160 d = torch.add(c, a) 161 return d 162 163 def np_easy(x, y, z): 164 a = x + y 165 b = a + z 166 c = x + b 167 d = c + a 168 return d 169 170 traced = torch.jit.trace( 171 easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) 172 ) 173 174 a = torch.rand(1024) 175 b = torch.rand(1024) 176 c = torch.rand(1024) 177 x = warmup_and_run_forward(traced, a, b, c) 178 self.assertLastGraphAllFused() 179 npr = np_easy(a.numpy(), b.numpy(), c.numpy()) 180 np.testing.assert_allclose(npr, x.numpy()) 181 182 def test_rank_two(self): 183 def easy(x, y, z): 184 a = torch.add(x, y) 185 b = torch.add(a, z) 186 c = torch.add(x, b) 187 d = torch.add(c, a) 188 return d 189 190 def np_easy(x, y, z): 191 a = x + y 192 b = a + z 193 c = x + b 194 d = c + a 195 return d 196 197 shape = 32, 32 198 traced = torch.jit.trace( 199 easy, (torch.rand(shape), torch.rand(shape), torch.rand(shape)) 200 ) 201 202 a = torch.rand(shape) 203 b = torch.rand(shape) 204 c = torch.rand(shape) 205 x = warmup_and_run_forward(traced, a, b, c) 206 self.assertLastGraphAllFused() 207 npr = np_easy(a.numpy(), b.numpy(), c.numpy()) 208 np.testing.assert_allclose(npr, x.numpy()) 209 210 def test_broadcast(self): 211 def easy(x, y, z): 212 a = torch.add(x, y) 213 b = torch.add(a, z) 214 return b 215 216 def np_easy(x, y, z): 217 a = x + y 218 b = a + z 219 return b 220 221 N = 32 222 traced = torch.jit.trace(easy, (torch.rand(N, N), torch.rand(N), torch.rand(N, N))) 223 224 a = torch.rand(N, N) 225 b = torch.rand(N) 226 c = torch.rand(N, N) 227 x = warmup_and_run_forward(traced, a, b, c) 228 self.assertLastGraphAllFused() 229 npr = np_easy(a.numpy(), b.numpy(), c.numpy()) 230 np.testing.assert_allclose(npr, x.numpy()) 231 232 def test_broadcast_2(self): 233 zero = torch.tensor([0.0], dtype=torch.float) 234 235 def foo(x, y, z): 236 aaa = torch.add(x, y) 237 bbb = torch.add(zero, aaa) 238 return torch.add(bbb, z) 239 240 def foo_np(x, y, z): 241 a = x + y 242 b = zero.numpy() + a 243 return b + z 244 245 x = torch.rand(3, 4) 246 y = torch.ones(3, 1) 247 z = torch.rand(4) 248 traced = torch.jit.trace(foo, (x, y, z)) 249 250 r = warmup_and_run_forward(traced, x, y, z) 251 self.assertLastGraphAllFused() 252 253 rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) 254 np.testing.assert_allclose(r, rnp) 255 256 def test_broadcast_big2(self): 257 zero = torch.tensor([0.0], dtype=torch.float) 258 259 def foo(x, y, z): 260 aaa = torch.add(x, y) 261 bbb = torch.add(zero, aaa) 262 return torch.add(bbb, z) 263 264 def foo_np(x, y, z): 265 a = x + y 266 b = zero.numpy() + a 267 return b + z 268 269 x = torch.rand(32, 1024) 270 y = torch.ones(32, 1) 271 z = torch.rand(1024) 272 traced = torch.jit.trace(foo, (x, y, z)) 273 274 r = warmup_and_run_forward(traced, x, y, z) 275 self.assertLastGraphAllFused() 276 rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) 277 np.testing.assert_allclose(r, rnp) 278 279 def test_alpha(self): 280 def alpha(x): 281 aaa = torch.add(x, x, alpha=2.0) 282 return aaa 283 284 traced = torch.jit.trace(alpha, (torch.tensor([1.0]))) 285 286 a = torch.tensor([1.0]) 287 x = traced(a) 288 np.testing.assert_allclose(a.numpy() + 2.0 * a.numpy(), x.numpy()) 289 290 @suppress_warnings 291 def test_constant(self): 292 def constant(x): 293 bbb = torch.tensor([1.0]) 294 aaa = torch.add(x, bbb) 295 return aaa 296 297 traced = torch.jit.trace(constant, (torch.tensor([1.0]))) 298 299 a = torch.tensor([1.0]) 300 x = warmup_and_run_forward(traced, a) 301 self.assertLastGraphAllFused() 302 np.testing.assert_allclose(a.numpy() + 1.0, x.numpy()) 303 304 def test_add_sub(self): 305 def easy(x, y, z): 306 aaa = torch.add(x, y) 307 bbb = torch.sub(aaa, z) 308 return bbb 309 310 traced = torch.jit.trace( 311 easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) 312 ) 313 314 a = torch.rand(1024) 315 b = torch.rand(1024) 316 c = torch.rand(1024) 317 x = warmup_and_run_forward(traced, a, b, c) 318 self.assertLastGraphAllFused() 319 np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy()) 320 321 def test_promotion(self): 322 def easy(x, y): 323 aaa = torch.add(x, y) 324 return aaa 325 326 traced = torch.jit.trace( 327 easy, 328 (torch.zeros(1024, dtype=torch.int32), torch.rand(1024, dtype=torch.float32)), 329 ) 330 331 a = torch.zeros(1024, dtype=torch.int32) 332 b = torch.rand(1024, dtype=torch.float32) 333 x = warmup_and_run_forward(traced, a, b) 334 self.assertLastGraphAllFused() 335 np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) 336 337 def test_double(self): 338 TENSOR_LEN = 8 339 340 def easy(x, y): 341 aaa = torch.add(x, y) 342 bbb = torch.mul(aaa, y) 343 return bbb 344 345 traced = torch.jit.trace( 346 easy, 347 (torch.rand(TENSOR_LEN, dtype=torch.float64), torch.full((TENSOR_LEN,), 0.5, dtype=torch.float64)), 348 ) 349 350 a = torch.rand(TENSOR_LEN, dtype=torch.double) 351 b = torch.full((TENSOR_LEN,), 0.5, dtype=torch.double) 352 x = warmup_and_run_forward(traced, a, b) 353 self.assertLastGraphAllFused() 354 np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) 355 356 def test_short(self): 357 TENSOR_LEN = 8 358 359 def easy(x, y): 360 aaa = torch.add(x, y) 361 bbb = torch.mul(aaa, y) 362 return bbb 363 364 traced = torch.jit.trace( 365 easy, 366 (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16), 367 torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)), 368 ) 369 370 a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) 371 b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) 372 x = warmup_and_run_forward(traced, a, b) 373 self.assertLastGraphAllFused() 374 np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) 375 376 def test_char(self): 377 TENSOR_LEN = 8 378 379 def easy(x, y): 380 aaa = torch.add(x, y) 381 bbb = torch.mul(aaa, y) 382 return bbb 383 384 traced = torch.jit.trace( 385 easy, 386 (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8), 387 torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)), 388 ) 389 390 a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) 391 b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) 392 x = warmup_and_run_forward(traced, a, b) 393 self.assertLastGraphAllFused() 394 np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) 395 396 def test_int64_promotion(self): 397 TENSOR_LEN = 8 398 399 def easy(x, y): 400 aaa = torch.add(x, y) 401 bbb = torch.mul(aaa, y) 402 return bbb 403 404 traced = torch.jit.trace( 405 easy, 406 (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8), 407 torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)), 408 ) 409 410 a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) 411 b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64) 412 x = warmup_and_run_forward(traced, a, b) 413 self.assertLastGraphAllFused() 414 np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) 415 416 def test_eq(self): 417 def easy(x, y): 418 c = torch.eq(x, y) 419 return c 420 421 traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) 422 a = torch.zeros(1024, dtype=torch.int32) 423 b = torch.zeros(1024, dtype=torch.int32) 424 x = warmup_and_run_forward(traced, a, b) 425 self.assertLastGraphAllFused() 426 np.testing.assert_allclose(np.ones(1024), x.numpy()) 427 428 def test_ne(self): 429 def easy(x, y): 430 c = torch.ne(x, y) 431 return c 432 433 traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) 434 a = torch.zeros(1024, dtype=torch.int32) 435 b = torch.ones(1024, dtype=torch.int32) 436 x = warmup_and_run_forward(traced, a, b) 437 self.assertLastGraphAllFused() 438 np.testing.assert_allclose(np.ones(1024), x.numpy()) 439 440 def test_ge(self): 441 def easy(x, y): 442 c = torch.ge(x, y) 443 return c 444 445 traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) 446 aa = np.empty([1024], dtype=np.int32) 447 aa.fill(5) 448 a = torch.from_numpy(aa) 449 b = torch.zeros(1024, dtype=torch.int32) 450 x = warmup_and_run_forward(traced, a, b) 451 self.assertLastGraphAllFused() 452 np.testing.assert_allclose(np.ones(1024), x.numpy()) 453 454 def test_gt(self): 455 def easy(x, y): 456 c = torch.gt(x, y) 457 return c 458 459 traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) 460 a = torch.ones(1024, dtype=torch.int32) 461 b = torch.zeros(1024, dtype=torch.int32) 462 x = warmup_and_run_forward(traced, a, b) 463 self.assertLastGraphAllFused() 464 np.testing.assert_allclose(np.ones(1024), x.numpy()) 465 466 def test_le(self): 467 def easy(x, y): 468 c = torch.le(x, y) 469 return c 470 471 traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) 472 aa = np.empty([1024], dtype=np.int32) 473 aa.fill(5) 474 a = torch.from_numpy(aa) 475 b = torch.zeros(1024, dtype=torch.int32) 476 x = warmup_and_run_forward(traced, a, b) 477 self.assertLastGraphAllFused() 478 np.testing.assert_allclose(np.zeros(1024), x.numpy()) 479 480 def test_lt(self): 481 def easy(x, y): 482 c = torch.lt(x, y) 483 return c 484 485 for dev in self.devices: 486 traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev))) 487 a = torch.ones(1024, dtype=torch.int32, device=dev) 488 b = torch.zeros(1024, dtype=torch.int32, device=dev) 489 x = warmup_and_run_forward(traced, a, b) 490 self.assertLastGraphAllFused() 491 np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy()) 492 493 @suppress_warnings 494 def test_min_max(self): 495 def test(x, y): 496 return torch.max(torch.min(x, y), torch.tensor([4.0])) 497 498 traced = torch.jit.trace(test, (torch.zeros(1024), torch.zeros(1024))) 499 a = 8.0 * torch.rand(1024) 500 b = 8.0 * torch.rand(1024) 501 np.testing.assert_allclose( 502 warmup_and_run_forward(traced, a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0]) 503 ) 504 self.assertLastGraphAllFused() 505 506 def test_min_max_reduction(self): 507 def test(x): 508 return torch.min(x) + torch.max(x) 509 510 traced = torch.jit.trace(test, (torch.zeros(1024))) 511 a = 8.0 * torch.rand(1024) 512 np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) 513 self.assertLastGraphAllFused() 514 515 def test_min_max_reduction2(self): 516 def test(x): 517 return x.min() + x.max() 518 519 traced = torch.jit.trace(test, (torch.zeros(1024))) 520 a = 8.0 * torch.rand(1024) 521 np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) 522 self.assertLastGraphAllFused() 523 524 def test_min_max_reduction_dim1(self): 525 def test(x): 526 return torch.min(x, 1)[0] + torch.max(x, 1)[0] 527 528 traced = torch.jit.trace(test, (torch.zeros(16, 16))) 529 a = 8.0 * torch.rand(16, 16) 530 np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin( 531 a.numpy(), axis=1) + np.amax(a.numpy(), axis=1)) 532 self.assertLastGraphAllFused() 533 534 def test_min_max_reduction_dim1_2(self): 535 def test(x): 536 return torch.min(x * x, 1) 537 538 traced = torch.jit.trace(test, (torch.zeros(16, 16))) 539 a = 8.0 * torch.rand(16, 16) 540 np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin((a * a).numpy(), axis=1)) 541 self.assertLastGraphAllFused() 542 543 def test_clamp(self): 544 def test(x): 545 return torch.clamp(x + 3.0, 0.0, 6.0) 546 547 for dev in self.devices: 548 traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) 549 a = 20.0 * torch.rand(1024, device=dev) - 10.0 550 an = a.cpu().numpy() 551 np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip(an + 3.0, 0.0, 6.0)) 552 self.assertLastGraphAllFused() 553 554 def test_relu(self): 555 def test(x): 556 return torch.clamp(F.relu(x), 0, 0.5) 557 558 for dev in self.devices: 559 traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) 560 a = 20.0 * torch.rand(1024, device=dev) - 10.0 561 an = a.cpu().numpy() 562 np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5)) 563 self.assertLastGraphAllFused() 564 565 def test_reps(self): 566 def easy(x, y): 567 c = torch.add(x, y) 568 return c 569 570 traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024))) 571 572 for _ in range(32): 573 a = torch.ones(1024) 574 b = torch.zeros(1024) 575 x = warmup_and_run_forward(traced, a, b) 576 np.testing.assert_allclose(np.ones(1024), x.numpy()) 577 578 def test_add_const_rhs(self): 579 def test(x): 580 return x + 3.0 581 582 traced = torch.jit.trace(test, torch.rand(4)) 583 x = torch.rand(4) 584 y = warmup_and_run_forward(traced, x) 585 self.assertLastGraphAllFused() 586 np.testing.assert_allclose(x.numpy() + 3.0, y.numpy()) 587 588 def test_int_output(self): 589 def test(x, y, z): 590 return x * y * z 591 592 xs = [(torch.rand(4) * 3 + 1).to(torch.int32) for i in range(3)] 593 x, y, z = xs 594 xn, yn, zn = (t.numpy() for t in xs) 595 traced = torch.jit.trace(test, (x, y, z)) 596 res = warmup_and_run_forward(traced, x, y, z) 597 self.assertLastGraphAllFused() 598 np.testing.assert_allclose(xn * yn * zn, res.numpy()) 599 600 def test_binary_ops(self): 601 def test_atan2(x, y): 602 c = torch.atan2(torch.add(x, y), y) 603 return c 604 605 def test_gt(x, y): 606 c = torch.gt(torch.add(x, y), y) 607 return c 608 609 def test_ge(x, y): 610 c = torch.ge(torch.add(x, y), y) 611 return c 612 613 def test_lt(x, y): 614 c = torch.lt(torch.add(x, y), y) 615 return c 616 617 def test_le(x, y): 618 c = torch.le(torch.add(x, y), y) 619 return c 620 621 def test_lerp(x, y): 622 c = torch.lerp(torch.add(x, 1), x, 2.0) 623 return c 624 625 def test_mul(x, y): 626 c = torch.mul(torch.add(x, y), y) 627 return c 628 629 def test_ne(x, y): 630 c = torch.ne(torch.add(x, y), y) 631 return c 632 633 def test_div(x, y): 634 c = torch.div(torch.add(x, y), 2) 635 return c 636 637 def test_eq(x, y): 638 c = torch.eq(torch.add(x, y), y) 639 return c 640 641 def test_fmod(x, y): 642 c = torch.fmod(torch.add(x, y), 2) 643 return c 644 645 def test_sub(x, y): 646 c = torch.sub(torch.add(x, y), x) 647 return c 648 649 def test_remainder(x, y): 650 c = torch.remainder(torch.add(x, y), 3.0) 651 return c 652 653 def test_pow(x, y): 654 c = torch.pow(torch.add(x, y), 2.0) 655 return c 656 657 def test_type_as(x, y): 658 return x.type_as(torch.add(x, y)) 659 660 cmp_fns = { 661 test_gt, 662 test_ge, 663 test_lt, 664 test_le, 665 test_ne, 666 test_eq 667 } 668 669 non_cmp_fns = { 670 test_atan2, 671 test_lerp, 672 test_mul, 673 test_div, 674 test_fmod, 675 test_sub, 676 test_remainder, 677 test_pow, 678 test_type_as, 679 } 680 681 all_test_fns = cmp_fns.union(non_cmp_fns) 682 fn_dev_dtype = itertools.product(all_test_fns, self.devices, self.dtypes) 683 for torch_fn, dev, data_type in fn_dev_dtype: 684 if torch_fn is test_lerp and data_type is torch.bfloat16: 685 continue 686 rand_a = torch.rand(1024, dtype=data_type, device=dev) 687 rand_b = torch.rand(1024, dtype=data_type, device=dev) 688 in1 = 20 * torch.rand(1024, dtype=data_type, device=dev) 689 in2 = 20 * torch.rand(1024, dtype=data_type, device=dev) 690 traced = torch.jit.trace(torch_fn, (in1, in2)) 691 x = warmup_and_run_forward(traced, rand_a, rand_b) 692 self.assertLastGraphAllFused() 693 694 _atol = 2e-3 695 _rtol = 1e-5 696 if data_type is torch.bfloat16: 697 # Compared to aten logic, NNC coudl save addtional BF16/Fp32 conversion. 698 # Take d = a + b - c as an example, the aten logic is as follows at 699 # operator level: 700 # tmp = to_bf16(to_fp32(a) + to_fp32(b)) 701 # d = to_bf16(to_fp32(tmp) + to_fp32(c)) 702 # But NNC could fuse the compression and remove the redudant conversions. 703 # The final statement is as follows 704 # d = to_bf16(to_fp32(a) + to_fp32(b) + to_fp32(c)) 705 # Hence, we simulate NNC computation by feeding fp32 tensors and converting 706 # the result tensor back to bf16. The simulation could avoid the numeric 707 # deviation to simplify the result comprasion 708 y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float()) 709 if torch_fn not in cmp_fns: 710 y = y.bfloat16() 711 _atol = 2e-2 712 else: 713 y = torch_fn(rand_a, rand_b) 714 self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol) 715 716 def test_unary_ops(self): 717 def test_cast_float(x, y): 718 c = torch.ops.aten._cast_Float(torch.add(x, y)) 719 return c 720 721 def test_round(x, y): 722 c = torch.round(torch.add(x, y)) 723 return c 724 725 def test_sin(x, y): 726 c = torch.sin(torch.add(x, y)) 727 return c 728 729 def test_asin(x, y): 730 c = torch.asin(torch.add(x, y)) 731 return c 732 733 def test_sinh(x, y): 734 c = torch.sinh(torch.add(x, y)) 735 return c 736 737 def test_cos(x, y): 738 c = torch.cos(torch.add(x, y)) 739 return c 740 741 def test_acos(x, y): 742 c = torch.acos(torch.add(x, y)) 743 return c 744 745 def test_cosh(x, y): 746 c = torch.cosh(torch.add(x, y)) 747 return c 748 749 def test_tan(x, y): 750 c = torch.tan(torch.add(x, y)) 751 return c 752 753 def test_atan(x, y): 754 c = torch.atan(torch.add(x, y)) 755 return c 756 757 def test_tanh(x, y): 758 c = torch.tanh(torch.add(x, y)) 759 return c 760 761 def test_sqrt(x, y): 762 c = torch.sqrt(torch.add(x, y)) 763 return c 764 765 def test_rsqrt(x, y): 766 c = torch.rsqrt(torch.add(x, y)) 767 return c 768 769 def test_floor(x, y): 770 c = torch.floor(torch.add(x, y)) 771 return c 772 773 def test_ceil(x, y): 774 c = torch.ceil(torch.add(x, y)) 775 return c 776 777 def test_trunc(x, y): 778 c = torch.trunc(torch.add(x, y)) 779 return c 780 781 def test_abs(x, y): 782 c = torch.abs(torch.add(x, y)) 783 return c 784 785 def test_log(x, y): 786 c = torch.log(torch.add(x, y)) 787 return c 788 789 def test_log2(x, y): 790 c = torch.log2(torch.add(x, y)) 791 return c 792 793 def test_log10(x, y): 794 c = torch.log10(torch.add(x, y)) 795 return c 796 797 def test_log1p(x, y): 798 c = torch.log1p(torch.add(x, y)) 799 return c 800 801 def test_rqrt(x, y): 802 c = torch.rsqrt(torch.add(x, y)) 803 return c 804 805 def test_erf(x, y): 806 c = torch.erf(torch.add(x, y)) 807 return c 808 809 def test_exp(x, y): 810 c = torch.exp(torch.add(x, y)) 811 return c 812 813 def test_expm1(x, y): 814 c = torch.expm1(torch.add(x, y)) 815 return c 816 817 def test_erfc(x, y): 818 c = torch.erfc(torch.add(x, y)) 819 return c 820 821 def test_frac(x, y): 822 c = torch.frac(torch.add(x, y)) 823 return c 824 825 def test_lgamma(x, y): 826 c = torch.lgamma(torch.add(x, y)) 827 return c 828 829 def test_sigmoid(x, y): 830 c = torch.sigmoid(torch.add(x, y)) 831 return c 832 833 def test_reciprocal(x, y): 834 c = torch.reciprocal(torch.add(x, y)) 835 return c 836 837 def test_neg(x, y): 838 c = torch.neg(torch.add(x, y)) 839 return c 840 841 def test_relu(x, y): 842 c = torch.relu(torch.add(x, y)) 843 return c 844 845 def test_hardtanh(x, y): 846 c = F.hardtanh(torch.add(x, y), -1.0, 1.0) 847 return c 848 849 def test_threshold(x, y): 850 c = F.threshold(torch.add(x, y), 0.5, 10) 851 return c 852 853 gpu_only_fns = { 854 test_erf, 855 test_erfc 856 } 857 fns = { 858 test_round, 859 test_sin, 860 test_asin, 861 test_sinh, 862 test_cos, 863 test_acos, 864 test_cosh, 865 test_tan, 866 test_atan, 867 test_sqrt, 868 test_floor, 869 test_ceil, 870 test_trunc, 871 test_abs, 872 test_log, 873 test_log2, 874 test_log10, 875 test_log1p, 876 test_rsqrt, 877 test_exp, 878 test_expm1, 879 test_frac, 880 test_lgamma, 881 test_reciprocal, 882 test_neg, 883 test_threshold, 884 test_relu, 885 test_tanh, 886 test_hardtanh, 887 test_sigmoid, 888 } 889 fn_dev_dtype = itertools.product(gpu_only_fns.union(fns), self.devices, self.dtypes) 890 891 torch.manual_seed(0) 892 for torch_fn, dev, data_type in fn_dev_dtype: 893 if torch_fn == test_lgamma and dev == "cuda": 894 # lgamma_cuda does not support BF16 895 continue 896 rand_a = torch.rand(1024, dtype=data_type, device=dev) 897 rand_b = torch.rand(1024, dtype=data_type, device=dev) 898 899 ins = 20 * torch.rand(1024, dtype=data_type, device=dev) 900 cc = np.empty([1024], dtype=np.float32) 901 cc.fill(np.nan) 902 nans = torch.from_numpy(cc).to(dev) 903 traced = torch.jit.trace(torch_fn, (ins, ins)) 904 x = warmup_and_run_forward(traced, rand_a, rand_b) 905 self.assertLastGraphAllFused() 906 907 _atol = 5e-3 if data_type is torch.bfloat16 else 2e-3 908 _rtol = 1e-5 909 if data_type is torch.bfloat16 and torch_fn not in gpu_only_fns: 910 y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float()) 911 y = y.bfloat16() 912 else: 913 y = torch_fn(rand_a, rand_b) 914 915 self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol) 916 # nans 917 # TODO: reenable. Currently all of the tests fail 918 # traced = torch.jit.trace(torch_fn, (ins, ins)) 919 # x = warmup_and_run_forward(traced, rand_a, rand_b) 920 # y = torch_fn(nans, rand_b) 921 # try: 922 # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) 923 # print("Succeeded on dev=", dev, "function=", torch_fn) 924 # except AssertionError: 925 # # Print extra info before exiting: 926 # print("Failed on dev=", dev, "function=", torch_fn) 927 # # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) 928 929 930 def test_round_2(self): 931 def round(x): 932 return torch.round(x) 933 934 for data_type in [torch.float32, torch.double]: 935 a = torch.tensor([0.2, 1.6, 2.5, 3.5]).to(data_type) 936 traced = torch.jit.trace(round, (a)) 937 x = warmup_and_run_forward(traced, a) 938 self.assertLastGraphAllFused() 939 y = round(x) 940 self.assertEqual(x, y) 941 942 def test_rand_like(self): 943 N = 1 << 16 944 945 def run_rand_like(x, y): 946 return torch.rand_like(torch.add(x, y)) 947 948 for device in self.devices: 949 x = torch.rand(N, device=device) 950 traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False) 951 952 for data_type in self.dtypes: 953 _x = x.to(dtype=data_type) 954 x_v = warmup_and_run_forward(traced, _x, _x) 955 self.assertLastGraphAllFused() 956 957 x_np = x.cpu().numpy() 958 x1_mean = np.mean(x_np) 959 x2_mean = np.mean(x_np ** 2) 960 x3_mean = np.mean(x_np ** 3) 961 np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2) 962 np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2) 963 np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2) 964 965 def test_nans(self): 966 def test_max(x, y): 967 return torch.max(2 * x, 2 * y) 968 969 def test_min(x, y): 970 return torch.min(2 * x, 2 * y) 971 972 tmax = torch.jit.trace(test_max, (torch.rand(1), torch.rand(1))) 973 tmin = torch.jit.trace(test_min, (torch.rand(1), torch.rand(1))) 974 975 for data_type in self.dtypes: 976 x = torch.tensor([np.nan]).to(dtype=data_type) 977 y = torch.tensor([1.0]).to(dtype=data_type) 978 979 assert np.isnan(warmup_and_run_forward(tmin, x, y).float().item()) 980 assert np.isnan(warmup_and_run_forward(tmin, y, x).float().item()) 981 self.assertLastGraphAllFused() 982 assert np.isnan(warmup_and_run_forward(tmax, x, y).float().item()) 983 assert np.isnan(warmup_and_run_forward(tmax, y, x).float().item()) 984 self.assertLastGraphAllFused() 985 986 def test_double_intrinsics(self): 987 def do_pow(x): 988 return torch.pow(x, 7) 989 990 for device in self.devices: 991 x = torch.rand(10, dtype=torch.double, device=device) 992 traced = torch.jit.trace(do_pow, (x)) 993 x = warmup_and_run_forward(traced, x) 994 self.assertLastGraphAllFused() 995 996 def test_remainder(self): 997 def run_remainder(x, y): 998 c = torch.remainder(torch.add(x, y), x) 999 return c 1000 1001 for data_type in self.dtypes: 1002 a = torch.rand(1024, dtype=data_type) 1003 b = torch.rand(1024, dtype=data_type) 1004 zeros = torch.zeros(1024, dtype=data_type) 1005 cc = np.array(1024, dtype=float) 1006 cc.fill(np.nan) 1007 nans = torch.from_numpy(cc).to(dtype=data_type) 1008 1009 # random floats 1010 zeros1 = torch.zeros(1024, dtype=data_type) 1011 zeros2 = torch.zeros(1024, dtype=data_type) 1012 1013 traced = torch.jit.trace(run_remainder, (zeros1, zeros2)) 1014 x = warmup_and_run_forward(traced, a, b) 1015 self.assertLastGraphAllFused() 1016 y = run_remainder(a, b) 1017 if data_type is torch.bfloat16: 1018 self.assertEqual(x, y, atol=4e-3, rtol=2e-3) 1019 else: 1020 self.assertEqual(x, y) 1021 1022 # div by 0 1023 traced = torch.jit.trace(run_remainder, (zeros1, zeros2)) 1024 x = warmup_and_run_forward(traced, zeros, a) 1025 self.assertLastGraphAllFused() 1026 y = run_remainder(zeros, a) 1027 self.assertEqual(x, y) 1028 1029 # numerators and denominatos are nan 1030 traced = torch.jit.trace(run_remainder, (zeros1, zeros2)) 1031 x = warmup_and_run_forward(traced, nans, a) 1032 self.assertLastGraphAllFused() 1033 y = run_remainder(nans, a) 1034 self.assertEqual(x, y) 1035 1036 def test_multioutput(self): 1037 def easy(x): 1038 b = x + 1 1039 c = b + b 1040 return (b, c) 1041 1042 traced = torch.jit.trace(easy, (torch.zeros(1024))) 1043 1044 a = torch.zeros(1024) 1045 b, c = warmup_and_run_forward(traced, a) 1046 self.assertLastGraphAllFused() 1047 bp = a.numpy() + 1 1048 cp = bp + bp 1049 np.testing.assert_allclose(b.numpy(), bp) 1050 np.testing.assert_allclose(c.numpy(), cp) 1051 1052 def test_chunk(self): 1053 def easy(x): 1054 y = x + 1 1055 aaa, bbb = torch.chunk(y, 2) 1056 return aaa + bbb 1057 1058 for data_type in self.dtypes: 1059 trace_input = torch.zeros(1024, 1024, dtype=data_type) 1060 traced = torch.jit.trace(easy, (trace_input)) 1061 1062 a = torch.zeros(32, 32, dtype=data_type) 1063 x = warmup_and_run_forward(traced, a) 1064 self.assertLastGraphAllFused() 1065 npr = a.float().numpy() 1066 npr2 = npr + 1 1067 npr_a, npr_b = np.array_split(npr2, 2) 1068 np.testing.assert_allclose(npr_a + npr_b, x.float().numpy()) 1069 1070 def test_cat(self): 1071 for device in self.devices: 1072 _dim = 1 1073 1074 def foo(*args): 1075 args_2 = [v + i for i, v in enumerate(args)] 1076 v = torch.cat(args_2, dim=_dim) 1077 return v * v 1078 1079 for data_type in self.dtypes: 1080 M = 16 1081 Ns = [128, 16, 1] 1082 values = [torch.zeros(M, N, dtype=data_type, device=device) for N in Ns] 1083 traced = torch.jit.trace(foo, values) 1084 1085 x = warmup_and_run_forward(traced, *values) 1086 self.assertLastGraphAllFused() 1087 ref = foo(*values) 1088 np.testing.assert_allclose(ref.cpu().float().numpy(), x.cpu().float().numpy()) 1089 1090 # Test channels-last 1091 for _cur_dim in range(4): 1092 _dim = _cur_dim 1093 values = [torch.randn((2, 3, 4, 5), device=device).to(memory_format=torch.channels_last) for _ in range(10)] 1094 traced = torch.jit.trace(foo, values) 1095 1096 x = warmup_and_run_forward(traced, *values) 1097 self.assertLastGraphAllFused() 1098 ref = foo(*values) 1099 self.assertEqual(ref, x) 1100 1101 # This test checks that we correctly handle fusion group with just aten::cat in it. 1102 # Note that the test only makes sense with min_fusion_group=1, otherwise no 1103 # fusion groups would be formed at all. 1104 # TODO: Fix and re-enable the test. 1105 @unittest.skip("cat is broken with fusion group inlining disabled") 1106 def test_cat_only(self): 1107 for device in self.devices: 1108 def foo(*args): 1109 args_2 = [v + i for i, v in enumerate(args)] 1110 v = torch.cat(args_2, dim=1) 1111 return v 1112 1113 M = 16 1114 Ns = [128, 16, 1] 1115 values = [torch.zeros(M, N, device=device) for N in Ns] 1116 traced = torch.jit.trace(foo, values) 1117 1118 x = warmup_and_run_forward(traced, *values) 1119 self.assertLastGraphAllFused() 1120 ref = foo(*values) 1121 np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1122 1123 def test_cat_negative_dim(self): 1124 for device in self.devices: 1125 def foo(*args): 1126 v = torch.cat(args, dim=-1) 1127 return v * v 1128 1129 M = 16 1130 Ns = [128, 16, 1] 1131 values = [torch.randn(M, N, device=device) for N in Ns] 1132 traced = torch.jit.trace(foo, values) 1133 1134 x = warmup_and_run_forward(traced, *values) 1135 self.assertLastGraphAllFused() 1136 ref = foo(*values) 1137 np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1138 1139 def test_cat_promote_inputs(self): 1140 for device in self.devices: 1141 def foo(*args): 1142 v = torch.cat(args, dim=1) 1143 return v * v 1144 1145 M = 16 1146 Ns = [128, 16, 1] 1147 dtypes = [torch.half, torch.float32, torch.double] 1148 values = [torch.randn(M, N, device=device, dtype=dt) for N, dt in zip(Ns, dtypes)] 1149 traced = torch.jit.trace(foo, values) 1150 1151 x = warmup_and_run_forward(traced, *values) 1152 self.assertLastGraphAllFused() 1153 ref = foo(*values) 1154 np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1155 1156 def test_cat_empty_tensors(self): 1157 for device in self.devices: 1158 def foo(*args): 1159 v = torch.cat(args, dim=1) 1160 return v * v 1161 1162 M = 16 1163 Ns = [128, 16, 1] 1164 empty = torch.tensor([], device=device, dtype=torch.double) 1165 values = [empty] + [torch.randn(M, N, device=device) for N in Ns] 1166 traced = torch.jit.trace(foo, values) 1167 1168 x = warmup_and_run_forward(traced, *values) 1169 self.assertLastGraphAllFused() 1170 ref = foo(*values) 1171 np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1172 1173 # now test with only empty tensors 1174 values = [empty for i in range(3)] 1175 traced = torch.jit.trace(foo, values) 1176 x = warmup_and_run_forward(traced, *values) 1177 self.assertLastGraphAllFused() 1178 ref = foo(*values) 1179 np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1180 1181 def test_cat_with_constant_dim(self): 1182 for device in self.devices: 1183 def foo(*args): 1184 v1 = torch.cat(args, dim=1) 1185 v2 = torch.cat([v1], dim=1) 1186 return v2 * v2 1187 1188 empty = torch.tensor([], device=device, dtype=torch.float32) 1189 inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)] 1190 traced = torch.jit.trace(foo, inputs) 1191 1192 x = warmup_and_run_forward(traced, *inputs) 1193 self.assertLastGraphAllFused() 1194 ref = foo(*inputs) 1195 np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1196 1197 def test_scalar(self): 1198 @torch.jit.script 1199 def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor: 1200 return torch.add(torch.add(x, y, alpha=a), z, alpha=b) 1201 1202 @torch.jit.script 1203 def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor: 1204 return torch.add(torch.add(x, y, alpha=a), z, alpha=b) 1205 1206 for test in (test_float, test_int): 1207 for data_type in self.dtypes: 1208 x, y, z = (torch.rand(4, dtype=data_type) for i in range(3)) 1209 a, b = 1, 2 1210 test(x, y, z, a, b) 1211 r = test(x, y, z, a, b) 1212 self.assertEqual(r, x + y * a + z * b) 1213 1214 def test_loop(self): 1215 @torch.jit.script 1216 def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor: 1217 b = y 1218 for i in range(0, z): 1219 a = x + y 1220 b = b + y 1221 return b 1222 1223 x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4) 1224 test(x, y, z) 1225 r = test(x, y, z) 1226 1227 def test_slice(self): 1228 def easy(x, y): 1229 a = x[0:512:2] 1230 b = y[0:512:2] 1231 return a + b 1232 1233 traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024))) 1234 1235 a = torch.ones(1024, 1024) 1236 x = traced(a, a) 1237 npr = a[0:512:2] 1238 npr = npr + npr 1239 np.testing.assert_allclose(npr.numpy(), x.numpy()) 1240 1241 def test_unsqueeze(self, N=256): 1242 def easy(x, y): 1243 a = torch.unsqueeze(x, 0) 1244 b = torch.unsqueeze(y, 0) 1245 return a + b 1246 1247 traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N))) 1248 1249 a = torch.rand(N, N) 1250 x = traced(a, a) 1251 npr = np.expand_dims(a, 0) 1252 npr = npr + npr 1253 np.testing.assert_allclose(npr, x.numpy()) 1254 1255 def _test_softmax(self, device): 1256 def test_softmax(x, y): 1257 a = F.softmax(x, dim=0, dtype=torch.float32) 1258 b = F.softmax(y, dim=0, dtype=torch.float32) 1259 c = F.softmax(x, dim=1, dtype=torch.float32) 1260 d = F.softmax(y, dim=1, dtype=torch.float32) 1261 return a + b + c + d 1262 1263 def test_softmax_neg_index(x, y): 1264 a = F.softmax(x, dim=-2, dtype=torch.float32) 1265 b = F.softmax(y, dim=-2, dtype=torch.float32) 1266 c = F.softmax(x, dim=-1, dtype=torch.float32) 1267 d = F.softmax(y, dim=-1, dtype=torch.float32) 1268 return a + b + c + d 1269 1270 def test_log_softmax(x, y): 1271 a = F.log_softmax(x, dim=0, dtype=torch.float32) 1272 b = F.log_softmax(y, dim=0, dtype=torch.float32) 1273 c = F.log_softmax(x, dim=1, dtype=torch.float32) 1274 d = F.log_softmax(y, dim=1, dtype=torch.float32) 1275 return a + b + c + d 1276 1277 for test in (test_softmax, test_log_softmax, test_softmax_neg_index): 1278 for data_type in self.dtypes: 1279 old = torch._C._jit_set_texpr_reductions_enabled(True) 1280 traced_input = torch.randn(2, 3, dtype=data_type, device=device) 1281 traced = torch.jit.trace(test, (traced_input, traced_input)) 1282 inp = torch.randn(2, 3, dtype=data_type, device=device) 1283 res = traced(inp, inp) 1284 # Use eager mode as reference. 1285 ref = test(inp, inp) 1286 np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06) 1287 torch._C._jit_set_texpr_reductions_enabled(old) 1288 1289 def test_softmax_cpu(self): 1290 self._test_softmax('cpu') 1291 1292 @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") 1293 @unittest.skip("global allocs are not supported yet.") 1294 def test_softmax_cuda(self): 1295 self._test_softmax('cuda') 1296 1297 def test_half_gelu(self): 1298 devices = ["cuda"] if torch.cuda.is_available() else [] 1299 1300 @torch.jit.script 1301 def bias_gelu(bias, y): 1302 x = bias + y 1303 return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) 1304 1305 for device in devices: 1306 a = torch.rand(1024, dtype=torch.half, device=device) 1307 b = torch.rand(1024, dtype=torch.half, device=device) 1308 traced = torch.jit.trace(bias_gelu, (a, b)) 1309 x = warmup_and_run_forward(traced, a, b) 1310 self.assertLastGraphAllFused() 1311 1312 def test_half_bn_relu(self): 1313 devices = ["cuda"] if torch.cuda.is_available() else [] 1314 1315 def foo(a, b, c): 1316 y = torch.nn.functional.batch_norm(a, b, c) 1317 z = y.relu() 1318 return z 1319 1320 for device in devices: 1321 a = torch.rand(16, 16, dtype=torch.half, device=device) 1322 b = torch.rand(16, dtype=torch.half, device=device) 1323 c = torch.rand(16, dtype=torch.half, device=device) 1324 traced = torch.jit.trace(foo, (a, b, c)) 1325 print(traced.graph) 1326 x = warmup_and_run_forward(traced, a, b, c) 1327 self.assertLastGraphAllFused() 1328 1329 def test_exp_pow(self): 1330 @torch.jit.script 1331 def do_exp(x, y, z): 1332 return ((x * y) * 2) * torch.pow(z, 2) 1333 1334 for device in self.devices: 1335 x = torch.rand(10, dtype=torch.double, device=device) 1336 y = torch.rand(10, dtype=torch.double, device=device) 1337 z = torch.rand(10, dtype=torch.double, device=device) 1338 traced = torch.jit.trace(do_exp, (x, y, z)) 1339 x = warmup_and_run_forward(traced, x, y, z) 1340 self.assertLastGraphAllFused() 1341 1342 def test_sin_pow(self): 1343 def test(x): 1344 return torch.sin(torch.pow(x, 0)) 1345 1346 for data_type, shape in itertools.product(self.dtypes, [[3], [5], [10]]): 1347 x = torch.rand(shape, dtype=data_type) 1348 scripted = torch.jit.script(test) 1349 out = warmup_and_run_forward(scripted, x) 1350 self.assertLastGraphAllFused() 1351 self.assertEqual(out, test(x)) 1352 1353 def test_transpose(self): 1354 @torch.jit.script 1355 def test(x, y, z): 1356 return x.transpose(0, 1) + y + z 1357 x = torch.rand(4, 5, 2, 3) 1358 y = torch.rand(5, 4, 2, 3) 1359 z = torch.rand(5, 4, 2, 3) 1360 ref = test(x, y, z) 1361 res = test(x, y, z) 1362 np.testing.assert_allclose(ref.numpy(), res.numpy()) 1363 1364 def test_sliced_stride(self): 1365 @torch.jit.script 1366 def test(x, y, z): 1367 return x + y + z 1368 x = torch.rand(16, 4, 2, 3)[::2] 1369 y = torch.rand(8, 4, 2, 3) 1370 z = torch.rand(8, 4, 2, 3) 1371 ref = test(x, y, z) 1372 res = test(x, y, z) 1373 np.testing.assert_allclose(ref.numpy(), res.numpy()) 1374 1375 @unittest.skip("dynamic shapes are not quite there yet") 1376 @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") 1377 def test_dynamic_shape(self): 1378 with num_profiled_runs(2): 1379 @torch.jit.script 1380 def test(x, y, z): 1381 return x * y * z 1382 x, y, z = (torch.rand(4, 8).cuda() for _ in range(3)) 1383 ref = test(x, y, z) 1384 _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)]) 1385 res = test(x, y, z) 1386 np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy()) 1387 1388 # A wild broadcast appears. 1389 x = torch.rand(4, 8).cuda() 1390 y = torch.rand(1, 8).cuda() 1391 z = torch.rand(4, 1).cuda() 1392 res = test(x, y, z) 1393 xn, yn, zn = (t.cpu().numpy() for t in (x, y, z)) 1394 np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) 1395 1396 # Mismatched shapes shouldn't reach codegen. 1397 x = torch.rand(4, 8).cuda() 1398 y = torch.rand(4, 8).cuda() 1399 z = torch.rand(5, 8).cuda() 1400 try: 1401 res = test(x, y, z) 1402 except RuntimeError as e: 1403 assert "The size of tensor a (4) must match" in e.args[0] 1404 1405 # Changing a static dimension fails guards. 1406 # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)] 1407 # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)] 1408 # res = test(x, y, z) 1409 # print(test.graph_for(x, y, z)) 1410 # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) 1411 1412 @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") 1413 def test_guard_fails(self): 1414 @torch.jit.script 1415 def test(x, y, z): 1416 return x * y * z 1417 r1 = test(*[torch.rand(4).cuda() for _ in range(3)]) 1418 r2 = test(*[torch.rand(4).cuda() for _ in range(3)]) 1419 r3 = test(*[torch.rand(4).cuda() for _ in range(3)]) 1420 r4 = test(*[torch.rand(7).cuda() for _ in range(3)]) 1421 1422 def test_bitwise_ops(self): 1423 def run_and(x, y): 1424 return x & (x & y) 1425 1426 def run_or(x, y): 1427 return x & (x | y) 1428 1429 def run_xor(x, y): 1430 return x ^ (x ^ y) 1431 1432 def run_lshift(x, y): 1433 return x & (x << y) 1434 1435 def run_rshift(x, y): 1436 return x & (x >> y) 1437 1438 fns = {run_and, run_or, run_xor, run_lshift, run_rshift} 1439 1440 for device in self.devices: 1441 for fn in fns: 1442 a = torch.ones(128, dtype=torch.int32, device=device) 1443 b = torch.zeros(128, dtype=torch.int32, device=device) 1444 inp = torch.ones(128, dtype=torch.int32, device=device) 1445 traced = torch.jit.trace(fn, (inp, inp)) 1446 x = warmup_and_run_forward(traced, a, b) 1447 self.assertLastGraphAllFused() 1448 y = fn(a, b) 1449 np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) 1450 1451 def test_where(self): 1452 def run_where(x, y): 1453 return torch.where(torch.gt(x, y), x, y) 1454 1455 for data_type in self.dtypes: 1456 a = torch.rand(1024, dtype=data_type) 1457 b = torch.rand(1024, dtype=data_type) 1458 zeros = torch.zeros(1024, dtype=data_type) 1459 traced = torch.jit.trace(run_where, (zeros, zeros)) 1460 x = warmup_and_run_forward(traced, a, b) 1461 self.assertLastGraphAllFused() 1462 y = run_where(a, b) 1463 np.testing.assert_allclose(x.float().numpy(), y.float().numpy()) 1464 1465 def test_multi_rand(self): 1466 for device in self.devices: 1467 def test(x): 1468 y = torch.rand_like(x) 1469 return (x + y) - (y - x) 1470 1471 _atol = 2e-3 1472 _rtol = 1e-5 1473 for data_type in self.dtypes: 1474 if data_type is torch.bfloat16: 1475 _atol = 2e-2 1476 a = torch.rand(4, dtype=data_type, device=device) 1477 scripted = torch.jit.script(test) 1478 out = warmup_and_run_forward(scripted, a) 1479 self.assertLastGraphAllFused() 1480 assert torch.allclose(out, 2 * a, atol=_atol, rtol=_rtol) 1481 1482 def test_mask(self): 1483 def test(x): 1484 return x.unsqueeze(1) == 0 1485 1486 for d in self.devices: 1487 for data_type in self.dtypes: 1488 x = torch.rand(4, dtype=data_type, device=d) > 0.5 1489 scripted = torch.jit.script(test) 1490 out = warmup_and_run_forward(scripted, x) 1491 self.assertLastGraphAllFused() 1492 assert torch.equal(out, test(x)) 1493 1494 def test_simple_add(self): 1495 val = torch._C._jit_get_te_generate_block_code() 1496 torch._C._jit_set_te_generate_block_code(True) 1497 fall_bk = torch._C._jit_texpr_fallback_allowed() 1498 torch._C._jit_texpr_set_fallback_allowed(True) 1499 1500 def simple(a, b): 1501 return torch.add(a, b) 1502 1503 a = torch.ones(256, 256) 1504 b = torch.ones(256, 256) 1505 traced = torch.jit.trace(simple, 1506 (torch.ones(256, 256), torch.ones(256, 256))) 1507 f = traced(a, b) 1508 f_test = np.full((256, 256), 2, dtype=float) 1509 np.testing.assert_allclose(f.numpy(), f_test) 1510 torch._C._jit_set_te_generate_block_code(val) 1511 torch._C._jit_texpr_set_fallback_allowed(fall_bk) 1512 1513 def test_strided_output_preserved(self): 1514 def foo(a, b): 1515 return a + b - a 1516 1517 # smaller, easier to debug example 1518 x = torch.arange(6) 1519 x = torch.as_strided(x, (2, 3), (1, 2)) 1520 total = 0 1521 for i in range(2): 1522 for j in range(3): 1523 x[i, j] = total 1524 total += 1 1525 foo_script = torch.jit.script(foo) 1526 foo_script(x, x) 1527 foo_script(x, x) 1528 out_s = foo_script(x, x) 1529 out_eager = foo(x, x) 1530 self.assertEqual(out_s, out_eager) 1531 self.assertEqual(out_s.stride(), out_eager.stride()) 1532 self.assertLastGraphAllFused() 1533 1534 # more dims 1535 N, C, H, W, = 2, 3, 4, 5 1536 x = torch.rand(N, C, H, W).to(memory_format=torch.channels_last) 1537 foo_script = torch.jit.script(foo) 1538 foo_script(x, x) 1539 foo_script(x, x) 1540 out_s = foo_script(x, x) 1541 out_eager = foo(x, x) 1542 self.assertEqual(out_s, out_eager) 1543 self.assertEqual(out_s.stride(), out_eager.stride()) 1544 self.assertLastGraphAllFused() 1545 1546 def test_alias_analysis_module(self): 1547 class AliasModule(nn.Module): 1548 def __init__(self) -> None: 1549 super().__init__() 1550 torch.manual_seed(1337) 1551 self.a = torch.randn(128, 128) 1552 self.b = torch.randn(128, 128) 1553 self.c = torch.randn(128, 128) 1554 1555 def forward(self, x, y, z): 1556 z = z + self.a 1557 self.b.add_(y) 1558 w = z + self.a 1559 z = w + x 1560 return z 1561 x = torch.randn(128, 128) 1562 1563 def getModule(script): 1564 am = AliasModule() 1565 if script: 1566 return torch.jit.script(am) 1567 return am 1568 1569 am = getModule(False) 1570 am_s = getModule(True) 1571 ref = am(x, x, x) 1572 test = am_s(x, x, x) 1573 torch.testing.assert_close(ref, test) 1574 1575 # Now do the aliasing 1576 am.a = am.b 1577 ref = am(x, x, x) 1578 1579 am_s.a = am_s.b 1580 test = am_s(x, x, x) 1581 1582 torch.testing.assert_close(ref, test) 1583 1584 def test_alias_analysis_inputs(self): 1585 class AliasModule(nn.Module): 1586 def __init__(self) -> None: 1587 super().__init__() 1588 torch.manual_seed(1337) 1589 self.a = torch.randn(128, 128) 1590 self.b = torch.randn(128, 128) 1591 self.c = torch.randn(128, 128) 1592 1593 def forward(self, x, y, z): 1594 x.add_(y) 1595 w = z + self.a 1596 z = w + x 1597 return z 1598 1599 def getModule(script): 1600 am = AliasModule() 1601 if script: 1602 return torch.jit.script(am) 1603 return am 1604 am = getModule(False) 1605 am_s = getModule(True) 1606 1607 torch.manual_seed(1337) 1608 x = torch.randn(128, 128) 1609 ref = am(x, x, x) 1610 1611 torch.manual_seed(1337) 1612 x = torch.randn(128, 128) 1613 test = am_s(x, x, x) 1614 1615 torch.testing.assert_close(ref, test) 1616 1617 def test_alias_analysis_input_and_module(self): 1618 class AliasModule(nn.Module): 1619 def __init__(self) -> None: 1620 super().__init__() 1621 torch.manual_seed(1337) 1622 self.a = torch.randn(128, 128) 1623 self.b = torch.randn(128, 128) 1624 self.c = torch.randn(128, 128) 1625 1626 def forward(self, x, y, z): 1627 x.add_(y) 1628 w = z + self.b 1629 z = w + x 1630 return z 1631 1632 def getModule(script): 1633 am = AliasModule() 1634 if script: 1635 return torch.jit.script(am) 1636 return am 1637 am = getModule(False) 1638 am_s = getModule(True) 1639 1640 torch.manual_seed(1337) 1641 x = torch.randn(128, 128) 1642 am.b = x 1643 ref = am(x, x, x) 1644 1645 torch.manual_seed(1337) 1646 x = torch.randn(128, 128) 1647 am_s.b = x 1648 test = am_s(x, x, x) 1649 1650 torch.testing.assert_close(ref, test) 1651 1652 def test_multiple_outputs(self): 1653 for device in self.devices: 1654 # A bug reported internally similar to the one reported in #48533 1655 def foo(a, b, c): 1656 t_next = c + 1 1657 t5 = t_next * b 1658 t6 = torch.unsqueeze(t_next, 1) 1659 t7 = a * t6 1660 return (t7, t5, t_next) 1661 1662 for data_type in self.dtypes: 1663 a = torch.rand(20, 20, dtype=data_type, device=device) 1664 b = torch.rand(20 * 29, dtype=data_type, device=device).as_strided([20], [29]) 1665 c = torch.ones(20, dtype=torch.int64, device=device) 1666 traced = torch.jit.trace(foo, (a, b, c)) 1667 ref = foo(a, b, c) 1668 exp = traced(a, b, c) 1669 exp = traced(a, b, c) 1670 self.assertEqual(ref, exp) 1671 1672 def test_propagated_mem_layout(self): 1673 def foo(a, b, c): 1674 t_next = c + 1 1675 t5 = t_next * b 1676 t7 = a * t5 1677 return t7 1678 1679 def foo_multi_outputs(a, b, c): 1680 t_next = c + 1 1681 t5 = b * t_next 1682 t7 = a * t5 1683 return (t7, t5, t_next) 1684 1685 def foo_multi_outputs_i_nhwc_o_nchw(a, b, c): 1686 t_next = c + 1 1687 t5 = b * t_next 1688 t7 = a * t5 1689 t8 = t7.to(memory_format=torch.contiguous_format) 1690 return (t8, t7, t5, t_next) 1691 1692 def run_foo_case(foo, a, b, c): 1693 traced_contiguous = torch.jit.trace(foo, (a, b, c)) 1694 ref = foo(a, b, c) 1695 exp = traced_contiguous(a, b, c) 1696 exp = traced_contiguous(a, b, c) 1697 self.assertEqual(ref, exp) 1698 1699 mem_layouts = list(itertools.product([torch.contiguous_format, torch.channels_last], repeat=3)) 1700 shapes = [(2, 3, 4, 5), (2, 1, 1, 5), (1, 1, 1, 1)] 1701 permutes = [(0, 3, 2, 1), (0, 3, 1, 2)] 1702 funcs = [foo, foo_multi_outputs, foo_multi_outputs_i_nhwc_o_nchw] 1703 configs = itertools.product(funcs, shapes, mem_layouts, permutes) 1704 for strategy in ["STATIC", "DYNAMIC"]: 1705 old_strategy = torch.jit.set_fusion_strategy([(strategy, 10)]) 1706 for _func, _shape, _mem_layouts, _permute in configs: 1707 a = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[0]) 1708 b = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[1]) 1709 c = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[2]) 1710 run_foo_case(_func, a, b, c) 1711 1712 a = a.permute(dims=_permute) 1713 b = b.permute(dims=_permute) 1714 c = c.permute(dims=_permute) 1715 run_foo_case(_func, a, b, c) 1716 1717 torch.jit.set_fusion_strategy(old_strategy) 1718 1719if __name__ == '__main__': 1720 run_tests() 1721