1# Owner(s): ["module: dynamo"] 2from unittest import mock 3 4import torch 5import torch._dynamo 6import torch._dynamo.test_case 7from torch._inductor.utils import pass_execution_and_save 8 9 10class FxPassesPreGradTests(torch._dynamo.test_case.TestCase): 11 @mock.patch("torch._inductor.utils.ShapeProp.propagate") 12 def test_pass_execution_and_save(self, mock_shape_prop): 13 class TestModule(torch.nn.Module): 14 def __init__(self) -> None: 15 super().__init__() 16 self.param = torch.nn.Parameter(torch.ones(4, 4)) 17 18 def forward(self, x: torch.Tensor) -> torch.Tensor: 19 return self.param + x 20 21 def fx_pass(graph: torch.fx.GraphModule) -> None: 22 return 23 24 sample_input = torch.randn(4, 4) 25 m = TestModule() 26 m(sample_input) 27 exported_program = torch.export.export(m, (sample_input,)) 28 gm = exported_program.graph_module 29 30 pass_execution_and_save(fx_pass, gm, sample_input, "Apply testing pass") 31 mock_shape_prop.assert_called_once() 32 33 34if __name__ == "__main__": 35 from torch._dynamo.test_case import run_tests 36 37 run_tests() 38