xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/replace_inf_buffer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7from executorch.exir.pass_base import ExportPass, PassResult
8
9
10class ReplaceInfBuffer(ExportPass):
11    """
12    Due to limitation in Qnn, we need to change inf or -inf to arbitrary value in quantization.
13    """
14
15    def __init__(self):
16        super(ReplaceInfBuffer, self).__init__()
17
18    def call(self, graph_module: torch.fx.GraphModule):
19        for buf_name, tensor in graph_module.named_buffers():
20            if tensor.is_floating_point():
21                tensor[tensor == float("inf")] = 255
22                tensor[tensor == float("-inf")] = -255
23                setattr(graph_module, buf_name, tensor)
24
25        graph_module.recompile()
26        return PassResult(graph_module, True)
27