xref: /aosp_15_r20/external/pytorch/torch/_python_dispatcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import re
3
4import torch._C as C
5
6
7"""
8PythonDispatcher class is a thin python-binding to C++ dispatcher and it
9is designed to show how dispatcher precompute works. In particular,
10it shows for a certain op `foo`, what the computed dispatch table looks
11like after user register their kernels to certains dispatch keys.
12
13In the real C++ dispatcher we support many dispatch keys for different
14functionalities. For simplicity PythonDispatcher only supports dispatch
15keys for a single example of each use case. These use cases are listed below:
16
17- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference &
18    autograd kernel in pytorch core library.
19    E.g. CPU, CUDA
20- FPGA/AutogradOther: represents in-tree backends which we usually have backend specific
21    inference kernels, but they share the same autograd kernel specified in AutogradOther.
22    E.g. FPGA, SparseCsrCPU
23- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd
24    kernel defined in pytorch core library. Backend owner is responsible for registering both
25    inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support.
26    E.g. XLA, XPU, MPS
27- CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc.
28    Kernels registered to this key MUST work for inference for all backends.
29- Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther.
30    Kernels registered to this key MUST work for autograd for all backends.
31- CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd
32    Kernels registered to this key MUST work for both inference + autograd for all backends.
33
34Note we only allow registrations to alias keys inside pytorch core library. E.g
35you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd
36kernel from torch-xla extension, instead you should upstream the kernel into
37pytorch/pytorch repo so that it's available for all backends and continuously
38tested even without the extension.
39
40Usage:
41  dispatcher = PythonDispatcher()
42  dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"])
43  print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend.
44  # For more debugging information
45  # print(dispatcher.keys())
46  # print(dispatcher.registrations())
47  # print(dispatcher.rawRegistrations())
48  # print(dispatcher.rawDispatchTable())
49PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table.
50This file only provides the simplified API for developers, relevant test code is located in
51test/test_dispatch.py
52"""
53
54
55class PythonDispatcher:
56    namespace = "__test__"
57    name = "foo"
58    # fmt: off
59    runtime_keys = [
60        "CPU", "AutogradCPU",
61        "FPGA", "AutogradOther",
62        "XLA", "AutogradXLA",
63        "Lazy", "AutogradLazy",
64    ]
65    # fmt: on
66    alias_keys = [
67        "CompositeExplicitAutograd",
68        "Autograd",
69        "CompositeImplicitAutograd",
70    ]
71    supported_keys = runtime_keys + alias_keys
72
73    def __init__(self) -> None:
74        C._dispatch_check_invariants(self.name)  # type: ignore[attr-defined]
75        self.ref = C._dispatch_library("FRAGMENT", self.namespace, "")
76        self.ref.def_("foo(Tensor x) -> Tensor")
77
78    """
79    Returns a list of dispatch keys supported by PythonDispatcher.
80    You can register kernels to these keys.
81    """
82
83    def keys(self):
84        return self.supported_keys
85
86    """
87    Register kernels to the target dispatchKeys.
88    dispatchKeys(list[str]): a list of dispatch keys that you want to register
89      your own kernel. Note that you don't need to write the kernel yourself in
90      this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is
91      automatically generated and registered.
92    """
93
94    def register(self, dispatchKeys):
95        # Overriden is not supported and triggers a warning in C++ dispatcher.
96        if len(set(dispatchKeys)) != len(dispatchKeys):
97            raise RuntimeError(
98                f"Overriden is not allowed but found duplicates in {dispatchKeys}."
99            )
100        # We currently forbid this in codegen instead of C++ dispatcher.
101        if (
102            "CompositeImplicitAutograd" in dispatchKeys
103            and "CompositeExplicitAutograd" in dispatchKeys
104        ):
105            raise RuntimeError(
106                "Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed."
107            )
108        for key in dispatchKeys:
109            if key not in self.supported_keys:
110                raise RuntimeError(
111                    f"{key} is not supported, please select a dispatch key in {self.supported_keys}."
112                )
113            self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key)
114
115    """
116    Helper function to format (key, kernel).
117    """
118
119    def _format_line(self, key, kernel):
120        return f"{key:<15} {kernel}\n"
121
122    """
123    Helper function to print a table header.
124    """
125
126    def _format_header(self, header):
127        s = f"""
128{header}
129"""
130        s += self._format_line("key", "kernel")
131        s += "---------------------------\n"
132        return s
133
134    """
135    Returns raw output of all registration info for debugging only.
136    Use registrations() for a simplified version.
137    """
138
139    def rawRegistrations(self):
140        return C._dispatch_dump(f"{self.namespace}::{self.name}")  # type: ignore[attr-defined]
141
142    """
143    Returns raw output of computed dispatch table for debugging only.
144    Use dispatchTable() for a simplified version.
145    """
146
147    def rawDispatchTable(self):
148        return C._dispatch_dump_table(f"{self.namespace}::{self.name}")  # type: ignore[attr-defined]
149
150    """
151    Returns a table(str) including all the registrations from users.
152    Note this includes registrations to both runtime keys and alias keys.
153    """
154
155    def registrations(self):
156        output = self._format_header("Registered Kernels")
157        state = self.rawRegistrations()
158        state_entries = state.split("\n")
159        for line in state_entries:
160            first = line.split(":")[0]
161            if any(first.startswith(k) for k in self.supported_keys):
162                kernel = line.split("::")[0].split(" ")[1]
163                output += self._format_line(first, kernel)
164        return output
165
166    """
167    Returns the computed dispatch table(str). Note this only include
168    runtime keys, registrations to alias keys have been decoded to their
169    mapped runtime keys.
170    """
171
172    def dispatchTable(self):
173        output = self._format_header("Computed Dispatch Table")
174        table = self.rawDispatchTable()
175        table_entries = table.split("\n")
176        regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)")
177        for line in table_entries:
178            k = line.split(":")[0]
179            if k in self.runtime_keys:
180                entry = regex.sub("[", line)
181                output += self._format_line(k, entry.split(": ")[1])
182        return output
183