1# Owner(s): ["oncall: distributed"] 2 3import torch 4from torch.distributed.fsdp._trace_utils import _ExecOrderTracer 5from torch.testing._internal.common_utils import ( 6 instantiate_parametrized_tests, 7 run_tests, 8 TestCase, 9) 10 11 12class Model(torch.nn.Module): 13 def __init__(self) -> None: 14 super().__init__() 15 self.weight1 = torch.nn.Parameter(torch.randn(6, 6)) 16 self.weight2 = torch.nn.Parameter(torch.randn(6, 6)) 17 self.weight_unused = torch.nn.Parameter(torch.randn(2, 2)) 18 self.layer0 = torch.nn.Linear(6, 6) 19 self.layer1 = torch.nn.Linear(6, 6, bias=False) 20 self.layer2 = torch.nn.Sequential( 21 torch.nn.Linear(6, 3, bias=False), 22 torch.nn.ReLU(), 23 torch.nn.Linear(3, 6, bias=False), 24 ) 25 self.relu = torch.nn.ReLU() 26 27 def forward(self, x: torch.Tensor, run_all_layers: bool) -> torch.Tensor: 28 z = self.relu(self.layer0(x)) 29 z = self.relu(self.layer2(z)) 30 z = z @ self.weight1 31 if run_all_layers: 32 z = self.relu(self.layer1(z)) 33 z = z @ self.weight2 34 # Use `layer0` twice to check the handling of multiplicity in the 35 # saved data structures 36 z = self.relu(self.layer0(x)) 37 return z 38 39 40class TestSymbolicTracing(TestCase): 41 def test_symbolic_tracing_outputs(self): 42 """ 43 Tests running ``tracer.trace()`` inside ``patch_tracer()`` by checking 44 the saved data structures. 45 """ 46 model = Model() 47 tracer = torch.fx.Tracer() 48 orig_call_module = tracer.call_module 49 orig_create_proxy = tracer.create_proxy 50 exec_order_tracer = _ExecOrderTracer() 51 with exec_order_tracer.patch_tracer(tracer=tracer, root_module=model): 52 concrete_args = {"run_all_layers": True} 53 tracer.trace(model, concrete_args) 54 # Check that the tracer methods are unchanged after exiting the context 55 self.assertEqual(orig_call_module, tracer.call_module) 56 self.assertEqual(orig_create_proxy, tracer.create_proxy) 57 # Check `module_forward_order` 58 correct_module_forward_order = [ 59 model, 60 model.layer0, 61 model.relu, 62 model.layer2, 63 model.layer2[0], 64 model.layer2[1], 65 model.layer2[2], 66 model.relu, 67 model.layer1, 68 model.relu, 69 model.layer0, 70 model.relu, 71 ] 72 exec_info = exec_order_tracer.exec_info 73 self.assertEqual(exec_info.module_forward_order, correct_module_forward_order) 74 # Check `module_to_param_usage_infos` 75 self.assertEqual( 76 exec_info.module_to_param_usage_infos[model], 77 [ 78 (model.layer0, list(model.layer0.named_parameters())), 79 (model.layer2, list(model.layer2.named_parameters())), 80 (model, [("weight1", model.weight1)]), 81 (model.layer1, list(model.layer1.named_parameters())), 82 (model, [("weight2", model.weight2)]), 83 (model.layer0, list(model.layer0.named_parameters())), 84 ], 85 ) 86 self.assertEqual( 87 exec_info.module_to_param_usage_infos[model.layer0], 88 [(model.layer0, list(model.layer0.named_parameters()))], 89 ) 90 self.assertEqual( 91 exec_info.module_to_param_usage_infos[model.layer1], 92 [(model.layer1, list(model.layer1.named_parameters()))], 93 ) 94 self.assertEqual( 95 exec_info.module_to_param_usage_infos[model.layer2], 96 [ 97 (model.layer2[0], list(model.layer2[0].named_parameters())), 98 (model.layer2[2], list(model.layer2[2].named_parameters())), 99 ], 100 ) 101 self.assertEqual(exec_info.module_to_param_usage_infos[model.relu], []) 102 # Check `param_forward_order` 103 correct_param_order = [ 104 model.layer0.weight, 105 model.layer0.bias, 106 model.layer2[0].weight, 107 model.layer2[2].weight, 108 model.weight1, 109 model.layer1.weight, 110 model.weight2, 111 ] 112 self.assertEqual(exec_info.param_forward_order, correct_param_order) 113 # Check `visited_params` 114 self.assertEqual( 115 len(exec_info.visited_params), len(exec_info.param_forward_order) 116 ) 117 self.assertEqual(exec_info.visited_params, set(exec_info.param_forward_order)) 118 119 120instantiate_parametrized_tests(TestSymbolicTracing) 121 122if __name__ == "__main__": 123 run_tests() 124