1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import io 10import unittest 11from typing import Tuple 12 13import executorch.exir as exir 14 15import torch 16from executorch.exir import to_edge 17from executorch.exir.backend.backend_api import CompileSpec, to_backend 18from executorch.exir.backend.test.backend_with_compiler_demo import ( 19 BackendWithCompilerDemo, 20) 21 22from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo 23from executorch.exir.serde.serialize import deserialize, serialize 24from torch import nn 25from torch.export import export 26from torch.export.exported_program import ExportedProgram as TorchExportedProgram 27from torch.utils import _pytree as pytree 28 29 30# Tests for serializing to json and back 31class TestSerde(unittest.TestCase): 32 def check_ep( 33 self, 34 ep1: TorchExportedProgram, 35 ep2: TorchExportedProgram, 36 inputs: Tuple[exir.Value, ...], 37 ) -> None: 38 """ 39 Checks if two graphs are equivalent 40 """ 41 orig_outputs = ep1.module()(*inputs) 42 loaded_outputs = ep2.module()(*inputs) 43 44 flat_orig_outputs, _ = pytree.tree_flatten(orig_outputs) 45 flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs) 46 47 for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True): 48 self.assertTrue(torch.allclose(orig, loaded)) 49 50 # pyre-ignore 51 def check_serde(self, m, inputs, check_executorch=True) -> None: 52 aten = export(m, inputs) 53 aten_new = deserialize(serialize(aten)) 54 self.check_ep(aten, aten_new, inputs) 55 56 edge = to_edge(aten) 57 edge_new = deserialize(serialize(edge.exported_program())) 58 self.check_ep(edge.exported_program(), edge_new, inputs) 59 60 buffer = io.BytesIO() 61 exir.save(edge.exported_program(), buffer) 62 buffer.seek(0) 63 loaded_ep = exir.load(buffer) 64 self.check_ep(edge.exported_program(), loaded_ep, inputs) 65 66 executorch = edge.to_executorch().exported_program() 67 executorch_new = deserialize(serialize(executorch)) 68 if check_executorch: 69 with torch.no_grad(): 70 self.check_ep(executorch, executorch_new, inputs) 71 72 buffer = io.BytesIO() 73 exir.save(executorch, buffer) 74 buffer.seek(0) 75 loaded_ep = exir.load(buffer) 76 self.check_ep(executorch, loaded_ep, inputs) 77 78 def test_basic(self) -> None: 79 class MyModule(torch.nn.Module): 80 def __init__(self): 81 super().__init__() 82 83 def forward(self, x): 84 x = x + x 85 x = x * x 86 x = x / x 87 return x, x.clone() 88 89 inputs = (torch.ones([512], requires_grad=True),) 90 self.check_serde(MyModule(), inputs) 91 92 def test_to_out_variant_singleon_tensor_list(self) -> None: 93 class MyModel(torch.nn.Module): 94 def __init__(self): 95 super().__init__() 96 97 def forward(self, x): 98 return torch.split(x, 10) 99 100 def get_random_inputs(self): 101 return (torch.randn(10),) 102 103 model = MyModel() 104 inputs = model.get_random_inputs() 105 # We set check_executorch to false for this test because this triggers 106 # an edge case where calling .module() on the executorch exported program 107 # will cause an unlift pass to be run on the graph and dead code elimination 108 # will be subsequently run, which essentially causes the split_copy op to be 109 # removed. 110 self.check_serde(model, inputs, check_executorch=False) 111 112 def test_to_out_variant_multiple_out(self) -> None: 113 class MyModel(torch.nn.Module): 114 def __init__(self): 115 super().__init__() 116 117 def forward(self, x): 118 values, indices = torch.topk(x, 5) 119 return (values, indices) 120 121 def get_random_inputs(self): 122 return (torch.randn(10),) 123 124 model = MyModel() 125 inputs = model.get_random_inputs() 126 self.check_serde(model, inputs) 127 128 def test_delegate(self) -> None: 129 class SinModule(torch.nn.Module): 130 def __init__(self): 131 super().__init__() 132 133 def forward(self, x): 134 return torch.sin(x) 135 136 sin_module = SinModule() 137 model_inputs = (torch.ones(1),) 138 edgeir_m = to_edge(export(sin_module, model_inputs)) 139 max_value = model_inputs[0].shape[0] 140 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 141 lowered_sin_module = to_backend( 142 BackendWithCompilerDemo.__name__, edgeir_m.exported_program(), compile_specs 143 ) 144 145 class CompositeModule(torch.nn.Module): 146 def __init__(self): 147 super().__init__() 148 self.lowered_linear_sin = lowered_sin_module 149 150 def forward(self, x): 151 return self.lowered_linear_sin(x) 152 153 composite_model = CompositeModule() 154 model_inputs = (torch.ones(1),) 155 156 composite_model(*model_inputs) 157 158 edge = to_edge(export(composite_model, model_inputs)) 159 edge_new = deserialize(serialize(edge.exported_program())) 160 self.check_ep(edge.exported_program(), edge_new, model_inputs) 161 162 def test_model_with_weights(self) -> None: 163 class LinearAdd(nn.Module): 164 def __init__(self, M: int, N: int): 165 super().__init__() 166 self.M = M 167 self.N = N 168 self.linear = torch.nn.Linear(M, N) 169 170 def forward(self, x, y): 171 x = self.linear(x) 172 y = self.linear(y) 173 return torch.add(x, y) 174 175 @classmethod 176 def _get_random_inputs(cls): 177 return (torch.rand(128, 20), torch.rand(128, 20)) 178 179 linear_add = LinearAdd(20, 30) 180 model_inputs = LinearAdd._get_random_inputs() 181 182 self.check_serde(linear_add, model_inputs) 183 184 def test_delegate_partitioner(self) -> None: 185 class Model(torch.nn.Module): 186 def __init__(self): 187 super().__init__() 188 189 def forward(self, a, x, b): 190 y = torch.mm(a, x) 191 z = y + b 192 a = z - a 193 y = torch.mm(a, x) 194 z = y + b 195 return z 196 197 m = Model() 198 inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) 199 200 ep = to_edge(export(m, inputs)) 201 edge = ep.to_backend(AddMulPartitionerDemo()) 202 edge_new = deserialize(serialize(edge.exported_program())) 203 self.check_ep(edge.exported_program(), edge_new, inputs) 204 205 def test_meta_stack_trace_module_hierarchy(self) -> None: 206 class Model(nn.Module): 207 def __init__(self): 208 super(Model, self).__init__() 209 self.conv_layer = nn.Conv2d( 210 in_channels=1, out_channels=64, kernel_size=3, padding=1 211 ) 212 213 def forward(self, x): 214 return self.conv_layer(x) 215 216 m = Model() 217 inputs = (torch.randn(1, 1, 32, 32),) 218 219 metadata = () 220 edge = to_edge(export(m, inputs)) 221 for node in edge.exported_program().graph_module.graph.nodes: 222 if "convolution" in str(node.target): 223 metadata = ( 224 node.meta.get("stack_trace"), 225 node.meta.get("nn_module_stack"), 226 ) 227 228 metadata_serde = () 229 edge_new = deserialize(serialize(edge.exported_program())) 230 for node in edge_new.graph_module.graph.nodes: 231 if "convolution" in str(node.target): 232 metadata_serde = ( 233 node.meta.get("stack_trace"), 234 node.meta.get("nn_module_stack"), 235 ) 236 self.assertTrue(len(metadata) != 0 and len(metadata_serde) != 0) 237 self.assertTrue( 238 all(val is not None for val in metadata) 239 and all(val is not None for val in metadata_serde) 240 ) 241 self.assertEqual(metadata[0], metadata_serde[0]) 242 self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys())) 243