1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import copy 10import logging 11from typing import Any, List, Tuple 12 13import torch 14from executorch.exir import memory 15 16from executorch.exir.dialects._ops import ops 17from executorch.exir.tensor import ( 18 contiguous_stride_from_shape, 19 determine_tensor_dynanism, 20 dim_order_from_stride, 21 TensorShapeDynamism, 22 TensorSpec, 23) 24from torch.fx.passes.infra.pass_base import PassBase, PassResult 25 26logger: logging.Logger = logging.getLogger(__name__) 27 28 29def _is_view_copy(node: torch.fx.Node) -> bool: 30 return node.op == "call_function" and node.target in ( 31 torch.ops.aten.view_copy.default, 32 ops.edge.aten.view_copy.default, 33 ) 34 35 36_VIEW_OP = memory.view 37 38 39class _Guard: 40 def __init__( 41 self, name: str, field_lambda, expected_val: Any # pyre-ignore[2] 42 ) -> None: 43 self.name: str = name 44 self.field_lambda = field_lambda # pyre-ignore[4] 45 self.expected_val = copy.deepcopy(expected_val) # pyre-ignore[4] 46 47 def __call__(self, view_spec) -> None: # pyre-ignore[2] 48 assert view_spec._unguarded_access 49 observed_val = self.field_lambda(view_spec) 50 if observed_val != self.expected_val: 51 raise Exception( 52 f"Guard {self.name} failed. Expected to see value {self.expected_val}, but saw value {observed_val}." 53 ) 54 55 56class _ViewSpec(TensorSpec): 57 def __init__(self, base: TensorSpec, shape: List[int]) -> None: 58 """ 59 A _ViewSpec is TensorSpec that shares non-size related fields with its base. 60 The size-related fields are: shape, stride, dim_order, and shape_dynamism. 61 62 If either the base or view spec updates a non-size related field, the change 63 is reflected in both specs. But size related fields are not linked and can 64 be set separately. 65 66 A _ViewSpec can only be created from a non-sparse, strided TensorSpec. 67 On creation, a _ViewSpec must be compatible with its base with respect to 68 shape_dynamism, dtype, and nbytes. 69 70 A _ViewSpec contains _guards that are evaluated on every __getattribute__ call. 71 The purpose of the guards is to make sure the _ViewSpec is still compatible 72 with its base. 73 """ 74 75 # Explicitly put all attributes into _self_fields or _base_fields 76 # Any attribute that is not in _self_fields or _base_fields will 77 # raise an Exception. If TensorSpec is extended with a new attribute, 78 # we should explicitly decide how _ViewSpec will handle it. 79 self._self_fields = [ 80 # We need to get the debug method from self 81 # so that the object id it prints is correct. 82 "debug", # method 83 "__repr__", # method 84 # The following are related to size and should use self 85 "shape", 86 "stride", 87 "dim_order", 88 "shape_dynamism", 89 "nbytes", # method 90 "allocated_memory", # property 91 "is_dynamic_shape_tensor", # property 92 "is_static_shape_tensor", # property 93 "is_upper_bound_tensor", # property 94 "is_dynamic_unbound_tensor", # property 95 ] 96 self._base_fields = [ 97 "scalar_type", 98 "const", 99 "alignment", 100 "storage", 101 "requires_grad", 102 "layout", 103 "is_sparse", 104 "init_mem_planning_fields", # method 105 "realign", # method 106 "from_tensor", # class method 107 "lifetime", 108 "mem_id", 109 "mem_obj_id", 110 "mem_offset", 111 "dtype", # property 112 ] 113 114 # Make sure _self_fields and _base_fields are disjoint 115 assert len(set(self._self_fields) & set(self._base_fields)) == 0 116 117 self._guards: List[_Guard] = [] 118 self._unguarded_access = False 119 120 # Make sure base is not sparse and add a guard 121 if base.is_sparse: 122 raise Exception( 123 "_ViewSpec can only be created from non-sparse TensorSpec, but base.is_sparse=True." 124 ) 125 self._guards.append( 126 _Guard( 127 "is_sparse", 128 lambda view_spec: view_spec.is_sparse, 129 False, 130 ) 131 ) 132 133 # Make sure base layout is strided and add a guard 134 if base.layout != torch.strided: 135 raise Exception( 136 f"_ViewSpec can only be created from TensorSpec with layout={torch.strided}, but got layout={base.layout}." 137 ) 138 self._guards.append( 139 _Guard( 140 "layout", 141 lambda view_spec: view_spec.layout, 142 torch.strided, 143 ) 144 ) 145 146 self._base = base 147 self.shape: List[int] = shape 148 self.stride: Tuple[int] = contiguous_stride_from_shape(torch.Size(self.shape)) 149 self.dim_order: Tuple[bytes] = dim_order_from_stride(self.stride) 150 self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism( 151 torch.Size(self.shape) 152 ) 153 154 # Check compatibility with base on creation 155 if self.shape_dynamism != base.shape_dynamism: 156 raise Exception( 157 f"_ViewSpec is incompatible with its base on creation. It has shape_dynamism={self.shape_dynamism}, but its base has shape_dynamism={base.shape_dynamism}." 158 ) 159 self._guards.append( 160 _Guard( 161 "shape_dynamism_init", 162 lambda view_spec: view_spec.shape_dynamism, 163 base.shape_dynamism, 164 ) 165 ) 166 self._guards.append( 167 _Guard( 168 "shape_dynamism_eq_base", 169 lambda view_spec: view_spec.shape_dynamism 170 == view_spec._base.shape_dynamism, 171 True, 172 ) 173 ) 174 175 if self.dtype != base.dtype: 176 raise Exception( 177 f"_ViewSpec is incompatible with its base on creation. It has dtype={self.dtype}, but its base has dtype={base.dtype}." 178 ) 179 self._guards.append( 180 _Guard("dtype", lambda view_spec: view_spec.dtype, base.dtype) 181 ) 182 183 # We do not guard nbytes because dynamic symints are replaced by upper bounds. 184 # We do guard on rank, though 185 if self.nbytes() != base.nbytes(): 186 raise Exception( 187 f"_ViewSpec is incompatible with its base on creation. It has nbytes={self.nbytes()}, but its base has nbytes={base.nbytes()}." 188 ) 189 self._guards.append( 190 _Guard("rank", lambda view_spec: len(view_spec.shape), len(shape)) 191 ) 192 193 def _run_guards(self) -> None: 194 unguarded_access = self._unguarded_access 195 try: 196 self._unguarded_access = True 197 for g in self._guards: 198 g(self) 199 finally: 200 self._unguarded_access = unguarded_access 201 202 def __getattribute__(self, name: str): # pyre-ignore 203 # Special field so we don't recurse infinitely 204 if name in [ 205 "_base", 206 "_self_fields", 207 "_base_fields", 208 "_guards", 209 "_unguarded_access", 210 "_run_guards", 211 ]: 212 return object.__getattribute__(self, name) 213 214 # Get some attributes from self 215 if name in self._self_fields: 216 val = object.__getattribute__(self, name) 217 elif name in self._base_fields: 218 val = object.__getattribute__(self._base, name) 219 else: 220 if len(name) > 0 and name[0] != "_": 221 logger.warning( 222 f"Getting non-private attribute {name} on self, but it is not in _self_fields or _base_fields. Is this intended?" 223 ) 224 val = object.__getattribute__(self, name) 225 226 if not self._unguarded_access: 227 self._run_guards() 228 return val 229 230 def __setattr__(self, name: str, val) -> None: # pyre-ignore 231 # Special field so we don't recurse infinitely 232 if name in [ 233 "_base", 234 "_self_fields", 235 "_base_fields", 236 "_guards", 237 "_unguarded_access", 238 "_run_guards", 239 ]: 240 object.__setattr__(self, name, val) 241 return 242 243 if name in self._self_fields: 244 object.__setattr__(self, name, val) 245 return 246 247 if name in self._base_fields: 248 object.__setattr__(self._base, name, val) 249 return 250 251 if len(name) > 0 and name[0] != "_": 252 logger.warning( 253 f"Setting non-private attribute {name} on self, but it is not in _self_fields or _base_fields. Is this intended?" 254 ) 255 object.__setattr__(self, name, val) 256 257 258class ReplaceViewCopyWithViewPass(PassBase): 259 def __init__(self) -> None: 260 super().__init__() 261 262 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 263 """ 264 This pass replaces view_copy nodes with view nodes. 265 266 This should be run after the NormalizeViewCopyBasePass. 267 268 During memory planning, view nodes share the same storage as their base. 269 """ 270 271 n_replaced = 0 272 for module in graph_module.modules(): 273 if not isinstance(module, torch.fx.GraphModule): 274 continue 275 for node in module.graph.nodes: 276 # Note: We only replace view_copy nodes that are not output, since 277 # the output pointer could be modified at runtime (T187925929) 278 if _is_view_copy(node) and all(u.op != "output" for u in node.users): 279 base, _ = node.args 280 node.target = _VIEW_OP 281 282 # Create spec for the node. 283 # _ViewSpec gives a view into its base spec for non-size 284 # related information. 285 286 # the shape is not the same as node.args[1] because node.args[1] 287 # can have an inferred sizes (-1). 288 shape = node.meta["val"].shape 289 node.meta["spec"] = _ViewSpec(base.meta["spec"], shape) 290 291 n_replaced += 1 292 293 module.recompile() 294 295 logger.debug(f"Replaced {n_replaced} view_copy nodes with {_VIEW_OP} nodes.") 296 return PassResult(graph_module, n_replaced > 0) 297 298 def ensures(self, graph_module: torch.fx.GraphModule) -> None: 299 for module in graph_module.modules(): 300 if not isinstance(module, torch.fx.GraphModule): 301 continue 302 for node in module.graph.nodes: 303 # Note: We only replace view_copy nodes that are not output, since 304 # the output pointer could be modified at runtime (T187925929) 305 assert not ( 306 _is_view_copy(node) and all(u.op != "output" for u in node.users) 307 ) 308 if node.op == "call_function" and node.target == _VIEW_OP: 309 assert isinstance(node.meta["spec"], _ViewSpec) 310 311 def requires(self, graph_module: torch.fx.GraphModule) -> None: 312 """ 313 This pass should be called after NormalizeViewCopyBasePass. 314 We check that all view_copy nodes have been normalized. 315 """ 316 for module in graph_module.modules(): 317 if not isinstance(module, torch.fx.GraphModule): 318 continue 319 for node in module.graph.nodes: 320 # Note: We only replace view_copy nodes that are not output, since 321 # the output pointer could be modified at runtime (T187925929) 322 if _is_view_copy(node) and all(u.op != "output" for u in node.users): 323 base, size = node.args 324 assert not _is_view_copy(base) 325