1# Owner(s): ["module: fx"] 2 3from collections import defaultdict 4from typing import Dict, List, Tuple 5 6import torch 7from torch.fx.passes.split_utils import split_by_tags 8from torch.testing._internal.common_utils import TestCase 9 10 11class TestFXSplit(TestCase): 12 def test_split_preserve_node_meta(self): 13 class TestModule(torch.nn.Module): 14 def forward(self, x, y): 15 x = x + x 16 y = y * y 17 return x - y 18 19 gm = torch.fx.symbolic_trace(TestModule()) 20 for node in gm.graph.nodes: 21 node.meta["name"] = node.name 22 if node.name == "add": 23 node.tag = "a" 24 elif node.name == "mul": 25 node.tag = "b" 26 elif node.name == "sub": 27 node.tag = "c" 28 29 split_gm = split_by_tags(gm, ["a", "b", "c"]) 30 for m in split_gm.children(): 31 for n in m.graph.nodes: 32 if n.op != "output": 33 self.assertIn("name", n.meta) 34 self.assertEqual(n.meta["name"], n.name) 35 36 # Validate that metadata is copied correctly for graph placeholder nodes 37 for node in split_gm.graph.nodes: 38 if node.op == "placeholder": 39 self.assertIn("name", node.meta) 40 self.assertEqual(node.meta["name"], node.name) 41 42 43class TestSplitByTags(TestCase): 44 class TestModule(torch.nn.Module): 45 def __init__(self) -> None: 46 super().__init__() 47 self.linear1 = torch.nn.Linear(2, 3) 48 self.linear2 = torch.nn.Linear(4, 5) 49 self.linear3 = torch.nn.Linear(6, 7) 50 self.linear4 = torch.nn.Linear(8, 6) 51 52 def forward( 53 self, 54 x1: torch.Tensor, 55 x2: torch.Tensor, 56 x3: torch.Tensor, 57 ) -> torch.Tensor: 58 v1 = self.linear1(x1) 59 v2 = self.linear2(x2) 60 v3 = self.linear3(x3) 61 v4 = torch.cat([v1, v2, v3]) 62 return self.linear4(v4) 63 64 @staticmethod 65 def trace_and_tag( 66 module: torch.nn.Module, tags: List[str] 67 ) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]: 68 """ 69 Test simple gm consists of nodes with tag (only show call_module nodes here): 70 linear1 - tag: "red" 71 linear2 - tag: "blue" 72 linear3, linear4 - tag: "green" 73 74 At the beginning we have: 75 gm: 76 linear1 77 linear2 78 linear3 79 linear4 80 81 split_gm = split_by_tags(gm, tags) 82 83 Then we have: 84 split_gm: 85 red: 86 linear1 87 blue: 88 linear2 89 green: 90 linear3 91 linear4 92 """ 93 tag_node = defaultdict(list) 94 gm: torch.fx.GraphModule = torch.fx.symbolic_trace(module) 95 96 # Add tag to all nodes and build dictionary record tag to call_module nodes 97 for node in gm.graph.nodes: 98 if "linear1" in node.name: 99 node.tag = tags[0] 100 tag_node[tags[0]].append(node.name) 101 elif "linear2" in node.name: 102 node.tag = tags[1] 103 tag_node[tags[1]].append(node.name) 104 else: 105 node.tag = tags[2] 106 if node.op == "call_module": 107 tag_node[tags[2]].append(node.name) 108 return gm, tag_node 109 110 def test_split_by_tags(self) -> None: 111 tags = ["red", "blue", "green"] 112 module = TestSplitByTags.TestModule() 113 gm, tag_node = TestSplitByTags.trace_and_tag(module, tags) 114 split_gm, orig_to_split_fqn_mapping = split_by_tags( 115 gm, tags, return_fqn_mapping=True 116 ) 117 # Ensure split_gm has (and only has) ordered submodules named 118 # red_0, blue_1, green_2 119 for idx, (name, _) in enumerate(split_gm.named_children()): 120 if idx < len(tags): 121 self.assertTrue( 122 name == tags[idx], 123 f"split_gm has an incorrect submodule named {name}", 124 ) 125 126 # Ensure each submodule has expected (ordered) call_module node(s). 127 # For example, a submodule named split_gm.red_0 has (and only has) linear1; 128 # split_gm.green_2 has (and only has) linear3 and linear4 with order 129 sub_graph_idx = 0 130 for sub_name, sub_graph_module in split_gm.named_children(): 131 node_idx = 0 132 for node in sub_graph_module.graph.nodes: 133 if node.op != "call_module": 134 continue 135 self.assertTrue( 136 node.name == tag_node[f"{sub_name}"][node_idx], 137 # pyre-fixme[61]: `name` is undefined, or not always defined. 138 f"{sub_name} has incorrectly include {node.name}", 139 ) 140 node_idx += 1 141 sub_graph_idx += 1 142 143 self.assertEqual( 144 orig_to_split_fqn_mapping, 145 { 146 "linear1": "red.linear1", 147 "linear2": "blue.linear2", 148 "linear3": "green.linear3", 149 "linear4": "green.linear4", 150 }, 151 f"{orig_to_split_fqn_mapping=}", 152 ) 153 154 155class TestSplitOutputType(TestCase): 156 class TestModule(torch.nn.Module): 157 def __init__(self) -> None: 158 super().__init__() 159 self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) 160 self.relu = torch.nn.ReLU() 161 162 def forward(self, x): 163 conv = self.conv(x) 164 conv = conv * 0.5 165 relu = self.relu(conv) 166 return relu 167 168 @staticmethod 169 def trace_and_tag( 170 module: torch.nn.Module, inputs: torch.Tensor, tags: List[str] 171 ) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]: 172 """ 173 Test simple gm consists of nodes with tag (only show call_module nodes here): 174 conv - tag: "red" 175 mul - tag: "blue" 176 relu - tag: "green" 177 178 At the beginning we have: 179 gm: 180 conv 181 mul 182 relu 183 184 split_gm = split_by_tags(gm, tags) 185 186 Then we have: 187 split_gm: 188 red: 189 conv 190 blue: 191 mul 192 green: 193 relu 194 """ 195 tag_node = defaultdict(list) 196 gm: torch.fx.GraphModule = torch.export.export(module, (inputs,)).module() 197 # Add tag to all nodes and build dictionary record tag to call_module nodes 198 for node in gm.graph.nodes: 199 if "conv" in node.name: 200 node.tag = tags[0] 201 tag_node[tags[0]].append(node.name) 202 elif "mul" in node.name: 203 node.tag = tags[1] 204 tag_node[tags[1]].append(node.name) 205 else: 206 node.tag = tags[2] 207 if node.op == "call_module": 208 tag_node[tags[2]].append(node.name) 209 return gm, tag_node 210 211 def test_split_by_tags(self) -> None: 212 tags = ["red", "blue", "green"] 213 module = TestSplitOutputType.TestModule() 214 215 inputs = torch.randn((1, 3, 224, 224)) 216 217 gm, tag_node = TestSplitOutputType.trace_and_tag(module, inputs, tags) 218 split_gm, orig_to_split_fqn_mapping = split_by_tags( 219 gm, tags, return_fqn_mapping=True 220 ) 221 222 gm_output = module(inputs) 223 split_gm_output = split_gm(inputs) 224 225 self.assertTrue(type(gm_output) == type(split_gm_output)) 226 self.assertTrue(torch.equal(gm_output, split_gm_output)) 227