1# Owner(s): ["module: unknown"] 2 3from copy import copy 4 5import torch 6from torch import nn 7from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo 8from torch.utils.checkpoint import checkpoint 9from torch.utils.module_tracker import ModuleTracker 10 11 12class TestModuleTracker(TestCase): 13 # "https://github.com/pytorch/pytorch/issues/127112 14 @xfailIfTorchDynamo 15 def test_module_hierarchy(self): 16 seen_fw = [] 17 seen_bw = [] 18 19 class Foo(nn.Module): 20 def forward(self, x): 21 x = x["a"].relu_() 22 seen_fw.append((copy(tracker.parents), tracker.is_bw)) 23 x.register_hook( 24 lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw)) 25 ) 26 return {"a": torch.mm(x, x)} 27 28 class Mod(nn.Module): 29 def __init__(self) -> None: 30 super().__init__() 31 self.a = Foo() 32 self.b = nn.ModuleDict({"nest": Foo()}) 33 self.c = nn.ModuleList([Foo()]) 34 35 def forward(self, x): 36 x = self.c[0](x) 37 return self.b["nest"](self.a(x)) 38 39 mod = Mod() 40 41 with ModuleTracker() as tracker: 42 mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ 43 "a" 44 ].sum().backward() 45 mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ 46 "a" 47 ].sum().backward() 48 49 self.assertEqual( 50 seen_fw, 51 [ 52 ({"Global", "Mod", "Mod.c.0"}, False), 53 ({"Global", "Mod", "Mod.a"}, False), 54 ({"Global", "Mod", "Mod.b.nest"}, False), 55 ({"Global", "Mod", "Mod.c.0"}, False), 56 ({"Global", "Mod", "Mod.a"}, False), 57 ({"Global", "Mod", "Mod.b.nest"}, False), 58 ], 59 ) 60 61 self.assertEqual( 62 seen_bw, 63 [ 64 ({"Global", "Mod", "Mod.b.nest"}, True), 65 ({"Global", "Mod", "Mod.a"}, True), 66 ({"Global", "Mod", "Mod.c.0"}, True), 67 ({"Global", "Mod", "Mod.b.nest"}, True), 68 ({"Global", "Mod", "Mod.a"}, True), 69 ({"Global", "Mod", "Mod.c.0"}, True), 70 ], 71 ) 72 73 def test_confused_hierarchy(self): 74 class MyMod(nn.Module): 75 def __init__(self): 76 super().__init__() 77 self.inner = nn.Linear(2, 2) 78 self.ran = False 79 80 def forward(self, inp): 81 if not self.ran: 82 self.ran = True 83 return self(inp) 84 else: 85 self.ran = False 86 return self.inner(inp) 87 88 mod = MyMod() 89 inp = torch.rand(1, 2, requires_grad=True) 90 91 # Should not fail 92 with ModuleTracker() as tracker: 93 res = mod(inp) 94 res.sum().backward() 95 96 # Should not fail 97 with ModuleTracker() as tracker: 98 res = checkpoint(lambda inp: mod(inp), inp) 99 res.sum().backward() 100 101 def test_bw_detection(self): 102 mod = nn.Linear(2, 2) 103 104 with ModuleTracker() as tracker: 105 mod(torch.rand(2, requires_grad=True)).sum().backward() 106 self.assertFalse(tracker.is_bw) 107 self.assertEqual(tracker.parents, {"Global"}) 108 109 110if __name__ == "__main__": 111 run_tests() 112