1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from typing import Callable, List, Optional 10 11import torch 12import torch.utils._pytree as pytree 13 14from executorch.exir._warnings import deprecated 15from executorch.exir.dialects._ops import ops as exir_ops 16from executorch.exir.pass_base import PassBase, PassResult 17from executorch.exir.sym_util import eval_expr, eval_shape, eval_upper_bound 18from executorch.exir.tensor import TensorSpec 19from torch.fx import GraphModule 20 21upper_bound_shape_inference_table = {} 22 23 24def register_upper_bound_inference(fn): 25 def inference_deco(f: Callable): 26 upper_bound_shape_inference_table[fn] = f 27 return f 28 29 return inference_deco 30 31 32@register_upper_bound_inference(exir_ops.edge.aten.nonzero.default) 33@register_upper_bound_inference(torch.ops.aten.nonzero.default) 34def nonzero(args, kwargs) -> List[Optional[int]]: 35 return [eval_expr(args[0].numel()), len(args[0].shape)] 36 37 38@register_upper_bound_inference(exir_ops.edge.aten.index.Tensor) 39@register_upper_bound_inference(torch.ops.aten.index.Tensor) 40def index_Tensor(args, kwargs) -> List[Optional[int]]: # noqa: C901 41 tensor = args[0] 42 indices = args[1] 43 44 # Compute numbers of contiguous blocks of non-null indices. 45 # For example, if A, B, C, D, E are non-null tensors, then 46 # [None, None, A, B, None, C, D, E, None] has 2 blocks. 47 index_blocks = 0 48 in_block = False 49 for index in indices: 50 if index is not None: 51 if not in_block: 52 in_block = True 53 index_blocks += 1 54 else: 55 in_block = False 56 57 if index_blocks == 0: 58 # If no dimensions are actually being indexed, either because the indices list is empty 59 # or all indices are null, then the result is just the same as the input tensor. 60 return tensor.shape 61 62 adjacent = index_blocks == 1 63 64 # Number of leading null indices in the indices list. 65 num_leading_null_indices = 0 66 for index in indices: 67 if index is None: 68 num_leading_null_indices += 1 69 else: 70 break 71 72 # Number of null indices in total in the indices list. 73 num_null_indices = sum([ix is None for ix in indices]) 74 75 # Number of dimensions being indexed (bool/byte tensors are treated as masks, and index as 76 # many input dimensions as their number of dimensions. 77 num_indexed_dims = 0 78 mask_indices = [] 79 int_indices = [] 80 for index in indices: 81 if index is not None: 82 if index.dtype in [torch.bool, torch.uint8]: 83 num_indexed_dims += index.dim() 84 mask_indices.append(index) 85 else: 86 num_indexed_dims += 1 87 int_indices.append(index) 88 89 broadcast_sizes = [] 90 if len(int_indices) > 0: 91 # All of the integer index tensors (non-mask & non-null index tensors) need to broadcast. 92 # We need to compute the resulting shape. 93 curr_ndim = 0 94 rev_shape = [] 95 for index in int_indices: 96 for j in range(index.dim()): 97 rev_j_size = eval_expr(index.size(index.dim() - j - 1)) 98 if j >= curr_ndim: 99 curr_ndim += 1 100 rev_shape.append(rev_j_size) 101 elif rev_shape[j] == 1: 102 rev_shape[j] = rev_j_size 103 broadcast_sizes = list(reversed(rev_shape)) 104 105 # The number of True elements in the mask indices must broadcast (i.e some might be 1 106 # but the others must all be equal). They also need to broadcast with broadcast_sizes[0] 107 # Therefore, if broadcast_sizes[0] != 1, we don't need to worry about the mask indices, 108 # since we are assuming that the inputs are valid. However, if broadcast_sizes[0] = 1, 109 # we do need to consider them. We can't know how many True elements there are in each mask, 110 # but we know that the broadcasted size, can't be larger than the minimum number of True 111 # elements across all mask indices with a number of elements other than 1. 112 if len(mask_indices) > 0 and (len(broadcast_sizes) == 0 or broadcast_sizes[0] == 1): 113 upper_bound_broadcast_size = 1 114 intialized = False 115 for mask in mask_indices: 116 mask_numel = eval_expr(mask.numel()) 117 if mask_numel != 1: 118 if intialized: 119 assert isinstance( 120 mask_numel, int 121 ), "Expect mask_numel to be a concrete int" 122 assert isinstance( 123 upper_bound_broadcast_size, int 124 ), "Expect upper_bound_broadcast_size to be a concrete int" 125 if upper_bound_broadcast_size > mask_numel: 126 upper_bound_broadcast_size = mask_numel 127 else: 128 upper_bound_broadcast_size = mask_numel 129 intialized = True 130 if len(broadcast_sizes) == 0: 131 broadcast_sizes.append(upper_bound_broadcast_size) 132 else: 133 broadcast_sizes[0] = upper_bound_broadcast_size 134 135 broadcast_ndim = len(broadcast_sizes) 136 137 out_ndim = tensor.dim() + broadcast_ndim - num_indexed_dims 138 out_sizes: List[Optional[int]] = [0 for _ in range(out_ndim)] 139 140 if adjacent: 141 for i in range(num_leading_null_indices): 142 out_sizes[i] = eval_expr(tensor.size(i)) 143 for i in range(broadcast_ndim): 144 out_sizes[i + num_leading_null_indices] = broadcast_sizes[i] 145 for i in range(num_indexed_dims + num_leading_null_indices, tensor.dim()): 146 out_sizes[i + broadcast_ndim - num_indexed_dims] = eval_expr(tensor.size(i)) 147 else: 148 for i in range(broadcast_ndim): 149 out_sizes[i] = broadcast_sizes[i] 150 in_i = 0 151 out_i = broadcast_ndim 152 for index in indices: 153 if index is None: 154 out_sizes[out_i] = eval_expr(tensor.size(in_i)) 155 out_i += 1 156 in_i += 1 157 else: 158 if index.dtype in [torch.bool, torch.uint8]: 159 in_i += index.dim() 160 else: 161 in_i += 1 162 163 for i in range(num_indexed_dims + num_null_indices, tensor.dim()): 164 out_sizes[i + broadcast_ndim - num_indexed_dims] = eval_expr(tensor.size(i)) 165 166 return out_sizes 167 168 169@deprecated( 170 "`HintBasedSymShapeEvalPass` is deprecated " 171 "and will be removed in a future version of ExecuTorch. " 172 "Please use `ConstraintBasedSymShapeEvalPass` instead.", 173 category=FutureWarning, 174) 175class HintBasedSymShapeEvalPass(PassBase): 176 """ 177 178 .. warning:: 179 180 'HintBasedSymShapeEvalPass` is deprecated 181 and will be removed in a future version of ExecuTorch. 182 Please use `ConstraintBasedSymShapeEvalPass` instead. 183 184 If we enable dynamic shape tracing, a tensor's shape may become a symbolic 185 formula. We should convert those symbolic formula to concrete value for 186 static/upperbound tensors so we can properly do memory planning for them. 187 188 HintBasedSymShapeEvalPass evalutes the symbolic expression of shapes based 189 on its hint, which is a concrete integer that backs the sym expression. The original 190 hint comes from the sizes of the inputs that user uses for tracing and hints of 191 symbolic expressions are propagated via meta tensor computation. 192 For example, when export f(x), we use x = torch.ones(3, 4) as an exmaple input to f and 193 suppose we constrain both dimensions of x as dynamic. We'll have two symbols s0, s1 created 194 and they are backed up with hints 3 and 4 respectively. If there is a y = x[0] operation in f, 195 the shape of y is inferred to be s1, which is backed up with hint 4. 196 197 Warning: if you're using torch.export with constrain API, this method doesn't respect the input constraints. 198 199 Not inherited from ExportPass since we simply need a way to iterate thru 200 every node's output. PassBase is easier for that purpose. 201 """ 202 203 def call(self, graph_module: GraphModule): 204 for subgm in graph_module.modules(): 205 if not isinstance(subgm, GraphModule): 206 continue 207 for node in subgm.graph.nodes: 208 for spec in pytree.tree_flatten(node.meta.get("spec", []))[0]: 209 # Node for function like aten.sym_size does not have spec 210 if isinstance(spec, TensorSpec): 211 concrete_shape = eval_shape(spec.shape) 212 concrete_spec = eval_shape(spec.stride) 213 if any(s is None for s in concrete_shape) or any( 214 s is None for s in concrete_spec 215 ): 216 # None indicates unbacked symints, see: https://fburl.com/code/v7hj5zv6 217 # Use value range to get the upper bounds of unbacked symints. 218 from torch._guards import detect_fake_mode 219 220 fake_mode = detect_fake_mode(node.meta.get("val")) 221 if fake_mode is not None: 222 from torch.utils._sympy.numbers import int_oo 223 224 shape_env = fake_mode.shape_env 225 for i, v in enumerate(spec.shape): 226 if concrete_shape[i] is None: 227 # get updated shape from var_to_range 228 _value_range = shape_env.var_to_range[ 229 v._sympy_() # pyre-fixme[16] Undefined attribute: `int` has no attribute `_sympy_`. 230 ] 231 # cannot handle unbounded, unbacked symints; add a range to bound it. 232 assert _value_range.upper is not int_oo 233 concrete_shape[i] = int(_value_range.upper) 234 for i, v in enumerate(spec.stride): 235 if concrete_spec[i] is None: 236 _expr = ( 237 v.node.expr # pyre-fixme[16] Undefined attribute: `int` has no attribute `node`. 238 ) 239 _value_range = v.node.shape_env.var_to_range 240 from torch.utils._sympy.value_ranges import ( 241 bound_sympy, 242 ) 243 244 _bound_sympy = bound_sympy(_expr, _value_range) 245 # cannot handle unbounded, unbacked symints; add a range to bound it. 246 assert _bound_sympy.upper is not int_oo 247 concrete_spec[i] = int(_bound_sympy.upper) 248 249 assert all(isinstance(s, int) for s in concrete_shape) and all( 250 isinstance(s, int) for s in concrete_spec 251 ) 252 spec.shape = concrete_shape 253 spec.stride = concrete_spec 254 return PassResult(graph_module, True) 255 256 257class ConstraintBasedSymShapeEvalPass(PassBase): 258 """ 259 If we enable dynamic shape tracing, a tensor's shape may become a symbolic 260 formula. We should convert those symbolic formula to concrete value for 261 static/upperbound tensors so we can properly do memory planning for them. 262 263 Not inherited from ExportPass since we simply need a way to iterate through 264 every node's output. PassBase is easier for that purpose. 265 """ 266 267 def call(self, graph_module: GraphModule): 268 for subgm in graph_module.modules(): 269 if not isinstance(subgm, GraphModule): 270 continue 271 for node in subgm.graph.nodes: 272 for spec in pytree.tree_flatten(node.meta.get("spec", []))[0]: 273 # Node for function like aten.sym_size does not have spec 274 if isinstance(spec, TensorSpec): 275 concrete_shape = [eval_upper_bound(s) for s in spec.shape] 276 concrete_stride = [eval_upper_bound(s) for s in spec.stride] 277 if any(not isinstance(s, int) for s in concrete_shape) or any( 278 not isinstance(s, int) for s in concrete_stride 279 ): 280 raise RuntimeError( 281 f"Cannot evalute the shape upper bound of a dynamic-shaped tensor to a concrete bounded integer. Got tensor spec: {spec}." 282 f"The upper bound shape we get {concrete_shape}, the upper bound stride we get {concrete_stride}" 283 "This tensor could either be from 1. a data-dependent operation such as nonzero. Or 2. an input, whose don't have a constraint for the upper bound." 284 "Please use export's constrain_as_size() or constrain_as_value() apis and set a concrete upper bound to resolve this." 285 ) 286 287 spec.shape = concrete_shape 288 spec.stride = concrete_stride # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[int]` 289 return PassResult(graph_module, True) 290