# Owner(s): ["module: unknown"] from copy import copy import torch from torch import nn from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo from torch.utils.checkpoint import checkpoint from torch.utils.module_tracker import ModuleTracker class TestModuleTracker(TestCase): # "https://github.com/pytorch/pytorch/issues/127112 @xfailIfTorchDynamo def test_module_hierarchy(self): seen_fw = [] seen_bw = [] class Foo(nn.Module): def forward(self, x): x = x["a"].relu_() seen_fw.append((copy(tracker.parents), tracker.is_bw)) x.register_hook( lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw)) ) return {"a": torch.mm(x, x)} class Mod(nn.Module): def __init__(self) -> None: super().__init__() self.a = Foo() self.b = nn.ModuleDict({"nest": Foo()}) self.c = nn.ModuleList([Foo()]) def forward(self, x): x = self.c[0](x) return self.b["nest"](self.a(x)) mod = Mod() with ModuleTracker() as tracker: mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ "a" ].sum().backward() mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ "a" ].sum().backward() self.assertEqual( seen_fw, [ ({"Global", "Mod", "Mod.c.0"}, False), ({"Global", "Mod", "Mod.a"}, False), ({"Global", "Mod", "Mod.b.nest"}, False), ({"Global", "Mod", "Mod.c.0"}, False), ({"Global", "Mod", "Mod.a"}, False), ({"Global", "Mod", "Mod.b.nest"}, False), ], ) self.assertEqual( seen_bw, [ ({"Global", "Mod", "Mod.b.nest"}, True), ({"Global", "Mod", "Mod.a"}, True), ({"Global", "Mod", "Mod.c.0"}, True), ({"Global", "Mod", "Mod.b.nest"}, True), ({"Global", "Mod", "Mod.a"}, True), ({"Global", "Mod", "Mod.c.0"}, True), ], ) def test_confused_hierarchy(self): class MyMod(nn.Module): def __init__(self): super().__init__() self.inner = nn.Linear(2, 2) self.ran = False def forward(self, inp): if not self.ran: self.ran = True return self(inp) else: self.ran = False return self.inner(inp) mod = MyMod() inp = torch.rand(1, 2, requires_grad=True) # Should not fail with ModuleTracker() as tracker: res = mod(inp) res.sum().backward() # Should not fail with ModuleTracker() as tracker: res = checkpoint(lambda inp: mod(inp), inp) res.sum().backward() def test_bw_detection(self): mod = nn.Linear(2, 2) with ModuleTracker() as tracker: mod(torch.rand(2, requires_grad=True)).sum().backward() self.assertFalse(tracker.is_bw) self.assertEqual(tracker.parents, {"Global"}) if __name__ == "__main__": run_tests()