1# Owner(s): ["module: functorch"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8import gc 9from unittest import skip, skipIf 10 11from attn_ft import BertSelfAttention as BertSelfAttentionA, Linear 12from attn_positional import BertSelfAttention as BertSelfAttentionB 13 14import torch 15from functorch._C import dim as _C 16from functorch.dim import ( 17 Dim, 18 DimensionBindError, 19 DimList, 20 dimlists, 21 dims, 22 stack, 23 Tensor, 24) 25from torch.testing._internal.common_utils import ( 26 run_tests, 27 skipIfTorchDynamo, 28 TEST_CUDA, 29 TestCase, 30) 31 32 33try: 34 from torchvision.models import resnet18 35except ImportError: 36 resnet18 = None 37 38_test_c, _parse_test, _set_pointwise_optimize = ( 39 _C._test_c, 40 _C._parse_test, 41 _C._set_pointwise_optimize, 42) 43 44from contextlib import contextmanager 45from time import perf_counter 46 47 48measure_perf = False 49if measure_perf: 50 from torchdim.magic_trace import magic_trace 51else: 52 53 @contextmanager 54 def magic_trace(*args, **kwargs): 55 yield 56 57 58@contextmanager 59def measure(what): 60 b = perf_counter() 61 yield 62 e = perf_counter() 63 print(f"{what}: {e - b:.20f} seconds") 64 65 66def triu(A): 67 i, j = dims() 68 a = A[i, j] 69 zero = torch.tensor(0, dtype=torch.float) # XXX - torch.where is janky... 70 return torch.where(i <= j, a, zero).order(i, j) 71 72 73def gpu_time(lmb, name, r=100): 74 b = torch.cuda.Event(enable_timing=True) 75 e = torch.cuda.Event(enable_timing=True) 76 # with magic_trace(name + ".fxt"): 77 for _ in range(r): 78 lmb() 79 b.record() 80 for _ in range(r): 81 lmb() 82 e.record() 83 e.synchronize() 84 elapsed = b.elapsed_time(e) 85 # with torch.profiler.profile(schedule=torch.profiler.schedule( 86 # wait=0, 87 # warmup=1, 88 # active=2), on_trace_ready=tensorboard_trace_handler(name), with_stack=True) as profiler: 89 # for _ in range(3): 90 # lmb() 91 # profiler.step() 92 print(name, elapsed / r) 93 return elapsed / r 94 95 96@skipIfTorchDynamo("Bad interaction") 97class TestMin(TestCase): 98 def setUp(self): 99 super().setUp() 100 gc.disable() 101 gc.collect() 102 self.interesting = set() 103 for o in gc.get_objects(): 104 if isinstance(o, (torch.Tensor, Dim, Tensor, DimList)): 105 self.interesting.add(id(o)) 106 if "cuda" in self._testMethodName: 107 self.mem_allocated = torch.cuda.memory_allocated() 108 109 def tearDown(self): 110 interesting = [] 111 for o in gc.get_objects(): 112 if ( 113 isinstance(o, (torch.Tensor, Dim, Tensor, DimList)) 114 and id(o) not in self.interesting 115 ): 116 interesting.append(o) 117 118 extra_memory = 0 119 if "cuda" in self._testMethodName: 120 extra_memory += torch.cuda.memory_allocated() - self.mem_allocated 121 122 # nolevels = _n_levels_in_use() == 0 123 if extra_memory != 0 or len(interesting) != 0: 124 import refcycle 125 126 refcycle.garbage().export_image("garbage.pdf") 127 gc.collect() 128 # assert nolevels, f"cleanup failed? {_n_levels_in_use()}" 129 assert extra_memory == 0, f"extra cuda memory left allocated: {extra_memory}" 130 assert len(interesting) == 0, ( 131 f"extra torch.Tensor, Dim, or Tensor left allocated: {len(interesting)} objects of types:" 132 f" { [type(t) for t in interesting] }" 133 ) 134 135 def test_manual_stuff(self): 136 A_ = torch.rand(3, 4) 137 B_ = torch.rand(4, 5) 138 i, j, k = dims() 139 A = A_[i, k] 140 B = B_[k, j] 141 C = (A.expand(j) * B.expand(i)).sum(k) 142 self.assertTrue(torch.allclose(C.order(i, j), torch.mm(A_, B_))) 143 self.assertTrue(torch.allclose(torch.triu(A_, 0), triu(A_))) 144 145 D_ = torch.randint(0, 3, (6,)) 146 d = dims() 147 D = D_[d] 148 149 A.index([i], [D]).order(k, d) 150 151 def attn( 152 self, 153 batch_size=1, 154 sequence_length=4, 155 hidden_size=6, 156 num_attention_heads=3, 157 linear=Linear, 158 device=None, 159 time=False, 160 ): 161 def maybe_to(x): 162 return x if device is None else x.to(device) 163 164 attention_probs_dropout_prob = 0.0 165 A = maybe_to( 166 BertSelfAttentionA( 167 hidden_size, 168 num_attention_heads, 169 attention_probs_dropout_prob, 170 linear=linear, 171 ) 172 ) 173 B = maybe_to( 174 BertSelfAttentionB( 175 hidden_size, num_attention_heads, attention_probs_dropout_prob 176 ) 177 ) 178 179 A.load_state_dict(B.state_dict()) 180 hidden_state = maybe_to(torch.rand(batch_size, sequence_length, hidden_size)) 181 b_out = B(hidden_state) 182 a_out = A(hidden_state) 183 self.assertTrue( 184 torch.allclose(a_out, b_out) 185 ) # why does a simple matmul not do the right thing? 186 187 if time: 188 gpu_time(lambda: B(hidden_state), "positional", r=3) 189 gpu_time(lambda: A(hidden_state), "first_class", r=3) 190 191 for approach in ("relative_key", "relative_key_query"): 192 A = maybe_to( 193 BertSelfAttentionA( 194 hidden_size, 195 num_attention_heads, 196 attention_probs_dropout_prob, 197 approach, 198 sequence_length, 199 linear=linear, 200 ) 201 ) 202 B = maybe_to( 203 BertSelfAttentionB( 204 hidden_size, 205 num_attention_heads, 206 attention_probs_dropout_prob, 207 approach, 208 sequence_length, 209 ) 210 ) 211 A.load_state_dict(B.state_dict()) 212 213 hidden_state = maybe_to( 214 torch.rand(batch_size, sequence_length, hidden_size) 215 ) 216 b_out = B(hidden_state) 217 a_out = A(hidden_state) 218 self.assertTrue(torch.allclose(a_out, b_out)) 219 220 if time: 221 gpu_time(lambda: B(hidden_state), "positional", r=3) 222 gpu_time(lambda: A(hidden_state), "first_class", r=3) 223 224 A = maybe_to( 225 BertSelfAttentionA( 226 hidden_size, 227 num_attention_heads, 228 attention_probs_dropout_prob, 229 None, 230 None, 231 linear=linear, 232 ) 233 ) 234 B = maybe_to( 235 BertSelfAttentionB( 236 hidden_size, 237 num_attention_heads, 238 attention_probs_dropout_prob, 239 None, 240 None, 241 ) 242 ) 243 A.load_state_dict(B.state_dict()) 244 245 hidden_state = maybe_to(torch.rand(batch_size, sequence_length, hidden_size)) 246 past_key_value = ( 247 maybe_to( 248 torch.rand( 249 batch_size, 250 num_attention_heads, 251 sequence_length, 252 hidden_size // num_attention_heads, 253 ) 254 ), 255 maybe_to( 256 torch.rand( 257 batch_size, 258 num_attention_heads, 259 sequence_length, 260 hidden_size // num_attention_heads, 261 ) 262 ), 263 ) 264 265 b_out = B(hidden_state, past_key_value=past_key_value) 266 a_out = A(hidden_state, past_key_value=past_key_value) 267 self.assertTrue(torch.allclose(a_out, b_out)) 268 269 if time: 270 gpu_time(lambda: B(hidden_state), "positional", r=3) 271 gpu_time(lambda: A(hidden_state), "first_class", r=3) 272 273 def test_attn(self): 274 self.attn() 275 276 def test_inplace(self): 277 # some embeddings table 278 embeddings = torch.zeros(10, 3) 279 280 # some sparse updates to the embeddings 281 indices = torch.arange(2) + 1 282 values = torch.rand(2, 3) 283 284 i, n, f = dims() 285 286 embeddings[indices[i], f] += values[i, f] 287 288 def test_adapt(self): 289 def f(): 290 ci, co = dims() 291 292 # python 3.11 adapts bytecode after a number of iterations 293 # check that we still match names correctly 294 for i in range(10): 295 f() 296 297 @skipIf(not TEST_CUDA, "no CUDA") 298 def test_attn_cuda(self): 299 # size from the BERT paper, 90% pretraining of sequence length 128 300 self.attn( 301 batch_size=256, 302 hidden_size=768, 303 sequence_length=128, 304 num_attention_heads=12, 305 device="cuda", 306 time=measure_perf, 307 linear=torch.nn.Linear, 308 ) 309 310 def test_stack(self): 311 i, j, d = dims() 312 A = torch.rand(4, 5) 313 r = stack([A[i, j]], d, j) 314 # a, b = r.unbind(d) 315 # self.assertTrue(torch.allclose(a.order(i, j), i.expand(j).order(i, j))) 316 # self.assertTrue(torch.allclose(b.order(i, j), j.expand(i).order(i, j))) 317 318 def test_max(self): 319 ap = torch.rand(2, 3, 2) 320 i, j, k = dims() 321 a = ap[i, j, k] 322 r, i0 = a.max(dim=k) 323 self.assertTrue(torch.allclose(r.order(i, j), ap.max(2)[0])) 324 325 def test_mm(self): 326 i, j, k, q = dims() 327 a = torch.rand(3, 4) 328 b = torch.rand(4, 5) 329 a_ = a[i, k] 330 b_ = b[k, j] 331 q.size = 1 332 r = (a_.expand(j, q) * b_.expand(i, q)).sum(k).order(q, i, j) 333 # r = (a_*b_).sum(k).order(q, i, j) 334 # print(r) 335 # print(a @ b) 336 337 def test_with_dims_split(self): 338 a = torch.arange(3 * 12).view(3, 12) 339 i, j, k = dims() 340 k.size = 4 341 r = a[i, [j, k]] 342 x = r.order(i, [j, k]) 343 self.assertTrue(torch.allclose(a, x)) 344 345 def test_hello(self): 346 A = torch.rand(3, 4) 347 B = torch.rand(4, 5) 348 i, j, k = dims() 349 350 # r = A[i]*4 351 r = (A[i, k] * B[k, j]).sum(k).order(i, j) 352 assert torch.allclose(r, A @ B) 353 354 assert A.sum() == A[i].sum((0, i)) 355 assert A.sum() == A[i].sum((-1, i)) 356 357 assert torch.allclose(A.sum(), A[i].sum(0, keepdim=True).sum((0, i))) 358 assert torch.allclose(A[i].std(i, True), A.std(0, True)) 359 360 assert torch.allclose(A[i, k].max(i)[0].order(k), A.max(0)[0]) 361 assert torch.allclose(A.sort(1)[0], A[i, k].sort(k)[0].order(i, k)) 362 # XXX - chunk changes the size of a dimension, has to take a new dimension... 363 # assert torch.allclose(A.chunk(2,1)[0], A[i, k].chunk(2, k)[0].order(i, k)) 364 assert torch.allclose(A[i].renorm(1, i, 7).order(i), A.renorm(1, 0, 7)) 365 kk = dims() 366 # assert torch.allclose( torch.stack([A, A], 1), stack([A[i,k], A[i, k]], kk, k).order(i, kk, k)) 367 368 k2 = dims() 369 # r = cat((A[i, k], A[i,k]), k, k2) 370 # assert torch.allclose(torch.cat([A, A], 1), r.order(i, k2)) 371 # assert k2.size == 2*k.size 372 373 assert torch.allclose(A.expand(5, -1, -1), A[i, k].expand(j).order(j, i, k)) 374 z = dims() 375 C = torch.arange(2) 376 assert torch.allclose(A[:, 0:2], A[i, k].index(k, C[z]).order(i, z)) 377 378 o, l = dims() 379 o.size = 2 380 r = A[i, k].index(k, (o, l)) 381 assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2)) 382 rr = r.index((o, l), k) 383 assert torch.allclose(A, rr.order(i, k)) 384 385 r = i + k - 1 386 r2 = torch.arange(3)[:, None] + torch.arange(4)[None, :] - 1 387 assert torch.allclose(r.order(i, k), r2) 388 389 # test with ... 390 assert torch.allclose(A.T, A[..., k].order(k)) 391 392 # test with dimlist 393 a_, b_ = dimlists() 394 assert torch.allclose(A[i, a_].order(*a_, i), A.T) 395 # test with one bound dimlist 396 assert torch.allclose(A[:, a_].order(*a_), A.T) 397 # test with a dimlist that will end up empty 398 assert torch.allclose(A[i, b_, k].order(i, k, *b_), A) 399 # test with too few things 400 (A[i] + i) 401 assert torch.allclose((A[i] + i).order(i), A + torch.arange(3)[:, None]) 402 # test with too many elements 403 try: 404 A[1, ..., 1, 1] 405 raise NotImplementedError 406 except IndexError: 407 pass 408 c, d = dims() 409 c.size = 2 410 assert torch.allclose(A[i, [c, d]].order(i, c, d), A.view(3, 2, 2)) 411 412 assert torch.allclose( 413 A[c + 1, c + 0].order(c), A[torch.arange(2) + 1, torch.arange(2)] 414 ) 415 try: 416 A[..., 3, ...] 417 raise NotImplementedError 418 except DimensionBindError: 419 pass 420 421 C = torch.rand(4, 7) 422 c_, x, y, z = dims() 423 424 a, b, c = C.split((3, 3, 1), dim=1) 425 s = dims() 426 ref = C.split((3, 3, 1), dim=1) 427 t = C[s, c_].split((x, y, z), dim=c_) 428 for a, b, d in zip(ref, t, (x, y, z)): 429 assert torch.allclose(a, b.order(s, d)) 430 431 D = torch.rand(3, 4, 5) 432 assert torch.allclose( 433 D.transpose(0, 1).flatten(1, 2), D[i, k, j].order((i, j)).order(k) 434 ) 435 436 r = [id(x) for x in torch.rand_like(A[i, k]).dims] 437 assert id(i) in r and id(k) in r 438 r = [id(x) for x in torch.nn.functional.dropout(A[i, k]).dims] 439 assert id(i) in r and id(k) in r 440 441 def test_simple(self): 442 i, j, k = dims() 443 x = torch.rand(3, 4) 444 z = x[i, j] 445 (z + z + z + z) 446 (z.order(i, j)) 447 448 def test_mm_fuse(self): 449 i, j, k = dims() 450 A = torch.rand(3, 4) 451 B = torch.rand(4, 5) 452 453 C = (A[i, k] * B[k, j]).sum(k).order(i, j) 454 assert torch.allclose(C, A @ B) 455 456 def test_time_mm_fuse(self): 457 i, j, k = dims() 458 A = torch.rand(3, 4) 459 B = torch.rand(4, 5) 460 461 for _ in range(10): 462 r0 = A @ B 463 464 for _ in range(10): 465 a = A[i, k] 466 b = B[k, j] 467 r1 = (a * b).sum(k) 468 469 with measure("pp"): 470 for _ in range(10000): 471 A @ B 472 # magic_trace_stop_indicator() 473 474 with measure("fc"): 475 for _ in range(10000): 476 (A[i, k] * B[k, j]).sum(k).order(i, j) 477 478 with magic_trace("f.fxt"): 479 for _ in range(10000): 480 (A[i, k] * B[k, j]).sum(k).order(i, j) 481 482 with magic_trace("p.fxt"): 483 for _ in range(10000): 484 A @ B 485 486 # magic_trace_stop_indicator() 487 488 assert torch.allclose(r1.order(i, j), r0) 489 490 def test_compare_dims(self): 491 i, j = dims() 492 i.size = 3 493 j.size = 4 494 (i < j) # noqa: B015 495 496 def test_c(self): 497 _test_c() 498 499 def test_seg(self): 500 A = torch.rand(3, 4) 501 i, k = dims() 502 i.size = 4 503 k.size = 3 504 r = i + k - 1 505 506 def test_expand(self): 507 A = torch.rand(3, 4) 508 i = dims() 509 assert list(A[i].expand(2, 4).order(i).size()) == [3, 2, 4] 510 511 def test_parse(self): 512 self.assertEqual(("x", None, None, None), _parse_test(1, 0, "x")) 513 self.assertEqual(("x", None, "y", None), _parse_test(1, 0, "x", c="y")) 514 self.assertEqual(("x", None, "y", "z"), _parse_test(1, 0, "x", d="z", c="y")) 515 516 self.assertEqual(("x", "4", None, None), _parse_test(2, 0, "x", b="4")) 517 self.assertEqual(("x", "y", "z", "q"), _parse_test(2, 0, "x", "y", "z", "q")) 518 with self.assertRaises(TypeError): 519 _parse_test(2, 0, "x", "y", "z", "q", "5") 520 with self.assertRaises(TypeError): 521 _parse_test(2, 0, "x", "y", b="y") 522 523 with self.assertRaises(TypeError): 524 _parse_test(2, 0, "x", c="y") 525 with self.assertRaises(TypeError): 526 _parse_test(2, 0, "x") 527 528 def test_network(self): 529 if resnet18 is None: 530 self.skipTest("no torchvision") 531 rn = resnet18( 532 norm_layer=lambda x: torch.nn.BatchNorm2d(x, track_running_stats=False) 533 ) 534 rn.train() 535 img = torch.rand(1, 1, 2, 3, 224, 224) 536 imgf = img.view(2, 3, 224, 224) 537 538 i, j = dims() 539 r = rn(img[i, j]) 540 r = r.order(i, j).view(2, 1000) 541 r2 = rn(imgf) 542 assert torch.allclose(r2, r, atol=1e-06) 543 544 def test_dim_args(self): 545 a = dimlists() 546 assert isinstance(a, DimList) 547 a = dims() 548 b = dimlists() 549 assert isinstance(a, Dim) 550 assert isinstance(b, DimList) 551 assert str(a) == "a" 552 a, b = dims(sizes=[3, 4]) 553 assert a.size == 3 554 assert b.size == 4 555 a = dims(sizes=[3]) 556 b = dimlists(sizes=[4]) 557 assert len(b) == 4 558 a = dims() 559 b = dimlists(sizes=[[4, 5]]) 560 assert b[0].size == 4 561 assert b[1].size == 5 562 563 def test_diag(self): 564 i = dims() 565 A = torch.rand(4, 4) 566 (A[i, i]) 567 568 def test_softmax_split(self): 569 a = torch.rand(16) 570 g, i = dims(sizes=[2, None]) 571 a2 = a[[i, g],] 572 573 m_b, _ = a2.max(i) 574 f_b = torch.exp(a2 - m_b) 575 l_b = f_b.sum(i) 576 577 m, _ = m_b.max(g) 578 c = torch.exp(m_b - m) 579 f = (c * f_b).order((i, g)) 580 l = (c * l_b).sum(g) 581 assert torch.allclose(f / l, torch.nn.functional.softmax(a, dim=0)) 582 583 def test_index(self): 584 A = torch.rand(3, 4) 585 B = torch.rand(4, 5) 586 i, j, k = dims() 587 588 o, l = dims() 589 o.size = 2 590 r = A[i, k].index(k, [o, l]) 591 assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2)) 592 rr = r.index([o, l], k) 593 assert torch.allclose(A, rr.order(i, k)) 594 z = dims() 595 C = torch.arange(2) 596 x = A[i, k].index(k, C[z]).order(i, z) 597 assert torch.allclose(A[:, 0:2], x) 598 599 C = torch.rand(3, 4, 5) 600 ik = dims() 601 assert torch.allclose( 602 C.index((0, 2), ik).order(ik), C.permute(0, 2, 1).reshape(15, 4) 603 ) 604 605 # failures that came up from monkey patching some operators... 606 def test_monkey(self): 607 A = torch.rand(3, 4) 608 A[0, 0] = 5 609 x = torch.randn(3, 4, 4, 4, 3) 610 x_clone1 = x.clone() 611 ia = torch.tensor([0, 2, 1]) 612 ib = torch.tensor([0, 2, 1]) 613 first_shape = x[:, ia, None, ib, 0].shape 614 x_clone1[:, ia, None, ib, 0] = torch.randn(first_shape).to(x_clone1) 615 x = torch.autograd.Variable(torch.tensor([])) 616 z = torch.autograd.Variable(torch.IntTensor([1, 2, 3])) 617 a = [z[2], z[0] + 3] 618 x.new(a) 619 # self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4]) 620 621 def test_index_placement(self): 622 A = torch.rand(1, 2, 3, 4) 623 624 i, j = dims(sizes=[2, 4]) 625 626 a = A[:, i + 0, :, j + 0] 627 r = a.order(i, j) 628 629 assert torch.allclose(A.permute(1, 3, 0, 2), r) 630 631 def test_order(self): 632 i, j = dims() 633 A = torch.rand(3, 4, 5) 634 assert torch.allclose(A[i].order(1, i), A.permute(2, 0, 1)) 635 636 def test_mask(self): 637 a = torch.rand(5) 638 i, j = dims(sizes=[a.size(0), a.size(0)]) 639 ((i >= j) * a[i]).sum(j).order(i) 640 641 def test_eq(self): 642 i, j = dims(sizes=[3, 3]) 643 assert (i == j).sum((i, j)) == 3 644 645 def test_dims_with_size(self): 646 x = dims(3) 647 assert len(x) == 3 and isinstance(x[0], Dim) 648 649 class Foo: 650 pass 651 652 y = Foo() 653 z, y.x, q = dims(3) 654 assert str(z) == "z" 655 assert str(y.x) == "d1" 656 assert str(q) == "d2" 657 658 def test_dir(self): 659 i, j = dims(sizes=[3, 3]) 660 dir(i <= j) 661 662 def test_doc(self): 663 assert Tensor.clamp.__doc__ == torch.Tensor.clamp.__doc__ 664 665 def test_embed(self): 666 embeddings = torch.rand(8, 32) 667 ids = torch.tensor([1, 0, 3, 4]) 668 669 # slow but Pythonic 670 values_ = torch.empty(4, 32) 671 for batch in range(ids.size(0)): 672 for feature in range(embeddings.size(1)): 673 values_[batch, feature] = embeddings[ids[batch], feature] 674 675 # with torchdim, single indexing kernel 676 batch, feature = dims(2) 677 values = embeddings[ids[batch], feature].order(batch, feature) 678 679 assert torch.allclose(values, values_) 680 681 def test_functorch(self): 682 A = torch.rand(3, 4, 5) 683 B = torch.rand(3, 4, 5) 684 C = torch.rand(5, 2) 685 686 i, j = dims() 687 688 AA = torch.mm(A[i], C) # 3, 4, 2 689 BB = torch.mm(B[j], C) # 3, 4, 2 690 assert list(torch.mm(AA.T, BB).order(i, j).shape) == [3, 3, 2, 2] 691 692 def test_permute_orig(self): 693 d = dims(1) 694 t_fc = torch.rand(1, 2, 3, 4)[d] 695 assert t_fc.permute(dims=(1, 0, 2)).shape == t_fc.permute(1, 0, 2).shape 696 697 def test_order_keyword(self): 698 d = dims(1) 699 t = torch.rand(3)[d] 700 self.assertRaises(TypeError, lambda: t.order(wrong=3)) 701 702 def test_big_split(self): 703 total = 0 704 l = [] 705 while total < 6400: 706 l.append(torch.randint(2, 10, (1,)).item()) 707 total += l[-1] 708 x = torch.randn(total, 1) 709 x.split(l, 0) 710 711 712skip_functorch_only = ["test_time_mm_fuse", "test_attn_cuda"] 713 714 715class TestMinFunctorchOnly(TestMin): 716 def setUp(self): 717 super().setUp() 718 _set_pointwise_optimize(False) 719 720 def tearDown(self): 721 _set_pointwise_optimize(True) 722 super().tearDown() 723 724 725for n in skip_functorch_only: 726 setattr(TestMinFunctorchOnly, n, skip("skip_functorch_only")(lambda self: None)) 727 728if __name__ == "__main__": 729 run_tests() 730