1# mypy: allow-untyped-defs 2from typing import Optional 3 4import torch 5from torch.utils import _pytree as pytree 6 7 8def _basic_validation(op, args=(), kwargs=None): 9 """ 10 Common validation across all ops go in here. 11 """ 12 from torch.distributed._shard.sharded_tensor import ShardedTensor 13 14 if len(args) == 0 and (kwargs is None or len(kwargs) == 0): 15 raise ValueError(f" No input for '{op.__name__}'!") 16 17 # Validate types 18 has_distributed_tensor = False 19 20 def is_distributed_tensor(e): 21 nonlocal has_distributed_tensor 22 if isinstance(e, ShardedTensor): 23 has_distributed_tensor = True 24 25 pytree.tree_map_(is_distributed_tensor, args) 26 pytree.tree_map_(is_distributed_tensor, kwargs) 27 28 if not has_distributed_tensor: 29 raise TypeError( 30 f"torch function '{op.__name__}', with args: {args} and " 31 f"kwargs: {kwargs} are called without any distributed tensor!" 32 ) 33 34 # Validate all distributed tensors use the same PG. 35 cur_pg: Optional[torch.distributed.ProcessGroup] = None 36 37 def validate_pg(e): 38 nonlocal cur_pg 39 if isinstance(e, ShardedTensor): 40 if cur_pg is not None and e._process_group is not cur_pg: 41 raise RuntimeError( 42 "All distributed tensors should use the " 43 "same ProcessGroup if used together in an op." 44 ) 45 cur_pg = e._process_group 46 47 pytree.tree_map_(validate_pg, args) 48 pytree.tree_map_(validate_pg, kwargs) 49 50 51def _register_default_op(op, decorator): 52 @decorator(op) 53 def tensor_default_op(types, args=(), kwargs=None, pg=None): 54 """ 55 Handles ``__torch_function__`` dispatch for the default tensor ops that 56 behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or 57 ``torch.Tensor.dtype``. We simply lower to the real op call with 58 DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__`` 59 to avoid recursions. 60 """ 61 if kwargs is None: 62 kwargs = {} 63 64 with torch._C.DisableTorchFunctionSubclass(): 65 return op(*args, **kwargs) 66