1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8import logging 9import os 10from typing import Any, cast 11 12import numpy as np 13import serializer.tosa_serializer as ts 14import torch 15from executorch.backends.arm.tosa_mapping import TosaArg 16 17from executorch.backends.arm.tosa_quant_utils import ( 18 get_quant_arg_downstream, 19 get_quant_arg_upstream, 20 q_op, 21) 22from executorch.exir.dialects._ops import ops as exir_ops 23from serializer.tosa_serializer import TosaOp 24from torch.fx import Node 25 26logger = logging.getLogger(__name__) 27logger.setLevel(logging.WARNING) 28TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" 29if TOSA_DBG_VERBOSE: 30 logging.basicConfig(level=logging.INFO) 31 logger.setLevel(logging.INFO) 32 33 34def dbg_node(node): 35 # Debug output of node information 36 logger.info("OP") 37 logger.info(f" op is {node.op}") 38 logger.info(f" name is {node.name}") 39 logger.info(f" node target is {node.target}") 40 logger.info(f" node args is {node.args}") 41 logger.info(f" node kwargs is {node.kwargs}") 42 logger.info(" node.meta = ") 43 for k, v in node.meta.items(): 44 logger.info(f" '{k}' = {v}") 45 if isinstance(v, list): 46 for i in v: 47 logger.info(f" {i} ") 48 49 50# Output TOSA flatbuffer and test harness file 51def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): 52 filename = f"output{suffix}.tosa" 53 54 logger.info(f"Emitting debug output to: {path=}, {suffix=}") 55 56 os.makedirs(path, exist_ok=True) 57 58 fb = tosa_graph.serialize() 59 js = tosa_graph.writeJson(filename) 60 61 filepath_tosa_fb = os.path.join(path, filename) 62 with open(filepath_tosa_fb, "wb") as f: 63 f.write(fb) 64 assert os.path.exists(filepath_tosa_fb), "Failed to write TOSA flatbuffer" 65 66 filepath_desc_json = os.path.join(path, f"desc{suffix}.json") 67 with open(filepath_desc_json, "w") as f: 68 f.write(js) 69 assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON" 70 71 72def dbg_fail(node, tosa_graph, path): 73 dbg_tosa_dump(tosa_graph, path) 74 logger.warn("Internal error due to poorly handled node:") 75 dbg_node(node) 76 logger.warn(f"Debug output captured in '{path}'.") 77 raise RuntimeError("TOSA Internal Error on node, enable logging for further info.") 78 79 80# Helper function to match TOSA's broadcasting rank requirement 81# Ref: TOSA 0.80.0 specification - 1.9.3. Data Layouts from 82# https://www.mlplatform.org/tosa/tosa_spec.html 83def promote_shape(tosa_fb, arg, promoted_shape, out_dtype): 84 assert np.prod(arg.shape) == np.prod(promoted_shape), "Incompatible promoted shape" 85 reshape_res = tosa_fb.addIntermediate(promoted_shape, out_dtype) 86 attr = ts.TosaSerializerAttribute() 87 attr.ReshapeAttribute(promoted_shape) 88 tosa_fb.addOperator(TosaOp.Op().RESHAPE, [arg.name], [reshape_res.name], attr) 89 return reshape_res 90 91 92# Helper transpose function to match TOSA's shape requirements 93# E.g., TOSA 0.80.0 specification - 2.3.3 CONV2D shapes: 94# https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d 95def transpose_helper(tosa_fb, input, new_order, out_dtype): 96 # Check new_order's length is equal to input rank 97 assert len(input.shape) == len(new_order), "Wrong shape order length" 98 99 # Check no duplications 100 assert len(set(new_order)) == len(new_order), "Contain duplicated dim numbers" 101 102 # Check all dims are valid 103 for idx in new_order: 104 if idx < 0: 105 assert True, "Negative dim number" 106 elif idx >= len(input.shape): 107 assert True, "Dim is greater than input rank" 108 109 input_shape_transpoed = [input.shape[i] for i in new_order] 110 attr = ts.TosaSerializerAttribute() 111 attr.TransposeAttribute(new_order) 112 input_transposed = tosa_fb.addIntermediate(input_shape_transpoed, out_dtype) 113 tosa_fb.addOperator( 114 TosaOp.Op().TRANSPOSE, [input.name], [input_transposed.name], attr 115 ) 116 return input_transposed 117 118 119def getNodeArgs(node: Node) -> list[TosaArg]: 120 return [TosaArg(arg) for arg in node.args] 121 122 123def get_input_tensor(node: Node) -> TosaArg: 124 return TosaArg(node.args[0]) 125 126 127def get_output_node(node: Node) -> Node: 128 return list(node.users)[0] 129 130 131""" TOSA reshape returns a tensor with the same type/values as the input. 132 No data conversion happens during a reshape operation. """ 133 134 135def build_reshape(tosa_fb, input_name, new_shape, output_name): 136 attr = ts.TosaSerializerAttribute() 137 attr.ReshapeAttribute(new_shape) 138 tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr) 139 140 141def is_bias_node_for_quantized_conv(node): 142 consumer_node = list(node.users)[0] 143 return ( 144 consumer_node.target == exir_ops.edge.aten.convolution.default 145 and list(consumer_node.users)[0].target == q_op 146 ) 147 148 149def is_consumer_node_depthwise_conv2d(node): 150 consumer_node = list(node.users)[0] 151 if consumer_node.target == exir_ops.edge.aten.convolution.default: 152 inputs = getNodeArgs(consumer_node) 153 group = inputs[-1] 154 in_channels = inputs[0].shape[1] 155 out_channels = inputs[1].shape[0] 156 if (in_channels == group.number) and (out_channels % in_channels) == 0: 157 return True 158 159 return False 160 161 162def build_avg_pool_2d_common( 163 node: torch.fx.Node, 164 tosa_graph: ts.TosaSerializer, 165 input_tensor: TosaArg, 166 kernel_size: list, 167 stride: list, 168 padding: list, 169 is_quant_node: bool, 170 output: TosaArg, 171): 172 accumulator_type = input_tensor.dtype 173 174 if is_quant_node: 175 # Accumulator type always is int32 when input tensor is an integer type. 176 accumulator_type = ts.DType.INT32 177 178 # Initilize zero point to zero. 179 input_zp = 0 180 output_zp = 0 181 182 if is_quant_node: 183 input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp 184 output_zp = get_quant_arg_downstream(list(node.users)[0]).zp 185 186 attr = ts.TosaSerializerAttribute() 187 attr.PoolAttribute( 188 kernel=kernel_size, 189 stride=stride, 190 pad=padding, 191 input_zp=input_zp, 192 output_zp=output_zp, 193 accum_dtype=accumulator_type, 194 ) 195 196 tosa_graph.addOperator( 197 TosaOp.Op().AVG_POOL2D, 198 [input_tensor.name], 199 [output.name], 200 attr, 201 ) 202 203 204def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: 205 """Returns two input nodes to 'node' in order. If 'node' only has one input, 206 it is returned twice. 207 208 Fails if there are no input nodes. 209 Fails if there are >2 input nodes and 'check' is True, 210 """ 211 212 num_inputs = len(node.all_input_nodes) 213 assert num_inputs > 0, f"Node '{node.name}' requires >0 input, got {num_inputs}." 214 215 input1 = node.all_input_nodes[0] 216 if num_inputs == 1: 217 input2 = node.all_input_nodes[0] 218 else: 219 input2 = node.all_input_nodes[1] 220 if check: 221 assert ( 222 num_inputs <= 2 223 ), f"Node '{node.name}' requires <=2 inputs, got {num_inputs}." 224 225 return input1, input2 226 227 228def tosa_shape(shape, dim_order): 229 return tuple([shape[dim] for dim in dim_order]) 230 231 232def expand_dims( 233 tosa_graph: ts.TosaSerializer, 234 input_node: TosaArg, 235 dtype: int, 236 dim: int, 237) -> Any: 238 """Inserts TOSA operators into the tosa_graph, that perform the equivalent 239 of the expand_dims (a.k.a unsqueeze) operation. A new axis is created at the 240 dim location. 241 242 Args: 243 tosa_graph (ts.TosaSerializer): The TOSA graph to manipulate. 244 input_node (TosaArg): The parent node of the expand dim operations. 245 dtype (ts.DType): The data type expand dims operations. 246 dim (int): The dimension to expand. 247 248 Returns: 249 Any: The output tensor of the inserted operation in the TOSA graph. 250 """ 251 new_shape = list(input_node.shape) 252 new_shape.insert(dim, 1) 253 254 intermediate = tosa_graph.addIntermediate(new_shape, dtype) 255 256 build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name) 257 258 return intermediate 259 260 261def get_resize_parameters( 262 input_size: torch.Tensor, 263 output_size: torch.Tensor, 264 resize_mode: int, 265 align_corners: bool, 266): 267 """Get the tosa.resize parameters based on the input and output size. 268 269 Args: 270 input_size (torch.Tensor): Size of the input 271 output_size (torch.Tensor): Size of the output 272 resize_mode (tosa.ResizeMode): The TOSA resize mode 273 align_corners (bool): Align the corners pixels of the input and output 274 275 Returns: 276 scale_n (torch.Tensor), scale_d (torch.Tensor), 277 offset (torch.Tensor), border (torch.Tensor) 278 """ 279 assert torch.all(input_size > 0) 280 assert torch.all(output_size > 0) 281 282 scale_n = torch.tensor( 283 [ 284 so - 1 if align_corners and si > 1 and so > 1 else so 285 for si, so in zip(input_size, output_size) 286 ] 287 ) 288 scale_d = torch.tensor( 289 [ 290 si - 1 if align_corners and si > 1 and so > 1 else si 291 for si, so in zip(input_size, output_size) 292 ] 293 ) 294 295 gcd = torch.gcd(scale_n, scale_d) 296 scale_n = scale_n // gcd 297 scale_d = scale_d // gcd 298 299 # No half-pixel centre support in PyTorch, no offset needed 300 offset = torch.zeros_like(input_size) 301 border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset 302 303 return scale_n, scale_d, offset, border 304