xref: /aosp_15_r20/external/executorch/exir/passes/replace_view_copy_with_view_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import copy
10import logging
11from typing import Any, List, Tuple
12
13import torch
14from executorch.exir import memory
15
16from executorch.exir.dialects._ops import ops
17from executorch.exir.tensor import (
18    contiguous_stride_from_shape,
19    determine_tensor_dynanism,
20    dim_order_from_stride,
21    TensorShapeDynamism,
22    TensorSpec,
23)
24from torch.fx.passes.infra.pass_base import PassBase, PassResult
25
26logger: logging.Logger = logging.getLogger(__name__)
27
28
29def _is_view_copy(node: torch.fx.Node) -> bool:
30    return node.op == "call_function" and node.target in (
31        torch.ops.aten.view_copy.default,
32        ops.edge.aten.view_copy.default,
33    )
34
35
36_VIEW_OP = memory.view
37
38
39class _Guard:
40    def __init__(
41        self, name: str, field_lambda, expected_val: Any  # pyre-ignore[2]
42    ) -> None:
43        self.name: str = name
44        self.field_lambda = field_lambda  # pyre-ignore[4]
45        self.expected_val = copy.deepcopy(expected_val)  # pyre-ignore[4]
46
47    def __call__(self, view_spec) -> None:  # pyre-ignore[2]
48        assert view_spec._unguarded_access
49        observed_val = self.field_lambda(view_spec)
50        if observed_val != self.expected_val:
51            raise Exception(
52                f"Guard {self.name} failed.  Expected to see value {self.expected_val}, but saw value {observed_val}."
53            )
54
55
56class _ViewSpec(TensorSpec):
57    def __init__(self, base: TensorSpec, shape: List[int]) -> None:
58        """
59        A _ViewSpec is TensorSpec that shares non-size related fields with its base.
60        The size-related fields are: shape, stride, dim_order, and shape_dynamism.
61
62        If either the base or view spec updates a non-size related field, the change
63        is reflected in both specs.  But size related fields are not linked and can
64        be set separately.
65
66        A _ViewSpec can only be created from a non-sparse, strided TensorSpec.
67        On creation, a _ViewSpec must be compatible with its base with respect to
68        shape_dynamism, dtype, and nbytes.
69
70        A _ViewSpec contains _guards that are evaluated on every __getattribute__ call.
71        The purpose of the guards is to make sure the _ViewSpec is still compatible
72        with its base.
73        """
74
75        # Explicitly put all attributes into _self_fields or _base_fields
76        # Any attribute that is not in _self_fields or _base_fields will
77        # raise an Exception.  If TensorSpec is extended with a new attribute,
78        # we should explicitly decide how _ViewSpec will handle it.
79        self._self_fields = [
80            # We need to get the debug method from self
81            # so that the object id it prints is correct.
82            "debug",  # method
83            "__repr__",  # method
84            # The following are related to size and should use self
85            "shape",
86            "stride",
87            "dim_order",
88            "shape_dynamism",
89            "nbytes",  # method
90            "allocated_memory",  # property
91            "is_dynamic_shape_tensor",  # property
92            "is_static_shape_tensor",  # property
93            "is_upper_bound_tensor",  # property
94            "is_dynamic_unbound_tensor",  # property
95        ]
96        self._base_fields = [
97            "scalar_type",
98            "const",
99            "alignment",
100            "storage",
101            "requires_grad",
102            "layout",
103            "is_sparse",
104            "init_mem_planning_fields",  # method
105            "realign",  # method
106            "from_tensor",  # class method
107            "lifetime",
108            "mem_id",
109            "mem_obj_id",
110            "mem_offset",
111            "dtype",  # property
112        ]
113
114        # Make sure _self_fields and _base_fields are disjoint
115        assert len(set(self._self_fields) & set(self._base_fields)) == 0
116
117        self._guards: List[_Guard] = []
118        self._unguarded_access = False
119
120        # Make sure base is not sparse and add a guard
121        if base.is_sparse:
122            raise Exception(
123                "_ViewSpec can only be created from non-sparse TensorSpec, but base.is_sparse=True."
124            )
125        self._guards.append(
126            _Guard(
127                "is_sparse",
128                lambda view_spec: view_spec.is_sparse,
129                False,
130            )
131        )
132
133        # Make sure base layout is strided and add a guard
134        if base.layout != torch.strided:
135            raise Exception(
136                f"_ViewSpec can only be created from TensorSpec with layout={torch.strided}, but got layout={base.layout}."
137            )
138        self._guards.append(
139            _Guard(
140                "layout",
141                lambda view_spec: view_spec.layout,
142                torch.strided,
143            )
144        )
145
146        self._base = base
147        self.shape: List[int] = shape
148        self.stride: Tuple[int] = contiguous_stride_from_shape(torch.Size(self.shape))
149        self.dim_order: Tuple[bytes] = dim_order_from_stride(self.stride)
150        self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(
151            torch.Size(self.shape)
152        )
153
154        # Check compatibility with base on creation
155        if self.shape_dynamism != base.shape_dynamism:
156            raise Exception(
157                f"_ViewSpec is incompatible with its base on creation.  It has shape_dynamism={self.shape_dynamism}, but its base has shape_dynamism={base.shape_dynamism}."
158            )
159        self._guards.append(
160            _Guard(
161                "shape_dynamism_init",
162                lambda view_spec: view_spec.shape_dynamism,
163                base.shape_dynamism,
164            )
165        )
166        self._guards.append(
167            _Guard(
168                "shape_dynamism_eq_base",
169                lambda view_spec: view_spec.shape_dynamism
170                == view_spec._base.shape_dynamism,
171                True,
172            )
173        )
174
175        if self.dtype != base.dtype:
176            raise Exception(
177                f"_ViewSpec is incompatible with its base on creation.  It has dtype={self.dtype}, but its base has dtype={base.dtype}."
178            )
179        self._guards.append(
180            _Guard("dtype", lambda view_spec: view_spec.dtype, base.dtype)
181        )
182
183        # We do not guard nbytes because dynamic symints are replaced by upper bounds.
184        # We do guard on rank, though
185        if self.nbytes() != base.nbytes():
186            raise Exception(
187                f"_ViewSpec is incompatible with its base on creation.  It has nbytes={self.nbytes()}, but its base has nbytes={base.nbytes()}."
188            )
189        self._guards.append(
190            _Guard("rank", lambda view_spec: len(view_spec.shape), len(shape))
191        )
192
193    def _run_guards(self) -> None:
194        unguarded_access = self._unguarded_access
195        try:
196            self._unguarded_access = True
197            for g in self._guards:
198                g(self)
199        finally:
200            self._unguarded_access = unguarded_access
201
202    def __getattribute__(self, name: str):  # pyre-ignore
203        # Special field so we don't recurse infinitely
204        if name in [
205            "_base",
206            "_self_fields",
207            "_base_fields",
208            "_guards",
209            "_unguarded_access",
210            "_run_guards",
211        ]:
212            return object.__getattribute__(self, name)
213
214        # Get some attributes from self
215        if name in self._self_fields:
216            val = object.__getattribute__(self, name)
217        elif name in self._base_fields:
218            val = object.__getattribute__(self._base, name)
219        else:
220            if len(name) > 0 and name[0] != "_":
221                logger.warning(
222                    f"Getting non-private attribute {name} on self, but it is not in _self_fields or _base_fields.  Is this intended?"
223                )
224            val = object.__getattribute__(self, name)
225
226        if not self._unguarded_access:
227            self._run_guards()
228        return val
229
230    def __setattr__(self, name: str, val) -> None:  # pyre-ignore
231        # Special field so we don't recurse infinitely
232        if name in [
233            "_base",
234            "_self_fields",
235            "_base_fields",
236            "_guards",
237            "_unguarded_access",
238            "_run_guards",
239        ]:
240            object.__setattr__(self, name, val)
241            return
242
243        if name in self._self_fields:
244            object.__setattr__(self, name, val)
245            return
246
247        if name in self._base_fields:
248            object.__setattr__(self._base, name, val)
249            return
250
251        if len(name) > 0 and name[0] != "_":
252            logger.warning(
253                f"Setting non-private attribute {name} on self, but it is not in _self_fields or _base_fields.  Is this intended?"
254            )
255        object.__setattr__(self, name, val)
256
257
258class ReplaceViewCopyWithViewPass(PassBase):
259    def __init__(self) -> None:
260        super().__init__()
261
262    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
263        """
264        This pass replaces view_copy nodes with view nodes.
265
266        This should be run after the NormalizeViewCopyBasePass.
267
268        During memory planning, view nodes share the same storage as their base.
269        """
270
271        n_replaced = 0
272        for module in graph_module.modules():
273            if not isinstance(module, torch.fx.GraphModule):
274                continue
275            for node in module.graph.nodes:
276                # Note: We only replace view_copy nodes that are not output, since
277                # the output pointer could be modified at runtime (T187925929)
278                if _is_view_copy(node) and all(u.op != "output" for u in node.users):
279                    base, _ = node.args
280                    node.target = _VIEW_OP
281
282                    # Create spec for the node.
283                    # _ViewSpec gives a view into its base spec for non-size
284                    # related information.
285
286                    # the shape is not the same as node.args[1] because node.args[1]
287                    # can have an inferred sizes (-1).
288                    shape = node.meta["val"].shape
289                    node.meta["spec"] = _ViewSpec(base.meta["spec"], shape)
290
291                    n_replaced += 1
292
293            module.recompile()
294
295        logger.debug(f"Replaced {n_replaced} view_copy nodes with {_VIEW_OP} nodes.")
296        return PassResult(graph_module, n_replaced > 0)
297
298    def ensures(self, graph_module: torch.fx.GraphModule) -> None:
299        for module in graph_module.modules():
300            if not isinstance(module, torch.fx.GraphModule):
301                continue
302            for node in module.graph.nodes:
303                # Note: We only replace view_copy nodes that are not output, since
304                # the output pointer could be modified at runtime (T187925929)
305                assert not (
306                    _is_view_copy(node) and all(u.op != "output" for u in node.users)
307                )
308                if node.op == "call_function" and node.target == _VIEW_OP:
309                    assert isinstance(node.meta["spec"], _ViewSpec)
310
311    def requires(self, graph_module: torch.fx.GraphModule) -> None:
312        """
313        This pass should be called after NormalizeViewCopyBasePass.
314        We check that all view_copy nodes have been normalized.
315        """
316        for module in graph_module.modules():
317            if not isinstance(module, torch.fx.GraphModule):
318                continue
319            for node in module.graph.nodes:
320                # Note: We only replace view_copy nodes that are not output, since
321                # the output pointer could be modified at runtime (T187925929)
322                if _is_view_copy(node) and all(u.op != "output" for u in node.users):
323                    base, size = node.args
324                    assert not _is_view_copy(base)
325