xref: /aosp_15_r20/external/executorch/exir/passes/spec_prop_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from typing import List, Optional
10
11import torch
12from executorch.exir.delegate import executorch_call_delegate
13from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
14from executorch.exir.tensor import TensorSpec
15from torch.export.exported_program import ExportGraphSignature
16from torch.fx.node import Node
17from torch.utils import _pytree as pytree
18
19
20# pyre-ignore
21def make_spec(x):
22    if isinstance(x, torch.Tensor):
23        return TensorSpec.from_tensor(x)
24    elif isinstance(x, (int, bool, float)):
25        return x
26    else:
27        return None
28
29
30def _is_mutable_buffer(
31    node: Node, graph_signature: Optional[ExportGraphSignature] = None
32) -> bool:
33    """
34    Check if the node is mutable buffer according to the provided graph signature.
35    """
36    # graph signature is None for memory planning passes not called from EdgeProgramManager, these paths are deprecated so mutable buffers are not supported on them.
37    if graph_signature is None:
38        return False
39    if node.op == "placeholder":
40        if isinstance(node.target, str):
41            if node.target in graph_signature.inputs_to_buffers:
42                fqn = graph_signature.inputs_to_buffers[node.target]
43                # if the buffer is mutated then record that
44                if fqn in graph_signature.buffers_to_mutate.values():
45                    return True
46    return False
47
48
49class SpecPropPass(ExportPass):
50    def __init__(self) -> None:
51        super().__init__()
52
53    def on_attr(self, attr: ProxyValue) -> None:
54        attr.node.meta["spec"] = pytree.tree_map_only(
55            torch.Tensor,
56            make_spec,
57            attr.data,
58        )
59
60    def update_placeholder_tensor_specs(
61        self,
62        exported_program: torch.export.ExportedProgram,
63        graph_module: torch.fx.GraphModule,
64    ) -> None:
65        """
66        Update the tensor specs for all placeholder nodes such that
67        placeholders that are parameters are marked as constant.
68        """
69        for node in graph_module.graph.nodes:
70            if node.op != "placeholder":
71                continue
72            if "spec" not in node.meta:
73                raise RuntimeError(f"Placeholder node {node} missing meta['spec']")
74            spec = node.meta["spec"]
75            if isinstance(node.target, str) and (
76                node.target in exported_program.graph_signature.inputs_to_parameters
77                or (
78                    node.target in exported_program.graph_signature.inputs_to_buffers
79                    and not _is_mutable_buffer(node, exported_program.graph_signature)
80                )
81                or node.target
82                in exported_program.graph_signature.inputs_to_lifted_tensor_constants
83            ):
84                spec.const = True
85
86    # pyre-ignore
87    def placeholder(self, name: str, arg, meta):
88        meta["spec"] = make_spec(arg)
89        return super().placeholder(name, arg, meta)
90
91    # pyre-ignore
92    def call_operator(self, op, args, kwargs, meta):
93        args_data, kwargs_data = pytree.tree_map_only(
94            ProxyValue, lambda x: x.data, (args, kwargs)
95        )
96        meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
97        return super().call_operator(op, args, kwargs, meta)
98
99    # pyre-ignore
100    def call_getitem(self, value, key: int, meta):
101        meta["spec"] = value.node.meta["spec"][key]
102        return super().call_getitem(value, key, meta)
103
104    # pyre-ignore
105    def call_cond(self, pred, true_fn, false_fn, inputs, meta):
106        # true_fn/false_fn return tensors of the same shape, so we can pick
107        # either one here.
108        *_, true_out_node = true_fn.graph.nodes
109        meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
110        return super().call_cond(pred, true_fn, false_fn, inputs, meta)
111
112    def call_map(
113        self,
114        f: torch.fx.GraphModule,
115        mapped_args: List[ProxyValue],
116        operands: List[ProxyValue],
117        meta: NodeMetadata,
118    ) -> ProxyValue:
119        mapped_dim_size = [arg.data for arg in mapped_args][0].size(0)
120        *_, body_out_node = f.graph.nodes
121        body_out_node_fake_tensor = body_out_node.meta["val"]
122        map_fake_tensor = pytree.tree_map_only(
123            torch.Tensor,
124            lambda x: x.new_empty(mapped_dim_size, *x.shape),
125            body_out_node_fake_tensor,
126        )
127        meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor)
128        return super().call_map(f, mapped_args, operands, meta)
129
130    # pyre-ignore
131    def call_delegate(self, lowered_module, args, kwargs, meta):
132        args_data, kwargs_data = pytree.tree_map_only(
133            ProxyValue, lambda x: x.data, (args, kwargs)
134        )
135        # If spec is missing, re-genenrate it with args data
136        if "spec" not in meta:
137            meta["spec"] = pytree.tree_map(
138                make_spec,
139                executorch_call_delegate(lowered_module, *args_data),
140            )
141        return super().call_delegate(lowered_module, args, kwargs, meta)
142
143    # pyre-ignore
144    def output(self, results, meta):
145        # pyre-ignore
146        def get_spec(x):
147            if isinstance(x, ProxyValue):
148                return x.node.meta["spec"]
149            else:
150                return make_spec(x)
151
152        meta["spec"] = pytree.tree_map(get_spec, results)
153        return super().output(results, meta)
154