1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3# pyre-strict 4 5 6# This file contains all the helper utility functions. 7 8from itertools import zip_longest 9from math import frexp, isclose, trunc 10from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union 11 12import torch 13import torch.fx 14 15from executorch.exir.dialects._ops import ops as exir_ops 16from torch.utils._pytree import tree_flatten 17 18 19# Return the output node of the graph 20def get_output_node(graph: torch.fx.Graph) -> torch.fx.Node: 21 assert graph is not None, "Cannot get output of an empty graph" 22 output_node = next(iter(reversed(graph.nodes))) 23 assert ( 24 output_node and output_node.op == "output" and len(output_node.args) == 1 25 ), "Failed to find output node" 26 return output_node 27 28 29# Return true if the node is part of the flattened output 30def is_node_in_flattened_output(graph: torch.fx.Graph, node: torch.fx.Node) -> bool: 31 output_node = get_output_node(graph) 32 return node in tree_flatten(output_node.args[0])[0] 33 34 35# Returns a list with placeholders/inputs 36def get_placeholders(graph: torch.fx.Graph) -> List[torch.fx.Node]: 37 return list(filter(lambda x: x.op == "placeholder", graph.nodes)) 38 39 40# Return the shape of the incoming node. 41def get_shape( 42 graph_module: torch.fx.GraphModule, node: torch.fx.Node 43) -> Union[torch.Size, None]: 44 """ 45 Return the shape of the tensor correspnding to node. If the node has a 46 tensor spec, return the shape from the metadata. If the node is a param, 47 return it shape. Otherwise return None. 48 """ 49 try: 50 # Case 1. node is a scalar (this pass happens before tensorization) 51 if isinstance(node, (float, int, bool)): 52 return torch.Size([1]) 53 # Case 2. node has TensorSpec metadata 54 fake_tensor = node.meta.get("val") 55 if fake_tensor is not None: 56 return fake_tensor.shape 57 # Case 3. node holds a param 58 if node.op == "get_attr": 59 attr_node = getattr(graph_module, node.target) 60 return attr_node.shape 61 # Default: return None 62 return None 63 except RuntimeError: 64 return None 65 66 67# Return true if shape_2 can be broadcasted to shape_1 68def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool: 69 """ 70 Check if 'shape_2' can be broadcasted to 'shape_1'. The broadcast is 71 feasible if: 72 (1) shape_2 does not have higher dimensionality than shape_1; 73 (2) the value at each dimension of shape_2 is either the same as shape_1 or 1; 74 (3) shape_1 or shape_2 is empty. 75 """ 76 return ( 77 not shape_1 78 or not shape_2 79 or all( 80 x == y or y == 1 or y is None 81 for x, y in zip_longest(shape_1[::-1], shape_2[::-1]) 82 ) 83 ) 84 85 86# Return a chain of nodes with target in op_targets 87def get_cascaded_ops( 88 nodes: List[torch.fx.Node], 89 # pyre-fixme[2]: Parameter annotation cannot contain `Any`. 90 op_targets: Iterable[Union[Callable[..., Any], str]], 91) -> Sequence[torch.fx.Node]: 92 """ 93 'nodes' contains a chain of ops with target in 'op_targets'. Extend that chain 94 by one if nodes[-1] has a single user with its op target in 'op_targets'. 95 """ 96 cur = nodes[-1] 97 users = list(cur.users.keys()) 98 # Assert that (a) there is only one user of cur, and (b) that user is 99 # one of the op in op_targets. 100 if len(users) == 1 and users[0].target in op_targets: 101 nodes.append(users[0]) 102 # Recursively find the chain starting at the user 103 return get_cascaded_ops(nodes, op_targets) 104 105 return nodes 106 107 108# Capture the effect of transpose op on incoming dimension order 109def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]: 110 """ 111 Given a transpose node, and the incoming dimension ordering of the input 112 tensor to the transpose node, return the net effect of transpose op on the 113 dimension order. 114 """ 115 assert node.target == exir_ops.edge.aten.transpose_copy.int 116 # Assert that the dims is not empty 117 assert dims is not None 118 dim_len = len(dims) 119 # Get dim0 and dim1 from the transpose op args 120 transpose_dims0 = node.args[1] 121 transpose_dims1 = node.args[2] 122 assert isinstance(transpose_dims0, int) 123 assert isinstance(transpose_dims1, int) 124 dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len 125 dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len 126 # Perform transpose on dimmension ordering (dims) 127 dims[dim0], dims[dim1] = dims[dim1], dims[dim0] 128 return dims 129 130 131# Capture the effect of permute op on incoming dimension order 132def get_permuted_dims(node: torch.fx.Node, dims: Optional[List[int]]) -> List[int]: 133 """ 134 Given a permute node, and the incoming dimension ordering of the input 135 tensor to the permute node, return the net effect of permute op on the 136 dimension order. 137 """ 138 assert node.target == exir_ops.edge.aten.permute_copy.default 139 # Permute each index of the dimension ordering (dims) 140 permute_dims = node.args[1] 141 assert isinstance(permute_dims, List) 142 assert all(isinstance(x, int) for x in permute_dims) 143 # If the dims is empty, we can simply return the permute order 144 if not dims: 145 return permute_dims 146 dims = [dims[x] for x in permute_dims] 147 return dims 148 149 150# Return the tensor of buffer/parameter op 151def get_tensor_from_attr( 152 graph_module: torch.fx.GraphModule, node: Optional[torch.fx.Node] 153) -> Optional[torch.Tensor]: 154 """ 155 For an input node that is a named buffer or parameter, return 156 the underlying tensor. 157 """ 158 if node is None: 159 return None 160 assert node.op == "get_attr" 161 return getattr(graph_module, node.target) 162 163 164def is_node_with_op(node: torch.fx.Node, op: str) -> bool: 165 """ 166 Return true if the incoming node has the given op type 167 """ 168 return node.op == op 169 170 171def count_users_with_target_op_type( 172 nodes: Iterable[torch.fx.Node], 173 # pyre-fixme[2]: Parameter annotation cannot contain `Any`. 174 op_target: Union[Callable[..., Any], str], 175) -> int: 176 """ 177 Given a set of nodes and a node target type `op_target`, iterate over all 178 the users of nodes, and return the total number of users with target 179 op_target. 180 """ 181 182 def contributions_per_node( 183 node: torch.fx.Node, 184 # pyre-fixme[2]: Parameter annotation cannot contain `Any`. 185 op_target: Union[Callable[..., Any], str], 186 ) -> int: 187 return [use.target for use in node.users if use.op == "call_function"].count( 188 op_target 189 ) 190 191 return sum([contributions_per_node(node, op_target) for node in nodes]) 192 193 194def contains_node_with_matching_target( 195 nodes: Iterable[torch.fx.Node], 196 # pyre-fixme[2]: Parameter annotation cannot contain `Any`. 197 op_target: Union[Callable[..., Any], str], 198) -> bool: 199 """ 200 Given a list of nodes, return true if any node in the list has target 201 'op_target'. 202 """ 203 return any(node.target == op_target for node in nodes) 204 205 206def is_quantized_tensor(x: torch.Tensor) -> bool: 207 """ 208 Return true if the tensor x is quantized 209 """ 210 return x.is_quantized 211 212 213def get_scale(x: torch.Tensor) -> torch.Tensor: 214 """ 215 Return the scale of a quantized tensor as a float32 tensor. 216 """ 217 return ( 218 x.q_per_channel_scales().to(torch.float32) 219 if x.qscheme() == torch.per_channel_affine 220 else torch.tensor([x.q_scale()], dtype=torch.float32) 221 ) 222 223 224def get_zero_point(x: torch.Tensor, reduce: bool = True) -> torch.Tensor: 225 """ 226 Return the zero point of a quantized tensor as int32 tensor. 227 """ 228 # If x was quantized per-tensor, simply create a tensor out of the scalar 229 # zero_point, and return it. 230 if x.qscheme() == torch.per_tensor_affine: 231 return torch.tensor([x.q_zero_point()], dtype=torch.int32) 232 # If x was quantized per-channel, check if the zero_point is all zeros. If 233 # so, then we can compress the zero_point tensor to a scalar. 234 assert x.qscheme() == torch.per_channel_affine, "Unhandled quantization scheme" 235 zero_point = x.q_per_channel_zero_points().to(torch.int32) 236 return ( 237 torch.tensor([zero_point[0]], dtype=torch.int32) 238 if reduce and all(zero_point == zero_point[0]) 239 else zero_point 240 ) 241 242 243def quantize_tensor_multiplier( 244 requantize_scale_tensor: torch.Tensor, 245) -> Tuple[torch.Tensor, torch.Tensor]: 246 """ 247 Given requantize_scale_tensor with values in the interval (0, 1), 248 produce a pair of tensors (out_multiplier, right_shift) where out_multiplier 249 is an int32 tensor representing fixed-point values in the interval [-1, 1), 250 and right_shift is an amount to shift right by, so that the floating-point 251 multiplication of some int32 input with each value of requantize_scale_tensor: 252 result = int32_value * requantize_scale_tensors[i] 253 is best approximated by the integer-arithmetic-only code: 254 result = RoundingRightShift(FixedPointMultiplication(int32_value, 255 out_multiplier[i]), right_shift[i]) 256 """ 257 258 # This is identical to C++11 std::round(). The general python round rounds 259 # down, and C++ rounds away from zero. 260 # pyre-fixme[2]: Parameter must be annotated. 261 def round_away_zero(f) -> int: 262 r = -0.5 if (f < 0) else 0.5 263 return trunc(f + r) 264 265 def quantize_scalar_multiplier(requantize_scale: float) -> Tuple[int, int]: 266 significand, exponent = frexp(requantize_scale) 267 significand_q31 = int(round_away_zero(significand * (1 << 31))) 268 # Handle the special case when the real multiplier was so close to 1 269 # that its fixed-point approximation was indistinguishable from 1. 270 # We handle this by dividing it by two, incrementing exponent by 1. 271 # the right shift amount. 272 if significand_q31 == (1 << 31): 273 significand_q31 //= 2 274 exponent += 1 275 276 # Verify that the decomposition of requantize_scale into significand 277 # and exponent is correct. 278 reconstructed = significand_q31 / (1 << 31) * pow(2, exponent) 279 assert isclose( 280 requantize_scale, reconstructed, rel_tol=1e-4, abs_tol=1e-4 281 ), "computation of significand and exponent from requantize_scale is not accurate" 282 283 return (significand_q31, exponent) 284 285 # Flatten the input scale tensor so that we can operate on individual values 286 orig_shape = requantize_scale_tensor.shape 287 flattened_tensor = requantize_scale_tensor.flatten().to(torch.float32) 288 out_multiplier = torch.zeros(flattened_tensor.shape, dtype=torch.int32) 289 right_shift = torch.zeros(flattened_tensor.shape, dtype=torch.int32) 290 291 # Iterate over the flattened scale tensor and compute the decomposition of 292 # each value in scale tensor into significand(out_multiplier) and 293 # exponent(right_shift) 294 for idx, scale in enumerate(flattened_tensor): 295 (si, ex) = quantize_scalar_multiplier(scale) 296 out_multiplier[idx], right_shift[idx] = si, ex 297 298 # Reshape the tensors back to the original shape 299 out_multiplier = out_multiplier.reshape(orig_shape) 300 right_shift = right_shift.reshape(orig_shape) 301 302 return (out_multiplier, right_shift) 303