xref: /aosp_15_r20/external/executorch/exir/tests/test_serde.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport io
10*523fa7a6SAndroid Build Coastguard Workerimport unittest
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Tuple
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workerimport torch
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import to_edge
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_api import CompileSpec, to_backend
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.test.backend_with_compiler_demo import (
19*523fa7a6SAndroid Build Coastguard Worker    BackendWithCompilerDemo,
20*523fa7a6SAndroid Build Coastguard Worker)
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.serde.serialize import deserialize, serialize
24*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn
25*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export
26*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram as TorchExportedProgram
27*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker# Tests for serializing to json and back
31*523fa7a6SAndroid Build Coastguard Workerclass TestSerde(unittest.TestCase):
32*523fa7a6SAndroid Build Coastguard Worker    def check_ep(
33*523fa7a6SAndroid Build Coastguard Worker        self,
34*523fa7a6SAndroid Build Coastguard Worker        ep1: TorchExportedProgram,
35*523fa7a6SAndroid Build Coastguard Worker        ep2: TorchExportedProgram,
36*523fa7a6SAndroid Build Coastguard Worker        inputs: Tuple[exir.Value, ...],
37*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
38*523fa7a6SAndroid Build Coastguard Worker        """
39*523fa7a6SAndroid Build Coastguard Worker        Checks if two graphs are equivalent
40*523fa7a6SAndroid Build Coastguard Worker        """
41*523fa7a6SAndroid Build Coastguard Worker        orig_outputs = ep1.module()(*inputs)
42*523fa7a6SAndroid Build Coastguard Worker        loaded_outputs = ep2.module()(*inputs)
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker        flat_orig_outputs, _ = pytree.tree_flatten(orig_outputs)
45*523fa7a6SAndroid Build Coastguard Worker        flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs)
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker        for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True):
48*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(orig, loaded))
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
51*523fa7a6SAndroid Build Coastguard Worker    def check_serde(self, m, inputs, check_executorch=True) -> None:
52*523fa7a6SAndroid Build Coastguard Worker        aten = export(m, inputs)
53*523fa7a6SAndroid Build Coastguard Worker        aten_new = deserialize(serialize(aten))
54*523fa7a6SAndroid Build Coastguard Worker        self.check_ep(aten, aten_new, inputs)
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker        edge = to_edge(aten)
57*523fa7a6SAndroid Build Coastguard Worker        edge_new = deserialize(serialize(edge.exported_program()))
58*523fa7a6SAndroid Build Coastguard Worker        self.check_ep(edge.exported_program(), edge_new, inputs)
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker        buffer = io.BytesIO()
61*523fa7a6SAndroid Build Coastguard Worker        exir.save(edge.exported_program(), buffer)
62*523fa7a6SAndroid Build Coastguard Worker        buffer.seek(0)
63*523fa7a6SAndroid Build Coastguard Worker        loaded_ep = exir.load(buffer)
64*523fa7a6SAndroid Build Coastguard Worker        self.check_ep(edge.exported_program(), loaded_ep, inputs)
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker        executorch = edge.to_executorch().exported_program()
67*523fa7a6SAndroid Build Coastguard Worker        executorch_new = deserialize(serialize(executorch))
68*523fa7a6SAndroid Build Coastguard Worker        if check_executorch:
69*523fa7a6SAndroid Build Coastguard Worker            with torch.no_grad():
70*523fa7a6SAndroid Build Coastguard Worker                self.check_ep(executorch, executorch_new, inputs)
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker                buffer = io.BytesIO()
73*523fa7a6SAndroid Build Coastguard Worker                exir.save(executorch, buffer)
74*523fa7a6SAndroid Build Coastguard Worker                buffer.seek(0)
75*523fa7a6SAndroid Build Coastguard Worker                loaded_ep = exir.load(buffer)
76*523fa7a6SAndroid Build Coastguard Worker                self.check_ep(executorch, loaded_ep, inputs)
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Worker    def test_basic(self) -> None:
79*523fa7a6SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
80*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
81*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
84*523fa7a6SAndroid Build Coastguard Worker                x = x + x
85*523fa7a6SAndroid Build Coastguard Worker                x = x * x
86*523fa7a6SAndroid Build Coastguard Worker                x = x / x
87*523fa7a6SAndroid Build Coastguard Worker                return x, x.clone()
88*523fa7a6SAndroid Build Coastguard Worker
89*523fa7a6SAndroid Build Coastguard Worker        inputs = (torch.ones([512], requires_grad=True),)
90*523fa7a6SAndroid Build Coastguard Worker        self.check_serde(MyModule(), inputs)
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker    def test_to_out_variant_singleon_tensor_list(self) -> None:
93*523fa7a6SAndroid Build Coastguard Worker        class MyModel(torch.nn.Module):
94*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
95*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
98*523fa7a6SAndroid Build Coastguard Worker                return torch.split(x, 10)
99*523fa7a6SAndroid Build Coastguard Worker
100*523fa7a6SAndroid Build Coastguard Worker            def get_random_inputs(self):
101*523fa7a6SAndroid Build Coastguard Worker                return (torch.randn(10),)
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker        model = MyModel()
104*523fa7a6SAndroid Build Coastguard Worker        inputs = model.get_random_inputs()
105*523fa7a6SAndroid Build Coastguard Worker        # We set check_executorch to false for this test because this triggers
106*523fa7a6SAndroid Build Coastguard Worker        # an edge case where calling .module() on the executorch exported program
107*523fa7a6SAndroid Build Coastguard Worker        # will cause an unlift pass to be run on the graph and dead code elimination
108*523fa7a6SAndroid Build Coastguard Worker        # will be subsequently run, which essentially causes the split_copy op to be
109*523fa7a6SAndroid Build Coastguard Worker        # removed.
110*523fa7a6SAndroid Build Coastguard Worker        self.check_serde(model, inputs, check_executorch=False)
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard Worker    def test_to_out_variant_multiple_out(self) -> None:
113*523fa7a6SAndroid Build Coastguard Worker        class MyModel(torch.nn.Module):
114*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
115*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
116*523fa7a6SAndroid Build Coastguard Worker
117*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
118*523fa7a6SAndroid Build Coastguard Worker                values, indices = torch.topk(x, 5)
119*523fa7a6SAndroid Build Coastguard Worker                return (values, indices)
120*523fa7a6SAndroid Build Coastguard Worker
121*523fa7a6SAndroid Build Coastguard Worker            def get_random_inputs(self):
122*523fa7a6SAndroid Build Coastguard Worker                return (torch.randn(10),)
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker        model = MyModel()
125*523fa7a6SAndroid Build Coastguard Worker        inputs = model.get_random_inputs()
126*523fa7a6SAndroid Build Coastguard Worker        self.check_serde(model, inputs)
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker    def test_delegate(self) -> None:
129*523fa7a6SAndroid Build Coastguard Worker        class SinModule(torch.nn.Module):
130*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
131*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
134*523fa7a6SAndroid Build Coastguard Worker                return torch.sin(x)
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker        sin_module = SinModule()
137*523fa7a6SAndroid Build Coastguard Worker        model_inputs = (torch.ones(1),)
138*523fa7a6SAndroid Build Coastguard Worker        edgeir_m = to_edge(export(sin_module, model_inputs))
139*523fa7a6SAndroid Build Coastguard Worker        max_value = model_inputs[0].shape[0]
140*523fa7a6SAndroid Build Coastguard Worker        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
141*523fa7a6SAndroid Build Coastguard Worker        lowered_sin_module = to_backend(
142*523fa7a6SAndroid Build Coastguard Worker            BackendWithCompilerDemo.__name__, edgeir_m.exported_program(), compile_specs
143*523fa7a6SAndroid Build Coastguard Worker        )
144*523fa7a6SAndroid Build Coastguard Worker
145*523fa7a6SAndroid Build Coastguard Worker        class CompositeModule(torch.nn.Module):
146*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
147*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
148*523fa7a6SAndroid Build Coastguard Worker                self.lowered_linear_sin = lowered_sin_module
149*523fa7a6SAndroid Build Coastguard Worker
150*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
151*523fa7a6SAndroid Build Coastguard Worker                return self.lowered_linear_sin(x)
152*523fa7a6SAndroid Build Coastguard Worker
153*523fa7a6SAndroid Build Coastguard Worker        composite_model = CompositeModule()
154*523fa7a6SAndroid Build Coastguard Worker        model_inputs = (torch.ones(1),)
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker        composite_model(*model_inputs)
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker        edge = to_edge(export(composite_model, model_inputs))
159*523fa7a6SAndroid Build Coastguard Worker        edge_new = deserialize(serialize(edge.exported_program()))
160*523fa7a6SAndroid Build Coastguard Worker        self.check_ep(edge.exported_program(), edge_new, model_inputs)
161*523fa7a6SAndroid Build Coastguard Worker
162*523fa7a6SAndroid Build Coastguard Worker    def test_model_with_weights(self) -> None:
163*523fa7a6SAndroid Build Coastguard Worker        class LinearAdd(nn.Module):
164*523fa7a6SAndroid Build Coastguard Worker            def __init__(self, M: int, N: int):
165*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
166*523fa7a6SAndroid Build Coastguard Worker                self.M = M
167*523fa7a6SAndroid Build Coastguard Worker                self.N = N
168*523fa7a6SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(M, N)
169*523fa7a6SAndroid Build Coastguard Worker
170*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x, y):
171*523fa7a6SAndroid Build Coastguard Worker                x = self.linear(x)
172*523fa7a6SAndroid Build Coastguard Worker                y = self.linear(y)
173*523fa7a6SAndroid Build Coastguard Worker                return torch.add(x, y)
174*523fa7a6SAndroid Build Coastguard Worker
175*523fa7a6SAndroid Build Coastguard Worker            @classmethod
176*523fa7a6SAndroid Build Coastguard Worker            def _get_random_inputs(cls):
177*523fa7a6SAndroid Build Coastguard Worker                return (torch.rand(128, 20), torch.rand(128, 20))
178*523fa7a6SAndroid Build Coastguard Worker
179*523fa7a6SAndroid Build Coastguard Worker        linear_add = LinearAdd(20, 30)
180*523fa7a6SAndroid Build Coastguard Worker        model_inputs = LinearAdd._get_random_inputs()
181*523fa7a6SAndroid Build Coastguard Worker
182*523fa7a6SAndroid Build Coastguard Worker        self.check_serde(linear_add, model_inputs)
183*523fa7a6SAndroid Build Coastguard Worker
184*523fa7a6SAndroid Build Coastguard Worker    def test_delegate_partitioner(self) -> None:
185*523fa7a6SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
186*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
187*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
188*523fa7a6SAndroid Build Coastguard Worker
189*523fa7a6SAndroid Build Coastguard Worker            def forward(self, a, x, b):
190*523fa7a6SAndroid Build Coastguard Worker                y = torch.mm(a, x)
191*523fa7a6SAndroid Build Coastguard Worker                z = y + b
192*523fa7a6SAndroid Build Coastguard Worker                a = z - a
193*523fa7a6SAndroid Build Coastguard Worker                y = torch.mm(a, x)
194*523fa7a6SAndroid Build Coastguard Worker                z = y + b
195*523fa7a6SAndroid Build Coastguard Worker                return z
196*523fa7a6SAndroid Build Coastguard Worker
197*523fa7a6SAndroid Build Coastguard Worker        m = Model()
198*523fa7a6SAndroid Build Coastguard Worker        inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
199*523fa7a6SAndroid Build Coastguard Worker
200*523fa7a6SAndroid Build Coastguard Worker        ep = to_edge(export(m, inputs))
201*523fa7a6SAndroid Build Coastguard Worker        edge = ep.to_backend(AddMulPartitionerDemo())
202*523fa7a6SAndroid Build Coastguard Worker        edge_new = deserialize(serialize(edge.exported_program()))
203*523fa7a6SAndroid Build Coastguard Worker        self.check_ep(edge.exported_program(), edge_new, inputs)
204*523fa7a6SAndroid Build Coastguard Worker
205*523fa7a6SAndroid Build Coastguard Worker    def test_meta_stack_trace_module_hierarchy(self) -> None:
206*523fa7a6SAndroid Build Coastguard Worker        class Model(nn.Module):
207*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
208*523fa7a6SAndroid Build Coastguard Worker                super(Model, self).__init__()
209*523fa7a6SAndroid Build Coastguard Worker                self.conv_layer = nn.Conv2d(
210*523fa7a6SAndroid Build Coastguard Worker                    in_channels=1, out_channels=64, kernel_size=3, padding=1
211*523fa7a6SAndroid Build Coastguard Worker                )
212*523fa7a6SAndroid Build Coastguard Worker
213*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
214*523fa7a6SAndroid Build Coastguard Worker                return self.conv_layer(x)
215*523fa7a6SAndroid Build Coastguard Worker
216*523fa7a6SAndroid Build Coastguard Worker        m = Model()
217*523fa7a6SAndroid Build Coastguard Worker        inputs = (torch.randn(1, 1, 32, 32),)
218*523fa7a6SAndroid Build Coastguard Worker
219*523fa7a6SAndroid Build Coastguard Worker        metadata = ()
220*523fa7a6SAndroid Build Coastguard Worker        edge = to_edge(export(m, inputs))
221*523fa7a6SAndroid Build Coastguard Worker        for node in edge.exported_program().graph_module.graph.nodes:
222*523fa7a6SAndroid Build Coastguard Worker            if "convolution" in str(node.target):
223*523fa7a6SAndroid Build Coastguard Worker                metadata = (
224*523fa7a6SAndroid Build Coastguard Worker                    node.meta.get("stack_trace"),
225*523fa7a6SAndroid Build Coastguard Worker                    node.meta.get("nn_module_stack"),
226*523fa7a6SAndroid Build Coastguard Worker                )
227*523fa7a6SAndroid Build Coastguard Worker
228*523fa7a6SAndroid Build Coastguard Worker        metadata_serde = ()
229*523fa7a6SAndroid Build Coastguard Worker        edge_new = deserialize(serialize(edge.exported_program()))
230*523fa7a6SAndroid Build Coastguard Worker        for node in edge_new.graph_module.graph.nodes:
231*523fa7a6SAndroid Build Coastguard Worker            if "convolution" in str(node.target):
232*523fa7a6SAndroid Build Coastguard Worker                metadata_serde = (
233*523fa7a6SAndroid Build Coastguard Worker                    node.meta.get("stack_trace"),
234*523fa7a6SAndroid Build Coastguard Worker                    node.meta.get("nn_module_stack"),
235*523fa7a6SAndroid Build Coastguard Worker                )
236*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(len(metadata) != 0 and len(metadata_serde) != 0)
237*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
238*523fa7a6SAndroid Build Coastguard Worker            all(val is not None for val in metadata)
239*523fa7a6SAndroid Build Coastguard Worker            and all(val is not None for val in metadata_serde)
240*523fa7a6SAndroid Build Coastguard Worker        )
241*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(metadata[0], metadata_serde[0])
242*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys()))
243