xref: /aosp_15_r20/external/pytorch/test/functorch/test_dims.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8import gc
9from unittest import skip, skipIf
10
11from attn_ft import BertSelfAttention as BertSelfAttentionA, Linear
12from attn_positional import BertSelfAttention as BertSelfAttentionB
13
14import torch
15from functorch._C import dim as _C
16from functorch.dim import (
17    Dim,
18    DimensionBindError,
19    DimList,
20    dimlists,
21    dims,
22    stack,
23    Tensor,
24)
25from torch.testing._internal.common_utils import (
26    run_tests,
27    skipIfTorchDynamo,
28    TEST_CUDA,
29    TestCase,
30)
31
32
33try:
34    from torchvision.models import resnet18
35except ImportError:
36    resnet18 = None
37
38_test_c, _parse_test, _set_pointwise_optimize = (
39    _C._test_c,
40    _C._parse_test,
41    _C._set_pointwise_optimize,
42)
43
44from contextlib import contextmanager
45from time import perf_counter
46
47
48measure_perf = False
49if measure_perf:
50    from torchdim.magic_trace import magic_trace
51else:
52
53    @contextmanager
54    def magic_trace(*args, **kwargs):
55        yield
56
57
58@contextmanager
59def measure(what):
60    b = perf_counter()
61    yield
62    e = perf_counter()
63    print(f"{what}: {e - b:.20f} seconds")
64
65
66def triu(A):
67    i, j = dims()
68    a = A[i, j]
69    zero = torch.tensor(0, dtype=torch.float)  # XXX - torch.where is janky...
70    return torch.where(i <= j, a, zero).order(i, j)
71
72
73def gpu_time(lmb, name, r=100):
74    b = torch.cuda.Event(enable_timing=True)
75    e = torch.cuda.Event(enable_timing=True)
76    # with magic_trace(name + ".fxt"):
77    for _ in range(r):
78        lmb()
79    b.record()
80    for _ in range(r):
81        lmb()
82    e.record()
83    e.synchronize()
84    elapsed = b.elapsed_time(e)
85    # with torch.profiler.profile(schedule=torch.profiler.schedule(
86    #     wait=0,
87    #     warmup=1,
88    #     active=2), on_trace_ready=tensorboard_trace_handler(name), with_stack=True) as profiler:
89    #     for _ in range(3):
90    #         lmb()
91    #         profiler.step()
92    print(name, elapsed / r)
93    return elapsed / r
94
95
96@skipIfTorchDynamo("Bad interaction")
97class TestMin(TestCase):
98    def setUp(self):
99        super().setUp()
100        gc.disable()
101        gc.collect()
102        self.interesting = set()
103        for o in gc.get_objects():
104            if isinstance(o, (torch.Tensor, Dim, Tensor, DimList)):
105                self.interesting.add(id(o))
106        if "cuda" in self._testMethodName:
107            self.mem_allocated = torch.cuda.memory_allocated()
108
109    def tearDown(self):
110        interesting = []
111        for o in gc.get_objects():
112            if (
113                isinstance(o, (torch.Tensor, Dim, Tensor, DimList))
114                and id(o) not in self.interesting
115            ):
116                interesting.append(o)
117
118        extra_memory = 0
119        if "cuda" in self._testMethodName:
120            extra_memory += torch.cuda.memory_allocated() - self.mem_allocated
121
122        #  nolevels = _n_levels_in_use() == 0
123        if extra_memory != 0 or len(interesting) != 0:
124            import refcycle
125
126            refcycle.garbage().export_image("garbage.pdf")
127        gc.collect()
128        # assert nolevels, f"cleanup failed? {_n_levels_in_use()}"
129        assert extra_memory == 0, f"extra cuda memory left allocated: {extra_memory}"
130        assert len(interesting) == 0, (
131            f"extra torch.Tensor, Dim, or Tensor left allocated: {len(interesting)} objects of types:"
132            f" { [type(t) for t in interesting] }"
133        )
134
135    def test_manual_stuff(self):
136        A_ = torch.rand(3, 4)
137        B_ = torch.rand(4, 5)
138        i, j, k = dims()
139        A = A_[i, k]
140        B = B_[k, j]
141        C = (A.expand(j) * B.expand(i)).sum(k)
142        self.assertTrue(torch.allclose(C.order(i, j), torch.mm(A_, B_)))
143        self.assertTrue(torch.allclose(torch.triu(A_, 0), triu(A_)))
144
145        D_ = torch.randint(0, 3, (6,))
146        d = dims()
147        D = D_[d]
148
149        A.index([i], [D]).order(k, d)
150
151    def attn(
152        self,
153        batch_size=1,
154        sequence_length=4,
155        hidden_size=6,
156        num_attention_heads=3,
157        linear=Linear,
158        device=None,
159        time=False,
160    ):
161        def maybe_to(x):
162            return x if device is None else x.to(device)
163
164        attention_probs_dropout_prob = 0.0
165        A = maybe_to(
166            BertSelfAttentionA(
167                hidden_size,
168                num_attention_heads,
169                attention_probs_dropout_prob,
170                linear=linear,
171            )
172        )
173        B = maybe_to(
174            BertSelfAttentionB(
175                hidden_size, num_attention_heads, attention_probs_dropout_prob
176            )
177        )
178
179        A.load_state_dict(B.state_dict())
180        hidden_state = maybe_to(torch.rand(batch_size, sequence_length, hidden_size))
181        b_out = B(hidden_state)
182        a_out = A(hidden_state)
183        self.assertTrue(
184            torch.allclose(a_out, b_out)
185        )  # why does a simple matmul not do the right thing?
186
187        if time:
188            gpu_time(lambda: B(hidden_state), "positional", r=3)
189            gpu_time(lambda: A(hidden_state), "first_class", r=3)
190
191        for approach in ("relative_key", "relative_key_query"):
192            A = maybe_to(
193                BertSelfAttentionA(
194                    hidden_size,
195                    num_attention_heads,
196                    attention_probs_dropout_prob,
197                    approach,
198                    sequence_length,
199                    linear=linear,
200                )
201            )
202            B = maybe_to(
203                BertSelfAttentionB(
204                    hidden_size,
205                    num_attention_heads,
206                    attention_probs_dropout_prob,
207                    approach,
208                    sequence_length,
209                )
210            )
211            A.load_state_dict(B.state_dict())
212
213            hidden_state = maybe_to(
214                torch.rand(batch_size, sequence_length, hidden_size)
215            )
216            b_out = B(hidden_state)
217            a_out = A(hidden_state)
218            self.assertTrue(torch.allclose(a_out, b_out))
219
220            if time:
221                gpu_time(lambda: B(hidden_state), "positional", r=3)
222                gpu_time(lambda: A(hidden_state), "first_class", r=3)
223
224        A = maybe_to(
225            BertSelfAttentionA(
226                hidden_size,
227                num_attention_heads,
228                attention_probs_dropout_prob,
229                None,
230                None,
231                linear=linear,
232            )
233        )
234        B = maybe_to(
235            BertSelfAttentionB(
236                hidden_size,
237                num_attention_heads,
238                attention_probs_dropout_prob,
239                None,
240                None,
241            )
242        )
243        A.load_state_dict(B.state_dict())
244
245        hidden_state = maybe_to(torch.rand(batch_size, sequence_length, hidden_size))
246        past_key_value = (
247            maybe_to(
248                torch.rand(
249                    batch_size,
250                    num_attention_heads,
251                    sequence_length,
252                    hidden_size // num_attention_heads,
253                )
254            ),
255            maybe_to(
256                torch.rand(
257                    batch_size,
258                    num_attention_heads,
259                    sequence_length,
260                    hidden_size // num_attention_heads,
261                )
262            ),
263        )
264
265        b_out = B(hidden_state, past_key_value=past_key_value)
266        a_out = A(hidden_state, past_key_value=past_key_value)
267        self.assertTrue(torch.allclose(a_out, b_out))
268
269        if time:
270            gpu_time(lambda: B(hidden_state), "positional", r=3)
271            gpu_time(lambda: A(hidden_state), "first_class", r=3)
272
273    def test_attn(self):
274        self.attn()
275
276    def test_inplace(self):
277        # some embeddings table
278        embeddings = torch.zeros(10, 3)
279
280        # some sparse updates to the embeddings
281        indices = torch.arange(2) + 1
282        values = torch.rand(2, 3)
283
284        i, n, f = dims()
285
286        embeddings[indices[i], f] += values[i, f]
287
288    def test_adapt(self):
289        def f():
290            ci, co = dims()
291
292        # python 3.11 adapts bytecode after a number of iterations
293        # check that we still match names correctly
294        for i in range(10):
295            f()
296
297    @skipIf(not TEST_CUDA, "no CUDA")
298    def test_attn_cuda(self):
299        # size from the BERT paper, 90% pretraining of sequence length 128
300        self.attn(
301            batch_size=256,
302            hidden_size=768,
303            sequence_length=128,
304            num_attention_heads=12,
305            device="cuda",
306            time=measure_perf,
307            linear=torch.nn.Linear,
308        )
309
310    def test_stack(self):
311        i, j, d = dims()
312        A = torch.rand(4, 5)
313        r = stack([A[i, j]], d, j)
314        # a, b = r.unbind(d)
315        # self.assertTrue(torch.allclose(a.order(i, j), i.expand(j).order(i, j)))
316        # self.assertTrue(torch.allclose(b.order(i, j), j.expand(i).order(i, j)))
317
318    def test_max(self):
319        ap = torch.rand(2, 3, 2)
320        i, j, k = dims()
321        a = ap[i, j, k]
322        r, i0 = a.max(dim=k)
323        self.assertTrue(torch.allclose(r.order(i, j), ap.max(2)[0]))
324
325    def test_mm(self):
326        i, j, k, q = dims()
327        a = torch.rand(3, 4)
328        b = torch.rand(4, 5)
329        a_ = a[i, k]
330        b_ = b[k, j]
331        q.size = 1
332        r = (a_.expand(j, q) * b_.expand(i, q)).sum(k).order(q, i, j)
333        # r = (a_*b_).sum(k).order(q, i, j)
334        # print(r)
335        # print(a @ b)
336
337    def test_with_dims_split(self):
338        a = torch.arange(3 * 12).view(3, 12)
339        i, j, k = dims()
340        k.size = 4
341        r = a[i, [j, k]]
342        x = r.order(i, [j, k])
343        self.assertTrue(torch.allclose(a, x))
344
345    def test_hello(self):
346        A = torch.rand(3, 4)
347        B = torch.rand(4, 5)
348        i, j, k = dims()
349
350        # r = A[i]*4
351        r = (A[i, k] * B[k, j]).sum(k).order(i, j)
352        assert torch.allclose(r, A @ B)
353
354        assert A.sum() == A[i].sum((0, i))
355        assert A.sum() == A[i].sum((-1, i))
356
357        assert torch.allclose(A.sum(), A[i].sum(0, keepdim=True).sum((0, i)))
358        assert torch.allclose(A[i].std(i, True), A.std(0, True))
359
360        assert torch.allclose(A[i, k].max(i)[0].order(k), A.max(0)[0])
361        assert torch.allclose(A.sort(1)[0], A[i, k].sort(k)[0].order(i, k))
362        # XXX - chunk changes the size of a dimension, has to take a new dimension...
363        # assert torch.allclose(A.chunk(2,1)[0], A[i, k].chunk(2, k)[0].order(i, k))
364        assert torch.allclose(A[i].renorm(1, i, 7).order(i), A.renorm(1, 0, 7))
365        kk = dims()
366        # assert torch.allclose( torch.stack([A, A], 1), stack([A[i,k], A[i, k]], kk, k).order(i, kk, k))
367
368        k2 = dims()
369        # r = cat((A[i, k], A[i,k]), k, k2)
370        # assert torch.allclose(torch.cat([A, A], 1), r.order(i, k2))
371        # assert k2.size == 2*k.size
372
373        assert torch.allclose(A.expand(5, -1, -1), A[i, k].expand(j).order(j, i, k))
374        z = dims()
375        C = torch.arange(2)
376        assert torch.allclose(A[:, 0:2], A[i, k].index(k, C[z]).order(i, z))
377
378        o, l = dims()
379        o.size = 2
380        r = A[i, k].index(k, (o, l))
381        assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2))
382        rr = r.index((o, l), k)
383        assert torch.allclose(A, rr.order(i, k))
384
385        r = i + k - 1
386        r2 = torch.arange(3)[:, None] + torch.arange(4)[None, :] - 1
387        assert torch.allclose(r.order(i, k), r2)
388
389        # test with ...
390        assert torch.allclose(A.T, A[..., k].order(k))
391
392        # test with dimlist
393        a_, b_ = dimlists()
394        assert torch.allclose(A[i, a_].order(*a_, i), A.T)
395        # test with one bound dimlist
396        assert torch.allclose(A[:, a_].order(*a_), A.T)
397        # test with a dimlist that will end up empty
398        assert torch.allclose(A[i, b_, k].order(i, k, *b_), A)
399        # test with too few things
400        (A[i] + i)
401        assert torch.allclose((A[i] + i).order(i), A + torch.arange(3)[:, None])
402        # test with too many elements
403        try:
404            A[1, ..., 1, 1]
405            raise NotImplementedError
406        except IndexError:
407            pass
408        c, d = dims()
409        c.size = 2
410        assert torch.allclose(A[i, [c, d]].order(i, c, d), A.view(3, 2, 2))
411
412        assert torch.allclose(
413            A[c + 1, c + 0].order(c), A[torch.arange(2) + 1, torch.arange(2)]
414        )
415        try:
416            A[..., 3, ...]
417            raise NotImplementedError
418        except DimensionBindError:
419            pass
420
421        C = torch.rand(4, 7)
422        c_, x, y, z = dims()
423
424        a, b, c = C.split((3, 3, 1), dim=1)
425        s = dims()
426        ref = C.split((3, 3, 1), dim=1)
427        t = C[s, c_].split((x, y, z), dim=c_)
428        for a, b, d in zip(ref, t, (x, y, z)):
429            assert torch.allclose(a, b.order(s, d))
430
431        D = torch.rand(3, 4, 5)
432        assert torch.allclose(
433            D.transpose(0, 1).flatten(1, 2), D[i, k, j].order((i, j)).order(k)
434        )
435
436        r = [id(x) for x in torch.rand_like(A[i, k]).dims]
437        assert id(i) in r and id(k) in r
438        r = [id(x) for x in torch.nn.functional.dropout(A[i, k]).dims]
439        assert id(i) in r and id(k) in r
440
441    def test_simple(self):
442        i, j, k = dims()
443        x = torch.rand(3, 4)
444        z = x[i, j]
445        (z + z + z + z)
446        (z.order(i, j))
447
448    def test_mm_fuse(self):
449        i, j, k = dims()
450        A = torch.rand(3, 4)
451        B = torch.rand(4, 5)
452
453        C = (A[i, k] * B[k, j]).sum(k).order(i, j)
454        assert torch.allclose(C, A @ B)
455
456    def test_time_mm_fuse(self):
457        i, j, k = dims()
458        A = torch.rand(3, 4)
459        B = torch.rand(4, 5)
460
461        for _ in range(10):
462            r0 = A @ B
463
464        for _ in range(10):
465            a = A[i, k]
466            b = B[k, j]
467            r1 = (a * b).sum(k)
468
469        with measure("pp"):
470            for _ in range(10000):
471                A @ B
472        # magic_trace_stop_indicator()
473
474        with measure("fc"):
475            for _ in range(10000):
476                (A[i, k] * B[k, j]).sum(k).order(i, j)
477
478        with magic_trace("f.fxt"):
479            for _ in range(10000):
480                (A[i, k] * B[k, j]).sum(k).order(i, j)
481
482        with magic_trace("p.fxt"):
483            for _ in range(10000):
484                A @ B
485
486        # magic_trace_stop_indicator()
487
488        assert torch.allclose(r1.order(i, j), r0)
489
490    def test_compare_dims(self):
491        i, j = dims()
492        i.size = 3
493        j.size = 4
494        (i < j)  # noqa: B015
495
496    def test_c(self):
497        _test_c()
498
499    def test_seg(self):
500        A = torch.rand(3, 4)
501        i, k = dims()
502        i.size = 4
503        k.size = 3
504        r = i + k - 1
505
506    def test_expand(self):
507        A = torch.rand(3, 4)
508        i = dims()
509        assert list(A[i].expand(2, 4).order(i).size()) == [3, 2, 4]
510
511    def test_parse(self):
512        self.assertEqual(("x", None, None, None), _parse_test(1, 0, "x"))
513        self.assertEqual(("x", None, "y", None), _parse_test(1, 0, "x", c="y"))
514        self.assertEqual(("x", None, "y", "z"), _parse_test(1, 0, "x", d="z", c="y"))
515
516        self.assertEqual(("x", "4", None, None), _parse_test(2, 0, "x", b="4"))
517        self.assertEqual(("x", "y", "z", "q"), _parse_test(2, 0, "x", "y", "z", "q"))
518        with self.assertRaises(TypeError):
519            _parse_test(2, 0, "x", "y", "z", "q", "5")
520        with self.assertRaises(TypeError):
521            _parse_test(2, 0, "x", "y", b="y")
522
523        with self.assertRaises(TypeError):
524            _parse_test(2, 0, "x", c="y")
525        with self.assertRaises(TypeError):
526            _parse_test(2, 0, "x")
527
528    def test_network(self):
529        if resnet18 is None:
530            self.skipTest("no torchvision")
531        rn = resnet18(
532            norm_layer=lambda x: torch.nn.BatchNorm2d(x, track_running_stats=False)
533        )
534        rn.train()
535        img = torch.rand(1, 1, 2, 3, 224, 224)
536        imgf = img.view(2, 3, 224, 224)
537
538        i, j = dims()
539        r = rn(img[i, j])
540        r = r.order(i, j).view(2, 1000)
541        r2 = rn(imgf)
542        assert torch.allclose(r2, r, atol=1e-06)
543
544    def test_dim_args(self):
545        a = dimlists()
546        assert isinstance(a, DimList)
547        a = dims()
548        b = dimlists()
549        assert isinstance(a, Dim)
550        assert isinstance(b, DimList)
551        assert str(a) == "a"
552        a, b = dims(sizes=[3, 4])
553        assert a.size == 3
554        assert b.size == 4
555        a = dims(sizes=[3])
556        b = dimlists(sizes=[4])
557        assert len(b) == 4
558        a = dims()
559        b = dimlists(sizes=[[4, 5]])
560        assert b[0].size == 4
561        assert b[1].size == 5
562
563    def test_diag(self):
564        i = dims()
565        A = torch.rand(4, 4)
566        (A[i, i])
567
568    def test_softmax_split(self):
569        a = torch.rand(16)
570        g, i = dims(sizes=[2, None])
571        a2 = a[[i, g],]
572
573        m_b, _ = a2.max(i)
574        f_b = torch.exp(a2 - m_b)
575        l_b = f_b.sum(i)
576
577        m, _ = m_b.max(g)
578        c = torch.exp(m_b - m)
579        f = (c * f_b).order((i, g))
580        l = (c * l_b).sum(g)
581        assert torch.allclose(f / l, torch.nn.functional.softmax(a, dim=0))
582
583    def test_index(self):
584        A = torch.rand(3, 4)
585        B = torch.rand(4, 5)
586        i, j, k = dims()
587
588        o, l = dims()
589        o.size = 2
590        r = A[i, k].index(k, [o, l])
591        assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2))
592        rr = r.index([o, l], k)
593        assert torch.allclose(A, rr.order(i, k))
594        z = dims()
595        C = torch.arange(2)
596        x = A[i, k].index(k, C[z]).order(i, z)
597        assert torch.allclose(A[:, 0:2], x)
598
599        C = torch.rand(3, 4, 5)
600        ik = dims()
601        assert torch.allclose(
602            C.index((0, 2), ik).order(ik), C.permute(0, 2, 1).reshape(15, 4)
603        )
604
605    # failures that came up from monkey patching some operators...
606    def test_monkey(self):
607        A = torch.rand(3, 4)
608        A[0, 0] = 5
609        x = torch.randn(3, 4, 4, 4, 3)
610        x_clone1 = x.clone()
611        ia = torch.tensor([0, 2, 1])
612        ib = torch.tensor([0, 2, 1])
613        first_shape = x[:, ia, None, ib, 0].shape
614        x_clone1[:, ia, None, ib, 0] = torch.randn(first_shape).to(x_clone1)
615        x = torch.autograd.Variable(torch.tensor([]))
616        z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
617        a = [z[2], z[0] + 3]
618        x.new(a)
619        # self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4])
620
621    def test_index_placement(self):
622        A = torch.rand(1, 2, 3, 4)
623
624        i, j = dims(sizes=[2, 4])
625
626        a = A[:, i + 0, :, j + 0]
627        r = a.order(i, j)
628
629        assert torch.allclose(A.permute(1, 3, 0, 2), r)
630
631    def test_order(self):
632        i, j = dims()
633        A = torch.rand(3, 4, 5)
634        assert torch.allclose(A[i].order(1, i), A.permute(2, 0, 1))
635
636    def test_mask(self):
637        a = torch.rand(5)
638        i, j = dims(sizes=[a.size(0), a.size(0)])
639        ((i >= j) * a[i]).sum(j).order(i)
640
641    def test_eq(self):
642        i, j = dims(sizes=[3, 3])
643        assert (i == j).sum((i, j)) == 3
644
645    def test_dims_with_size(self):
646        x = dims(3)
647        assert len(x) == 3 and isinstance(x[0], Dim)
648
649        class Foo:
650            pass
651
652        y = Foo()
653        z, y.x, q = dims(3)
654        assert str(z) == "z"
655        assert str(y.x) == "d1"
656        assert str(q) == "d2"
657
658    def test_dir(self):
659        i, j = dims(sizes=[3, 3])
660        dir(i <= j)
661
662    def test_doc(self):
663        assert Tensor.clamp.__doc__ == torch.Tensor.clamp.__doc__
664
665    def test_embed(self):
666        embeddings = torch.rand(8, 32)
667        ids = torch.tensor([1, 0, 3, 4])
668
669        # slow but Pythonic
670        values_ = torch.empty(4, 32)
671        for batch in range(ids.size(0)):
672            for feature in range(embeddings.size(1)):
673                values_[batch, feature] = embeddings[ids[batch], feature]
674
675        # with torchdim, single indexing kernel
676        batch, feature = dims(2)
677        values = embeddings[ids[batch], feature].order(batch, feature)
678
679        assert torch.allclose(values, values_)
680
681    def test_functorch(self):
682        A = torch.rand(3, 4, 5)
683        B = torch.rand(3, 4, 5)
684        C = torch.rand(5, 2)
685
686        i, j = dims()
687
688        AA = torch.mm(A[i], C)  # 3, 4, 2
689        BB = torch.mm(B[j], C)  # 3, 4, 2
690        assert list(torch.mm(AA.T, BB).order(i, j).shape) == [3, 3, 2, 2]
691
692    def test_permute_orig(self):
693        d = dims(1)
694        t_fc = torch.rand(1, 2, 3, 4)[d]
695        assert t_fc.permute(dims=(1, 0, 2)).shape == t_fc.permute(1, 0, 2).shape
696
697    def test_order_keyword(self):
698        d = dims(1)
699        t = torch.rand(3)[d]
700        self.assertRaises(TypeError, lambda: t.order(wrong=3))
701
702    def test_big_split(self):
703        total = 0
704        l = []
705        while total < 6400:
706            l.append(torch.randint(2, 10, (1,)).item())
707            total += l[-1]
708        x = torch.randn(total, 1)
709        x.split(l, 0)
710
711
712skip_functorch_only = ["test_time_mm_fuse", "test_attn_cuda"]
713
714
715class TestMinFunctorchOnly(TestMin):
716    def setUp(self):
717        super().setUp()
718        _set_pointwise_optimize(False)
719
720    def tearDown(self):
721        _set_pointwise_optimize(True)
722        super().tearDown()
723
724
725for n in skip_functorch_only:
726    setattr(TestMinFunctorchOnly, n, skip("skip_functorch_only")(lambda self: None))
727
728if __name__ == "__main__":
729    run_tests()
730