xref: /aosp_15_r20/external/executorch/exir/tests/test_serde.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import io
10import unittest
11from typing import Tuple
12
13import executorch.exir as exir
14
15import torch
16from executorch.exir import to_edge
17from executorch.exir.backend.backend_api import CompileSpec, to_backend
18from executorch.exir.backend.test.backend_with_compiler_demo import (
19    BackendWithCompilerDemo,
20)
21
22from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
23from executorch.exir.serde.serialize import deserialize, serialize
24from torch import nn
25from torch.export import export
26from torch.export.exported_program import ExportedProgram as TorchExportedProgram
27from torch.utils import _pytree as pytree
28
29
30# Tests for serializing to json and back
31class TestSerde(unittest.TestCase):
32    def check_ep(
33        self,
34        ep1: TorchExportedProgram,
35        ep2: TorchExportedProgram,
36        inputs: Tuple[exir.Value, ...],
37    ) -> None:
38        """
39        Checks if two graphs are equivalent
40        """
41        orig_outputs = ep1.module()(*inputs)
42        loaded_outputs = ep2.module()(*inputs)
43
44        flat_orig_outputs, _ = pytree.tree_flatten(orig_outputs)
45        flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs)
46
47        for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True):
48            self.assertTrue(torch.allclose(orig, loaded))
49
50    # pyre-ignore
51    def check_serde(self, m, inputs, check_executorch=True) -> None:
52        aten = export(m, inputs)
53        aten_new = deserialize(serialize(aten))
54        self.check_ep(aten, aten_new, inputs)
55
56        edge = to_edge(aten)
57        edge_new = deserialize(serialize(edge.exported_program()))
58        self.check_ep(edge.exported_program(), edge_new, inputs)
59
60        buffer = io.BytesIO()
61        exir.save(edge.exported_program(), buffer)
62        buffer.seek(0)
63        loaded_ep = exir.load(buffer)
64        self.check_ep(edge.exported_program(), loaded_ep, inputs)
65
66        executorch = edge.to_executorch().exported_program()
67        executorch_new = deserialize(serialize(executorch))
68        if check_executorch:
69            with torch.no_grad():
70                self.check_ep(executorch, executorch_new, inputs)
71
72                buffer = io.BytesIO()
73                exir.save(executorch, buffer)
74                buffer.seek(0)
75                loaded_ep = exir.load(buffer)
76                self.check_ep(executorch, loaded_ep, inputs)
77
78    def test_basic(self) -> None:
79        class MyModule(torch.nn.Module):
80            def __init__(self):
81                super().__init__()
82
83            def forward(self, x):
84                x = x + x
85                x = x * x
86                x = x / x
87                return x, x.clone()
88
89        inputs = (torch.ones([512], requires_grad=True),)
90        self.check_serde(MyModule(), inputs)
91
92    def test_to_out_variant_singleon_tensor_list(self) -> None:
93        class MyModel(torch.nn.Module):
94            def __init__(self):
95                super().__init__()
96
97            def forward(self, x):
98                return torch.split(x, 10)
99
100            def get_random_inputs(self):
101                return (torch.randn(10),)
102
103        model = MyModel()
104        inputs = model.get_random_inputs()
105        # We set check_executorch to false for this test because this triggers
106        # an edge case where calling .module() on the executorch exported program
107        # will cause an unlift pass to be run on the graph and dead code elimination
108        # will be subsequently run, which essentially causes the split_copy op to be
109        # removed.
110        self.check_serde(model, inputs, check_executorch=False)
111
112    def test_to_out_variant_multiple_out(self) -> None:
113        class MyModel(torch.nn.Module):
114            def __init__(self):
115                super().__init__()
116
117            def forward(self, x):
118                values, indices = torch.topk(x, 5)
119                return (values, indices)
120
121            def get_random_inputs(self):
122                return (torch.randn(10),)
123
124        model = MyModel()
125        inputs = model.get_random_inputs()
126        self.check_serde(model, inputs)
127
128    def test_delegate(self) -> None:
129        class SinModule(torch.nn.Module):
130            def __init__(self):
131                super().__init__()
132
133            def forward(self, x):
134                return torch.sin(x)
135
136        sin_module = SinModule()
137        model_inputs = (torch.ones(1),)
138        edgeir_m = to_edge(export(sin_module, model_inputs))
139        max_value = model_inputs[0].shape[0]
140        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
141        lowered_sin_module = to_backend(
142            BackendWithCompilerDemo.__name__, edgeir_m.exported_program(), compile_specs
143        )
144
145        class CompositeModule(torch.nn.Module):
146            def __init__(self):
147                super().__init__()
148                self.lowered_linear_sin = lowered_sin_module
149
150            def forward(self, x):
151                return self.lowered_linear_sin(x)
152
153        composite_model = CompositeModule()
154        model_inputs = (torch.ones(1),)
155
156        composite_model(*model_inputs)
157
158        edge = to_edge(export(composite_model, model_inputs))
159        edge_new = deserialize(serialize(edge.exported_program()))
160        self.check_ep(edge.exported_program(), edge_new, model_inputs)
161
162    def test_model_with_weights(self) -> None:
163        class LinearAdd(nn.Module):
164            def __init__(self, M: int, N: int):
165                super().__init__()
166                self.M = M
167                self.N = N
168                self.linear = torch.nn.Linear(M, N)
169
170            def forward(self, x, y):
171                x = self.linear(x)
172                y = self.linear(y)
173                return torch.add(x, y)
174
175            @classmethod
176            def _get_random_inputs(cls):
177                return (torch.rand(128, 20), torch.rand(128, 20))
178
179        linear_add = LinearAdd(20, 30)
180        model_inputs = LinearAdd._get_random_inputs()
181
182        self.check_serde(linear_add, model_inputs)
183
184    def test_delegate_partitioner(self) -> None:
185        class Model(torch.nn.Module):
186            def __init__(self):
187                super().__init__()
188
189            def forward(self, a, x, b):
190                y = torch.mm(a, x)
191                z = y + b
192                a = z - a
193                y = torch.mm(a, x)
194                z = y + b
195                return z
196
197        m = Model()
198        inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
199
200        ep = to_edge(export(m, inputs))
201        edge = ep.to_backend(AddMulPartitionerDemo())
202        edge_new = deserialize(serialize(edge.exported_program()))
203        self.check_ep(edge.exported_program(), edge_new, inputs)
204
205    def test_meta_stack_trace_module_hierarchy(self) -> None:
206        class Model(nn.Module):
207            def __init__(self):
208                super(Model, self).__init__()
209                self.conv_layer = nn.Conv2d(
210                    in_channels=1, out_channels=64, kernel_size=3, padding=1
211                )
212
213            def forward(self, x):
214                return self.conv_layer(x)
215
216        m = Model()
217        inputs = (torch.randn(1, 1, 32, 32),)
218
219        metadata = ()
220        edge = to_edge(export(m, inputs))
221        for node in edge.exported_program().graph_module.graph.nodes:
222            if "convolution" in str(node.target):
223                metadata = (
224                    node.meta.get("stack_trace"),
225                    node.meta.get("nn_module_stack"),
226                )
227
228        metadata_serde = ()
229        edge_new = deserialize(serialize(edge.exported_program()))
230        for node in edge_new.graph_module.graph.nodes:
231            if "convolution" in str(node.target):
232                metadata_serde = (
233                    node.meta.get("stack_trace"),
234                    node.meta.get("nn_module_stack"),
235                )
236        self.assertTrue(len(metadata) != 0 and len(metadata_serde) != 0)
237        self.assertTrue(
238            all(val is not None for val in metadata)
239            and all(val is not None for val in metadata_serde)
240        )
241        self.assertEqual(metadata[0], metadata_serde[0])
242        self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys()))
243