xref: /aosp_15_r20/external/executorch/exir/passes/replace_sym_size_op_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8from typing import Dict
9
10import torch
11from executorch.exir.pass_base import PassBase, PassResult
12
13replacements: Dict[torch._ops.OpOverloadPacket, torch._ops.OpOverload] = {
14    torch.ops.aten.sym_size: torch.ops.aten.sym_size.int,
15    torch.ops.aten.sym_stride: torch.ops.aten.sym_stride.int,
16    torch.ops.aten.sym_numel: torch.ops.aten.sym_numel.default,
17}
18
19
20class ReplaceSymSizeOpPass(PassBase):
21    """
22    Replace torch.ops.aten.sym_size with torch.ops.aten.sym_size.int
23    and torch.ops.aten.sym_stride with torch.ops.aten.sym_stride.int
24
25    TODO: this can be refactors into a general OpReplacementPass
26    """
27
28    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
29        for module in graph_module.modules():
30            if not isinstance(module, torch.fx.GraphModule):
31                continue
32            for node in module.graph.nodes:
33                if node.target in replacements:
34                    node.target = replacements[node.target]
35        return PassResult(graph_module, True)
36