xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import torch
4from torch.distributed.fsdp._trace_utils import _ExecOrderTracer
5from torch.testing._internal.common_utils import (
6    instantiate_parametrized_tests,
7    run_tests,
8    TestCase,
9)
10
11
12class Model(torch.nn.Module):
13    def __init__(self) -> None:
14        super().__init__()
15        self.weight1 = torch.nn.Parameter(torch.randn(6, 6))
16        self.weight2 = torch.nn.Parameter(torch.randn(6, 6))
17        self.weight_unused = torch.nn.Parameter(torch.randn(2, 2))
18        self.layer0 = torch.nn.Linear(6, 6)
19        self.layer1 = torch.nn.Linear(6, 6, bias=False)
20        self.layer2 = torch.nn.Sequential(
21            torch.nn.Linear(6, 3, bias=False),
22            torch.nn.ReLU(),
23            torch.nn.Linear(3, 6, bias=False),
24        )
25        self.relu = torch.nn.ReLU()
26
27    def forward(self, x: torch.Tensor, run_all_layers: bool) -> torch.Tensor:
28        z = self.relu(self.layer0(x))
29        z = self.relu(self.layer2(z))
30        z = z @ self.weight1
31        if run_all_layers:
32            z = self.relu(self.layer1(z))
33            z = z @ self.weight2
34            # Use `layer0` twice to check the handling of multiplicity in the
35            # saved data structures
36            z = self.relu(self.layer0(x))
37        return z
38
39
40class TestSymbolicTracing(TestCase):
41    def test_symbolic_tracing_outputs(self):
42        """
43        Tests running ``tracer.trace()`` inside ``patch_tracer()`` by checking
44        the saved data structures.
45        """
46        model = Model()
47        tracer = torch.fx.Tracer()
48        orig_call_module = tracer.call_module
49        orig_create_proxy = tracer.create_proxy
50        exec_order_tracer = _ExecOrderTracer()
51        with exec_order_tracer.patch_tracer(tracer=tracer, root_module=model):
52            concrete_args = {"run_all_layers": True}
53            tracer.trace(model, concrete_args)
54        # Check that the tracer methods are unchanged after exiting the context
55        self.assertEqual(orig_call_module, tracer.call_module)
56        self.assertEqual(orig_create_proxy, tracer.create_proxy)
57        # Check `module_forward_order`
58        correct_module_forward_order = [
59            model,
60            model.layer0,
61            model.relu,
62            model.layer2,
63            model.layer2[0],
64            model.layer2[1],
65            model.layer2[2],
66            model.relu,
67            model.layer1,
68            model.relu,
69            model.layer0,
70            model.relu,
71        ]
72        exec_info = exec_order_tracer.exec_info
73        self.assertEqual(exec_info.module_forward_order, correct_module_forward_order)
74        # Check `module_to_param_usage_infos`
75        self.assertEqual(
76            exec_info.module_to_param_usage_infos[model],
77            [
78                (model.layer0, list(model.layer0.named_parameters())),
79                (model.layer2, list(model.layer2.named_parameters())),
80                (model, [("weight1", model.weight1)]),
81                (model.layer1, list(model.layer1.named_parameters())),
82                (model, [("weight2", model.weight2)]),
83                (model.layer0, list(model.layer0.named_parameters())),
84            ],
85        )
86        self.assertEqual(
87            exec_info.module_to_param_usage_infos[model.layer0],
88            [(model.layer0, list(model.layer0.named_parameters()))],
89        )
90        self.assertEqual(
91            exec_info.module_to_param_usage_infos[model.layer1],
92            [(model.layer1, list(model.layer1.named_parameters()))],
93        )
94        self.assertEqual(
95            exec_info.module_to_param_usage_infos[model.layer2],
96            [
97                (model.layer2[0], list(model.layer2[0].named_parameters())),
98                (model.layer2[2], list(model.layer2[2].named_parameters())),
99            ],
100        )
101        self.assertEqual(exec_info.module_to_param_usage_infos[model.relu], [])
102        # Check `param_forward_order`
103        correct_param_order = [
104            model.layer0.weight,
105            model.layer0.bias,
106            model.layer2[0].weight,
107            model.layer2[2].weight,
108            model.weight1,
109            model.layer1.weight,
110            model.weight2,
111        ]
112        self.assertEqual(exec_info.param_forward_order, correct_param_order)
113        # Check `visited_params`
114        self.assertEqual(
115            len(exec_info.visited_params), len(exec_info.param_forward_order)
116        )
117        self.assertEqual(exec_info.visited_params, set(exec_info.param_forward_order))
118
119
120instantiate_parametrized_tests(TestSymbolicTracing)
121
122if __name__ == "__main__":
123    run_tests()
124