xref: /aosp_15_r20/external/pytorch/test/dynamo/test_fx_passes_pre_grad.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2from unittest import mock
3
4import torch
5import torch._dynamo
6import torch._dynamo.test_case
7from torch._inductor.utils import pass_execution_and_save
8
9
10class FxPassesPreGradTests(torch._dynamo.test_case.TestCase):
11    @mock.patch("torch._inductor.utils.ShapeProp.propagate")
12    def test_pass_execution_and_save(self, mock_shape_prop):
13        class TestModule(torch.nn.Module):
14            def __init__(self) -> None:
15                super().__init__()
16                self.param = torch.nn.Parameter(torch.ones(4, 4))
17
18            def forward(self, x: torch.Tensor) -> torch.Tensor:
19                return self.param + x
20
21        def fx_pass(graph: torch.fx.GraphModule) -> None:
22            return
23
24        sample_input = torch.randn(4, 4)
25        m = TestModule()
26        m(sample_input)
27        exported_program = torch.export.export(m, (sample_input,))
28        gm = exported_program.graph_module
29
30        pass_execution_and_save(fx_pass, gm, sample_input, "Apply testing pass")
31        mock_shape_prop.assert_called_once()
32
33
34if __name__ == "__main__":
35    from torch._dynamo.test_case import run_tests
36
37    run_tests()
38