1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerfrom types import FunctionType as function 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Tuple, Union 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport torch 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard WorkerLeafValue = Union[ 16*523fa7a6SAndroid Build Coastguard Worker torch.Tensor, 17*523fa7a6SAndroid Build Coastguard Worker str, 18*523fa7a6SAndroid Build Coastguard Worker int, 19*523fa7a6SAndroid Build Coastguard Worker float, 20*523fa7a6SAndroid Build Coastguard Worker bool, 21*523fa7a6SAndroid Build Coastguard Worker complex, 22*523fa7a6SAndroid Build Coastguard Worker torch.dtype, 23*523fa7a6SAndroid Build Coastguard Worker torch.device, 24*523fa7a6SAndroid Build Coastguard Worker torch.memory_format, 25*523fa7a6SAndroid Build Coastguard Worker torch.layout, 26*523fa7a6SAndroid Build Coastguard Worker None, 27*523fa7a6SAndroid Build Coastguard Worker] 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker# We maintain a global cache of op lookups as this significantly speeds up 30*523fa7a6SAndroid Build Coastguard Worker# deserialization because hasattr(torch.ops, name) is an expensive call. 31*523fa7a6SAndroid Build Coastguard Worker_cache_ops_dict: Dict[ 32*523fa7a6SAndroid Build Coastguard Worker Tuple[str, str], Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket] 33*523fa7a6SAndroid Build Coastguard Worker] = {} 34*523fa7a6SAndroid Build Coastguard Worker_cache_fake_ops_dict: Dict[Tuple[str, str], function] = {} 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Workerdef _get_submodule( 38*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int 39*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[str, torch.nn.Module, torch.fx.Node]: 40*523fa7a6SAndroid Build Coastguard Worker submod_node = node.args[arg_index] 41*523fa7a6SAndroid Build Coastguard Worker assert isinstance(submod_node, torch.fx.Node) 42*523fa7a6SAndroid Build Coastguard Worker assert submod_node.op == "get_attr" 43*523fa7a6SAndroid Build Coastguard Worker assert isinstance(submod_node.target, str) 44*523fa7a6SAndroid Build Coastguard Worker submodule = graph_module.get_submodule(submod_node.target) 45*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 46*523fa7a6SAndroid Build Coastguard Worker return submod_node.target, submodule, node 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Workerdef get_control_flow_submodules( 50*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 51*523fa7a6SAndroid Build Coastguard Worker) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: 52*523fa7a6SAndroid Build Coastguard Worker """ 53*523fa7a6SAndroid Build Coastguard Worker Returns a list of submodules used for control flow operations 54*523fa7a6SAndroid Build Coastguard Worker (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look 55*523fa7a6SAndroid Build Coastguard Worker into submodules). Specifically, the returned value is a list containing a 56*523fa7a6SAndroid Build Coastguard Worker tuple of (name of the submodule that's stored in the graph module, the 57*523fa7a6SAndroid Build Coastguard Worker submodule itself, and the fx node that uses this submodule). 58*523fa7a6SAndroid Build Coastguard Worker """ 59*523fa7a6SAndroid Build Coastguard Worker control_flow_submodules = [] 60*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 61*523fa7a6SAndroid Build Coastguard Worker if node.op != "call_function": 62*523fa7a6SAndroid Build Coastguard Worker continue 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker if node.target is torch.ops.higher_order.cond: 65*523fa7a6SAndroid Build Coastguard Worker control_flow_submodules.append(_get_submodule(graph_module, node, 1)) 66*523fa7a6SAndroid Build Coastguard Worker control_flow_submodules.append(_get_submodule(graph_module, node, 2)) 67*523fa7a6SAndroid Build Coastguard Worker if node.target is torch.ops.higher_order.map_impl: 68*523fa7a6SAndroid Build Coastguard Worker control_flow_submodules.append(_get_submodule(graph_module, node, 0)) 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker return control_flow_submodules 71