xref: /aosp_15_r20/external/executorch/exir/graph_module.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerfrom types import FunctionType as function
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Tuple, Union
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport torch
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard WorkerLeafValue = Union[
16*523fa7a6SAndroid Build Coastguard Worker    torch.Tensor,
17*523fa7a6SAndroid Build Coastguard Worker    str,
18*523fa7a6SAndroid Build Coastguard Worker    int,
19*523fa7a6SAndroid Build Coastguard Worker    float,
20*523fa7a6SAndroid Build Coastguard Worker    bool,
21*523fa7a6SAndroid Build Coastguard Worker    complex,
22*523fa7a6SAndroid Build Coastguard Worker    torch.dtype,
23*523fa7a6SAndroid Build Coastguard Worker    torch.device,
24*523fa7a6SAndroid Build Coastguard Worker    torch.memory_format,
25*523fa7a6SAndroid Build Coastguard Worker    torch.layout,
26*523fa7a6SAndroid Build Coastguard Worker    None,
27*523fa7a6SAndroid Build Coastguard Worker]
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker# We maintain a global cache of op lookups as this significantly speeds up
30*523fa7a6SAndroid Build Coastguard Worker# deserialization because hasattr(torch.ops, name) is an expensive call.
31*523fa7a6SAndroid Build Coastguard Worker_cache_ops_dict: Dict[
32*523fa7a6SAndroid Build Coastguard Worker    Tuple[str, str], Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]
33*523fa7a6SAndroid Build Coastguard Worker] = {}
34*523fa7a6SAndroid Build Coastguard Worker_cache_fake_ops_dict: Dict[Tuple[str, str], function] = {}
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Workerdef _get_submodule(
38*523fa7a6SAndroid Build Coastguard Worker    graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
39*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[str, torch.nn.Module, torch.fx.Node]:
40*523fa7a6SAndroid Build Coastguard Worker    submod_node = node.args[arg_index]
41*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(submod_node, torch.fx.Node)
42*523fa7a6SAndroid Build Coastguard Worker    assert submod_node.op == "get_attr"
43*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(submod_node.target, str)
44*523fa7a6SAndroid Build Coastguard Worker    submodule = graph_module.get_submodule(submod_node.target)
45*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
46*523fa7a6SAndroid Build Coastguard Worker    return submod_node.target, submodule, node
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Workerdef get_control_flow_submodules(
50*523fa7a6SAndroid Build Coastguard Worker    graph_module: torch.fx.GraphModule,
51*523fa7a6SAndroid Build Coastguard Worker) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
52*523fa7a6SAndroid Build Coastguard Worker    """
53*523fa7a6SAndroid Build Coastguard Worker    Returns a list of submodules used for control flow operations
54*523fa7a6SAndroid Build Coastguard Worker    (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
55*523fa7a6SAndroid Build Coastguard Worker    into submodules). Specifically, the returned value is a list containing a
56*523fa7a6SAndroid Build Coastguard Worker    tuple of (name of the submodule that's stored in the graph module, the
57*523fa7a6SAndroid Build Coastguard Worker    submodule itself, and the fx node that uses this submodule).
58*523fa7a6SAndroid Build Coastguard Worker    """
59*523fa7a6SAndroid Build Coastguard Worker    control_flow_submodules = []
60*523fa7a6SAndroid Build Coastguard Worker    for node in graph_module.graph.nodes:
61*523fa7a6SAndroid Build Coastguard Worker        if node.op != "call_function":
62*523fa7a6SAndroid Build Coastguard Worker            continue
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker        if node.target is torch.ops.higher_order.cond:
65*523fa7a6SAndroid Build Coastguard Worker            control_flow_submodules.append(_get_submodule(graph_module, node, 1))
66*523fa7a6SAndroid Build Coastguard Worker            control_flow_submodules.append(_get_submodule(graph_module, node, 2))
67*523fa7a6SAndroid Build Coastguard Worker        if node.target is torch.ops.higher_order.map_impl:
68*523fa7a6SAndroid Build Coastguard Worker            control_flow_submodules.append(_get_submodule(graph_module, node, 0))
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker    return control_flow_submodules
71