xref: /aosp_15_r20/external/pytorch/torch/fx/passes/backends/cudagraphs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
4from torch.fx.passes.operator_support import OperatorSupport
5from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
6from torch.fx.passes.fake_tensor_prop import FakeTensorProp
7from torch.utils import _pytree as pytree
8
9import operator
10
11class CudaGraphsSupport(OperatorSupport):
12    # TODO: why is submodules passed here
13    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
14        if node.op not in CALLABLE_NODE_OPS:
15            return False
16
17        if node.target in [torch.ops.aten.embedding_dense_backward.default]:
18            return False
19
20        if node.target in [operator.getitem]:
21            return True
22
23        found_not_cuda = False
24
25        def meta_fk(meta):
26            return meta["val"] if "val" in meta else meta["fake_result"]
27
28        def find_not_cuda(t):
29            nonlocal found_not_cuda
30            if isinstance(t, torch.Tensor) and t.device.type != 'cuda':
31                found_not_cuda = True
32
33        for n in node.all_input_nodes:
34            pytree.tree_map_(find_not_cuda, meta_fk(n.meta))
35
36        pytree.tree_map_(find_not_cuda, meta_fk(node.meta))
37
38        # NB: factory function is accounted for because the result would be
39        # cpu or cuda
40
41        return not found_not_cuda
42
43def partition_cudagraphs(gm, inputs):
44    """
45    Partition an FX graph into sub-GraphModules that can be validly run under
46    CUDA graphs.  For a subgraph to be runnable under CUDA, all of the operations
47    must involve CUDA tensors only/
48    """
49
50    FakeTensorProp(gm).propagate(*inputs)
51    supported_ops = CudaGraphsSupport()
52    # TODO: single node partition may be wrong due to the pessimization
53    # from copying in and out the data.  Check in benchmarks, perhaps
54    partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True)
55    partitions = partitioner.propose_partitions()
56    fused_graph = partitioner.fuse_partitions(partitions)
57    return fused_graph
58