xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/registration.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Module for handling symbolic function registration."""
3
4import warnings
5from typing import (
6    Callable,
7    Collection,
8    Dict,
9    Generic,
10    Optional,
11    Sequence,
12    Set,
13    TypeVar,
14    Union,
15)
16
17from torch.onnx import _constants, errors
18
19
20OpsetVersion = int
21
22
23def _dispatch_opset_version(
24    target: OpsetVersion, registered_opsets: Collection[OpsetVersion]
25) -> Optional[OpsetVersion]:
26    """Finds the registered opset given a target opset version and the available opsets.
27
28    Args:
29        target: The target opset version.
30        registered_opsets: The available opsets.
31
32    Returns:
33        The registered opset version.
34    """
35    if not registered_opsets:
36        return None
37
38    descending_registered_versions = sorted(registered_opsets, reverse=True)
39    # Linear search for the opset version, which is fine since the number of opset
40    # versions is small.
41
42    if target >= _constants.ONNX_BASE_OPSET:
43        # Always look down toward opset 1 when the target is >= ONNX_BASE_OPSET (opset 9).
44        # When a custom op is register at opset 1, we want to be able to discover it as a
45        # fallback for all opsets >= ONNX_BASE_OPSET.
46        for version in descending_registered_versions:
47            if version <= target:
48                return version
49        return None
50
51    # target < opset 9. This is the legacy behavior to support opset 7 and opset 8.
52    # for caffe2 support. We search up toward opset 9.
53    for version in reversed(descending_registered_versions):
54        # Count back up until _constants.ONNX_BASE_OPSET
55        if target <= version <= _constants.ONNX_BASE_OPSET:
56            return version
57
58    return None
59
60
61_K = TypeVar("_K")
62_V = TypeVar("_V")
63
64
65class OverrideDict(Collection[_K], Generic[_K, _V]):
66    """A dictionary that merges built-in and custom symbolic functions.
67
68    It supports overriding and un-overriding built-in symbolic functions with custom
69    ones.
70    """
71
72    def __init__(self) -> None:
73        self._base: Dict[_K, _V] = {}
74        self._overrides: Dict[_K, _V] = {}
75        self._merged: Dict[_K, _V] = {}
76
77    def set_base(self, key: _K, value: _V) -> None:
78        self._base[key] = value
79        if key not in self._overrides:
80            self._merged[key] = value
81
82    def in_base(self, key: _K) -> bool:
83        """Checks if a key is in the base dictionary."""
84        return key in self._base
85
86    def override(self, key: _K, value: _V) -> None:
87        """Overrides a base key-value with a new pair."""
88        self._overrides[key] = value
89        self._merged[key] = value
90
91    def remove_override(self, key: _K) -> None:
92        """Un-overrides a key-value pair."""
93        self._overrides.pop(key, None)  # type: ignore[arg-type]
94        self._merged.pop(key, None)  # type: ignore[arg-type]
95        if key in self._base:
96            self._merged[key] = self._base[key]
97
98    def overridden(self, key: _K) -> bool:
99        """Checks if a key-value pair is overridden."""
100        return key in self._overrides
101
102    def __getitem__(self, key: _K) -> _V:
103        return self._merged[key]
104
105    def get(self, key: _K, default: Optional[_V] = None):
106        return self._merged.get(key, default)
107
108    def __contains__(self, key: object) -> bool:
109        return key in self._merged
110
111    def __iter__(self):
112        return iter(self._merged)
113
114    def __len__(self) -> int:
115        return len(self._merged)
116
117    def __repr__(self) -> str:
118        return f"OverrideDict(base={self._base}, overrides={self._overrides})"
119
120    def __bool__(self) -> bool:
121        return bool(self._merged)
122
123
124class _SymbolicFunctionGroup:
125    """Different versions of symbolic functions registered to the same name.
126
127    O(number of registered versions of an op) search is performed to find the most
128    recent version of the op.
129
130    The registration is delayed until op is used to improve startup time.
131
132    Function overloads with different arguments are not allowed.
133    Custom op overrides are supported.
134    """
135
136    def __init__(self, name: str) -> None:
137        self._name = name
138        # A dictionary of functions, keyed by the opset version.
139        self._functions: OverrideDict[OpsetVersion, Callable] = OverrideDict()
140
141    def __repr__(self) -> str:
142        return f"_SymbolicFunctionGroup({self._name}, registered={self._functions})"
143
144    def __getitem__(self, key: OpsetVersion) -> Callable:
145        result = self.get(key)
146        if result is None:
147            raise KeyError(key)
148        return result
149
150    # TODO(justinchuby): Add @functools.lru_cache(maxsize=None) if lookup time becomes
151    # a problem.
152    def get(self, opset: OpsetVersion) -> Optional[Callable]:
153        """Find the most recent version of the function."""
154        version = _dispatch_opset_version(opset, self._functions)
155        if version is None:
156            return None
157
158        return self._functions[version]
159
160    def add(self, func: Callable, opset: OpsetVersion) -> None:
161        """Adds a symbolic function.
162
163        Args:
164            func: The function to add.
165            opset: The opset version of the function to add.
166        """
167        if self._functions.in_base(opset):
168            warnings.warn(
169                f"Symbolic function '{self._name}' already registered for opset {opset}. "
170                f"Replacing the existing function with new function. This is unexpected. "
171                f"Please report it on {_constants.PYTORCH_GITHUB_ISSUES_URL}.",
172                errors.OnnxExporterWarning,
173            )
174        self._functions.set_base(opset, func)
175
176    def add_custom(self, func: Callable, opset: OpsetVersion) -> None:
177        """Adds a custom symbolic function.
178
179        Args:
180            func: The symbolic function to register.
181            opset: The corresponding opset version.
182        """
183        self._functions.override(opset, func)
184
185    def remove_custom(self, opset: OpsetVersion) -> None:
186        """Removes a custom symbolic function.
187
188        Args:
189            opset: The opset version of the custom function to remove.
190        """
191        if not self._functions.overridden(opset):
192            warnings.warn(
193                f"No custom function registered for '{self._name}' opset {opset}"
194            )
195            return
196        self._functions.remove_override(opset)
197
198    def get_min_supported(self) -> OpsetVersion:
199        """Returns the lowest built-in opset version supported by the function."""
200        return min(self._functions)
201
202
203class SymbolicRegistry:
204    """Registry for symbolic functions.
205
206    The registry maintains a mapping from qualified names to symbolic functions.
207    It is used to register new symbolic functions and to dispatch calls to
208    the appropriate function.
209    """
210
211    def __init__(self) -> None:
212        self._registry: Dict[str, _SymbolicFunctionGroup] = {}
213
214    def register(
215        self, name: str, opset: OpsetVersion, func: Callable, custom: bool = False
216    ) -> None:
217        """Registers a symbolic function.
218
219        Args:
220            name: The qualified name of the function to register. In the form of 'domain::op'.
221                E.g. 'aten::add'.
222            opset: The opset version of the function to register.
223            func: The symbolic function to register.
224            custom: Whether the function is a custom function that overrides existing ones.
225
226        Raises:
227            ValueError: If the separator '::' is not in the name.
228        """
229        if "::" not in name:
230            raise ValueError(
231                f"The name must be in the form of 'domain::op', not '{name}'"
232            )
233        symbolic_functions = self._registry.setdefault(
234            name, _SymbolicFunctionGroup(name)
235        )
236        if custom:
237            symbolic_functions.add_custom(func, opset)
238        else:
239            symbolic_functions.add(func, opset)
240
241    def unregister(self, name: str, opset: OpsetVersion) -> None:
242        """Unregisters a symbolic function.
243
244        Args:
245            name: The qualified name of the function to unregister.
246            opset: The opset version of the function to unregister.
247        """
248        if name not in self._registry:
249            return
250        self._registry[name].remove_custom(opset)
251
252    def get_function_group(self, name: str) -> Optional[_SymbolicFunctionGroup]:
253        """Returns the function group for the given name."""
254        return self._registry.get(name)
255
256    def is_registered_op(self, name: str, version: int) -> bool:
257        """Returns whether the given op is registered for the given opset version."""
258        functions = self.get_function_group(name)
259        if functions is None:
260            return False
261        return functions.get(version) is not None
262
263    def all_functions(self) -> Set[str]:
264        """Returns the set of all registered function names."""
265        return set(self._registry)
266
267
268def onnx_symbolic(
269    name: str,
270    opset: Union[OpsetVersion, Sequence[OpsetVersion]],
271    decorate: Optional[Sequence[Callable]] = None,
272    custom: bool = False,
273) -> Callable:
274    """Registers a symbolic function.
275
276    Usage::
277
278    ```
279    @onnx_symbolic(
280        "aten::symbolic_b",
281        opset=10,
282        decorate=[quantized_aten_handler(scale=1 / 128, zero_point=0)],
283    )
284    @symbolic_helper.parse_args("v", "v", "b")
285    def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value: ...
286    ```
287
288    Args:
289        name: The qualified name of the function in the form of 'domain::op'.
290            E.g. 'aten::add'.
291        opset: The opset versions of the function to register at.
292        decorate: A sequence of decorators to apply to the function.
293        custom: Whether the function is a custom symbolic function.
294
295    Raises:
296        ValueError: If the separator '::' is not in the name.
297    """
298
299    def wrapper(func: Callable) -> Callable:
300        decorated = func
301        if decorate is not None:
302            for decorate_func in decorate:
303                decorated = decorate_func(decorated)
304
305        global registry
306        nonlocal opset
307        if isinstance(opset, OpsetVersion):
308            opset = (opset,)
309        for opset_version in opset:
310            registry.register(name, opset_version, decorated, custom=custom)
311
312        # Return the original function because the decorators in "decorate" are only
313        # specific to the instance being registered.
314        return func
315
316    return wrapper
317
318
319def custom_onnx_symbolic(
320    name: str,
321    opset: Union[OpsetVersion, Sequence[OpsetVersion]],
322    decorate: Optional[Sequence[Callable]] = None,
323) -> Callable:
324    """Registers a custom symbolic function.
325
326    Args:
327        name: the qualified name of the function.
328        opset: the opset version of the function.
329        decorate: a sequence of decorators to apply to the function.
330
331    Returns:
332        The decorator.
333
334    Raises:
335        ValueError: If the separator '::' is not in the name.
336    """
337    return onnx_symbolic(name, opset, decorate, custom=True)
338
339
340# The registry for all symbolic functions.
341registry = SymbolicRegistry()
342