1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3 4import warnings 5 6import torch 7 8from .core import is_masked_tensor 9from .creation import as_masked_tensor, masked_tensor 10 11 12__all__ = [] # type: ignore[var-annotated] 13 14 15def _masked_all_all(data, mask=None): 16 if mask is None: 17 return data.all() 18 return data.masked_fill(~mask, True).all() 19 20 21def _masked_all_dim(data, dim, keepdim=False, mask=None): 22 if mask is None: 23 return torch.all(data, dim=dim, keepdim=keepdim) 24 return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim) 25 26 27def _masked_all(*args, **kwargs): 28 if len(args) == 1 and len(kwargs) == 1: 29 return _masked_all_all(args[0], mask=kwargs["mask"]) 30 return _masked_all_dim(*args, **kwargs) 31 32 33def _multidim_any(mask, dim, keepdim): 34 if isinstance(dim, int): 35 return _multidim_any(mask, [dim], keepdim) 36 for d in sorted(dim, reverse=True): 37 mask = torch.any(mask, dim=d, keepdim=keepdim) 38 return mask 39 40 41def _get_masked_fn(fn): 42 if fn == "all": 43 return _masked_all 44 return getattr(torch.masked, fn) 45 46 47def _torch_reduce_all(fn): 48 def reduce_all(self): 49 masked_fn = _get_masked_fn(fn) 50 data = self.get_data() 51 mask = self.get_mask().values() if self.is_sparse else self.get_mask() 52 # When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the 53 # element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts. 54 # Therefore, this implementation calculates it using the strides. 55 if fn == "all": 56 result_data = masked_fn(data, mask=mask) 57 58 elif fn in {"argmin", "argmax"} and self.is_sparse_coo(): 59 sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int) 60 indices = ( 61 data.to_sparse_coo().indices() 62 if not self.is_sparse_coo() 63 else data.indices() 64 ) 65 idx = indices.unbind(1)[sparse_idx] 66 stride = data.size().numel() / torch.tensor( 67 data.size(), device=data.device 68 ).cumprod(0) 69 result_data = torch.sum(idx * stride) 70 71 # we simply pass in the values for sparse COO/CSR tensors 72 elif self.is_sparse: 73 result_data = masked_fn(masked_tensor(data.values(), mask)) 74 75 else: 76 result_data = masked_fn(self, mask=mask) 77 78 return as_masked_tensor(result_data, torch.any(mask)) 79 80 return reduce_all 81 82 83def _torch_reduce_dim(fn): 84 def reduce_dim(self, dim, keepdim=False, dtype=None): 85 if self.is_sparse: 86 msg = ( 87 f"The sparse version of {fn} is not implemented in reductions.\n" 88 "If you would like this operator to be supported, please file an issue for a feature request at " 89 "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n" 90 "In the case that the semantics for the operator are not trivial, it would be appreciated " 91 "to also include a proposal for the semantics." 92 ) 93 warnings.warn(msg) 94 return NotImplemented 95 if not is_masked_tensor(self): 96 raise TypeError("Input to reduce_dim must be a MaskedTensor") 97 98 masked_fn = _get_masked_fn(fn) 99 data = self.get_data() 100 mask = self.get_mask() 101 if fn == "all": 102 result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask) 103 else: 104 result_data = masked_fn( 105 self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask() 106 ) 107 return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim)) 108 109 return reduce_dim 110 111 112def _torch_reduce(fn): 113 def reduce_fn(*args, **kwargs): 114 if len(args) == 1 and len(kwargs) == 0: 115 return _torch_reduce_all(fn)(args[0]) 116 return _torch_reduce_dim(fn)(*args, **kwargs) 117 118 return reduce_fn 119 120 121def _reduce_dim_args(input, dim, keepdim=False, dtype=None): 122 return input, dim, keepdim, dtype 123 124 125def _torch_grad_reduce(fn): 126 def grad_reduce(*args, **kwargs): 127 if len(args) == 1 and len(kwargs) == 0: 128 return _torch_reduce_all(fn)(args[0]) 129 # TODO: autograd.Function doesn't support kwarg 130 input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs) 131 return _torch_reduce_dim(fn)(input, dim, keepdim, dtype) 132 133 return grad_reduce 134 135 136REDUCE_NAMES = [ 137 "sum", 138 "mean", 139 "amin", 140 "amax", 141 "argmin", 142 "argmax", 143 "prod", 144 "all", 145 "norm", 146 "var", 147 "std", 148] 149 150NATIVE_REDUCE_MAP = { 151 getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES 152} 153TORCH_REDUCE_MAP = { 154 getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES 155} 156TENSOR_REDUCE_MAP = { 157 getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES 158} 159 160NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys()) 161TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys()) 162TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys()) 163 164 165def _is_reduction(fn): 166 return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP 167 168 169def _apply_reduction(fn, *args, **kwargs): 170 if fn in NATIVE_REDUCE_MAP: 171 return NATIVE_REDUCE_MAP[fn](*args, **kwargs) 172 if fn in TORCH_REDUCE_MAP: 173 return TORCH_REDUCE_MAP[fn](*args, **kwargs) 174 if fn in TENSOR_REDUCE_MAP: 175 return TENSOR_REDUCE_MAP[fn](*args, **kwargs) 176 return NotImplemented 177