1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import final, List, NamedTuple 8 9import torch 10 11from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult 12from executorch.exir.backend.compile_spec_schema import CompileSpec 13from executorch.exir.dialects._ops import ops as exir_ops 14from torch.export.exported_program import ExportedProgram 15 16 17# A simple way to represent an op used in BackendWithCompilerDemo 18class DemoOp(NamedTuple): 19 op: str 20 numel: int 21 dtype: str 22 23 def __repr__(self): 24 return f"op:demo::{self.op}, numel:{self.numel}, dtype:{self.dtype}" 25 26 27# Backend details are final (cannot be subclassed). 28@final 29class BackendWithCompilerDemo(BackendDetails): 30 """ 31 An example implementation to lower a module. Currently this example 32 only supports the sin operator. 33 The example origin module can be: 34 35 class SinModule(torch.nn.Module): 36 def __init__(self): 37 super().__init__() 38 39 def forward(self, x): 40 return torch.sin(x) 41 42 sin_module = SinModule() 43 model_inputs = torch.ones(1, 1) 44 45 edgeir_m = to_edge(export(sin_module, model_inputs)) 46 compile_specs = [] 47 lowered_sin_module = to_backend( 48 "BackendWithCompilerDemo", edgeir_m, compile_specs 49 ) 50 51 # Module composition of lowered modules is possible. 52 class HugeModule(torch.nn.Module): 53 def __init__(self): 54 super().__init__() 55 self.lowered_linear_sin = lowered_module 56 57 def forward(self, x): 58 output_from_submodule = self.lowered_linear_sin(x) 59 return output_from_submodule 60 61 The output trace through graph result is 62 graph(): 63 %arg0_1 : [#users=2] = placeholder[target=arg0_1] 64 %lowered_module_0 : [#users=1] = get_attr[target=lowered_module_0] 65 %executorch_call_delegate : [#users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, forward, %arg0_1), kwargs = {}) 66 return [executorch_call_delegate] 67 68 Args: 69 edge_ir_module: The edge ir module after capture. 70 compile_specs: List of backend-specific objects needed for the compilation process 71 72 Returns: 73 Bytes: A compiled blob - a binary that can run the desired program in the backend. 74 Raises: 75 RuntimeError: The module cannot be processed by the backend. 76 """ 77 78 @staticmethod 79 def preprocess( 80 edge_program: ExportedProgram, 81 compile_specs: List[CompileSpec], 82 ) -> PreprocessResult: 83 processed_bytes = "" 84 number_of_instruction = 0 85 version = "0" 86 debug_handle_map = {} 87 match_ops = [ 88 exir_ops.edge.aten.sin.default, 89 exir_ops.edge.aten.mm.default, 90 exir_ops.edge.aten.add.Tensor, 91 torch.ops.aten.sin.default, 92 exir_ops.edge.aten.linear.default, 93 exir_ops.edge.aten.scaled_dot_product_attention.default, 94 exir_ops.edge.aten.upsample_nearest2d.vec, 95 ] 96 97 for node in edge_program.graph.nodes: 98 if node.op == "call_function": 99 # TODO(gasoonjia): remove the support of torch.ops.aten.sin.default after migrate serde to edge dialect. 100 if node.target in match_ops: 101 simple_op = DemoOp( 102 node.target.__name__, 103 int(torch.prod(torch.tensor(node.meta["val"].shape), 0).item()), 104 node.meta["val"].dtype, 105 ) 106 number_of_instruction += 1 107 processed_bytes += ( 108 str(simple_op) 109 + "<debug_handle>" 110 + str(node.meta.get("debug_handle", -1)) 111 + "#" 112 ) 113 else: 114 raise RuntimeError( 115 f"{node.op} {node.target.__name__} is not supported in backend BackendWithCompilerDemo" 116 ) 117 elif node.op == "placeholder": 118 continue 119 elif node.op == "output": 120 continue 121 elif node.op == "get_attr": 122 continue 123 else: 124 raise RuntimeError( 125 f"{node.op} is not supported in backend BackendWithCompilerDemo" 126 ) 127 # Since the graph remains the same, debug handle remains the same. 128 original_debug_id = node.meta["debug_handle"] 129 new_debug_id = original_debug_id 130 debug_handle_map[new_debug_id] = (original_debug_id,) 131 return PreprocessResult( 132 processed_bytes=bytes( 133 str(number_of_instruction) 134 + "version:" 135 + version 136 + "#" 137 + processed_bytes, 138 encoding="utf8", 139 ), 140 debug_handle_map=debug_handle_map, 141 ) 142