xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/op_registry_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from inspect import signature
4
5from .common_op_utils import _basic_validation
6
7
8"""
9Common utilities to register ops on ShardedTensor
10and PartialTensor.
11"""
12
13
14def _register_op(op, func, op_table):
15    """
16    Performs basic validation and registers the provided op in the given
17    op_table.
18    """
19    if len(signature(func).parameters) != 4:
20        raise TypeError(
21            f"Custom sharded op function expects signature: "
22            f"(types, args, kwargs, process_group), but received "
23            f"signature: {signature(func)}"
24        )
25
26    op_table[op] = func
27
28
29def _decorator_func(wrapped_func, op, op_table):
30    """
31    Decorator function to register the given ``op`` in the provided
32    ``op_table``
33    """
34
35    @functools.wraps(wrapped_func)
36    def wrapper(types, args, kwargs, process_group):
37        _basic_validation(op, args, kwargs)
38        return wrapped_func(types, args, kwargs, process_group)
39
40    _register_op(op, wrapper, op_table)
41    return wrapper
42