1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8# Utiliy functions for TOSA quantized lowerings 9 10import math 11from typing import Callable, cast, NamedTuple, Sequence 12 13import numpy as np 14 15import serializer.tosa_serializer as ts 16import torch.fx 17import tosa.Op as TosaOp 18from executorch.backends.arm.tosa_mapping import TosaArg 19from executorch.exir.dialects._ops import ops as exir_ops 20from serializer.tosa_serializer import TosaSerializerTensor 21from torch.fx import Node 22 23 24q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default 25dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default 26dq_q_ops = (q_op, dq_op) 27passable_ops = [ 28 exir_ops.edge.aten.view_copy.default, 29 exir_ops.edge.aten.permute_copy.default, 30 exir_ops.edge.aten.squeeze_copy.dims, 31 exir_ops.edge.aten.unsqueeze_copy.default, 32 exir_ops.edge.aten.split_with_sizes_copy.default, 33 exir_ops.edge.aten.repeat.default, 34 exir_ops.edge.aten.clone.default, 35 exir_ops.edge.aten.slice_copy.Tensor, 36 exir_ops.edge.aten.cat.default, 37] 38 39 40def register_passable_op(op): 41 """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" 42 passable_ops.append(op) 43 44 45class QuantArgs(NamedTuple): 46 scale: float 47 zp: int 48 qmin: int 49 qmax: int 50 dtype: torch.dtype 51 52 def quantize_value(self, x): 53 if not isinstance(x, torch.Tensor): 54 x = torch.Tensor([x]) 55 return torch.clip( 56 torch.round(x / self.scale) + self.zp, 57 self.qmin, 58 self.qmax, 59 ).to(self.dtype) 60 61 def dequantize_value(self, qx: int) -> float: 62 return (qx - self.zp) * self.scale 63 64 65def quantize_value(x, qargs: QuantArgs, dtype=np.int8): 66 return np.clip( 67 np.round(x / qargs.scale) + qargs.zp, 68 qargs.qmin, 69 qargs.qmax, 70 ).astype(dtype) 71 72 73def dequantize_value(qx, qargs: QuantArgs): 74 return (qx - qargs.zp) * qargs.scale 75 76 77def qargs_from_qnode(node: torch.fx.Node): 78 assert node.target in dq_q_ops, f"Op {node} is not a quant node." 79 80 return QuantArgs( 81 scale=cast(float, node.args[1]), 82 zp=cast(int, node.args[2]), 83 qmin=cast(int, node.args[3]), 84 qmax=cast(int, node.args[4]), 85 dtype=cast(torch.dtype, node.args[5]), 86 ) 87 88 89def get_neighbour_quant_args( 90 node: torch.fx.Node, 91) -> tuple[list[QuantArgs], list[QuantArgs]]: 92 user_q_args = [] 93 94 for user in node.users: 95 q_args = search_quant_arg_downstream(user) 96 if q_args: 97 user_q_args.append(q_args) 98 99 input_q_nodes = [] 100 for input_node in node.all_input_nodes: 101 q_args = search_quant_arg_upstream(input_node) 102 if q_args: 103 input_q_nodes.append(q_args) 104 return user_q_args, input_q_nodes 105 106 107def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: 108 first_q_arg = q_arg_list[0] 109 for q_arg in q_arg_list: 110 if q_arg != first_q_arg: 111 return False 112 return True 113 114 115def is_node_quantized(node: torch.fx.Node) -> bool: 116 if node.target in dq_q_ops: 117 return True 118 119 user_q_args, input_q_args = get_neighbour_quant_args(node) 120 121 # If we did not find any neighbouring quant nodes, we are not quantized. 122 if len(input_q_args) == 0 and len(user_q_args) == 0: 123 return False 124 125 if node.target in passable_ops: 126 assert all_q_args_equal( 127 user_q_args + input_q_args 128 ), f"Node {node} needs same quantization parameters on all inputs and outputs." 129 130 return True 131 132 133def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: 134 """ 135 Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, 136 starting with 'node'. 137 If a passable node with multiple consumers is encountered, 138 find QuantArgs for all consumers and assert that they are equal. 139 If a node not in passable_ops is encountered, return None. 140 If a node without consumers is encountered, return None. 141 """ 142 if node.target in dq_q_ops: 143 return qargs_from_qnode(node) 144 if node.target not in passable_ops: 145 return None 146 consumer_nodes = list(node.users) 147 if len(consumer_nodes) == 0: 148 return None 149 elif len(consumer_nodes) == 1: 150 return search_quant_arg_downstream(consumer_nodes[0]) 151 else: 152 consumer_qargs: list[QuantArgs] = [] 153 for input in consumer_nodes: 154 quant_args = search_quant_arg_downstream(input) 155 if quant_args: 156 consumer_qargs.append(quant_args) 157 if len(consumer_qargs) == 0: 158 return None 159 assert all_q_args_equal( 160 consumer_qargs 161 ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." 162 return consumer_qargs[0] 163 164 165def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs: 166 """Calls search_quant_arg_downstream and asserts that QuantArgs are found, 167 meaning return value can't be None. 168 """ 169 qargs = search_quant_arg_downstream(node) 170 assert qargs, f"Did not find QuantArgs downstream for node {node}" 171 return qargs 172 173 174def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: 175 """ 176 Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, 177 starting with 'node'. 178 If a passable node with multiple inputs is encountered, 179 find QuantArgs for all inputs and assert that they are equal. 180 If a node not in passable_ops is encountered, return None. 181 If a node without inputs is encountered, return None. 182 """ 183 184 if node.target in dq_q_ops: 185 return qargs_from_qnode(node) 186 if node.target not in passable_ops: 187 return None 188 input_nodes = list(node.all_input_nodes) 189 if len(input_nodes) == 0: 190 return None 191 elif len(input_nodes) == 1: 192 return search_quant_arg_upstream(input_nodes[0]) 193 else: 194 input_qargs: list[QuantArgs] = [] 195 for input in input_nodes: 196 quant_args = search_quant_arg_upstream(input) 197 if quant_args: 198 input_qargs.append(quant_args) 199 if len(input_qargs) == 0: 200 return None 201 assert all_q_args_equal( 202 input_qargs 203 ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." 204 return input_qargs[0] 205 206 207def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: 208 """Calls search_quant_arg_upstream and asserts that QuantArgs are found, 209 meaning return value can't be None. 210 """ 211 qargs = search_quant_arg_upstream(node) 212 assert qargs, f"Did not find QuantArgs upstream for node {node}" 213 return qargs 214 215 216def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: 217 if isinstance(node.target, Callable) and "tosa" in node.target.__name__: 218 return node.meta["val"].dtype 219 if node.target in dq_q_ops: 220 return cast(torch.dtype, node.args[5]) 221 222 # if not a tosa node, nor a q/dq op, walk the graph until we find a q op 223 user_q_args, input_q_args = get_neighbour_quant_args(node) 224 if len(user_q_args) > 0: 225 return user_q_args[0].dtype 226 elif node.target in passable_ops and len(input_q_args) > 0: 227 return input_q_args[0].dtype 228 else: 229 raise RuntimeError("No quantized node found in graph") 230 231 232# Check if scale32 mode is used for given output element type 233def is_scale32(type): 234 return type == ts.DType.INT8 235 236 237# TOSA uses the RESCALE operation to scale between values with differing precision. 238# The RESCALE operator is defined using an integer multiply, add, and shift. 239# This utility function is for calculating the multier and shift given a scale. 240# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling 241def compute_multiplier_and_shift(scale, scaleWidth=32): 242 if scaleWidth == 16: 243 offset = 15 244 elif scaleWidth == 32: 245 offset = 31 246 else: 247 raise AssertionError("unsupported scale width") 248 249 assert isinstance(scale, float) 250 251 mantissa, exponent = math.frexp(scale) 252 shift = exponent 253 254 const_2_power_15_or_31 = 1 << offset 255 shifted_mantissa = round(mantissa * const_2_power_15_or_31) 256 257 assert shifted_mantissa <= const_2_power_15_or_31 258 259 if shifted_mantissa == const_2_power_15_or_31: 260 shifted_mantissa = shifted_mantissa / 2 261 shift += 1 262 263 # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits. 264 shift = offset - shift 265 266 # INT32_MAX, 2^31 - 1 267 assert shifted_mantissa <= (const_2_power_15_or_31 - 1) 268 269 multiplier = shifted_mantissa 270 271 if shift > 62: 272 multiplier = multiplier >> min(31, shift - 62) 273 shift = 62 274 return multiplier, shift 275 276 277def build_rescale( 278 tosa_fb, 279 scale, 280 input_node, 281 output_name, 282 output_type, 283 output_shape, 284 input_zp, 285 output_zp, 286 is_double_round=False, 287): 288 scale_width = 32 if is_scale32(output_type) else 16 289 multiplier, shift = compute_multiplier_and_shift(scale, scale_width) 290 291 attr_rescale = ts.TosaSerializerAttribute() 292 attr_rescale.RescaleAttribute( 293 input_zp=input_zp, 294 output_zp=output_zp, 295 multiplier=[multiplier], 296 shift=[shift], 297 scale32=is_scale32(output_type), 298 double_round=is_double_round, 299 per_channel=False, 300 input_unsigned=False, 301 output_unsigned=False, 302 ) 303 304 tosa_fb.addOperator( 305 TosaOp.Op().RESCALE, [input_node.name], [output_name], attr_rescale 306 ) 307 308 return 309 310 311def build_rescale_to_int32( 312 tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=False 313) -> TosaSerializerTensor: 314 multiplier, shift = compute_multiplier_and_shift(rescale_scale) 315 attr_rescale = ts.TosaSerializerAttribute() 316 attr_rescale.RescaleAttribute( 317 input_zp=input_zp, 318 output_zp=0, 319 multiplier=[multiplier], 320 shift=[shift], 321 scale32=is_scale32, 322 double_round=is_double_round, 323 per_channel=False, 324 input_unsigned=False, 325 output_unsigned=False, 326 ) 327 input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32) 328 tosa_fb.addOperator( 329 TosaOp.Op().RESCALE, 330 [input.name], 331 [input_A_rescaled_to_int32.name], 332 attr_rescale, 333 ) 334 335 return input_A_rescaled_to_int32 336 337 338def build_rescale_from_int32( 339 tosa_fb, 340 input_name, 341 output_name, 342 output_zp, 343 rescale_scale, 344 is_scale32=True, 345 is_double_round=False, 346) -> None: 347 multiplier, shift = compute_multiplier_and_shift(rescale_scale) 348 attr_rescale_output = ts.TosaSerializerAttribute() 349 attr_rescale_output.RescaleAttribute( 350 input_zp=0, 351 output_zp=output_zp, 352 multiplier=[multiplier], 353 shift=[shift], 354 scale32=is_scale32, 355 double_round=is_double_round, 356 per_channel=False, 357 input_unsigned=False, 358 output_unsigned=False, 359 ) 360 361 tosa_fb.addOperator( 362 TosaOp.Op().RESCALE, [input_name], [output_name], attr_rescale_output 363 ) 364 365 return 366 367 368def rescale_nodes_to_int32( 369 nodes: Sequence[Node], tosa_graph: ts.TosaSerializer 370) -> tuple[list[TosaSerializerTensor], float]: 371 """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. 372 The scales are adjusted using the smallest scale of all 'nodes'. 373 374 Returns a list of the rescaled nodes and the scale factor used, 375 needed by rescale_node_back_to_int8. 376 """ 377 378 tensors = [TosaArg(node) for node in nodes] 379 380 # Reshape tensor according to tosa dim order 381 for tensor in tensors: 382 dim_order = tensor.dim_order 383 tensor.shape = [tensor.shape[i] for i in dim_order] 384 385 qargs = [get_quant_arg_upstream(node) for node in nodes] 386 387 # Scale the int8 quantized input to a common scale in the integer 388 # domain 389 min_scale = min([qarg.scale for qarg in qargs]) 390 scales = [qarg.scale / min_scale for qarg in qargs] 391 392 rescaled_nodes: list[TosaSerializerTensor] = [] 393 for tensor, qarg, scale in zip(tensors, qargs, scales): 394 rescaled_nodes.append( 395 build_rescale_to_int32( 396 tosa_graph, 397 tensor, 398 qarg.zp, 399 scale, 400 ) 401 ) 402 return rescaled_nodes, min_scale 403 404 405def rescale_node_back_to_int8( 406 node: Node, 407 last_tensor: TosaSerializerTensor, 408 scale: float, 409 tosa_graph: ts.TosaSerializer, 410): 411 """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. 412 Parameters: 413 node: The original node that is being handled by the rescales. 414 last_tensor:the tosa tensor to rescale back. 415 scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' 416 tosa_graph: the tosa_graph to manipulate. 417 """ 418 qargs_out = get_quant_arg_downstream(list(node.users)[0]) 419 output_rescale_scale = scale / qargs_out.scale 420 421 # Rescale Back to INT8 422 build_rescale_from_int32( 423 tosa_graph, 424 last_tensor.name, 425 node.name, 426 qargs_out.zp, 427 output_rescale_scale, 428 ) 429 430 431""" Creates a TOSA rescale op based on conv2d parameters. """ 432 433 434def build_rescale_conv_output( 435 tosa_fb, 436 op, 437 output_name, 438 output_type, 439 input_scale, 440 weight_scale, 441 output_scale, 442 output_zp, 443): 444 # TODO add check to verify if this is a Per-channel quantization. 445 post_conv2d_scale = (input_scale * weight_scale) / output_scale 446 447 # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0. 448 build_rescale( 449 tosa_fb, 450 post_conv2d_scale, 451 op, 452 output_name, 453 output_type, 454 op.shape, 455 0, 456 output_zp, 457 ) 458 return 459