1# mypy: allow-untyped-defs 2import functools 3from inspect import signature 4 5from .common_op_utils import _basic_validation 6 7 8""" 9Common utilities to register ops on ShardedTensor 10and PartialTensor. 11""" 12 13 14def _register_op(op, func, op_table): 15 """ 16 Performs basic validation and registers the provided op in the given 17 op_table. 18 """ 19 if len(signature(func).parameters) != 4: 20 raise TypeError( 21 f"Custom sharded op function expects signature: " 22 f"(types, args, kwargs, process_group), but received " 23 f"signature: {signature(func)}" 24 ) 25 26 op_table[op] = func 27 28 29def _decorator_func(wrapped_func, op, op_table): 30 """ 31 Decorator function to register the given ``op`` in the provided 32 ``op_table`` 33 """ 34 35 @functools.wraps(wrapped_func) 36 def wrapper(types, args, kwargs, process_group): 37 _basic_validation(op, args, kwargs) 38 return wrapped_func(types, args, kwargs, process_group) 39 40 _register_op(op, wrapper, op_table) 41 return wrapper 42