1# Owner(s): ["module: dynamo"] 2import dataclasses 3import unittest.mock 4 5import torch 6import torch._dynamo.test_case 7import torch._dynamo.testing 8from torch._dynamo.testing import same 9from torch.testing._internal.common_device_type import instantiate_device_type_tests 10from torch.testing._internal.common_utils import TEST_HPU, TestCase 11 12 13try: 14 from transformers import modeling_outputs 15 from transformers.configuration_utils import PretrainedConfig 16 from transformers.file_utils import ModelOutput 17 from transformers.modeling_outputs import ( 18 BaseModelOutput, 19 BaseModelOutputWithPastAndCrossAttentions, 20 BaseModelOutputWithPoolingAndCrossAttentions, 21 CausalLMOutputWithPast, 22 ) 23except ImportError: 24 modeling_outputs = None 25 26 27def maybe_skip(fn): 28 if modeling_outputs is None: 29 return unittest.skip("requires HuggingFace")(fn) 30 return fn 31 32 33class TestHFPretrained(torch._dynamo.test_case.TestCase): 34 @maybe_skip 35 def test_pretrained(self): 36 def fn(a, tmp): 37 if hasattr(tmp, "somekey"): 38 a = a + 1 39 if tmp.return_dict: 40 return a + torch.ones(2) * tmp.max_length 41 return a 42 43 x = torch.randn(2) 44 tmp = PretrainedConfig(return_dict=True, max_length=20) 45 ref = fn(x, tmp) 46 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 47 res = opt_fn(x, tmp) 48 self.assertTrue(same(ref, res)) 49 50 @maybe_skip 51 def test_pretrained_non_const_attr(self): 52 def fn(a, tmp): 53 if tmp.pruned_heads: 54 return a + 1 55 else: 56 return a - 1 57 58 x = torch.randn(2) 59 tmp = PretrainedConfig() 60 ref = fn(x, tmp) 61 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 62 res = opt_fn(x, tmp) 63 self.assertTrue(same(ref, res)) 64 65 66class TestModelOutput(torch._dynamo.test_case.TestCase): 67 @maybe_skip 68 def test_mo_create(self): 69 def fn(a, b): 70 tmp = BaseModelOutput(a + 1, attentions=b + 3) 71 return tmp 72 73 torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2) 74 75 @maybe_skip 76 def test_mo_assign(self): 77 def fn(a, b): 78 tmp = BaseModelOutput(last_hidden_state=b + 3) 79 tmp.hidden_states = a + 7 80 tmp["attentions"] = a + b + 6 81 return tmp 82 83 args = [torch.randn(10), torch.randn(10)] 84 obj1 = fn(*args) 85 86 cnts = torch._dynamo.testing.CompileCounter() 87 opt_fn = torch._dynamo.optimize_assert(cnts)(fn) 88 obj2 = opt_fn(*args) 89 self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state)) 90 self.assertTrue(same(obj1.hidden_states, obj2.hidden_states)) 91 self.assertTrue(same(obj1.attentions, obj2.attentions)) 92 self.assertEqual(cnts.frame_count, 1) 93 self.assertEqual(cnts.op_count, 4) 94 95 def _common(self, fn, op_count): 96 args = [ 97 BaseModelOutput( 98 last_hidden_state=torch.randn(10), attentions=torch.randn(10) 99 ) 100 ] 101 obj1 = fn(*args) 102 cnts = torch._dynamo.testing.CompileCounter() 103 opt_fn = torch._dynamo.optimize_assert(cnts)(fn) 104 obj2 = opt_fn(*args) 105 self.assertTrue(same(obj1, obj2)) 106 self.assertEqual(cnts.frame_count, 1) 107 self.assertEqual(cnts.op_count, op_count) 108 109 @maybe_skip 110 def test_mo_getattr(self): 111 def fn(obj: BaseModelOutput): 112 x = obj.last_hidden_state * 10 113 if obj.hidden_states is not None: 114 x += obj.hidden_states 115 if obj.attentions is not None: 116 x += obj.attentions 117 return x 118 119 self._common(fn, 2) 120 121 @maybe_skip 122 def test_mo_getattr_missing(self): 123 def fn(obj: BaseModelOutput): 124 if getattr(obj, "asdf", None) is not None: 125 obj.asdf += 1 126 return obj.attentions + 1 127 128 self._common(fn, 1) 129 130 @maybe_skip 131 def test_mo_getitem(self): 132 def fn(obj: BaseModelOutput): 133 x = obj["last_hidden_state"] * 10 134 if "hidden_stats" in obj: 135 x += obj["hidden_states"] 136 if "attentions" in obj: 137 x += obj["attentions"] 138 return x 139 140 self._common(fn, 2) 141 142 @maybe_skip 143 def test_mo_tuple(self): 144 def fn(obj: BaseModelOutput): 145 a, b = obj.to_tuple() 146 return a + b * 10 147 148 self._common(fn, 2) 149 150 @maybe_skip 151 def test_mo_index(self): 152 def fn(obj: BaseModelOutput): 153 return obj[0] * 10 + obj[1] 154 155 self._common(fn, 2) 156 157 @maybe_skip 158 def test_mo_init(self): 159 @dataclasses.dataclass 160 class MyDataClass(ModelOutput): 161 a: torch.Tensor 162 b: torch.Tensor = None 163 c: torch.Tensor = None 164 d: torch.Tensor = None 165 e: torch.Tensor = None 166 167 def fn(obj): 168 class_fields = dataclasses.fields(obj) 169 assert len(class_fields) 170 assert all(field.default is None for field in class_fields[1:]) 171 other_fields_are_none = all( 172 getattr(obj, field.name) is None for field in class_fields[1:] 173 ) 174 assert not other_fields_are_none 175 176 total = getattr(obj, class_fields[0].name) 177 for field in class_fields[1:]: 178 v = getattr(obj, field.name) 179 if v is not None: 180 total += v 181 182 return total 183 184 tensors = [torch.randn(10), torch.randn(10), torch.randn(10)] 185 obj1 = MyDataClass(*tensors) 186 correct1 = fn(obj1) 187 188 obj2 = MyDataClass(*tensors) 189 cnts = torch._dynamo.testing.CompileCounter() 190 opt_fn = torch._dynamo.optimize(cnts)(fn) 191 self.assertTrue(same(opt_fn(obj2), correct1)) 192 self.assertEqual(cnts.frame_count, 1) 193 self.assertEqual(cnts.op_count, 2) 194 195 @maybe_skip 196 def test_mo_init2(self): 197 # this ModelOutput subclass runs a different __post_init__ codepath 198 @dataclasses.dataclass 199 class MyDataClass(ModelOutput): 200 x: torch.FloatTensor = None 201 202 def fn(x): 203 obj = MyDataClass(x=x) 204 return obj 205 206 inp = torch.randn(3, 3) 207 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 208 self.assertEqual(fn(inp).x, opt_fn(inp).x) 209 210 @maybe_skip 211 def test_mo_init_with_disable(self): 212 # Can result in "non-function or method super: <slot wrapper '__setattr__' of 'object' objects>" 213 # graph breaks (although it may not be the first) 214 # Minimal repro for https://github.com/pytorch/pytorch/issues/126028 215 @dataclasses.dataclass 216 class MyDataClass(ModelOutput): 217 x: torch.FloatTensor = None 218 219 @torch._dynamo.disable(recursive=False) 220 def fn(x): 221 return MyDataClass(x=x) 222 223 inp = torch.randn(3, 3) 224 opt_fn = torch._dynamo.optimize("eager")(fn) 225 self.assertEqual(fn(inp).x, opt_fn(inp).x) 226 227 @maybe_skip 228 def test_mo_newkey(self): 229 obj = BaseModelOutput() 230 231 def fn(obj): 232 return obj["wwww"] + 1 233 234 inp = torch.randn(3, 3) 235 obj["wwww"] = inp 236 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 237 self.assertEqual(fn(obj), opt_fn(obj)) 238 239 @maybe_skip 240 def test_mo_from_outside(self): 241 def fn(obj): 242 return obj.attentions + 1 243 244 obj = BaseModelOutput(attentions=torch.randn(3, 3)) 245 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 246 self.assertEqual(fn(obj), opt_fn(obj)) 247 248 @maybe_skip 249 def test_mo_reconstruct_bytecode(self): 250 def fn(inp): 251 return BaseModelOutput(attentions=inp + 1) 252 253 inp = torch.randn(3, 3) 254 opt_fn = torch._dynamo.optimize("eager")(fn) 255 self.assertEqual(fn(inp).attentions, opt_fn(inp).attentions) 256 257 @maybe_skip 258 def test_none(self): 259 class Model(torch.nn.Module): 260 def forward(self, x): 261 x = x + 1 262 return CausalLMOutputWithPast(loss=None, logits=x)[0] 263 264 model = Model() 265 opt_model = torch.compile(model, backend="eager", fullgraph=True) 266 x = torch.randn(1, 1, 1, 1) 267 268 self.assertTrue(same(model(x), opt_model(x))) 269 270 @maybe_skip 271 def test_reconstruction(self): 272 class Model(torch.nn.Module): 273 def forward(self, x): 274 x = x + 1 275 return CausalLMOutputWithPast(loss=x, logits=None) 276 277 model = Model() 278 x = torch.randn(1, 1, 1, 1) 279 eo = torch._dynamo.export(Model(), aten_graph=True)(x) 280 self.assertTrue(same(model(x), eo.graph_module(x))) 281 282 283class TestModelOutputBert(TestCase): 284 @maybe_skip 285 def test_HF_bert_model_output(self, device): 286 class BertPooler(torch.nn.Module): 287 def __init__(self) -> None: 288 super().__init__() 289 self.dense = torch.nn.Linear(768, 768).to(device) 290 self.activation = torch.nn.Tanh() 291 292 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 293 # We "pool" the model by simply taking the hidden state corresponding 294 # to the first token. 295 first_token_tensor = hidden_states[:, 0] 296 pooled_output = self.dense(first_token_tensor) 297 pooled_output = self.activation(pooled_output) 298 return pooled_output 299 300 class BertEncoder(torch.nn.Module): 301 def __init__(self) -> None: 302 super().__init__() 303 304 def forward( 305 self, 306 hidden_states: torch.Tensor, 307 ) -> BaseModelOutputWithPastAndCrossAttentions: 308 return BaseModelOutputWithPastAndCrossAttentions( 309 last_hidden_state=hidden_states, 310 past_key_values=None, 311 hidden_states=None, 312 attentions=None, 313 cross_attentions=None, 314 ) 315 316 class BertModel(torch.nn.Module): 317 def __init__(self) -> None: 318 super().__init__() 319 self.encoder = BertEncoder() 320 self.pooler = BertPooler() 321 322 def forward( 323 self, 324 sequence_output: torch.Tensor, 325 ) -> BaseModelOutputWithPoolingAndCrossAttentions: 326 encoder_outputs = self.encoder(sequence_output) 327 # test __getitem__ and to_tuple 328 sequence_output = encoder_outputs[0] 329 pooled_output = ( 330 self.pooler(sequence_output) if self.pooler is not None else None 331 ) 332 # test CustomDictVariable.create 333 result = BaseModelOutputWithPoolingAndCrossAttentions( 334 last_hidden_state=sequence_output, 335 pooler_output=pooled_output, 336 past_key_values=encoder_outputs.past_key_values, 337 hidden_states=encoder_outputs.hidden_states, 338 attentions=encoder_outputs.attentions, 339 cross_attentions=encoder_outputs.cross_attentions, 340 ) 341 # test __setattr__ 342 result.pooler_output = pooled_output 343 # test __setitem__ 344 result["pooler_output"] = pooled_output 345 return result 346 347 sequence_output = torch.rand(1, 12, 768).to(device) 348 model = BertModel() 349 orig_result = model(sequence_output) 350 compiled_model = torch.compile(model, backend="eager") 351 compiled_result = compiled_model(sequence_output) 352 self.assertTrue( 353 torch.allclose( 354 orig_result.last_hidden_state, compiled_result.last_hidden_state 355 ) 356 ) 357 self.assertTrue( 358 torch.allclose(orig_result.pooler_output, compiled_result.pooler_output) 359 ) 360 361 362devices = ["cpu", "cuda"] 363if TEST_HPU: 364 devices.append("hpu") 365 366instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices) 367 368if __name__ == "__main__": 369 from torch._dynamo.test_case import run_tests 370 371 run_tests() 372