1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8from typing import Dict 9 10import torch 11from executorch.exir.pass_base import PassBase, PassResult 12 13replacements: Dict[torch._ops.OpOverloadPacket, torch._ops.OpOverload] = { 14 torch.ops.aten.sym_size: torch.ops.aten.sym_size.int, 15 torch.ops.aten.sym_stride: torch.ops.aten.sym_stride.int, 16 torch.ops.aten.sym_numel: torch.ops.aten.sym_numel.default, 17} 18 19 20class ReplaceSymSizeOpPass(PassBase): 21 """ 22 Replace torch.ops.aten.sym_size with torch.ops.aten.sym_size.int 23 and torch.ops.aten.sym_stride with torch.ops.aten.sym_stride.int 24 25 TODO: this can be refactors into a general OpReplacementPass 26 """ 27 28 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 29 for module in graph_module.modules(): 30 if not isinstance(module, torch.fx.GraphModule): 31 continue 32 for node in module.graph.nodes: 33 if node.target in replacements: 34 node.target = replacements[node.target] 35 return PassResult(graph_module, True) 36