xref: /aosp_15_r20/external/pytorch/test/dynamo/test_model_output.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import dataclasses
3import unittest.mock
4
5import torch
6import torch._dynamo.test_case
7import torch._dynamo.testing
8from torch._dynamo.testing import same
9from torch.testing._internal.common_device_type import instantiate_device_type_tests
10from torch.testing._internal.common_utils import TEST_HPU, TestCase
11
12
13try:
14    from transformers import modeling_outputs
15    from transformers.configuration_utils import PretrainedConfig
16    from transformers.file_utils import ModelOutput
17    from transformers.modeling_outputs import (
18        BaseModelOutput,
19        BaseModelOutputWithPastAndCrossAttentions,
20        BaseModelOutputWithPoolingAndCrossAttentions,
21        CausalLMOutputWithPast,
22    )
23except ImportError:
24    modeling_outputs = None
25
26
27def maybe_skip(fn):
28    if modeling_outputs is None:
29        return unittest.skip("requires HuggingFace")(fn)
30    return fn
31
32
33class TestHFPretrained(torch._dynamo.test_case.TestCase):
34    @maybe_skip
35    def test_pretrained(self):
36        def fn(a, tmp):
37            if hasattr(tmp, "somekey"):
38                a = a + 1
39            if tmp.return_dict:
40                return a + torch.ones(2) * tmp.max_length
41            return a
42
43        x = torch.randn(2)
44        tmp = PretrainedConfig(return_dict=True, max_length=20)
45        ref = fn(x, tmp)
46        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
47        res = opt_fn(x, tmp)
48        self.assertTrue(same(ref, res))
49
50    @maybe_skip
51    def test_pretrained_non_const_attr(self):
52        def fn(a, tmp):
53            if tmp.pruned_heads:
54                return a + 1
55            else:
56                return a - 1
57
58        x = torch.randn(2)
59        tmp = PretrainedConfig()
60        ref = fn(x, tmp)
61        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
62        res = opt_fn(x, tmp)
63        self.assertTrue(same(ref, res))
64
65
66class TestModelOutput(torch._dynamo.test_case.TestCase):
67    @maybe_skip
68    def test_mo_create(self):
69        def fn(a, b):
70            tmp = BaseModelOutput(a + 1, attentions=b + 3)
71            return tmp
72
73        torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2)
74
75    @maybe_skip
76    def test_mo_assign(self):
77        def fn(a, b):
78            tmp = BaseModelOutput(last_hidden_state=b + 3)
79            tmp.hidden_states = a + 7
80            tmp["attentions"] = a + b + 6
81            return tmp
82
83        args = [torch.randn(10), torch.randn(10)]
84        obj1 = fn(*args)
85
86        cnts = torch._dynamo.testing.CompileCounter()
87        opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
88        obj2 = opt_fn(*args)
89        self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state))
90        self.assertTrue(same(obj1.hidden_states, obj2.hidden_states))
91        self.assertTrue(same(obj1.attentions, obj2.attentions))
92        self.assertEqual(cnts.frame_count, 1)
93        self.assertEqual(cnts.op_count, 4)
94
95    def _common(self, fn, op_count):
96        args = [
97            BaseModelOutput(
98                last_hidden_state=torch.randn(10), attentions=torch.randn(10)
99            )
100        ]
101        obj1 = fn(*args)
102        cnts = torch._dynamo.testing.CompileCounter()
103        opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
104        obj2 = opt_fn(*args)
105        self.assertTrue(same(obj1, obj2))
106        self.assertEqual(cnts.frame_count, 1)
107        self.assertEqual(cnts.op_count, op_count)
108
109    @maybe_skip
110    def test_mo_getattr(self):
111        def fn(obj: BaseModelOutput):
112            x = obj.last_hidden_state * 10
113            if obj.hidden_states is not None:
114                x += obj.hidden_states
115            if obj.attentions is not None:
116                x += obj.attentions
117            return x
118
119        self._common(fn, 2)
120
121    @maybe_skip
122    def test_mo_getattr_missing(self):
123        def fn(obj: BaseModelOutput):
124            if getattr(obj, "asdf", None) is not None:
125                obj.asdf += 1
126            return obj.attentions + 1
127
128        self._common(fn, 1)
129
130    @maybe_skip
131    def test_mo_getitem(self):
132        def fn(obj: BaseModelOutput):
133            x = obj["last_hidden_state"] * 10
134            if "hidden_stats" in obj:
135                x += obj["hidden_states"]
136            if "attentions" in obj:
137                x += obj["attentions"]
138            return x
139
140        self._common(fn, 2)
141
142    @maybe_skip
143    def test_mo_tuple(self):
144        def fn(obj: BaseModelOutput):
145            a, b = obj.to_tuple()
146            return a + b * 10
147
148        self._common(fn, 2)
149
150    @maybe_skip
151    def test_mo_index(self):
152        def fn(obj: BaseModelOutput):
153            return obj[0] * 10 + obj[1]
154
155        self._common(fn, 2)
156
157    @maybe_skip
158    def test_mo_init(self):
159        @dataclasses.dataclass
160        class MyDataClass(ModelOutput):
161            a: torch.Tensor
162            b: torch.Tensor = None
163            c: torch.Tensor = None
164            d: torch.Tensor = None
165            e: torch.Tensor = None
166
167        def fn(obj):
168            class_fields = dataclasses.fields(obj)
169            assert len(class_fields)
170            assert all(field.default is None for field in class_fields[1:])
171            other_fields_are_none = all(
172                getattr(obj, field.name) is None for field in class_fields[1:]
173            )
174            assert not other_fields_are_none
175
176            total = getattr(obj, class_fields[0].name)
177            for field in class_fields[1:]:
178                v = getattr(obj, field.name)
179                if v is not None:
180                    total += v
181
182            return total
183
184        tensors = [torch.randn(10), torch.randn(10), torch.randn(10)]
185        obj1 = MyDataClass(*tensors)
186        correct1 = fn(obj1)
187
188        obj2 = MyDataClass(*tensors)
189        cnts = torch._dynamo.testing.CompileCounter()
190        opt_fn = torch._dynamo.optimize(cnts)(fn)
191        self.assertTrue(same(opt_fn(obj2), correct1))
192        self.assertEqual(cnts.frame_count, 1)
193        self.assertEqual(cnts.op_count, 2)
194
195    @maybe_skip
196    def test_mo_init2(self):
197        # this ModelOutput subclass runs a different __post_init__ codepath
198        @dataclasses.dataclass
199        class MyDataClass(ModelOutput):
200            x: torch.FloatTensor = None
201
202        def fn(x):
203            obj = MyDataClass(x=x)
204            return obj
205
206        inp = torch.randn(3, 3)
207        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
208        self.assertEqual(fn(inp).x, opt_fn(inp).x)
209
210    @maybe_skip
211    def test_mo_init_with_disable(self):
212        # Can result in "non-function or method super: <slot wrapper '__setattr__' of 'object' objects>"
213        # graph breaks (although it may not be the first)
214        # Minimal repro for https://github.com/pytorch/pytorch/issues/126028
215        @dataclasses.dataclass
216        class MyDataClass(ModelOutput):
217            x: torch.FloatTensor = None
218
219        @torch._dynamo.disable(recursive=False)
220        def fn(x):
221            return MyDataClass(x=x)
222
223        inp = torch.randn(3, 3)
224        opt_fn = torch._dynamo.optimize("eager")(fn)
225        self.assertEqual(fn(inp).x, opt_fn(inp).x)
226
227    @maybe_skip
228    def test_mo_newkey(self):
229        obj = BaseModelOutput()
230
231        def fn(obj):
232            return obj["wwww"] + 1
233
234        inp = torch.randn(3, 3)
235        obj["wwww"] = inp
236        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
237        self.assertEqual(fn(obj), opt_fn(obj))
238
239    @maybe_skip
240    def test_mo_from_outside(self):
241        def fn(obj):
242            return obj.attentions + 1
243
244        obj = BaseModelOutput(attentions=torch.randn(3, 3))
245        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
246        self.assertEqual(fn(obj), opt_fn(obj))
247
248    @maybe_skip
249    def test_mo_reconstruct_bytecode(self):
250        def fn(inp):
251            return BaseModelOutput(attentions=inp + 1)
252
253        inp = torch.randn(3, 3)
254        opt_fn = torch._dynamo.optimize("eager")(fn)
255        self.assertEqual(fn(inp).attentions, opt_fn(inp).attentions)
256
257    @maybe_skip
258    def test_none(self):
259        class Model(torch.nn.Module):
260            def forward(self, x):
261                x = x + 1
262                return CausalLMOutputWithPast(loss=None, logits=x)[0]
263
264        model = Model()
265        opt_model = torch.compile(model, backend="eager", fullgraph=True)
266        x = torch.randn(1, 1, 1, 1)
267
268        self.assertTrue(same(model(x), opt_model(x)))
269
270    @maybe_skip
271    def test_reconstruction(self):
272        class Model(torch.nn.Module):
273            def forward(self, x):
274                x = x + 1
275                return CausalLMOutputWithPast(loss=x, logits=None)
276
277        model = Model()
278        x = torch.randn(1, 1, 1, 1)
279        eo = torch._dynamo.export(Model(), aten_graph=True)(x)
280        self.assertTrue(same(model(x), eo.graph_module(x)))
281
282
283class TestModelOutputBert(TestCase):
284    @maybe_skip
285    def test_HF_bert_model_output(self, device):
286        class BertPooler(torch.nn.Module):
287            def __init__(self) -> None:
288                super().__init__()
289                self.dense = torch.nn.Linear(768, 768).to(device)
290                self.activation = torch.nn.Tanh()
291
292            def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
293                # We "pool" the model by simply taking the hidden state corresponding
294                # to the first token.
295                first_token_tensor = hidden_states[:, 0]
296                pooled_output = self.dense(first_token_tensor)
297                pooled_output = self.activation(pooled_output)
298                return pooled_output
299
300        class BertEncoder(torch.nn.Module):
301            def __init__(self) -> None:
302                super().__init__()
303
304            def forward(
305                self,
306                hidden_states: torch.Tensor,
307            ) -> BaseModelOutputWithPastAndCrossAttentions:
308                return BaseModelOutputWithPastAndCrossAttentions(
309                    last_hidden_state=hidden_states,
310                    past_key_values=None,
311                    hidden_states=None,
312                    attentions=None,
313                    cross_attentions=None,
314                )
315
316        class BertModel(torch.nn.Module):
317            def __init__(self) -> None:
318                super().__init__()
319                self.encoder = BertEncoder()
320                self.pooler = BertPooler()
321
322            def forward(
323                self,
324                sequence_output: torch.Tensor,
325            ) -> BaseModelOutputWithPoolingAndCrossAttentions:
326                encoder_outputs = self.encoder(sequence_output)
327                # test __getitem__ and to_tuple
328                sequence_output = encoder_outputs[0]
329                pooled_output = (
330                    self.pooler(sequence_output) if self.pooler is not None else None
331                )
332                # test CustomDictVariable.create
333                result = BaseModelOutputWithPoolingAndCrossAttentions(
334                    last_hidden_state=sequence_output,
335                    pooler_output=pooled_output,
336                    past_key_values=encoder_outputs.past_key_values,
337                    hidden_states=encoder_outputs.hidden_states,
338                    attentions=encoder_outputs.attentions,
339                    cross_attentions=encoder_outputs.cross_attentions,
340                )
341                # test __setattr__
342                result.pooler_output = pooled_output
343                # test __setitem__
344                result["pooler_output"] = pooled_output
345                return result
346
347        sequence_output = torch.rand(1, 12, 768).to(device)
348        model = BertModel()
349        orig_result = model(sequence_output)
350        compiled_model = torch.compile(model, backend="eager")
351        compiled_result = compiled_model(sequence_output)
352        self.assertTrue(
353            torch.allclose(
354                orig_result.last_hidden_state, compiled_result.last_hidden_state
355            )
356        )
357        self.assertTrue(
358            torch.allclose(orig_result.pooler_output, compiled_result.pooler_output)
359        )
360
361
362devices = ["cpu", "cuda"]
363if TEST_HPU:
364    devices.append("hpu")
365
366instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices)
367
368if __name__ == "__main__":
369    from torch._dynamo.test_case import run_tests
370
371    run_tests()
372