1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3 4import warnings 5from typing import Any 6from typing_extensions import TypeGuard 7 8import torch 9from torch.overrides import get_default_nowrap_functions 10 11 12__all__ = [ 13 "MaskedTensor", 14 "is_masked_tensor", 15] 16 17 18def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]: 19 r"""Returns True if the input is a MaskedTensor, else False 20 21 Args: 22 a: any input 23 24 Examples: 25 26 >>> # xdoctest: +SKIP 27 >>> from torch.masked import MaskedTensor 28 >>> data = torch.arange(6).reshape(2,3) 29 >>> mask = torch.tensor([[True, False, False], [True, True, False]]) 30 >>> mt = MaskedTensor(data, mask) 31 >>> is_masked_tensor(mt) 32 True 33 """ 34 return isinstance(obj, MaskedTensor) 35 36 37def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08): 38 if is_masked_tensor(a) or is_masked_tensor(b): 39 raise ValueError("Neither `a` nor `b` can be a MaskedTensor.") 40 if a.layout != b.layout: 41 raise ValueError( 42 f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}" 43 ) 44 45 if a.dtype != b.dtype: 46 b = b.type(a.dtype) 47 if a.layout == b.layout == torch.sparse_coo: 48 return _tensors_match(a.values(), b.values(), exact) and _tensors_match( 49 a.indices(), b.indices(), exact 50 ) 51 elif a.layout == b.layout == torch.sparse_csr: 52 return ( 53 _tensors_match(a.crow_indices(), b.crow_indices(), exact) 54 and _tensors_match(a.col_indices(), b.col_indices(), exact) 55 and _tensors_match(a.values(), b.values(), exact) 56 ) 57 if exact: 58 return (a.dim() == b.dim()) and torch.eq(a, b).all().item() 59 return (a.dim() == b.dim()) and torch.allclose(a, b, rtol=rtol, atol=atol) 60 61 62def _masks_match(a, b): 63 if is_masked_tensor(a) and is_masked_tensor(b): 64 mask_a = a.get_mask() 65 mask_b = b.get_mask() 66 return _tensors_match(mask_a, mask_b, exact=True) 67 return True 68 69 70def _map_mt_args_kwargs(args, kwargs, map_fn): 71 def _helper(a, map_fn): 72 if is_masked_tensor(a): 73 return map_fn(a) 74 elif torch.is_tensor(a): 75 return a 76 elif isinstance(a, list): 77 a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn) 78 return a_impl 79 elif isinstance(a, tuple): 80 a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn) 81 return tuple(a_impl) 82 else: 83 return a 84 85 if kwargs is None: 86 kwargs = {} 87 impl_args = [] 88 for a in args: 89 impl_args.append(_helper(a, map_fn)) 90 impl_kwargs = {} 91 for k in kwargs.keys(): 92 impl_kwargs[k] = _helper(a, map_fn) 93 return impl_args, impl_kwargs 94 95 96def _wrap_result(result_data, result_mask): 97 if isinstance(result_data, list): 98 return [_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)] 99 if isinstance(result_data, tuple): 100 return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)) 101 if torch.is_tensor(result_data): 102 return MaskedTensor(result_data, result_mask) 103 # Expect result_data and result_mask to be Tensors only 104 return NotImplemented 105 106 107def _masked_tensor_str(data, mask, formatter): 108 if data.layout in {torch.sparse_coo, torch.sparse_csr}: 109 data = data.to_dense() 110 mask = mask.to_dense() 111 if data.dim() == 1: 112 formatted_elements = [ 113 formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item()) 114 for d in data 115 ] 116 max_len = max(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask)) 117 return ( 118 "[" 119 + ", ".join( 120 [ 121 "--".rjust(max_len) if m else e 122 for (e, m) in zip(formatted_elements, ~mask) 123 ] 124 ) 125 + "]" 126 ) 127 sub_strings = [_masked_tensor_str(d, m, formatter) for (d, m) in zip(data, mask)] 128 sub_strings = ["\n".join([" " + si for si in s.split("\n")]) for s in sub_strings] 129 return "[\n" + ",\n".join(sub_strings) + "\n]" 130 131 132def _get_data(a): 133 if is_masked_tensor(a): 134 return a._masked_data 135 return a 136 137 138def _maybe_get_mask(a): 139 if is_masked_tensor(a): 140 return a.get_mask() 141 return None 142 143 144class MaskedTensor(torch.Tensor): 145 @staticmethod 146 def __new__(cls, data, mask, requires_grad=False): 147 if is_masked_tensor(data) or not torch.is_tensor(data): 148 raise TypeError("data must be a Tensor") 149 if is_masked_tensor(mask) or not torch.is_tensor(mask): 150 raise TypeError("mask must be a Tensor") 151 # Use a Tensor that of the give size for the wrapper. 152 kwargs = { 153 "device": data.device, 154 "dtype": data.dtype, 155 "layout": data.layout, 156 "requires_grad": requires_grad, 157 "dispatch_sizes_strides_policy": "strides", 158 "dispatch_layout": True, 159 } 160 warnings.warn( 161 ( 162 "The PyTorch API of MaskedTensors is in prototype stage " 163 "and will change in the near future. Please open a Github issue " 164 "for features requests and see our documentation on the torch.masked " 165 "module for further information about the project." 166 ), 167 UserWarning, 168 stacklevel=2, 169 ) 170 if data.requires_grad: 171 warnings.warn( 172 "It is not recommended to create a MaskedTensor with a tensor that requires_grad. " 173 "To avoid this, you can use data.clone().detach()", 174 UserWarning, 175 stacklevel=2, 176 ) 177 return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined] 178 179 def _preprocess_data(self, data, mask): 180 from .._ops import _sparse_coo_where, _sparse_csr_where 181 182 if data.layout != mask.layout: 183 raise TypeError("data and mask must have the same layout.") 184 if data.layout == torch.sparse_coo: 185 data = data.coalesce() 186 mask = mask.coalesce() 187 if data._nnz() != mask._nnz(): 188 data = _sparse_coo_where(mask, data, torch.tensor(0)) 189 elif data.layout == torch.sparse_csr: 190 if data._nnz() != mask._nnz(): 191 data = _sparse_csr_where(mask, data, torch.tensor(0)) 192 193 # Have to pick awkward names to not conflict with existing fields such as data 194 self._masked_data = data.clone() 195 self._masked_mask = mask.clone() 196 197 def _validate_members(self): 198 data = self._masked_data 199 mask = self.get_mask() 200 if type(data) != type(mask): 201 raise TypeError( 202 f"data and mask must have the same type. Got {type(data)} and {type(mask)}" 203 ) 204 if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: 205 raise TypeError(f"data layout of {data.layout} is not supported.") 206 if data.layout == torch.sparse_coo: 207 if not _tensors_match(data.indices(), mask.indices(), exact=True): 208 raise ValueError( 209 "data and mask are both sparse COO tensors but do not have the same indices." 210 ) 211 elif data.layout == torch.sparse_csr: 212 if not _tensors_match( 213 data.crow_indices(), mask.crow_indices(), exact=True 214 ) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True): 215 raise ValueError( 216 "data and mask are both sparse CSR tensors but do not share either crow or col indices." 217 ) 218 if mask.dtype != torch.bool: 219 raise TypeError("mask must have dtype bool.") 220 if not ( 221 data.dtype == torch.float16 222 or data.dtype == torch.float32 223 or data.dtype == torch.float64 224 or data.dtype == torch.bool 225 or data.dtype == torch.int8 226 or data.dtype == torch.int16 227 or data.dtype == torch.int32 228 or data.dtype == torch.int64 229 ): 230 raise TypeError(f"{data.dtype} is not supported in MaskedTensor.") 231 if data.dim() != mask.dim(): 232 raise ValueError("data.dim() must equal mask.dim()") 233 if data.size() != mask.size(): 234 raise ValueError("data.size() must equal mask.size()") 235 236 def __init__(self, data, mask, requires_grad=False): 237 self._preprocess_data(data, mask) 238 self._validate_members() 239 240 @staticmethod 241 def _from_values(data, mask): 242 """Differentiable constructor for MaskedTensor""" 243 244 class Constructor(torch.autograd.Function): 245 @staticmethod 246 def forward(ctx, data, mask): 247 return MaskedTensor(data, mask) 248 249 @staticmethod 250 def backward(ctx, grad_output): 251 return grad_output, None 252 253 result = Constructor.apply(data, mask) 254 return result 255 256 def _set_data_mask(self, data, mask): 257 self._masked_data = data 258 self._masked_mask = mask 259 self._validate_members() 260 261 def __repr__(self): 262 formatter = "{0:8.4f}" 263 if self.dim() == 0: 264 scalar_data = self.get_data().item() 265 data_formatted = ( 266 formatter.format(scalar_data) 267 if isinstance(scalar_data, float) 268 else str(scalar_data) 269 ) 270 if not self.get_mask().item(): 271 data_formatted = "--" 272 return ( 273 "MaskedTensor(" 274 + data_formatted 275 + ", " 276 + str(self.get_mask().item()) 277 + ")" 278 ) 279 s = _masked_tensor_str(self.get_data(), self.get_mask(), formatter) 280 s = "\n".join(" " + si for si in s.split("\n")) 281 return "MaskedTensor(\n" + s + "\n)" 282 283 # Seems like this needs to be defined before torch_dispatch to work 284 @classmethod 285 def __torch_function__(cls, func, types, args=(), kwargs=None): 286 kwargs = kwargs or {} 287 288 from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE 289 290 if func in _MASKEDTENSOR_FUNCTION_TABLE: 291 return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs) 292 293 if not all(issubclass(cls, t) for t in types): 294 return NotImplemented 295 with torch._C.DisableTorchFunctionSubclass(): 296 ret = func(*args, **kwargs) 297 if func in get_default_nowrap_functions(): 298 return ret 299 else: 300 return torch._tensor._convert(ret, cls) 301 302 @classmethod 303 def unary(cls, fn, data, mask): 304 return MaskedTensor(fn(data), mask) 305 306 @classmethod 307 def __torch_dispatch__(cls, func, types, args, kwargs): 308 func = func.overloadpacket 309 310 from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE 311 312 if func in _MASKEDTENSOR_DISPATCH_TABLE: 313 return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs) 314 315 msg = ( 316 f"{func.__name__} is not implemented in __torch_dispatch__ for MaskedTensor.\n" 317 "If you would like this operator to be supported, please file an issue for a feature request at " 318 "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n" 319 "In the case that the semantics for the operator are not trivial, it would be appreciated " 320 "to also include a proposal for the semantics." 321 ) 322 warnings.warn(msg) 323 return NotImplemented 324 325 def __lt__(self, other): 326 if is_masked_tensor(other): 327 return MaskedTensor(self.get_data() < _get_data(other), self.get_mask()) 328 return MaskedTensor(self.get_data() < other, self.get_mask()) 329 330 def to_tensor(self, value): 331 return self.get_data().masked_fill(~self.get_mask(), value) 332 333 def get_data(self): 334 class GetData(torch.autograd.Function): 335 @staticmethod 336 def forward(ctx, self): 337 return self._masked_data 338 339 @staticmethod 340 def backward(ctx, grad_output): 341 if is_masked_tensor(grad_output): 342 return grad_output 343 return MaskedTensor(grad_output, self.get_mask()) 344 345 return GetData.apply(self) 346 347 def get_mask(self): 348 return self._masked_mask 349 350 def is_sparse_coo(self): 351 return self.layout == torch.sparse_coo 352 353 def is_sparse_csr(self): 354 return self.layout == torch.sparse_csr 355 356 # Update later to support more sparse layouts 357 @property 358 def is_sparse(self): 359 return self.is_sparse_coo() or self.is_sparse_csr() 360