1# mypy: allow-untyped-defs 2import contextlib 3from typing import List, Tuple 4 5import torch 6 7 8@contextlib.contextmanager 9def optimized_execution(should_optimize): 10 """Context manager that controls whether the JIT's executor will run optimizations before executing a function.""" 11 stored_flag = torch._C._get_graph_executor_optimize() 12 torch._C._set_graph_executor_optimize(should_optimize) 13 try: 14 yield 15 finally: 16 torch._C._set_graph_executor_optimize(stored_flag) 17 18 19@contextlib.contextmanager 20def fuser(name): 21 """Context manager that facilitates switching between backend fusers. 22 23 Valid names: 24 * ``fuser0`` - enables only legacy fuser 25 * ``fuser1`` - enables only NNC 26 * ``fuser2`` - enables only nvFuser 27 * ``fuser3`` - enables oneDNN Graph 28 """ 29 old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() 30 old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() 31 old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() 32 old_nvfuser_state = torch._C._jit_nvfuser_enabled() 33 old_llga_state = torch._C._jit_llga_enabled() 34 if name == "fuser0": # legacy fuser 35 torch._C._jit_override_can_fuse_on_cpu(True) 36 torch._C._jit_override_can_fuse_on_gpu(True) 37 torch._C._jit_set_texpr_fuser_enabled(False) 38 torch._C._jit_set_nvfuser_enabled(False) 39 torch._C._jit_set_llga_enabled(False) 40 elif name == "fuser1": # NNC 41 old_profiling_executor = torch._C._jit_set_profiling_executor(True) 42 old_profiling_mode = torch._C._get_graph_executor_optimize(True) 43 torch._C._jit_override_can_fuse_on_cpu(True) 44 torch._C._jit_override_can_fuse_on_gpu(True) 45 torch._C._jit_set_texpr_fuser_enabled(True) 46 torch._C._jit_set_nvfuser_enabled(False) 47 torch._C._jit_set_llga_enabled(False) 48 elif name == "fuser2": # nvFuser 49 torch._C._jit_override_can_fuse_on_cpu(False) 50 torch._C._jit_override_can_fuse_on_gpu(False) 51 torch._C._jit_set_texpr_fuser_enabled(False) 52 torch._C._jit_set_nvfuser_enabled(True) 53 torch._C._jit_set_llga_enabled(False) 54 elif name == "fuser3": # oneDNN Graph 55 old_profiling_executor = torch._C._jit_set_profiling_executor(True) 56 old_profiling_mode = torch._C._get_graph_executor_optimize(True) 57 torch._C._jit_override_can_fuse_on_cpu(True) 58 torch._C._jit_override_can_fuse_on_gpu(False) 59 torch._C._jit_set_texpr_fuser_enabled(True) 60 torch._C._jit_set_nvfuser_enabled(False) 61 torch._C._jit_set_llga_enabled(True) 62 elif name == "none": # Turn Pytorch fuser off 63 torch._C._jit_override_can_fuse_on_cpu(False) 64 torch._C._jit_override_can_fuse_on_gpu(False) 65 torch._C._jit_set_texpr_fuser_enabled(False) 66 torch._C._jit_set_nvfuser_enabled(False) 67 torch._C._jit_set_llga_enabled(False) 68 else: 69 raise Exception(f"unrecognized fuser option (name: {name})") # noqa: TRY002 70 try: 71 yield 72 finally: 73 if name in ["fuser1", "fuser3"]: # NNC or oneDNN Graph 74 torch._C._jit_set_profiling_executor(old_profiling_executor) # type: ignore[possibly-undefined] 75 torch._C._get_graph_executor_optimize(old_profiling_mode) # type: ignore[possibly-undefined] 76 # recover the previous values 77 torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse) 78 torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse) 79 torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state) 80 torch._C._jit_set_nvfuser_enabled(old_nvfuser_state) 81 torch._C._jit_set_llga_enabled(old_llga_state) 82 83 84last_executed_optimized_graph = torch._C._last_executed_optimized_graph 85 86 87def _get_differentiable_graph_node(node, diff_node): 88 if node.kind() == "prim::DifferentiableGraph": 89 diff_node.append(node) 90 else: 91 for block in node.blocks(): 92 for n in block.nodes(): 93 _get_differentiable_graph_node(n, diff_node) 94 95 96def _graph_for(self, *args, **kwargs): 97 return _script_method_graph_for(self, self, *args, **kwargs) 98 99 100def _script_method_graph_for(self, parent, *args, **kwargs): 101 try: 102 dbs = parent.get_debug_state() 103 eps = list(dbs.execution_plans.values()) 104 assert len(eps) == 1 105 graph = eps[0].graph.copy() 106 107 # graph_executor_states for differentiable node 108 fw_states = eps[0].code.differentiable_op_executor_states() 109 diff_nodes: List[torch._C.Node] = [] 110 for n in graph.nodes(): 111 _get_differentiable_graph_node(n, diff_nodes) 112 113 assert len(fw_states) == len(diff_nodes) 114 # swap each differentiable graph with optimized graph in their execution plan 115 for n, state in zip(diff_nodes, fw_states): 116 fw_execution_plans = list(state.execution_plans.values()) 117 # we can only update the subgraph when there's a unique execution 118 # plan. Avoid assert here so we would skip the ones that can't be 119 # updated while try the best effort to update other nodes. 120 if len(fw_execution_plans) == 1: 121 n.g_("Subgraph", fw_execution_plans[0].graph) 122 123 return graph 124 except Exception: 125 # fallback approach, we just ran the graph and return the recorded optimized 126 # graph 127 self(*args, **kwargs) 128 return last_executed_optimized_graph() 129 130 131def set_fusion_strategy(strategy: List[Tuple[str, int]]): 132 """Set the type and number of specializations that can occur during fusion. 133 134 Usage: provide a list of pairs (type, depth) where type is one of "STATIC" or "DYNAMIC" 135 and depth is an integer. 136 137 Behavior - static vs dynamic: 138 In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined 139 based on some initial profiling runs. 140 In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple 141 shapes are possible. 142 143 In both cases, we also recompile on new striding behavior, device, or dtype. 144 145 Behavior - fallback functions & depth: 146 When an input doesn't match the format required by the specialized compiled op, it will run 147 a fallback function. Fallback functions are recursively be compiled and specialized based 148 on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to 149 limit the number of specializations that can be compiled, before giving up on recompiling and 150 falling back to a completely un-fused, un-specialized implementation. 151 152 The list of (type, depth) pairs controls the type of specializations and the number of 153 specializations. For example: [("STATIC", 2), ("DYNAMIC", 2)] indicates that the first 154 two specializations will use static fusions, the following two specializations will use 155 dynamic fusion, and any inputs that satisfy none of the 4 options will run an 156 unfused implementation. 157 158 NB: in the future, if more as more fusion backends are added there may be more granular 159 apis for specific fusers. 160 """ 161 return torch._C._jit_set_fusion_strategy(strategy) 162