1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import itertools 7import operator 8 9import torch 10from executorch.exir.pass_base import ExportPass, PassResult 11from executorch.exir.passes import dead_code_elimination_pass 12from torch.fx.passes.utils.source_matcher_utils import get_source_partitions 13 14 15class ReduceDynamicRange(ExportPass): 16 """ 17 Due to limitation in Qnn, we need to change torch.finfo(torch.float32).min 18 to the smallest representable value in quantization. 19 """ 20 21 binary_op_sources = [ 22 operator.add, 23 operator.sub, 24 operator.mul, 25 operator.truediv, 26 torch.add, 27 torch.sub, 28 torch.mul, 29 torch.div, 30 "add", 31 "sub", 32 "mul", 33 "truediv", 34 ] 35 36 def __init__(self): 37 super(ReduceDynamicRange, self).__init__() 38 39 def _traverse_binary_node(self, graph_module: torch.fx.GraphModule): 40 src_partitions = get_source_partitions( 41 graph_module.graph, self.binary_op_sources 42 ) 43 src_partitions = list(itertools.chain(*src_partitions.values())) 44 for src_partition in src_partitions: 45 if len(src_partition.input_nodes) == 1: 46 binary_node = src_partition.nodes[0] 47 # (input node 0, constant value) 48 args_list = list(binary_node.args) 49 # Due to limitation in Qnn, we need to change torch.finfo(torch.float32).min 50 # to the smallest representable value in quantization 51 for i, arg in enumerate(args_list): 52 if arg == torch.finfo(torch.float32).min: 53 args_list[i] = -255.0 54 binary_node.args = tuple(args_list) 55 56 def call(self, graph_module: torch.fx.GraphModule): 57 self._traverse_binary_node(graph_module) 58 graph_module.recompile() 59 dead_code_elimination_pass(graph_module) 60 return PassResult(graph_module, True) 61