1# Owner(s): ["module: dynamo"] 2import unittest.mock 3 4import torch 5import torch._dynamo.test_case 6import torch._dynamo.testing 7from torch._dynamo.testing import same 8 9 10try: 11 from diffusers.models import unet_2d 12except ImportError: 13 unet_2d = None 14 15 16def maybe_skip(fn): 17 if unet_2d is None: 18 return unittest.skip("requires diffusers")(fn) 19 return fn 20 21 22class TestBaseOutput(torch._dynamo.test_case.TestCase): 23 @maybe_skip 24 def test_create(self): 25 def fn(a): 26 tmp = unet_2d.UNet2DOutput(a + 1) 27 return tmp 28 29 torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=1) 30 31 @maybe_skip 32 def test_assign(self): 33 def fn(a): 34 tmp = unet_2d.UNet2DOutput(a + 1) 35 tmp.sample = a + 2 36 return tmp 37 38 args = [torch.randn(10)] 39 obj1 = fn(*args) 40 41 cnts = torch._dynamo.testing.CompileCounter() 42 opt_fn = torch._dynamo.optimize_assert(cnts)(fn) 43 obj2 = opt_fn(*args) 44 self.assertTrue(same(obj1.sample, obj2.sample)) 45 self.assertEqual(cnts.frame_count, 1) 46 self.assertEqual(cnts.op_count, 2) 47 48 def _common(self, fn, op_count): 49 args = [ 50 unet_2d.UNet2DOutput( 51 sample=torch.randn(10), 52 ) 53 ] 54 obj1 = fn(*args) 55 cnts = torch._dynamo.testing.CompileCounter() 56 opt_fn = torch._dynamo.optimize_assert(cnts)(fn) 57 obj2 = opt_fn(*args) 58 self.assertTrue(same(obj1, obj2)) 59 self.assertEqual(cnts.frame_count, 1) 60 self.assertEqual(cnts.op_count, op_count) 61 62 @maybe_skip 63 def test_getattr(self): 64 def fn(obj: unet_2d.UNet2DOutput): 65 x = obj.sample * 10 66 return x 67 68 self._common(fn, 1) 69 70 @maybe_skip 71 def test_getitem(self): 72 def fn(obj: unet_2d.UNet2DOutput): 73 x = obj["sample"] * 10 74 return x 75 76 self._common(fn, 1) 77 78 @maybe_skip 79 def test_tuple(self): 80 def fn(obj: unet_2d.UNet2DOutput): 81 a = obj.to_tuple() 82 return a[0] * 10 83 84 self._common(fn, 1) 85 86 @maybe_skip 87 def test_index(self): 88 def fn(obj: unet_2d.UNet2DOutput): 89 return obj[0] * 10 90 91 self._common(fn, 1) 92 93 94if __name__ == "__main__": 95 from torch._dynamo.test_case import run_tests 96 97 run_tests() 98