# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe # Utiliy functions for TOSA quantized lowerings import math from typing import Callable, cast, NamedTuple, Sequence import numpy as np import serializer.tosa_serializer as ts import torch.fx import tosa.Op as TosaOp from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaSerializerTensor from torch.fx import Node q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default dq_q_ops = (q_op, dq_op) passable_ops = [ exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.permute_copy.default, exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.cat.default, ] def register_passable_op(op): """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" passable_ops.append(op) class QuantArgs(NamedTuple): scale: float zp: int qmin: int qmax: int dtype: torch.dtype def quantize_value(self, x): if not isinstance(x, torch.Tensor): x = torch.Tensor([x]) return torch.clip( torch.round(x / self.scale) + self.zp, self.qmin, self.qmax, ).to(self.dtype) def dequantize_value(self, qx: int) -> float: return (qx - self.zp) * self.scale def quantize_value(x, qargs: QuantArgs, dtype=np.int8): return np.clip( np.round(x / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax, ).astype(dtype) def dequantize_value(qx, qargs: QuantArgs): return (qx - qargs.zp) * qargs.scale def qargs_from_qnode(node: torch.fx.Node): assert node.target in dq_q_ops, f"Op {node} is not a quant node." return QuantArgs( scale=cast(float, node.args[1]), zp=cast(int, node.args[2]), qmin=cast(int, node.args[3]), qmax=cast(int, node.args[4]), dtype=cast(torch.dtype, node.args[5]), ) def get_neighbour_quant_args( node: torch.fx.Node, ) -> tuple[list[QuantArgs], list[QuantArgs]]: user_q_args = [] for user in node.users: q_args = search_quant_arg_downstream(user) if q_args: user_q_args.append(q_args) input_q_nodes = [] for input_node in node.all_input_nodes: q_args = search_quant_arg_upstream(input_node) if q_args: input_q_nodes.append(q_args) return user_q_args, input_q_nodes def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: first_q_arg = q_arg_list[0] for q_arg in q_arg_list: if q_arg != first_q_arg: return False return True def is_node_quantized(node: torch.fx.Node) -> bool: if node.target in dq_q_ops: return True user_q_args, input_q_args = get_neighbour_quant_args(node) # If we did not find any neighbouring quant nodes, we are not quantized. if len(input_q_args) == 0 and len(user_q_args) == 0: return False if node.target in passable_ops: assert all_q_args_equal( user_q_args + input_q_args ), f"Node {node} needs same quantization parameters on all inputs and outputs." return True def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: """ Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, starting with 'node'. If a passable node with multiple consumers is encountered, find QuantArgs for all consumers and assert that they are equal. If a node not in passable_ops is encountered, return None. If a node without consumers is encountered, return None. """ if node.target in dq_q_ops: return qargs_from_qnode(node) if node.target not in passable_ops: return None consumer_nodes = list(node.users) if len(consumer_nodes) == 0: return None elif len(consumer_nodes) == 1: return search_quant_arg_downstream(consumer_nodes[0]) else: consumer_qargs: list[QuantArgs] = [] for input in consumer_nodes: quant_args = search_quant_arg_downstream(input) if quant_args: consumer_qargs.append(quant_args) if len(consumer_qargs) == 0: return None assert all_q_args_equal( consumer_qargs ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." return consumer_qargs[0] def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs: """Calls search_quant_arg_downstream and asserts that QuantArgs are found, meaning return value can't be None. """ qargs = search_quant_arg_downstream(node) assert qargs, f"Did not find QuantArgs downstream for node {node}" return qargs def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: """ Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, starting with 'node'. If a passable node with multiple inputs is encountered, find QuantArgs for all inputs and assert that they are equal. If a node not in passable_ops is encountered, return None. If a node without inputs is encountered, return None. """ if node.target in dq_q_ops: return qargs_from_qnode(node) if node.target not in passable_ops: return None input_nodes = list(node.all_input_nodes) if len(input_nodes) == 0: return None elif len(input_nodes) == 1: return search_quant_arg_upstream(input_nodes[0]) else: input_qargs: list[QuantArgs] = [] for input in input_nodes: quant_args = search_quant_arg_upstream(input) if quant_args: input_qargs.append(quant_args) if len(input_qargs) == 0: return None assert all_q_args_equal( input_qargs ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." return input_qargs[0] def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: """Calls search_quant_arg_upstream and asserts that QuantArgs are found, meaning return value can't be None. """ qargs = search_quant_arg_upstream(node) assert qargs, f"Did not find QuantArgs upstream for node {node}" return qargs def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: if isinstance(node.target, Callable) and "tosa" in node.target.__name__: return node.meta["val"].dtype if node.target in dq_q_ops: return cast(torch.dtype, node.args[5]) # if not a tosa node, nor a q/dq op, walk the graph until we find a q op user_q_args, input_q_args = get_neighbour_quant_args(node) if len(user_q_args) > 0: return user_q_args[0].dtype elif node.target in passable_ops and len(input_q_args) > 0: return input_q_args[0].dtype else: raise RuntimeError("No quantized node found in graph") # Check if scale32 mode is used for given output element type def is_scale32(type): return type == ts.DType.INT8 # TOSA uses the RESCALE operation to scale between values with differing precision. # The RESCALE operator is defined using an integer multiply, add, and shift. # This utility function is for calculating the multier and shift given a scale. # Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling def compute_multiplier_and_shift(scale, scaleWidth=32): if scaleWidth == 16: offset = 15 elif scaleWidth == 32: offset = 31 else: raise AssertionError("unsupported scale width") assert isinstance(scale, float) mantissa, exponent = math.frexp(scale) shift = exponent const_2_power_15_or_31 = 1 << offset shifted_mantissa = round(mantissa * const_2_power_15_or_31) assert shifted_mantissa <= const_2_power_15_or_31 if shifted_mantissa == const_2_power_15_or_31: shifted_mantissa = shifted_mantissa / 2 shift += 1 # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits. shift = offset - shift # INT32_MAX, 2^31 - 1 assert shifted_mantissa <= (const_2_power_15_or_31 - 1) multiplier = shifted_mantissa if shift > 62: multiplier = multiplier >> min(31, shift - 62) shift = 62 return multiplier, shift def build_rescale( tosa_fb, scale, input_node, output_name, output_type, output_shape, input_zp, output_zp, is_double_round=False, ): scale_width = 32 if is_scale32(output_type) else 16 multiplier, shift = compute_multiplier_and_shift(scale, scale_width) attr_rescale = ts.TosaSerializerAttribute() attr_rescale.RescaleAttribute( input_zp=input_zp, output_zp=output_zp, multiplier=[multiplier], shift=[shift], scale32=is_scale32(output_type), double_round=is_double_round, per_channel=False, input_unsigned=False, output_unsigned=False, ) tosa_fb.addOperator( TosaOp.Op().RESCALE, [input_node.name], [output_name], attr_rescale ) return def build_rescale_to_int32( tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=False ) -> TosaSerializerTensor: multiplier, shift = compute_multiplier_and_shift(rescale_scale) attr_rescale = ts.TosaSerializerAttribute() attr_rescale.RescaleAttribute( input_zp=input_zp, output_zp=0, multiplier=[multiplier], shift=[shift], scale32=is_scale32, double_round=is_double_round, per_channel=False, input_unsigned=False, output_unsigned=False, ) input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32) tosa_fb.addOperator( TosaOp.Op().RESCALE, [input.name], [input_A_rescaled_to_int32.name], attr_rescale, ) return input_A_rescaled_to_int32 def build_rescale_from_int32( tosa_fb, input_name, output_name, output_zp, rescale_scale, is_scale32=True, is_double_round=False, ) -> None: multiplier, shift = compute_multiplier_and_shift(rescale_scale) attr_rescale_output = ts.TosaSerializerAttribute() attr_rescale_output.RescaleAttribute( input_zp=0, output_zp=output_zp, multiplier=[multiplier], shift=[shift], scale32=is_scale32, double_round=is_double_round, per_channel=False, input_unsigned=False, output_unsigned=False, ) tosa_fb.addOperator( TosaOp.Op().RESCALE, [input_name], [output_name], attr_rescale_output ) return def rescale_nodes_to_int32( nodes: Sequence[Node], tosa_graph: ts.TosaSerializer ) -> tuple[list[TosaSerializerTensor], float]: """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. The scales are adjusted using the smallest scale of all 'nodes'. Returns a list of the rescaled nodes and the scale factor used, needed by rescale_node_back_to_int8. """ tensors = [TosaArg(node) for node in nodes] # Reshape tensor according to tosa dim order for tensor in tensors: dim_order = tensor.dim_order tensor.shape = [tensor.shape[i] for i in dim_order] qargs = [get_quant_arg_upstream(node) for node in nodes] # Scale the int8 quantized input to a common scale in the integer # domain min_scale = min([qarg.scale for qarg in qargs]) scales = [qarg.scale / min_scale for qarg in qargs] rescaled_nodes: list[TosaSerializerTensor] = [] for tensor, qarg, scale in zip(tensors, qargs, scales): rescaled_nodes.append( build_rescale_to_int32( tosa_graph, tensor, qarg.zp, scale, ) ) return rescaled_nodes, min_scale def rescale_node_back_to_int8( node: Node, last_tensor: TosaSerializerTensor, scale: float, tosa_graph: ts.TosaSerializer, ): """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. Parameters: node: The original node that is being handled by the rescales. last_tensor:the tosa tensor to rescale back. scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' tosa_graph: the tosa_graph to manipulate. """ qargs_out = get_quant_arg_downstream(list(node.users)[0]) output_rescale_scale = scale / qargs_out.scale # Rescale Back to INT8 build_rescale_from_int32( tosa_graph, last_tensor.name, node.name, qargs_out.zp, output_rescale_scale, ) """ Creates a TOSA rescale op based on conv2d parameters. """ def build_rescale_conv_output( tosa_fb, op, output_name, output_type, input_scale, weight_scale, output_scale, output_zp, ): # TODO add check to verify if this is a Per-channel quantization. post_conv2d_scale = (input_scale * weight_scale) / output_scale # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0. build_rescale( tosa_fb, post_conv2d_scale, op, output_name, output_type, op.shape, 0, output_zp, ) return