xref: /aosp_15_r20/external/executorch/exir/passes/memory_format_ops_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8import logging
9
10import torch
11from executorch.exir.dialects.edge._ops import EdgeOpOverload
12from executorch.exir.dim_order_utils import get_dim_order, get_memory_format
13from executorch.exir.pass_base import ExportPass, ProxyValue
14from executorch.exir.passes.dim_order_ops_registry import (
15    DimOrderOpsMap,
16    MemoryFormatOpsMap,
17)
18
19logger = logging.getLogger(__file__)
20logger.setLevel(logging.INFO)
21
22# TODO - these passes are too specialized on a single to_copy op.
23# We should be able to replace (or revert) any of the dim_order ops in the future.
24
25
26class MemoryFormatOpsPass(ExportPass):
27    """
28    This pass replaces ops which takes torch.memory_format as an argument with
29    'equivalent' op which takes dim_order. This is towards the larger ExecuTorch
30    goal to move away from torch.memory_format. There is a 1:1 mapping between
31    the aten op and the new edge dialect dim_order op.
32    """
33
34    def call_operator(self, op, args, kwargs, meta):
35        if not (isinstance(op, EdgeOpOverload) and op.__name__ in DimOrderOpsMap):
36            return super().call_operator(
37                op,
38                args,
39                kwargs,
40                meta,
41            )
42        # new kwargs with dim_order, and no memory_format for the new op
43        nkwargs = dict(copy.deepcopy(kwargs))  # orig kwargs are immutable
44
45        # get the "to" memory format for the EdgeOp
46        mem_format = nkwargs.pop("memory_format", torch.contiguous_format)
47
48        # can always get the shape, assuming rank is specialized
49        if isinstance(args[0], ProxyValue) and args[0].is_tensor():
50            ndim = args[0].to_tensor().dim()
51        elif isinstance(args[0], torch.Tensor):
52            ndim = args[0].dim()
53        else:
54            assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}"
55
56        nkwargs["dim_order"] = get_dim_order(mem_format, ndim)
57        logger.debug(
58            f"_to_copy = rank: {ndim}, memory_format: {mem_format}."
59            f" _to_dim_order_copy = dim_order: {nkwargs['dim_order']}"
60        )
61
62        t = DimOrderOpsMap.get(op.__name__, None)
63        assert t is not None, f"{op.__name__} not found in DimOrderOpsMap"
64
65        return super().call_operator(
66            t,
67            args,
68            nkwargs,
69            meta,
70        )
71
72
73class DimOrderOpsRevertPass(ExportPass):
74    """
75    This pass is to revert the dim_order ops back to the memory format ops.
76    """
77
78    def call_operator(self, op, args, kwargs, meta):
79        if not (isinstance(op, EdgeOpOverload) and op.__name__ in MemoryFormatOpsMap):
80            return super().call_operator(
81                op,
82                args,
83                kwargs,
84                meta,
85            )
86
87        # new kwargs with dim_order, and no memory_format for the new op
88        nkwargs = dict(copy.deepcopy(kwargs))  # orig kwargs are immutable
89
90        # can always get the shape, assuming rank is specialized
91        if isinstance(args[0], ProxyValue) and args[0].is_tensor():
92            ndim = args[0].to_tensor().dim()
93        elif isinstance(args[0], torch.Tensor):
94            ndim = args[0].dim()
95        else:
96            assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}"
97
98        # get the "to" memory format for the EdgeOp
99        default_dim_order = list(range(ndim))
100        dim_order = nkwargs.pop("dim_order", default_dim_order)
101
102        nkwargs["memory_format"] = get_memory_format(dim_order)
103
104        logger.debug(
105            f" _to_dim_order_copy = dim_order: {dim_order}."
106            f"_to_copy = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
107        )
108
109        t = MemoryFormatOpsMap.get(op.__name__, None)
110        assert t is not None, f"{op.__name__} not found in MemoryFormatOpsMap"
111
112        return super().call_operator(
113            t,
114            args,
115            nkwargs,
116            meta,
117        )
118