xref: /aosp_15_r20/external/pytorch/test/export/test_lift_unlift.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2import unittest
3from typing import Any, Dict, Optional, OrderedDict, Tuple
4
5import torch
6from torch._export.passes.lift_constants_pass import (
7    ConstantAttrMap,
8    lift_constants_pass,
9)
10from torch.export._unlift import _unlift_exported_program_lifted_states
11from torch.export.exported_program import (
12    ExportGraphSignature,
13    InputKind,
14    InputSpec,
15    OutputKind,
16    OutputSpec,
17    TensorArgument,
18)
19from torch.export.graph_signature import CustomObjArgument
20from torch.testing._internal.common_utils import (
21    find_library_location,
22    IS_FBCODE,
23    IS_MACOS,
24    IS_SANDCASTLE,
25    IS_WINDOWS,
26    run_tests,
27    TestCase,
28)
29
30
31class GraphBuilder:
32    def __init__(self) -> None:
33        self.graph = torch.fx.Graph()
34        self.nodes = {}
35        self.values = {}
36        self.nn_module_stack_key: Dict[str, int] = {}
37        self.latest_id = 0
38        self.input_to_kind: Dict[torch.fx.Node, InputKind] = {}
39
40    def input(self, name: str, value: torch.Tensor, kind: InputKind):
41        node = self.graph.placeholder(name)
42        node.meta["val"] = value
43        self.nodes[name] = node
44        self.values[name] = value
45        self.input_to_kind[node] = kind
46
47    def add(self, x: str, y: str, out: str, module_fqn: str = ""):
48        node = self.graph.create_node(
49            "call_function",
50            torch.ops.aten.add.Tensor,
51            (self.nodes[x], self.nodes[y]),
52            name=out,
53        )
54        self.values[out] = self.values[x] + self.values[y]
55        node.meta["val"] = self.values[out]
56        node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn)
57        self.nodes[out] = node
58
59    def call_function(self, target, args, out: str, module_fqn: str = ""):
60        arg_nodes = tuple(self.nodes[arg] for arg in args)
61        arg_values = tuple(self.values[arg] for arg in args)
62        node = self.graph.create_node(
63            "call_function",
64            target,
65            arg_nodes,
66            name=out,
67        )
68        self.values[out] = target(*arg_values)
69        node.meta["val"] = self.values[out]
70        node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn)
71        self.nodes[out] = node
72
73    def constant(
74        self, name: str, value: Any, target: Optional[str] = None, module_fqn: str = ""
75    ):
76        if target is None:
77            target = name
78        node = self.graph.get_attr(target)
79        node.meta["val"] = value
80        node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn)
81        self.nodes[name] = node
82        self.values[name] = value
83
84    def output(self, out: str):
85        self.graph.output(self.nodes[out])
86
87    def create_nn_module_stack(
88        self, module_fqn: str
89    ) -> OrderedDict[int, Tuple[str, type]]:
90        cur_name = ""
91        nn_module_stack = OrderedDict()
92        for atom in module_fqn.split("."):
93            if cur_name == "":
94                cur_name = atom
95            else:
96                cur_name = cur_name + "." + atom
97
98            if cur_name not in self.nn_module_stack_key:
99                id_counter = self.latest_id
100                self.latest_id += 1
101                self.nn_module_stack_key[cur_name] = id_counter
102            else:
103                id_counter = self.nn_module_stack_key[cur_name]
104
105            nn_module_stack[id_counter] = (cur_name, torch.nn.Module)
106        return nn_module_stack
107
108    def create_input_specs(self):
109        input_specs = []
110        for node in self.graph.nodes:
111            if node.op == "placeholder":
112                input_specs.append(
113                    InputSpec(
114                        kind=self.input_to_kind[node],
115                        arg=TensorArgument(name=node.name),
116                        target=None,
117                        persistent=(
118                            True
119                            if self.input_to_kind[node] == InputKind.BUFFER
120                            else None
121                        ),
122                    )
123                )
124        return input_specs
125
126    # NOTE: does not handle non-user-outputs atm
127    def gen_graph_signature(self) -> ExportGraphSignature:
128        output = [n for n in self.graph.nodes if n.op == "output"]
129        assert len(output) == 1
130        output = output[0]
131        assert len(output.args) == 1, "multiple outputs NYI"
132
133        return ExportGraphSignature(
134            input_specs=self.create_input_specs(),
135            output_specs=[
136                OutputSpec(
137                    kind=OutputKind.USER_OUTPUT,
138                    arg=TensorArgument(name=n.name),
139                    target=None,
140                )
141                for n in output.args
142            ],
143        )
144
145
146class TestLift(TestCase):
147    def setUp(self):
148        if IS_MACOS:
149            raise unittest.SkipTest("non-portable load_library call used in test")
150        elif IS_SANDCASTLE or IS_FBCODE:
151            torch.ops.load_library(
152                "//caffe2/test/cpp/jit:test_custom_class_registrations"
153            )
154        elif IS_WINDOWS:
155            lib_file_path = find_library_location("torchbind_test.dll")
156            torch.ops.load_library(str(lib_file_path))
157        else:
158            lib_file_path = find_library_location("libtorchbind_test.so")
159            torch.ops.load_library(str(lib_file_path))
160
161    def test_lift_basic(self):
162        builder = GraphBuilder()
163
164        builder.input("param", torch.rand(2, 3), InputKind.PARAMETER)
165        builder.input("buffer", torch.rand(2, 3), InputKind.BUFFER)
166        builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT)
167        builder.input("y", torch.rand(2, 3), InputKind.USER_INPUT)
168
169        builder.add("x", "y", out="foo")
170        builder.add("foo", "param", out="bar")
171        builder.add("bar", "buffer", out="baz")
172        builder.constant("const_tensor", torch.rand(2, 3))
173        builder.constant("const_obj", torch.classes._TorchScriptTesting._Foo(10, 20))
174        builder.add("baz", "const_tensor", out="out")
175        builder.call_function(
176            torch.ops._TorchScriptTesting.takes_foo,
177            ("const_obj", "x"),
178            out="torchbind_out",
179        )
180        builder.add("out", "torchbind_out", out="final_out")
181        builder.output("final_out")
182
183        builder.graph.lint()
184        graph = builder.graph
185        const_tensor = builder.values["const_tensor"]
186        const_obj = builder.values["const_obj"]
187
188        root = {"const_tensor": const_tensor, "const_obj": const_obj}
189        gm = torch.fx.GraphModule(root, graph)
190        graph_signature = builder.gen_graph_signature()
191        constants = lift_constants_pass(gm, graph_signature, {})
192        gm.graph.lint()
193
194        self.assertEqual(len(constants), 2)
195
196        # The key of the constants table should match the fqn of the constant.
197        # In this case, it's just the name of the constant, since the constant
198        # is at the root submodule.
199        # TODO(suo): we shouldn't hardcode these names in the test, this is an
200        # internal detail of the pass.
201        self.assertIn("lifted_tensor_0", constants)
202        self.assertEqual(constants["lifted_tensor_0"], const_tensor)
203        self.assertIn("lifted_custom_0", constants)
204        self.assertEqual(constants["lifted_custom_0"], const_obj)
205
206        # The constant node should be removed.
207        getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
208        self.assertEqual(len(getattr_nodes), 0)
209
210        # The constant should be lifted to a placeholder node.
211        placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"]
212        self.assertEqual(len(placeholder_nodes), 6)
213
214        # The lifted constant should be placed before user inputs but after params/buffers
215        lifted_tensor_placeholder = placeholder_nodes[2]
216        self.assertEqual(lifted_tensor_placeholder.target, "lifted_tensor_0")
217        # It should have a val equivalent to the constant
218        self.assertEqual(lifted_tensor_placeholder.meta["val"], const_tensor)
219
220        lifted_obj_placeholder = placeholder_nodes[3]
221        self.assertEqual(lifted_obj_placeholder.target, "lifted_custom_0")
222        # It should have a val equivalent to the constant
223        self.assertEqual(
224            lifted_obj_placeholder.meta["val"],
225            CustomObjArgument(
226                name="lifted_custom_0",
227                class_fqn="__torch__.torch.classes._TorchScriptTesting._Foo",
228            ),
229        )
230
231        # Graph signature should have been mutated a way that reflects the placeholders.
232        tensor_constant_input_spec = graph_signature.input_specs[2]
233        self.assertEqual(tensor_constant_input_spec.kind, InputKind.CONSTANT_TENSOR)
234        self.assertIsInstance(tensor_constant_input_spec.arg, TensorArgument)
235        self.assertEqual(
236            tensor_constant_input_spec.arg.name, lifted_tensor_placeholder.name
237        )
238
239        obj_constant_input_spec = graph_signature.input_specs[3]
240        self.assertEqual(obj_constant_input_spec.kind, InputKind.CUSTOM_OBJ)
241        self.assertIsInstance(obj_constant_input_spec.arg, CustomObjArgument)
242        self.assertEqual(obj_constant_input_spec.arg.name, lifted_obj_placeholder.name)
243
244    def test_lift_nested(self):
245        builder = GraphBuilder()
246        builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT)
247        builder.input("y", torch.rand(2, 3), InputKind.USER_INPUT)
248        builder.input("z", torch.rand(2, 3), InputKind.USER_INPUT)
249
250        builder.add("x", "y", out="foo")
251        builder.add("foo", "z", out="bar", module_fqn="foo")
252        builder.constant("const_tensor", torch.rand(2, 3), module_fqn="foo")
253        builder.add("bar", "const_tensor", "out")
254        builder.output("out")
255
256        graph = builder.graph
257        graph.lint()
258
259        const_tensor = builder.values["const_tensor"]
260        root = {"const_tensor": builder.values["const_tensor"]}
261
262        graph_signature = builder.gen_graph_signature()
263        gm = torch.fx.GraphModule(root, graph)
264
265        constants = lift_constants_pass(gm, graph_signature, {})
266        gm.graph.lint()
267
268        self.assertEqual(len(constants), 1)
269
270        # The key of the constants table should match the fqn of the constant.
271        self.assertIn("foo.lifted_tensor_0", constants)
272        self.assertEqual(constants["foo.lifted_tensor_0"], const_tensor)
273
274        # The constant node should be removed.
275        getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
276        self.assertEqual(len(getattr_nodes), 0)
277
278        # The constant should be lifted to a placeholder node.
279        placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"]
280        self.assertEqual(len(placeholder_nodes), 4)
281
282        # The lifted constant should be placed before user inputs but after params/buffers
283        lifted_constant_placeholder = placeholder_nodes[0]
284        self.assertEqual(lifted_constant_placeholder.target, "lifted_tensor_0")
285
286        # Graph signature should have been mutated a way that reflects the placeholders.
287        constant_input_spec = graph_signature.input_specs[0]
288        self.assertEqual(constant_input_spec.kind, InputKind.CONSTANT_TENSOR)
289        self.assertIsInstance(constant_input_spec.arg, TensorArgument)
290        self.assertEqual(constant_input_spec.arg.name, lifted_constant_placeholder.name)
291
292    def test_duplicate_constant_access(self):
293        const = torch.rand(2, 3)
294        const_obj = torch.classes._TorchScriptTesting._Foo(10, 20)
295
296        builder = GraphBuilder()
297        builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT)
298        builder.constant("const_tensor", const, target="const_tensor")
299        # loading the same target twice
300        builder.constant("const_tensor2", const, target="const_tensor")
301
302        # loading the same object twice with different targets
303        builder.constant("const_obj", const_obj)
304        builder.constant("const_obj2", const_obj)
305        builder.call_function(
306            torch.ops._TorchScriptTesting.takes_foo,
307            ("const_obj", "x"),
308            out="torchbind_out",
309        )
310        builder.call_function(
311            torch.ops._TorchScriptTesting.takes_foo,
312            ("const_obj2", "x"),
313            out="torchbind_out2",
314        )
315        builder.add("x", "const_tensor", out="foo")
316        builder.add("foo", "const_tensor2", out="tensor_out")
317        builder.add("torchbind_out", "torchbind_out2", out="obj_out")
318        builder.add("tensor_out", "obj_out", out="out")
319        builder.output("out")
320        graph = builder.graph
321        graph.lint()
322
323        input_specs = builder.create_input_specs()
324        output_specs = [
325            OutputSpec(
326                kind=OutputKind.USER_OUTPUT,
327                arg=TensorArgument(name=builder.nodes["out"].name),
328                target=None,
329            )
330        ]
331        graph_signature = ExportGraphSignature(input_specs, output_specs)
332
333        root = {"const_tensor": const, "const_obj": const_obj, "const_obj2": const_obj}
334        gm = torch.fx.GraphModule(root, graph)
335
336        constants = lift_constants_pass(gm, graph_signature, {})
337        gm.graph.lint()
338
339        self.assertEqual(len(constants), 2)
340
341        # All get_attr nodes should be removed
342        getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
343        self.assertEqual(len(getattr_nodes), 0)
344
345        # There should only be two additional inputs (plus the existing user input)
346        placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"]
347        self.assertEqual(len(placeholder_nodes), 3)
348
349        # Graph signature should have been mutated a way that reflects the placeholders.
350        self.assertEqual(len(graph_signature.input_specs), 3)
351        constant_input_spec = graph_signature.input_specs[0]
352        self.assertEqual(constant_input_spec.kind, InputKind.CONSTANT_TENSOR)
353        self.assertIsInstance(constant_input_spec.arg, TensorArgument)
354
355    def test_unlift_nonpersistent_buffer(self):
356        class Foo(torch.nn.Module):
357            def __init__(self) -> None:
358                super().__init__()
359                self.register_buffer(
360                    "non_persistent_buf", torch.zeros(1), persistent=False
361                )
362
363            def forward(self, x):
364                self.non_persistent_buf.add_(1)
365                return x.sum() + self.non_persistent_buf.sum()
366
367        foo = Foo()
368        exported = torch.export.export(foo, (torch.ones(5, 5),), strict=False)
369        stateful_gm = _unlift_exported_program_lifted_states(exported)
370
371        # Check the unlifted stateful_gm contains the original non-persistent buffer
372        self.assertTrue(hasattr(stateful_gm, "non_persistent_buf"))
373        non_persistent_buf = stateful_gm.get_buffer("non_persistent_buf")
374        self.assertEqual(non_persistent_buf, foo.get_buffer("non_persistent_buf"))
375        self.assertIn("non_persistent_buf", stateful_gm._non_persistent_buffers_set)
376        self.assertNotIn("non_persistent_buf", stateful_gm.state_dict())
377
378
379class ConstantAttrMapTest(TestCase):
380    def setUp(self):
381        if IS_MACOS:
382            raise unittest.SkipTest("non-portable load_library call used in test")
383        elif IS_SANDCASTLE or IS_FBCODE:
384            torch.ops.load_library(
385                "//caffe2/test/cpp/jit:test_custom_class_registrations"
386            )
387        elif IS_WINDOWS:
388            lib_file_path = find_library_location("torchbind_test.dll")
389            torch.ops.load_library(str(lib_file_path))
390        else:
391            lib_file_path = find_library_location("libtorchbind_test.so")
392            torch.ops.load_library(str(lib_file_path))
393
394    def test_dict_api(self):
395        constant_attr_map = ConstantAttrMap()
396        const_obj = torch.classes._TorchScriptTesting._Foo(10, 20)
397        const_tensor = torch.ones(2, 3)
398        constant_attr_map.add(const_obj, "foo.bar")
399        constant_attr_map.add(const_tensor, "foo.bar.baz")
400        self.assertEqual(len(constant_attr_map), 2)
401        self.assertEqual(list(constant_attr_map), [const_obj, const_tensor])
402        self.assertEqual(list(constant_attr_map.keys()), [const_obj, const_tensor])
403        self.assertEqual(
404            list(constant_attr_map.values()), [["foo.bar"], ["foo.bar.baz"]]
405        )
406        self.assertEqual(constant_attr_map[const_obj], ["foo.bar"])
407        self.assertEqual(constant_attr_map[const_tensor], ["foo.bar.baz"])
408        self.assertTrue(const_obj in constant_attr_map)
409        with self.assertRaises(TypeError):
410            constant_attr_map.add(1, "foo.bar")
411
412        del constant_attr_map[const_obj]
413        self.assertEqual(len(constant_attr_map), 1)
414
415
416if __name__ == "__main__":
417    run_tests()
418