1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport io 10*523fa7a6SAndroid Build Coastguard Workerimport unittest 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Tuple 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerimport torch 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import to_edge 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_api import CompileSpec, to_backend 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.test.backend_with_compiler_demo import ( 19*523fa7a6SAndroid Build Coastguard Worker BackendWithCompilerDemo, 20*523fa7a6SAndroid Build Coastguard Worker) 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo 23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.serde.serialize import deserialize, serialize 24*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn 25*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export 26*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram as TorchExportedProgram 27*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker# Tests for serializing to json and back 31*523fa7a6SAndroid Build Coastguard Workerclass TestSerde(unittest.TestCase): 32*523fa7a6SAndroid Build Coastguard Worker def check_ep( 33*523fa7a6SAndroid Build Coastguard Worker self, 34*523fa7a6SAndroid Build Coastguard Worker ep1: TorchExportedProgram, 35*523fa7a6SAndroid Build Coastguard Worker ep2: TorchExportedProgram, 36*523fa7a6SAndroid Build Coastguard Worker inputs: Tuple[exir.Value, ...], 37*523fa7a6SAndroid Build Coastguard Worker ) -> None: 38*523fa7a6SAndroid Build Coastguard Worker """ 39*523fa7a6SAndroid Build Coastguard Worker Checks if two graphs are equivalent 40*523fa7a6SAndroid Build Coastguard Worker """ 41*523fa7a6SAndroid Build Coastguard Worker orig_outputs = ep1.module()(*inputs) 42*523fa7a6SAndroid Build Coastguard Worker loaded_outputs = ep2.module()(*inputs) 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker flat_orig_outputs, _ = pytree.tree_flatten(orig_outputs) 45*523fa7a6SAndroid Build Coastguard Worker flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs) 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Worker for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True): 48*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(orig, loaded)) 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 51*523fa7a6SAndroid Build Coastguard Worker def check_serde(self, m, inputs, check_executorch=True) -> None: 52*523fa7a6SAndroid Build Coastguard Worker aten = export(m, inputs) 53*523fa7a6SAndroid Build Coastguard Worker aten_new = deserialize(serialize(aten)) 54*523fa7a6SAndroid Build Coastguard Worker self.check_ep(aten, aten_new, inputs) 55*523fa7a6SAndroid Build Coastguard Worker 56*523fa7a6SAndroid Build Coastguard Worker edge = to_edge(aten) 57*523fa7a6SAndroid Build Coastguard Worker edge_new = deserialize(serialize(edge.exported_program())) 58*523fa7a6SAndroid Build Coastguard Worker self.check_ep(edge.exported_program(), edge_new, inputs) 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker buffer = io.BytesIO() 61*523fa7a6SAndroid Build Coastguard Worker exir.save(edge.exported_program(), buffer) 62*523fa7a6SAndroid Build Coastguard Worker buffer.seek(0) 63*523fa7a6SAndroid Build Coastguard Worker loaded_ep = exir.load(buffer) 64*523fa7a6SAndroid Build Coastguard Worker self.check_ep(edge.exported_program(), loaded_ep, inputs) 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Worker executorch = edge.to_executorch().exported_program() 67*523fa7a6SAndroid Build Coastguard Worker executorch_new = deserialize(serialize(executorch)) 68*523fa7a6SAndroid Build Coastguard Worker if check_executorch: 69*523fa7a6SAndroid Build Coastguard Worker with torch.no_grad(): 70*523fa7a6SAndroid Build Coastguard Worker self.check_ep(executorch, executorch_new, inputs) 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Worker buffer = io.BytesIO() 73*523fa7a6SAndroid Build Coastguard Worker exir.save(executorch, buffer) 74*523fa7a6SAndroid Build Coastguard Worker buffer.seek(0) 75*523fa7a6SAndroid Build Coastguard Worker loaded_ep = exir.load(buffer) 76*523fa7a6SAndroid Build Coastguard Worker self.check_ep(executorch, loaded_ep, inputs) 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Worker def test_basic(self) -> None: 79*523fa7a6SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 80*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 81*523fa7a6SAndroid Build Coastguard Worker super().__init__() 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 84*523fa7a6SAndroid Build Coastguard Worker x = x + x 85*523fa7a6SAndroid Build Coastguard Worker x = x * x 86*523fa7a6SAndroid Build Coastguard Worker x = x / x 87*523fa7a6SAndroid Build Coastguard Worker return x, x.clone() 88*523fa7a6SAndroid Build Coastguard Worker 89*523fa7a6SAndroid Build Coastguard Worker inputs = (torch.ones([512], requires_grad=True),) 90*523fa7a6SAndroid Build Coastguard Worker self.check_serde(MyModule(), inputs) 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker def test_to_out_variant_singleon_tensor_list(self) -> None: 93*523fa7a6SAndroid Build Coastguard Worker class MyModel(torch.nn.Module): 94*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 95*523fa7a6SAndroid Build Coastguard Worker super().__init__() 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 98*523fa7a6SAndroid Build Coastguard Worker return torch.split(x, 10) 99*523fa7a6SAndroid Build Coastguard Worker 100*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 101*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(10),) 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Worker model = MyModel() 104*523fa7a6SAndroid Build Coastguard Worker inputs = model.get_random_inputs() 105*523fa7a6SAndroid Build Coastguard Worker # We set check_executorch to false for this test because this triggers 106*523fa7a6SAndroid Build Coastguard Worker # an edge case where calling .module() on the executorch exported program 107*523fa7a6SAndroid Build Coastguard Worker # will cause an unlift pass to be run on the graph and dead code elimination 108*523fa7a6SAndroid Build Coastguard Worker # will be subsequently run, which essentially causes the split_copy op to be 109*523fa7a6SAndroid Build Coastguard Worker # removed. 110*523fa7a6SAndroid Build Coastguard Worker self.check_serde(model, inputs, check_executorch=False) 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker def test_to_out_variant_multiple_out(self) -> None: 113*523fa7a6SAndroid Build Coastguard Worker class MyModel(torch.nn.Module): 114*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 115*523fa7a6SAndroid Build Coastguard Worker super().__init__() 116*523fa7a6SAndroid Build Coastguard Worker 117*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 118*523fa7a6SAndroid Build Coastguard Worker values, indices = torch.topk(x, 5) 119*523fa7a6SAndroid Build Coastguard Worker return (values, indices) 120*523fa7a6SAndroid Build Coastguard Worker 121*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 122*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(10),) 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Worker model = MyModel() 125*523fa7a6SAndroid Build Coastguard Worker inputs = model.get_random_inputs() 126*523fa7a6SAndroid Build Coastguard Worker self.check_serde(model, inputs) 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker def test_delegate(self) -> None: 129*523fa7a6SAndroid Build Coastguard Worker class SinModule(torch.nn.Module): 130*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 131*523fa7a6SAndroid Build Coastguard Worker super().__init__() 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 134*523fa7a6SAndroid Build Coastguard Worker return torch.sin(x) 135*523fa7a6SAndroid Build Coastguard Worker 136*523fa7a6SAndroid Build Coastguard Worker sin_module = SinModule() 137*523fa7a6SAndroid Build Coastguard Worker model_inputs = (torch.ones(1),) 138*523fa7a6SAndroid Build Coastguard Worker edgeir_m = to_edge(export(sin_module, model_inputs)) 139*523fa7a6SAndroid Build Coastguard Worker max_value = model_inputs[0].shape[0] 140*523fa7a6SAndroid Build Coastguard Worker compile_specs = [CompileSpec("max_value", bytes([max_value]))] 141*523fa7a6SAndroid Build Coastguard Worker lowered_sin_module = to_backend( 142*523fa7a6SAndroid Build Coastguard Worker BackendWithCompilerDemo.__name__, edgeir_m.exported_program(), compile_specs 143*523fa7a6SAndroid Build Coastguard Worker ) 144*523fa7a6SAndroid Build Coastguard Worker 145*523fa7a6SAndroid Build Coastguard Worker class CompositeModule(torch.nn.Module): 146*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 147*523fa7a6SAndroid Build Coastguard Worker super().__init__() 148*523fa7a6SAndroid Build Coastguard Worker self.lowered_linear_sin = lowered_sin_module 149*523fa7a6SAndroid Build Coastguard Worker 150*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 151*523fa7a6SAndroid Build Coastguard Worker return self.lowered_linear_sin(x) 152*523fa7a6SAndroid Build Coastguard Worker 153*523fa7a6SAndroid Build Coastguard Worker composite_model = CompositeModule() 154*523fa7a6SAndroid Build Coastguard Worker model_inputs = (torch.ones(1),) 155*523fa7a6SAndroid Build Coastguard Worker 156*523fa7a6SAndroid Build Coastguard Worker composite_model(*model_inputs) 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard Worker edge = to_edge(export(composite_model, model_inputs)) 159*523fa7a6SAndroid Build Coastguard Worker edge_new = deserialize(serialize(edge.exported_program())) 160*523fa7a6SAndroid Build Coastguard Worker self.check_ep(edge.exported_program(), edge_new, model_inputs) 161*523fa7a6SAndroid Build Coastguard Worker 162*523fa7a6SAndroid Build Coastguard Worker def test_model_with_weights(self) -> None: 163*523fa7a6SAndroid Build Coastguard Worker class LinearAdd(nn.Module): 164*523fa7a6SAndroid Build Coastguard Worker def __init__(self, M: int, N: int): 165*523fa7a6SAndroid Build Coastguard Worker super().__init__() 166*523fa7a6SAndroid Build Coastguard Worker self.M = M 167*523fa7a6SAndroid Build Coastguard Worker self.N = N 168*523fa7a6SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(M, N) 169*523fa7a6SAndroid Build Coastguard Worker 170*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 171*523fa7a6SAndroid Build Coastguard Worker x = self.linear(x) 172*523fa7a6SAndroid Build Coastguard Worker y = self.linear(y) 173*523fa7a6SAndroid Build Coastguard Worker return torch.add(x, y) 174*523fa7a6SAndroid Build Coastguard Worker 175*523fa7a6SAndroid Build Coastguard Worker @classmethod 176*523fa7a6SAndroid Build Coastguard Worker def _get_random_inputs(cls): 177*523fa7a6SAndroid Build Coastguard Worker return (torch.rand(128, 20), torch.rand(128, 20)) 178*523fa7a6SAndroid Build Coastguard Worker 179*523fa7a6SAndroid Build Coastguard Worker linear_add = LinearAdd(20, 30) 180*523fa7a6SAndroid Build Coastguard Worker model_inputs = LinearAdd._get_random_inputs() 181*523fa7a6SAndroid Build Coastguard Worker 182*523fa7a6SAndroid Build Coastguard Worker self.check_serde(linear_add, model_inputs) 183*523fa7a6SAndroid Build Coastguard Worker 184*523fa7a6SAndroid Build Coastguard Worker def test_delegate_partitioner(self) -> None: 185*523fa7a6SAndroid Build Coastguard Worker class Model(torch.nn.Module): 186*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 187*523fa7a6SAndroid Build Coastguard Worker super().__init__() 188*523fa7a6SAndroid Build Coastguard Worker 189*523fa7a6SAndroid Build Coastguard Worker def forward(self, a, x, b): 190*523fa7a6SAndroid Build Coastguard Worker y = torch.mm(a, x) 191*523fa7a6SAndroid Build Coastguard Worker z = y + b 192*523fa7a6SAndroid Build Coastguard Worker a = z - a 193*523fa7a6SAndroid Build Coastguard Worker y = torch.mm(a, x) 194*523fa7a6SAndroid Build Coastguard Worker z = y + b 195*523fa7a6SAndroid Build Coastguard Worker return z 196*523fa7a6SAndroid Build Coastguard Worker 197*523fa7a6SAndroid Build Coastguard Worker m = Model() 198*523fa7a6SAndroid Build Coastguard Worker inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) 199*523fa7a6SAndroid Build Coastguard Worker 200*523fa7a6SAndroid Build Coastguard Worker ep = to_edge(export(m, inputs)) 201*523fa7a6SAndroid Build Coastguard Worker edge = ep.to_backend(AddMulPartitionerDemo()) 202*523fa7a6SAndroid Build Coastguard Worker edge_new = deserialize(serialize(edge.exported_program())) 203*523fa7a6SAndroid Build Coastguard Worker self.check_ep(edge.exported_program(), edge_new, inputs) 204*523fa7a6SAndroid Build Coastguard Worker 205*523fa7a6SAndroid Build Coastguard Worker def test_meta_stack_trace_module_hierarchy(self) -> None: 206*523fa7a6SAndroid Build Coastguard Worker class Model(nn.Module): 207*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 208*523fa7a6SAndroid Build Coastguard Worker super(Model, self).__init__() 209*523fa7a6SAndroid Build Coastguard Worker self.conv_layer = nn.Conv2d( 210*523fa7a6SAndroid Build Coastguard Worker in_channels=1, out_channels=64, kernel_size=3, padding=1 211*523fa7a6SAndroid Build Coastguard Worker ) 212*523fa7a6SAndroid Build Coastguard Worker 213*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 214*523fa7a6SAndroid Build Coastguard Worker return self.conv_layer(x) 215*523fa7a6SAndroid Build Coastguard Worker 216*523fa7a6SAndroid Build Coastguard Worker m = Model() 217*523fa7a6SAndroid Build Coastguard Worker inputs = (torch.randn(1, 1, 32, 32),) 218*523fa7a6SAndroid Build Coastguard Worker 219*523fa7a6SAndroid Build Coastguard Worker metadata = () 220*523fa7a6SAndroid Build Coastguard Worker edge = to_edge(export(m, inputs)) 221*523fa7a6SAndroid Build Coastguard Worker for node in edge.exported_program().graph_module.graph.nodes: 222*523fa7a6SAndroid Build Coastguard Worker if "convolution" in str(node.target): 223*523fa7a6SAndroid Build Coastguard Worker metadata = ( 224*523fa7a6SAndroid Build Coastguard Worker node.meta.get("stack_trace"), 225*523fa7a6SAndroid Build Coastguard Worker node.meta.get("nn_module_stack"), 226*523fa7a6SAndroid Build Coastguard Worker ) 227*523fa7a6SAndroid Build Coastguard Worker 228*523fa7a6SAndroid Build Coastguard Worker metadata_serde = () 229*523fa7a6SAndroid Build Coastguard Worker edge_new = deserialize(serialize(edge.exported_program())) 230*523fa7a6SAndroid Build Coastguard Worker for node in edge_new.graph_module.graph.nodes: 231*523fa7a6SAndroid Build Coastguard Worker if "convolution" in str(node.target): 232*523fa7a6SAndroid Build Coastguard Worker metadata_serde = ( 233*523fa7a6SAndroid Build Coastguard Worker node.meta.get("stack_trace"), 234*523fa7a6SAndroid Build Coastguard Worker node.meta.get("nn_module_stack"), 235*523fa7a6SAndroid Build Coastguard Worker ) 236*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(len(metadata) != 0 and len(metadata_serde) != 0) 237*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 238*523fa7a6SAndroid Build Coastguard Worker all(val is not None for val in metadata) 239*523fa7a6SAndroid Build Coastguard Worker and all(val is not None for val in metadata_serde) 240*523fa7a6SAndroid Build Coastguard Worker ) 241*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(metadata[0], metadata_serde[0]) 242*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys())) 243