xref: /aosp_15_r20/external/executorch/exir/graph_module.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from types import FunctionType as function
10from typing import Dict, List, Tuple, Union
11
12import torch
13
14
15LeafValue = Union[
16    torch.Tensor,
17    str,
18    int,
19    float,
20    bool,
21    complex,
22    torch.dtype,
23    torch.device,
24    torch.memory_format,
25    torch.layout,
26    None,
27]
28
29# We maintain a global cache of op lookups as this significantly speeds up
30# deserialization because hasattr(torch.ops, name) is an expensive call.
31_cache_ops_dict: Dict[
32    Tuple[str, str], Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]
33] = {}
34_cache_fake_ops_dict: Dict[Tuple[str, str], function] = {}
35
36
37def _get_submodule(
38    graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
39) -> Tuple[str, torch.nn.Module, torch.fx.Node]:
40    submod_node = node.args[arg_index]
41    assert isinstance(submod_node, torch.fx.Node)
42    assert submod_node.op == "get_attr"
43    assert isinstance(submod_node.target, str)
44    submodule = graph_module.get_submodule(submod_node.target)
45    # pyre-ignore
46    return submod_node.target, submodule, node
47
48
49def get_control_flow_submodules(
50    graph_module: torch.fx.GraphModule,
51) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
52    """
53    Returns a list of submodules used for control flow operations
54    (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
55    into submodules). Specifically, the returned value is a list containing a
56    tuple of (name of the submodule that's stored in the graph module, the
57    submodule itself, and the fx node that uses this submodule).
58    """
59    control_flow_submodules = []
60    for node in graph_module.graph.nodes:
61        if node.op != "call_function":
62            continue
63
64        if node.target is torch.ops.higher_order.cond:
65            control_flow_submodules.append(_get_submodule(graph_module, node, 1))
66            control_flow_submodules.append(_get_submodule(graph_module, node, 2))
67        if node.target is torch.ops.higher_order.map_impl:
68            control_flow_submodules.append(_get_submodule(graph_module, node, 0))
69
70    return control_flow_submodules
71