1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7from executorch.exir.pass_base import ExportPass, PassResult 8 9 10class ReplaceInfBuffer(ExportPass): 11 """ 12 Due to limitation in Qnn, we need to change inf or -inf to arbitrary value in quantization. 13 """ 14 15 def __init__(self): 16 super(ReplaceInfBuffer, self).__init__() 17 18 def call(self, graph_module: torch.fx.GraphModule): 19 for buf_name, tensor in graph_module.named_buffers(): 20 if tensor.is_floating_point(): 21 tensor[tensor == float("inf")] = 255 22 tensor[tensor == float("-inf")] = -255 23 setattr(graph_module, buf_name, tensor) 24 25 graph_module.recompile() 26 return PassResult(graph_module, True) 27