1# mypy: allow-untyped-defs 2import re 3 4import torch._C as C 5 6 7""" 8PythonDispatcher class is a thin python-binding to C++ dispatcher and it 9is designed to show how dispatcher precompute works. In particular, 10it shows for a certain op `foo`, what the computed dispatch table looks 11like after user register their kernels to certains dispatch keys. 12 13In the real C++ dispatcher we support many dispatch keys for different 14functionalities. For simplicity PythonDispatcher only supports dispatch 15keys for a single example of each use case. These use cases are listed below: 16 17- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference & 18 autograd kernel in pytorch core library. 19 E.g. CPU, CUDA 20- FPGA/AutogradOther: represents in-tree backends which we usually have backend specific 21 inference kernels, but they share the same autograd kernel specified in AutogradOther. 22 E.g. FPGA, SparseCsrCPU 23- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd 24 kernel defined in pytorch core library. Backend owner is responsible for registering both 25 inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support. 26 E.g. XLA, XPU, MPS 27- CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc. 28 Kernels registered to this key MUST work for inference for all backends. 29- Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther. 30 Kernels registered to this key MUST work for autograd for all backends. 31- CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd 32 Kernels registered to this key MUST work for both inference + autograd for all backends. 33 34Note we only allow registrations to alias keys inside pytorch core library. E.g 35you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd 36kernel from torch-xla extension, instead you should upstream the kernel into 37pytorch/pytorch repo so that it's available for all backends and continuously 38tested even without the extension. 39 40Usage: 41 dispatcher = PythonDispatcher() 42 dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"]) 43 print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend. 44 # For more debugging information 45 # print(dispatcher.keys()) 46 # print(dispatcher.registrations()) 47 # print(dispatcher.rawRegistrations()) 48 # print(dispatcher.rawDispatchTable()) 49PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table. 50This file only provides the simplified API for developers, relevant test code is located in 51test/test_dispatch.py 52""" 53 54 55class PythonDispatcher: 56 namespace = "__test__" 57 name = "foo" 58 # fmt: off 59 runtime_keys = [ 60 "CPU", "AutogradCPU", 61 "FPGA", "AutogradOther", 62 "XLA", "AutogradXLA", 63 "Lazy", "AutogradLazy", 64 ] 65 # fmt: on 66 alias_keys = [ 67 "CompositeExplicitAutograd", 68 "Autograd", 69 "CompositeImplicitAutograd", 70 ] 71 supported_keys = runtime_keys + alias_keys 72 73 def __init__(self) -> None: 74 C._dispatch_check_invariants(self.name) # type: ignore[attr-defined] 75 self.ref = C._dispatch_library("FRAGMENT", self.namespace, "") 76 self.ref.def_("foo(Tensor x) -> Tensor") 77 78 """ 79 Returns a list of dispatch keys supported by PythonDispatcher. 80 You can register kernels to these keys. 81 """ 82 83 def keys(self): 84 return self.supported_keys 85 86 """ 87 Register kernels to the target dispatchKeys. 88 dispatchKeys(list[str]): a list of dispatch keys that you want to register 89 your own kernel. Note that you don't need to write the kernel yourself in 90 this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is 91 automatically generated and registered. 92 """ 93 94 def register(self, dispatchKeys): 95 # Overriden is not supported and triggers a warning in C++ dispatcher. 96 if len(set(dispatchKeys)) != len(dispatchKeys): 97 raise RuntimeError( 98 f"Overriden is not allowed but found duplicates in {dispatchKeys}." 99 ) 100 # We currently forbid this in codegen instead of C++ dispatcher. 101 if ( 102 "CompositeImplicitAutograd" in dispatchKeys 103 and "CompositeExplicitAutograd" in dispatchKeys 104 ): 105 raise RuntimeError( 106 "Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed." 107 ) 108 for key in dispatchKeys: 109 if key not in self.supported_keys: 110 raise RuntimeError( 111 f"{key} is not supported, please select a dispatch key in {self.supported_keys}." 112 ) 113 self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key) 114 115 """ 116 Helper function to format (key, kernel). 117 """ 118 119 def _format_line(self, key, kernel): 120 return f"{key:<15} {kernel}\n" 121 122 """ 123 Helper function to print a table header. 124 """ 125 126 def _format_header(self, header): 127 s = f""" 128{header} 129""" 130 s += self._format_line("key", "kernel") 131 s += "---------------------------\n" 132 return s 133 134 """ 135 Returns raw output of all registration info for debugging only. 136 Use registrations() for a simplified version. 137 """ 138 139 def rawRegistrations(self): 140 return C._dispatch_dump(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] 141 142 """ 143 Returns raw output of computed dispatch table for debugging only. 144 Use dispatchTable() for a simplified version. 145 """ 146 147 def rawDispatchTable(self): 148 return C._dispatch_dump_table(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] 149 150 """ 151 Returns a table(str) including all the registrations from users. 152 Note this includes registrations to both runtime keys and alias keys. 153 """ 154 155 def registrations(self): 156 output = self._format_header("Registered Kernels") 157 state = self.rawRegistrations() 158 state_entries = state.split("\n") 159 for line in state_entries: 160 first = line.split(":")[0] 161 if any(first.startswith(k) for k in self.supported_keys): 162 kernel = line.split("::")[0].split(" ")[1] 163 output += self._format_line(first, kernel) 164 return output 165 166 """ 167 Returns the computed dispatch table(str). Note this only include 168 runtime keys, registrations to alias keys have been decoded to their 169 mapped runtime keys. 170 """ 171 172 def dispatchTable(self): 173 output = self._format_header("Computed Dispatch Table") 174 table = self.rawDispatchTable() 175 table_entries = table.split("\n") 176 regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)") 177 for line in table_entries: 178 k = line.split(":")[0] 179 if k in self.runtime_keys: 180 entry = regex.sub("[", line) 181 output += self._format_line(k, entry.split(": ")[1]) 182 return output 183