1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Workerimport copy 9*523fa7a6SAndroid Build Coastguard Workerimport os 10*523fa7a6SAndroid Build Coastguard Workerimport tempfile 11*523fa7a6SAndroid Build Coastguard Workerimport unittest 12*523fa7a6SAndroid Build Coastguard Workerfrom typing import List, Optional, Tuple 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Worker# Import passes 17*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.memory_planning # noqa 18*523fa7a6SAndroid Build Coastguard Workerimport torch 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.emit import emit_program 23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import get_control_flow_submodules 24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_base import ExportPass, PassResult 25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes import ( 26*523fa7a6SAndroid Build Coastguard Worker dead_code_elimination_pass, 27*523fa7a6SAndroid Build Coastguard Worker DebugPass, 28*523fa7a6SAndroid Build Coastguard Worker HintBasedSymShapeEvalPass, 29*523fa7a6SAndroid Build Coastguard Worker MemoryPlanningPass, 30*523fa7a6SAndroid Build Coastguard Worker propagate_dynamic_shape, 31*523fa7a6SAndroid Build Coastguard Worker RemoveNoopPass, 32*523fa7a6SAndroid Build Coastguard Worker ReplaceSymSizeOpPass, 33*523fa7a6SAndroid Build Coastguard Worker ToOutVarPass, 34*523fa7a6SAndroid Build Coastguard Worker) 35*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.constant_prop_pass import constant_prop_pass 36*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.debug_handle_generator_pass import ( 37*523fa7a6SAndroid Build Coastguard Worker DebugHandleGeneratorPass, 38*523fa7a6SAndroid Build Coastguard Worker generate_missing_debug_handles, 39*523fa7a6SAndroid Build Coastguard Worker) 40*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.insert_write_back_for_buffers_pass import ( 41*523fa7a6SAndroid Build Coastguard Worker insert_write_back_for_buffers_pass, 42*523fa7a6SAndroid Build Coastguard Worker) 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass 45*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.normalize_view_copy_base_pass import ( 46*523fa7a6SAndroid Build Coastguard Worker NormalizeViewCopyBasePass, 47*523fa7a6SAndroid Build Coastguard Worker) 48*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass 49*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators 50*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass 51*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.replace_view_copy_with_view_pass import ( 52*523fa7a6SAndroid Build Coastguard Worker ReplaceViewCopyWithViewPass, 53*523fa7a6SAndroid Build Coastguard Worker) 54*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass 55*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.spec_prop_pass import SpecPropPass 56*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass 57*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program._program import lift_constant_tensor_pass 58*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import TensorShapeDynamism 59*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec 60*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.common import register_additional_test_aten_ops 61*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic 62*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.models import MLP, Mul 63*523fa7a6SAndroid Build Coastguard Workerfrom functorch.experimental import control_flow 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 68*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer.xnnpack_quantizer import ( 69*523fa7a6SAndroid Build Coastguard Worker get_symmetric_quantization_config, 70*523fa7a6SAndroid Build Coastguard Worker XNNPACKQuantizer, 71*523fa7a6SAndroid Build Coastguard Worker) 72*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export 73*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.graph_signature import InputKind, InputSpec, TensorArgument 74*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import GraphModule, subgraph_rewriter 75*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import make_fx 76*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import impl, Library 77*523fa7a6SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 78*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker 81*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore 82*523fa7a6SAndroid Build Coastguard Workerdef collect_ops(gm: torch.fx.GraphModule): 83*523fa7a6SAndroid Build Coastguard Worker """ 84*523fa7a6SAndroid Build Coastguard Worker Collect all targets for call_function nodes from the graph module recursively. 85*523fa7a6SAndroid Build Coastguard Worker """ 86*523fa7a6SAndroid Build Coastguard Worker ops = set() 87*523fa7a6SAndroid Build Coastguard Worker for subgm in gm.modules(): 88*523fa7a6SAndroid Build Coastguard Worker if not isinstance(subgm, torch.fx.GraphModule): 89*523fa7a6SAndroid Build Coastguard Worker continue 90*523fa7a6SAndroid Build Coastguard Worker for node in subgm.graph.nodes: 91*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function": 92*523fa7a6SAndroid Build Coastguard Worker ops.add(node.target) 93*523fa7a6SAndroid Build Coastguard Worker return ops 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Workerlib = Library("DO_NOT_USE_TEST_ONLY", "DEF") 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Workerlib.define("foo(Tensor self) -> (Tensor, Tensor)") 99*523fa7a6SAndroid Build Coastguard Workerlib.define("add_relu(Tensor self, Tensor other) -> Tensor") 100*523fa7a6SAndroid Build Coastguard Worker 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker@impl(lib, "foo", "CompositeExplicitAutograd") 103*523fa7a6SAndroid Build Coastguard Workerdef foo(a: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 104*523fa7a6SAndroid Build Coastguard Worker return a + 1, None 105*523fa7a6SAndroid Build Coastguard Worker 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Workerlib.define( 108*523fa7a6SAndroid Build Coastguard Worker "foo.out(Tensor self, *, Tensor(a!) out1, Tensor(b!) out2) -> (Tensor(a!), Tensor(b!))" 109*523fa7a6SAndroid Build Coastguard Worker) 110*523fa7a6SAndroid Build Coastguard Worker 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker@impl(lib, "foo.out", "CompositeExplicitAutograd") 113*523fa7a6SAndroid Build Coastguard Workerdef foo_out( 114*523fa7a6SAndroid Build Coastguard Worker a: torch.Tensor, out1: torch.Tensor, out2: torch.Tensor 115*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 116*523fa7a6SAndroid Build Coastguard Worker return a + 1, None 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Worker 119*523fa7a6SAndroid Build Coastguard Workerclass TestPasses(unittest.TestCase): 120*523fa7a6SAndroid Build Coastguard Worker @classmethod 121*523fa7a6SAndroid Build Coastguard Worker def setUpClass(cls) -> None: 122*523fa7a6SAndroid Build Coastguard Worker register_additional_test_aten_ops() 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Worker def test_remove_mixed_type_operators(self) -> None: 125*523fa7a6SAndroid Build Coastguard Worker class Add(torch.nn.Module): 126*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 127*523fa7a6SAndroid Build Coastguard Worker return (x + y) + x 128*523fa7a6SAndroid Build Coastguard Worker 129*523fa7a6SAndroid Build Coastguard Worker add = Add() 130*523fa7a6SAndroid Build Coastguard Worker 131*523fa7a6SAndroid Build Coastguard Worker int_tensor = torch.tensor([[1, 2, 3]]) 132*523fa7a6SAndroid Build Coastguard Worker float_tensor = torch.tensor([[1.0, 2.0, 3.0]]) 133*523fa7a6SAndroid Build Coastguard Worker edge_prog = to_edge( 134*523fa7a6SAndroid Build Coastguard Worker export( 135*523fa7a6SAndroid Build Coastguard Worker add, 136*523fa7a6SAndroid Build Coastguard Worker (int_tensor, float_tensor), 137*523fa7a6SAndroid Build Coastguard Worker ) 138*523fa7a6SAndroid Build Coastguard Worker ) 139*523fa7a6SAndroid Build Coastguard Worker 140*523fa7a6SAndroid Build Coastguard Worker new_prog = edge_prog.transform([RemoveMixedTypeOperators()]) 141*523fa7a6SAndroid Build Coastguard Worker new_graph_module = new_prog.exported_program().graph_module 142*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_graph_module) 143*523fa7a6SAndroid Build Coastguard Worker 144*523fa7a6SAndroid Build Coastguard Worker add_count = 0 145*523fa7a6SAndroid Build Coastguard Worker 146*523fa7a6SAndroid Build Coastguard Worker for node in new_graph_module.graph.nodes: 147*523fa7a6SAndroid Build Coastguard Worker if ( 148*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 149*523fa7a6SAndroid Build Coastguard Worker and node.target == exir_ops.edge.aten.add.Tensor 150*523fa7a6SAndroid Build Coastguard Worker ): 151*523fa7a6SAndroid Build Coastguard Worker add_count += 1 152*523fa7a6SAndroid Build Coastguard Worker node_args = node.args 153*523fa7a6SAndroid Build Coastguard Worker for arg in node_args: 154*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(arg.meta["val"].dtype, torch.float) 155*523fa7a6SAndroid Build Coastguard Worker 156*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(add_count, 2) 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard Worker double_tensor = torch.tensor([[1.0, 2.0, 3.0]]) 159*523fa7a6SAndroid Build Coastguard Worker double_tensor = double_tensor.to(torch.double) 160*523fa7a6SAndroid Build Coastguard Worker 161*523fa7a6SAndroid Build Coastguard Worker double_prog = to_edge(export(add, (int_tensor, double_tensor))) 162*523fa7a6SAndroid Build Coastguard Worker 163*523fa7a6SAndroid Build Coastguard Worker double_prog.transform([RemoveMixedTypeOperators()]) 164*523fa7a6SAndroid Build Coastguard Worker new_graph_module_double = double_prog.exported_program().graph_module 165*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_graph_module_double) 166*523fa7a6SAndroid Build Coastguard Worker 167*523fa7a6SAndroid Build Coastguard Worker add_count_double = 0 168*523fa7a6SAndroid Build Coastguard Worker 169*523fa7a6SAndroid Build Coastguard Worker for node in new_graph_module_double.graph.nodes: 170*523fa7a6SAndroid Build Coastguard Worker if ( 171*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 172*523fa7a6SAndroid Build Coastguard Worker and node.target == exir_ops.edge.aten.add.Tensor 173*523fa7a6SAndroid Build Coastguard Worker ): 174*523fa7a6SAndroid Build Coastguard Worker add_count_double += 1 175*523fa7a6SAndroid Build Coastguard Worker node_args = node.args 176*523fa7a6SAndroid Build Coastguard Worker for arg in node_args: 177*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(arg.meta["val"].dtype, torch.double) 178*523fa7a6SAndroid Build Coastguard Worker 179*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(add_count_double, 2) 180*523fa7a6SAndroid Build Coastguard Worker 181*523fa7a6SAndroid Build Coastguard Worker class Mult(torch.nn.Module): 182*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 183*523fa7a6SAndroid Build Coastguard Worker return x * y 184*523fa7a6SAndroid Build Coastguard Worker 185*523fa7a6SAndroid Build Coastguard Worker mult = Mult() 186*523fa7a6SAndroid Build Coastguard Worker 187*523fa7a6SAndroid Build Coastguard Worker float_tensor_vert = float_tensor.T 188*523fa7a6SAndroid Build Coastguard Worker mult_prog = to_edge( 189*523fa7a6SAndroid Build Coastguard Worker export( 190*523fa7a6SAndroid Build Coastguard Worker mult, 191*523fa7a6SAndroid Build Coastguard Worker (int_tensor, float_tensor_vert), 192*523fa7a6SAndroid Build Coastguard Worker ) 193*523fa7a6SAndroid Build Coastguard Worker ) 194*523fa7a6SAndroid Build Coastguard Worker 195*523fa7a6SAndroid Build Coastguard Worker # graph_module_mult.graph.print_tabular() 196*523fa7a6SAndroid Build Coastguard Worker 197*523fa7a6SAndroid Build Coastguard Worker mult_prog = mult_prog.transform([RemoveMixedTypeOperators()]) 198*523fa7a6SAndroid Build Coastguard Worker new_graph_module_mult = mult_prog.exported_program().graph_module 199*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_graph_module_mult) 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Worker mult_count = 0 202*523fa7a6SAndroid Build Coastguard Worker 203*523fa7a6SAndroid Build Coastguard Worker for node in new_graph_module_mult.graph.nodes: 204*523fa7a6SAndroid Build Coastguard Worker if ( 205*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 206*523fa7a6SAndroid Build Coastguard Worker and node.target == exir_ops.edge.aten.mul.Tensor 207*523fa7a6SAndroid Build Coastguard Worker ): 208*523fa7a6SAndroid Build Coastguard Worker mult_count += 1 209*523fa7a6SAndroid Build Coastguard Worker node_args = node.args 210*523fa7a6SAndroid Build Coastguard Worker for arg in node_args: 211*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(arg.meta["val"].dtype, torch.float) 212*523fa7a6SAndroid Build Coastguard Worker 213*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(mult_count, 1) 214*523fa7a6SAndroid Build Coastguard Worker 215*523fa7a6SAndroid Build Coastguard Worker def test_remove_noop_pass(self) -> None: 216*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 217*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 218*523fa7a6SAndroid Build Coastguard Worker return x.to(dtype=torch.float32) 219*523fa7a6SAndroid Build Coastguard Worker 220*523fa7a6SAndroid Build Coastguard Worker foo = Foo() 221*523fa7a6SAndroid Build Coastguard Worker 222*523fa7a6SAndroid Build Coastguard Worker # Turn off functionalization so that we can get the actual to.dtype op 223*523fa7a6SAndroid Build Coastguard Worker edge_prog = to_edge( 224*523fa7a6SAndroid Build Coastguard Worker export( 225*523fa7a6SAndroid Build Coastguard Worker foo, 226*523fa7a6SAndroid Build Coastguard Worker (torch.ones(1, dtype=torch.float32),), 227*523fa7a6SAndroid Build Coastguard Worker ) 228*523fa7a6SAndroid Build Coastguard Worker ) 229*523fa7a6SAndroid Build Coastguard Worker edge_prog = edge_prog.transform([RemoveNoopPass()]) 230*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(edge_prog.exported_program().graph_module) 231*523fa7a6SAndroid Build Coastguard Worker new_graph_module = edge_prog.exported_program().graph_module 232*523fa7a6SAndroid Build Coastguard Worker for node in new_graph_module.graph.nodes: 233*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function": 234*523fa7a6SAndroid Build Coastguard Worker self.assertNotEqual(node.target, torch.ops.aten.to.dtype) 235*523fa7a6SAndroid Build Coastguard Worker 236*523fa7a6SAndroid Build Coastguard Worker def test_redundant_slice_copy_removal(self) -> None: 237*523fa7a6SAndroid Build Coastguard Worker class FooWithNoSlice(torch.nn.Module): 238*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 239*523fa7a6SAndroid Build Coastguard Worker return x[:, :, :] 240*523fa7a6SAndroid Build Coastguard Worker 241*523fa7a6SAndroid Build Coastguard Worker foo_with_no_slice = FooWithNoSlice() 242*523fa7a6SAndroid Build Coastguard Worker 243*523fa7a6SAndroid Build Coastguard Worker class FooWithOneSlice(torch.nn.Module): 244*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 245*523fa7a6SAndroid Build Coastguard Worker return x[:1, :, :] 246*523fa7a6SAndroid Build Coastguard Worker 247*523fa7a6SAndroid Build Coastguard Worker foo_with_one_slice = FooWithOneSlice() 248*523fa7a6SAndroid Build Coastguard Worker 249*523fa7a6SAndroid Build Coastguard Worker class FooWithAllSlices(torch.nn.Module): 250*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 251*523fa7a6SAndroid Build Coastguard Worker return x[:1, :2, 2:4] 252*523fa7a6SAndroid Build Coastguard Worker 253*523fa7a6SAndroid Build Coastguard Worker foo_with_all_slices = FooWithAllSlices() 254*523fa7a6SAndroid Build Coastguard Worker 255*523fa7a6SAndroid Build Coastguard Worker # Turn off functionalization so that we can get the actual to.dtype op 256*523fa7a6SAndroid Build Coastguard Worker x = torch.ones((3, 8, 8)) 257*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 258*523fa7a6SAndroid Build Coastguard Worker export( 259*523fa7a6SAndroid Build Coastguard Worker foo_with_no_slice, 260*523fa7a6SAndroid Build Coastguard Worker (x,), 261*523fa7a6SAndroid Build Coastguard Worker ) 262*523fa7a6SAndroid Build Coastguard Worker ) 263*523fa7a6SAndroid Build Coastguard Worker prog = prog.transform([RemoveNoopPass()]) 264*523fa7a6SAndroid Build Coastguard Worker new_graph_module = prog.exported_program().graph_module 265*523fa7a6SAndroid Build Coastguard Worker FileCheck().check_count( 266*523fa7a6SAndroid Build Coastguard Worker "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 0, exactly=True 267*523fa7a6SAndroid Build Coastguard Worker ).run(new_graph_module.code) 268*523fa7a6SAndroid Build Coastguard Worker 269*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 270*523fa7a6SAndroid Build Coastguard Worker export( 271*523fa7a6SAndroid Build Coastguard Worker foo_with_one_slice, 272*523fa7a6SAndroid Build Coastguard Worker (x,), 273*523fa7a6SAndroid Build Coastguard Worker ) 274*523fa7a6SAndroid Build Coastguard Worker ) 275*523fa7a6SAndroid Build Coastguard Worker prog = prog.transform([RemoveNoopPass()]) 276*523fa7a6SAndroid Build Coastguard Worker new_graph_module = prog.exported_program().graph_module 277*523fa7a6SAndroid Build Coastguard Worker FileCheck().check_count( 278*523fa7a6SAndroid Build Coastguard Worker "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 1, exactly=True 279*523fa7a6SAndroid Build Coastguard Worker ).run(new_graph_module.code) 280*523fa7a6SAndroid Build Coastguard Worker 281*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 282*523fa7a6SAndroid Build Coastguard Worker export( 283*523fa7a6SAndroid Build Coastguard Worker foo_with_all_slices, 284*523fa7a6SAndroid Build Coastguard Worker (x,), 285*523fa7a6SAndroid Build Coastguard Worker ) 286*523fa7a6SAndroid Build Coastguard Worker ) 287*523fa7a6SAndroid Build Coastguard Worker prog = prog.transform([RemoveNoopPass()]) 288*523fa7a6SAndroid Build Coastguard Worker new_graph_module = prog.exported_program().graph_module 289*523fa7a6SAndroid Build Coastguard Worker FileCheck().check_count( 290*523fa7a6SAndroid Build Coastguard Worker "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 3, exactly=True 291*523fa7a6SAndroid Build Coastguard Worker ).run(new_graph_module.code) 292*523fa7a6SAndroid Build Coastguard Worker 293*523fa7a6SAndroid Build Coastguard Worker def test_compile_to_edge(self) -> None: 294*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 295*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 296*523fa7a6SAndroid Build Coastguard Worker return x * 2 297*523fa7a6SAndroid Build Coastguard Worker 298*523fa7a6SAndroid Build Coastguard Worker f = Foo() 299*523fa7a6SAndroid Build Coastguard Worker 300*523fa7a6SAndroid Build Coastguard Worker x = (torch.randn(2, 3),) 301*523fa7a6SAndroid Build Coastguard Worker 302*523fa7a6SAndroid Build Coastguard Worker to_edge( 303*523fa7a6SAndroid Build Coastguard Worker export( 304*523fa7a6SAndroid Build Coastguard Worker f, 305*523fa7a6SAndroid Build Coastguard Worker x, 306*523fa7a6SAndroid Build Coastguard Worker ) 307*523fa7a6SAndroid Build Coastguard Worker ).exported_program().graph_module 308*523fa7a6SAndroid Build Coastguard Worker # TODO(angelayi): Add a utility function that verifies a model is in 309*523fa7a6SAndroid Build Coastguard Worker # the edge dialect 310*523fa7a6SAndroid Build Coastguard Worker 311*523fa7a6SAndroid Build Coastguard Worker def test_to_out_variant_none_output(self) -> None: 312*523fa7a6SAndroid Build Coastguard Worker class CompositeModel(torch.nn.Module): 313*523fa7a6SAndroid Build Coastguard Worker def __init__(self, _weight): 314*523fa7a6SAndroid Build Coastguard Worker super().__init__() 315*523fa7a6SAndroid Build Coastguard Worker self.weight = _weight 316*523fa7a6SAndroid Build Coastguard Worker self.lstm = torch.nn.LSTM( 317*523fa7a6SAndroid Build Coastguard Worker input_size=32, 318*523fa7a6SAndroid Build Coastguard Worker hidden_size=32, 319*523fa7a6SAndroid Build Coastguard Worker num_layers=1, 320*523fa7a6SAndroid Build Coastguard Worker ) 321*523fa7a6SAndroid Build Coastguard Worker 322*523fa7a6SAndroid Build Coastguard Worker def forward(self, x_raw, h, c): 323*523fa7a6SAndroid Build Coastguard Worker output, (hn, cn) = self.lstm(x_raw, (h, c)) 324*523fa7a6SAndroid Build Coastguard Worker return output 325*523fa7a6SAndroid Build Coastguard Worker 326*523fa7a6SAndroid Build Coastguard Worker # Prepare input and trace it 327*523fa7a6SAndroid Build Coastguard Worker input_x = torch.ones([1, 32]) 328*523fa7a6SAndroid Build Coastguard Worker input_h = torch.ones([1, 32]) 329*523fa7a6SAndroid Build Coastguard Worker input_c = torch.ones([1, 32]) 330*523fa7a6SAndroid Build Coastguard Worker inputs = (input_x, input_h, input_c) 331*523fa7a6SAndroid Build Coastguard Worker 332*523fa7a6SAndroid Build Coastguard Worker composite_m = CompositeModel(3) 333*523fa7a6SAndroid Build Coastguard Worker 334*523fa7a6SAndroid Build Coastguard Worker edge_prog = to_edge( 335*523fa7a6SAndroid Build Coastguard Worker export( 336*523fa7a6SAndroid Build Coastguard Worker composite_m, 337*523fa7a6SAndroid Build Coastguard Worker inputs, 338*523fa7a6SAndroid Build Coastguard Worker ) 339*523fa7a6SAndroid Build Coastguard Worker # torch._ops.aten.t.default 340*523fa7a6SAndroid Build Coastguard Worker , 341*523fa7a6SAndroid Build Coastguard Worker compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 342*523fa7a6SAndroid Build Coastguard Worker ) 343*523fa7a6SAndroid Build Coastguard Worker 344*523fa7a6SAndroid Build Coastguard Worker new_prog = edge_prog.transform([SpecPropPass()]) 345*523fa7a6SAndroid Build Coastguard Worker 346*523fa7a6SAndroid Build Coastguard Worker new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module) 347*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_gm_res) 348*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 349*523fa7a6SAndroid Build Coastguard Worker for node in new_gm.graph.nodes: 350*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" and node.target in [ 351*523fa7a6SAndroid Build Coastguard Worker torch.ops.DO_NOT_USE_TEST_ONLY.foo.out, 352*523fa7a6SAndroid Build Coastguard Worker torch.ops.my_awesome_3rdparty_ns.awesome_op.out, 353*523fa7a6SAndroid Build Coastguard Worker ]: 354*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(node.kwargs), 2) 355*523fa7a6SAndroid Build Coastguard Worker out1_node = node.kwargs["out1"] 356*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(out1_node.op, "call_function") 357*523fa7a6SAndroid Build Coastguard Worker self.assertIs(out1_node.target, memory.alloc) 358*523fa7a6SAndroid Build Coastguard Worker self.assertIs(node.kwargs["out2"], None) 359*523fa7a6SAndroid Build Coastguard Worker 360*523fa7a6SAndroid Build Coastguard Worker new_gm_res = MemoryPlanningPass()(new_gm) 361*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_gm_res) 362*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 363*523fa7a6SAndroid Build Coastguard Worker new_prog.exported_program().graph_module.graph = new_gm.graph 364*523fa7a6SAndroid Build Coastguard Worker emit_program(new_prog.exported_program()) 365*523fa7a6SAndroid Build Coastguard Worker 366*523fa7a6SAndroid Build Coastguard Worker def test_to_out_variant_singleon_tensor_list(self) -> None: 367*523fa7a6SAndroid Build Coastguard Worker class MyModel(nn.Module): 368*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 369*523fa7a6SAndroid Build Coastguard Worker super().__init__() 370*523fa7a6SAndroid Build Coastguard Worker 371*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 372*523fa7a6SAndroid Build Coastguard Worker return torch.split(x, 10) 373*523fa7a6SAndroid Build Coastguard Worker 374*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 375*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(10),) 376*523fa7a6SAndroid Build Coastguard Worker 377*523fa7a6SAndroid Build Coastguard Worker model = MyModel() 378*523fa7a6SAndroid Build Coastguard Worker inputs = model.get_random_inputs() 379*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 380*523fa7a6SAndroid Build Coastguard Worker export( 381*523fa7a6SAndroid Build Coastguard Worker model, 382*523fa7a6SAndroid Build Coastguard Worker inputs, 383*523fa7a6SAndroid Build Coastguard Worker ), 384*523fa7a6SAndroid Build Coastguard Worker compile_config=EdgeCompileConfig(_check_ir_validity=False), 385*523fa7a6SAndroid Build Coastguard Worker ) # TODO(larryliu): fix split_copy 386*523fa7a6SAndroid Build Coastguard Worker new_gm_res = ToOutVarPass()(prog.exported_program().graph_module) 387*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_gm_res) 388*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 389*523fa7a6SAndroid Build Coastguard Worker 390*523fa7a6SAndroid Build Coastguard Worker for nd in new_gm.graph.nodes: 391*523fa7a6SAndroid Build Coastguard Worker if nd.target is exir_ops.edge.aten.split_copy.Tensor_out: 392*523fa7a6SAndroid Build Coastguard Worker break 393*523fa7a6SAndroid Build Coastguard Worker 394*523fa7a6SAndroid Build Coastguard Worker val = nd.meta["val"] 395*523fa7a6SAndroid Build Coastguard Worker 396*523fa7a6SAndroid Build Coastguard Worker # We must return a spec which is a list of a signle TensorSpec item. 397*523fa7a6SAndroid Build Coastguard Worker # Returning the TensorSpec item directly cause future getitem op fails. 398*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(val, (tuple, list))) 399*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(1, len(val)) 400*523fa7a6SAndroid Build Coastguard Worker 401*523fa7a6SAndroid Build Coastguard Worker def test_to_out_variant_multiple_out(self) -> None: 402*523fa7a6SAndroid Build Coastguard Worker class MyModel(nn.Module): 403*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 404*523fa7a6SAndroid Build Coastguard Worker super().__init__() 405*523fa7a6SAndroid Build Coastguard Worker 406*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 407*523fa7a6SAndroid Build Coastguard Worker return torch.topk(x, 5) 408*523fa7a6SAndroid Build Coastguard Worker 409*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 410*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(10),) 411*523fa7a6SAndroid Build Coastguard Worker 412*523fa7a6SAndroid Build Coastguard Worker model = MyModel() 413*523fa7a6SAndroid Build Coastguard Worker inputs = model.get_random_inputs() 414*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 415*523fa7a6SAndroid Build Coastguard Worker export( 416*523fa7a6SAndroid Build Coastguard Worker model, 417*523fa7a6SAndroid Build Coastguard Worker inputs, 418*523fa7a6SAndroid Build Coastguard Worker ), 419*523fa7a6SAndroid Build Coastguard Worker compile_config=EdgeCompileConfig(_check_ir_validity=False), 420*523fa7a6SAndroid Build Coastguard Worker ) # TODO(larryliu): fix topk 421*523fa7a6SAndroid Build Coastguard Worker new_gm_res = ToOutVarPass()(prog.exported_program().graph_module) 422*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_gm_res) 423*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 424*523fa7a6SAndroid Build Coastguard Worker 425*523fa7a6SAndroid Build Coastguard Worker for nd in new_gm.graph.nodes: 426*523fa7a6SAndroid Build Coastguard Worker if nd.target is torch.ops.aten.topk.values: 427*523fa7a6SAndroid Build Coastguard Worker break 428*523fa7a6SAndroid Build Coastguard Worker 429*523fa7a6SAndroid Build Coastguard Worker val = nd.meta["val"] 430*523fa7a6SAndroid Build Coastguard Worker 431*523fa7a6SAndroid Build Coastguard Worker # We must return a spec which is a list of a signle TensorSpec item. 432*523fa7a6SAndroid Build Coastguard Worker # Returning the TensorSpec item directly cause future getitem op fails. 433*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(val, (tuple, list))) 434*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(2, len(val)) 435*523fa7a6SAndroid Build Coastguard Worker 436*523fa7a6SAndroid Build Coastguard Worker def test_to_out_variant_to_copy(self) -> None: 437*523fa7a6SAndroid Build Coastguard Worker class Module(torch.nn.Module): 438*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 439*523fa7a6SAndroid Build Coastguard Worker super().__init__() 440*523fa7a6SAndroid Build Coastguard Worker 441*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 442*523fa7a6SAndroid Build Coastguard Worker return x.to(torch.int32) 443*523fa7a6SAndroid Build Coastguard Worker 444*523fa7a6SAndroid Build Coastguard Worker model = Module() 445*523fa7a6SAndroid Build Coastguard Worker 446*523fa7a6SAndroid Build Coastguard Worker inputs = torch.tensor(1.0, dtype=torch.float) 447*523fa7a6SAndroid Build Coastguard Worker model_res = model(inputs) 448*523fa7a6SAndroid Build Coastguard Worker 449*523fa7a6SAndroid Build Coastguard Worker edge_dialect = to_edge( 450*523fa7a6SAndroid Build Coastguard Worker export( 451*523fa7a6SAndroid Build Coastguard Worker model, 452*523fa7a6SAndroid Build Coastguard Worker (inputs,), 453*523fa7a6SAndroid Build Coastguard Worker ) 454*523fa7a6SAndroid Build Coastguard Worker ) 455*523fa7a6SAndroid Build Coastguard Worker edge_res = edge_dialect.exported_program().module()(inputs) 456*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(model_res, edge_res)) 457*523fa7a6SAndroid Build Coastguard Worker 458*523fa7a6SAndroid Build Coastguard Worker def test_export_pass(self) -> None: 459*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 460*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 461*523fa7a6SAndroid Build Coastguard Worker y = torch.cat([x, x]) 462*523fa7a6SAndroid Build Coastguard Worker return torch.ops.aten.tensor_split.sections(y, 2) 463*523fa7a6SAndroid Build Coastguard Worker 464*523fa7a6SAndroid Build Coastguard Worker f = Foo() 465*523fa7a6SAndroid Build Coastguard Worker 466*523fa7a6SAndroid Build Coastguard Worker class NullPass(ExportPass): 467*523fa7a6SAndroid Build Coastguard Worker pass 468*523fa7a6SAndroid Build Coastguard Worker 469*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 470*523fa7a6SAndroid Build Coastguard Worker export( 471*523fa7a6SAndroid Build Coastguard Worker f, 472*523fa7a6SAndroid Build Coastguard Worker (torch.ones(3, 2),), 473*523fa7a6SAndroid Build Coastguard Worker ), 474*523fa7a6SAndroid Build Coastguard Worker compile_config=EdgeCompileConfig(_check_ir_validity=False), 475*523fa7a6SAndroid Build Coastguard Worker ) # TODO(larryliu): fix cat 476*523fa7a6SAndroid Build Coastguard Worker new_prog = prog.transform([NullPass()]) 477*523fa7a6SAndroid Build Coastguard Worker new_nodes = new_prog.exported_program().graph_module.graph.nodes 478*523fa7a6SAndroid Build Coastguard Worker for node in new_nodes: 479*523fa7a6SAndroid Build Coastguard Worker if node.op != "call_function": 480*523fa7a6SAndroid Build Coastguard Worker continue 481*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(hasattr(node, "stack_trace")) 482*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(node.stack_trace) 483*523fa7a6SAndroid Build Coastguard Worker 484*523fa7a6SAndroid Build Coastguard Worker old_nodes = prog.exported_program().graph_module.graph.nodes 485*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(new_nodes), len(old_nodes)) 486*523fa7a6SAndroid Build Coastguard Worker for new_node, old_node in zip(new_nodes, old_nodes): 487*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(new_node.op, old_node.op) 488*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(new_node.target, old_node.target) 489*523fa7a6SAndroid Build Coastguard Worker 490*523fa7a6SAndroid Build Coastguard Worker def test_export_pass_pt2(self) -> None: 491*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 492*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 493*523fa7a6SAndroid Build Coastguard Worker y = torch.cat([x, x]) 494*523fa7a6SAndroid Build Coastguard Worker return torch.ops.aten.tensor_split.sections(y, 2) 495*523fa7a6SAndroid Build Coastguard Worker 496*523fa7a6SAndroid Build Coastguard Worker f = Foo() 497*523fa7a6SAndroid Build Coastguard Worker 498*523fa7a6SAndroid Build Coastguard Worker class NullPass(ExportPass): 499*523fa7a6SAndroid Build Coastguard Worker pass 500*523fa7a6SAndroid Build Coastguard Worker 501*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 502*523fa7a6SAndroid Build Coastguard Worker export( 503*523fa7a6SAndroid Build Coastguard Worker f, 504*523fa7a6SAndroid Build Coastguard Worker (torch.ones(3, 2),), 505*523fa7a6SAndroid Build Coastguard Worker ), 506*523fa7a6SAndroid Build Coastguard Worker compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 507*523fa7a6SAndroid Build Coastguard Worker ) 508*523fa7a6SAndroid Build Coastguard Worker new_prog = prog.transform([NullPass()]) 509*523fa7a6SAndroid Build Coastguard Worker new_nodes = new_prog.exported_program().graph_module.graph.nodes 510*523fa7a6SAndroid Build Coastguard Worker for node in new_nodes: 511*523fa7a6SAndroid Build Coastguard Worker if node.op != "call_function": 512*523fa7a6SAndroid Build Coastguard Worker continue 513*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(hasattr(node, "stack_trace")) 514*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(node.stack_trace) 515*523fa7a6SAndroid Build Coastguard Worker 516*523fa7a6SAndroid Build Coastguard Worker old_nodes = prog.exported_program().graph_module.graph.nodes 517*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(new_nodes), len(old_nodes)) 518*523fa7a6SAndroid Build Coastguard Worker for new_node, old_node in zip(new_nodes, old_nodes): 519*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(new_node.op, old_node.op) 520*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(new_node.target, old_node.target) 521*523fa7a6SAndroid Build Coastguard Worker 522*523fa7a6SAndroid Build Coastguard Worker def test_export_scalar_to_tensor_pass(self) -> None: 523*523fa7a6SAndroid Build Coastguard Worker class Mul(torch.nn.Module): 524*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 525*523fa7a6SAndroid Build Coastguard Worker return x * 3.14 526*523fa7a6SAndroid Build Coastguard Worker 527*523fa7a6SAndroid Build Coastguard Worker mul = Mul() 528*523fa7a6SAndroid Build Coastguard Worker 529*523fa7a6SAndroid Build Coastguard Worker expo_prog = to_edge(export(mul, (torch.ones(1),))) 530*523fa7a6SAndroid Build Coastguard Worker new_prog = expo_prog.transform([ScalarToTensorPass()]) 531*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_prog.exported_program().graph_module) 532*523fa7a6SAndroid Build Coastguard Worker new_graph_module = new_prog.exported_program().graph_module 533*523fa7a6SAndroid Build Coastguard Worker 534*523fa7a6SAndroid Build Coastguard Worker inp = torch.zeros(1) 535*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 536*523fa7a6SAndroid Build Coastguard Worker torch.allclose( 537*523fa7a6SAndroid Build Coastguard Worker expo_prog.exported_program().module()(inp), 538*523fa7a6SAndroid Build Coastguard Worker new_prog.exported_program().module()(inp), 539*523fa7a6SAndroid Build Coastguard Worker ) 540*523fa7a6SAndroid Build Coastguard Worker ) 541*523fa7a6SAndroid Build Coastguard Worker for node in new_graph_module.graph.nodes: 542*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function": 543*523fa7a6SAndroid Build Coastguard Worker for arg in node.args + tuple(node.kwargs.values()): 544*523fa7a6SAndroid Build Coastguard Worker self.assertFalse(isinstance(arg, float)) 545*523fa7a6SAndroid Build Coastguard Worker 546*523fa7a6SAndroid Build Coastguard Worker def test_remove_mixed_types_symfloats(self) -> None: 547*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 548*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 549*523fa7a6SAndroid Build Coastguard Worker return torch.nn.functional.interpolate( 550*523fa7a6SAndroid Build Coastguard Worker x, 551*523fa7a6SAndroid Build Coastguard Worker size=(x.shape[2] * 2, x.shape[3] * 3), 552*523fa7a6SAndroid Build Coastguard Worker mode="bilinear", 553*523fa7a6SAndroid Build Coastguard Worker align_corners=False, 554*523fa7a6SAndroid Build Coastguard Worker antialias=False, 555*523fa7a6SAndroid Build Coastguard Worker ) 556*523fa7a6SAndroid Build Coastguard Worker 557*523fa7a6SAndroid Build Coastguard Worker f = Foo() 558*523fa7a6SAndroid Build Coastguard Worker 559*523fa7a6SAndroid Build Coastguard Worker example_inputs = (torch.randn(2, 3, 4, 5),) 560*523fa7a6SAndroid Build Coastguard Worker 561*523fa7a6SAndroid Build Coastguard Worker gm = to_edge( 562*523fa7a6SAndroid Build Coastguard Worker export( 563*523fa7a6SAndroid Build Coastguard Worker f, 564*523fa7a6SAndroid Build Coastguard Worker example_inputs, 565*523fa7a6SAndroid Build Coastguard Worker ) 566*523fa7a6SAndroid Build Coastguard Worker ) 567*523fa7a6SAndroid Build Coastguard Worker new_gm = gm.transform( 568*523fa7a6SAndroid Build Coastguard Worker [ReplaceSymSizeOpPass(), ScalarToTensorPass(), RemoveMixedTypeOperators()] 569*523fa7a6SAndroid Build Coastguard Worker ) 570*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_gm.exported_program().graph_module) 571*523fa7a6SAndroid Build Coastguard Worker 572*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 573*523fa7a6SAndroid Build Coastguard Worker torch.allclose( 574*523fa7a6SAndroid Build Coastguard Worker gm.exported_program().module()(*example_inputs), 575*523fa7a6SAndroid Build Coastguard Worker new_gm.exported_program().module()(*example_inputs), 576*523fa7a6SAndroid Build Coastguard Worker ) 577*523fa7a6SAndroid Build Coastguard Worker ) 578*523fa7a6SAndroid Build Coastguard Worker 579*523fa7a6SAndroid Build Coastguard Worker def test_spec_prop_pass(self) -> None: 580*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 581*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 582*523fa7a6SAndroid Build Coastguard Worker return x + x 583*523fa7a6SAndroid Build Coastguard Worker 584*523fa7a6SAndroid Build Coastguard Worker f = Foo() 585*523fa7a6SAndroid Build Coastguard Worker 586*523fa7a6SAndroid Build Coastguard Worker gm = ( 587*523fa7a6SAndroid Build Coastguard Worker to_edge( 588*523fa7a6SAndroid Build Coastguard Worker export( 589*523fa7a6SAndroid Build Coastguard Worker f, 590*523fa7a6SAndroid Build Coastguard Worker (torch.ones(3, 2),), 591*523fa7a6SAndroid Build Coastguard Worker ) 592*523fa7a6SAndroid Build Coastguard Worker ) 593*523fa7a6SAndroid Build Coastguard Worker .exported_program() 594*523fa7a6SAndroid Build Coastguard Worker .graph_module 595*523fa7a6SAndroid Build Coastguard Worker ) 596*523fa7a6SAndroid Build Coastguard Worker new_gm = SpecPropPass()(gm) 597*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_gm) 598*523fa7a6SAndroid Build Coastguard Worker new_nodes = new_gm.graph_module.graph.nodes 599*523fa7a6SAndroid Build Coastguard Worker counter = 0 600*523fa7a6SAndroid Build Coastguard Worker for node in new_nodes: 601*523fa7a6SAndroid Build Coastguard Worker if node.op != "output": 602*523fa7a6SAndroid Build Coastguard Worker continue 603*523fa7a6SAndroid Build Coastguard Worker counter += 1 604*523fa7a6SAndroid Build Coastguard Worker self.assertIs(node.meta["spec"][0], node.args[0][0].meta["spec"]) 605*523fa7a6SAndroid Build Coastguard Worker 606*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(counter, 1) 607*523fa7a6SAndroid Build Coastguard Worker 608*523fa7a6SAndroid Build Coastguard Worker def test_spec_prop_pass_tuple_output(self) -> None: 609*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 610*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: 611*523fa7a6SAndroid Build Coastguard Worker return (x + x,) 612*523fa7a6SAndroid Build Coastguard Worker 613*523fa7a6SAndroid Build Coastguard Worker f = Foo() 614*523fa7a6SAndroid Build Coastguard Worker 615*523fa7a6SAndroid Build Coastguard Worker gm = ( 616*523fa7a6SAndroid Build Coastguard Worker to_edge( 617*523fa7a6SAndroid Build Coastguard Worker export( 618*523fa7a6SAndroid Build Coastguard Worker f, 619*523fa7a6SAndroid Build Coastguard Worker (torch.ones(3, 2),), 620*523fa7a6SAndroid Build Coastguard Worker ) 621*523fa7a6SAndroid Build Coastguard Worker ) 622*523fa7a6SAndroid Build Coastguard Worker .exported_program() 623*523fa7a6SAndroid Build Coastguard Worker .graph_module 624*523fa7a6SAndroid Build Coastguard Worker ) 625*523fa7a6SAndroid Build Coastguard Worker new_gm = SpecPropPass()(gm) 626*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_gm) 627*523fa7a6SAndroid Build Coastguard Worker new_nodes = new_gm.graph_module.graph.nodes 628*523fa7a6SAndroid Build Coastguard Worker counter = 0 629*523fa7a6SAndroid Build Coastguard Worker for node in new_nodes: 630*523fa7a6SAndroid Build Coastguard Worker if node.op != "output": 631*523fa7a6SAndroid Build Coastguard Worker continue 632*523fa7a6SAndroid Build Coastguard Worker counter += 1 633*523fa7a6SAndroid Build Coastguard Worker self.assertIs(node.meta["spec"][0], node.args[0][0].meta["spec"]) 634*523fa7a6SAndroid Build Coastguard Worker 635*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(counter, 1) 636*523fa7a6SAndroid Build Coastguard Worker 637*523fa7a6SAndroid Build Coastguard Worker def test_compile_fix_broken_ops(self) -> None: 638*523fa7a6SAndroid Build Coastguard Worker # When pass an input of more than 4 dimensions to Linear 639*523fa7a6SAndroid Build Coastguard Worker # aten._unsafe_view is used under the hood 640*523fa7a6SAndroid Build Coastguard Worker x = torch.randn([2, 3, 4, 5]) 641*523fa7a6SAndroid Build Coastguard Worker model: torch.nn.Linear = torch.nn.Linear(5, 5) 642*523fa7a6SAndroid Build Coastguard Worker 643*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 644*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 645*523fa7a6SAndroid Build Coastguard Worker super().__init__() 646*523fa7a6SAndroid Build Coastguard Worker self.model = model 647*523fa7a6SAndroid Build Coastguard Worker 648*523fa7a6SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 649*523fa7a6SAndroid Build Coastguard Worker return self.model(inp) 650*523fa7a6SAndroid Build Coastguard Worker 651*523fa7a6SAndroid Build Coastguard Worker f = Foo() 652*523fa7a6SAndroid Build Coastguard Worker 653*523fa7a6SAndroid Build Coastguard Worker # ReplaceBrokenOpsWithFunctionalOpsPass is used in to_edge() 654*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 655*523fa7a6SAndroid Build Coastguard Worker export( 656*523fa7a6SAndroid Build Coastguard Worker f, 657*523fa7a6SAndroid Build Coastguard Worker (x,), 658*523fa7a6SAndroid Build Coastguard Worker ), 659*523fa7a6SAndroid Build Coastguard Worker compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 660*523fa7a6SAndroid Build Coastguard Worker ) 661*523fa7a6SAndroid Build Coastguard Worker gm = prog.exported_program().graph_module 662*523fa7a6SAndroid Build Coastguard Worker count_after = 0 663*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 664*523fa7a6SAndroid Build Coastguard Worker if node.target == torch.ops.aten._unsafe_view.default: 665*523fa7a6SAndroid Build Coastguard Worker count_after += 1 666*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(count_after, 0) 667*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(prog.exported_program().module()(x), f(x))) 668*523fa7a6SAndroid Build Coastguard Worker 669*523fa7a6SAndroid Build Coastguard Worker def test_convert_symb_ops(self) -> None: 670*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 671*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 672*523fa7a6SAndroid Build Coastguard Worker return torch.add(x, x.shape[0] - 1) 673*523fa7a6SAndroid Build Coastguard Worker 674*523fa7a6SAndroid Build Coastguard Worker f = Foo() 675*523fa7a6SAndroid Build Coastguard Worker 676*523fa7a6SAndroid Build Coastguard Worker # Mark the 0th dimension of X as dynamic with a max value of 3. 677*523fa7a6SAndroid Build Coastguard Worker dim_x = torch.export.Dim("dim_x", max=3) 678*523fa7a6SAndroid Build Coastguard Worker 679*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 680*523fa7a6SAndroid Build Coastguard Worker export( 681*523fa7a6SAndroid Build Coastguard Worker f, 682*523fa7a6SAndroid Build Coastguard Worker (torch.ones(3, 2),), 683*523fa7a6SAndroid Build Coastguard Worker dynamic_shapes={"x": {0: dim_x}}, 684*523fa7a6SAndroid Build Coastguard Worker ), 685*523fa7a6SAndroid Build Coastguard Worker compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 686*523fa7a6SAndroid Build Coastguard Worker ) 687*523fa7a6SAndroid Build Coastguard Worker new_prog = prog.transform([EdgeToBackendOpsPass()]) 688*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_prog.exported_program().graph_module) 689*523fa7a6SAndroid Build Coastguard Worker converted_gm = new_prog.exported_program().graph_module 690*523fa7a6SAndroid Build Coastguard Worker 691*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("torch.ops.aten.sym_size.int").check( 692*523fa7a6SAndroid Build Coastguard Worker "executorch_exir_dialects_backend__ops_executorch_prim_sub_Scalar" 693*523fa7a6SAndroid Build Coastguard Worker ).check_not("operator.sub").run(converted_gm.code) 694*523fa7a6SAndroid Build Coastguard Worker 695*523fa7a6SAndroid Build Coastguard Worker def test_alloc_node_spec(self) -> None: 696*523fa7a6SAndroid Build Coastguard Worker """ 697*523fa7a6SAndroid Build Coastguard Worker Make sure every memory.alloc node including those in sub graph modules 698*523fa7a6SAndroid Build Coastguard Worker have a TensorSpec. 699*523fa7a6SAndroid Build Coastguard Worker """ 700*523fa7a6SAndroid Build Coastguard Worker eager_model = FTMapBasic() 701*523fa7a6SAndroid Build Coastguard Worker inputs = eager_model.get_random_inputs() 702*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 703*523fa7a6SAndroid Build Coastguard Worker export( 704*523fa7a6SAndroid Build Coastguard Worker eager_model, 705*523fa7a6SAndroid Build Coastguard Worker inputs, 706*523fa7a6SAndroid Build Coastguard Worker ), 707*523fa7a6SAndroid Build Coastguard Worker compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 708*523fa7a6SAndroid Build Coastguard Worker ) 709*523fa7a6SAndroid Build Coastguard Worker passes = [ 710*523fa7a6SAndroid Build Coastguard Worker SpecPropPass(), 711*523fa7a6SAndroid Build Coastguard Worker HintBasedSymShapeEvalPass(), 712*523fa7a6SAndroid Build Coastguard Worker ] 713*523fa7a6SAndroid Build Coastguard Worker new_prog = prog.transform(passes) 714*523fa7a6SAndroid Build Coastguard Worker 715*523fa7a6SAndroid Build Coastguard Worker new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module) 716*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_gm_res) 717*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 718*523fa7a6SAndroid Build Coastguard Worker 719*523fa7a6SAndroid Build Coastguard Worker new_gm_res = MemoryPlanningPass()(new_gm) 720*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_gm_res) 721*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 722*523fa7a6SAndroid Build Coastguard Worker 723*523fa7a6SAndroid Build Coastguard Worker alloc_nodes = [] 724*523fa7a6SAndroid Build Coastguard Worker for subgm in new_gm.modules(): 725*523fa7a6SAndroid Build Coastguard Worker if isinstance(subgm, torch.fx.GraphModule): 726*523fa7a6SAndroid Build Coastguard Worker for node in subgm.graph.nodes: 727*523fa7a6SAndroid Build Coastguard Worker if node.target == memory.alloc: 728*523fa7a6SAndroid Build Coastguard Worker alloc_nodes.append(node) 729*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(len(alloc_nodes) > 0) 730*523fa7a6SAndroid Build Coastguard Worker for node in alloc_nodes: 731*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(node.meta.get("spec", None), TensorSpec)) 732*523fa7a6SAndroid Build Coastguard Worker 733*523fa7a6SAndroid Build Coastguard Worker def test_debug_pass_file_log(self) -> None: 734*523fa7a6SAndroid Build Coastguard Worker eager_model = Mul() 735*523fa7a6SAndroid Build Coastguard Worker inputs = eager_model.get_random_inputs() 736*523fa7a6SAndroid Build Coastguard Worker 737*523fa7a6SAndroid Build Coastguard Worker # the debug pass works with a graph generated with make_fx directly 738*523fa7a6SAndroid Build Coastguard Worker gm = make_fx(eager_model)(*inputs) 739*523fa7a6SAndroid Build Coastguard Worker 740*523fa7a6SAndroid Build Coastguard Worker try: 741*523fa7a6SAndroid Build Coastguard Worker fd, path = tempfile.mkstemp() 742*523fa7a6SAndroid Build Coastguard Worker 743*523fa7a6SAndroid Build Coastguard Worker print(f"Write DebugPass output to {path}") 744*523fa7a6SAndroid Build Coastguard Worker DebugPass(log_filename=path)(gm) 745*523fa7a6SAndroid Build Coastguard Worker with open(path) as f: 746*523fa7a6SAndroid Build Coastguard Worker file_cont = f.read() 747*523fa7a6SAndroid Build Coastguard Worker self.assertTrue("torch.ops.aten.mul" in file_cont) 748*523fa7a6SAndroid Build Coastguard Worker finally: 749*523fa7a6SAndroid Build Coastguard Worker os.close(fd) 750*523fa7a6SAndroid Build Coastguard Worker os.unlink(path) 751*523fa7a6SAndroid Build Coastguard Worker 752*523fa7a6SAndroid Build Coastguard Worker def test_dce_recursive(self) -> None: 753*523fa7a6SAndroid Build Coastguard Worker eager_model = FTCondDeadCode() 754*523fa7a6SAndroid Build Coastguard Worker inputs = eager_model.get_random_inputs() 755*523fa7a6SAndroid Build Coastguard Worker gm = export( 756*523fa7a6SAndroid Build Coastguard Worker eager_model, 757*523fa7a6SAndroid Build Coastguard Worker inputs, 758*523fa7a6SAndroid Build Coastguard Worker ).graph_module 759*523fa7a6SAndroid Build Coastguard Worker 760*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(torch.ops.aten.sub.Tensor in collect_ops(gm)) 761*523fa7a6SAndroid Build Coastguard Worker dead_code_elimination_pass(gm) 762*523fa7a6SAndroid Build Coastguard Worker gm.print_readable() 763*523fa7a6SAndroid Build Coastguard Worker self.assertFalse(torch.ops.aten.sub.Tensor in collect_ops(gm)) 764*523fa7a6SAndroid Build Coastguard Worker 765*523fa7a6SAndroid Build Coastguard Worker def test_propagate_dynamic_shape(self) -> None: 766*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 767*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 768*523fa7a6SAndroid Build Coastguard Worker y = x 769*523fa7a6SAndroid Build Coastguard Worker for _ in range(2): 770*523fa7a6SAndroid Build Coastguard Worker y = y + x 771*523fa7a6SAndroid Build Coastguard Worker return y 772*523fa7a6SAndroid Build Coastguard Worker 773*523fa7a6SAndroid Build Coastguard Worker f = Foo() 774*523fa7a6SAndroid Build Coastguard Worker 775*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 776*523fa7a6SAndroid Build Coastguard Worker export( 777*523fa7a6SAndroid Build Coastguard Worker f, 778*523fa7a6SAndroid Build Coastguard Worker (torch.rand(5),), 779*523fa7a6SAndroid Build Coastguard Worker ), 780*523fa7a6SAndroid Build Coastguard Worker # missing dispatch key 781*523fa7a6SAndroid Build Coastguard Worker compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 782*523fa7a6SAndroid Build Coastguard Worker ).transform(propagate_dynamic_shape()) 783*523fa7a6SAndroid Build Coastguard Worker gm = prog.exported_program().graph_module 784*523fa7a6SAndroid Build Coastguard Worker nspec = 0 785*523fa7a6SAndroid Build Coastguard Worker for n in gm.graph.nodes: 786*523fa7a6SAndroid Build Coastguard Worker for spec in pytree.tree_flatten(n.meta["spec"])[0]: 787*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(all(isinstance(x, int) for x in spec.shape)) 788*523fa7a6SAndroid Build Coastguard Worker nspec += 1 789*523fa7a6SAndroid Build Coastguard Worker 790*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(nspec > 0) 791*523fa7a6SAndroid Build Coastguard Worker 792*523fa7a6SAndroid Build Coastguard Worker def test_losing_symbolic_info(self) -> None: 793*523fa7a6SAndroid Build Coastguard Worker """ 794*523fa7a6SAndroid Build Coastguard Worker Guard against an issue that after calling ConvertSymbolicOpsPass(), 795*523fa7a6SAndroid Build Coastguard Worker future ExportPass will encounter symbolic information loss. 796*523fa7a6SAndroid Build Coastguard Worker """ 797*523fa7a6SAndroid Build Coastguard Worker 798*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 799*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 800*523fa7a6SAndroid Build Coastguard Worker return torch.add(x, x.shape[0] - 1) 801*523fa7a6SAndroid Build Coastguard Worker 802*523fa7a6SAndroid Build Coastguard Worker f = Foo() 803*523fa7a6SAndroid Build Coastguard Worker 804*523fa7a6SAndroid Build Coastguard Worker dim_x = torch.export.Dim("dim_x", max=3) 805*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 806*523fa7a6SAndroid Build Coastguard Worker export( 807*523fa7a6SAndroid Build Coastguard Worker f, 808*523fa7a6SAndroid Build Coastguard Worker (torch.ones(3, 2),), 809*523fa7a6SAndroid Build Coastguard Worker dynamic_shapes={"x": {0: dim_x}}, 810*523fa7a6SAndroid Build Coastguard Worker ), 811*523fa7a6SAndroid Build Coastguard Worker compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 812*523fa7a6SAndroid Build Coastguard Worker ) 813*523fa7a6SAndroid Build Coastguard Worker 814*523fa7a6SAndroid Build Coastguard Worker new_prog = prog.transform([EdgeToBackendOpsPass()]) 815*523fa7a6SAndroid Build Coastguard Worker gm = new_prog.exported_program().graph_module 816*523fa7a6SAndroid Build Coastguard Worker gm.print_readable() 817*523fa7a6SAndroid Build Coastguard Worker *_, ones, out = gm.graph.nodes 818*523fa7a6SAndroid Build Coastguard Worker print(f"Before ExportPass: {ones.format_node()}") 819*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt)) 820*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0) 821*523fa7a6SAndroid Build Coastguard Worker 822*523fa7a6SAndroid Build Coastguard Worker new_prog = new_prog.transform([ExportPass()]) 823*523fa7a6SAndroid Build Coastguard Worker gm = new_prog.exported_program().graph_module 824*523fa7a6SAndroid Build Coastguard Worker gm.print_readable() 825*523fa7a6SAndroid Build Coastguard Worker *_, ones, out = gm.graph.nodes 826*523fa7a6SAndroid Build Coastguard Worker print(f"After ExportPass: {ones.format_node()}") 827*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt)) 828*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0) 829*523fa7a6SAndroid Build Coastguard Worker 830*523fa7a6SAndroid Build Coastguard Worker def test_to_edge_with_edge_ops(self) -> None: 831*523fa7a6SAndroid Build Coastguard Worker x = torch.randn([2, 3, 4, 5]) 832*523fa7a6SAndroid Build Coastguard Worker 833*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 834*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 835*523fa7a6SAndroid Build Coastguard Worker return x + x 836*523fa7a6SAndroid Build Coastguard Worker 837*523fa7a6SAndroid Build Coastguard Worker f = Foo() 838*523fa7a6SAndroid Build Coastguard Worker 839*523fa7a6SAndroid Build Coastguard Worker gm = ( 840*523fa7a6SAndroid Build Coastguard Worker to_edge( 841*523fa7a6SAndroid Build Coastguard Worker export( 842*523fa7a6SAndroid Build Coastguard Worker f, 843*523fa7a6SAndroid Build Coastguard Worker (x,), 844*523fa7a6SAndroid Build Coastguard Worker ) 845*523fa7a6SAndroid Build Coastguard Worker ) 846*523fa7a6SAndroid Build Coastguard Worker .exported_program() 847*523fa7a6SAndroid Build Coastguard Worker .graph_module 848*523fa7a6SAndroid Build Coastguard Worker ) 849*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 850*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function": 851*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(type(node.target), EdgeOpOverload) 852*523fa7a6SAndroid Build Coastguard Worker 853*523fa7a6SAndroid Build Coastguard Worker # TODO(T143084047) 854*523fa7a6SAndroid Build Coastguard Worker @unittest.expectedFailure 855*523fa7a6SAndroid Build Coastguard Worker def test_backend_fused_op_retraceable(self) -> None: 856*523fa7a6SAndroid Build Coastguard Worker """This test makes sure the backend op is still retraceable, with the pattern being registered as kernel.""" 857*523fa7a6SAndroid Build Coastguard Worker 858*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 859*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 860*523fa7a6SAndroid Build Coastguard Worker z = x + y 861*523fa7a6SAndroid Build Coastguard Worker return torch.ops.aten.relu.default(z) 862*523fa7a6SAndroid Build Coastguard Worker 863*523fa7a6SAndroid Build Coastguard Worker f = Foo() 864*523fa7a6SAndroid Build Coastguard Worker 865*523fa7a6SAndroid Build Coastguard Worker gm = export( 866*523fa7a6SAndroid Build Coastguard Worker f, 867*523fa7a6SAndroid Build Coastguard Worker ( 868*523fa7a6SAndroid Build Coastguard Worker torch.randn(2, 2), 869*523fa7a6SAndroid Build Coastguard Worker torch.randn(2, 2), 870*523fa7a6SAndroid Build Coastguard Worker ), 871*523fa7a6SAndroid Build Coastguard Worker ) 872*523fa7a6SAndroid Build Coastguard Worker # should look like: 873*523fa7a6SAndroid Build Coastguard Worker # graph(): 874*523fa7a6SAndroid Build Coastguard Worker # %ph_0 : [#users=1] = placeholder[target=ph_0] 875*523fa7a6SAndroid Build Coastguard Worker # %ph_1 : [#users=1] = placeholder[target=ph_1] 876*523fa7a6SAndroid Build Coastguard Worker # %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, %ph_1), kwargs = {}) 877*523fa7a6SAndroid Build Coastguard Worker # %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%add_tensor,), kwargs = {}) 878*523fa7a6SAndroid Build Coastguard Worker # return [relu_default] 879*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("torch.ops.aten.add.Tensor").check( 880*523fa7a6SAndroid Build Coastguard Worker "torch.ops.aten.relu.default" 881*523fa7a6SAndroid Build Coastguard Worker ).run(gm.graph_module.code) 882*523fa7a6SAndroid Build Coastguard Worker 883*523fa7a6SAndroid Build Coastguard Worker class AddReluFusionPass(ExportPass): 884*523fa7a6SAndroid Build Coastguard Worker def call(self, graph_module: GraphModule) -> PassResult: 885*523fa7a6SAndroid Build Coastguard Worker # decorator registers this pattern as a CompositeExplicitAutograd kernel, since there's no kernel registered before. 886*523fa7a6SAndroid Build Coastguard Worker @bind_pattern_to_op(lib, "add_relu") 887*523fa7a6SAndroid Build Coastguard Worker def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 888*523fa7a6SAndroid Build Coastguard Worker z = torch.ops.aten.add.Tensor(x, y) 889*523fa7a6SAndroid Build Coastguard Worker out = torch.ops.aten.relu.default(z) 890*523fa7a6SAndroid Build Coastguard Worker return out 891*523fa7a6SAndroid Build Coastguard Worker 892*523fa7a6SAndroid Build Coastguard Worker def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 893*523fa7a6SAndroid Build Coastguard Worker return ops.backend.DO_NOT_USE_TEST_ONLY.add_relu.default(x, y) 894*523fa7a6SAndroid Build Coastguard Worker 895*523fa7a6SAndroid Build Coastguard Worker subgraph_rewriter.replace_pattern(graph_module, pattern, replacement) 896*523fa7a6SAndroid Build Coastguard Worker return PassResult(graph_module, True) 897*523fa7a6SAndroid Build Coastguard Worker 898*523fa7a6SAndroid Build Coastguard Worker # TODO: larryliu this pass needs to be in to_executorch() 899*523fa7a6SAndroid Build Coastguard Worker class OpReplacePass(ExportPass): 900*523fa7a6SAndroid Build Coastguard Worker def call_operator(self, op, args, kwargs, meta): 901*523fa7a6SAndroid Build Coastguard Worker if op == torch.ops.DO_NOT_USE_TEST_ONLY.add_relu.default: 902*523fa7a6SAndroid Build Coastguard Worker return super().call_operator( 903*523fa7a6SAndroid Build Coastguard Worker ops.backend.DO_NOT_USE_TEST_ONLY.add_relu.default, 904*523fa7a6SAndroid Build Coastguard Worker args, 905*523fa7a6SAndroid Build Coastguard Worker kwargs, 906*523fa7a6SAndroid Build Coastguard Worker meta, 907*523fa7a6SAndroid Build Coastguard Worker ) 908*523fa7a6SAndroid Build Coastguard Worker return super().call_operator(op, args, kwargs, meta) 909*523fa7a6SAndroid Build Coastguard Worker 910*523fa7a6SAndroid Build Coastguard Worker gm_lowered = to_edge( 911*523fa7a6SAndroid Build Coastguard Worker gm, 912*523fa7a6SAndroid Build Coastguard Worker compile_config=EdgeCompileConfig( 913*523fa7a6SAndroid Build Coastguard Worker _check_ir_validity=False, 914*523fa7a6SAndroid Build Coastguard Worker ), 915*523fa7a6SAndroid Build Coastguard Worker ).transform([AddReluFusionPass(), OpReplacePass()]) 916*523fa7a6SAndroid Build Coastguard Worker 917*523fa7a6SAndroid Build Coastguard Worker FileCheck().check( 918*523fa7a6SAndroid Build Coastguard Worker "executorch_exir_dialects_backend__ops_DO_NOT_USE_TEST_ONLY_add_relu_default" 919*523fa7a6SAndroid Build Coastguard Worker ).run(gm_lowered.exported_program().graph_module.code) 920*523fa7a6SAndroid Build Coastguard Worker # lowered module: 921*523fa7a6SAndroid Build Coastguard Worker # def forward(self, ph_0, ph_1): 922*523fa7a6SAndroid Build Coastguard Worker # do_not_use_test_only_add_relu_default = executorch_exir_dialects_backend__ops_DO_NOT_USE_TEST_ONLY_add_relu_default(ph_0, ph_1); ph_0 = ph_1 = None 923*523fa7a6SAndroid Build Coastguard Worker # return [do_not_use_test_only_add_relu_default] 924*523fa7a6SAndroid Build Coastguard Worker 925*523fa7a6SAndroid Build Coastguard Worker # Retrace: 926*523fa7a6SAndroid Build Coastguard Worker # If not backend op retrace will error out because no CPU/CompositeExplicitAutograd kernel registered. 927*523fa7a6SAndroid Build Coastguard Worker gm_retraced = to_edge( 928*523fa7a6SAndroid Build Coastguard Worker export( 929*523fa7a6SAndroid Build Coastguard Worker gm_lowered.exported_program().module(), 930*523fa7a6SAndroid Build Coastguard Worker ( 931*523fa7a6SAndroid Build Coastguard Worker torch.randn(2, 2), 932*523fa7a6SAndroid Build Coastguard Worker torch.randn(2, 2), 933*523fa7a6SAndroid Build Coastguard Worker ), 934*523fa7a6SAndroid Build Coastguard Worker ) 935*523fa7a6SAndroid Build Coastguard Worker ) 936*523fa7a6SAndroid Build Coastguard Worker # Retrace-able, the graph "promote" back to ATen dialect, showing up add and relu, which is expected. 937*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("torch.ops.aten.add.Tensor").check( 938*523fa7a6SAndroid Build Coastguard Worker "torch.ops.aten.relu.default" 939*523fa7a6SAndroid Build Coastguard Worker ).run(gm_retraced.exported_program().graph_module.code) 940*523fa7a6SAndroid Build Coastguard Worker 941*523fa7a6SAndroid Build Coastguard Worker def test_debug_handle_generator_pass(self) -> None: 942*523fa7a6SAndroid Build Coastguard Worker eager_model = MLP(2, output_size=4) 943*523fa7a6SAndroid Build Coastguard Worker inputs = eager_model.get_random_inputs() 944*523fa7a6SAndroid Build Coastguard Worker 945*523fa7a6SAndroid Build Coastguard Worker graph_module = ( 946*523fa7a6SAndroid Build Coastguard Worker to_edge( 947*523fa7a6SAndroid Build Coastguard Worker export( 948*523fa7a6SAndroid Build Coastguard Worker eager_model, 949*523fa7a6SAndroid Build Coastguard Worker inputs, 950*523fa7a6SAndroid Build Coastguard Worker ) 951*523fa7a6SAndroid Build Coastguard Worker ) 952*523fa7a6SAndroid Build Coastguard Worker .exported_program() 953*523fa7a6SAndroid Build Coastguard Worker .graph_module 954*523fa7a6SAndroid Build Coastguard Worker ) 955*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 956*523fa7a6SAndroid Build Coastguard Worker self.assertIn("debug_handle", node.meta) 957*523fa7a6SAndroid Build Coastguard Worker ScalarToTensorPass()(graph_module) 958*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 959*523fa7a6SAndroid Build Coastguard Worker self.assertIn("debug_handle", node.meta) 960*523fa7a6SAndroid Build Coastguard Worker 961*523fa7a6SAndroid Build Coastguard Worker def test_generate_missing_debug_handles(self) -> None: 962*523fa7a6SAndroid Build Coastguard Worker eager_model = MLP(2, output_size=4) 963*523fa7a6SAndroid Build Coastguard Worker inputs = eager_model.get_random_inputs() 964*523fa7a6SAndroid Build Coastguard Worker 965*523fa7a6SAndroid Build Coastguard Worker ep = to_edge( 966*523fa7a6SAndroid Build Coastguard Worker export( 967*523fa7a6SAndroid Build Coastguard Worker eager_model, 968*523fa7a6SAndroid Build Coastguard Worker inputs, 969*523fa7a6SAndroid Build Coastguard Worker ) 970*523fa7a6SAndroid Build Coastguard Worker ).exported_program() 971*523fa7a6SAndroid Build Coastguard Worker 972*523fa7a6SAndroid Build Coastguard Worker list(ep.graph.nodes)[0].meta.pop("debug_handle") 973*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None) 974*523fa7a6SAndroid Build Coastguard Worker generate_missing_debug_handles(ep) 975*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None) 976*523fa7a6SAndroid Build Coastguard Worker 977*523fa7a6SAndroid Build Coastguard Worker def test_debug_handle_generator_pass_with_control_flow(self) -> None: 978*523fa7a6SAndroid Build Coastguard Worker def true_nested(y: torch.Tensor) -> torch.Tensor: 979*523fa7a6SAndroid Build Coastguard Worker y = y + y 980*523fa7a6SAndroid Build Coastguard Worker y = torch.mm(y, y) 981*523fa7a6SAndroid Build Coastguard Worker return y 982*523fa7a6SAndroid Build Coastguard Worker 983*523fa7a6SAndroid Build Coastguard Worker def false_nested(y: torch.Tensor) -> torch.Tensor: 984*523fa7a6SAndroid Build Coastguard Worker return torch.mm(y, y) 985*523fa7a6SAndroid Build Coastguard Worker 986*523fa7a6SAndroid Build Coastguard Worker def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor: 987*523fa7a6SAndroid Build Coastguard Worker z = control_flow.cond(pred2, true_nested, false_nested, [x]) 988*523fa7a6SAndroid Build Coastguard Worker return x + z 989*523fa7a6SAndroid Build Coastguard Worker 990*523fa7a6SAndroid Build Coastguard Worker def false_fn(x: torch.Tensor, _) -> torch.Tensor: 991*523fa7a6SAndroid Build Coastguard Worker return x.cos() 992*523fa7a6SAndroid Build Coastguard Worker 993*523fa7a6SAndroid Build Coastguard Worker def map_fn( 994*523fa7a6SAndroid Build Coastguard Worker x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor 995*523fa7a6SAndroid Build Coastguard Worker ) -> torch.Tensor: 996*523fa7a6SAndroid Build Coastguard Worker x = x.cos() 997*523fa7a6SAndroid Build Coastguard Worker y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) 998*523fa7a6SAndroid Build Coastguard Worker x = x + y 999*523fa7a6SAndroid Build Coastguard Worker return x.sin() 1000*523fa7a6SAndroid Build Coastguard Worker 1001*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 1002*523fa7a6SAndroid Build Coastguard Worker def forward( 1003*523fa7a6SAndroid Build Coastguard Worker self, 1004*523fa7a6SAndroid Build Coastguard Worker xs: torch.Tensor, 1005*523fa7a6SAndroid Build Coastguard Worker pred1: torch.Tensor, 1006*523fa7a6SAndroid Build Coastguard Worker pred2: torch.Tensor, 1007*523fa7a6SAndroid Build Coastguard Worker y: torch.Tensor, 1008*523fa7a6SAndroid Build Coastguard Worker ) -> torch.Tensor: 1009*523fa7a6SAndroid Build Coastguard Worker y = torch.mm(y, y) 1010*523fa7a6SAndroid Build Coastguard Worker return control_flow.map(map_fn, xs, pred1, pred2, y) 1011*523fa7a6SAndroid Build Coastguard Worker 1012*523fa7a6SAndroid Build Coastguard Worker f = Foo() 1013*523fa7a6SAndroid Build Coastguard Worker 1014*523fa7a6SAndroid Build Coastguard Worker inputs = ( 1015*523fa7a6SAndroid Build Coastguard Worker torch.ones(2, 2), 1016*523fa7a6SAndroid Build Coastguard Worker torch.tensor([False]), 1017*523fa7a6SAndroid Build Coastguard Worker torch.tensor([False]), 1018*523fa7a6SAndroid Build Coastguard Worker torch.ones(2, 2), 1019*523fa7a6SAndroid Build Coastguard Worker ) 1020*523fa7a6SAndroid Build Coastguard Worker 1021*523fa7a6SAndroid Build Coastguard Worker ep = to_edge( 1022*523fa7a6SAndroid Build Coastguard Worker export( 1023*523fa7a6SAndroid Build Coastguard Worker f, 1024*523fa7a6SAndroid Build Coastguard Worker inputs, 1025*523fa7a6SAndroid Build Coastguard Worker ) 1026*523fa7a6SAndroid Build Coastguard Worker ).exported_program() 1027*523fa7a6SAndroid Build Coastguard Worker graph_module = ep.graph_module 1028*523fa7a6SAndroid Build Coastguard Worker 1029*523fa7a6SAndroid Build Coastguard Worker def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None: 1030*523fa7a6SAndroid Build Coastguard Worker queue = [graph_module] 1031*523fa7a6SAndroid Build Coastguard Worker while queue: 1032*523fa7a6SAndroid Build Coastguard Worker current_graph_module = queue.pop(0) 1033*523fa7a6SAndroid Build Coastguard Worker for node in current_graph_module.graph.nodes: 1034*523fa7a6SAndroid Build Coastguard Worker self.assertIn("debug_handle", node.meta) 1035*523fa7a6SAndroid Build Coastguard Worker control_flow_submodules = [ 1036*523fa7a6SAndroid Build Coastguard Worker submodule 1037*523fa7a6SAndroid Build Coastguard Worker for _, submodule, _ in get_control_flow_submodules( 1038*523fa7a6SAndroid Build Coastguard Worker current_graph_module 1039*523fa7a6SAndroid Build Coastguard Worker ) 1040*523fa7a6SAndroid Build Coastguard Worker ] 1041*523fa7a6SAndroid Build Coastguard Worker queue.extend(control_flow_submodules) 1042*523fa7a6SAndroid Build Coastguard Worker 1043*523fa7a6SAndroid Build Coastguard Worker DebugHandleGeneratorPass()(graph_module) 1044*523fa7a6SAndroid Build Coastguard Worker check_debug_handle_metadata(graph_module) 1045*523fa7a6SAndroid Build Coastguard Worker generate_missing_debug_handles(ep) 1046*523fa7a6SAndroid Build Coastguard Worker 1047*523fa7a6SAndroid Build Coastguard Worker # Check debug handle still preserved after ScalarToTensorPass 1048*523fa7a6SAndroid Build Coastguard Worker ScalarToTensorPass()(graph_module) 1049*523fa7a6SAndroid Build Coastguard Worker check_debug_handle_metadata(graph_module) 1050*523fa7a6SAndroid Build Coastguard Worker 1051*523fa7a6SAndroid Build Coastguard Worker def test_symint_conversion(self) -> None: 1052*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 1053*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 1054*523fa7a6SAndroid Build Coastguard Worker return torch.add(x, x.shape[0] - 1) 1055*523fa7a6SAndroid Build Coastguard Worker 1056*523fa7a6SAndroid Build Coastguard Worker f = Foo() 1057*523fa7a6SAndroid Build Coastguard Worker 1058*523fa7a6SAndroid Build Coastguard Worker dim_x = torch.export.Dim("dim_x", max=3) 1059*523fa7a6SAndroid Build Coastguard Worker prog = to_edge( 1060*523fa7a6SAndroid Build Coastguard Worker export( 1061*523fa7a6SAndroid Build Coastguard Worker f, 1062*523fa7a6SAndroid Build Coastguard Worker (torch.ones(3, 2),), 1063*523fa7a6SAndroid Build Coastguard Worker dynamic_shapes={"x": {0: dim_x}}, 1064*523fa7a6SAndroid Build Coastguard Worker ), 1065*523fa7a6SAndroid Build Coastguard Worker compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 1066*523fa7a6SAndroid Build Coastguard Worker ) 1067*523fa7a6SAndroid Build Coastguard Worker prog = prog.transform([SymToTensorPass()]) 1068*523fa7a6SAndroid Build Coastguard Worker 1069*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("torch.ops.aten.scalar_tensor.default").run( 1070*523fa7a6SAndroid Build Coastguard Worker prog.exported_program().graph_module.code 1071*523fa7a6SAndroid Build Coastguard Worker ) 1072*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 1073*523fa7a6SAndroid Build Coastguard Worker torch.allclose( 1074*523fa7a6SAndroid Build Coastguard Worker f(torch.ones(3, 2)), prog.exported_program().module()(torch.ones(3, 2)) 1075*523fa7a6SAndroid Build Coastguard Worker ) 1076*523fa7a6SAndroid Build Coastguard Worker ) 1077*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 1078*523fa7a6SAndroid Build Coastguard Worker torch.allclose( 1079*523fa7a6SAndroid Build Coastguard Worker f(torch.zeros(3, 2)), 1080*523fa7a6SAndroid Build Coastguard Worker prog.exported_program().module()(torch.zeros(3, 2)), 1081*523fa7a6SAndroid Build Coastguard Worker ) 1082*523fa7a6SAndroid Build Coastguard Worker ) 1083*523fa7a6SAndroid Build Coastguard Worker 1084*523fa7a6SAndroid Build Coastguard Worker def test_remove_assert_pass(self) -> None: 1085*523fa7a6SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 1086*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 1087*523fa7a6SAndroid Build Coastguard Worker assert x.shape[0] == 5 1088*523fa7a6SAndroid Build Coastguard Worker return x * x 1089*523fa7a6SAndroid Build Coastguard Worker 1090*523fa7a6SAndroid Build Coastguard Worker f = Foo() 1091*523fa7a6SAndroid Build Coastguard Worker 1092*523fa7a6SAndroid Build Coastguard Worker gm = to_edge( 1093*523fa7a6SAndroid Build Coastguard Worker export( 1094*523fa7a6SAndroid Build Coastguard Worker f, 1095*523fa7a6SAndroid Build Coastguard Worker (torch.randn(5),), 1096*523fa7a6SAndroid Build Coastguard Worker ), 1097*523fa7a6SAndroid Build Coastguard Worker compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 1098*523fa7a6SAndroid Build Coastguard Worker ) 1099*523fa7a6SAndroid Build Coastguard Worker new_gm = gm.transform([RemoveGraphAssertsPass()]) 1100*523fa7a6SAndroid Build Coastguard Worker num_asserts = [ 1101*523fa7a6SAndroid Build Coastguard Worker node 1102*523fa7a6SAndroid Build Coastguard Worker for node in new_gm.exported_program().graph.nodes 1103*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" 1104*523fa7a6SAndroid Build Coastguard Worker and node.target == torch.ops.aten._assert_async.msg 1105*523fa7a6SAndroid Build Coastguard Worker ] 1106*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(num_asserts), 0) 1107*523fa7a6SAndroid Build Coastguard Worker 1108*523fa7a6SAndroid Build Coastguard Worker def test_arange(self) -> None: 1109*523fa7a6SAndroid Build Coastguard Worker class M(torch.nn.Module): 1110*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1111*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1112*523fa7a6SAndroid Build Coastguard Worker self.a = torch.ones(2) 1113*523fa7a6SAndroid Build Coastguard Worker 1114*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1115*523fa7a6SAndroid Build Coastguard Worker return torch.arange(start=0, end=2) + x 1116*523fa7a6SAndroid Build Coastguard Worker 1117*523fa7a6SAndroid Build Coastguard Worker _ = to_edge( 1118*523fa7a6SAndroid Build Coastguard Worker export( 1119*523fa7a6SAndroid Build Coastguard Worker M(), 1120*523fa7a6SAndroid Build Coastguard Worker (torch.randn(2),), 1121*523fa7a6SAndroid Build Coastguard Worker ) 1122*523fa7a6SAndroid Build Coastguard Worker ).to_executorch() 1123*523fa7a6SAndroid Build Coastguard Worker 1124*523fa7a6SAndroid Build Coastguard Worker def test_replace_slice(self) -> None: 1125*523fa7a6SAndroid Build Coastguard Worker class M(torch.nn.Module): 1126*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1127*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1128*523fa7a6SAndroid Build Coastguard Worker self.a = torch.ones(10) 1129*523fa7a6SAndroid Build Coastguard Worker 1130*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1131*523fa7a6SAndroid Build Coastguard Worker return self.a[:2] + x 1132*523fa7a6SAndroid Build Coastguard Worker 1133*523fa7a6SAndroid Build Coastguard Worker gm = ( 1134*523fa7a6SAndroid Build Coastguard Worker to_edge( 1135*523fa7a6SAndroid Build Coastguard Worker export( 1136*523fa7a6SAndroid Build Coastguard Worker M(), 1137*523fa7a6SAndroid Build Coastguard Worker (torch.randn(2),), 1138*523fa7a6SAndroid Build Coastguard Worker ) 1139*523fa7a6SAndroid Build Coastguard Worker ) 1140*523fa7a6SAndroid Build Coastguard Worker .exported_program() 1141*523fa7a6SAndroid Build Coastguard Worker .graph_module 1142*523fa7a6SAndroid Build Coastguard Worker ) 1143*523fa7a6SAndroid Build Coastguard Worker FileCheck().check( 1144*523fa7a6SAndroid Build Coastguard Worker "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor" 1145*523fa7a6SAndroid Build Coastguard Worker ).run(gm.code) 1146*523fa7a6SAndroid Build Coastguard Worker 1147*523fa7a6SAndroid Build Coastguard Worker def test_constant_prop_pass_for_add(self) -> None: 1148*523fa7a6SAndroid Build Coastguard Worker class Add(torch.nn.Module): 1149*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 1150*523fa7a6SAndroid Build Coastguard Worker return x + 3 1151*523fa7a6SAndroid Build Coastguard Worker 1152*523fa7a6SAndroid Build Coastguard Worker add = Add() 1153*523fa7a6SAndroid Build Coastguard Worker 1154*523fa7a6SAndroid Build Coastguard Worker edge = to_edge( 1155*523fa7a6SAndroid Build Coastguard Worker export(add, (torch.ones(1),)), 1156*523fa7a6SAndroid Build Coastguard Worker compile_config=EdgeCompileConfig(_skip_dim_order=False), 1157*523fa7a6SAndroid Build Coastguard Worker ) 1158*523fa7a6SAndroid Build Coastguard Worker edge = edge.transform([ScalarToTensorPass(), RemoveMixedTypeOperators()]) 1159*523fa7a6SAndroid Build Coastguard Worker exported_program = lift_constant_tensor_pass(edge.exported_program()) 1160*523fa7a6SAndroid Build Coastguard Worker 1161*523fa7a6SAndroid Build Coastguard Worker # Check there is a lifted tensor followed by a to_copy node 1162*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("_lifted_tensor_constant0").check( 1163*523fa7a6SAndroid Build Coastguard Worker "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" 1164*523fa7a6SAndroid Build Coastguard Worker ).run(exported_program.graph_module.code) 1165*523fa7a6SAndroid Build Coastguard Worker 1166*523fa7a6SAndroid Build Coastguard Worker new_ep = constant_prop_pass(exported_program) 1167*523fa7a6SAndroid Build Coastguard Worker 1168*523fa7a6SAndroid Build Coastguard Worker # Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor 1169*523fa7a6SAndroid Build Coastguard Worker FileCheck().check_not("_lifted_tensor_constant").check( 1170*523fa7a6SAndroid Build Coastguard Worker "_prop_tensor_constant0" 1171*523fa7a6SAndroid Build Coastguard Worker ).check_not( 1172*523fa7a6SAndroid Build Coastguard Worker "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" 1173*523fa7a6SAndroid Build Coastguard Worker ).run( 1174*523fa7a6SAndroid Build Coastguard Worker new_ep.graph_module.code 1175*523fa7a6SAndroid Build Coastguard Worker ) 1176*523fa7a6SAndroid Build Coastguard Worker 1177*523fa7a6SAndroid Build Coastguard Worker def test_constant_prop_pass_for_parameter(self) -> None: 1178*523fa7a6SAndroid Build Coastguard Worker def count_additions(gm: torch.fx.GraphModule) -> int: 1179*523fa7a6SAndroid Build Coastguard Worker return sum( 1180*523fa7a6SAndroid Build Coastguard Worker (node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes 1181*523fa7a6SAndroid Build Coastguard Worker ) 1182*523fa7a6SAndroid Build Coastguard Worker 1183*523fa7a6SAndroid Build Coastguard Worker class M(torch.nn.Module): 1184*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1185*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1186*523fa7a6SAndroid Build Coastguard Worker self.a = torch.nn.Parameter(torch.ones(1, 2, 3)) 1187*523fa7a6SAndroid Build Coastguard Worker 1188*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1189*523fa7a6SAndroid Build Coastguard Worker b = self.a + self.a 1190*523fa7a6SAndroid Build Coastguard Worker c = torch.cat([self.a, b]) 1191*523fa7a6SAndroid Build Coastguard Worker return (c + c) + x 1192*523fa7a6SAndroid Build Coastguard Worker 1193*523fa7a6SAndroid Build Coastguard Worker aten = export( 1194*523fa7a6SAndroid Build Coastguard Worker M(), 1195*523fa7a6SAndroid Build Coastguard Worker (torch.zeros(2, 2, 3),), 1196*523fa7a6SAndroid Build Coastguard Worker ) 1197*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(count_additions(aten.graph_module), 3) 1198*523fa7a6SAndroid Build Coastguard Worker new_ep = constant_prop_pass(aten) 1199*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(count_additions(new_ep.graph_module), 1) 1200*523fa7a6SAndroid Build Coastguard Worker 1201*523fa7a6SAndroid Build Coastguard Worker def test_constant_prop_pass_graph_signature(self) -> None: 1202*523fa7a6SAndroid Build Coastguard Worker def count_additions(gm: torch.fx.GraphModule) -> int: 1203*523fa7a6SAndroid Build Coastguard Worker return sum( 1204*523fa7a6SAndroid Build Coastguard Worker (node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes 1205*523fa7a6SAndroid Build Coastguard Worker ) 1206*523fa7a6SAndroid Build Coastguard Worker 1207*523fa7a6SAndroid Build Coastguard Worker class M(torch.nn.Module): 1208*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1209*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1210*523fa7a6SAndroid Build Coastguard Worker self.a = torch.nn.Parameter(torch.ones(1, 2, 3)) 1211*523fa7a6SAndroid Build Coastguard Worker 1212*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1213*523fa7a6SAndroid Build Coastguard Worker b = self.a + self.a 1214*523fa7a6SAndroid Build Coastguard Worker c = torch.cat([self.a, b]) 1215*523fa7a6SAndroid Build Coastguard Worker return (c + c) + x 1216*523fa7a6SAndroid Build Coastguard Worker 1217*523fa7a6SAndroid Build Coastguard Worker aten = export( 1218*523fa7a6SAndroid Build Coastguard Worker M(), 1219*523fa7a6SAndroid Build Coastguard Worker (torch.zeros(2, 2, 3),), 1220*523fa7a6SAndroid Build Coastguard Worker ) 1221*523fa7a6SAndroid Build Coastguard Worker # Input signature will have two entries: 1222*523fa7a6SAndroid Build Coastguard Worker # (1) parameter `a` and (2) user input `x`. 1223*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(len(aten.graph_signature.input_specs), 2) 1224*523fa7a6SAndroid Build Coastguard Worker new_ep = constant_prop_pass(aten) 1225*523fa7a6SAndroid Build Coastguard Worker # Check that there are exactly two propagated tensors - (1) propagated 1226*523fa7a6SAndroid Build Coastguard Worker # constant and (2) user input. 1227*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 1228*523fa7a6SAndroid Build Coastguard Worker new_ep.graph_signature.input_specs, 1229*523fa7a6SAndroid Build Coastguard Worker [ 1230*523fa7a6SAndroid Build Coastguard Worker InputSpec( 1231*523fa7a6SAndroid Build Coastguard Worker kind=InputKind.CONSTANT_TENSOR, 1232*523fa7a6SAndroid Build Coastguard Worker arg=TensorArgument(name="_prop_tensor_constant0"), 1233*523fa7a6SAndroid Build Coastguard Worker target="_prop_tensor_constant0", 1234*523fa7a6SAndroid Build Coastguard Worker persistent=True, 1235*523fa7a6SAndroid Build Coastguard Worker ), 1236*523fa7a6SAndroid Build Coastguard Worker # User input graph signature. 1237*523fa7a6SAndroid Build Coastguard Worker aten.graph_signature.input_specs[-1], 1238*523fa7a6SAndroid Build Coastguard Worker ], 1239*523fa7a6SAndroid Build Coastguard Worker ) 1240*523fa7a6SAndroid Build Coastguard Worker 1241*523fa7a6SAndroid Build Coastguard Worker def test_constant_prop_pass_for_parameter_slice(self) -> None: 1242*523fa7a6SAndroid Build Coastguard Worker def count_slice(gm: torch.fx.GraphModule) -> int: 1243*523fa7a6SAndroid Build Coastguard Worker return sum( 1244*523fa7a6SAndroid Build Coastguard Worker (node.target == torch.ops.aten.slice_copy.Tensor) 1245*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes 1246*523fa7a6SAndroid Build Coastguard Worker ) 1247*523fa7a6SAndroid Build Coastguard Worker 1248*523fa7a6SAndroid Build Coastguard Worker class M(torch.nn.Module): 1249*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1250*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1251*523fa7a6SAndroid Build Coastguard Worker self.a = torch.nn.Parameter(torch.ones(3, 2, 2)) 1252*523fa7a6SAndroid Build Coastguard Worker 1253*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1254*523fa7a6SAndroid Build Coastguard Worker # Create slice of shape (1, 2, 2) 1255*523fa7a6SAndroid Build Coastguard Worker slice_tensor = torch.slice_copy(self.a, dim=0, start=0, end=1) 1256*523fa7a6SAndroid Build Coastguard Worker return torch.cat([x, slice_tensor]) 1257*523fa7a6SAndroid Build Coastguard Worker 1258*523fa7a6SAndroid Build Coastguard Worker aten = export( 1259*523fa7a6SAndroid Build Coastguard Worker M(), 1260*523fa7a6SAndroid Build Coastguard Worker (torch.zeros(2, 2, 2),), 1261*523fa7a6SAndroid Build Coastguard Worker ) 1262*523fa7a6SAndroid Build Coastguard Worker self.assertIn("a", aten.state_dict) 1263*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(count_slice(aten.graph_module), 1) 1264*523fa7a6SAndroid Build Coastguard Worker 1265*523fa7a6SAndroid Build Coastguard Worker new_ep = constant_prop_pass(aten) 1266*523fa7a6SAndroid Build Coastguard Worker # Check there is a propagated tensor. 1267*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("_prop_tensor_constant0").run(aten.graph_module.code) 1268*523fa7a6SAndroid Build Coastguard Worker self.assertIn("_prop_tensor_constant0", new_ep.constants) 1269*523fa7a6SAndroid Build Coastguard Worker self.assertNotIn("a", new_ep.state_dict) 1270*523fa7a6SAndroid Build Coastguard Worker # No more slice copy. 1271*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(count_slice(new_ep.graph_module), 0) 1272*523fa7a6SAndroid Build Coastguard Worker 1273*523fa7a6SAndroid Build Coastguard Worker def test_constant_prop_pass_no_propagate(self) -> None: 1274*523fa7a6SAndroid Build Coastguard Worker def count_placeholder(gm: torch.fx.GraphModule) -> int: 1275*523fa7a6SAndroid Build Coastguard Worker return sum((node.op == "placeholder") for node in gm.graph.nodes) 1276*523fa7a6SAndroid Build Coastguard Worker 1277*523fa7a6SAndroid Build Coastguard Worker class M(torch.nn.Module): 1278*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1279*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1280*523fa7a6SAndroid Build Coastguard Worker self.a = torch.nn.Parameter(torch.ones(3, 2, 4)) 1281*523fa7a6SAndroid Build Coastguard Worker 1282*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 1283*523fa7a6SAndroid Build Coastguard Worker # y is unused. 1284*523fa7a6SAndroid Build Coastguard Worker return x + self.a 1285*523fa7a6SAndroid Build Coastguard Worker 1286*523fa7a6SAndroid Build Coastguard Worker aten = export( 1287*523fa7a6SAndroid Build Coastguard Worker M(), 1288*523fa7a6SAndroid Build Coastguard Worker (torch.zeros(3, 2, 4), torch.zeros(3, 2, 4)), 1289*523fa7a6SAndroid Build Coastguard Worker ) 1290*523fa7a6SAndroid Build Coastguard Worker self.assertIn("a", aten.state_dict) 1291*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(count_placeholder(aten.graph_module), 3) 1292*523fa7a6SAndroid Build Coastguard Worker 1293*523fa7a6SAndroid Build Coastguard Worker new_ep = constant_prop_pass(aten) 1294*523fa7a6SAndroid Build Coastguard Worker # Check there is no propagated tensor. 1295*523fa7a6SAndroid Build Coastguard Worker FileCheck().check("p_a").check("x").check("y").run(aten.graph_module.code) 1296*523fa7a6SAndroid Build Coastguard Worker self.assertNotIn("_prop_tensor_constant0", new_ep.constants) 1297*523fa7a6SAndroid Build Coastguard Worker self.assertIn("a", new_ep.state_dict) 1298*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(count_placeholder(new_ep.graph_module), 3) 1299*523fa7a6SAndroid Build Coastguard Worker 1300*523fa7a6SAndroid Build Coastguard Worker def test_constant_prop_pass_for_control_flow(self) -> None: 1301*523fa7a6SAndroid Build Coastguard Worker class Module(torch.nn.Module): 1302*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1303*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1304*523fa7a6SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(3, 3) 1305*523fa7a6SAndroid Build Coastguard Worker 1306*523fa7a6SAndroid Build Coastguard Worker def t(self, val): 1307*523fa7a6SAndroid Build Coastguard Worker return val + 1 1308*523fa7a6SAndroid Build Coastguard Worker 1309*523fa7a6SAndroid Build Coastguard Worker def f(self, val): 1310*523fa7a6SAndroid Build Coastguard Worker return val - 1 1311*523fa7a6SAndroid Build Coastguard Worker 1312*523fa7a6SAndroid Build Coastguard Worker def true_fn(self, val): 1313*523fa7a6SAndroid Build Coastguard Worker return self.linear(val) + self.t(val) 1314*523fa7a6SAndroid Build Coastguard Worker 1315*523fa7a6SAndroid Build Coastguard Worker def false_fn(self, val): 1316*523fa7a6SAndroid Build Coastguard Worker return self.linear(val) - self.f(val) 1317*523fa7a6SAndroid Build Coastguard Worker 1318*523fa7a6SAndroid Build Coastguard Worker def forward(self, pred, x): 1319*523fa7a6SAndroid Build Coastguard Worker return torch.ops.higher_order.cond( 1320*523fa7a6SAndroid Build Coastguard Worker pred, self.true_fn, self.false_fn, [x] 1321*523fa7a6SAndroid Build Coastguard Worker ) 1322*523fa7a6SAndroid Build Coastguard Worker 1323*523fa7a6SAndroid Build Coastguard Worker mod = Module() 1324*523fa7a6SAndroid Build Coastguard Worker x = torch.randn([3, 3]) 1325*523fa7a6SAndroid Build Coastguard Worker pred = torch.tensor(x[0][0].item() < 0) 1326*523fa7a6SAndroid Build Coastguard Worker edge = to_edge( 1327*523fa7a6SAndroid Build Coastguard Worker export(mod, (pred, x)), 1328*523fa7a6SAndroid Build Coastguard Worker compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 1329*523fa7a6SAndroid Build Coastguard Worker ) 1330*523fa7a6SAndroid Build Coastguard Worker error_msg = r"constant_prop_pass for control flow is not supported yet." 1331*523fa7a6SAndroid Build Coastguard Worker 1332*523fa7a6SAndroid Build Coastguard Worker # TODO(chenlai): enable constant prop pass for control flow 1333*523fa7a6SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1334*523fa7a6SAndroid Build Coastguard Worker RuntimeError, 1335*523fa7a6SAndroid Build Coastguard Worker error_msg, 1336*523fa7a6SAndroid Build Coastguard Worker ): 1337*523fa7a6SAndroid Build Coastguard Worker _ = constant_prop_pass(edge.exported_program()) 1338*523fa7a6SAndroid Build Coastguard Worker 1339*523fa7a6SAndroid Build Coastguard Worker def test_mutable_buffers(self) -> None: 1340*523fa7a6SAndroid Build Coastguard Worker def count_copies(gm: torch.fx.GraphModule) -> int: 1341*523fa7a6SAndroid Build Coastguard Worker return sum( 1342*523fa7a6SAndroid Build Coastguard Worker (node.target == torch.ops.aten.copy_.default) for node in gm.graph.nodes 1343*523fa7a6SAndroid Build Coastguard Worker ) 1344*523fa7a6SAndroid Build Coastguard Worker 1345*523fa7a6SAndroid Build Coastguard Worker class MutableStateModule(torch.nn.Module): 1346*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1347*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1348*523fa7a6SAndroid Build Coastguard Worker self.register_buffer("state", torch.zeros(1)) 1349*523fa7a6SAndroid Build Coastguard Worker 1350*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1351*523fa7a6SAndroid Build Coastguard Worker y = x + self.state 1352*523fa7a6SAndroid Build Coastguard Worker self.state.add_(1) 1353*523fa7a6SAndroid Build Coastguard Worker return y 1354*523fa7a6SAndroid Build Coastguard Worker 1355*523fa7a6SAndroid Build Coastguard Worker model = to_edge( 1356*523fa7a6SAndroid Build Coastguard Worker export( 1357*523fa7a6SAndroid Build Coastguard Worker MutableStateModule(), 1358*523fa7a6SAndroid Build Coastguard Worker (torch.zeros(1),), 1359*523fa7a6SAndroid Build Coastguard Worker ) 1360*523fa7a6SAndroid Build Coastguard Worker ) 1361*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(count_copies(model.exported_program().graph_module), 0) 1362*523fa7a6SAndroid Build Coastguard Worker # Before 1363*523fa7a6SAndroid Build Coastguard Worker # graph(): 1364*523fa7a6SAndroid Build Coastguard Worker # %arg0_1 : [num_users=2] = placeholder[target=arg0_1] 1365*523fa7a6SAndroid Build Coastguard Worker # %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1] 1366*523fa7a6SAndroid Build Coastguard Worker # %arg1_1 : [num_users=1] = placeholder[target=arg1_1] 1367*523fa7a6SAndroid Build Coastguard Worker # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {}) 1368*523fa7a6SAndroid Build Coastguard Worker # %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32}) 1369*523fa7a6SAndroid Build Coastguard Worker # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {}) 1370*523fa7a6SAndroid Build Coastguard Worker # return (aten_add_tensor_1, aten_add_tensor) 1371*523fa7a6SAndroid Build Coastguard Worker gm, _ = insert_write_back_for_buffers_pass(model.exported_program()) 1372*523fa7a6SAndroid Build Coastguard Worker 1373*523fa7a6SAndroid Build Coastguard Worker # After 1374*523fa7a6SAndroid Build Coastguard Worker # graph(): 1375*523fa7a6SAndroid Build Coastguard Worker # %arg0_1 : [num_users=3] = placeholder[target=arg0_1] 1376*523fa7a6SAndroid Build Coastguard Worker # %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1] 1377*523fa7a6SAndroid Build Coastguard Worker # %arg1_1 : [num_users=1] = placeholder[target=arg1_1] 1378*523fa7a6SAndroid Build Coastguard Worker # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {}) 1379*523fa7a6SAndroid Build Coastguard Worker # %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32}) 1380*523fa7a6SAndroid Build Coastguard Worker # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {}) 1381*523fa7a6SAndroid Build Coastguard Worker # %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {}) 1382*523fa7a6SAndroid Build Coastguard Worker # return (copy__default, aten_add_tensor) 1383*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(count_copies(gm), 1) 1384*523fa7a6SAndroid Build Coastguard Worker 1385*523fa7a6SAndroid Build Coastguard Worker def test_remove_quantized_op_noop_pass(self) -> None: 1386*523fa7a6SAndroid Build Coastguard Worker class TestAddSliceNoop(torch.nn.Module): 1387*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1388*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1389*523fa7a6SAndroid Build Coastguard Worker 1390*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1391*523fa7a6SAndroid Build Coastguard Worker x = x + x 1392*523fa7a6SAndroid Build Coastguard Worker x = x + x[:] 1393*523fa7a6SAndroid Build Coastguard Worker return x 1394*523fa7a6SAndroid Build Coastguard Worker 1395*523fa7a6SAndroid Build Coastguard Worker class TestAddSliceNotNoop(torch.nn.Module): 1396*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1397*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1398*523fa7a6SAndroid Build Coastguard Worker 1399*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1400*523fa7a6SAndroid Build Coastguard Worker x = x + x 1401*523fa7a6SAndroid Build Coastguard Worker x = x + x[:1] 1402*523fa7a6SAndroid Build Coastguard Worker return x 1403*523fa7a6SAndroid Build Coastguard Worker 1404*523fa7a6SAndroid Build Coastguard Worker def count_dq_nodes(gm: torch.fx.GraphModule) -> int: 1405*523fa7a6SAndroid Build Coastguard Worker return sum( 1406*523fa7a6SAndroid Build Coastguard Worker ( 1407*523fa7a6SAndroid Build Coastguard Worker node.target 1408*523fa7a6SAndroid Build Coastguard Worker in ( 1409*523fa7a6SAndroid Build Coastguard Worker torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1410*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 1411*523fa7a6SAndroid Build Coastguard Worker ) 1412*523fa7a6SAndroid Build Coastguard Worker ) 1413*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes 1414*523fa7a6SAndroid Build Coastguard Worker ) 1415*523fa7a6SAndroid Build Coastguard Worker 1416*523fa7a6SAndroid Build Coastguard Worker def count_q_nodes(gm: torch.fx.GraphModule) -> int: 1417*523fa7a6SAndroid Build Coastguard Worker return sum( 1418*523fa7a6SAndroid Build Coastguard Worker ( 1419*523fa7a6SAndroid Build Coastguard Worker node.target 1420*523fa7a6SAndroid Build Coastguard Worker in ( 1421*523fa7a6SAndroid Build Coastguard Worker torch.ops.quantized_decomposed.quantize_per_tensor.default, 1422*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 1423*523fa7a6SAndroid Build Coastguard Worker ) 1424*523fa7a6SAndroid Build Coastguard Worker ) 1425*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes 1426*523fa7a6SAndroid Build Coastguard Worker ) 1427*523fa7a6SAndroid Build Coastguard Worker 1428*523fa7a6SAndroid Build Coastguard Worker def quantize_model( 1429*523fa7a6SAndroid Build Coastguard Worker m_eager: torch.nn.Module, example_inputs: Tuple[torch.Tensor] 1430*523fa7a6SAndroid Build Coastguard Worker ) -> Tuple[EdgeProgramManager, int, int]: 1431*523fa7a6SAndroid Build Coastguard Worker # program capture 1432*523fa7a6SAndroid Build Coastguard Worker m = torch.export.export_for_training( 1433*523fa7a6SAndroid Build Coastguard Worker m_eager, 1434*523fa7a6SAndroid Build Coastguard Worker example_inputs, 1435*523fa7a6SAndroid Build Coastguard Worker ).module() 1436*523fa7a6SAndroid Build Coastguard Worker 1437*523fa7a6SAndroid Build Coastguard Worker quantizer = XNNPACKQuantizer() 1438*523fa7a6SAndroid Build Coastguard Worker quantization_config = get_symmetric_quantization_config() 1439*523fa7a6SAndroid Build Coastguard Worker quantizer.set_global(quantization_config) 1440*523fa7a6SAndroid Build Coastguard Worker m = prepare_pt2e(m, quantizer) # pyre-fixme[6] 1441*523fa7a6SAndroid Build Coastguard Worker m = convert_pt2e(m, fold_quantize=True) 1442*523fa7a6SAndroid Build Coastguard Worker ep = torch.export.export(m, example_inputs) 1443*523fa7a6SAndroid Build Coastguard Worker dq_nodes_pre = count_dq_nodes(ep.graph_module) 1444*523fa7a6SAndroid Build Coastguard Worker q_nodes_pre = count_q_nodes(ep.graph_module) 1445*523fa7a6SAndroid Build Coastguard Worker edge = to_edge( 1446*523fa7a6SAndroid Build Coastguard Worker ep, compile_config=EdgeCompileConfig(_check_ir_validity=False) 1447*523fa7a6SAndroid Build Coastguard Worker ) 1448*523fa7a6SAndroid Build Coastguard Worker return edge, dq_nodes_pre, q_nodes_pre 1449*523fa7a6SAndroid Build Coastguard Worker 1450*523fa7a6SAndroid Build Coastguard Worker example_inputs = (torch.randn(9, 8),) 1451*523fa7a6SAndroid Build Coastguard Worker model = TestAddSliceNoop() 1452*523fa7a6SAndroid Build Coastguard Worker m_eager = model.eval() 1453*523fa7a6SAndroid Build Coastguard Worker edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs) 1454*523fa7a6SAndroid Build Coastguard Worker 1455*523fa7a6SAndroid Build Coastguard Worker dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module) 1456*523fa7a6SAndroid Build Coastguard Worker q_nodes_post = count_q_nodes(edge.exported_program().graph_module) 1457*523fa7a6SAndroid Build Coastguard Worker # One dq and one q node around the slice copy should have been removed. 1458*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(dq_nodes_pre - dq_nodes_post, 1) 1459*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(q_nodes_pre - q_nodes_post, 1) 1460*523fa7a6SAndroid Build Coastguard Worker 1461*523fa7a6SAndroid Build Coastguard Worker # Check that the slice_copy is removed by the RemoveNoopPass. 1462*523fa7a6SAndroid Build Coastguard Worker for node in edge.exported_program().graph_module.graph.nodes: 1463*523fa7a6SAndroid Build Coastguard Worker self.assertFalse("slice" in str(node.target)) 1464*523fa7a6SAndroid Build Coastguard Worker 1465*523fa7a6SAndroid Build Coastguard Worker model = TestAddSliceNotNoop() 1466*523fa7a6SAndroid Build Coastguard Worker m_eager = model.eval() 1467*523fa7a6SAndroid Build Coastguard Worker edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs) 1468*523fa7a6SAndroid Build Coastguard Worker 1469*523fa7a6SAndroid Build Coastguard Worker dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module) 1470*523fa7a6SAndroid Build Coastguard Worker q_nodes_post = count_q_nodes(edge.exported_program().graph_module) 1471*523fa7a6SAndroid Build Coastguard Worker # One dq and one q node around the slice copy should have been removed. 1472*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(dq_nodes_pre, dq_nodes_post) 1473*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(q_nodes_pre, q_nodes_post) 1474*523fa7a6SAndroid Build Coastguard Worker 1475*523fa7a6SAndroid Build Coastguard Worker # Check that the slice_copy is not removed by the RemoveNoopPass. 1476*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 1477*523fa7a6SAndroid Build Coastguard Worker any( 1478*523fa7a6SAndroid Build Coastguard Worker "slice" in str(node.target) 1479*523fa7a6SAndroid Build Coastguard Worker for node in edge.exported_program().graph_module.graph.nodes 1480*523fa7a6SAndroid Build Coastguard Worker ) 1481*523fa7a6SAndroid Build Coastguard Worker ) 1482*523fa7a6SAndroid Build Coastguard Worker 1483*523fa7a6SAndroid Build Coastguard Worker def test_dq_q_no_op_pass(self) -> None: 1484*523fa7a6SAndroid Build Coastguard Worker class TestDqQ(torch.nn.Module): 1485*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1486*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1487*523fa7a6SAndroid Build Coastguard Worker 1488*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1489*523fa7a6SAndroid Build Coastguard Worker dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 1490*523fa7a6SAndroid Build Coastguard Worker x, 1.0, 0, -128, 127, torch.int8 1491*523fa7a6SAndroid Build Coastguard Worker ) 1492*523fa7a6SAndroid Build Coastguard Worker q = torch.ops.quantized_decomposed.quantize_per_tensor.default( 1493*523fa7a6SAndroid Build Coastguard Worker dq, 1.0, 0, -128, 127, torch.int8 1494*523fa7a6SAndroid Build Coastguard Worker ) 1495*523fa7a6SAndroid Build Coastguard Worker return q 1496*523fa7a6SAndroid Build Coastguard Worker 1497*523fa7a6SAndroid Build Coastguard Worker model = TestDqQ() 1498*523fa7a6SAndroid Build Coastguard Worker m_eager = model.eval() 1499*523fa7a6SAndroid Build Coastguard Worker ep = torch.export.export(m_eager, (torch.randn(9, 8),)) 1500*523fa7a6SAndroid Build Coastguard Worker edge = to_edge(ep) 1501*523fa7a6SAndroid Build Coastguard Worker # Check that the dq and q nodes are not touched by the RemoveNoopPass. 1502*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 1503*523fa7a6SAndroid Build Coastguard Worker any( 1504*523fa7a6SAndroid Build Coastguard Worker "dequantize" in str(node.target) 1505*523fa7a6SAndroid Build Coastguard Worker for node in edge.exported_program().graph_module.graph.nodes 1506*523fa7a6SAndroid Build Coastguard Worker ) 1507*523fa7a6SAndroid Build Coastguard Worker ) 1508*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 1509*523fa7a6SAndroid Build Coastguard Worker any( 1510*523fa7a6SAndroid Build Coastguard Worker "quantize" in str(node.target) 1511*523fa7a6SAndroid Build Coastguard Worker for node in edge.exported_program().graph_module.graph.nodes 1512*523fa7a6SAndroid Build Coastguard Worker ) 1513*523fa7a6SAndroid Build Coastguard Worker ) 1514*523fa7a6SAndroid Build Coastguard Worker 1515*523fa7a6SAndroid Build Coastguard Worker def test_dq_q_different_qparams(self) -> None: 1516*523fa7a6SAndroid Build Coastguard Worker class TestDqQDifferentQParam(torch.nn.Module): 1517*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1518*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1519*523fa7a6SAndroid Build Coastguard Worker 1520*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1521*523fa7a6SAndroid Build Coastguard Worker dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 1522*523fa7a6SAndroid Build Coastguard Worker x, 1.0, 0, -128, 127, torch.int8 1523*523fa7a6SAndroid Build Coastguard Worker ) 1524*523fa7a6SAndroid Build Coastguard Worker slice_copy_output = torch.ops.aten.slice_copy.Tensor(dq, 0, 0) 1525*523fa7a6SAndroid Build Coastguard Worker q = torch.ops.quantized_decomposed.quantize_per_tensor.default( 1526*523fa7a6SAndroid Build Coastguard Worker slice_copy_output, 1.0, 0, -127, 127, torch.int8 1527*523fa7a6SAndroid Build Coastguard Worker ) 1528*523fa7a6SAndroid Build Coastguard Worker return q 1529*523fa7a6SAndroid Build Coastguard Worker 1530*523fa7a6SAndroid Build Coastguard Worker model = TestDqQDifferentQParam() 1531*523fa7a6SAndroid Build Coastguard Worker m_eager = model.eval() 1532*523fa7a6SAndroid Build Coastguard Worker ep = torch.export.export(m_eager, (torch.randn(9, 8),)) 1533*523fa7a6SAndroid Build Coastguard Worker edge = to_edge(ep) 1534*523fa7a6SAndroid Build Coastguard Worker print(edge.exported_program().graph_module.graph) 1535*523fa7a6SAndroid Build Coastguard Worker # Check that the dq and q nodes are not touched by the RemoveNoopPass. 1536*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 1537*523fa7a6SAndroid Build Coastguard Worker any( 1538*523fa7a6SAndroid Build Coastguard Worker "dequantize" in str(node.target) 1539*523fa7a6SAndroid Build Coastguard Worker for node in edge.exported_program().graph_module.graph.nodes 1540*523fa7a6SAndroid Build Coastguard Worker ) 1541*523fa7a6SAndroid Build Coastguard Worker ) 1542*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 1543*523fa7a6SAndroid Build Coastguard Worker any( 1544*523fa7a6SAndroid Build Coastguard Worker "quantize" in str(node.target) 1545*523fa7a6SAndroid Build Coastguard Worker for node in edge.exported_program().graph_module.graph.nodes 1546*523fa7a6SAndroid Build Coastguard Worker ) 1547*523fa7a6SAndroid Build Coastguard Worker ) 1548*523fa7a6SAndroid Build Coastguard Worker self.assertFalse( 1549*523fa7a6SAndroid Build Coastguard Worker any( 1550*523fa7a6SAndroid Build Coastguard Worker "slice" in str(node.target) 1551*523fa7a6SAndroid Build Coastguard Worker for node in edge.exported_program().graph_module.graph.nodes 1552*523fa7a6SAndroid Build Coastguard Worker ) 1553*523fa7a6SAndroid Build Coastguard Worker ) 1554*523fa7a6SAndroid Build Coastguard Worker 1555*523fa7a6SAndroid Build Coastguard Worker def test_normalize_view_copy_base_pass(self) -> None: 1556*523fa7a6SAndroid Build Coastguard Worker 1557*523fa7a6SAndroid Build Coastguard Worker class ViewChain(torch.nn.Module): 1558*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1559*523fa7a6SAndroid Build Coastguard Worker x = torch.ops.aten.view_copy.default(x, [30, 1]) 1560*523fa7a6SAndroid Build Coastguard Worker x = torch.ops.aten.view_copy.default(x, [5, 6]) 1561*523fa7a6SAndroid Build Coastguard Worker x = torch.ops.aten.view_copy.default(x, [2, 15]) 1562*523fa7a6SAndroid Build Coastguard Worker x = torch.ops.aten.view_copy.default(x, [3, -1]) 1563*523fa7a6SAndroid Build Coastguard Worker return x 1564*523fa7a6SAndroid Build Coastguard Worker 1565*523fa7a6SAndroid Build Coastguard Worker def is_view_copy(node: torch.fx.Node) -> bool: 1566*523fa7a6SAndroid Build Coastguard Worker return ( 1567*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 1568*523fa7a6SAndroid Build Coastguard Worker and node.target == torch.ops.aten.view_copy.default 1569*523fa7a6SAndroid Build Coastguard Worker ) 1570*523fa7a6SAndroid Build Coastguard Worker 1571*523fa7a6SAndroid Build Coastguard Worker gm = export(ViewChain(), (torch.ones(30),)).graph_module 1572*523fa7a6SAndroid Build Coastguard Worker 1573*523fa7a6SAndroid Build Coastguard Worker # Check before transformation 1574*523fa7a6SAndroid Build Coastguard Worker n_view_copy_before = 0 1575*523fa7a6SAndroid Build Coastguard Worker n_view_copy_bases_before = 0 1576*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 1577*523fa7a6SAndroid Build Coastguard Worker if is_view_copy(node): 1578*523fa7a6SAndroid Build Coastguard Worker n_view_copy_before += 1 1579*523fa7a6SAndroid Build Coastguard Worker base = node.args[0] 1580*523fa7a6SAndroid Build Coastguard Worker if is_view_copy(base): 1581*523fa7a6SAndroid Build Coastguard Worker n_view_copy_bases_before += 1 1582*523fa7a6SAndroid Build Coastguard Worker 1583*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(n_view_copy_before, 4) 1584*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(n_view_copy_bases_before, 3) 1585*523fa7a6SAndroid Build Coastguard Worker 1586*523fa7a6SAndroid Build Coastguard Worker # Do transformation 1587*523fa7a6SAndroid Build Coastguard Worker p = NormalizeViewCopyBasePass() 1588*523fa7a6SAndroid Build Coastguard Worker gm_res = p(gm) 1589*523fa7a6SAndroid Build Coastguard Worker assert gm_res is not None 1590*523fa7a6SAndroid Build Coastguard Worker gm = gm_res.graph_module 1591*523fa7a6SAndroid Build Coastguard Worker 1592*523fa7a6SAndroid Build Coastguard Worker # Check after transformation 1593*523fa7a6SAndroid Build Coastguard Worker n_view_copy_after = 0 1594*523fa7a6SAndroid Build Coastguard Worker n_view_copy_bases_after = 0 1595*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes: 1596*523fa7a6SAndroid Build Coastguard Worker if is_view_copy(node): 1597*523fa7a6SAndroid Build Coastguard Worker n_view_copy_after += 1 1598*523fa7a6SAndroid Build Coastguard Worker base = node.args[0] 1599*523fa7a6SAndroid Build Coastguard Worker if is_view_copy(base): 1600*523fa7a6SAndroid Build Coastguard Worker n_view_copy_bases_after += 1 1601*523fa7a6SAndroid Build Coastguard Worker 1602*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(n_view_copy_after, 4) 1603*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(n_view_copy_bases_after, 0) 1604*523fa7a6SAndroid Build Coastguard Worker 1605*523fa7a6SAndroid Build Coastguard Worker def test_replace_view_copy_with_view_pass(self) -> None: # noqa: C901 1606*523fa7a6SAndroid Build Coastguard Worker 1607*523fa7a6SAndroid Build Coastguard Worker # Helper functions 1608*523fa7a6SAndroid Build Coastguard Worker def is_view_copy(node: torch.fx.Node) -> bool: 1609*523fa7a6SAndroid Build Coastguard Worker return ( 1610*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 1611*523fa7a6SAndroid Build Coastguard Worker and node.target == torch.ops.aten.view_copy.default 1612*523fa7a6SAndroid Build Coastguard Worker ) 1613*523fa7a6SAndroid Build Coastguard Worker 1614*523fa7a6SAndroid Build Coastguard Worker def is_memory_view(node: torch.fx.Node) -> bool: 1615*523fa7a6SAndroid Build Coastguard Worker return node.op == "call_function" and node.target == memory.view 1616*523fa7a6SAndroid Build Coastguard Worker 1617*523fa7a6SAndroid Build Coastguard Worker # Test example set up 1618*523fa7a6SAndroid Build Coastguard Worker class TestViewCopies(torch.nn.Module): 1619*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1620*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1621*523fa7a6SAndroid Build Coastguard Worker self.parameter = torch.nn.Parameter(torch.ones(1)) 1622*523fa7a6SAndroid Build Coastguard Worker 1623*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1624*523fa7a6SAndroid Build Coastguard Worker o1 = torch.ops.aten.view_copy.default(x, [1]) 1625*523fa7a6SAndroid Build Coastguard Worker o2 = torch.ops.aten.view_copy.default(self.parameter, [1]) 1626*523fa7a6SAndroid Build Coastguard Worker # view_copys at the end of a function are not replaced, so add 1627*523fa7a6SAndroid Build Coastguard Worker # a computation before the end of the graph. 1628*523fa7a6SAndroid Build Coastguard Worker return torch.ops.aten.add.Tensor(o1, o2) 1629*523fa7a6SAndroid Build Coastguard Worker 1630*523fa7a6SAndroid Build Coastguard Worker ep = torch.export.export( 1631*523fa7a6SAndroid Build Coastguard Worker TestViewCopies(), 1632*523fa7a6SAndroid Build Coastguard Worker args=(torch.ones(1),), 1633*523fa7a6SAndroid Build Coastguard Worker ) 1634*523fa7a6SAndroid Build Coastguard Worker for node in ep.graph.nodes: 1635*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 1636*523fa7a6SAndroid Build Coastguard Worker node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) 1637*523fa7a6SAndroid Build Coastguard Worker node.meta["spec"].shape_dynamism = TensorShapeDynamism.STATIC 1638*523fa7a6SAndroid Build Coastguard Worker 1639*523fa7a6SAndroid Build Coastguard Worker # Run tests 1640*523fa7a6SAndroid Build Coastguard Worker gm = ep.graph_module 1641*523fa7a6SAndroid Build Coastguard Worker 1642*523fa7a6SAndroid Build Coastguard Worker # Check before transformation 1643*523fa7a6SAndroid Build Coastguard Worker FileCheck().check_count( 1644*523fa7a6SAndroid Build Coastguard Worker "torch.ops.aten.view_copy.default", 2, exactly=True 1645*523fa7a6SAndroid Build Coastguard Worker ).run(gm.code) 1646*523fa7a6SAndroid Build Coastguard Worker FileCheck().check_count("executorch_exir_memory_view", 0, exactly=True).run( 1647*523fa7a6SAndroid Build Coastguard Worker gm.code 1648*523fa7a6SAndroid Build Coastguard Worker ) 1649*523fa7a6SAndroid Build Coastguard Worker 1650*523fa7a6SAndroid Build Coastguard Worker # Do transformation 1651*523fa7a6SAndroid Build Coastguard Worker p = ReplaceViewCopyWithViewPass() 1652*523fa7a6SAndroid Build Coastguard Worker gm_res = p(gm) 1653*523fa7a6SAndroid Build Coastguard Worker assert gm_res is not None 1654*523fa7a6SAndroid Build Coastguard Worker gm = gm_res.graph_module 1655*523fa7a6SAndroid Build Coastguard Worker 1656*523fa7a6SAndroid Build Coastguard Worker # Check after transformation 1657*523fa7a6SAndroid Build Coastguard Worker FileCheck().check_count( 1658*523fa7a6SAndroid Build Coastguard Worker "torch.ops.aten.view_copy.default", 0, exactly=True 1659*523fa7a6SAndroid Build Coastguard Worker ).run(gm.code) 1660*523fa7a6SAndroid Build Coastguard Worker FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run( 1661*523fa7a6SAndroid Build Coastguard Worker gm.code 1662*523fa7a6SAndroid Build Coastguard Worker ) 1663*523fa7a6SAndroid Build Coastguard Worker 1664*523fa7a6SAndroid Build Coastguard Worker def test_constant_prop_pass_for_no_grad(self) -> None: 1665*523fa7a6SAndroid Build Coastguard Worker class LSTM(torch.nn.Module): 1666*523fa7a6SAndroid Build Coastguard Worker def __init__(self, input_size, hidden_size, num_layers): 1667*523fa7a6SAndroid Build Coastguard Worker super(LSTM, self).__init__() 1668*523fa7a6SAndroid Build Coastguard Worker self.hidden_size = hidden_size 1669*523fa7a6SAndroid Build Coastguard Worker self.num_layers = num_layers 1670*523fa7a6SAndroid Build Coastguard Worker self.lstm = torch.nn.LSTM( 1671*523fa7a6SAndroid Build Coastguard Worker input_size, hidden_size, num_layers, batch_first=True 1672*523fa7a6SAndroid Build Coastguard Worker ) 1673*523fa7a6SAndroid Build Coastguard Worker 1674*523fa7a6SAndroid Build Coastguard Worker def forward(self, text_tokens): 1675*523fa7a6SAndroid Build Coastguard Worker # input: (seq_len, batch, input_size) 1676*523fa7a6SAndroid Build Coastguard Worker lstm_out, (new_hidden_state, new_cell_state) = self.lstm( 1677*523fa7a6SAndroid Build Coastguard Worker input=text_tokens, hx=None 1678*523fa7a6SAndroid Build Coastguard Worker ) 1679*523fa7a6SAndroid Build Coastguard Worker return lstm_out 1680*523fa7a6SAndroid Build Coastguard Worker 1681*523fa7a6SAndroid Build Coastguard Worker lstm = LSTM(input_size=200, hidden_size=203, num_layers=2) 1682*523fa7a6SAndroid Build Coastguard Worker example_input = (torch.rand(2, 10, 200),) 1683*523fa7a6SAndroid Build Coastguard Worker 1684*523fa7a6SAndroid Build Coastguard Worker aten = torch.export.export(lstm, example_input, strict=False) 1685*523fa7a6SAndroid Build Coastguard Worker _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( 1686*523fa7a6SAndroid Build Coastguard Worker _check_ir_validity=True, 1687*523fa7a6SAndroid Build Coastguard Worker _skip_dim_order=True, # TODO(T189114319): Reuse dim order op after solving the ios oss issue 1688*523fa7a6SAndroid Build Coastguard Worker ) 1689*523fa7a6SAndroid Build Coastguard Worker 1690*523fa7a6SAndroid Build Coastguard Worker edge_manager: EdgeProgramManager = to_edge( 1691*523fa7a6SAndroid Build Coastguard Worker aten, 1692*523fa7a6SAndroid Build Coastguard Worker compile_config=_EDGE_COMPILE_CONFIG, 1693*523fa7a6SAndroid Build Coastguard Worker ) 1694*523fa7a6SAndroid Build Coastguard Worker new_ep = constant_prop_pass(edge_manager._edge_programs["forward"]) 1695*523fa7a6SAndroid Build Coastguard Worker _ = copy.deepcopy(new_ep.module_call_graph) 1696*523fa7a6SAndroid Build Coastguard Worker 1697*523fa7a6SAndroid Build Coastguard Worker def test_dim_order_revert_pass(self) -> None: 1698*523fa7a6SAndroid Build Coastguard Worker aten_op_str = "torch.ops.aten._to_copy.default" 1699*523fa7a6SAndroid Build Coastguard Worker edge_aten_op_str = "executorch_exir_dialects_edge__ops_aten__to_copy_default" 1700*523fa7a6SAndroid Build Coastguard Worker edge_dim_order_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" 1701*523fa7a6SAndroid Build Coastguard Worker 1702*523fa7a6SAndroid Build Coastguard Worker class Module(torch.nn.Module): 1703*523fa7a6SAndroid Build Coastguard Worker """ 1704*523fa7a6SAndroid Build Coastguard Worker A simple module that has a single to op that converts to channels last and then back to contiguous. 1705*523fa7a6SAndroid Build Coastguard Worker Assuming contiguous input. 1706*523fa7a6SAndroid Build Coastguard Worker """ 1707*523fa7a6SAndroid Build Coastguard Worker 1708*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1709*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1710*523fa7a6SAndroid Build Coastguard Worker 1711*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 1712*523fa7a6SAndroid Build Coastguard Worker return x.to(memory_format=torch.channels_last).to( 1713*523fa7a6SAndroid Build Coastguard Worker memory_format=torch.contiguous_format 1714*523fa7a6SAndroid Build Coastguard Worker ) + x.to(memory_format=torch.channels_last).to( 1715*523fa7a6SAndroid Build Coastguard Worker memory_format=torch.contiguous_format 1716*523fa7a6SAndroid Build Coastguard Worker ) 1717*523fa7a6SAndroid Build Coastguard Worker 1718*523fa7a6SAndroid Build Coastguard Worker @staticmethod 1719*523fa7a6SAndroid Build Coastguard Worker def to_copy_count(): 1720*523fa7a6SAndroid Build Coastguard Worker return 4 1721*523fa7a6SAndroid Build Coastguard Worker 1722*523fa7a6SAndroid Build Coastguard Worker def _do_checks( 1723*523fa7a6SAndroid Build Coastguard Worker test_str: str, allowed: str, allowed_count: int, not_allowed_list: List[str] 1724*523fa7a6SAndroid Build Coastguard Worker ) -> None: 1725*523fa7a6SAndroid Build Coastguard Worker for not_allowed in not_allowed_list: 1726*523fa7a6SAndroid Build Coastguard Worker FileCheck().check_count(allowed, allowed_count, exactly=True).check_not( 1727*523fa7a6SAndroid Build Coastguard Worker not_allowed 1728*523fa7a6SAndroid Build Coastguard Worker ).run(test_str) 1729*523fa7a6SAndroid Build Coastguard Worker 1730*523fa7a6SAndroid Build Coastguard Worker m = Module() 1731*523fa7a6SAndroid Build Coastguard Worker n = m.to_copy_count() 1732*523fa7a6SAndroid Build Coastguard Worker input = torch.randn([2, 3, 4, 5]).to(memory_format=torch.contiguous_format) 1733*523fa7a6SAndroid Build Coastguard Worker 1734*523fa7a6SAndroid Build Coastguard Worker # 1. vanilla export, no edge ops 1735*523fa7a6SAndroid Build Coastguard Worker ep = export( 1736*523fa7a6SAndroid Build Coastguard Worker m, 1737*523fa7a6SAndroid Build Coastguard Worker (input,), 1738*523fa7a6SAndroid Build Coastguard Worker ).run_decompositions({}) 1739*523fa7a6SAndroid Build Coastguard Worker _do_checks( 1740*523fa7a6SAndroid Build Coastguard Worker ep.graph_module.code, 1741*523fa7a6SAndroid Build Coastguard Worker aten_op_str, 1742*523fa7a6SAndroid Build Coastguard Worker n, 1743*523fa7a6SAndroid Build Coastguard Worker [edge_aten_op_str, edge_dim_order_op_str], 1744*523fa7a6SAndroid Build Coastguard Worker ) 1745*523fa7a6SAndroid Build Coastguard Worker 1746*523fa7a6SAndroid Build Coastguard Worker # 2a. to edge without dim orders, we should see edge aten ops but not dim order ops 1747*523fa7a6SAndroid Build Coastguard Worker edge_prog = to_edge( 1748*523fa7a6SAndroid Build Coastguard Worker ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=True) 1749*523fa7a6SAndroid Build Coastguard Worker )._edge_programs["forward"] 1750*523fa7a6SAndroid Build Coastguard Worker _do_checks( 1751*523fa7a6SAndroid Build Coastguard Worker edge_prog.graph_module.code, 1752*523fa7a6SAndroid Build Coastguard Worker edge_aten_op_str, 1753*523fa7a6SAndroid Build Coastguard Worker n, 1754*523fa7a6SAndroid Build Coastguard Worker [aten_op_str, edge_dim_order_op_str], 1755*523fa7a6SAndroid Build Coastguard Worker ) 1756*523fa7a6SAndroid Build Coastguard Worker 1757*523fa7a6SAndroid Build Coastguard Worker # 3a. expect no change after the pass, we should see edge aten ops but not dim order ops 1758*523fa7a6SAndroid Build Coastguard Worker new_res = DimOrderOpsRevertPass()(edge_prog.graph_module) 1759*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_res) 1760*523fa7a6SAndroid Build Coastguard Worker _do_checks( 1761*523fa7a6SAndroid Build Coastguard Worker new_res.graph_module.code, 1762*523fa7a6SAndroid Build Coastguard Worker edge_aten_op_str, 1763*523fa7a6SAndroid Build Coastguard Worker n, 1764*523fa7a6SAndroid Build Coastguard Worker [aten_op_str, edge_dim_order_op_str], 1765*523fa7a6SAndroid Build Coastguard Worker ) 1766*523fa7a6SAndroid Build Coastguard Worker 1767*523fa7a6SAndroid Build Coastguard Worker # 2b. let's try with dim order enabled, we should see edge dim order ops but not edge aten ops 1768*523fa7a6SAndroid Build Coastguard Worker edge_prog_dim_order = to_edge( 1769*523fa7a6SAndroid Build Coastguard Worker ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=False) 1770*523fa7a6SAndroid Build Coastguard Worker )._edge_programs["forward"] 1771*523fa7a6SAndroid Build Coastguard Worker _do_checks( 1772*523fa7a6SAndroid Build Coastguard Worker edge_prog_dim_order.graph_module.code, 1773*523fa7a6SAndroid Build Coastguard Worker edge_dim_order_op_str, 1774*523fa7a6SAndroid Build Coastguard Worker n, 1775*523fa7a6SAndroid Build Coastguard Worker [aten_op_str, edge_aten_op_str], 1776*523fa7a6SAndroid Build Coastguard Worker ) 1777*523fa7a6SAndroid Build Coastguard Worker 1778*523fa7a6SAndroid Build Coastguard Worker # 3b. expect edge aten ops after the pass, we should see not see the edge dim order ops 1779*523fa7a6SAndroid Build Coastguard Worker new_res_dim_order = DimOrderOpsRevertPass()(edge_prog_dim_order.graph_module) 1780*523fa7a6SAndroid Build Coastguard Worker self.assertIsNotNone(new_res_dim_order) 1781*523fa7a6SAndroid Build Coastguard Worker _do_checks( 1782*523fa7a6SAndroid Build Coastguard Worker new_res_dim_order.graph_module.code, 1783*523fa7a6SAndroid Build Coastguard Worker edge_aten_op_str, 1784*523fa7a6SAndroid Build Coastguard Worker n, 1785*523fa7a6SAndroid Build Coastguard Worker [aten_op_str, edge_dim_order_op_str], 1786*523fa7a6SAndroid Build Coastguard Worker ) 1787*523fa7a6SAndroid Build Coastguard Worker 1788*523fa7a6SAndroid Build Coastguard Worker output_no_dim_order = new_res.graph_module(input) 1789*523fa7a6SAndroid Build Coastguard Worker output_no_dim_order_revert = new_res_dim_order.graph_module(input) 1790*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 1791*523fa7a6SAndroid Build Coastguard Worker torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0]) 1792*523fa7a6SAndroid Build Coastguard Worker ) 1793