xref: /aosp_15_r20/external/pytorch/test/dynamo/test_base_output.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import unittest.mock
3
4import torch
5import torch._dynamo.test_case
6import torch._dynamo.testing
7from torch._dynamo.testing import same
8
9
10try:
11    from diffusers.models import unet_2d
12except ImportError:
13    unet_2d = None
14
15
16def maybe_skip(fn):
17    if unet_2d is None:
18        return unittest.skip("requires diffusers")(fn)
19    return fn
20
21
22class TestBaseOutput(torch._dynamo.test_case.TestCase):
23    @maybe_skip
24    def test_create(self):
25        def fn(a):
26            tmp = unet_2d.UNet2DOutput(a + 1)
27            return tmp
28
29        torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=1)
30
31    @maybe_skip
32    def test_assign(self):
33        def fn(a):
34            tmp = unet_2d.UNet2DOutput(a + 1)
35            tmp.sample = a + 2
36            return tmp
37
38        args = [torch.randn(10)]
39        obj1 = fn(*args)
40
41        cnts = torch._dynamo.testing.CompileCounter()
42        opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
43        obj2 = opt_fn(*args)
44        self.assertTrue(same(obj1.sample, obj2.sample))
45        self.assertEqual(cnts.frame_count, 1)
46        self.assertEqual(cnts.op_count, 2)
47
48    def _common(self, fn, op_count):
49        args = [
50            unet_2d.UNet2DOutput(
51                sample=torch.randn(10),
52            )
53        ]
54        obj1 = fn(*args)
55        cnts = torch._dynamo.testing.CompileCounter()
56        opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
57        obj2 = opt_fn(*args)
58        self.assertTrue(same(obj1, obj2))
59        self.assertEqual(cnts.frame_count, 1)
60        self.assertEqual(cnts.op_count, op_count)
61
62    @maybe_skip
63    def test_getattr(self):
64        def fn(obj: unet_2d.UNet2DOutput):
65            x = obj.sample * 10
66            return x
67
68        self._common(fn, 1)
69
70    @maybe_skip
71    def test_getitem(self):
72        def fn(obj: unet_2d.UNet2DOutput):
73            x = obj["sample"] * 10
74            return x
75
76        self._common(fn, 1)
77
78    @maybe_skip
79    def test_tuple(self):
80        def fn(obj: unet_2d.UNet2DOutput):
81            a = obj.to_tuple()
82            return a[0] * 10
83
84        self._common(fn, 1)
85
86    @maybe_skip
87    def test_index(self):
88        def fn(obj: unet_2d.UNet2DOutput):
89            return obj[0] * 10
90
91        self._common(fn, 1)
92
93
94if __name__ == "__main__":
95    from torch._dynamo.test_case import run_tests
96
97    run_tests()
98