xref: /aosp_15_r20/external/executorch/exir/backend/test/backend_with_compiler_demo.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import final, List, NamedTuple
8
9import torch
10
11from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
12from executorch.exir.backend.compile_spec_schema import CompileSpec
13from executorch.exir.dialects._ops import ops as exir_ops
14from torch.export.exported_program import ExportedProgram
15
16
17# A simple way to represent an op used in BackendWithCompilerDemo
18class DemoOp(NamedTuple):
19    op: str
20    numel: int
21    dtype: str
22
23    def __repr__(self):
24        return f"op:demo::{self.op}, numel:{self.numel}, dtype:{self.dtype}"
25
26
27# Backend details are final (cannot be subclassed).
28@final
29class BackendWithCompilerDemo(BackendDetails):
30    """
31    An example implementation to lower a module. Currently this example
32    only supports the sin operator.
33    The example origin module can be:
34
35    class SinModule(torch.nn.Module):
36        def __init__(self):
37            super().__init__()
38
39        def forward(self, x):
40            return torch.sin(x)
41
42    sin_module = SinModule()
43    model_inputs = torch.ones(1, 1)
44
45    edgeir_m = to_edge(export(sin_module, model_inputs))
46    compile_specs = []
47    lowered_sin_module = to_backend(
48        "BackendWithCompilerDemo", edgeir_m, compile_specs
49    )
50
51    # Module composition of lowered modules is possible.
52    class HugeModule(torch.nn.Module):
53        def __init__(self):
54            super().__init__()
55            self.lowered_linear_sin = lowered_module
56
57        def forward(self, x):
58            output_from_submodule = self.lowered_linear_sin(x)
59            return output_from_submodule
60
61    The output trace through graph result is
62    graph():
63        %arg0_1 : [#users=2] = placeholder[target=arg0_1]
64        %lowered_module_0 : [#users=1] = get_attr[target=lowered_module_0]
65        %executorch_call_delegate : [#users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, forward, %arg0_1), kwargs = {})
66        return [executorch_call_delegate]
67
68    Args:
69        edge_ir_module: The edge ir module after capture.
70        compile_specs: List of backend-specific objects needed for the compilation process
71
72    Returns:
73        Bytes: A compiled blob - a binary that can run the desired program in the backend.
74    Raises:
75        RuntimeError: The module cannot be processed by the backend.
76    """
77
78    @staticmethod
79    def preprocess(
80        edge_program: ExportedProgram,
81        compile_specs: List[CompileSpec],
82    ) -> PreprocessResult:
83        processed_bytes = ""
84        number_of_instruction = 0
85        version = "0"
86        debug_handle_map = {}
87        match_ops = [
88            exir_ops.edge.aten.sin.default,
89            exir_ops.edge.aten.mm.default,
90            exir_ops.edge.aten.add.Tensor,
91            torch.ops.aten.sin.default,
92            exir_ops.edge.aten.linear.default,
93            exir_ops.edge.aten.scaled_dot_product_attention.default,
94            exir_ops.edge.aten.upsample_nearest2d.vec,
95        ]
96
97        for node in edge_program.graph.nodes:
98            if node.op == "call_function":
99                # TODO(gasoonjia): remove the support of torch.ops.aten.sin.default after migrate serde to edge dialect.
100                if node.target in match_ops:
101                    simple_op = DemoOp(
102                        node.target.__name__,
103                        int(torch.prod(torch.tensor(node.meta["val"].shape), 0).item()),
104                        node.meta["val"].dtype,
105                    )
106                    number_of_instruction += 1
107                    processed_bytes += (
108                        str(simple_op)
109                        + "<debug_handle>"
110                        + str(node.meta.get("debug_handle", -1))
111                        + "#"
112                    )
113                else:
114                    raise RuntimeError(
115                        f"{node.op} {node.target.__name__} is not supported in backend BackendWithCompilerDemo"
116                    )
117            elif node.op == "placeholder":
118                continue
119            elif node.op == "output":
120                continue
121            elif node.op == "get_attr":
122                continue
123            else:
124                raise RuntimeError(
125                    f"{node.op} is not supported in backend BackendWithCompilerDemo"
126                )
127            # Since the graph remains the same, debug handle remains the same.
128            original_debug_id = node.meta["debug_handle"]
129            new_debug_id = original_debug_id
130            debug_handle_map[new_debug_id] = (original_debug_id,)
131        return PreprocessResult(
132            processed_bytes=bytes(
133                str(number_of_instruction)
134                + "version:"
135                + version
136                + "#"
137                + processed_bytes,
138                encoding="utf8",
139            ),
140            debug_handle_map=debug_handle_map,
141        )
142