1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport itertools 10*523fa7a6SAndroid Build Coastguard Workerimport logging 11*523fa7a6SAndroid Build Coastguard Workerimport operator 12*523fa7a6SAndroid Build Coastguard Workerimport typing 13*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict 14*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass 15*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerimport torch 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import memory 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.control_flow import while_loop as exir_while 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import internal_assert, InternalError 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.operator.convert import is_inplace_variant, is_out_variant 23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import TensorShapeDynamism 24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Workerfrom torch import fx 27*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportGraphSignature 28*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import Node 29*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_flatten 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard WorkerREGISTERED_ALGOS: Dict[str, Callable[..., List[int]]] = {} 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Workerclass Verifier: 35*523fa7a6SAndroid Build Coastguard Worker """ 36*523fa7a6SAndroid Build Coastguard Worker Verify if the outcome of a memory planning algorithm makes sense. 37*523fa7a6SAndroid Build Coastguard Worker E.g., make sure tensors having overlapping lifetime does not have overlapping 38*523fa7a6SAndroid Build Coastguard Worker storage/buffer. 39*523fa7a6SAndroid Build Coastguard Worker """ 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Worker def __init__( 42*523fa7a6SAndroid Build Coastguard Worker self, 43*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 44*523fa7a6SAndroid Build Coastguard Worker alloc_graph_input: bool, 45*523fa7a6SAndroid Build Coastguard Worker alloc_graph_output: bool, 46*523fa7a6SAndroid Build Coastguard Worker graph_signature: Optional[ExportGraphSignature] = None, 47*523fa7a6SAndroid Build Coastguard Worker ) -> None: 48*523fa7a6SAndroid Build Coastguard Worker self.graph_module = graph_module 49*523fa7a6SAndroid Build Coastguard Worker self.graph_signature = graph_signature 50*523fa7a6SAndroid Build Coastguard Worker self.alloc_graph_input = alloc_graph_input 51*523fa7a6SAndroid Build Coastguard Worker self.alloc_graph_output = alloc_graph_output 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker @classmethod 54*523fa7a6SAndroid Build Coastguard Worker def mem_obj_id_match( 55*523fa7a6SAndroid Build Coastguard Worker cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec, accept_both_none: bool = True 56*523fa7a6SAndroid Build Coastguard Worker ) -> bool: 57*523fa7a6SAndroid Build Coastguard Worker """ 58*523fa7a6SAndroid Build Coastguard Worker Given two `TensorSpec`, return if their `mem_obj_id` are the same. Note that if 59*523fa7a6SAndroid Build Coastguard Worker both are None, this function will return True if `accept_both_none` is True and 60*523fa7a6SAndroid Build Coastguard Worker False otherwise. 61*523fa7a6SAndroid Build Coastguard Worker """ 62*523fa7a6SAndroid Build Coastguard Worker if lhs_spec.mem_id != rhs_spec.mem_id: 63*523fa7a6SAndroid Build Coastguard Worker return False 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker # both are None 66*523fa7a6SAndroid Build Coastguard Worker if lhs_spec.mem_obj_id is None and rhs_spec.mem_obj_id is None: 67*523fa7a6SAndroid Build Coastguard Worker return accept_both_none 68*523fa7a6SAndroid Build Coastguard Worker 69*523fa7a6SAndroid Build Coastguard Worker return lhs_spec.mem_obj_id == rhs_spec.mem_obj_id 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Worker @classmethod 72*523fa7a6SAndroid Build Coastguard Worker def has_overlap(cls, lhs_ivl: List[int], rhs_ivl: List[int]) -> bool: 73*523fa7a6SAndroid Build Coastguard Worker r""" 74*523fa7a6SAndroid Build Coastguard Worker The passed in intervals are inclusive in both sides. Return if they have 75*523fa7a6SAndroid Build Coastguard Worker overlapping. 76*523fa7a6SAndroid Build Coastguard Worker """ 77*523fa7a6SAndroid Build Coastguard Worker # empty interval 78*523fa7a6SAndroid Build Coastguard Worker if lhs_ivl[0] > lhs_ivl[1] or rhs_ivl[0] > rhs_ivl[1]: 79*523fa7a6SAndroid Build Coastguard Worker return False 80*523fa7a6SAndroid Build Coastguard Worker 81*523fa7a6SAndroid Build Coastguard Worker return (lhs_ivl[0] >= rhs_ivl[0] and lhs_ivl[0] <= rhs_ivl[1]) or ( 82*523fa7a6SAndroid Build Coastguard Worker rhs_ivl[0] >= lhs_ivl[0] and rhs_ivl[0] <= lhs_ivl[1] 83*523fa7a6SAndroid Build Coastguard Worker ) 84*523fa7a6SAndroid Build Coastguard Worker 85*523fa7a6SAndroid Build Coastguard Worker @classmethod 86*523fa7a6SAndroid Build Coastguard Worker def lifetime_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool: 87*523fa7a6SAndroid Build Coastguard Worker lhs_lifetime = lhs_spec.lifetime 88*523fa7a6SAndroid Build Coastguard Worker rhs_lifetime = rhs_spec.lifetime 89*523fa7a6SAndroid Build Coastguard Worker internal_assert( 90*523fa7a6SAndroid Build Coastguard Worker lhs_lifetime[0] is not None and lhs_lifetime[1] is not None, 91*523fa7a6SAndroid Build Coastguard Worker f"{lhs_spec} should have valid start and end", 92*523fa7a6SAndroid Build Coastguard Worker ) 93*523fa7a6SAndroid Build Coastguard Worker internal_assert( 94*523fa7a6SAndroid Build Coastguard Worker rhs_lifetime[0] is not None and rhs_lifetime[1] is not None, 95*523fa7a6SAndroid Build Coastguard Worker f"{rhs_spec} should have valid start and end", 96*523fa7a6SAndroid Build Coastguard Worker ) 97*523fa7a6SAndroid Build Coastguard Worker return cls.has_overlap(lhs_lifetime, rhs_lifetime) 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker @classmethod 100*523fa7a6SAndroid Build Coastguard Worker def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool: 101*523fa7a6SAndroid Build Coastguard Worker intervals = [] 102*523fa7a6SAndroid Build Coastguard Worker if lhs_spec.mem_id != rhs_spec.mem_id: 103*523fa7a6SAndroid Build Coastguard Worker return False 104*523fa7a6SAndroid Build Coastguard Worker for spec in [lhs_spec, rhs_spec]: 105*523fa7a6SAndroid Build Coastguard Worker internal_assert( 106*523fa7a6SAndroid Build Coastguard Worker spec.allocated_memory >= 0, 107*523fa7a6SAndroid Build Coastguard Worker f"{spec} should have non-zero allocated memory", 108*523fa7a6SAndroid Build Coastguard Worker ) 109*523fa7a6SAndroid Build Coastguard Worker internal_assert( 110*523fa7a6SAndroid Build Coastguard Worker isinstance(spec.mem_offset, int) and spec.mem_offset >= 0, 111*523fa7a6SAndroid Build Coastguard Worker f"{spec} should have specified memory offset", 112*523fa7a6SAndroid Build Coastguard Worker ) 113*523fa7a6SAndroid Build Coastguard Worker intervals.append( 114*523fa7a6SAndroid Build Coastguard Worker [spec.mem_offset, spec.mem_offset + spec.allocated_memory - 1] 115*523fa7a6SAndroid Build Coastguard Worker ) 116*523fa7a6SAndroid Build Coastguard Worker has_overlap = cls.has_overlap(*intervals) 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Worker return has_overlap 119*523fa7a6SAndroid Build Coastguard Worker 120*523fa7a6SAndroid Build Coastguard Worker def verify_storage_reuse( 121*523fa7a6SAndroid Build Coastguard Worker self, allow_lifetime_and_storage_overlap: bool = False 122*523fa7a6SAndroid Build Coastguard Worker ) -> int: 123*523fa7a6SAndroid Build Coastguard Worker """ 124*523fa7a6SAndroid Build Coastguard Worker 'allow_lifetime_and_storage_overlap' allows tensors to overlap in both 125*523fa7a6SAndroid Build Coastguard Worker lifetime and storage. If is it False, and two tensors have both overlapping 126*523fa7a6SAndroid Build Coastguard Worker lifetime and storage, throw an exception. 127*523fa7a6SAndroid Build Coastguard Worker Returns: 128*523fa7a6SAndroid Build Coastguard Worker Number of pairs of tenors that have overlapping storage. 129*523fa7a6SAndroid Build Coastguard Worker """ 130*523fa7a6SAndroid Build Coastguard Worker num_reuse_pairs = 0 131*523fa7a6SAndroid Build Coastguard Worker 132*523fa7a6SAndroid Build Coastguard Worker # unique tensors specs 133*523fa7a6SAndroid Build Coastguard Worker all_specs = list( 134*523fa7a6SAndroid Build Coastguard Worker collect_specs_from_nodes( 135*523fa7a6SAndroid Build Coastguard Worker self.graph_module.graph.nodes, 136*523fa7a6SAndroid Build Coastguard Worker self.graph_signature, 137*523fa7a6SAndroid Build Coastguard Worker ignore_const=True, 138*523fa7a6SAndroid Build Coastguard Worker ignore_graph_input=not self.alloc_graph_input, 139*523fa7a6SAndroid Build Coastguard Worker ignore_graph_output=not self.alloc_graph_output, 140*523fa7a6SAndroid Build Coastguard Worker do_assertion=False, 141*523fa7a6SAndroid Build Coastguard Worker ignore_out_var_node=False, 142*523fa7a6SAndroid Build Coastguard Worker dedup=True, 143*523fa7a6SAndroid Build Coastguard Worker ) 144*523fa7a6SAndroid Build Coastguard Worker ) 145*523fa7a6SAndroid Build Coastguard Worker 146*523fa7a6SAndroid Build Coastguard Worker for lhs_spec_idx, lhs_spec in enumerate(all_specs): 147*523fa7a6SAndroid Build Coastguard Worker for rhs_spec in all_specs[lhs_spec_idx + 1 :]: 148*523fa7a6SAndroid Build Coastguard Worker # Check that both specs are consistent about whether mem_obj_id is defined 149*523fa7a6SAndroid Build Coastguard Worker if (lhs_spec.mem_obj_id is None) != (rhs_spec.mem_obj_id is None): 150*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 151*523fa7a6SAndroid Build Coastguard Worker "Specs do not agree on whether mem_obj_id is defined." 152*523fa7a6SAndroid Build Coastguard Worker ) 153*523fa7a6SAndroid Build Coastguard Worker 154*523fa7a6SAndroid Build Coastguard Worker has_storage_overlap = Verifier.storage_overlap(lhs_spec, rhs_spec) 155*523fa7a6SAndroid Build Coastguard Worker if not has_storage_overlap: 156*523fa7a6SAndroid Build Coastguard Worker continue 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard Worker if not allow_lifetime_and_storage_overlap and self.lifetime_overlap( 159*523fa7a6SAndroid Build Coastguard Worker lhs_spec, rhs_spec 160*523fa7a6SAndroid Build Coastguard Worker ): 161*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 162*523fa7a6SAndroid Build Coastguard Worker f"Unexpected storage overlap: lhs {lhs_spec}, rhs {rhs_spec}" 163*523fa7a6SAndroid Build Coastguard Worker ) 164*523fa7a6SAndroid Build Coastguard Worker 165*523fa7a6SAndroid Build Coastguard Worker # Check that each mem_obj_id is consistent with whether the tensors have 166*523fa7a6SAndroid Build Coastguard Worker # storage overlap 167*523fa7a6SAndroid Build Coastguard Worker if not Verifier.mem_obj_id_match(lhs_spec, rhs_spec): 168*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 169*523fa7a6SAndroid Build Coastguard Worker f"Unexpected mem_obj_id mismatch: lhs {lhs_spec}, rhs {rhs_spec}" 170*523fa7a6SAndroid Build Coastguard Worker ) 171*523fa7a6SAndroid Build Coastguard Worker 172*523fa7a6SAndroid Build Coastguard Worker num_reuse_pairs += 1 173*523fa7a6SAndroid Build Coastguard Worker 174*523fa7a6SAndroid Build Coastguard Worker return num_reuse_pairs 175*523fa7a6SAndroid Build Coastguard Worker 176*523fa7a6SAndroid Build Coastguard Worker def verify_graph_input_output(self) -> None: 177*523fa7a6SAndroid Build Coastguard Worker r""" 178*523fa7a6SAndroid Build Coastguard Worker alloc_graph_input / alloc_graph_output indicas if memory for graph 179*523fa7a6SAndroid Build Coastguard Worker input/output is allocated by the compiler. If not, the runtime will 180*523fa7a6SAndroid Build Coastguard Worker set them using buffers provided by users. 181*523fa7a6SAndroid Build Coastguard Worker """ 182*523fa7a6SAndroid Build Coastguard Worker graph_module = self.graph_module 183*523fa7a6SAndroid Build Coastguard Worker # There is one tricky case here. If the graph input and graph output 184*523fa7a6SAndroid Build Coastguard Worker # tensors have overlap, but alloc_graph_input != alloc_graph_output, 185*523fa7a6SAndroid Build Coastguard Worker # then the overlapped tensor will cause assertion failure below. 186*523fa7a6SAndroid Build Coastguard Worker # The current behavior is if either alloc_graph_input or alloc_graph_output 187*523fa7a6SAndroid Build Coastguard Worker # is false, those overlapped tensor will not have memory allocated. 188*523fa7a6SAndroid Build Coastguard Worker # 189*523fa7a6SAndroid Build Coastguard Worker # Ignore the check in this case for now. 190*523fa7a6SAndroid Build Coastguard Worker overlap = get_graph_input_tensors( 191*523fa7a6SAndroid Build Coastguard Worker graph_module.graph.nodes, self.graph_signature 192*523fa7a6SAndroid Build Coastguard Worker ) & get_graph_output_tensors(graph_module.graph.nodes) 193*523fa7a6SAndroid Build Coastguard Worker if overlap and (self.alloc_graph_input != self.alloc_graph_output): 194*523fa7a6SAndroid Build Coastguard Worker logging.debug( 195*523fa7a6SAndroid Build Coastguard Worker "Having overlapping graph input/output tensors while the allocation decision for graph input/output mismatch." 196*523fa7a6SAndroid Build Coastguard Worker ) 197*523fa7a6SAndroid Build Coastguard Worker return 198*523fa7a6SAndroid Build Coastguard Worker 199*523fa7a6SAndroid Build Coastguard Worker graph_input_allocated = None 200*523fa7a6SAndroid Build Coastguard Worker graph_output_allocated = None 201*523fa7a6SAndroid Build Coastguard Worker 202*523fa7a6SAndroid Build Coastguard Worker has_dynamic_unbound_input = False 203*523fa7a6SAndroid Build Coastguard Worker has_dynamic_unbound_output = False 204*523fa7a6SAndroid Build Coastguard Worker 205*523fa7a6SAndroid Build Coastguard Worker check_list = {"placeholder", "output"} & { 206*523fa7a6SAndroid Build Coastguard Worker node.op for node in graph_module.graph.nodes 207*523fa7a6SAndroid Build Coastguard Worker } 208*523fa7a6SAndroid Build Coastguard Worker assert "output" in check_list, f"graph module has no output: {graph_module}" 209*523fa7a6SAndroid Build Coastguard Worker 210*523fa7a6SAndroid Build Coastguard Worker for nd in graph_module.graph.nodes: 211*523fa7a6SAndroid Build Coastguard Worker if nd.op in check_list: 212*523fa7a6SAndroid Build Coastguard Worker if not (specs := get_node_tensor_specs(nd)): 213*523fa7a6SAndroid Build Coastguard Worker continue 214*523fa7a6SAndroid Build Coastguard Worker if _is_mutable_buffer(nd, self.graph_signature): 215*523fa7a6SAndroid Build Coastguard Worker continue 216*523fa7a6SAndroid Build Coastguard Worker assert len(specs) > 0, "Expect tensor specs" 217*523fa7a6SAndroid Build Coastguard Worker specs = list(filter(lambda spec: not spec.const, specs)) 218*523fa7a6SAndroid Build Coastguard Worker if len(specs) == 0: 219*523fa7a6SAndroid Build Coastguard Worker continue 220*523fa7a6SAndroid Build Coastguard Worker allocated = any( 221*523fa7a6SAndroid Build Coastguard Worker spec is None or spec.mem_offset is not None for spec in specs 222*523fa7a6SAndroid Build Coastguard Worker ) 223*523fa7a6SAndroid Build Coastguard Worker has_dynamic_unbound_tensor = any( 224*523fa7a6SAndroid Build Coastguard Worker spec is None 225*523fa7a6SAndroid Build Coastguard Worker or spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND 226*523fa7a6SAndroid Build Coastguard Worker for spec in specs 227*523fa7a6SAndroid Build Coastguard Worker ) 228*523fa7a6SAndroid Build Coastguard Worker assert ( 229*523fa7a6SAndroid Build Coastguard Worker all(spec is None or spec.mem_offset is not None for spec in specs) 230*523fa7a6SAndroid Build Coastguard Worker == allocated 231*523fa7a6SAndroid Build Coastguard Worker ), "Either all or non of the tensors should be allocated memory" 232*523fa7a6SAndroid Build Coastguard Worker if nd.op == "placeholder": 233*523fa7a6SAndroid Build Coastguard Worker graph_input_allocated = allocated 234*523fa7a6SAndroid Build Coastguard Worker has_dynamic_unbound_input |= has_dynamic_unbound_tensor 235*523fa7a6SAndroid Build Coastguard Worker else: 236*523fa7a6SAndroid Build Coastguard Worker graph_output_allocated = allocated 237*523fa7a6SAndroid Build Coastguard Worker has_dynamic_unbound_output |= has_dynamic_unbound_tensor 238*523fa7a6SAndroid Build Coastguard Worker 239*523fa7a6SAndroid Build Coastguard Worker if "placeholder" in check_list: 240*523fa7a6SAndroid Build Coastguard Worker assert graph_input_allocated is not None, "graph_input_allocated not set" 241*523fa7a6SAndroid Build Coastguard Worker if not has_dynamic_unbound_input: 242*523fa7a6SAndroid Build Coastguard Worker assert ( 243*523fa7a6SAndroid Build Coastguard Worker graph_input_allocated == self.alloc_graph_input 244*523fa7a6SAndroid Build Coastguard Worker ), f"Misallocate graph input: {graph_input_allocated} v.s. {self.alloc_graph_input}" 245*523fa7a6SAndroid Build Coastguard Worker 246*523fa7a6SAndroid Build Coastguard Worker assert graph_output_allocated is not None, "graph_output_allocated not set" 247*523fa7a6SAndroid Build Coastguard Worker if not has_dynamic_unbound_output: 248*523fa7a6SAndroid Build Coastguard Worker assert ( 249*523fa7a6SAndroid Build Coastguard Worker graph_output_allocated == self.alloc_graph_output 250*523fa7a6SAndroid Build Coastguard Worker ), f"Misallocate graph output {graph_output_allocated} v.s. {self.alloc_graph_output}" 251*523fa7a6SAndroid Build Coastguard Worker 252*523fa7a6SAndroid Build Coastguard Worker 253*523fa7a6SAndroid Build Coastguard Workerdef _is_out_var_node(node: torch.fx.Node) -> bool: 254*523fa7a6SAndroid Build Coastguard Worker return ( 255*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 256*523fa7a6SAndroid Build Coastguard Worker and isinstance(node.target, torch._ops.OpOverload) 257*523fa7a6SAndroid Build Coastguard Worker and is_out_variant(node.target._schema.name, node.target._schema.overload_name) 258*523fa7a6SAndroid Build Coastguard Worker ) 259*523fa7a6SAndroid Build Coastguard Worker 260*523fa7a6SAndroid Build Coastguard Worker 261*523fa7a6SAndroid Build Coastguard Workerdef _is_inplace_node(node: torch.fx.Node) -> bool: 262*523fa7a6SAndroid Build Coastguard Worker return ( 263*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 264*523fa7a6SAndroid Build Coastguard Worker and isinstance(node.target, torch._ops.OpOverload) 265*523fa7a6SAndroid Build Coastguard Worker and is_inplace_variant( 266*523fa7a6SAndroid Build Coastguard Worker node.target._schema.name, node.target._schema.overload_name 267*523fa7a6SAndroid Build Coastguard Worker ) 268*523fa7a6SAndroid Build Coastguard Worker ) 269*523fa7a6SAndroid Build Coastguard Worker 270*523fa7a6SAndroid Build Coastguard Worker 271*523fa7a6SAndroid Build Coastguard Workerdef update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None: 272*523fa7a6SAndroid Build Coastguard Worker r""" 273*523fa7a6SAndroid Build Coastguard Worker Update the lifetime of the tensor to cover node_idx. A tensor's lifetime 274*523fa7a6SAndroid Build Coastguard Worker are represented by the index of the first and last node referring 275*523fa7a6SAndroid Build Coastguard Worker that tensor in its inputs/outputs. 276*523fa7a6SAndroid Build Coastguard Worker 277*523fa7a6SAndroid Build Coastguard Worker Arguments: 278*523fa7a6SAndroid Build Coastguard Worker spec: the TensorSpec for the tensor 279*523fa7a6SAndroid Build Coastguard Worker node_idx: extend the tensor's lifetime to cover node_idx 280*523fa7a6SAndroid Build Coastguard Worker """ 281*523fa7a6SAndroid Build Coastguard Worker start, end = spec.lifetime 282*523fa7a6SAndroid Build Coastguard Worker start = node_idx if start is None or start > node_idx else start 283*523fa7a6SAndroid Build Coastguard Worker end = node_idx if end is None or end < node_idx else end 284*523fa7a6SAndroid Build Coastguard Worker spec.lifetime = [start, end] 285*523fa7a6SAndroid Build Coastguard Worker 286*523fa7a6SAndroid Build Coastguard Worker 287*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore 288*523fa7a6SAndroid Build Coastguard Workerdef filter_nodes(inputs: Iterable[Any]) -> Iterable[Node]: 289*523fa7a6SAndroid Build Coastguard Worker """ 290*523fa7a6SAndroid Build Coastguard Worker This method need return Node object embedded inside List/Dict as well. 291*523fa7a6SAndroid Build Coastguard Worker """ 292*523fa7a6SAndroid Build Coastguard Worker return [nd for nd in tree_flatten(list(inputs))[0] if isinstance(nd, Node)] 293*523fa7a6SAndroid Build Coastguard Worker 294*523fa7a6SAndroid Build Coastguard Worker 295*523fa7a6SAndroid Build Coastguard Workerdef _is_mutable_buffer( 296*523fa7a6SAndroid Build Coastguard Worker node: Node, graph_signature: Optional[ExportGraphSignature] = None 297*523fa7a6SAndroid Build Coastguard Worker) -> bool: 298*523fa7a6SAndroid Build Coastguard Worker """ 299*523fa7a6SAndroid Build Coastguard Worker Check if the node is mutable buffer according to the provided graph signature. 300*523fa7a6SAndroid Build Coastguard Worker """ 301*523fa7a6SAndroid Build Coastguard Worker # graph signature is None for memory planning passes not called from EdgeProgramManager, these paths are deprecated so mutable buffers are not supported on them. 302*523fa7a6SAndroid Build Coastguard Worker if graph_signature is None: 303*523fa7a6SAndroid Build Coastguard Worker return False 304*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 305*523fa7a6SAndroid Build Coastguard Worker if isinstance(node.target, str): 306*523fa7a6SAndroid Build Coastguard Worker if node.target in graph_signature.inputs_to_buffers: 307*523fa7a6SAndroid Build Coastguard Worker fqn = graph_signature.inputs_to_buffers[node.target] 308*523fa7a6SAndroid Build Coastguard Worker # if the buffer is mutated then record that 309*523fa7a6SAndroid Build Coastguard Worker if fqn in graph_signature.buffers_to_mutate.values(): 310*523fa7a6SAndroid Build Coastguard Worker return True 311*523fa7a6SAndroid Build Coastguard Worker return False 312*523fa7a6SAndroid Build Coastguard Worker 313*523fa7a6SAndroid Build Coastguard Worker 314*523fa7a6SAndroid Build Coastguard Workerdef get_graph_input_tensors( 315*523fa7a6SAndroid Build Coastguard Worker nodes: Iterable[Node], graph_signature: Optional[ExportGraphSignature] = None 316*523fa7a6SAndroid Build Coastguard Worker) -> Set[TensorSpec]: 317*523fa7a6SAndroid Build Coastguard Worker graph_input_tensors = set() 318*523fa7a6SAndroid Build Coastguard Worker for node in nodes: 319*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder" and not _is_mutable_buffer(node, graph_signature): 320*523fa7a6SAndroid Build Coastguard Worker for spec in get_node_tensor_specs(node): 321*523fa7a6SAndroid Build Coastguard Worker graph_input_tensors.add(spec) 322*523fa7a6SAndroid Build Coastguard Worker 323*523fa7a6SAndroid Build Coastguard Worker return graph_input_tensors 324*523fa7a6SAndroid Build Coastguard Worker 325*523fa7a6SAndroid Build Coastguard Worker 326*523fa7a6SAndroid Build Coastguard Workerdef get_graph_output_tensors(nodes: Iterable[Node]) -> Set[TensorSpec]: 327*523fa7a6SAndroid Build Coastguard Worker graph_output_tensors = set() 328*523fa7a6SAndroid Build Coastguard Worker for node in nodes: 329*523fa7a6SAndroid Build Coastguard Worker if node.op == "output": 330*523fa7a6SAndroid Build Coastguard Worker for spec in get_node_tensor_specs(node): 331*523fa7a6SAndroid Build Coastguard Worker graph_output_tensors.add(spec) 332*523fa7a6SAndroid Build Coastguard Worker 333*523fa7a6SAndroid Build Coastguard Worker return graph_output_tensors 334*523fa7a6SAndroid Build Coastguard Worker 335*523fa7a6SAndroid Build Coastguard Worker 336*523fa7a6SAndroid Build Coastguard Workerdef collect_specs_from_nodes( # noqa: C901 337*523fa7a6SAndroid Build Coastguard Worker nodes: Iterable[Node], 338*523fa7a6SAndroid Build Coastguard Worker graph_signature: Optional[ExportGraphSignature] = None, 339*523fa7a6SAndroid Build Coastguard Worker ignore_graph_input: bool = False, 340*523fa7a6SAndroid Build Coastguard Worker ignore_graph_output: bool = False, 341*523fa7a6SAndroid Build Coastguard Worker ignore_const: bool = True, 342*523fa7a6SAndroid Build Coastguard Worker ignore_out_var_node: bool = True, 343*523fa7a6SAndroid Build Coastguard Worker dedup: bool = True, 344*523fa7a6SAndroid Build Coastguard Worker do_assertion: bool = True, 345*523fa7a6SAndroid Build Coastguard Worker ignore_dynamic_unbound_tensor: bool = True, 346*523fa7a6SAndroid Build Coastguard Worker) -> Iterable[TensorSpec]: 347*523fa7a6SAndroid Build Coastguard Worker r""" 348*523fa7a6SAndroid Build Coastguard Worker Collect specs from the passed in nodes. Do filtering as controlled by 349*523fa7a6SAndroid Build Coastguard Worker arguments. 350*523fa7a6SAndroid Build Coastguard Worker Arguments: 351*523fa7a6SAndroid Build Coastguard Worker ignore_graph_input: ignore graph input tensors from placeholder nodes 352*523fa7a6SAndroid Build Coastguard Worker ignore_const: whether to ignore the const 353*523fa7a6SAndroid Build Coastguard Worker ignore_out_var_node: whether to ignore out variant node 354*523fa7a6SAndroid Build Coastguard Worker dedup: whether do dedup 355*523fa7a6SAndroid Build Coastguard Worker do_assertion: whether to assert the filtered nodes belong to a resticted set like alloc, getitem 356*523fa7a6SAndroid Build Coastguard Worker """ 357*523fa7a6SAndroid Build Coastguard Worker unique_spec = set() 358*523fa7a6SAndroid Build Coastguard Worker graph_input_tensors: Set[TensorSpec] = ( 359*523fa7a6SAndroid Build Coastguard Worker get_graph_input_tensors(nodes, graph_signature) if ignore_graph_input else set() 360*523fa7a6SAndroid Build Coastguard Worker ) 361*523fa7a6SAndroid Build Coastguard Worker graph_output_tensors: Set[TensorSpec] = ( 362*523fa7a6SAndroid Build Coastguard Worker get_graph_output_tensors(nodes) if ignore_graph_output else set() 363*523fa7a6SAndroid Build Coastguard Worker ) 364*523fa7a6SAndroid Build Coastguard Worker 365*523fa7a6SAndroid Build Coastguard Worker for node in nodes: 366*523fa7a6SAndroid Build Coastguard Worker # ignore the specs from unrelevant Fx ops for now. 367*523fa7a6SAndroid Build Coastguard Worker if node.op in ["get_attr"]: 368*523fa7a6SAndroid Build Coastguard Worker continue 369*523fa7a6SAndroid Build Coastguard Worker 370*523fa7a6SAndroid Build Coastguard Worker # don't reallocate memory for out-variant op's output tensors, 371*523fa7a6SAndroid Build Coastguard Worker # since they are just input tenors. 372*523fa7a6SAndroid Build Coastguard Worker if ignore_out_var_node and _is_out_var_node(node): 373*523fa7a6SAndroid Build Coastguard Worker continue 374*523fa7a6SAndroid Build Coastguard Worker 375*523fa7a6SAndroid Build Coastguard Worker if not (specs := get_node_tensor_specs(node)): 376*523fa7a6SAndroid Build Coastguard Worker continue 377*523fa7a6SAndroid Build Coastguard Worker 378*523fa7a6SAndroid Build Coastguard Worker if _is_inplace_node(node): 379*523fa7a6SAndroid Build Coastguard Worker continue 380*523fa7a6SAndroid Build Coastguard Worker 381*523fa7a6SAndroid Build Coastguard Worker if do_assertion: 382*523fa7a6SAndroid Build Coastguard Worker internal_assert( 383*523fa7a6SAndroid Build Coastguard Worker node.op in ("placeholder", "output") 384*523fa7a6SAndroid Build Coastguard Worker or node.target 385*523fa7a6SAndroid Build Coastguard Worker in [ 386*523fa7a6SAndroid Build Coastguard Worker memory.alloc, 387*523fa7a6SAndroid Build Coastguard Worker memory.view, 388*523fa7a6SAndroid Build Coastguard Worker operator.getitem, 389*523fa7a6SAndroid Build Coastguard Worker torch.ops.higher_order.cond, 390*523fa7a6SAndroid Build Coastguard Worker exir_while, 391*523fa7a6SAndroid Build Coastguard Worker torch.ops.higher_order.map_impl, 392*523fa7a6SAndroid Build Coastguard Worker executorch_call_delegate, 393*523fa7a6SAndroid Build Coastguard Worker ], 394*523fa7a6SAndroid Build Coastguard Worker f"Unexpected op {node.op}, target {node.target}", 395*523fa7a6SAndroid Build Coastguard Worker ) 396*523fa7a6SAndroid Build Coastguard Worker for spec in specs: 397*523fa7a6SAndroid Build Coastguard Worker if spec is None: 398*523fa7a6SAndroid Build Coastguard Worker continue 399*523fa7a6SAndroid Build Coastguard Worker # Dynamic unbound tensors' memory will be allocated by the runtime. 400*523fa7a6SAndroid Build Coastguard Worker # Memory planning should ignore them. 401*523fa7a6SAndroid Build Coastguard Worker if ( 402*523fa7a6SAndroid Build Coastguard Worker ignore_dynamic_unbound_tensor 403*523fa7a6SAndroid Build Coastguard Worker and spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND 404*523fa7a6SAndroid Build Coastguard Worker ): 405*523fa7a6SAndroid Build Coastguard Worker continue 406*523fa7a6SAndroid Build Coastguard Worker 407*523fa7a6SAndroid Build Coastguard Worker # Note: graph input may be the output of other ops (e.g. the return op) 408*523fa7a6SAndroid Build Coastguard Worker # If ignore_graph_input is true, we should ignore those Tensor so 409*523fa7a6SAndroid Build Coastguard Worker # we skip planning memory for graph input. 410*523fa7a6SAndroid Build Coastguard Worker if ignore_graph_input and spec in graph_input_tensors: 411*523fa7a6SAndroid Build Coastguard Worker continue 412*523fa7a6SAndroid Build Coastguard Worker if ignore_graph_output and spec in graph_output_tensors: 413*523fa7a6SAndroid Build Coastguard Worker continue 414*523fa7a6SAndroid Build Coastguard Worker if ( 415*523fa7a6SAndroid Build Coastguard Worker ignore_const 416*523fa7a6SAndroid Build Coastguard Worker and spec.const 417*523fa7a6SAndroid Build Coastguard Worker and not node.meta.get("weight_has_gradient", False) 418*523fa7a6SAndroid Build Coastguard Worker ): 419*523fa7a6SAndroid Build Coastguard Worker continue 420*523fa7a6SAndroid Build Coastguard Worker if dedup: 421*523fa7a6SAndroid Build Coastguard Worker if spec in unique_spec: 422*523fa7a6SAndroid Build Coastguard Worker continue 423*523fa7a6SAndroid Build Coastguard Worker else: 424*523fa7a6SAndroid Build Coastguard Worker unique_spec.add(spec) 425*523fa7a6SAndroid Build Coastguard Worker yield spec 426*523fa7a6SAndroid Build Coastguard Worker 427*523fa7a6SAndroid Build Coastguard Worker 428*523fa7a6SAndroid Build Coastguard Workerdef update_all_tensors_lifetime( 429*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 430*523fa7a6SAndroid Build Coastguard Worker graph_signature: Optional[ExportGraphSignature] = None, 431*523fa7a6SAndroid Build Coastguard Worker) -> Set[TensorSpec]: 432*523fa7a6SAndroid Build Coastguard Worker r""" 433*523fa7a6SAndroid Build Coastguard Worker Set the lifetime for all the tensors encountered in the Fx graph. 434*523fa7a6SAndroid Build Coastguard Worker """ 435*523fa7a6SAndroid Build Coastguard Worker specs = set() 436*523fa7a6SAndroid Build Coastguard Worker for node_idx, node in enumerate(graph_module.graph.nodes): 437*523fa7a6SAndroid Build Coastguard Worker for spec in collect_specs_from_nodes( 438*523fa7a6SAndroid Build Coastguard Worker filter_nodes(itertools.chain([node], node.args, node.kwargs.values())), 439*523fa7a6SAndroid Build Coastguard Worker graph_signature, 440*523fa7a6SAndroid Build Coastguard Worker ignore_graph_input=False, 441*523fa7a6SAndroid Build Coastguard Worker ignore_const=False, 442*523fa7a6SAndroid Build Coastguard Worker ignore_out_var_node=False, 443*523fa7a6SAndroid Build Coastguard Worker dedup=False, 444*523fa7a6SAndroid Build Coastguard Worker do_assertion=False, 445*523fa7a6SAndroid Build Coastguard Worker ignore_dynamic_unbound_tensor=False, 446*523fa7a6SAndroid Build Coastguard Worker ): 447*523fa7a6SAndroid Build Coastguard Worker update_tensor_lifetime(spec, node_idx) 448*523fa7a6SAndroid Build Coastguard Worker specs.add(spec) 449*523fa7a6SAndroid Build Coastguard Worker return specs 450*523fa7a6SAndroid Build Coastguard Worker 451*523fa7a6SAndroid Build Coastguard Worker 452*523fa7a6SAndroid Build Coastguard Worker@dataclass 453*523fa7a6SAndroid Build Coastguard Workerclass SharedObject: 454*523fa7a6SAndroid Build Coastguard Worker r""" 455*523fa7a6SAndroid Build Coastguard Worker We define the concept of shared object, which represents a segment 456*523fa7a6SAndroid Build Coastguard Worker in the memory buffer that can be shared by multiple tensors. In order to 457*523fa7a6SAndroid Build Coastguard Worker check if a shared object is available for a tensor, we maintain the 458*523fa7a6SAndroid Build Coastguard Worker last_used_index attribute. The shared object will be available for nodes 459*523fa7a6SAndroid Build Coastguard Worker with index greater than last_used_index. 460*523fa7a6SAndroid Build Coastguard Worker """ 461*523fa7a6SAndroid Build Coastguard Worker 462*523fa7a6SAndroid Build Coastguard Worker # index of the shared object in the list of shared objects, used as a unique id 463*523fa7a6SAndroid Build Coastguard Worker idx: int 464*523fa7a6SAndroid Build Coastguard Worker # offset in the memory buffer 465*523fa7a6SAndroid Build Coastguard Worker offset: int 466*523fa7a6SAndroid Build Coastguard Worker # size of this shared object in bytes 467*523fa7a6SAndroid Build Coastguard Worker size: int 468*523fa7a6SAndroid Build Coastguard Worker # the object will be available for index (last_used_index + 1) 469*523fa7a6SAndroid Build Coastguard Worker last_used_index: int 470*523fa7a6SAndroid Build Coastguard Worker 471*523fa7a6SAndroid Build Coastguard Worker 472*523fa7a6SAndroid Build Coastguard Workerdef materialize_buffer( 473*523fa7a6SAndroid Build Coastguard Worker shared_objects: List[SharedObject], input_total_size: int = 0 474*523fa7a6SAndroid Build Coastguard Worker) -> int: 475*523fa7a6SAndroid Build Coastguard Worker r""" 476*523fa7a6SAndroid Build Coastguard Worker Assign concrete location in the buffer for each SharedObject.offset. 477*523fa7a6SAndroid Build Coastguard Worker 478*523fa7a6SAndroid Build Coastguard Worker Assuming all the passed in shared objects belong to the same memory buffer. 479*523fa7a6SAndroid Build Coastguard Worker """ 480*523fa7a6SAndroid Build Coastguard Worker total_size = input_total_size 481*523fa7a6SAndroid Build Coastguard Worker for sobj in shared_objects: 482*523fa7a6SAndroid Build Coastguard Worker sobj.offset = total_size 483*523fa7a6SAndroid Build Coastguard Worker total_size += sobj.size 484*523fa7a6SAndroid Build Coastguard Worker return total_size 485*523fa7a6SAndroid Build Coastguard Worker 486*523fa7a6SAndroid Build Coastguard Worker 487*523fa7a6SAndroid Build Coastguard Workerdef _size_abs_dif(sobj: SharedObject, spec: TensorSpec) -> int: 488*523fa7a6SAndroid Build Coastguard Worker r""" 489*523fa7a6SAndroid Build Coastguard Worker Calculate the absolute different between the size of a shared object and 490*523fa7a6SAndroid Build Coastguard Worker a tensor. 491*523fa7a6SAndroid Build Coastguard Worker """ 492*523fa7a6SAndroid Build Coastguard Worker return abs(sobj.size - spec.allocated_memory) 493*523fa7a6SAndroid Build Coastguard Worker 494*523fa7a6SAndroid Build Coastguard Worker 495*523fa7a6SAndroid Build Coastguard Workerdef pick_shared_obj( 496*523fa7a6SAndroid Build Coastguard Worker shared_objects: List[SharedObject], spec: TensorSpec 497*523fa7a6SAndroid Build Coastguard Worker) -> SharedObject: 498*523fa7a6SAndroid Build Coastguard Worker r""" 499*523fa7a6SAndroid Build Coastguard Worker Pick the available shared object with closest size to the tensor. 500*523fa7a6SAndroid Build Coastguard Worker If there are no available shared object left, create a new one. 501*523fa7a6SAndroid Build Coastguard Worker """ 502*523fa7a6SAndroid Build Coastguard Worker # TODO: do better than linear scan 503*523fa7a6SAndroid Build Coastguard Worker picked = None 504*523fa7a6SAndroid Build Coastguard Worker for sobj in shared_objects: 505*523fa7a6SAndroid Build Coastguard Worker if spec.lifetime[0] > sobj.last_used_index: 506*523fa7a6SAndroid Build Coastguard Worker if picked is None or _size_abs_dif(sobj, spec) < _size_abs_dif( 507*523fa7a6SAndroid Build Coastguard Worker picked, spec 508*523fa7a6SAndroid Build Coastguard Worker ): 509*523fa7a6SAndroid Build Coastguard Worker picked = sobj 510*523fa7a6SAndroid Build Coastguard Worker sobj.last_used_index = spec.lifetime[1] 511*523fa7a6SAndroid Build Coastguard Worker sobj.size = max(sobj.size, spec.allocated_memory) 512*523fa7a6SAndroid Build Coastguard Worker if picked is None: 513*523fa7a6SAndroid Build Coastguard Worker picked = SharedObject( 514*523fa7a6SAndroid Build Coastguard Worker len(shared_objects), -1, spec.allocated_memory, spec.lifetime[1] 515*523fa7a6SAndroid Build Coastguard Worker ) 516*523fa7a6SAndroid Build Coastguard Worker shared_objects.append(picked) 517*523fa7a6SAndroid Build Coastguard Worker 518*523fa7a6SAndroid Build Coastguard Worker return picked 519*523fa7a6SAndroid Build Coastguard Worker 520*523fa7a6SAndroid Build Coastguard Worker 521*523fa7a6SAndroid Build Coastguard Workerdef get_node_tensor_specs( 522*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 523*523fa7a6SAndroid Build Coastguard Worker) -> Union[List[TensorSpec], Tuple[TensorSpec]]: 524*523fa7a6SAndroid Build Coastguard Worker r""" 525*523fa7a6SAndroid Build Coastguard Worker Return the list of the tensor specs for the node or empty list if the node 526*523fa7a6SAndroid Build Coastguard Worker has no tensor specs. 527*523fa7a6SAndroid Build Coastguard Worker """ 528*523fa7a6SAndroid Build Coastguard Worker # get tensor specs 529*523fa7a6SAndroid Build Coastguard Worker if node.target == memory.view: 530*523fa7a6SAndroid Build Coastguard Worker base = node.args[0] 531*523fa7a6SAndroid Build Coastguard Worker assert isinstance(base, torch.fx.Node) 532*523fa7a6SAndroid Build Coastguard Worker specs = base.meta.get("spec") 533*523fa7a6SAndroid Build Coastguard Worker else: 534*523fa7a6SAndroid Build Coastguard Worker specs = node.meta.get("spec") 535*523fa7a6SAndroid Build Coastguard Worker 536*523fa7a6SAndroid Build Coastguard Worker if isinstance(specs, TensorSpec): 537*523fa7a6SAndroid Build Coastguard Worker specs = [specs] 538*523fa7a6SAndroid Build Coastguard Worker if not isinstance(specs, (list, tuple)): 539*523fa7a6SAndroid Build Coastguard Worker return [] 540*523fa7a6SAndroid Build Coastguard Worker else: 541*523fa7a6SAndroid Build Coastguard Worker return [ 542*523fa7a6SAndroid Build Coastguard Worker spec 543*523fa7a6SAndroid Build Coastguard Worker for spec in specs 544*523fa7a6SAndroid Build Coastguard Worker if not isinstance(spec, (int, float, bool, str, type(None))) 545*523fa7a6SAndroid Build Coastguard Worker ] 546*523fa7a6SAndroid Build Coastguard Worker 547*523fa7a6SAndroid Build Coastguard Worker 548*523fa7a6SAndroid Build Coastguard Workerdef greedy( 549*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 550*523fa7a6SAndroid Build Coastguard Worker alignment: int, 551*523fa7a6SAndroid Build Coastguard Worker graph_signature: Optional[ExportGraphSignature] = None, 552*523fa7a6SAndroid Build Coastguard Worker alloc_graph_input: bool = True, 553*523fa7a6SAndroid Build Coastguard Worker alloc_graph_output: bool = True, 554*523fa7a6SAndroid Build Coastguard Worker) -> List[int]: 555*523fa7a6SAndroid Build Coastguard Worker spec2obj = {} 556*523fa7a6SAndroid Build Coastguard Worker shared_objects = defaultdict(list) 557*523fa7a6SAndroid Build Coastguard Worker # Don't do assertion in collect_specs_from_nodes if we have already encountered 558*523fa7a6SAndroid Build Coastguard Worker # and ignored some to_out_variant errors. 559*523fa7a6SAndroid Build Coastguard Worker do_assertion = not getattr(graph_module, "encounter_to_out_var_failure", False) 560*523fa7a6SAndroid Build Coastguard Worker # For each tensor, pick the available shared object with closest size to 561*523fa7a6SAndroid Build Coastguard Worker # the tensor. If there are no available shared object left, create a new 562*523fa7a6SAndroid Build Coastguard Worker # one. 563*523fa7a6SAndroid Build Coastguard Worker for spec in collect_specs_from_nodes( 564*523fa7a6SAndroid Build Coastguard Worker graph_module.graph.nodes, 565*523fa7a6SAndroid Build Coastguard Worker graph_signature, 566*523fa7a6SAndroid Build Coastguard Worker do_assertion=do_assertion, 567*523fa7a6SAndroid Build Coastguard Worker ignore_graph_input=not alloc_graph_input, 568*523fa7a6SAndroid Build Coastguard Worker ignore_graph_output=not alloc_graph_output, 569*523fa7a6SAndroid Build Coastguard Worker ): 570*523fa7a6SAndroid Build Coastguard Worker if spec.mem_id is None: 571*523fa7a6SAndroid Build Coastguard Worker spec.mem_id = 1 572*523fa7a6SAndroid Build Coastguard Worker spec.realign(alignment) 573*523fa7a6SAndroid Build Coastguard Worker spec2obj[spec] = pick_shared_obj(shared_objects[spec.mem_id], spec) 574*523fa7a6SAndroid Build Coastguard Worker 575*523fa7a6SAndroid Build Coastguard Worker if len(shared_objects) == 0: 576*523fa7a6SAndroid Build Coastguard Worker # Cannot find any tensor in the graph that needs to be allocated. 577*523fa7a6SAndroid Build Coastguard Worker # Return [0, 0] to be consistent with default behavior of naive. 578*523fa7a6SAndroid Build Coastguard Worker total_sizes = [0, 0] 579*523fa7a6SAndroid Build Coastguard Worker else: 580*523fa7a6SAndroid Build Coastguard Worker total_sizes = [0] * (max(shared_objects.keys()) + 1) 581*523fa7a6SAndroid Build Coastguard Worker for mem_id in shared_objects: 582*523fa7a6SAndroid Build Coastguard Worker input_total_size = 0 583*523fa7a6SAndroid Build Coastguard Worker if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None): 584*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[6]: For 1st argument expected 585*523fa7a6SAndroid Build Coastguard Worker # `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`. 586*523fa7a6SAndroid Build Coastguard Worker if len(bufsizes) > mem_id: 587*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten... 588*523fa7a6SAndroid Build Coastguard Worker input_total_size = bufsizes[mem_id] 589*523fa7a6SAndroid Build Coastguard Worker total_sizes[mem_id] = materialize_buffer( 590*523fa7a6SAndroid Build Coastguard Worker shared_objects[mem_id], input_total_size 591*523fa7a6SAndroid Build Coastguard Worker ) 592*523fa7a6SAndroid Build Coastguard Worker 593*523fa7a6SAndroid Build Coastguard Worker # Since we now know the number of shared objects we need and the size of 594*523fa7a6SAndroid Build Coastguard Worker # each shared object, we can assign offset in the memory buffer for each 595*523fa7a6SAndroid Build Coastguard Worker # shared object. 596*523fa7a6SAndroid Build Coastguard Worker for spec, sobj in spec2obj.items(): 597*523fa7a6SAndroid Build Coastguard Worker spec.mem_obj_id = sobj.idx 598*523fa7a6SAndroid Build Coastguard Worker spec.mem_offset = sobj.offset 599*523fa7a6SAndroid Build Coastguard Worker 600*523fa7a6SAndroid Build Coastguard Worker logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}") 601*523fa7a6SAndroid Build Coastguard Worker return total_sizes 602*523fa7a6SAndroid Build Coastguard Worker 603*523fa7a6SAndroid Build Coastguard Worker 604*523fa7a6SAndroid Build Coastguard Workerdef naive( 605*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 606*523fa7a6SAndroid Build Coastguard Worker alignment: int, 607*523fa7a6SAndroid Build Coastguard Worker graph_signature: Optional[ExportGraphSignature] = None, 608*523fa7a6SAndroid Build Coastguard Worker alloc_graph_input: bool = True, 609*523fa7a6SAndroid Build Coastguard Worker alloc_graph_output: bool = True, 610*523fa7a6SAndroid Build Coastguard Worker) -> List[int]: 611*523fa7a6SAndroid Build Coastguard Worker 612*523fa7a6SAndroid Build Coastguard Worker # allocate 'allocated' bytes from buffer with id mem_id. 613*523fa7a6SAndroid Build Coastguard Worker # return the starting offset of the allocated buffer. 614*523fa7a6SAndroid Build Coastguard Worker def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int: 615*523fa7a6SAndroid Build Coastguard Worker if mem_id >= len(bufsizes): 616*523fa7a6SAndroid Build Coastguard Worker bufsizes.extend([0] * (mem_id - len(bufsizes) + 1)) 617*523fa7a6SAndroid Build Coastguard Worker ret = bufsizes[mem_id] 618*523fa7a6SAndroid Build Coastguard Worker bufsizes[mem_id] += allocated 619*523fa7a6SAndroid Build Coastguard Worker return ret 620*523fa7a6SAndroid Build Coastguard Worker 621*523fa7a6SAndroid Build Coastguard Worker bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None) 622*523fa7a6SAndroid Build Coastguard Worker if bufsizes is None: 623*523fa7a6SAndroid Build Coastguard Worker bufsizes = [0, 0] 624*523fa7a6SAndroid Build Coastguard Worker 625*523fa7a6SAndroid Build Coastguard Worker bufsizes = typing.cast(List[int], bufsizes) 626*523fa7a6SAndroid Build Coastguard Worker for spec in collect_specs_from_nodes( 627*523fa7a6SAndroid Build Coastguard Worker graph_module.graph.nodes, 628*523fa7a6SAndroid Build Coastguard Worker graph_signature, 629*523fa7a6SAndroid Build Coastguard Worker ignore_graph_input=not alloc_graph_input, 630*523fa7a6SAndroid Build Coastguard Worker ignore_graph_output=not alloc_graph_output, 631*523fa7a6SAndroid Build Coastguard Worker ): 632*523fa7a6SAndroid Build Coastguard Worker # assume a single memory layer which has mem_id 1 633*523fa7a6SAndroid Build Coastguard Worker if spec.mem_id is None: 634*523fa7a6SAndroid Build Coastguard Worker spec.mem_id = 1 635*523fa7a6SAndroid Build Coastguard Worker # allocate spec.allocated_memory bytes in the buffer 636*523fa7a6SAndroid Build Coastguard Worker # with the corresponding mem_id 637*523fa7a6SAndroid Build Coastguard Worker spec.realign(alignment) 638*523fa7a6SAndroid Build Coastguard Worker spec.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory) 639*523fa7a6SAndroid Build Coastguard Worker 640*523fa7a6SAndroid Build Coastguard Worker logging.debug(f"naive algorithm returns bufsizes: {bufsizes}") 641*523fa7a6SAndroid Build Coastguard Worker return bufsizes 642*523fa7a6SAndroid Build Coastguard Worker 643*523fa7a6SAndroid Build Coastguard Worker 644*523fa7a6SAndroid Build Coastguard Workerdef get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: 645*523fa7a6SAndroid Build Coastguard Worker for nd in graph_module.graph.nodes: 646*523fa7a6SAndroid Build Coastguard Worker if nd.target is torch.ops.higher_order.cond: 647*523fa7a6SAndroid Build Coastguard Worker yield nd 648*523fa7a6SAndroid Build Coastguard Worker 649*523fa7a6SAndroid Build Coastguard Worker 650*523fa7a6SAndroid Build Coastguard Workerdef get_while_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: 651*523fa7a6SAndroid Build Coastguard Worker for nd in graph_module.graph.nodes: 652*523fa7a6SAndroid Build Coastguard Worker if nd.target is exir_while: 653*523fa7a6SAndroid Build Coastguard Worker yield nd 654*523fa7a6SAndroid Build Coastguard Worker 655*523fa7a6SAndroid Build Coastguard Worker 656*523fa7a6SAndroid Build Coastguard Workerdef get_map_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: 657*523fa7a6SAndroid Build Coastguard Worker for nd in graph_module.graph.nodes: 658*523fa7a6SAndroid Build Coastguard Worker if nd.target is torch.ops.higher_order.map_impl: 659*523fa7a6SAndroid Build Coastguard Worker yield nd 660*523fa7a6SAndroid Build Coastguard Worker 661*523fa7a6SAndroid Build Coastguard Worker 662*523fa7a6SAndroid Build Coastguard Workerdef get_return_specs(graph_module: fx.GraphModule) -> Set[TensorSpec]: 663*523fa7a6SAndroid Build Coastguard Worker return_specs = set() 664*523fa7a6SAndroid Build Coastguard Worker nodes = graph_module.graph.nodes 665*523fa7a6SAndroid Build Coastguard Worker if len(nodes) > 0: 666*523fa7a6SAndroid Build Coastguard Worker last_node = next(iter(reversed(nodes))) 667*523fa7a6SAndroid Build Coastguard Worker for spec in tree_flatten(last_node.meta["spec"])[0]: 668*523fa7a6SAndroid Build Coastguard Worker return_specs.add(spec) 669*523fa7a6SAndroid Build Coastguard Worker return return_specs 670*523fa7a6SAndroid Build Coastguard Worker 671*523fa7a6SAndroid Build Coastguard Worker 672*523fa7a6SAndroid Build Coastguard Workerdef get_input_specs(graph_module: fx.GraphModule) -> Set[TensorSpec]: 673*523fa7a6SAndroid Build Coastguard Worker input_specs = set() 674*523fa7a6SAndroid Build Coastguard Worker nodes = graph_module.graph.nodes 675*523fa7a6SAndroid Build Coastguard Worker for node in nodes: 676*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 677*523fa7a6SAndroid Build Coastguard Worker for spec in tree_flatten(node.meta["spec"])[0]: 678*523fa7a6SAndroid Build Coastguard Worker input_specs.add(spec) 679*523fa7a6SAndroid Build Coastguard Worker return input_specs 680*523fa7a6SAndroid Build Coastguard Worker 681*523fa7a6SAndroid Build Coastguard Worker 682*523fa7a6SAndroid Build Coastguard Workerdef insert_calls_to_free( 683*523fa7a6SAndroid Build Coastguard Worker graph_module: fx.GraphModule, allspecs: Set[TensorSpec] 684*523fa7a6SAndroid Build Coastguard Worker) -> None: 685*523fa7a6SAndroid Build Coastguard Worker """ 686*523fa7a6SAndroid Build Coastguard Worker Insert calls to free for dynamic unbound tensors that goes out of lifetime. 687*523fa7a6SAndroid Build Coastguard Worker 688*523fa7a6SAndroid Build Coastguard Worker Only handle the module itself. Submodule is handles in separate calls of 689*523fa7a6SAndroid Build Coastguard Worker this function. 690*523fa7a6SAndroid Build Coastguard Worker 691*523fa7a6SAndroid Build Coastguard Worker NOTE: this method will invalidate lifetime recorded in TensorSpec because 692*523fa7a6SAndroid Build Coastguard Worker of extra free node added to the graph. 693*523fa7a6SAndroid Build Coastguard Worker """ 694*523fa7a6SAndroid Build Coastguard Worker # Note: we should never free a output tensor 695*523fa7a6SAndroid Build Coastguard Worker return_specs = get_return_specs(graph_module) 696*523fa7a6SAndroid Build Coastguard Worker # Note: we should never free a input tensor since buffer for input tensor 697*523fa7a6SAndroid Build Coastguard Worker # may be passed in from user. 698*523fa7a6SAndroid Build Coastguard Worker input_specs = get_input_specs(graph_module) 699*523fa7a6SAndroid Build Coastguard Worker idx_to_dead_specs = defaultdict(list) 700*523fa7a6SAndroid Build Coastguard Worker for spec in allspecs: 701*523fa7a6SAndroid Build Coastguard Worker if ( 702*523fa7a6SAndroid Build Coastguard Worker spec.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND 703*523fa7a6SAndroid Build Coastguard Worker and spec not in return_specs 704*523fa7a6SAndroid Build Coastguard Worker and spec not in input_specs 705*523fa7a6SAndroid Build Coastguard Worker ): 706*523fa7a6SAndroid Build Coastguard Worker idx_to_dead_specs[spec.lifetime[1]].append(spec) 707*523fa7a6SAndroid Build Coastguard Worker 708*523fa7a6SAndroid Build Coastguard Worker num_nodes = len(graph_module.graph.nodes) 709*523fa7a6SAndroid Build Coastguard Worker # iterate in reverse order so inserted node does not disturbe node 710*523fa7a6SAndroid Build Coastguard Worker # numbering. 711*523fa7a6SAndroid Build Coastguard Worker for node, node_idx in zip( 712*523fa7a6SAndroid Build Coastguard Worker reversed(graph_module.graph.nodes), range(num_nodes - 1, -1, -1) 713*523fa7a6SAndroid Build Coastguard Worker ): 714*523fa7a6SAndroid Build Coastguard Worker dead_specs = idx_to_dead_specs.get(node_idx, []) 715*523fa7a6SAndroid Build Coastguard Worker if not dead_specs: 716*523fa7a6SAndroid Build Coastguard Worker continue 717*523fa7a6SAndroid Build Coastguard Worker with graph_module.graph.inserting_after(node): 718*523fa7a6SAndroid Build Coastguard Worker for spec in dead_specs: 719*523fa7a6SAndroid Build Coastguard Worker graph_module.graph.call_function(memory.free, (spec,)) 720*523fa7a6SAndroid Build Coastguard Worker graph_module.recompile() 721*523fa7a6SAndroid Build Coastguard Worker 722*523fa7a6SAndroid Build Coastguard Worker 723*523fa7a6SAndroid Build Coastguard Workerdef apply_algo( 724*523fa7a6SAndroid Build Coastguard Worker algo: Callable[ 725*523fa7a6SAndroid Build Coastguard Worker [torch.fx.GraphModule, int, Optional[ExportGraphSignature], bool, bool], 726*523fa7a6SAndroid Build Coastguard Worker List[int], 727*523fa7a6SAndroid Build Coastguard Worker ], 728*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 729*523fa7a6SAndroid Build Coastguard Worker alignment: int, 730*523fa7a6SAndroid Build Coastguard Worker graph_signature: Optional[ExportGraphSignature] = None, 731*523fa7a6SAndroid Build Coastguard Worker alloc_graph_input: bool = True, 732*523fa7a6SAndroid Build Coastguard Worker alloc_graph_output: bool = True, 733*523fa7a6SAndroid Build Coastguard Worker) -> List[int]: 734*523fa7a6SAndroid Build Coastguard Worker """ 735*523fa7a6SAndroid Build Coastguard Worker Recursively apply algo to graph_module and its submodules for control flow. 736*523fa7a6SAndroid Build Coastguard Worker 737*523fa7a6SAndroid Build Coastguard Worker Quite naively right now since it does not take the following optimizations 738*523fa7a6SAndroid Build Coastguard Worker into considerating: 739*523fa7a6SAndroid Build Coastguard Worker 1. for conditional structure, true branch and false true does not overlap 740*523fa7a6SAndroid Build Coastguard Worker in lifetime and can share tensor storage 741*523fa7a6SAndroid Build Coastguard Worker 2. tensors inside a submodule (e.g. true branch) has opportunities to share 742*523fa7a6SAndroid Build Coastguard Worker storage with tensors in the outer module. 743*523fa7a6SAndroid Build Coastguard Worker TODO: make these optimizations once we have some baseline working. 744*523fa7a6SAndroid Build Coastguard Worker """ 745*523fa7a6SAndroid Build Coastguard Worker specs = update_all_tensors_lifetime(graph_module, graph_signature) 746*523fa7a6SAndroid Build Coastguard Worker bufsizes: List[int] = algo( 747*523fa7a6SAndroid Build Coastguard Worker graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output 748*523fa7a6SAndroid Build Coastguard Worker ) 749*523fa7a6SAndroid Build Coastguard Worker insert_calls_to_free(graph_module, specs) 750*523fa7a6SAndroid Build Coastguard Worker 751*523fa7a6SAndroid Build Coastguard Worker def handle_submodule( 752*523fa7a6SAndroid Build Coastguard Worker submodule_nd: torch.fx.Node, alloc_graph_input: bool = False 753*523fa7a6SAndroid Build Coastguard Worker ) -> None: 754*523fa7a6SAndroid Build Coastguard Worker nonlocal bufsizes 755*523fa7a6SAndroid Build Coastguard Worker assert submodule_nd.op == "get_attr" 756*523fa7a6SAndroid Build Coastguard Worker submodule = getattr(graph_module, submodule_nd.target) 757*523fa7a6SAndroid Build Coastguard Worker # memory planning for submodule need to be aware of the amount of 758*523fa7a6SAndroid Build Coastguard Worker # buffer already allocated. 759*523fa7a6SAndroid Build Coastguard Worker submodule.input_mem_buffer_sizes = bufsizes 760*523fa7a6SAndroid Build Coastguard Worker bufsizes = apply_algo( 761*523fa7a6SAndroid Build Coastguard Worker algo, 762*523fa7a6SAndroid Build Coastguard Worker submodule, 763*523fa7a6SAndroid Build Coastguard Worker alignment, 764*523fa7a6SAndroid Build Coastguard Worker graph_signature, 765*523fa7a6SAndroid Build Coastguard Worker alloc_graph_input=alloc_graph_input, 766*523fa7a6SAndroid Build Coastguard Worker alloc_graph_output=True, 767*523fa7a6SAndroid Build Coastguard Worker ) 768*523fa7a6SAndroid Build Coastguard Worker submodule.meta.update({"non_const_buffer_sizes": bufsizes}) 769*523fa7a6SAndroid Build Coastguard Worker 770*523fa7a6SAndroid Build Coastguard Worker for cond_node in get_cond_nodes(graph_module): 771*523fa7a6SAndroid Build Coastguard Worker handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1])) 772*523fa7a6SAndroid Build Coastguard Worker handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2])) 773*523fa7a6SAndroid Build Coastguard Worker 774*523fa7a6SAndroid Build Coastguard Worker for while_node in get_while_nodes(graph_module): 775*523fa7a6SAndroid Build Coastguard Worker handle_submodule(typing.cast(torch.fx.Node, while_node.args[0])) 776*523fa7a6SAndroid Build Coastguard Worker handle_submodule(typing.cast(torch.fx.Node, while_node.args[1])) 777*523fa7a6SAndroid Build Coastguard Worker # TODO: Add test coverage for map operator once dynamo tracing is 778*523fa7a6SAndroid Build Coastguard Worker # fully supported for this. T142287208 779*523fa7a6SAndroid Build Coastguard Worker for map_node in get_map_nodes(graph_module): 780*523fa7a6SAndroid Build Coastguard Worker handle_submodule( 781*523fa7a6SAndroid Build Coastguard Worker typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True 782*523fa7a6SAndroid Build Coastguard Worker ) 783*523fa7a6SAndroid Build Coastguard Worker 784*523fa7a6SAndroid Build Coastguard Worker graph_module.meta.update({"non_const_buffer_sizes": bufsizes}) 785*523fa7a6SAndroid Build Coastguard Worker 786*523fa7a6SAndroid Build Coastguard Worker return bufsizes 787