1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport unittest 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.tests.models as models 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport torch 12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import EdgeCompileConfig, to_edge 13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import ( 15*523fa7a6SAndroid Build Coastguard Worker create_submodule_from_nodes, 16*523fa7a6SAndroid Build Coastguard Worker LoweredBackendModule, 17*523fa7a6SAndroid Build Coastguard Worker) 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import ( 19*523fa7a6SAndroid Build Coastguard Worker BackendDelegate, 20*523fa7a6SAndroid Build Coastguard Worker BackendDelegateDataReference, 21*523fa7a6SAndroid Build Coastguard Worker DataLocation, 22*523fa7a6SAndroid Build Coastguard Worker DelegateCall, 23*523fa7a6SAndroid Build Coastguard Worker) 24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.common import register_additional_test_aten_ops 25*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export 26*523fa7a6SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Workerclass WrapperModule(torch.nn.Module): 30*523fa7a6SAndroid Build Coastguard Worker def __init__(self, fn): 31*523fa7a6SAndroid Build Coastguard Worker super().__init__() 32*523fa7a6SAndroid Build Coastguard Worker self.fn = fn 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker def forward(self, *args, **kwargs): 35*523fa7a6SAndroid Build Coastguard Worker return self.fn(*args, **kwargs) 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker 38*523fa7a6SAndroid Build Coastguard Workerclass TestDelegate(unittest.TestCase): 39*523fa7a6SAndroid Build Coastguard Worker @classmethod 40*523fa7a6SAndroid Build Coastguard Worker def setUpClass(cls) -> None: 41*523fa7a6SAndroid Build Coastguard Worker register_additional_test_aten_ops() 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker def test_call_delegate(self) -> None: 44*523fa7a6SAndroid Build Coastguard Worker def g(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 45*523fa7a6SAndroid Build Coastguard Worker return x + y 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Worker inputs = (torch.ones(1, 3), torch.ones(1, 3)) 48*523fa7a6SAndroid Build Coastguard Worker edge_ir_m = to_edge(export(WrapperModule(g), inputs)) 49*523fa7a6SAndroid Build Coastguard Worker lowered_module: LoweredBackendModule = LoweredBackendModule( 50*523fa7a6SAndroid Build Coastguard Worker edge_ir_m.exported_program(), "BackendWithCompilerDemo", b"moo", [] 51*523fa7a6SAndroid Build Coastguard Worker ) 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 54*523fa7a6SAndroid Build Coastguard Worker return torch.ops.higher_order.executorch_call_delegate(lowered_module, x, y) 55*523fa7a6SAndroid Build Coastguard Worker 56*523fa7a6SAndroid Build Coastguard Worker orig_res = f(*inputs) 57*523fa7a6SAndroid Build Coastguard Worker gm = export( 58*523fa7a6SAndroid Build Coastguard Worker WrapperModule(f), 59*523fa7a6SAndroid Build Coastguard Worker inputs, 60*523fa7a6SAndroid Build Coastguard Worker ) 61*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("lowered_module_0").check( 62*523fa7a6SAndroid Build Coastguard Worker "torch.ops.higher_order.executorch_call_delegate" 63*523fa7a6SAndroid Build Coastguard Worker ).run(gm.graph_module.code) 64*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(orig_res, gm.module()(*inputs))) 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Worker def test_to_backend(self) -> None: 67*523fa7a6SAndroid Build Coastguard Worker """Check if we have patched a lowered module correctly (for delegation)""" 68*523fa7a6SAndroid Build Coastguard Worker 69*523fa7a6SAndroid Build Coastguard Worker m = models.CompositeDelegateModule() 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Worker exec_prog = to_edge( 72*523fa7a6SAndroid Build Coastguard Worker export(m, m.get_random_inputs()), 73*523fa7a6SAndroid Build Coastguard Worker compile_config=EdgeCompileConfig(_check_ir_validity=False), 74*523fa7a6SAndroid Build Coastguard Worker ).to_executorch() # TODO(larryliu): fix split_copy.Tensor 75*523fa7a6SAndroid Build Coastguard Worker graph_module = exec_prog.exported_program().graph_module 76*523fa7a6SAndroid Build Coastguard Worker program = exec_prog._emitter_output.program 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Worker # Check that there exists a call_delegate, representing the call to the 79*523fa7a6SAndroid Build Coastguard Worker # delegated function 80*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("lowered_module_0").check( 81*523fa7a6SAndroid Build Coastguard Worker "torch.ops.higher_order.executorch_call_delegate" 82*523fa7a6SAndroid Build Coastguard Worker ).run(graph_module.code) 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Worker # Check that there does not exist an add node (from the non-delegated 85*523fa7a6SAndroid Build Coastguard Worker # BasicModuleAdd.forward function) 86*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 87*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.add.default 88*523fa7a6SAndroid Build Coastguard Worker not in {node.target for node in graph_module.graph.nodes} 89*523fa7a6SAndroid Build Coastguard Worker ) 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 92*523fa7a6SAndroid Build Coastguard Worker if ( 93*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 94*523fa7a6SAndroid Build Coastguard Worker and node.target == torch.ops.higher_order.executorch_call_delegate 95*523fa7a6SAndroid Build Coastguard Worker ): 96*523fa7a6SAndroid Build Coastguard Worker # Check that the first argument is the lowered backend module 97*523fa7a6SAndroid Build Coastguard Worker # (which we got from a getattr) 98*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(node.args[0].op, "get_attr") 99*523fa7a6SAndroid Build Coastguard Worker get_attr_backend = getattr(graph_module, node.args[0].target) 100*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 101*523fa7a6SAndroid Build Coastguard Worker get_attr_backend._backend_id, m.lowered_module._backend_id 102*523fa7a6SAndroid Build Coastguard Worker ) 103*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 104*523fa7a6SAndroid Build Coastguard Worker get_attr_backend._processed_bytes, m.lowered_module._processed_bytes 105*523fa7a6SAndroid Build Coastguard Worker ) 106*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 107*523fa7a6SAndroid Build Coastguard Worker get_attr_backend._compile_specs, m.lowered_module._compile_specs 108*523fa7a6SAndroid Build Coastguard Worker ) 109*523fa7a6SAndroid Build Coastguard Worker 110*523fa7a6SAndroid Build Coastguard Worker # Check the BackendDelegate object itself 111*523fa7a6SAndroid Build Coastguard Worker delegate: BackendDelegate = program.execution_plan[0].delegates[0] 112*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(delegate.id, "backend_demo") 113*523fa7a6SAndroid Build Coastguard Worker processed: BackendDelegateDataReference = delegate.processed 114*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(processed.location, DataLocation.INLINE) 115*523fa7a6SAndroid Build Coastguard Worker self.assertLess(processed.index, len(program.backend_delegate_data)) 116*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 117*523fa7a6SAndroid Build Coastguard Worker program.backend_delegate_data[processed.index].data, b"basic_module_add" 118*523fa7a6SAndroid Build Coastguard Worker ) 119*523fa7a6SAndroid Build Coastguard Worker 120*523fa7a6SAndroid Build Coastguard Worker # Check the delegate instruction 121*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 122*523fa7a6SAndroid Build Coastguard Worker isinstance( 123*523fa7a6SAndroid Build Coastguard Worker program.execution_plan[0].chains[0].instructions[0].instr_args, 124*523fa7a6SAndroid Build Coastguard Worker DelegateCall, 125*523fa7a6SAndroid Build Coastguard Worker ) 126*523fa7a6SAndroid Build Coastguard Worker ) 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker def test_cannot_assign_attr(self) -> None: 129*523fa7a6SAndroid Build Coastguard Worker deleg = LoweredBackendModule(None, "", b"", []) # pyre-ignore 130*523fa7a6SAndroid Build Coastguard Worker with self.assertRaises(AttributeError): 131*523fa7a6SAndroid Build Coastguard Worker deleg.backend_id = "123" # pyre-ignore 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Worker def test_create_submodule_single_return(self) -> None: 134*523fa7a6SAndroid Build Coastguard Worker """ 135*523fa7a6SAndroid Build Coastguard Worker Original graph: 136*523fa7a6SAndroid Build Coastguard Worker add_tensor = add(x, y) 137*523fa7a6SAndroid Build Coastguard Worker mul_tensor = mul(add_tensor, y) 138*523fa7a6SAndroid Build Coastguard Worker sub_tensor = sub(mul_tensor, y) 139*523fa7a6SAndroid Build Coastguard Worker div_tensor = div(sub_tensor, y) 140*523fa7a6SAndroid Build Coastguard Worker return [div_tensor] 141*523fa7a6SAndroid Build Coastguard Worker 142*523fa7a6SAndroid Build Coastguard Worker Partitioned graph: 143*523fa7a6SAndroid Build Coastguard Worker add_tensor = add(x, y) 144*523fa7a6SAndroid Build Coastguard Worker mul_tensor = mul(add_tensor, y) 145*523fa7a6SAndroid Build Coastguard Worker return [mul_tensor] # Output is pytree.flatten-ed 146*523fa7a6SAndroid Build Coastguard Worker 147*523fa7a6SAndroid Build Coastguard Worker Final graph: 148*523fa7a6SAndroid Build Coastguard Worker partitioned_res = partitioned_graph(x, y) 149*523fa7a6SAndroid Build Coastguard Worker getitem_0 = partitioned_res[0] 150*523fa7a6SAndroid Build Coastguard Worker sub_tensor = sub(getitem_0, y) 151*523fa7a6SAndroid Build Coastguard Worker div_tensor = div(sub_tensor, y) 152*523fa7a6SAndroid Build Coastguard Worker return [div_tensor] 153*523fa7a6SAndroid Build Coastguard Worker """ 154*523fa7a6SAndroid Build Coastguard Worker inputs = (torch.randn(1, 3), torch.randn(1, 3)) 155*523fa7a6SAndroid Build Coastguard Worker 156*523fa7a6SAndroid Build Coastguard Worker class Model(torch.nn.Module): 157*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 158*523fa7a6SAndroid Build Coastguard Worker super().__init__() 159*523fa7a6SAndroid Build Coastguard Worker 160*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 161*523fa7a6SAndroid Build Coastguard Worker x = x + y 162*523fa7a6SAndroid Build Coastguard Worker x = x * y 163*523fa7a6SAndroid Build Coastguard Worker x = x - y 164*523fa7a6SAndroid Build Coastguard Worker x = x / y 165*523fa7a6SAndroid Build Coastguard Worker return x 166*523fa7a6SAndroid Build Coastguard Worker 167*523fa7a6SAndroid Build Coastguard Worker orig_res = Model()(*inputs) 168*523fa7a6SAndroid Build Coastguard Worker prog = to_edge(export(Model(), inputs)) 169*523fa7a6SAndroid Build Coastguard Worker gm = prog.exported_program().graph_module 170*523fa7a6SAndroid Build Coastguard Worker 171*523fa7a6SAndroid Build Coastguard Worker node_list = [] 172*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 173*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" and node.target in { 174*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.add.Tensor, 175*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.mul.Tensor, 176*523fa7a6SAndroid Build Coastguard Worker }: 177*523fa7a6SAndroid Build Coastguard Worker node_list.append(node) 178*523fa7a6SAndroid Build Coastguard Worker 179*523fa7a6SAndroid Build Coastguard Worker sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag") 180*523fa7a6SAndroid Build Coastguard Worker sub_gm.recompile() 181*523fa7a6SAndroid Build Coastguard Worker gm.recompile() 182*523fa7a6SAndroid Build Coastguard Worker 183*523fa7a6SAndroid Build Coastguard Worker for node in sub_gm.graph.nodes: 184*523fa7a6SAndroid Build Coastguard Worker if node.op == "output": 185*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(node.args), 1) 186*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(node.args[0], list)) 187*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(node.args[0]), 1) 188*523fa7a6SAndroid Build Coastguard Worker 189*523fa7a6SAndroid Build Coastguard Worker new_res = prog.exported_program().module()(*inputs) 190*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(new_res, orig_res)) 191*523fa7a6SAndroid Build Coastguard Worker 192*523fa7a6SAndroid Build Coastguard Worker def test_create_submodule_multiple_return(self) -> None: 193*523fa7a6SAndroid Build Coastguard Worker """ 194*523fa7a6SAndroid Build Coastguard Worker Original graph: 195*523fa7a6SAndroid Build Coastguard Worker add_tensor = add(x, y) 196*523fa7a6SAndroid Build Coastguard Worker mul_tensor = mul(add_tensor, y) 197*523fa7a6SAndroid Build Coastguard Worker sub_tensor = sub(add_tensor, mul_tensor) 198*523fa7a6SAndroid Build Coastguard Worker div_tensor = div(sub_tensor, mul_tensor) 199*523fa7a6SAndroid Build Coastguard Worker return [div_tensor] 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Worker Partitioned graph: 202*523fa7a6SAndroid Build Coastguard Worker add_tensor = add(x, y) 203*523fa7a6SAndroid Build Coastguard Worker mul_tensor = mul(add_tensor, y) 204*523fa7a6SAndroid Build Coastguard Worker return [add_tensor, mul_tensor] 205*523fa7a6SAndroid Build Coastguard Worker 206*523fa7a6SAndroid Build Coastguard Worker Final graph: 207*523fa7a6SAndroid Build Coastguard Worker partitioned_res = partitioned_graph(x, y) 208*523fa7a6SAndroid Build Coastguard Worker getitem_0 = partitioned_res[0] 209*523fa7a6SAndroid Build Coastguard Worker getitem_1 = partitioned_res[1] 210*523fa7a6SAndroid Build Coastguard Worker sub_tensor = sub(getitem_0, getitem_1) 211*523fa7a6SAndroid Build Coastguard Worker div_tensor = div(sub_tensor, getitem_1) 212*523fa7a6SAndroid Build Coastguard Worker return [div_tensor] 213*523fa7a6SAndroid Build Coastguard Worker """ 214*523fa7a6SAndroid Build Coastguard Worker inputs = (torch.randn(1, 3), torch.randn(1, 3)) 215*523fa7a6SAndroid Build Coastguard Worker 216*523fa7a6SAndroid Build Coastguard Worker class Model(torch.nn.Module): 217*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 218*523fa7a6SAndroid Build Coastguard Worker super().__init__() 219*523fa7a6SAndroid Build Coastguard Worker 220*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 221*523fa7a6SAndroid Build Coastguard Worker x = x + y 222*523fa7a6SAndroid Build Coastguard Worker y = x * y 223*523fa7a6SAndroid Build Coastguard Worker x = x - y 224*523fa7a6SAndroid Build Coastguard Worker x = x / y 225*523fa7a6SAndroid Build Coastguard Worker return x 226*523fa7a6SAndroid Build Coastguard Worker 227*523fa7a6SAndroid Build Coastguard Worker orig_res = Model()(*inputs) 228*523fa7a6SAndroid Build Coastguard Worker prog = to_edge(export(Model(), inputs)) 229*523fa7a6SAndroid Build Coastguard Worker gm = prog.exported_program().graph_module 230*523fa7a6SAndroid Build Coastguard Worker 231*523fa7a6SAndroid Build Coastguard Worker node_list = [] 232*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 233*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" and node.target in { 234*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.add.Tensor, 235*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.mul.Tensor, 236*523fa7a6SAndroid Build Coastguard Worker }: 237*523fa7a6SAndroid Build Coastguard Worker node_list.append(node) 238*523fa7a6SAndroid Build Coastguard Worker 239*523fa7a6SAndroid Build Coastguard Worker sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag") 240*523fa7a6SAndroid Build Coastguard Worker sub_gm.recompile() 241*523fa7a6SAndroid Build Coastguard Worker gm.recompile() 242*523fa7a6SAndroid Build Coastguard Worker 243*523fa7a6SAndroid Build Coastguard Worker for node in sub_gm.graph.nodes: 244*523fa7a6SAndroid Build Coastguard Worker if node.op == "output": 245*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(node.args), 1) 246*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(node.args[0], list)) 247*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(node.args[0]), 2) 248*523fa7a6SAndroid Build Coastguard Worker 249*523fa7a6SAndroid Build Coastguard Worker new_res = prog.exported_program().module()(*inputs) 250*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(new_res, orig_res)) 251*523fa7a6SAndroid Build Coastguard Worker 252*523fa7a6SAndroid Build Coastguard Worker def test_create_submodule_list_return(self) -> None: 253*523fa7a6SAndroid Build Coastguard Worker """ 254*523fa7a6SAndroid Build Coastguard Worker Original graph: 255*523fa7a6SAndroid Build Coastguard Worker split_tensor = split(x, 5) 256*523fa7a6SAndroid Build Coastguard Worker getitem_0 = split_tensor[0] 257*523fa7a6SAndroid Build Coastguard Worker sub_tensor = sub(getitem_0, y) 258*523fa7a6SAndroid Build Coastguard Worker div_tensor = div(sub_tensor, y) 259*523fa7a6SAndroid Build Coastguard Worker return [div_tensor] 260*523fa7a6SAndroid Build Coastguard Worker 261*523fa7a6SAndroid Build Coastguard Worker Partitioned graph: 262*523fa7a6SAndroid Build Coastguard Worker split_tensor = split(x, 5) 263*523fa7a6SAndroid Build Coastguard Worker getitem_0 = split_tensor[0] 264*523fa7a6SAndroid Build Coastguard Worker getitem_1 = split_tensor[1] 265*523fa7a6SAndroid Build Coastguard Worker return [getitem_0, getitem_1] # List output is "opened" 266*523fa7a6SAndroid Build Coastguard Worker 267*523fa7a6SAndroid Build Coastguard Worker Final graph: 268*523fa7a6SAndroid Build Coastguard Worker partitioned_res = partitioned_graph(x, y) 269*523fa7a6SAndroid Build Coastguard Worker getitem_0 = partitioned_res[0] 270*523fa7a6SAndroid Build Coastguard Worker sub_tensor = sub(getitem_0, y) 271*523fa7a6SAndroid Build Coastguard Worker div_tensor = div(sub_tensor, y) 272*523fa7a6SAndroid Build Coastguard Worker return [div_tensor] 273*523fa7a6SAndroid Build Coastguard Worker """ 274*523fa7a6SAndroid Build Coastguard Worker inputs = (torch.randn(10), torch.randn(5)) 275*523fa7a6SAndroid Build Coastguard Worker 276*523fa7a6SAndroid Build Coastguard Worker class Model(torch.nn.Module): 277*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 278*523fa7a6SAndroid Build Coastguard Worker super().__init__() 279*523fa7a6SAndroid Build Coastguard Worker 280*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 281*523fa7a6SAndroid Build Coastguard Worker x = torch.split(x, 5) 282*523fa7a6SAndroid Build Coastguard Worker x = x[0] - y 283*523fa7a6SAndroid Build Coastguard Worker x = x / y 284*523fa7a6SAndroid Build Coastguard Worker return x 285*523fa7a6SAndroid Build Coastguard Worker 286*523fa7a6SAndroid Build Coastguard Worker orig_res = Model()(*inputs) 287*523fa7a6SAndroid Build Coastguard Worker prog = to_edge(export(Model(), inputs)) 288*523fa7a6SAndroid Build Coastguard Worker gm = prog.exported_program().graph_module 289*523fa7a6SAndroid Build Coastguard Worker 290*523fa7a6SAndroid Build Coastguard Worker node_list = [] 291*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 292*523fa7a6SAndroid Build Coastguard Worker # TODO(ssjia): split.Tensor now gets decomposed to split_with_sizes. Due to how executorch uses a pinned Pytorch 293*523fa7a6SAndroid Build Coastguard Worker # nightly, the CI may not catch the changes to Pytorch's core decomposition table. As a temporary workaround, 294*523fa7a6SAndroid Build Coastguard Worker # make the test backwards compatible with the old decomposition table. Remove the or statement once Pytorch nightly 295*523fa7a6SAndroid Build Coastguard Worker # has been updated. 296*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" and ( 297*523fa7a6SAndroid Build Coastguard Worker node.target == exir_ops.edge.aten.split_with_sizes_copy.default 298*523fa7a6SAndroid Build Coastguard Worker or node.target == exir_ops.edge.aten.split_copy.Tensor 299*523fa7a6SAndroid Build Coastguard Worker ): 300*523fa7a6SAndroid Build Coastguard Worker node_list.append(node) 301*523fa7a6SAndroid Build Coastguard Worker 302*523fa7a6SAndroid Build Coastguard Worker sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag") 303*523fa7a6SAndroid Build Coastguard Worker 304*523fa7a6SAndroid Build Coastguard Worker for node in sub_gm.graph.nodes: 305*523fa7a6SAndroid Build Coastguard Worker if node.op == "output": 306*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(node.args), 1) 307*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(node.args[0], list)) 308*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(node.args[0]), 2) 309*523fa7a6SAndroid Build Coastguard Worker 310*523fa7a6SAndroid Build Coastguard Worker new_res = prog.exported_program().module()(*inputs) 311*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(new_res, orig_res)) 312