xref: /aosp_15_r20/external/executorch/docs/source/compiler-backend-dialect.md (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Backend Dialect
2## Overview
3
4_Backend dialect_ is a special variant of [edge dialect](./ir-exir.md), because it contains backend specific nodes and metadata, after backend specific graph transformations. Backend dialect is an optional stage, only needed if we want to introduce backend-awareness into the graph. More specifically, a graph in backend dialect may contain operators or delegated lowered modules (see [delegate doc](./compiler-delegate-and-partitioner.md)) that are only meaningful to the target backend. One use case is that if we want to fuse operators into a single operator, for example, fusing consecutive addmm + relu to a single operator addmm_relu, we can do that here.
5
6This document describes how to introduce backend specific operators.
7
8Difference between custom ops and backend specific ops: while custom ops are showing up in eager mode, ATen dialect, and edge dialect, backend-specific ops are only being introduced by passes happening after edge dialect.
9
10
11## When to use
12
13This dialect allows the introduction of operators that do not conform to the schema defined in the canonical ATen operator set, and which do not appear in any of the dialects above (ATen dialect and edge dialect). Consider using backend operators if your use case satisfies one or more of the following criteria:
14
15
16
17* Your backend provides a library that optimizes a certain operator that is equivalent to a subgraph. For example, `linear_relu` (equivalent to linear + relu) that can be executed faster on a certain backend.
18* There's a need to retrace the graph module after it is already lowered to a backend. When we retrace, backend operators can transform back to the original subgraph (in ATen dialect) where normal custom op doesn't take care of that.
19* Your backend-specific operator doesn't have a generic CPU kernel but only a kernel for a certain backend. Using a backend operator can workaround this issue by using the original subgraph as default kernel and keeping the graph module runnable.
20* Alternatively, you can use delegate if you are concerned it might be an overkill and just want something more lightweight and only requires Python code at the compiler stage.
21
22
23## APIs
24
25For an operator/subgraph replacement, the common flow is:
26
27
28
291. Register an operator that has the same input and output as the subgraph. This operator won’t have the target-specific implementations (also, doesn’t need to in the compilation stage), but it needs to give the same result as the subgraph.
302. Create a pattern that allows the compiler to find the subgraph and substitute it with the replacement.
313. Write a pass to replace the subgraph with the new operator.
32
33In order to facilitate the process, we provide an API to help reduce the effort for ExecuTorch users to do these steps.
34
35
36### Pass Infra Entry Point
37
38To lower edge ops to backend ops, a pass will perform pattern matching to identify the edge ops of interest in the graph, and then replace them with equivalent backend operators. There are two APIs to register such passes:
39
40
41
42* `transform()`. An API on ExportProgram that allows users to provide custom passes. Note that this is not guarded by any validator so the soundness of the program is not guaranteed.
43* [ExecutorchBackendConfig.passes](https://github.com/pytorch/executorch/blob/main/exir/capture/_config.py#L40). If added here, the pass will be part of the lowering process from backend dialect to ExecutorchProgram.
44
45Example: one such pass is QuantFusion. This pass takes a "canonical quantization pattern", ie. "dequant - some_op - quant" and fuses this pattern into a single operator that is backend specific, i.e. `quantized_decomposed::some_op`. Another simpler example is [here](https://github.com/pytorch/executorch/blob/main/exir/passes/replace_edge_with_backend_pass.py#L20) where we replace `sym_size` operators to the ones that are understood by ExecuTorch
46
47
48### Pattern Binding Decorator
49
50We provide a decorator `bind_pattern_to_op` to help users easily register their backend operators into EXIR. This decorator takes:
51
52
53
54* a `torch.Library` object, it indicates which library or namespace this backend operator belongs to.
55* a name or schema. If we already defined the schema of the backend operator in the `torch.Library` object, only a name is needed. Otherwise we can register the schema if a schema string is being passed in.
56
57This decorator should be added to the pattern we are trying to match (and then lower to this backend op) on edge dialect. This way we are registering this pattern as a `CompositeImplicitAutograd` kernel for this backend operator.
58
59Then the operator can be accessed/used from the passes. The `CompositeImplicitAutograd` kernel makes sure:
60
61
62
631. No need for the user to write a (CPU) runnable kernel.
642. Ensures the retrace-ability of `ExportProgram`. Once retraced, the backend operator will be decomposed into the ATen ops used in the pattern.
65
66
67## Example
68
69Let’s assume a simple program that contains both add and relu operators:
70```python
71def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
72    z = x + y
73    return torch.ops.aten.relu.default(z)
74```
75After lowering to edge dialect it becomes:
76```
77graph():
78    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
79    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
80    %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
81    %aten_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.relu.default](args = (%aten_add_tensor,), kwargs = {})
82    return (aten_relu_default,)
83```
84Now I want to write a pass to merge `add` and `relu` into `add_relu`, the first step is to write a pattern:
85```python
86# In the pattern, we can use edge ops and ATen ops interchangably
87def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
88    z = torch.ops.aten.add.Tensor(x, y)
89    out = torch.ops.aten.relu.default(z)
90    return out
91```
92Then we need to create an operator library from the fused operator namespace, then use the decorator on our pattern:
93
94```python
95lib = Library("foo_namespace", "DEF")
96
97@bind_pattern_to_op(lib, "add_relu(Tensor self, Tensor other) -> Tensor")
98def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
99        z = torch.ops.aten.add.Tensor(x, y)
100        out = torch.ops.aten.relu.default(z)
101        return out
102```
103This way we are registering the pattern as a kernel to `add_relu` and it is ready to be used in a pass. A simple pass looks like this:
104```python
105class AddReluFusionPass(ExportPass):
106    def call(self, graph_module: GraphModule) -> PassResult:
107        # decorator registers this pattern as a CompositeExplicitAutograd kernel, since there's no kernel registered before.
108        @bind_pattern_to_op(lib, "add_relu")
109        def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
110            z = torch.ops.aten.add.Tensor(x, y)
111            out = torch.ops.aten.relu.default(z)
112            return out
113
114        def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
115            return torch.ops.foo_namespace.add_relu.default(x, y)
116
117        subgraph_rewriter.replace_pattern(
118            graph_module,
119            _trace_and_lower_to_edge_ops(pattern),
120            _trace_and_lower_to_edge_ops(replacement),
121        )
122        return PassResult(graph_module, True)
123```
124The result graph looks like this:
125```
126graph():
127    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
128    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
129    %foo_namespace_add_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.foo_namespace.add_relu.default](args = (%arg0_1, %arg1_1), kwargs = {})
130    return (foo_namespace_add_relu_default,)
131```
132### Op Set
133
134There are the backend operators currently using `bind_pattern_to_op` API.
135
136* `executorch_prims::add.int(SymInt a, SymInt b) -> SymInt`
137  * pattern: builtin.add
138  * backend: executor
139* `executorch_prims::mul.int(SymInt a, SymInt b) -> SymInt`
140  * pattern: builtin.mul
141  * backend: executor
142* `executorch_prims::sub.int(SymInt a, SymInt b) -> SymInt`
143  * pattern: builtin.sub
144  * backend: executor
145* `executorch_prims::floordiv.int(SymInt a, SymInt b) -> SymInt`
146  * pattern: builtin.floordiv
147  * backend: executor
148* `executorch_prims::truediv.int(Scalar a, Scalar b) -> Scalar`
149  * pattern: builtin.div
150  * backend: executor
151* `executorch_prims::sym_float.Scalar(Scalar a) -> Scalar`
152  * pattern: builtin.float
153  * backend: executor
154* `executorch_prims::gt.int(SymInt a, SymInt b) -> bool`
155  * pattern: builtin.gt
156  * backend: executor
157* `executorch_prims::lt.int(SymInt a, SymInt b) -> bool`
158  * pattern: builtin.lt
159  * backend: executor
160* `executorch_prims::ge.int(SymInt a, SymInt b) -> bool`
161  * pattern: builtin.ge
162  * backend: executor
163* `executorch_prims::le.int(SymInt a, SymInt b) -> bool`
164  * pattern: builtin.le
165  * backend: executor
166* `executorch_prims::eq.int(SymInt a, SymInt b) -> bool`
167  * pattern: builtin.eq
168  * backend: executor
169* `executorch_prims::mod.Scalar(SymInt a, SymInt b) -> SymInt`
170  * pattern: builtin.divmod
171  * backend: executor
172* `executorch_prims::neg.Scalar(Scalar a) -> Scalar`
173  * pattern: operator.ne
174  * backend: executor
175* `quantized_decomposed::embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor`
176  * pattern: [source](https://github.com/pytorch/executorch/blob/main/exir/passes/_quant_patterns_and_replacements.py)
177  * backend: quantization
178* `quantized_decomposed::add(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc`
179  * pattern: [source](https://github.com/pytorch/executorch/blob/main/exir/passes/_quant_patterns_and_replacements.py)
180  * backend: quantization
181* `quantized_decomposed::add.scalar(Tensor qa, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, ScalarType a_dtype, Scalar b, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max, ScalarType out_dtype) -> Tensor`
182  * pattern: [source](https://github.com/pytorch/executorch/blob/main/exir/passes/_quant_patterns_and_replacements.py)
183  * backend: quantization
184* `quantized_decomposed::add_relu(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc`
185  * pattern: [source](https://github.com/pytorch/executorch/blob/main/exir/passes/_quant_patterns_and_replacements.py)
186  * backend: quantization
187