1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import torch 10from executorch.exir.dialects._ops import ops 11 12from executorch.exir.pass_base import ExportPass, ProxyValue 13from torch._subclasses.fake_tensor import FakeTensor 14from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib 15from torch.utils import _pytree as pytree 16 17__all__ = [ 18 "quantized_decomposed_lib", 19] 20 21 22class ConstPropPass(ExportPass): 23 """ 24 Performs constant folding and constant propagation. 25 """ 26 27 def __init__(self, propogate_quant: bool = False) -> None: 28 super().__init__() 29 self.propogate_quant = propogate_quant 30 31 # pyre-ignore 32 def call_operator(self, op, args, kwargs, meta): 33 # pyre-ignore 34 def is_const(arg) -> bool: 35 if isinstance(arg, FakeTensor): 36 return False 37 if isinstance( 38 arg, 39 ( 40 float, 41 int, 42 bool, 43 str, 44 torch.Tensor, 45 torch.device, 46 torch.dtype, 47 torch.layout, 48 ), 49 ): 50 return True 51 if isinstance(arg, (tuple, list)): 52 return all(map(is_const, arg)) 53 if isinstance(arg, dict): 54 return all(map(is_const, arg.values())) 55 return False 56 57 dequant_quant_ops = { 58 torch.ops.quantized_decomposed.quantize_per_tensor.default, 59 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 60 torch.ops.quantized_decomposed.quantize_per_channel.default, 61 torch.ops.quantized_decomposed.dequantize_per_channel.default, 62 ops.edge.quantized_decomposed.quantize_per_tensor.default, 63 ops.edge.quantized_decomposed.dequantize_per_tensor.default, 64 ops.edge.quantized_decomposed.quantize_per_channel.default, 65 ops.edge.quantized_decomposed.dequantize_per_channel.default, 66 } 67 op_is_q_dq = op in dequant_quant_ops 68 # XNOR relationship, if propogate_quant is true only const prop quant ops, 69 # if false propogate everything but quant ops 70 if ( 71 (not op_is_q_dq and not self.propogate_quant) 72 or (op_is_q_dq and self.propogate_quant) 73 ) and is_const([args, kwargs]): 74 guard = torch._C._DisableTorchDispatch() # noqa 75 try: 76 args_data, kwargs_data = pytree.tree_map_only( 77 ProxyValue, lambda x: x.data, (args, kwargs) 78 ) 79 result = op(*args_data, **kwargs_data) 80 finally: 81 del guard 82 return result.to_tensor() if isinstance(result, ProxyValue) else result 83 else: 84 return super().call_operator(op, args, kwargs, meta) 85