xref: /aosp_15_r20/external/pytorch/torch/jit/_fuser.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3from typing import List, Tuple
4
5import torch
6
7
8@contextlib.contextmanager
9def optimized_execution(should_optimize):
10    """Context manager that controls whether the JIT's executor will run optimizations before executing a function."""
11    stored_flag = torch._C._get_graph_executor_optimize()
12    torch._C._set_graph_executor_optimize(should_optimize)
13    try:
14        yield
15    finally:
16        torch._C._set_graph_executor_optimize(stored_flag)
17
18
19@contextlib.contextmanager
20def fuser(name):
21    """Context manager that facilitates switching between backend fusers.
22
23    Valid names:
24    * ``fuser0`` - enables only legacy fuser
25    * ``fuser1`` - enables only NNC
26    * ``fuser2`` - enables only nvFuser
27    * ``fuser3`` - enables oneDNN Graph
28    """
29    old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
30    old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
31    old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
32    old_nvfuser_state = torch._C._jit_nvfuser_enabled()
33    old_llga_state = torch._C._jit_llga_enabled()
34    if name == "fuser0":  # legacy fuser
35        torch._C._jit_override_can_fuse_on_cpu(True)
36        torch._C._jit_override_can_fuse_on_gpu(True)
37        torch._C._jit_set_texpr_fuser_enabled(False)
38        torch._C._jit_set_nvfuser_enabled(False)
39        torch._C._jit_set_llga_enabled(False)
40    elif name == "fuser1":  # NNC
41        old_profiling_executor = torch._C._jit_set_profiling_executor(True)
42        old_profiling_mode = torch._C._get_graph_executor_optimize(True)
43        torch._C._jit_override_can_fuse_on_cpu(True)
44        torch._C._jit_override_can_fuse_on_gpu(True)
45        torch._C._jit_set_texpr_fuser_enabled(True)
46        torch._C._jit_set_nvfuser_enabled(False)
47        torch._C._jit_set_llga_enabled(False)
48    elif name == "fuser2":  # nvFuser
49        torch._C._jit_override_can_fuse_on_cpu(False)
50        torch._C._jit_override_can_fuse_on_gpu(False)
51        torch._C._jit_set_texpr_fuser_enabled(False)
52        torch._C._jit_set_nvfuser_enabled(True)
53        torch._C._jit_set_llga_enabled(False)
54    elif name == "fuser3":  # oneDNN Graph
55        old_profiling_executor = torch._C._jit_set_profiling_executor(True)
56        old_profiling_mode = torch._C._get_graph_executor_optimize(True)
57        torch._C._jit_override_can_fuse_on_cpu(True)
58        torch._C._jit_override_can_fuse_on_gpu(False)
59        torch._C._jit_set_texpr_fuser_enabled(True)
60        torch._C._jit_set_nvfuser_enabled(False)
61        torch._C._jit_set_llga_enabled(True)
62    elif name == "none":  # Turn Pytorch fuser off
63        torch._C._jit_override_can_fuse_on_cpu(False)
64        torch._C._jit_override_can_fuse_on_gpu(False)
65        torch._C._jit_set_texpr_fuser_enabled(False)
66        torch._C._jit_set_nvfuser_enabled(False)
67        torch._C._jit_set_llga_enabled(False)
68    else:
69        raise Exception(f"unrecognized fuser option (name: {name})")  # noqa: TRY002
70    try:
71        yield
72    finally:
73        if name in ["fuser1", "fuser3"]:  # NNC or oneDNN Graph
74            torch._C._jit_set_profiling_executor(old_profiling_executor)  # type: ignore[possibly-undefined]
75            torch._C._get_graph_executor_optimize(old_profiling_mode)  # type: ignore[possibly-undefined]
76        # recover the previous values
77        torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
78        torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
79        torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
80        torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
81        torch._C._jit_set_llga_enabled(old_llga_state)
82
83
84last_executed_optimized_graph = torch._C._last_executed_optimized_graph
85
86
87def _get_differentiable_graph_node(node, diff_node):
88    if node.kind() == "prim::DifferentiableGraph":
89        diff_node.append(node)
90    else:
91        for block in node.blocks():
92            for n in block.nodes():
93                _get_differentiable_graph_node(n, diff_node)
94
95
96def _graph_for(self, *args, **kwargs):
97    return _script_method_graph_for(self, self, *args, **kwargs)
98
99
100def _script_method_graph_for(self, parent, *args, **kwargs):
101    try:
102        dbs = parent.get_debug_state()
103        eps = list(dbs.execution_plans.values())
104        assert len(eps) == 1
105        graph = eps[0].graph.copy()
106
107        # graph_executor_states for differentiable node
108        fw_states = eps[0].code.differentiable_op_executor_states()
109        diff_nodes: List[torch._C.Node] = []
110        for n in graph.nodes():
111            _get_differentiable_graph_node(n, diff_nodes)
112
113        assert len(fw_states) == len(diff_nodes)
114        # swap each differentiable graph with optimized graph in their execution plan
115        for n, state in zip(diff_nodes, fw_states):
116            fw_execution_plans = list(state.execution_plans.values())
117            # we can only update the subgraph when there's a unique execution
118            # plan. Avoid assert here so we would skip the ones that can't be
119            # updated while try the best effort to update other nodes.
120            if len(fw_execution_plans) == 1:
121                n.g_("Subgraph", fw_execution_plans[0].graph)
122
123        return graph
124    except Exception:
125        # fallback approach, we just ran the graph and return the recorded optimized
126        # graph
127        self(*args, **kwargs)
128        return last_executed_optimized_graph()
129
130
131def set_fusion_strategy(strategy: List[Tuple[str, int]]):
132    """Set the type and number of specializations that can occur during fusion.
133
134    Usage: provide a list of pairs (type, depth) where type is one of "STATIC" or "DYNAMIC"
135    and depth is an integer.
136
137    Behavior - static vs dynamic:
138        In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined
139        based on some initial profiling runs.
140        In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple
141        shapes are possible.
142
143    In both cases, we also recompile on new striding behavior, device, or dtype.
144
145    Behavior - fallback functions & depth:
146        When an input doesn't match the format required by the specialized compiled op, it will run
147        a fallback function. Fallback functions are recursively be compiled and specialized based
148        on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to
149        limit the number of specializations that can be compiled, before giving up on recompiling and
150        falling back to a completely un-fused, un-specialized implementation.
151
152    The list of (type, depth) pairs controls the type of specializations and the number of
153    specializations. For example: [("STATIC", 2), ("DYNAMIC", 2)] indicates that the first
154    two specializations will use static fusions, the following two specializations will use
155    dynamic fusion, and any inputs that satisfy none of the 4 options will run an
156    unfused implementation.
157
158    NB: in the future, if more as more fusion backends are added there may be more granular
159    apis for specific fusers.
160    """
161    return torch._C._jit_set_fusion_strategy(strategy)
162