1# Owner(s): ["oncall: jit"] 2 3import io 4import os 5import sys 6 7import torch 8 9 10# Make the helper files in test/ importable 11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12sys.path.append(pytorch_test_dir) 13from torch.testing._internal.common_utils import suppress_warnings 14from torch.testing._internal.jit_utils import JitTestCase 15 16 17if __name__ == "__main__": 18 raise RuntimeError( 19 "This test file is not meant to be run directly, use:\n\n" 20 "\tpython test/test_jit.py TESTNAME\n\n" 21 "instead." 22 ) 23 24 25class TestTypeSharing(JitTestCase): 26 def assertSameType(self, m1, m2): 27 if not isinstance(m1, torch.jit.ScriptModule): 28 m1 = torch.jit.script(m1) 29 if not isinstance(m2, torch.jit.ScriptModule): 30 m2 = torch.jit.script(m2) 31 self.assertEqual(m1._c._type(), m2._c._type()) 32 33 def assertDifferentType(self, m1, m2): 34 if not isinstance(m1, torch.jit.ScriptModule): 35 m1 = torch.jit.script(m1) 36 if not isinstance(m2, torch.jit.ScriptModule): 37 m2 = torch.jit.script(m2) 38 self.assertNotEqual(m1._c._type(), m2._c._type()) 39 40 def test_basic(self): 41 class M(torch.nn.Module): 42 def __init__(self, a, b, c): 43 super().__init__() 44 self.a = a 45 self.b = b 46 self.c = c 47 48 def forward(self, x): 49 return x 50 51 a = torch.rand(2, 3) 52 b = torch.rand(2, 3) 53 c = torch.rand(2, 3) 54 m1 = M(a, b, c) 55 m2 = M(a, b, c) 56 self.assertSameType(m1, m2) 57 58 def test_diff_attr_values(self): 59 """ 60 Types should be shared even if attribute values differ 61 """ 62 63 class M(torch.nn.Module): 64 def __init__(self, a, b, c): 65 super().__init__() 66 self.a = a 67 self.b = b 68 self.c = c 69 70 def forward(self, x): 71 return x 72 73 a = torch.rand(2, 3) 74 b = torch.rand(2, 3) 75 c = torch.rand(2, 3) 76 m1 = M(a, b, c) 77 m2 = M(a * 2, b * 3, c * 4) 78 self.assertSameType(m1, m2) 79 80 def test_constants(self): 81 """ 82 Types should be shared for identical constant values, and different for different constant values 83 """ 84 85 class M(torch.nn.Module): 86 __constants__ = ["const"] 87 88 def __init__(self, attr, const): 89 super().__init__() 90 self.attr = attr 91 self.const = const 92 93 def forward(self): 94 return self.const 95 96 attr = torch.rand(2, 3) 97 m1 = M(attr, 1) 98 m2 = M(attr, 1) 99 self.assertSameType(m1, m2) 100 101 # a different constant value 102 m3 = M(attr, 2) 103 self.assertDifferentType(m1, m3) 104 105 def test_linear(self): 106 """ 107 Simple example with a real nn Module 108 """ 109 a = torch.nn.Linear(5, 5) 110 b = torch.nn.Linear(5, 5) 111 c = torch.nn.Linear(10, 10) 112 a = torch.jit.script(a) 113 b = torch.jit.script(b) 114 c = torch.jit.script(c) 115 116 self.assertSameType(a, b) 117 self.assertDifferentType(a, c) 118 119 def test_submodules(self): 120 """ 121 If submodules differ, the types should differ. 122 """ 123 124 class M(torch.nn.Module): 125 def __init__(self, in1, out1, in2, out2): 126 super().__init__() 127 self.submod1 = torch.nn.Linear(in1, out1) 128 self.submod2 = torch.nn.Linear(in2, out2) 129 130 def forward(self, x): 131 x = self.submod1(x) 132 x = self.submod2(x) 133 return x 134 135 a = M(1, 1, 2, 2) 136 b = M(1, 1, 2, 2) 137 self.assertSameType(a, b) 138 self.assertSameType(a.submod1, b.submod1) 139 c = M(2, 2, 2, 2) 140 self.assertDifferentType(a, c) 141 142 self.assertSameType(b.submod2, c.submod1) 143 self.assertDifferentType(a.submod1, b.submod2) 144 145 def test_param_vs_attribute(self): 146 """ 147 The same module with an `foo` as a parameter vs. attribute shouldn't 148 share types 149 """ 150 151 class M(torch.nn.Module): 152 def __init__(self, foo): 153 super().__init__() 154 self.foo = foo 155 156 def forward(self, x): 157 return x + self.foo 158 159 as_param = torch.nn.Parameter(torch.ones(2, 2)) 160 as_attr = torch.ones(2, 2) 161 param_mod = M(as_param) 162 attr_mod = M(as_attr) 163 self.assertDifferentType(attr_mod, param_mod) 164 165 def test_same_but_different_classes(self): 166 """ 167 Even if everything about the module is the same, different originating 168 classes should prevent type sharing. 169 """ 170 171 class A(torch.nn.Module): 172 __constants__ = ["const"] 173 174 def __init__(self, in1, out1, in2, out2): 175 super().__init__() 176 self.submod1 = torch.nn.Linear(in1, out1) 177 self.submod2 = torch.nn.Linear(in2, out2) 178 self.const = 5 179 180 def forward(self, x): 181 x = self.submod1(x) 182 x = self.submod2(x) 183 return x * self.const 184 185 class B(torch.nn.Module): 186 __constants__ = ["const"] 187 188 def __init__(self, in1, out1, in2, out2): 189 super().__init__() 190 self.submod1 = torch.nn.Linear(in1, out1) 191 self.submod2 = torch.nn.Linear(in2, out2) 192 self.const = 5 193 194 def forward(self, x): 195 x = self.submod1(x) 196 x = self.submod2(x) 197 return x * self.const 198 199 a = A(1, 1, 2, 2) 200 b = B(1, 1, 2, 2) 201 self.assertDifferentType(a, b) 202 203 def test_mutate_attr_value(self): 204 """ 205 Mutating the value of an attribute should not change type sharing 206 """ 207 208 class M(torch.nn.Module): 209 def __init__(self, in1, out1, in2, out2): 210 super().__init__() 211 self.submod1 = torch.nn.Linear(in1, out1) 212 self.submod2 = torch.nn.Linear(in2, out2) 213 self.foo = torch.ones(in1, in1) 214 215 def forward(self, x): 216 x = self.submod1(x) 217 x = self.submod2(x) 218 return x + self.foo 219 220 a = M(1, 1, 2, 2) 221 b = M(1, 1, 2, 2) 222 a.foo = torch.ones(2, 2) 223 b.foo = torch.rand(2, 2) 224 self.assertSameType(a, b) 225 226 def test_assign_python_attr(self): 227 """ 228 Assigning a new (python-only) attribute should not change type sharing 229 """ 230 231 class M(torch.nn.Module): 232 def __init__(self, in1, out1, in2, out2): 233 super().__init__() 234 self.submod1 = torch.nn.Linear(in1, out1) 235 self.submod2 = torch.nn.Linear(in2, out2) 236 self.foo = torch.ones(in1, in1) 237 238 def forward(self, x): 239 x = self.submod1(x) 240 x = self.submod2(x) 241 return x + self.foo 242 243 # explicitly call script() to freeze the type 244 a = torch.jit.script(M(1, 1, 2, 2)) 245 b = torch.jit.script(M(1, 1, 2, 2)) 246 a.new_attr = "foo bar baz" 247 self.assertSameType(a, b) 248 249 # but if we assign attributes *before* calling script(), the types 250 # should be different, since `new_attr` should be turned into a Script 251 # attribute 252 a = M(1, 1, 2, 2) 253 b = M(1, 1, 2, 2) 254 a.new_attr = "foo bar baz" 255 self.assertDifferentType(a, b) 256 257 def test_failed_attribute_compilation(self): 258 """ 259 Attributes whose type cannot be inferred should fail cleanly with nice hints 260 """ 261 262 class M(torch.nn.Module): 263 def __init__(self) -> None: 264 super().__init__() 265 # assign a type we know can't be converted to TorchScript 266 self.foo = object 267 268 def forward(self): 269 # try to use it in forward 270 return self.foo 271 272 m = M() 273 with self.assertRaisesRegexWithHighlight( 274 RuntimeError, "failed to convert Python type", "self.foo" 275 ): 276 torch.jit.script(m) 277 278 def test_script_function_attribute_different(self): 279 """ 280 Different functions passed in should lead to different types 281 """ 282 283 @torch.jit.script 284 def fn1(x): 285 return x + x 286 287 @torch.jit.script 288 def fn2(x): 289 return x - x 290 291 class M(torch.nn.Module): 292 def __init__(self, fn): 293 super().__init__() 294 self.fn = fn 295 296 def forward(self, x): 297 return self.fn(x) 298 299 fn1_mod = M(fn1) 300 fn2_mod = M(fn2) 301 302 self.assertDifferentType(fn1_mod, fn2_mod) 303 304 def test_builtin_function_same(self): 305 class Caller(torch.nn.Module): 306 def __init__(self, fn): 307 super().__init__() 308 self.fn = fn 309 310 def forward(self, input): 311 return self.fn(input, input) 312 313 c1 = Caller(torch.add) 314 c2 = Caller(torch.add) 315 316 self.assertSameType(c1, c2) 317 318 def test_builtin_function_different(self): 319 class Caller(torch.nn.Module): 320 def __init__(self, fn): 321 super().__init__() 322 self.fn = fn 323 324 def forward(self, input): 325 return self.fn(input, input) 326 327 c1 = Caller(torch.add) 328 c2 = Caller(torch.sub) 329 330 self.assertDifferentType(c1, c2) 331 332 def test_script_function_attribute_same(self): 333 """ 334 Same functions passed in should lead to same types 335 """ 336 337 @torch.jit.script 338 def fn(x): 339 return x + x 340 341 class M(torch.nn.Module): 342 def __init__(self, fn): 343 super().__init__() 344 self.fn = fn 345 346 def forward(self, x): 347 return self.fn(x) 348 349 fn1_mod = M(fn) 350 fn2_mod = M(fn) 351 352 self.assertSameType(fn1_mod, fn2_mod) 353 354 def test_python_function_attribute_different(self): 355 """ 356 Different functions passed in should lead to different types 357 """ 358 359 def fn1(x): 360 return x + x 361 362 def fn2(x): 363 return x - x 364 365 class M(torch.nn.Module): 366 def __init__(self, fn): 367 super().__init__() 368 self.fn = fn 369 370 def forward(self, x): 371 return self.fn(x) 372 373 fn1_mod = M(fn1) 374 fn2_mod = M(fn2) 375 376 self.assertDifferentType(fn1_mod, fn2_mod) 377 378 def test_python_function_attribute_same(self): 379 """ 380 Same functions passed in should lead to same types 381 """ 382 383 def fn(x): 384 return x + x 385 386 class M(torch.nn.Module): 387 def __init__(self, fn): 388 super().__init__() 389 self.fn = fn 390 391 def forward(self, x): 392 return self.fn(x) 393 394 fn1_mod = M(fn) 395 fn2_mod = M(fn) 396 397 self.assertSameType(fn1_mod, fn2_mod) 398 399 @suppress_warnings 400 def test_tracing_gives_different_types(self): 401 """ 402 Since we can't guarantee that methods are the same between different 403 trace runs, tracing must always generate a unique type. 404 """ 405 406 class M(torch.nn.Module): 407 def forward(self, x, y): 408 if x.sum() > y.sum(): 409 return x 410 else: 411 return y 412 413 a = torch.jit.trace(M(), (torch.zeros(1, 1), torch.ones(1, 1))) 414 b = torch.jit.trace(M(), (torch.ones(1, 1), torch.zeros(1, 1))) 415 self.assertDifferentType(a, b) 416 417 def test_ignored_fns(self): 418 class M(torch.nn.Module): 419 def __init__(self, foo): 420 super().__init__() 421 self.foo = foo 422 423 @torch.jit.ignore 424 def ignored(self): 425 return self.foo 426 427 def forward(self): 428 return self.ignored() 429 430 a = torch.jit.script(M(torch.ones(1))) 431 b = torch.jit.script(M(torch.ones(2))) 432 self.assertSameType(a, b) 433 self.assertNotEqual(a(), b()) 434 435 @suppress_warnings 436 def test_script_module_containing_traced_module(self): 437 class Traced(torch.nn.Module): 438 def forward(self, x): 439 if x.sum() > 0: 440 return x 441 else: 442 return x + x 443 444 class M(torch.nn.Module): 445 def __init__(self, input): 446 super().__init__() 447 self.traced = torch.jit.trace(Traced(), input) 448 449 def forward(self, x): 450 return self.traced(x) 451 452 a = M((torch.ones(1),)) 453 b = M((torch.zeros(1),)) 454 self.assertDifferentType(a, b) 455 456 def test_loaded_modules_work(self): 457 class AB(torch.nn.Module): 458 def __init__(self) -> None: 459 super().__init__() 460 self.a = 1 461 self.b = 1 462 463 def forward(self): 464 return self.a + self.b 465 466 class A(torch.nn.Module): 467 def __init__(self) -> None: 468 super().__init__() 469 self.a = 1 470 471 def forward(self): 472 return self.a 473 474 class Wrapper(torch.nn.Module): 475 def __init__(self, sub): 476 super().__init__() 477 self.sub = sub 478 479 def forward(self): 480 return self.sub() 481 482 def package(x): 483 buffer = io.BytesIO() 484 torch.jit.save(torch.jit.script(x), buffer) 485 buffer.seek(0) 486 return torch.jit.script(Wrapper(torch.jit.load(buffer))) 487 488 a = package(AB()) 489 a() 490 b = package(A()) 491 b() 492 493 def test_module_dict_same_type_different_name(self): 494 """ 495 We should be able to differentiate between two ModuleDict instances 496 that have different keys but the same value types. 497 """ 498 499 class A(torch.nn.Module): 500 def forward(self, x): 501 return x 502 503 class Foo(torch.nn.Module): 504 def __init__(self, s): 505 super().__init__() 506 self.dict = torch.nn.ModuleDict(s) 507 508 def forward(self, x): 509 return x 510 511 a = Foo({"foo": A()}) 512 b = Foo({"bar": A()}) 513 c = Foo({"bar": A()}) 514 self.assertDifferentType(a, b) 515 self.assertSameType(b, c) 516 517 def test_type_sharing_define_in_init(self): 518 """ 519 Tests that types between instances of a ScriptModule 520 subclass that defines methods in its __init__ are not 521 shared. 522 """ 523 524 class A(torch.jit.ScriptModule): 525 def __init__(self, val): 526 super().__init__() 527 self.define( 528 f""" 529 def forward(self) -> int: 530 return {val} 531 """ 532 ) 533 534 one = A(1) 535 two = A(2) 536 537 self.assertEqual(one(), 1) 538 self.assertEqual(two(), 2) 539 540 def test_type_sharing_disabled(self): 541 """ 542 Test that type sharing can be disabled. 543 """ 544 545 class A(torch.nn.Module): 546 def __init__(self, sub): 547 super().__init__() 548 self.sub = sub 549 550 def forward(self, x): 551 return x 552 553 class B(torch.nn.Module): 554 def forward(self, x): 555 return x 556 557 top1 = A(A(B())) 558 top2 = A(A(B())) 559 560 top1_s = torch.jit._recursive.create_script_module( 561 top1, 562 torch.jit._recursive.infer_methods_to_compile, 563 share_types=False, 564 ) 565 top2_s = torch.jit._recursive.create_script_module( 566 top2, 567 torch.jit._recursive.infer_methods_to_compile, 568 share_types=False, 569 ) 570 571 self.assertDifferentType(top1_s, top2_s) 572 self.assertDifferentType(top1_s, top1_s.sub) 573 self.assertDifferentType(top1_s, top2_s.sub) 574 self.assertDifferentType(top2_s, top2_s.sub) 575 self.assertDifferentType(top2_s, top1_s.sub) 576 577 def test_type_shared_ignored_attributes(self): 578 """ 579 Test that types are shared if the exclusion of their 580 ignored attributes makes them equal. 581 """ 582 583 class A(torch.nn.Module): 584 __jit_ignored_attributes__ = ["a"] 585 586 def __init__(self, a, b): 587 super().__init__() 588 self.a = a 589 self.b = b 590 591 def forward(self, x): 592 return x 593 594 a_with_linear = A(torch.nn.Linear(5, 5), 5) 595 a_with_string = A("string", 10) 596 597 # Both should have the same type because the attribute 598 # that differs in type is ignored and the common attribute 599 # has the same type. 600 self.assertSameType(a_with_linear, a_with_string) 601 602 def test_type_not_shared_ignored_attributes(self): 603 """ 604 Test that types are not shared if the exclusion of their 605 ignored attributes makes them not equal. 606 """ 607 608 class A(torch.nn.Module): 609 __jit_ignored_attributes__ = ["a"] 610 611 def __init__(self, a, b, c): 612 super().__init__() 613 self.a = a 614 self.b = b 615 self.c = c 616 617 def forward(self, x): 618 return x 619 620 mod = A(torch.nn.Linear(5, 5), 5, "string") 621 s1 = torch.jit.script(mod) 622 A.__jit_ignored_attributes__ = ["a", "b"] 623 s2 = torch.jit.script(mod) 624 625 # The types of s1 and s2 should differ. Although they are instances 626 # of A, __jit_ignored_attributes__ was modified before scripting s2, 627 # so the set of ignored attributes is different between s1 and s2. 628 self.assertDifferentType(s1, s2) 629