1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from types import FunctionType as function 10from typing import Dict, List, Tuple, Union 11 12import torch 13 14 15LeafValue = Union[ 16 torch.Tensor, 17 str, 18 int, 19 float, 20 bool, 21 complex, 22 torch.dtype, 23 torch.device, 24 torch.memory_format, 25 torch.layout, 26 None, 27] 28 29# We maintain a global cache of op lookups as this significantly speeds up 30# deserialization because hasattr(torch.ops, name) is an expensive call. 31_cache_ops_dict: Dict[ 32 Tuple[str, str], Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket] 33] = {} 34_cache_fake_ops_dict: Dict[Tuple[str, str], function] = {} 35 36 37def _get_submodule( 38 graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int 39) -> Tuple[str, torch.nn.Module, torch.fx.Node]: 40 submod_node = node.args[arg_index] 41 assert isinstance(submod_node, torch.fx.Node) 42 assert submod_node.op == "get_attr" 43 assert isinstance(submod_node.target, str) 44 submodule = graph_module.get_submodule(submod_node.target) 45 # pyre-ignore 46 return submod_node.target, submodule, node 47 48 49def get_control_flow_submodules( 50 graph_module: torch.fx.GraphModule, 51) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: 52 """ 53 Returns a list of submodules used for control flow operations 54 (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look 55 into submodules). Specifically, the returned value is a list containing a 56 tuple of (name of the submodule that's stored in the graph module, the 57 submodule itself, and the fx node that uses this submodule). 58 """ 59 control_flow_submodules = [] 60 for node in graph_module.graph.nodes: 61 if node.op != "call_function": 62 continue 63 64 if node.target is torch.ops.higher_order.cond: 65 control_flow_submodules.append(_get_submodule(graph_module, node, 1)) 66 control_flow_submodules.append(_get_submodule(graph_module, node, 2)) 67 if node.target is torch.ops.higher_order.map_impl: 68 control_flow_submodules.append(_get_submodule(graph_module, node, 0)) 69 70 return control_flow_submodules 71