1# Backend and Delegate 2 3Audience: Vendors, Backend Delegate developers, who are interested in integrating their own compilers and hardware as part of ExecuTorch 4 5Backend delegation is an entry point for backends to process and execute PyTorch 6programs to leverage performance and efficiency benefits of specialized 7backends and hardware, while still providing PyTorch users with an experience 8close to that of the PyTorch runtime. 9 10## Backend Interfaces: Overview 11 12At a high level, the entry point for backends is defined by 2 components: 13 14- An IR to represent the program: **Edge Dialect** (which is produced through 15 the `to_edge` API) 16- A couple of interfaces for backends to implement: 17 - Ahead-of-Time (AOT) 18 - Program preprocessing (e.g. ahead of time compilation, transformation, optimization...). 19 - Runtime 20 - Program initialization (e.g. runtime compilation). 21 - Program execution. 22 - (optional) Program destroy (e.g. release backend owned resource). 23 24A delegate backend implementation is composed of: 25 261) An ahead-of-time preprocessing interface 272) A runtime initialization and execution interface 28 29The diagram looks like following 30 31<img src="./_static/img/backend_interface.png" alt="drawing" style="width:600px;"/> 32 33**Figure 1.** A high level of entry points for backend interfaces, including both ahead-of-time and runtime. 34 35## Backend Interfaces: Ahead-of-Time Preprocessing 36 37There are mainly two Ahead-of-Time entry point for backend to implement: `partition` and `preprocess`. 38 39`partitioner` is an algorithm implemented by the backend to tag the nodes to be lowered to the backend. `to_backend` API will apply the partition algorithm and lower each subgraph, which consists of connected tagged nodes, to the targeted backend. Every subgraph 40will be sent to the `preprocess` part provided by the backend to compiled as a binary blob. 41 42During partition, the `exported_program` is not allowed to mutate the program, and it's supposed to apply tag to each node. The 43`PartitionResult` includes both tagged exported program and the partition tags dictionary for `to_backend` to look up the tag and 44link to the `backend_id` and `compile_spec` 45 46```python 47def partition( 48 exported_program: ExportedProgram, 49) -> PartitionResult: 50``` 51 52During preprocessing, backends are given an edge dialect program, 53a list of compile specs specifying the values needed for compilation, and are 54expected to return a compiled blob, or binary containing the desired program to be 55run in the backend. During serialization, the 56compiled blob will be serialized as part of the `.pte` file, and directly loaded to the device. The 57API for this process is: 58 59```python 60def preprocess( 61 edge_program: ExportedProgram, 62 compile_specs: List[CompileSpec], 63) -> PreprocessResult: 64``` 65 66A demo of the preprocess function is implemented 67[here](https://github.com/pytorch/executorch/blob/main/exir/backend/test/backend_with_compiler_demo.py). 68The demo loops through the nodes in the graph module of the `edge_program` and 69serializes the `add`, `mul`, and `sin` instructions into a string, which is later 70parsed and executed at runtime. 71 72The diagram looks like following 73 74<img src="./_static/img/backend_interface_aot.png" alt="drawing" style="width:800px;"/> 75 76**Figure 2.** The graph goes through partition and each subgraph will be sent to the preprocess part. 77 78## Backend Interfaces: Runtime Initialization and Execution 79 80During the runtime, the compiled blob from the `preprocess` function will be 81loaded and passed directly to the backend's custom `init` function. This 82function is responsible for further processing the compiled unit, as well as 83perform any backend initialization. The backend's custom `execute` function will 84then be called to execute the handle produced by `init`. And finally, if 85destroying is required for some backend, backends can implement a `destroy` 86function which will be called when the program is out of its lifespan. 87 88```cpp 89// Runtime check 90ET_NODISCARD bool is_available(); 91 92// Runtime initialization 93ET_NODISCARD virtual Result<DelegateHandle*> init( 94 BackendInitContext& context, 95 FreeableBuffer* processed, 96 ArrayRef<CompileSpec> compile_specs); 97 98// Runtime execution 99ET_NODISCARD virtual Error execute( 100 BackendExecutionContext& context, 101 DelegateHandle* handle, 102 EValue** args); 103 104// [optional] Runtime destroy. Destroy the resource held by the backend 105virtual void destroy(ET_UNUSED DelegateHandle* handle); 106``` 107 108The diagram looks like following 109 110<img src="./_static/img/backend_interface_runtime.png" alt="drawing" style="width:600px;"/> 111 112**Figure 3.** The relationship between standard ExecuTorch Runtime and backend entry point. 113 114 115In order to make backend available to ExecuTorch runtime, it must be registered via the `register_backend` API: 116```cpp 117ET_NODISCARD Error register_backend(const Backend& backend); 118``` 119 120Static registeration, i.e., at libraray init or load time, of a backend can be achieved as follows: 121```cpp 122namespace { 123auto cls = BackendWithCompiler(); 124Backend backend{"BackendWithCompilerDemo", &cls}; 125static auto success_with_compiler = register_backend(backend); 126} // namespace 127``` 128 129 130## Developer Tools Integration: Debuggability 131 132Providing consistent debugging experience, be it for runtime failures or performance profiling, is important. ExecuTorch employs native Developer Tools for this purpose, which enables correlating program instructions to original PyTorch code, via debug handles. You can read more about it [here](./etrecord). 133 134Delegated program or subgraphs are opaque to ExecuTorch runtime and appear as a special `call_delegate` instruction, which asks corresponding backend to handle the execution of the subgraph or program. Due to the opaque nature of backend delgates, native Developer Tools does not have visibility into delegated program. Thus the debugging, functional or performance, experiences of delegated execution suffers significantly as compared to it's non-delegated counterpart. 135 136In order to provide consistent debugging experience to users, regardless of the use of delegation for a model, Developer Tools provide an interface to correlate delegated (sub)graph to original (sub)graph. The Developer Tools do so via debug handles map which allows delegates to generate internal handles that can be associated with the original (sub)graph consumed by the delegate. Then at runtime, backend developer can report error or profiling information using the internal handle, which will be mapped to original (sub)graph using the debug handle map. For more information, please refer to [Delegate Debugging](./delegate-debugging). 137 138By leveraging the debug identifier, backend developer can embed the debug as part of the delegated blob 139 140<img src="./_static/img/backend_debug_handle.png" alt="drawing" style="width:600px;"/> 141 142In this way, during execute stage, with the debug identifier, backend developer can associate the failed instruction inside the delegate 143back to the exact line of PyThon code. 144 145<img src="./_static/img/backend_debug_handle_example.png" alt="drawing" style="width:700px;"/> 146 147## Common Questions 148 149**1. How can we get data in backend.preprocess?** 150 151The graph module being preprocessed is a lifted graph, this means that static 152data like weights and biases are supplied as inputs to the graph. However, we 153can access the weights and biases ahead-of-time through the exported program. To 154access these parameters from a given node, we can use the function `get_params` 155provided in `torch/_export/utils.py` 156 157**2. How can we embed the data (like weight/bias) to the backend?** 158 159It's common that backends have some ways to optimize the const data. In this case, 160we'd need to tag the placeholder nodes which are also the state in the 161partitioner, and during backend.preprocess, we can follow the description in the 162first question to get the weight. 163 164**3. How can we run the lowered module in Python with the specific backend?** 165 166We haven't added the support yet but that's the plan! 167 168**4. Should we expect to see `get_attr` nodes in the edge dialect program?** 169 170`get_attr` nodes will only show up for submodules used for control flow or 171delegation. It won't hold any data. 172 173**5. Can we delegate to multiple backends?** 174 175Yes! There are two ways to do this: 176 177*Option 1: Run to_backend multiple times for different backends* 178 179If we have two backends, backend_1 and backend_2, and they have their own 180parititioners: backend_1_parititioner and backend_2_partitioner, we can run it 181like: 182 183```python 184# Will first lower nodes to backend_1 depending on the backend_1_parititioner depending on partitioner algorithm 185exported_program_backend_1 = to_backend(exported_program, backend_1_parititioner()) 186# For the rest of nodes, they will be lowered to backend_2 depending on backend_2_parititioner 187exported_program_backend_1_and_2 = to_backend(exported_program_backend_1, backend_2_parititioner()) 188``` 189 190A more concrete example be found 191[here](https://github.com/pytorch/executorch/blob/main/exir/backend/test/demos/test_xnnpack_qnnpack.py). 192In this example, 193qnnpack is one backend and xnnpack is another backend. We haven't open-sourced 194these two backends delegates yet, and this example won't run out of box. It can 195be used as a reference to see how it can be done. 196 197This option is easy to try becuase usually all backends will implement their own 198parititioner. However this option may get different results if we change the 199order of to_backend call. If we want to have a better control on the nodes, like 200which backend they should go, option 2 is better. 201 202*Option 2: Have a partitioner which partitions for different backends* 203 204Another option is to create a customized partitioner, say partitioner 205`backend_1_2_partitioner`, and inside the partitioner logic, 206 207```python 208class Backend_1_2_Partitioner(Partitioner): 209 """ 210 Partitions all add/mul nodes regardless of order for Backend2 211 """ 212 213 def __init__(self) -> None: 214 self.delegation_spec_1 = DelegationSpec("Backend1", []) 215 self.delegation_spec_2 = DelegationSpec("Backend2", []) 216 self.partition_tags = {} 217 218 def partition( 219 self, exported_program: ExportedProgram 220 ) -> ExportedProgram: 221 222 # Tag all nodes in the first partiton to backend 1 223 node_to_backend_1 = ... # some logic to select the nodes from the graph 224 delegation_tag = f"backend2_tag{partitioner_1.id}" 225 node.meta["delegation_tag"] = delegation_tag 226 self.partition_tags[delegation_tag] = self.delegation_spec_1 227 228 # Tag all nodes in the first partiton to backend 2 229 node_to_backend_2 = ... # some logic to select the nodes from the graph 230 delegation_tag = f"backend2_tag{partitioner_2.id}" 231 node.meta["delegation_tag"] = delegation_tag 232 self.partition_tags[delegation_tag] = self.delegation_spec_2 233 return exported_program 234``` 235 236**6. Is there an easy way to write a partitioner?** 237 238We provide some helper partitioners 239[here](./compiler-custom-compiler-passes.md) to make it easy to find 240nodes from decomposed operators. 241 242**7. How do we link the node back to the source code?** 243We provide an helper function 244```python 245from executorch.exir.print_program import inspect_node 246 247print(inspect_node(graph, node)) 248``` 249And it will highlight the node in the graph as well as point to the source code, example output will be like following: 250``` 251_param_constant1 error_msg: Here is the node in the graph module: 252graph(): 253 %arg0_1 : [num_users=1] = placeholder[target=arg0_1] 254 %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0] 255--> %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1] 256 %aten_convolution_default : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%arg0_1, %_param_constant0, %_param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) 257 %_param_constant2 : [num_users=1] = get_attr[target=_param_constant2] 258 %_param_constant3 : [num_users=1] = get_attr[target=_param_constant3] 259 %aten_convolution_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_convolution_default, %_param_constant2, %_param_constant3, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) 260 %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_convolution_default, %aten_convolution_default_1), kwargs = {}) 261 %_param_constant4 : [num_users=1] = get_attr[target=_param_constant4] 262 %_param_constant5 : [num_users=1] = get_attr[target=_param_constant5] 263 %aten_convolution_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten_add_tensor, %_param_constant4, %_param_constant5, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) 264 %aten_gelu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.gelu.default](args = (%aten_convolution_default_2,), kwargs = {}) 265 return [aten_gelu_default] 266This node _param_constant1 has metadata of: 267The node stacktrace: 268Traceback (most recent call last): 269 File "/tmp/ipykernel_1204253/3382880687.py", line 7, in forward 270return self.test_model(x) 271 File "/mnt/xarfuse/uid-25337/7b86ad0c-seed-nspid4026532987_cgpid2707357-ns-4026532984/torch/nn/modules/module.py", line 1528, in _call_impl 272return forward_call(*args, **kwargs) 273 File "/tmp/ipykernel_1204253/712280972.py", line 10, in forward 274a = self.conv1(x) 275``` 276