xref: /aosp_15_r20/external/executorch/exir/passes/sym_shape_eval_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from typing import Callable, List, Optional
10
11import torch
12import torch.utils._pytree as pytree
13
14from executorch.exir._warnings import deprecated
15from executorch.exir.dialects._ops import ops as exir_ops
16from executorch.exir.pass_base import PassBase, PassResult
17from executorch.exir.sym_util import eval_expr, eval_shape, eval_upper_bound
18from executorch.exir.tensor import TensorSpec
19from torch.fx import GraphModule
20
21upper_bound_shape_inference_table = {}
22
23
24def register_upper_bound_inference(fn):
25    def inference_deco(f: Callable):
26        upper_bound_shape_inference_table[fn] = f
27        return f
28
29    return inference_deco
30
31
32@register_upper_bound_inference(exir_ops.edge.aten.nonzero.default)
33@register_upper_bound_inference(torch.ops.aten.nonzero.default)
34def nonzero(args, kwargs) -> List[Optional[int]]:
35    return [eval_expr(args[0].numel()), len(args[0].shape)]
36
37
38@register_upper_bound_inference(exir_ops.edge.aten.index.Tensor)
39@register_upper_bound_inference(torch.ops.aten.index.Tensor)
40def index_Tensor(args, kwargs) -> List[Optional[int]]:  # noqa: C901
41    tensor = args[0]
42    indices = args[1]
43
44    # Compute numbers of contiguous blocks of non-null indices.
45    # For example, if A, B, C, D, E are non-null tensors, then
46    # [None, None, A, B, None, C, D, E, None] has 2 blocks.
47    index_blocks = 0
48    in_block = False
49    for index in indices:
50        if index is not None:
51            if not in_block:
52                in_block = True
53                index_blocks += 1
54        else:
55            in_block = False
56
57    if index_blocks == 0:
58        # If no dimensions are actually being indexed, either because the indices list is empty
59        # or all indices are null, then the result is just the same as the input tensor.
60        return tensor.shape
61
62    adjacent = index_blocks == 1
63
64    # Number of leading null indices in the indices list.
65    num_leading_null_indices = 0
66    for index in indices:
67        if index is None:
68            num_leading_null_indices += 1
69        else:
70            break
71
72    # Number of null indices in total in the indices list.
73    num_null_indices = sum([ix is None for ix in indices])
74
75    # Number of dimensions being indexed (bool/byte tensors are treated as masks, and index as
76    # many input dimensions as their number of dimensions.
77    num_indexed_dims = 0
78    mask_indices = []
79    int_indices = []
80    for index in indices:
81        if index is not None:
82            if index.dtype in [torch.bool, torch.uint8]:
83                num_indexed_dims += index.dim()
84                mask_indices.append(index)
85            else:
86                num_indexed_dims += 1
87                int_indices.append(index)
88
89    broadcast_sizes = []
90    if len(int_indices) > 0:
91        # All of the integer index tensors (non-mask & non-null index tensors) need to broadcast.
92        # We need to compute the resulting shape.
93        curr_ndim = 0
94        rev_shape = []
95        for index in int_indices:
96            for j in range(index.dim()):
97                rev_j_size = eval_expr(index.size(index.dim() - j - 1))
98                if j >= curr_ndim:
99                    curr_ndim += 1
100                    rev_shape.append(rev_j_size)
101                elif rev_shape[j] == 1:
102                    rev_shape[j] = rev_j_size
103        broadcast_sizes = list(reversed(rev_shape))
104
105    # The number of True elements in the mask indices must broadcast (i.e some might be 1
106    # but the others must all be equal). They also need to broadcast with broadcast_sizes[0]
107    # Therefore, if broadcast_sizes[0] != 1, we don't need to worry about the mask indices,
108    # since we are assuming that the inputs are valid. However, if broadcast_sizes[0] = 1,
109    # we do need to consider them. We can't know how many True elements there are in each mask,
110    # but we know that the broadcasted size, can't be larger than the minimum number of True
111    # elements across all mask indices with a number of elements other than 1.
112    if len(mask_indices) > 0 and (len(broadcast_sizes) == 0 or broadcast_sizes[0] == 1):
113        upper_bound_broadcast_size = 1
114        intialized = False
115        for mask in mask_indices:
116            mask_numel = eval_expr(mask.numel())
117            if mask_numel != 1:
118                if intialized:
119                    assert isinstance(
120                        mask_numel, int
121                    ), "Expect mask_numel to be a concrete int"
122                    assert isinstance(
123                        upper_bound_broadcast_size, int
124                    ), "Expect upper_bound_broadcast_size to be a concrete int"
125                    if upper_bound_broadcast_size > mask_numel:
126                        upper_bound_broadcast_size = mask_numel
127                else:
128                    upper_bound_broadcast_size = mask_numel
129                    intialized = True
130        if len(broadcast_sizes) == 0:
131            broadcast_sizes.append(upper_bound_broadcast_size)
132        else:
133            broadcast_sizes[0] = upper_bound_broadcast_size
134
135    broadcast_ndim = len(broadcast_sizes)
136
137    out_ndim = tensor.dim() + broadcast_ndim - num_indexed_dims
138    out_sizes: List[Optional[int]] = [0 for _ in range(out_ndim)]
139
140    if adjacent:
141        for i in range(num_leading_null_indices):
142            out_sizes[i] = eval_expr(tensor.size(i))
143        for i in range(broadcast_ndim):
144            out_sizes[i + num_leading_null_indices] = broadcast_sizes[i]
145        for i in range(num_indexed_dims + num_leading_null_indices, tensor.dim()):
146            out_sizes[i + broadcast_ndim - num_indexed_dims] = eval_expr(tensor.size(i))
147    else:
148        for i in range(broadcast_ndim):
149            out_sizes[i] = broadcast_sizes[i]
150        in_i = 0
151        out_i = broadcast_ndim
152        for index in indices:
153            if index is None:
154                out_sizes[out_i] = eval_expr(tensor.size(in_i))
155                out_i += 1
156                in_i += 1
157            else:
158                if index.dtype in [torch.bool, torch.uint8]:
159                    in_i += index.dim()
160                else:
161                    in_i += 1
162
163        for i in range(num_indexed_dims + num_null_indices, tensor.dim()):
164            out_sizes[i + broadcast_ndim - num_indexed_dims] = eval_expr(tensor.size(i))
165
166    return out_sizes
167
168
169@deprecated(
170    "`HintBasedSymShapeEvalPass` is deprecated "
171    "and will be removed in a future version of ExecuTorch. "
172    "Please use `ConstraintBasedSymShapeEvalPass` instead.",
173    category=FutureWarning,
174)
175class HintBasedSymShapeEvalPass(PassBase):
176    """
177
178    .. warning::
179
180        'HintBasedSymShapeEvalPass` is deprecated
181        and will be removed in a future version of ExecuTorch.
182        Please use `ConstraintBasedSymShapeEvalPass` instead.
183
184    If we enable dynamic shape tracing, a tensor's shape may become a symbolic
185    formula. We should convert those symbolic formula to concrete value for
186    static/upperbound tensors so we can properly do memory planning for them.
187
188    HintBasedSymShapeEvalPass evalutes the symbolic expression of shapes based
189    on its hint, which is a concrete integer that backs the sym expression. The original
190    hint comes from the sizes of the inputs that user uses for tracing and hints of
191    symbolic expressions are propagated via meta tensor computation.
192    For example, when export f(x), we use x = torch.ones(3, 4) as an exmaple input to f and
193    suppose we constrain both dimensions of x as dynamic. We'll have two symbols s0, s1 created
194    and they are backed up with hints 3 and 4 respectively. If there is a y = x[0] operation in f,
195    the shape of y is inferred to be s1, which is backed up with hint 4.
196
197    Warning: if you're using torch.export with constrain API, this method doesn't respect the input constraints.
198
199    Not inherited from ExportPass since we simply need a way to iterate thru
200    every node's output. PassBase is easier for that purpose.
201    """
202
203    def call(self, graph_module: GraphModule):
204        for subgm in graph_module.modules():
205            if not isinstance(subgm, GraphModule):
206                continue
207            for node in subgm.graph.nodes:
208                for spec in pytree.tree_flatten(node.meta.get("spec", []))[0]:
209                    # Node for function like aten.sym_size does not have spec
210                    if isinstance(spec, TensorSpec):
211                        concrete_shape = eval_shape(spec.shape)
212                        concrete_spec = eval_shape(spec.stride)
213                        if any(s is None for s in concrete_shape) or any(
214                            s is None for s in concrete_spec
215                        ):
216                            # None indicates unbacked symints, see: https://fburl.com/code/v7hj5zv6
217                            # Use value range to get the upper bounds of unbacked symints.
218                            from torch._guards import detect_fake_mode
219
220                            fake_mode = detect_fake_mode(node.meta.get("val"))
221                            if fake_mode is not None:
222                                from torch.utils._sympy.numbers import int_oo
223
224                                shape_env = fake_mode.shape_env
225                                for i, v in enumerate(spec.shape):
226                                    if concrete_shape[i] is None:
227                                        # get updated shape from var_to_range
228                                        _value_range = shape_env.var_to_range[
229                                            v._sympy_()  # pyre-fixme[16] Undefined attribute: `int` has no attribute `_sympy_`.
230                                        ]
231                                        # cannot handle unbounded, unbacked symints; add a range to bound it.
232                                        assert _value_range.upper is not int_oo
233                                        concrete_shape[i] = int(_value_range.upper)
234                                for i, v in enumerate(spec.stride):
235                                    if concrete_spec[i] is None:
236                                        _expr = (
237                                            v.node.expr  # pyre-fixme[16] Undefined attribute: `int` has no attribute `node`.
238                                        )
239                                        _value_range = v.node.shape_env.var_to_range
240                                        from torch.utils._sympy.value_ranges import (
241                                            bound_sympy,
242                                        )
243
244                                        _bound_sympy = bound_sympy(_expr, _value_range)
245                                        # cannot handle unbounded, unbacked symints; add a range to bound it.
246                                        assert _bound_sympy.upper is not int_oo
247                                        concrete_spec[i] = int(_bound_sympy.upper)
248
249                        assert all(isinstance(s, int) for s in concrete_shape) and all(
250                            isinstance(s, int) for s in concrete_spec
251                        )
252                        spec.shape = concrete_shape
253                        spec.stride = concrete_spec
254        return PassResult(graph_module, True)
255
256
257class ConstraintBasedSymShapeEvalPass(PassBase):
258    """
259    If we enable dynamic shape tracing, a tensor's shape may become a symbolic
260    formula. We should convert those symbolic formula to concrete value for
261    static/upperbound tensors so we can properly do memory planning for them.
262
263    Not inherited from ExportPass since we simply need a way to iterate through
264    every node's output. PassBase is easier for that purpose.
265    """
266
267    def call(self, graph_module: GraphModule):
268        for subgm in graph_module.modules():
269            if not isinstance(subgm, GraphModule):
270                continue
271            for node in subgm.graph.nodes:
272                for spec in pytree.tree_flatten(node.meta.get("spec", []))[0]:
273                    # Node for function like aten.sym_size does not have spec
274                    if isinstance(spec, TensorSpec):
275                        concrete_shape = [eval_upper_bound(s) for s in spec.shape]
276                        concrete_stride = [eval_upper_bound(s) for s in spec.stride]
277                        if any(not isinstance(s, int) for s in concrete_shape) or any(
278                            not isinstance(s, int) for s in concrete_stride
279                        ):
280                            raise RuntimeError(
281                                f"Cannot evalute the shape upper bound of a dynamic-shaped tensor to a concrete bounded integer. Got tensor spec: {spec}."
282                                f"The upper bound shape we get {concrete_shape}, the upper bound stride we get {concrete_stride}"
283                                "This tensor could either be from 1. a data-dependent operation such as nonzero. Or 2. an input, whose don't have a constraint for the upper bound."
284                                "Please use export's constrain_as_size() or constrain_as_value() apis and set a concrete upper bound to resolve this."
285                            )
286
287                        spec.shape = concrete_shape
288                        spec.stride = concrete_stride  # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[int]`
289        return PassResult(graph_module, True)
290