1# mypy: allow-untyped-defs 2import torch 3from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 4from torch.fx.passes.operator_support import OperatorSupport 5from torch.fx.passes.tools_common import CALLABLE_NODE_OPS 6from torch.fx.passes.fake_tensor_prop import FakeTensorProp 7from torch.utils import _pytree as pytree 8 9import operator 10 11class CudaGraphsSupport(OperatorSupport): 12 # TODO: why is submodules passed here 13 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 14 if node.op not in CALLABLE_NODE_OPS: 15 return False 16 17 if node.target in [torch.ops.aten.embedding_dense_backward.default]: 18 return False 19 20 if node.target in [operator.getitem]: 21 return True 22 23 found_not_cuda = False 24 25 def meta_fk(meta): 26 return meta["val"] if "val" in meta else meta["fake_result"] 27 28 def find_not_cuda(t): 29 nonlocal found_not_cuda 30 if isinstance(t, torch.Tensor) and t.device.type != 'cuda': 31 found_not_cuda = True 32 33 for n in node.all_input_nodes: 34 pytree.tree_map_(find_not_cuda, meta_fk(n.meta)) 35 36 pytree.tree_map_(find_not_cuda, meta_fk(node.meta)) 37 38 # NB: factory function is accounted for because the result would be 39 # cpu or cuda 40 41 return not found_not_cuda 42 43def partition_cudagraphs(gm, inputs): 44 """ 45 Partition an FX graph into sub-GraphModules that can be validly run under 46 CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations 47 must involve CUDA tensors only/ 48 """ 49 50 FakeTensorProp(gm).propagate(*inputs) 51 supported_ops = CudaGraphsSupport() 52 # TODO: single node partition may be wrong due to the pessimization 53 # from copying in and out the data. Check in benchmarks, perhaps 54 partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True) 55 partitions = partitioner.propose_partitions() 56 fused_graph = partitioner.fuse_partitions(partitions) 57 return fused_graph 58