xref: /aosp_15_r20/external/executorch/exir/passes/const_prop_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import torch
10from executorch.exir.dialects._ops import ops
11
12from executorch.exir.pass_base import ExportPass, ProxyValue
13from torch._subclasses.fake_tensor import FakeTensor
14from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
15from torch.utils import _pytree as pytree
16
17__all__ = [
18    "quantized_decomposed_lib",
19]
20
21
22class ConstPropPass(ExportPass):
23    """
24    Performs constant folding and constant propagation.
25    """
26
27    def __init__(self, propogate_quant: bool = False) -> None:
28        super().__init__()
29        self.propogate_quant = propogate_quant
30
31    # pyre-ignore
32    def call_operator(self, op, args, kwargs, meta):
33        # pyre-ignore
34        def is_const(arg) -> bool:
35            if isinstance(arg, FakeTensor):
36                return False
37            if isinstance(
38                arg,
39                (
40                    float,
41                    int,
42                    bool,
43                    str,
44                    torch.Tensor,
45                    torch.device,
46                    torch.dtype,
47                    torch.layout,
48                ),
49            ):
50                return True
51            if isinstance(arg, (tuple, list)):
52                return all(map(is_const, arg))
53            if isinstance(arg, dict):
54                return all(map(is_const, arg.values()))
55            return False
56
57        dequant_quant_ops = {
58            torch.ops.quantized_decomposed.quantize_per_tensor.default,
59            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
60            torch.ops.quantized_decomposed.quantize_per_channel.default,
61            torch.ops.quantized_decomposed.dequantize_per_channel.default,
62            ops.edge.quantized_decomposed.quantize_per_tensor.default,
63            ops.edge.quantized_decomposed.dequantize_per_tensor.default,
64            ops.edge.quantized_decomposed.quantize_per_channel.default,
65            ops.edge.quantized_decomposed.dequantize_per_channel.default,
66        }
67        op_is_q_dq = op in dequant_quant_ops
68        # XNOR relationship, if propogate_quant is true only const prop quant ops,
69        # if false propogate everything but quant ops
70        if (
71            (not op_is_q_dq and not self.propogate_quant)
72            or (op_is_q_dq and self.propogate_quant)
73        ) and is_const([args, kwargs]):
74            guard = torch._C._DisableTorchDispatch()  # noqa
75            try:
76                args_data, kwargs_data = pytree.tree_map_only(
77                    ProxyValue, lambda x: x.data, (args, kwargs)
78                )
79                result = op(*args_data, **kwargs_data)
80            finally:
81                del guard
82            return result.to_tensor() if isinstance(result, ProxyValue) else result
83        else:
84            return super().call_operator(op, args, kwargs, meta)
85