1# Owner(s): ["module: onnx"] 2import pytorch_test_common 3 4import torch 5import torch._dynamo 6import torch.fx 7from torch.onnx._internal.fx.passes import _utils as pass_utils 8from torch.testing._internal import common_utils 9 10 11class TestFxPasses(common_utils.TestCase): 12 def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self): 13 def func(x, y, z): 14 return x + y + z 15 16 x = torch.randn(3) 17 y = torch.randn(3) 18 z = torch.randn(3) 19 gm, _ = torch._dynamo.export(func)(x, y, z) 20 torch._dynamo.reset() 21 22 # Purposely name the nodes in a way that will cause a recursive collision later. 23 # See :func:`set_node_name` for name collision renaming logic. 24 base_name = "tensor" 25 nodes = list(gm.graph.nodes) 26 for i, node in enumerate(nodes[1:]): 27 if i == 0: 28 node.name = base_name 29 else: 30 node.name = f"{base_name}.{i}" 31 32 # Run `set_node_name` and verify that the names are correct. 33 name_to_node = {node.name: node for node in gm.graph.nodes} 34 pass_utils.set_node_name(nodes[0], base_name, name_to_node) 35 assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}" 36 assert len({node.name for node in nodes}) == len( 37 nodes 38 ), f"Expected all names to be unique, got {nodes}" 39 40 def test_set_node_name_succeeds_when_no_name_collisions(self): 41 def func(x, y, z): 42 return x + y + z 43 44 x = torch.randn(3) 45 y = torch.randn(3) 46 z = torch.randn(3) 47 gm, _ = torch._dynamo.export(func)(x, y, z) 48 torch._dynamo.reset() 49 50 # Run `set_node_name` and verify that the names are correct. 51 new_name = "some_tensor" 52 nodes = list(gm.graph.nodes) 53 name_to_node = {node.name: node for node in nodes} 54 pass_utils.set_node_name(nodes[1], new_name, name_to_node) 55 assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}" 56 assert len({node.name for node in nodes}) == len( 57 nodes 58 ), f"Expected all names to be unique, got {nodes}" 59 60 def test_onnx_dynamo_export_raises_when_model_contains_unsupported_fx_nodes(self): 61 @torch.library.custom_op( 62 "mylibrary::foo_op", device_types="cpu", mutates_args=() 63 ) 64 def foo_op(x: torch.Tensor) -> torch.Tensor: 65 return x + 1 66 67 @torch.library.custom_op( 68 "mylibrary::bar_op", device_types="cpu", mutates_args=() 69 ) 70 def bar_op(x: torch.Tensor) -> torch.Tensor: 71 return x + 2 72 73 @foo_op.register_fake 74 def _(x): 75 return torch.empty_like(x) 76 77 @bar_op.register_fake 78 def _(x): 79 return torch.empty_like(x) 80 81 def func(x, y, z): 82 return foo_op(x) + bar_op(y) + z 83 84 x = torch.randn(3) 85 y = torch.randn(3) 86 z = torch.randn(3) 87 with self.assertRaises(torch.onnx.OnnxExporterError) as ctx: 88 torch.onnx.dynamo_export(func, x, y, z) 89 inner_exception = ctx.exception.__cause__ 90 self.assertRegex( 91 str(inner_exception), 92 r"Unsupported FX nodes.*mylibrary\.foo_op.*mylibrary\.bar_op", 93 ) 94 95 torch._dynamo.reset() 96 97 98@common_utils.instantiate_parametrized_tests 99class TestModularizePass(common_utils.TestCase): 100 @pytorch_test_common.xfail( 101 error_message="'torch_nn_modules_activation_GELU_used_gelu_1' not found", 102 reason="optimizer", 103 ) 104 @common_utils.parametrize( 105 "is_exported_program", 106 [ 107 common_utils.subtest( 108 True, 109 name="exported_program", 110 ), 111 common_utils.subtest( 112 False, 113 name="nn_module", 114 ), 115 ], 116 ) 117 def test_modularize_pass_succeeds_when_submodule_output_is_unused( 118 self, is_exported_program 119 ): 120 # This is an ill-formed model, but exporter must not crash. 121 # It is illegal for submodule to have zero output. For modularization pass it can happen 122 # when the submodule output is unused, so no inner node is connected to any outer 123 # nodes. 124 # However, this also means the entire submodule should be erased by DCE. Hence 125 # it should never occur. 126 # 127 # Minified repro from Background_Matting. https://github.com/pytorch/benchmark/issues/1768 128 class TestModule(torch.nn.Module): 129 def __init__(self) -> None: 130 super().__init__() 131 self.unused_relu = torch.nn.ReLU() 132 self.used_gelu = torch.nn.GELU() 133 134 def forward(self, x, y): 135 result = self.used_gelu(x + y) 136 unused_relu_result = self.unused_relu(x) 137 return result 138 139 if is_exported_program: 140 model = torch.export.export( 141 TestModule(), args=(torch.randn(3), torch.randn(3)) 142 ) 143 else: 144 model = TestModule() 145 146 onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3)) 147 model_proto = onnx_program.model_proto 148 function_proto_names = [function.name for function in model_proto.functions] 149 self.assertIn( 150 "torch_nn_modules_activation_GELU_used_gelu_1", function_proto_names 151 ) 152 self.assertFalse(any("ReLU" in name for name in function_proto_names)) 153 154 @pytorch_test_common.xfail( 155 error_message="'torch_nn_modules_activation_ReLU_relu_1' not found", 156 reason="optimizer", 157 ) 158 @common_utils.parametrize( 159 "is_exported_program", 160 [ 161 common_utils.subtest( 162 True, 163 name="exported_program", 164 ), 165 common_utils.subtest( 166 False, 167 name="nn_module", 168 ), 169 ], 170 ) 171 def test_modularize_pass_succeeds_when_a_submodule_is_called_multiple_times( 172 self, is_exported_program 173 ): 174 class TestModule(torch.nn.Module): 175 def __init__(self) -> None: 176 super().__init__() 177 self.relu = torch.nn.ReLU() 178 179 def forward(self, x, y): 180 out = x + y 181 out = self.relu(out) 182 out = out + x 183 out = self.relu(out) 184 return out 185 186 if is_exported_program: 187 model = torch.export.export( 188 TestModule(), args=(torch.randn(3), torch.randn(3)) 189 ) 190 else: 191 model = TestModule() 192 193 onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3)) 194 model_proto = onnx_program.model_proto 195 function_proto_names = [function.name for function in model_proto.functions] 196 self.assertIn("torch_nn_modules_activation_ReLU_relu_1", function_proto_names) 197 self.assertIn("torch_nn_modules_activation_ReLU_relu_2", function_proto_names) 198 199 @pytorch_test_common.xfail( 200 error_message="'torch_nn_modules_activation_ReLU_inner_module_relu_1' not found", 201 reason="optimizer", 202 ) 203 @common_utils.parametrize( 204 "is_exported_program", 205 [ 206 common_utils.subtest( 207 True, 208 name="exported_program", 209 ), 210 common_utils.subtest( 211 False, 212 name="nn_module", 213 ), 214 ], 215 ) 216 def test_modularize_pass_succeeds_when_a_submodule_is_called_from_multiple_layers( 217 self, is_exported_program 218 ): 219 # Minified repro from basic_gnn_edgecnn. 220 class InnerModule(torch.nn.Module): 221 def __init__(self) -> None: 222 super().__init__() 223 self.relu = torch.nn.ReLU() 224 225 def forward(self, x): 226 return self.relu(x) 227 228 class TestModule(torch.nn.Module): 229 def __init__(self) -> None: 230 super().__init__() 231 self.inner_module = InnerModule() 232 233 def forward(self, x, y): 234 out = x + y 235 out = self.inner_module(out) 236 out = out + x 237 out = self.inner_module.relu(out) 238 return out 239 240 if is_exported_program: 241 model = torch.export.export( 242 TestModule(), args=(torch.randn(3), torch.randn(3)) 243 ) 244 else: 245 model = TestModule() 246 247 onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3)) 248 model_proto = onnx_program.model_proto 249 function_proto_names = [function.name for function in model_proto.functions] 250 self.assertIn( 251 "torch_nn_modules_activation_ReLU_inner_module_relu_1", function_proto_names 252 ) 253 self.assertIn( 254 "torch_nn_modules_activation_ReLU_inner_module_relu_2", function_proto_names 255 ) 256 # local module qualified name is unstable in test environment depending on different test 257 # invocation methods. 258 self.assertTrue( 259 any("InnerModule_inner_module_1" in name for name in function_proto_names) 260 ) 261 262 263if __name__ == "__main__": 264 common_utils.run_tests() 265