xref: /aosp_15_r20/external/executorch/backends/arm/_passes/cast_int64_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8import torch
9from executorch.exir.pass_base import ExportPass, PassResult
10
11
12class CastInt64ToInt32Pass(ExportPass):
13    def __init__(self, exported_program: torch.export.ExportedProgram):
14        super(CastInt64ToInt32Pass, self).__init__()
15        self.exported_program = exported_program
16
17    def _to_int32(self, graph_module: torch.fx.GraphModule):
18        for node in graph_module.graph.nodes:
19            fake_tensor = node.meta["val"]
20            if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
21                if node.meta["val"].dtype == torch.int64:
22                    node.meta["val"] = node.meta["val"].to(torch.int32)
23                    buffer_name = (
24                        self.exported_program.graph_signature.inputs_to_buffers[
25                            node.name
26                        ]
27                    )
28                    new_tensor = self.exported_program.state_dict[buffer_name].to(
29                        torch.int32
30                    )
31                    self.exported_program.state_dict[buffer_name] = new_tensor
32
33    def call(self, graph_module: torch.fx.GraphModule):
34        self._to_int32(graph_module)
35        graph_module.recompile()
36        graph_module = super().call(graph_module).graph_module
37        return PassResult(graph_module, True)
38