1# mypy: allow-untyped-defs 2import torch 3from torch.fx import Node 4from torch.fx._compatibility import compatibility 5from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor 6from torch.utils._pytree import tree_map_only 7from torch.utils import _pytree as pytree 8from torch.multiprocessing.reductions import StorageWeakRef 9 10import _operator 11from enum import Enum 12import itertools 13from typing import Set, Dict 14from collections import defaultdict 15 16__all__ = ['reinplace'] 17 18class _ViewType(Enum): 19 NonView = 0 20 SingleOutputView = 1 21 MultiOutputView = 2 22 23def _is_view_op(tgt): 24 if tgt is not None and isinstance(tgt, torch._ops.OpOverload): 25 schema = tgt._schema 26 if len(schema.arguments) > 0: 27 first_arg = schema.arguments[0] 28 # check if op is a view 29 return first_arg.alias_info is not None and not first_arg.alias_info.is_write 30 31def _get_view_type(tgt) -> _ViewType: 32 if tgt is not None and isinstance(tgt, torch._ops.OpOverload): 33 schema = tgt._schema 34 if len(schema.arguments) > 0: 35 first_arg = schema.arguments[0] 36 # check if op is a view 37 if first_arg.alias_info is not None and not first_arg.alias_info.is_write: 38 # check if op is a multi-output view 39 if '*' in first_arg.alias_info.after_set: 40 return _ViewType.MultiOutputView 41 else: 42 return _ViewType.SingleOutputView 43 return _ViewType.NonView 44 45 46# Stores a bunch of metadata related to functionalization each node. 47# Relevant metadata: 48# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors) 49# The fake tensor output from running the current node 50# n.meta['view_of']: Node 51# If the current node n is a view of some base tensor, the 'view_of' field tells us which 52# view node was used to generate the current node (a view tensor). 53# This information actually makes `fake_result` redundant, but we can use `fake_result` 54# to sanity check that our aliasing information is correct. 55@compatibility(is_backward_compatible=False) 56class _FunctionalizationMetadataProp(torch.fx.Interpreter): 57 58 def run_node(self, node: Node): 59 self.node_counter += 1 60 result = super().run_node(node) 61 node.meta['fake_result'] = result 62 node.meta['node_idx'] = self.node_counter 63 64 # (1) Update metadata with the list of nodes that are used by this node 65 # copy_() doesn't read from its first argument; it writes to it, overwriting previous data. 66 # We don't want to treat it as "being used as an input". 67 node_args = node.args 68 if node.target is torch.ops.aten.copy_.default: 69 node_args = node_args[1:] 70 71 # (2) Update metadata to track aliasing information about view tensor nodes. 72 if node.op == 'call_function': 73 view_type = _get_view_type(node.target) 74 if view_type == _ViewType.SingleOutputView: 75 assert isinstance(node.args[0], Node) 76 node.meta['view_of'] = node.args[0] 77 elif view_type == _ViewType.MultiOutputView: 78 self.multi_output_view_nodes[node] = node.args[0] 79 80 # Check if we returned a multi-output view, 81 # and we're now grabbing the individual views from the output. 82 # 83 # For multi-output views, we want to map each output view to the base, 84 # but this mapping involves two separate nodes in FX IR. 85 # e.g. "a, b = x_1.split(...)" becomes: 86 # %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {}) 87 # %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {}) 88 # %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {}) 89 # And we'd like to set: 90 # getitem1.meta['view_of'] = x_1 91 elif node.target is _operator.getitem: 92 list_arg = node.args[0] 93 maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None) 94 if maybe_base_of_view is not None: 95 # Note: we could also track indexing info here for multi-output views. 96 # I don't think this metadata is strictly needed for de-functionalization. 97 assert isinstance(maybe_base_of_view, Node) 98 node.meta['view_of'] = maybe_base_of_view 99 100 if 'view_of' in node.meta: 101 # We're linking the current node with its first argument as views. 102 # Assert here that this is actually the case, and their storages are the same. 103 assert isinstance(node.meta['fake_result'], FakeTensor) 104 assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) 105 view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) 106 base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage()) 107 assert view_storage == base_storage 108 return result 109 110 111 112 def propagate(self, *args): 113 self.multi_output_view_nodes = {} 114 self.node_counter = -1 115 116 with FakeTensorMode() as mode: 117 fake_args = [mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args] 118 return super().run(*fake_args) 119 120def _schemas_match(functional_schema, inplace_schema): 121 names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name 122 arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all( 123 a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)) 124 # for the inplace op, its first argument should be mutable 125 assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write 126 # and its remaining arguments shouldn't be. 127 assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) 128 return names_match and arg_types_match 129 130# TODO: this should be beefed up to be able to properly re-inplace with: 131# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) 132# - out= ops (e.g. angle -> angle.out) 133# TODO: we should also figure this info out using torchgen. 134def _maybe_get_inplace_op(op): 135 # __module__ seems broken; it returns torch._ops.aten which doesn't exist 136 if not isinstance(op, torch._ops.OpOverload): 137 return None 138 # Some view ops have inplace variants (as_strided_, etc), 139 # but we do NOT want the reinplacing pass to directly add these into the program. 140 # (they'll require extra special handling, aren't aren't really useful for perf anyway) 141 if _is_view_op(op): 142 return None 143 op_namespace = op.__module__.split(".")[-1] 144 op_base_name = op.overloadpacket.__name__ 145 maybe_namespace_module = getattr(torch.ops, op_namespace) 146 maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None) 147 if maybe_inplace_op is None: 148 return None 149 150 inplace_overloads = [ 151 getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads() 152 ] 153 inplace_overloads_with_matching_schemas = [ 154 f 155 for f in inplace_overloads 156 if _schemas_match(op._schema, f._schema) 157 ] 158 # Just because foo() and foo_() are both existing operators, 159 # They aren't guaranteed to have compatible schemas. 160 # For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant, 161 # Even though several overloads of pow_ exist. 162 if len(inplace_overloads_with_matching_schemas) == 0: 163 return None 164 assert len(inplace_overloads_with_matching_schemas) == 1 165 inplace_op = inplace_overloads_with_matching_schemas[0] 166 return inplace_op 167 168_VIEW_INVERSE_MAP = { 169 torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, 170 torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, 171 torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, 172 torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, 173} 174 175# This function, given a set of set of (aliased) tensor nodes, 176# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index 177# in the node ordering. 178def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): 179 def _add_if_tensor(x, set_): 180 if isinstance(x, FakeTensor): 181 set_.add(StorageWeakRef(x._typed_storage())) 182 183 nodes_used_after = set() 184 for t in tensor_aliases: 185 # get all nodes that use the current alias 186 usage_nodes = t.users 187 for n in usage_nodes: 188 # We only care about usages after the current node 189 if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index: 190 continue 191 # We also don't care about intermediate view ops. 192 # They only matter if their output is then used elsewhere 193 # (either in an out-of-place op, or as an output to the function). 194 if n in tensor_aliases: 195 if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem: 196 continue 197 nodes_used_after.add(n) 198 return nodes_used_after 199 200# Given an op that we're trying to re-inplace, "b = foo(a)", 201# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)" 202# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF: 203# If there are any aliases in the alias_set(a) that satisfy: 204# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base" 205# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata 206# as "alias" 207def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]: 208 def matching_view_metadata(a, b): 209 return a.size() == b.size() and \ 210 a.stride() == b.stride() and \ 211 a.storage_offset() == b.storage_offset() 212 213 view_inverse_nodes = set() 214 # Go through them in node order, so we can see chains of view_scatter ops. 215 for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']): 216 if n.target not in _VIEW_INVERSE_MAP: 217 continue 218 base = n.args[0] 219 mutated_view = n.args[1] 220 assert isinstance(base, Node) 221 assert isinstance(base.meta['fake_result'], FakeTensor) 222 assert isinstance(mutated_view, Node) 223 assert isinstance(mutated_view.meta['fake_result'], FakeTensor) 224 # Check that this view_inverse op actually corresponds to taking doing the inverse 225 # of one of our existing self_alias nodes. 226 original_view = _VIEW_INVERSE_MAP[n.target] 227 for self_alias in self_aliases: 228 # We're looking for some alias of the self arg, "alias", 229 # that was created from some op `alias = foo(base, args...)` 230 # such that the current _scatter op "inverts" that foo call. 231 # We can check that by running the original op again, and checking that the strides match. 232 if 'view_of' not in self_alias.meta: 233 continue 234 self_alias_base = self_alias.meta['view_of'] 235 try: 236 # The we're trying to re-use the args from the view_scatter call inside of the corresponding 237 # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse 238 # of the current alias we're looking at. 239 view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs) 240 expected_metadata = self_alias.meta['fake_result'] 241 # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace. 242 if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \ 243 matching_view_metadata(view_replay_metadata, expected_metadata): 244 view_inverse_nodes.add(n) 245 except Exception: 246 continue 247 248 return view_inverse_nodes 249 250 251@compatibility(is_backward_compatible=True) 252def reinplace(gm, *sample_args): 253 """ 254 Given an fx.GraphModule, modifies it to perform "reinplacing", 255 mutating the nodes of the graph. 256 We look for out-of-place op call sites like `b = a.add(...)`, 257 and convert them to be inplace (`b = a.add_(...)`), 258 as long as the input to the current operator ("a") isn't re-used 259 anywhere later in the graph. 260 261 This pass currently expects to operate on a **functional, ATen** graph. 262 This can be obtained by running `make_fx(functionalize(f))`. 263 264 Sample inputs are needed to determine aliasing relationships of the inputs. 265 In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the 266 inputs to the program. 267 268 Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows: 269 270 (1) Perform some initial checks on the metadata of "a" and "args..." 271 that can disqualify them from being reinplaced. 272 273 (1a) Check that the self argument we're attempting to reinplace 274 has acceptable dtype/size metadata to reinplace with. 275 276 For example, if we have: 277 a = torch.ones(1) 278 b = torch.ones(10) 279 out = torch.add(a, b) 280 We can't turn that into 281 a.add_(b) 282 Because that would require resizing "a". 283 284 Similarly, we can't convert torch.ge(a, b) into a.ge_(b), 285 because that would require changing a's dtype (from e.g. float32 to bool). 286 Note that in this specific example, we could technically do better.. 287 288 If we see the pattern: 289 a_1 = a.ge(b) 290 a_2 = aten._to_copy(a_1, a.dtype) 291 Then we this should be valid to completely re-inplace 292 (this is exactly what functionalization will emit when it sees a.ge_(b)). 293 294 This optimization is only really important for user programs 295 that directly use inplace comparison ops though. 296 297 We also cannot re-inplace on tensors that have overlapping memory, 298 e.g. torch.ones(1).expand(4, 4).add_(1) 299 300 (1b) Check if "a" is an alias of any of the program inputs. 301 302 If it is, skip and move to the next node. 303 Inplace'ing an op that would cause it to mutate a program is not sound, 304 because that would be a side effect visible to the user. 305 306 NOTE: there's a future optimization that we should make: 307 if "a" is a (alias of a) program input, but later in the program 308 there is a node that looks like "a.copy_(...)", 309 Then re-inplacing is ok to do - we are temporarily re-using a's buffer, 310 which will later be overwritten by the copy_() call. 311 312 This will be an important optimization to have for programs that mutate 313 their inputs. It currently isn't implemented though. 314 315 (1c) Check if "a" and "args..." alias 316 317 For example, re-inplacing to create code like the below 318 isn't guaranteed to be sound: 319 320 aten.mul_(a, a) 321 322 (2) Check that "a" and all of its outstanding aliases are not used anywhere 323 later in the graph. If this is the case, then it's safe to re-inplace 324 to "b = foo_(a)". 325 326 There are a few caveats to this, explained in more detail below: 327 (a) If "a" is used later as an argument to a view op, that is okay. 328 It's only a problem if "a" (or that view) is later passed 329 into a normal operator, or if it is returned as the program output. 330 (b) If "a" is a repeat argument in `foo()`, then don't reinplace. 331 Most ATen kernels don't make any guarantees that this is sound, 332 e.g. if you do aten.mul_(a, a). 333 So we'll just ban re-inplacing in this case. 334 It's only a problem if "a" (or that view) is later passed 335 (c) If "a" is used as an input into a view "inverse" / "scatter" 336 operator, it is potentially fine to re-inplace 337 (and remove that scatter operator from the graph). 338 See below for a more detailed example. 339 340 NOTE: there is an optimization in this step that is crucial 341 to fully recovering performance from functionalization. 342 343 Given this program: 344 def f(x): 345 a = torch.ops.aten.add(x, x) 346 b = torch.ops.aten.diagonal(a) 347 torch.ops.aten.fill_(b, 0) 348 return d 349 350 Functionalization will emit the following: 351 def f(x): 352 a = torch.ops.aten.add(x, x) 353 b = torch.ops.aten.diagonal(a, 0, 1) 354 b_updated = torch.ops.aten.fill(b, 0) 355 a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1) 356 return a_updated 357 358 Ordinarily, we would not be able to reinplace the fill, 359 because "b" aliases with "a" which is used by the diagonal_scatter call. 360 361 "re-inplacing" is on the hook for figuring out that it is ok to 362 completely, the expensive diagonal_scatter call, if we re-inplace the add(). 363 364 So, for every `alias in alias_set(a)`, instead of checking 365 that "alias" is not used anywhere later in the graph, 366 we check that 367 EITHER: 368 (a) alias is not used anywhere later in the graph 369 OR: 370 (b) alias is used exactly once later on in the graph, 371 in the following op: 372 373 out = foo_scatter(alias, x, args...) 374 375 where the following must hold: 376 (i) "foo_scatter" is the "inverse" operator for foo. 377 This only applies to "foo" ops that are view operators, 378 which view into a subset of the original tensor's memory. 379 In practice, there are ~4 operators where this applies: 380 diagonal -> diagonal_scatter 381 slice -> slice_scatter 382 select -> select_scatter 383 as_strided -> as_strided_scatter 384 (ii) "args..." are the same between the foo() and foo_scatter() calls. 385 386 (3) Perform the actual re-inplacing on foo! 387 388 (3b) is the common case, but special care is needed for {view}_scatter (3a) 389 390 (3a) {view}_scatter ops. 391 392 Consider this program: 393 a = torch.zeros(2, 2) 394 b = torch.ones(2) 395 a[0] = b 396 397 Post functionalization, that will look like: 398 a = torch.zeros(2) 399 b = torch.ones(1) 400 a_updated = torch.select_scatter(a, b, 0, 0) 401 402 In this case though, there is no "functional" op to re-inplace! 403 Instead, we'd like to directly remove toe select_scatter call. 404 We already know from (3) that this is valid, 405 because "a" has no later usages in the graph. 406 407 We perform the re-inplacing on the {view}_scatter op like so 408 Before: 409 a_updated = torch.select_scatter(a, b, args...) 410 After: 411 a_slice = a.select(a, args...) 412 a_slice.copy_(b) 413 414 (3b) Otherwise, replace the functional op with its inplace variant. 415 Before: 416 b = foo(a, args...) 417 After: 418 a.foo_(args...) 419 420 (4) Finally, after converting either: 421 Before: 422 b = foo(a) 423 After: 424 foo_(a) 425 or 426 Before: 427 b = {slice}_scatter(a, mutated_slice, args...) 428 After: 429 slice = {slice}(a, args...) 430 slice.copy_(mutated_slice) 431 432 We now need to find all later nodes that use "b" as an argument 433 and update them to take in "a" instead. 434 435 Note that for the majority of inplace ops, this isn't actually necessary 436 (because most inplace ops return "self" as their output). 437 This isn't generally true for all mutable ops though, which is why 438 we need to actually replace all of the arguments. 439 440 We also need to update our metadata of Dict[StorageWeakRef, Set[Node]], 441 That maps a given tensor storage to the set of all nodes that take in that storage 442 as an input. 443 Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused 444 together. 445 446 (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them" 447 during step (3) get manually deleted from the graph. 448 Their outputs are no longer used, so technically standard DCE would be able 449 to do this, but we can no longer run FX's DCE pass now that we have mutable 450 ops in the graph. 451 """ 452 _FunctionalizationMetadataProp(gm).propagate(*sample_args) 453 454 # Useful debug printing 455 # def _print(x): 456 # if isinstance(x, FakeTensor): 457 # print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}') 458 459 # for n in gm.graph.nodes: 460 # print(n.format_node()) 461 # if hasattr(n, 'meta'): 462 # print(f'node_idx: {n.meta["node_idx"]}') 463 # if 'fake_result' in n.meta: 464 # tree_map(_print, n.meta['fake_result']) 465 # if 'view_of' in n.meta: 466 # print(f'view_of: {str(n.meta["view_of"])}') 467 # print() 468 469 # We need to know which nodes correspond to inputs (or their aliases) 470 # so we know not to re-inplace them. 471 # NOTE: later, we'll need to add an optimization for fully recovering performance 472 # on programs that mutate inputs. 473 input_storages = { 474 StorageWeakRef( 475 node.meta['fake_result']._typed_storage() 476 ) for node in gm.graph.nodes if (node.op == 'placeholder' and isinstance(node.meta['fake_result'], torch.Tensor))} 477 478 # We also need to know for a given node, what are all of its aliasing nodes. 479 storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set) 480 for n in gm.graph.nodes: 481 if 'fake_result' in n.meta: 482 # Tree-mapping because some ops can return lists of tensors. 483 def _add_to_map(x): 484 if isinstance(x, FakeTensor): 485 storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) 486 pytree.tree_map_(_add_to_map, n.meta['fake_result']) 487 488 # inplace-ify functional ops, subject to the constraints written below. 489 all_later_view_inverse_nodes_to_delete = set() 490 for idx, node in enumerate(gm.graph.nodes): 491 if node.op == 'call_function': 492 493 # Today, the re-inplace pass on directly acts on: 494 # - functional ops with an inplace variant 495 # - {view}_scatter ops that can be potentially removed from the graph. 496 # Both of these ops take in tensor first args, so filtering on this condition 497 # makes the later code simpler. 498 # We should revisit this at some point though, particularly when we also want 499 # the reinplacer to be able to handle out= and mutable operators 500 # and tensorlist first args (like `_foreach_` ops). 501 if not isinstance(node.target, torch._ops.OpOverload): 502 continue 503 if len(node.target._schema.arguments) < 1: 504 continue 505 if type(node.target._schema.arguments[0].type) != torch.TensorType: 506 continue 507 508 # Step 1a: Check that the self argument we're attempting to reinplace 509 # has the same size/stride as the output. 510 # For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor) 511 # As it would require resizing scalar_tensor. 512 # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor), 513 # this is probably an optimization to revisit later). 514 self_arg = node.args[0] 515 self_flattened = pytree.tree_leaves(self_arg.meta['fake_result']) 516 node_flattened = pytree.tree_leaves(node.meta['fake_result']) 517 self_has_wrong_metadata = False 518 if len(self_flattened) == len(node_flattened): 519 for self_meta, node_meta in zip(self_flattened, node_flattened): 520 if self_meta.numel() != node_meta.numel(): 521 self_has_wrong_metadata = True 522 if self_meta.dtype != node_meta.dtype: 523 self_has_wrong_metadata = True 524 # We also cannot re-inplace on tensors that have internal memory overlap. 525 # e.g. torch.ones(1).expand(4, 4).add_(1) 526 if torch._debug_has_internal_overlap(self_meta) == 1: 527 self_has_wrong_metadata = True 528 # Here, we (optimistically) assume that a.resize(b) is valid to re-inplace, 529 # Since users should never really be calling the functional "torch.ops.aten.resize" 530 # op directly in their programs. 531 if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default: 532 continue 533 534 # Step 1b: ensure that the op we're trying to re-inplace isn't a program input 535 self_arg_name = self_arg.name 536 self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) 537 if self_arg_storage in input_storages: 538 # TODO: later, add the optimization for handling `copy_()` calls in the graph. 539 continue 540 if len([x for x in node.args if x is self_arg]) > 1: 541 # Step 1c: 542 # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound, 543 # so we prevent re-inplacing in this case. 544 continue 545 546 self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) 547 self_aliases = storage_to_nodes[self_arg_storage] 548 549 # First, we find all later usages of any of the aliases of self_arg. 550 later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx']) 551 # Then, we check if any of those later usages are actually view_scatter ops 552 # that are safe to fully remove. 553 later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases) 554 555 # Step 2: Check to see if the input to the op is re-used later in the graph. 556 # If not (same goes for its aliases), then this op is safe to re-in place. 557 # This is a slightly roundabout way to check that there are no later usages of the current self argument. 558 # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete) 559 can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0 560 if not can_reinplace: 561 continue 562 563 # Step 3a: Special handling for when we see *_scatter operators. 564 # When we see an operator like `b = torch.slice_scatter(a, ...)`, 565 # instead of trying to "inplace" it into a.slice_scatter_(..._), 566 # we would prefer to remove it from the graph entirely, 567 # and instead copy_() the slice directly into the larger tensor. 568 # See the description of the algorithm for a full example. 569 if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete: 570 view_op = _VIEW_INVERSE_MAP[node.target] 571 # Before: 572 # base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...) 573 # After: 574 # slice = torch.ops.aten.slice.default(base, args...) 575 # slice.copy_(mutated_slice) 576 with gm.graph.inserting_before(node): 577 mutated_slice_node = node.args[1] 578 remaining_slice_args = node.args[2:] 579 slice_node = gm.graph.create_node( 580 'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs) 581 copy_node = gm.graph.create_node( 582 'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {}) 583 # Add the slice_scatter node to our "nodes to delete" list. 584 all_later_view_inverse_nodes_to_delete.add(node) 585 586 587 else: 588 # Step 3b: Check to see if this operator has an inplace variant. 589 maybe_inplace_op = _maybe_get_inplace_op(node.target) 590 if maybe_inplace_op is None: 591 continue 592 # And if so, replace it with its inplace variant. 593 node.target = maybe_inplace_op 594 595 # At this point, 'storage_to_nodes' will be stale. 596 # Now that we're inplacing `b = foo(a)`, we need to effectively 597 # union together the dict values for b and a's storage. 598 # Hmm... morally I think we also want to keep the `fake_result` metadata 599 # up to date here, but I'm not sure how easy it is to do. 600 # Maybe it's fine to wait until the end of the pass to update it. 601 curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) 602 storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage]) 603 storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage]) 604 605 # Need to remember the view_scatter view nodes we found so we can remove them alter. 606 all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages) 607 608 # Step 4: 609 # Now that we've replaced b = a.foo() with a.foo_(), 610 # We need to replace any later usages of "b" with "a" 611 for old in itertools.chain([node], later_view_inverse_node_usages): 612 new = old.args[0] 613 nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']] 614 for node_to_update in nodes_to_update: 615 new_args = [] 616 args = node_to_update.args 617 618 def replace_arg(a): 619 if a == old: 620 return new 621 return a 622 623 # First, replace usages of "b" with "a" 624 node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args) 625 node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs) 626 627 # Second, update our storage_to_nodes data structure. 628 old_flattened_res = pytree.tree_leaves(old.meta['fake_result']) 629 node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result']) 630 631 old_res_storage = { 632 StorageWeakRef( 633 x._typed_storage() 634 ) for x in old_flattened_res if isinstance(x, FakeTensor)} 635 node_res_storage = { 636 StorageWeakRef( 637 x._typed_storage() 638 ) for x in node_flattened_res if isinstance(x, FakeTensor)} 639 640 # This will happen if we're updating a view op, e.g. 641 # e.g. replacing 642 # x = view(old) 643 # x = view(new) 644 # When that happens, we need to make sure to keep our 645 # storage mapping up to date. 646 # 647 # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor, 648 # or multiple tensors that all share the same storage. 649 # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. 650 if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage: 651 new_flattened_res = pytree.tree_leaves(new.meta['fake_result']) 652 new_res_storage = { 653 StorageWeakRef( 654 x._typed_storage() 655 ) for x in new_flattened_res if isinstance(x, FakeTensor)} 656 assert len(new_res_storage) == 1 657 (old_ref,) = old_res_storage 658 (new_ref,) = new_res_storage 659 (node_ref,) = node_res_storage 660 # Technically, "old_ref" and all its aliases will remain 661 # in our mapping. 662 # That should be fine though, since we deleted "old" 663 # from the graph at this point. 664 storage_to_nodes[node_ref].update(storage_to_nodes[new_ref]) 665 storage_to_nodes[new_ref].update(storage_to_nodes[node_ref]) 666 667 # Step 4: delete any _scatter nodes that we de-functionalized 668 # Need to take care not to delete any of these nodes until after *all* modifications 669 # to the graph are finished. 670 for to_delete in all_later_view_inverse_nodes_to_delete: 671 gm.graph.erase_node(to_delete) 672 673 674 gm.recompile() 675 return gm 676