1# Owner(s): ["oncall: export"] 2import unittest 3from typing import Any, Dict, Optional, OrderedDict, Tuple 4 5import torch 6from torch._export.passes.lift_constants_pass import ( 7 ConstantAttrMap, 8 lift_constants_pass, 9) 10from torch.export._unlift import _unlift_exported_program_lifted_states 11from torch.export.exported_program import ( 12 ExportGraphSignature, 13 InputKind, 14 InputSpec, 15 OutputKind, 16 OutputSpec, 17 TensorArgument, 18) 19from torch.export.graph_signature import CustomObjArgument 20from torch.testing._internal.common_utils import ( 21 find_library_location, 22 IS_FBCODE, 23 IS_MACOS, 24 IS_SANDCASTLE, 25 IS_WINDOWS, 26 run_tests, 27 TestCase, 28) 29 30 31class GraphBuilder: 32 def __init__(self) -> None: 33 self.graph = torch.fx.Graph() 34 self.nodes = {} 35 self.values = {} 36 self.nn_module_stack_key: Dict[str, int] = {} 37 self.latest_id = 0 38 self.input_to_kind: Dict[torch.fx.Node, InputKind] = {} 39 40 def input(self, name: str, value: torch.Tensor, kind: InputKind): 41 node = self.graph.placeholder(name) 42 node.meta["val"] = value 43 self.nodes[name] = node 44 self.values[name] = value 45 self.input_to_kind[node] = kind 46 47 def add(self, x: str, y: str, out: str, module_fqn: str = ""): 48 node = self.graph.create_node( 49 "call_function", 50 torch.ops.aten.add.Tensor, 51 (self.nodes[x], self.nodes[y]), 52 name=out, 53 ) 54 self.values[out] = self.values[x] + self.values[y] 55 node.meta["val"] = self.values[out] 56 node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn) 57 self.nodes[out] = node 58 59 def call_function(self, target, args, out: str, module_fqn: str = ""): 60 arg_nodes = tuple(self.nodes[arg] for arg in args) 61 arg_values = tuple(self.values[arg] for arg in args) 62 node = self.graph.create_node( 63 "call_function", 64 target, 65 arg_nodes, 66 name=out, 67 ) 68 self.values[out] = target(*arg_values) 69 node.meta["val"] = self.values[out] 70 node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn) 71 self.nodes[out] = node 72 73 def constant( 74 self, name: str, value: Any, target: Optional[str] = None, module_fqn: str = "" 75 ): 76 if target is None: 77 target = name 78 node = self.graph.get_attr(target) 79 node.meta["val"] = value 80 node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn) 81 self.nodes[name] = node 82 self.values[name] = value 83 84 def output(self, out: str): 85 self.graph.output(self.nodes[out]) 86 87 def create_nn_module_stack( 88 self, module_fqn: str 89 ) -> OrderedDict[int, Tuple[str, type]]: 90 cur_name = "" 91 nn_module_stack = OrderedDict() 92 for atom in module_fqn.split("."): 93 if cur_name == "": 94 cur_name = atom 95 else: 96 cur_name = cur_name + "." + atom 97 98 if cur_name not in self.nn_module_stack_key: 99 id_counter = self.latest_id 100 self.latest_id += 1 101 self.nn_module_stack_key[cur_name] = id_counter 102 else: 103 id_counter = self.nn_module_stack_key[cur_name] 104 105 nn_module_stack[id_counter] = (cur_name, torch.nn.Module) 106 return nn_module_stack 107 108 def create_input_specs(self): 109 input_specs = [] 110 for node in self.graph.nodes: 111 if node.op == "placeholder": 112 input_specs.append( 113 InputSpec( 114 kind=self.input_to_kind[node], 115 arg=TensorArgument(name=node.name), 116 target=None, 117 persistent=( 118 True 119 if self.input_to_kind[node] == InputKind.BUFFER 120 else None 121 ), 122 ) 123 ) 124 return input_specs 125 126 # NOTE: does not handle non-user-outputs atm 127 def gen_graph_signature(self) -> ExportGraphSignature: 128 output = [n for n in self.graph.nodes if n.op == "output"] 129 assert len(output) == 1 130 output = output[0] 131 assert len(output.args) == 1, "multiple outputs NYI" 132 133 return ExportGraphSignature( 134 input_specs=self.create_input_specs(), 135 output_specs=[ 136 OutputSpec( 137 kind=OutputKind.USER_OUTPUT, 138 arg=TensorArgument(name=n.name), 139 target=None, 140 ) 141 for n in output.args 142 ], 143 ) 144 145 146class TestLift(TestCase): 147 def setUp(self): 148 if IS_MACOS: 149 raise unittest.SkipTest("non-portable load_library call used in test") 150 elif IS_SANDCASTLE or IS_FBCODE: 151 torch.ops.load_library( 152 "//caffe2/test/cpp/jit:test_custom_class_registrations" 153 ) 154 elif IS_WINDOWS: 155 lib_file_path = find_library_location("torchbind_test.dll") 156 torch.ops.load_library(str(lib_file_path)) 157 else: 158 lib_file_path = find_library_location("libtorchbind_test.so") 159 torch.ops.load_library(str(lib_file_path)) 160 161 def test_lift_basic(self): 162 builder = GraphBuilder() 163 164 builder.input("param", torch.rand(2, 3), InputKind.PARAMETER) 165 builder.input("buffer", torch.rand(2, 3), InputKind.BUFFER) 166 builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT) 167 builder.input("y", torch.rand(2, 3), InputKind.USER_INPUT) 168 169 builder.add("x", "y", out="foo") 170 builder.add("foo", "param", out="bar") 171 builder.add("bar", "buffer", out="baz") 172 builder.constant("const_tensor", torch.rand(2, 3)) 173 builder.constant("const_obj", torch.classes._TorchScriptTesting._Foo(10, 20)) 174 builder.add("baz", "const_tensor", out="out") 175 builder.call_function( 176 torch.ops._TorchScriptTesting.takes_foo, 177 ("const_obj", "x"), 178 out="torchbind_out", 179 ) 180 builder.add("out", "torchbind_out", out="final_out") 181 builder.output("final_out") 182 183 builder.graph.lint() 184 graph = builder.graph 185 const_tensor = builder.values["const_tensor"] 186 const_obj = builder.values["const_obj"] 187 188 root = {"const_tensor": const_tensor, "const_obj": const_obj} 189 gm = torch.fx.GraphModule(root, graph) 190 graph_signature = builder.gen_graph_signature() 191 constants = lift_constants_pass(gm, graph_signature, {}) 192 gm.graph.lint() 193 194 self.assertEqual(len(constants), 2) 195 196 # The key of the constants table should match the fqn of the constant. 197 # In this case, it's just the name of the constant, since the constant 198 # is at the root submodule. 199 # TODO(suo): we shouldn't hardcode these names in the test, this is an 200 # internal detail of the pass. 201 self.assertIn("lifted_tensor_0", constants) 202 self.assertEqual(constants["lifted_tensor_0"], const_tensor) 203 self.assertIn("lifted_custom_0", constants) 204 self.assertEqual(constants["lifted_custom_0"], const_obj) 205 206 # The constant node should be removed. 207 getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"] 208 self.assertEqual(len(getattr_nodes), 0) 209 210 # The constant should be lifted to a placeholder node. 211 placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"] 212 self.assertEqual(len(placeholder_nodes), 6) 213 214 # The lifted constant should be placed before user inputs but after params/buffers 215 lifted_tensor_placeholder = placeholder_nodes[2] 216 self.assertEqual(lifted_tensor_placeholder.target, "lifted_tensor_0") 217 # It should have a val equivalent to the constant 218 self.assertEqual(lifted_tensor_placeholder.meta["val"], const_tensor) 219 220 lifted_obj_placeholder = placeholder_nodes[3] 221 self.assertEqual(lifted_obj_placeholder.target, "lifted_custom_0") 222 # It should have a val equivalent to the constant 223 self.assertEqual( 224 lifted_obj_placeholder.meta["val"], 225 CustomObjArgument( 226 name="lifted_custom_0", 227 class_fqn="__torch__.torch.classes._TorchScriptTesting._Foo", 228 ), 229 ) 230 231 # Graph signature should have been mutated a way that reflects the placeholders. 232 tensor_constant_input_spec = graph_signature.input_specs[2] 233 self.assertEqual(tensor_constant_input_spec.kind, InputKind.CONSTANT_TENSOR) 234 self.assertIsInstance(tensor_constant_input_spec.arg, TensorArgument) 235 self.assertEqual( 236 tensor_constant_input_spec.arg.name, lifted_tensor_placeholder.name 237 ) 238 239 obj_constant_input_spec = graph_signature.input_specs[3] 240 self.assertEqual(obj_constant_input_spec.kind, InputKind.CUSTOM_OBJ) 241 self.assertIsInstance(obj_constant_input_spec.arg, CustomObjArgument) 242 self.assertEqual(obj_constant_input_spec.arg.name, lifted_obj_placeholder.name) 243 244 def test_lift_nested(self): 245 builder = GraphBuilder() 246 builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT) 247 builder.input("y", torch.rand(2, 3), InputKind.USER_INPUT) 248 builder.input("z", torch.rand(2, 3), InputKind.USER_INPUT) 249 250 builder.add("x", "y", out="foo") 251 builder.add("foo", "z", out="bar", module_fqn="foo") 252 builder.constant("const_tensor", torch.rand(2, 3), module_fqn="foo") 253 builder.add("bar", "const_tensor", "out") 254 builder.output("out") 255 256 graph = builder.graph 257 graph.lint() 258 259 const_tensor = builder.values["const_tensor"] 260 root = {"const_tensor": builder.values["const_tensor"]} 261 262 graph_signature = builder.gen_graph_signature() 263 gm = torch.fx.GraphModule(root, graph) 264 265 constants = lift_constants_pass(gm, graph_signature, {}) 266 gm.graph.lint() 267 268 self.assertEqual(len(constants), 1) 269 270 # The key of the constants table should match the fqn of the constant. 271 self.assertIn("foo.lifted_tensor_0", constants) 272 self.assertEqual(constants["foo.lifted_tensor_0"], const_tensor) 273 274 # The constant node should be removed. 275 getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"] 276 self.assertEqual(len(getattr_nodes), 0) 277 278 # The constant should be lifted to a placeholder node. 279 placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"] 280 self.assertEqual(len(placeholder_nodes), 4) 281 282 # The lifted constant should be placed before user inputs but after params/buffers 283 lifted_constant_placeholder = placeholder_nodes[0] 284 self.assertEqual(lifted_constant_placeholder.target, "lifted_tensor_0") 285 286 # Graph signature should have been mutated a way that reflects the placeholders. 287 constant_input_spec = graph_signature.input_specs[0] 288 self.assertEqual(constant_input_spec.kind, InputKind.CONSTANT_TENSOR) 289 self.assertIsInstance(constant_input_spec.arg, TensorArgument) 290 self.assertEqual(constant_input_spec.arg.name, lifted_constant_placeholder.name) 291 292 def test_duplicate_constant_access(self): 293 const = torch.rand(2, 3) 294 const_obj = torch.classes._TorchScriptTesting._Foo(10, 20) 295 296 builder = GraphBuilder() 297 builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT) 298 builder.constant("const_tensor", const, target="const_tensor") 299 # loading the same target twice 300 builder.constant("const_tensor2", const, target="const_tensor") 301 302 # loading the same object twice with different targets 303 builder.constant("const_obj", const_obj) 304 builder.constant("const_obj2", const_obj) 305 builder.call_function( 306 torch.ops._TorchScriptTesting.takes_foo, 307 ("const_obj", "x"), 308 out="torchbind_out", 309 ) 310 builder.call_function( 311 torch.ops._TorchScriptTesting.takes_foo, 312 ("const_obj2", "x"), 313 out="torchbind_out2", 314 ) 315 builder.add("x", "const_tensor", out="foo") 316 builder.add("foo", "const_tensor2", out="tensor_out") 317 builder.add("torchbind_out", "torchbind_out2", out="obj_out") 318 builder.add("tensor_out", "obj_out", out="out") 319 builder.output("out") 320 graph = builder.graph 321 graph.lint() 322 323 input_specs = builder.create_input_specs() 324 output_specs = [ 325 OutputSpec( 326 kind=OutputKind.USER_OUTPUT, 327 arg=TensorArgument(name=builder.nodes["out"].name), 328 target=None, 329 ) 330 ] 331 graph_signature = ExportGraphSignature(input_specs, output_specs) 332 333 root = {"const_tensor": const, "const_obj": const_obj, "const_obj2": const_obj} 334 gm = torch.fx.GraphModule(root, graph) 335 336 constants = lift_constants_pass(gm, graph_signature, {}) 337 gm.graph.lint() 338 339 self.assertEqual(len(constants), 2) 340 341 # All get_attr nodes should be removed 342 getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"] 343 self.assertEqual(len(getattr_nodes), 0) 344 345 # There should only be two additional inputs (plus the existing user input) 346 placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"] 347 self.assertEqual(len(placeholder_nodes), 3) 348 349 # Graph signature should have been mutated a way that reflects the placeholders. 350 self.assertEqual(len(graph_signature.input_specs), 3) 351 constant_input_spec = graph_signature.input_specs[0] 352 self.assertEqual(constant_input_spec.kind, InputKind.CONSTANT_TENSOR) 353 self.assertIsInstance(constant_input_spec.arg, TensorArgument) 354 355 def test_unlift_nonpersistent_buffer(self): 356 class Foo(torch.nn.Module): 357 def __init__(self) -> None: 358 super().__init__() 359 self.register_buffer( 360 "non_persistent_buf", torch.zeros(1), persistent=False 361 ) 362 363 def forward(self, x): 364 self.non_persistent_buf.add_(1) 365 return x.sum() + self.non_persistent_buf.sum() 366 367 foo = Foo() 368 exported = torch.export.export(foo, (torch.ones(5, 5),), strict=False) 369 stateful_gm = _unlift_exported_program_lifted_states(exported) 370 371 # Check the unlifted stateful_gm contains the original non-persistent buffer 372 self.assertTrue(hasattr(stateful_gm, "non_persistent_buf")) 373 non_persistent_buf = stateful_gm.get_buffer("non_persistent_buf") 374 self.assertEqual(non_persistent_buf, foo.get_buffer("non_persistent_buf")) 375 self.assertIn("non_persistent_buf", stateful_gm._non_persistent_buffers_set) 376 self.assertNotIn("non_persistent_buf", stateful_gm.state_dict()) 377 378 379class ConstantAttrMapTest(TestCase): 380 def setUp(self): 381 if IS_MACOS: 382 raise unittest.SkipTest("non-portable load_library call used in test") 383 elif IS_SANDCASTLE or IS_FBCODE: 384 torch.ops.load_library( 385 "//caffe2/test/cpp/jit:test_custom_class_registrations" 386 ) 387 elif IS_WINDOWS: 388 lib_file_path = find_library_location("torchbind_test.dll") 389 torch.ops.load_library(str(lib_file_path)) 390 else: 391 lib_file_path = find_library_location("libtorchbind_test.so") 392 torch.ops.load_library(str(lib_file_path)) 393 394 def test_dict_api(self): 395 constant_attr_map = ConstantAttrMap() 396 const_obj = torch.classes._TorchScriptTesting._Foo(10, 20) 397 const_tensor = torch.ones(2, 3) 398 constant_attr_map.add(const_obj, "foo.bar") 399 constant_attr_map.add(const_tensor, "foo.bar.baz") 400 self.assertEqual(len(constant_attr_map), 2) 401 self.assertEqual(list(constant_attr_map), [const_obj, const_tensor]) 402 self.assertEqual(list(constant_attr_map.keys()), [const_obj, const_tensor]) 403 self.assertEqual( 404 list(constant_attr_map.values()), [["foo.bar"], ["foo.bar.baz"]] 405 ) 406 self.assertEqual(constant_attr_map[const_obj], ["foo.bar"]) 407 self.assertEqual(constant_attr_map[const_tensor], ["foo.bar.baz"]) 408 self.assertTrue(const_obj in constant_attr_map) 409 with self.assertRaises(TypeError): 410 constant_attr_map.add(1, "foo.bar") 411 412 del constant_attr_map[const_obj] 413 self.assertEqual(len(constant_attr_map), 1) 414 415 416if __name__ == "__main__": 417 run_tests() 418