xref: /aosp_15_r20/external/executorch/backends/cadence/aot/compiler_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3# pyre-strict
4
5
6# This file contains all the helper utility functions.
7
8from itertools import zip_longest
9from math import frexp, isclose, trunc
10from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
11
12import torch
13import torch.fx
14
15from executorch.exir.dialects._ops import ops as exir_ops
16from torch.utils._pytree import tree_flatten
17
18
19# Return the output node of the graph
20def get_output_node(graph: torch.fx.Graph) -> torch.fx.Node:
21    assert graph is not None, "Cannot get output of an empty graph"
22    output_node = next(iter(reversed(graph.nodes)))
23    assert (
24        output_node and output_node.op == "output" and len(output_node.args) == 1
25    ), "Failed to find output node"
26    return output_node
27
28
29# Return true if the node is part of the flattened output
30def is_node_in_flattened_output(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
31    output_node = get_output_node(graph)
32    return node in tree_flatten(output_node.args[0])[0]
33
34
35# Returns a list with placeholders/inputs
36def get_placeholders(graph: torch.fx.Graph) -> List[torch.fx.Node]:
37    return list(filter(lambda x: x.op == "placeholder", graph.nodes))
38
39
40# Return the shape of the incoming node.
41def get_shape(
42    graph_module: torch.fx.GraphModule, node: torch.fx.Node
43) -> Union[torch.Size, None]:
44    """
45    Return the shape of the tensor correspnding to node. If the node has a
46    tensor spec, return the shape from the metadata. If the node is a param,
47    return it shape. Otherwise return None.
48    """
49    try:
50        # Case 1. node is a scalar (this pass happens before tensorization)
51        if isinstance(node, (float, int, bool)):
52            return torch.Size([1])
53        # Case 2. node has TensorSpec metadata
54        fake_tensor = node.meta.get("val")
55        if fake_tensor is not None:
56            return fake_tensor.shape
57        # Case 3. node holds a param
58        if node.op == "get_attr":
59            attr_node = getattr(graph_module, node.target)
60            return attr_node.shape
61        # Default: return None
62        return None
63    except RuntimeError:
64        return None
65
66
67# Return true if shape_2 can be broadcasted to shape_1
68def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
69    """
70    Check if 'shape_2' can be broadcasted to 'shape_1'. The broadcast is
71    feasible if:
72    (1) shape_2 does not have higher dimensionality than shape_1;
73    (2) the value at each dimension of shape_2 is either the same as shape_1 or 1;
74    (3) shape_1 or shape_2 is empty.
75    """
76    return (
77        not shape_1
78        or not shape_2
79        or all(
80            x == y or y == 1 or y is None
81            for x, y in zip_longest(shape_1[::-1], shape_2[::-1])
82        )
83    )
84
85
86# Return a chain of nodes with target in op_targets
87def get_cascaded_ops(
88    nodes: List[torch.fx.Node],
89    # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
90    op_targets: Iterable[Union[Callable[..., Any], str]],
91) -> Sequence[torch.fx.Node]:
92    """
93    'nodes' contains a chain of ops with target in 'op_targets'. Extend that chain
94    by one if nodes[-1] has a single user with its op target in 'op_targets'.
95    """
96    cur = nodes[-1]
97    users = list(cur.users.keys())
98    # Assert that (a) there is only one user of cur, and (b) that user is
99    # one of the op in op_targets.
100    if len(users) == 1 and users[0].target in op_targets:
101        nodes.append(users[0])
102        # Recursively find the chain starting at the user
103        return get_cascaded_ops(nodes, op_targets)
104
105    return nodes
106
107
108# Capture the effect of transpose op on incoming dimension order
109def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
110    """
111    Given a transpose node, and the incoming dimension ordering of the input
112    tensor to the transpose node, return the net effect of transpose op on the
113    dimension order.
114    """
115    assert node.target == exir_ops.edge.aten.transpose_copy.int
116    # Assert that the dims is not empty
117    assert dims is not None
118    dim_len = len(dims)
119    # Get dim0 and dim1 from the transpose op args
120    transpose_dims0 = node.args[1]
121    transpose_dims1 = node.args[2]
122    assert isinstance(transpose_dims0, int)
123    assert isinstance(transpose_dims1, int)
124    dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len
125    dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len
126    # Perform transpose on dimmension ordering (dims)
127    dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
128    return dims
129
130
131# Capture the effect of permute op on incoming dimension order
132def get_permuted_dims(node: torch.fx.Node, dims: Optional[List[int]]) -> List[int]:
133    """
134    Given a permute node, and the incoming dimension ordering of the input
135    tensor to the permute node, return the net effect of permute op on the
136    dimension order.
137    """
138    assert node.target == exir_ops.edge.aten.permute_copy.default
139    # Permute each index of the dimension ordering (dims)
140    permute_dims = node.args[1]
141    assert isinstance(permute_dims, List)
142    assert all(isinstance(x, int) for x in permute_dims)
143    # If the dims is empty, we can simply return the permute order
144    if not dims:
145        return permute_dims
146    dims = [dims[x] for x in permute_dims]
147    return dims
148
149
150# Return the tensor of buffer/parameter op
151def get_tensor_from_attr(
152    graph_module: torch.fx.GraphModule, node: Optional[torch.fx.Node]
153) -> Optional[torch.Tensor]:
154    """
155    For an input node that is a named buffer or parameter, return
156    the underlying tensor.
157    """
158    if node is None:
159        return None
160    assert node.op == "get_attr"
161    return getattr(graph_module, node.target)
162
163
164def is_node_with_op(node: torch.fx.Node, op: str) -> bool:
165    """
166    Return true if the incoming node has the given op type
167    """
168    return node.op == op
169
170
171def count_users_with_target_op_type(
172    nodes: Iterable[torch.fx.Node],
173    # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
174    op_target: Union[Callable[..., Any], str],
175) -> int:
176    """
177    Given a set of nodes and a node target type `op_target`, iterate over all
178    the users of nodes, and return the total number of users with target
179    op_target.
180    """
181
182    def contributions_per_node(
183        node: torch.fx.Node,
184        # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
185        op_target: Union[Callable[..., Any], str],
186    ) -> int:
187        return [use.target for use in node.users if use.op == "call_function"].count(
188            op_target
189        )
190
191    return sum([contributions_per_node(node, op_target) for node in nodes])
192
193
194def contains_node_with_matching_target(
195    nodes: Iterable[torch.fx.Node],
196    # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
197    op_target: Union[Callable[..., Any], str],
198) -> bool:
199    """
200    Given a list of nodes, return true if any node in the list has target
201    'op_target'.
202    """
203    return any(node.target == op_target for node in nodes)
204
205
206def is_quantized_tensor(x: torch.Tensor) -> bool:
207    """
208    Return true if the tensor x is quantized
209    """
210    return x.is_quantized
211
212
213def get_scale(x: torch.Tensor) -> torch.Tensor:
214    """
215    Return the scale of a quantized tensor as a float32 tensor.
216    """
217    return (
218        x.q_per_channel_scales().to(torch.float32)
219        if x.qscheme() == torch.per_channel_affine
220        else torch.tensor([x.q_scale()], dtype=torch.float32)
221    )
222
223
224def get_zero_point(x: torch.Tensor, reduce: bool = True) -> torch.Tensor:
225    """
226    Return the zero point of a quantized tensor as int32 tensor.
227    """
228    # If x was quantized per-tensor, simply create a tensor out of the scalar
229    # zero_point, and return it.
230    if x.qscheme() == torch.per_tensor_affine:
231        return torch.tensor([x.q_zero_point()], dtype=torch.int32)
232    # If x was quantized per-channel, check if the zero_point is all zeros. If
233    # so, then we can compress the zero_point tensor to a scalar.
234    assert x.qscheme() == torch.per_channel_affine, "Unhandled quantization scheme"
235    zero_point = x.q_per_channel_zero_points().to(torch.int32)
236    return (
237        torch.tensor([zero_point[0]], dtype=torch.int32)
238        if reduce and all(zero_point == zero_point[0])
239        else zero_point
240    )
241
242
243def quantize_tensor_multiplier(
244    requantize_scale_tensor: torch.Tensor,
245) -> Tuple[torch.Tensor, torch.Tensor]:
246    """
247    Given requantize_scale_tensor with values in the interval (0, 1),
248    produce a pair of tensors (out_multiplier, right_shift) where out_multiplier
249    is an int32 tensor representing fixed-point values in the interval [-1, 1),
250    and right_shift is an amount to shift right by, so that the floating-point
251    multiplication of some int32 input with each value of requantize_scale_tensor:
252        result = int32_value * requantize_scale_tensors[i]
253    is best approximated by the integer-arithmetic-only code:
254        result = RoundingRightShift(FixedPointMultiplication(int32_value,
255                                    out_multiplier[i]), right_shift[i])
256    """
257
258    # This is identical to C++11 std::round(). The general python round rounds
259    # down, and C++ rounds away from zero.
260    # pyre-fixme[2]: Parameter must be annotated.
261    def round_away_zero(f) -> int:
262        r = -0.5 if (f < 0) else 0.5
263        return trunc(f + r)
264
265    def quantize_scalar_multiplier(requantize_scale: float) -> Tuple[int, int]:
266        significand, exponent = frexp(requantize_scale)
267        significand_q31 = int(round_away_zero(significand * (1 << 31)))
268        # Handle the special case when the real multiplier was so close to 1
269        # that its fixed-point approximation was indistinguishable from 1.
270        # We handle this by dividing it by two, incrementing exponent by 1.
271        # the right shift amount.
272        if significand_q31 == (1 << 31):
273            significand_q31 //= 2
274            exponent += 1
275
276        # Verify that the decomposition of requantize_scale into significand
277        # and exponent is correct.
278        reconstructed = significand_q31 / (1 << 31) * pow(2, exponent)
279        assert isclose(
280            requantize_scale, reconstructed, rel_tol=1e-4, abs_tol=1e-4
281        ), "computation of significand and exponent from requantize_scale is not accurate"
282
283        return (significand_q31, exponent)
284
285    # Flatten the input scale tensor so that we can operate on individual values
286    orig_shape = requantize_scale_tensor.shape
287    flattened_tensor = requantize_scale_tensor.flatten().to(torch.float32)
288    out_multiplier = torch.zeros(flattened_tensor.shape, dtype=torch.int32)
289    right_shift = torch.zeros(flattened_tensor.shape, dtype=torch.int32)
290
291    # Iterate over the flattened scale tensor and compute the decomposition of
292    # each value in scale tensor into significand(out_multiplier) and
293    # exponent(right_shift)
294    for idx, scale in enumerate(flattened_tensor):
295        (si, ex) = quantize_scalar_multiplier(scale)
296        out_multiplier[idx], right_shift[idx] = si, ex
297
298    # Reshape the tensors back to the original shape
299    out_multiplier = out_multiplier.reshape(orig_shape)
300    right_shift = right_shift.reshape(orig_shape)
301
302    return (out_multiplier, right_shift)
303