1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8import torch 9from executorch.exir.pass_base import ExportPass, PassResult 10 11 12class CastInt64ToInt32Pass(ExportPass): 13 def __init__(self, exported_program: torch.export.ExportedProgram): 14 super(CastInt64ToInt32Pass, self).__init__() 15 self.exported_program = exported_program 16 17 def _to_int32(self, graph_module: torch.fx.GraphModule): 18 for node in graph_module.graph.nodes: 19 fake_tensor = node.meta["val"] 20 if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): 21 if node.meta["val"].dtype == torch.int64: 22 node.meta["val"] = node.meta["val"].to(torch.int32) 23 buffer_name = ( 24 self.exported_program.graph_signature.inputs_to_buffers[ 25 node.name 26 ] 27 ) 28 new_tensor = self.exported_program.state_dict[buffer_name].to( 29 torch.int32 30 ) 31 self.exported_program.state_dict[buffer_name] = new_tensor 32 33 def call(self, graph_module: torch.fx.GraphModule): 34 self._to_int32(graph_module) 35 graph_module.recompile() 36 graph_module = super().call(graph_module).graph_module 37 return PassResult(graph_module, True) 38