1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import List, Optional 10 11import torch 12from executorch.exir.delegate import executorch_call_delegate 13from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue 14from executorch.exir.tensor import TensorSpec 15from torch.export.exported_program import ExportGraphSignature 16from torch.fx.node import Node 17from torch.utils import _pytree as pytree 18 19 20# pyre-ignore 21def make_spec(x): 22 if isinstance(x, torch.Tensor): 23 return TensorSpec.from_tensor(x) 24 elif isinstance(x, (int, bool, float)): 25 return x 26 else: 27 return None 28 29 30def _is_mutable_buffer( 31 node: Node, graph_signature: Optional[ExportGraphSignature] = None 32) -> bool: 33 """ 34 Check if the node is mutable buffer according to the provided graph signature. 35 """ 36 # graph signature is None for memory planning passes not called from EdgeProgramManager, these paths are deprecated so mutable buffers are not supported on them. 37 if graph_signature is None: 38 return False 39 if node.op == "placeholder": 40 if isinstance(node.target, str): 41 if node.target in graph_signature.inputs_to_buffers: 42 fqn = graph_signature.inputs_to_buffers[node.target] 43 # if the buffer is mutated then record that 44 if fqn in graph_signature.buffers_to_mutate.values(): 45 return True 46 return False 47 48 49class SpecPropPass(ExportPass): 50 def __init__(self) -> None: 51 super().__init__() 52 53 def on_attr(self, attr: ProxyValue) -> None: 54 attr.node.meta["spec"] = pytree.tree_map_only( 55 torch.Tensor, 56 make_spec, 57 attr.data, 58 ) 59 60 def update_placeholder_tensor_specs( 61 self, 62 exported_program: torch.export.ExportedProgram, 63 graph_module: torch.fx.GraphModule, 64 ) -> None: 65 """ 66 Update the tensor specs for all placeholder nodes such that 67 placeholders that are parameters are marked as constant. 68 """ 69 for node in graph_module.graph.nodes: 70 if node.op != "placeholder": 71 continue 72 if "spec" not in node.meta: 73 raise RuntimeError(f"Placeholder node {node} missing meta['spec']") 74 spec = node.meta["spec"] 75 if isinstance(node.target, str) and ( 76 node.target in exported_program.graph_signature.inputs_to_parameters 77 or ( 78 node.target in exported_program.graph_signature.inputs_to_buffers 79 and not _is_mutable_buffer(node, exported_program.graph_signature) 80 ) 81 or node.target 82 in exported_program.graph_signature.inputs_to_lifted_tensor_constants 83 ): 84 spec.const = True 85 86 # pyre-ignore 87 def placeholder(self, name: str, arg, meta): 88 meta["spec"] = make_spec(arg) 89 return super().placeholder(name, arg, meta) 90 91 # pyre-ignore 92 def call_operator(self, op, args, kwargs, meta): 93 args_data, kwargs_data = pytree.tree_map_only( 94 ProxyValue, lambda x: x.data, (args, kwargs) 95 ) 96 meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data)) 97 return super().call_operator(op, args, kwargs, meta) 98 99 # pyre-ignore 100 def call_getitem(self, value, key: int, meta): 101 meta["spec"] = value.node.meta["spec"][key] 102 return super().call_getitem(value, key, meta) 103 104 # pyre-ignore 105 def call_cond(self, pred, true_fn, false_fn, inputs, meta): 106 # true_fn/false_fn return tensors of the same shape, so we can pick 107 # either one here. 108 *_, true_out_node = true_fn.graph.nodes 109 meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"]) 110 return super().call_cond(pred, true_fn, false_fn, inputs, meta) 111 112 def call_map( 113 self, 114 f: torch.fx.GraphModule, 115 mapped_args: List[ProxyValue], 116 operands: List[ProxyValue], 117 meta: NodeMetadata, 118 ) -> ProxyValue: 119 mapped_dim_size = [arg.data for arg in mapped_args][0].size(0) 120 *_, body_out_node = f.graph.nodes 121 body_out_node_fake_tensor = body_out_node.meta["val"] 122 map_fake_tensor = pytree.tree_map_only( 123 torch.Tensor, 124 lambda x: x.new_empty(mapped_dim_size, *x.shape), 125 body_out_node_fake_tensor, 126 ) 127 meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor) 128 return super().call_map(f, mapped_args, operands, meta) 129 130 # pyre-ignore 131 def call_delegate(self, lowered_module, args, kwargs, meta): 132 args_data, kwargs_data = pytree.tree_map_only( 133 ProxyValue, lambda x: x.data, (args, kwargs) 134 ) 135 # If spec is missing, re-genenrate it with args data 136 if "spec" not in meta: 137 meta["spec"] = pytree.tree_map( 138 make_spec, 139 executorch_call_delegate(lowered_module, *args_data), 140 ) 141 return super().call_delegate(lowered_module, args, kwargs, meta) 142 143 # pyre-ignore 144 def output(self, results, meta): 145 # pyre-ignore 146 def get_spec(x): 147 if isinstance(x, ProxyValue): 148 return x.node.meta["spec"] 149 else: 150 return make_spec(x) 151 152 meta["spec"] = pytree.tree_map(get_spec, results) 153 return super().output(results, meta) 154