1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import warnings 4from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union 5 6import torch 7from torch import sym_float, Tensor 8from torch._prims_common import corresponding_real_dtype 9from torch.masked import _docs 10from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor 11from torch.masked.maskedtensor.creation import as_masked_tensor 12 13 14if TYPE_CHECKING: 15 from torch.types import _dtype as DType 16 17 DimOrDims = Optional[Union[int, Tuple[int], List[int]]] 18else: 19 # The JIT doesn't understand Union, nor torch.dtype here 20 DType = int 21 DimOrDims = Optional[Tuple[int]] 22 23 24__all__: List[str] = [] 25 26# All masked reduction/normalization operations have the same 27# signatures. Here we introduce docstring templates that are applied 28# to docstrings of reduction/normalization functions via 29# _apply_docstring_templates decorator. 30 31 32def _apply_docstring_templates(func): 33 """Decorator that applies docstring templates to function docstring 34 and returns the function instance. 35 """ 36 37 doc_string = getattr(_docs, f"{func.__name__}_docstring", None) 38 if doc_string is None: 39 warnings.warn( 40 f"No documentation string available for {func.__name__}." 41 " PyTorch team should run `python tools/update_masked_docs.py`" 42 " to generate the missing docstrings." 43 ) 44 else: 45 func.__doc__ = doc_string 46 47 # Expose function as public symbol 48 __all__.append(func.__name__) 49 50 return func 51 52 53def _generate_docstring(func): 54 """A utility function called from tools/update_masked_docs.py 55 script to update the module torch.masked._docs.py 56 """ 57 docstring_templates = dict( 58 reduction_signature="""\ 59{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", 60 reduction_descr="""\ 61Returns {operation name} of all the elements in the :attr:`input` 62tensor along the given dimension(s) :attr:`dim` while the :attr:`input` 63elements are masked out according to the boolean tensor 64:attr:`mask`.""", 65 reduction_args="""\ 66If :attr:`keepdim` is ``True``, the output tensor is of the same size 67as :attr:`input` except in the dimension(s) :attr:`dim` where it is of 68size 1. Otherwise, :attr:`dim` is squeezed (see 69:func:`torch.squeeze`), resulting in the output tensor having 1 (or 70``len(dim)``) fewer dimension(s). 71 72The boolean tensor :attr:`mask` defines the "validity" of 73:attr:`input` tensor elements: if :attr:`mask` element is True 74then the corresponding element in :attr:`input` tensor will be 75included in {operation name} computation, otherwise the element is 76ignored. 77 78When all elements of :attr:`input` along the given dimension 79:attr:`dim` are ignored (fully masked-out), the corresponding element 80of the output tensor will have undefined value: it may or may not 81correspond to the identity value of {operation name} operation; the 82choice may correspond to the value that leads to the most efficient 83storage of :attr:`output` tensor. 84 85The mask of the output tensor can be computed as 86``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim, 87dtype=torch.bool)``. 88 89The shapes of the :attr:`mask` tensor and the :attr:`input` tensor 90don't need to match, but they must be :ref:`broadcastable 91<broadcasting-semantics>` and the dimensionality of the :attr:`mask` 92tensor must not be greater than of the :attr:`input` tensor. 93 94Args: 95 input (Tensor): the input tensor 96 {args_declarations} 97 98Keyword args: 99 {kwargs_declarations}""", 100 reduction_example="""\ 101Example:: 102 103 >>> input = {example_input} 104 >>> input 105 {indent_example_input} 106 >>> mask = {example_mask} 107 >>> mask 108 {indent_example_mask} 109 >>> {full_function_name}(input, {example_args}, mask=mask) 110 {indent_example_output} 111""", 112 reduction_identity="""\ 113The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""", 114 reduction_identity_dtype="""\ 115The identity value of {operation name} operation, which is used to start the 116reduction, depends on input dtype. For instance, for float32, uint8, 117and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""", 118 normalization_signature="""\ 119{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""", 120 normalization_descr="""\ 121Returns {operation name} of all the slices in the :attr:`input` tensor 122along :attr:`dim` while the :attr:`input` elements are masked out 123according to the boolean tensor :attr:`mask`. 124 125{definition}""", 126 normalization_args="""\ 127The boolean tensor :attr:`mask` defines the "validity" of 128:attr:`input` tensor elements: if :attr:`mask` element is True then 129the corresponding element in :attr:`input` tensor will be included in 130{operation name} computation, otherwise the element is ignored. 131 132The values of masked-out elements of the output tensor have undefined 133value: it may or may not be set to zero or nan; the choice may correspond to 134the value that leads to the most efficient storage of :attr:`output` 135tensor. 136 137The mask of the {operation name} output tensor can be computed as 138``torch.broadcast_to(mask, input.shape)``. 139 140The shapes of the :attr:`mask` tensor and the :attr:`input` tensor 141don't need to match, but they must be :ref:`broadcastable 142<broadcasting-semantics>` and the dimensionality of the :attr:`mask` 143tensor must not be greater than of the :attr:`input` tensor. 144 145Args: 146 input (Tensor): the input tensor 147 {args_declarations} 148 149Keyword args: 150 {kwargs_declarations}""", 151 normalization_example="""\ 152Example:: 153 154 >>> input = {example_input} 155 >>> input 156 {indent_example_input} 157 >>> mask = {example_mask} 158 >>> mask 159 {indent_example_mask} 160 >>> {full_function_name}(input, {example_args}, mask=mask) 161 {indent_example_output} 162""", 163 ) 164 165 args_and_kwargs = dict( 166 # argument name sufficies separated by double underscore will 167 # be removed in the final documentation string. 168 sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), 169 prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), 170 cumsum=(("dim__as_int",), ("dtype=None", "mask=None")), 171 cumprod=(("dim__as_int",), ("dtype=None", "mask=None")), 172 amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), 173 amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), 174 argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), 175 argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), 176 mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), 177 median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), 178 norm=( 179 ( 180 "ord", 181 "dim", 182 ), 183 ("keepdim=False", "dtype=None", "mask=None"), 184 ), 185 var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), 186 std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), 187 logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), 188 softmax=(("dim__as_int",), ("dtype=None", "mask=None")), 189 log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")), 190 softmin=(("dim__as_int",), ("dtype=None", "mask=None")), 191 normalize=( 192 ( 193 "ord__required", 194 "dim__as_int", 195 ), 196 ("eps=1e-12", "dtype=None", "mask=None"), 197 ), 198 ) 199 200 argument_declarations = dict( 201 dim="""\ 202dim (int or tuple of ints, optional): the dimension or dimensions to reduce. 203 Default: None that is equivalent to ``tuple(range(input.ndim))``.""", 204 dim__as_int="""\ 205dim (int): the dimension along which {operation name} is computed.""", 206 ord="""\ 207ord (int, float, optional): the order of vector norm. Default: 2. 208 See :func:`torch.linalg.vector_norm` for a list of supported norms.""", 209 ord__required="""\ 210ord (int, float): the order of vector norm. Default: 2. 211 See :func:`torch.linalg.vector_norm` for a list of supported norms.""", 212 unbiased="""\ 213unbiased (bool): when True, use Bessel's correction, otherwise, compute 214 the uncorrected sample variance.""", 215 eps="""\ 216eps (float, optional): small value to avoid division by zero. Default: {default}.""", 217 keepdim="""\ 218keepdim (bool, optional): whether the output tensor has 219 :attr:`dim` retained or not. Default: {default}.""", 220 dtype="""\ 221dtype (:class:`torch.dtype`, optional): the desired data type 222 of returned tensor. If specified, the input tensor is 223 casted to :attr:`dtype` before the operation is 224 performed. Default: {default}.""", 225 mask="""\ 226mask (:class:`torch.Tensor`, optional): the boolean tensor 227 containing the binary mask of validity of input tensor 228 elements. 229 Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""", 230 ) 231 232 definitions = dict( 233 softmax="""\ 234Let ``x`` be a sequence of unmasked elements of one-dimensional slice 235of the :attr:`input` tensor. Softmax of i-th element in ``x`` is 236defined as ``exp(x[i])/sum(exp(x))``.""", 237 log_softmax="""\ 238Let ``x`` be a sequence of unmasked elements of one-dimensional slice 239of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is 240defined as ``log(exp(x[i])/sum(exp(x)))``.""", 241 softmin="""\ 242Let ``x`` be a sequence of unmasked elements of one-dimensional slice 243of the :attr:`input` tensor. Softmin of i-th element in ``x`` is 244defined as ``exp(-x[i])/sum(exp(-x))``.""", 245 normalize="""\ 246Let ``x`` be a sequence of unmasked elements of one-dimensional slice 247of the :attr:`input` tensor. Normalize of i-th element in ``x`` is 248defined as ``x[i]/max(norm(x, p), eps)``.""", 249 cumsum="""\ 250Let ``x`` be a sequence of unmasked elements of one-dimensional slice 251of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is 252defined as ``sum(x[:i])``.""", 253 cumprod="""\ 254Let ``x`` be a sequence of unmasked elements of one-dimensional slice 255of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is 256defined as ``prod(x[:i])``.""", 257 ) 258 259 reduction_names = dict( 260 sum="sum", 261 prod="product", 262 amax="maximum", 263 amin="minimum", 264 argmax="argmax", 265 argmin="argmin", 266 mean="mean", 267 median="median", 268 norm="norm", 269 var="variance", 270 std="standard_deviation", 271 logsumexp="logsumexp", 272 ) 273 274 normalization_names = dict( 275 softmax="softmax", 276 log_softmax="log_softmax", 277 softmin="softmin", 278 normalize="normalize", 279 cumsum="cumulative_sum", 280 cumprod="cumulative_prod", 281 ) 282 283 operation_names = {} 284 operation_names.update(reduction_names) 285 operation_names.update(normalization_names) 286 287 # Default example data: 288 example_dim = 1 289 example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]]) 290 example_mask = torch.tensor([[True, False, True], [False, False, False]]) 291 example_args: Tuple[Any, ...] 292 if func.__name__ in {"norm", "normalize"}: 293 example_args = (2.0, example_dim) 294 example_input = example_input.to(dtype=torch.float32) 295 elif func.__name__ in {"var", "std"}: 296 example_args = (example_dim, False) 297 elif func.__name__ == "median": 298 example_args = (example_dim,) 299 example_input = example_input.to(dtype=torch.float32) 300 else: 301 example_args = (example_dim,) 302 303 operation_args: Tuple[str, ...] 304 operation_kwargs: Tuple[str, ...] 305 operation_args, operation_kwargs = args_and_kwargs[func.__name__] 306 arg_declarations = [ 307 "\n ".join( 308 argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines() 309 ) 310 for a in operation_args 311 ] 312 kwarg_declarations = [ 313 "\n ".join( 314 argument_declarations.get( 315 a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.' 316 ) 317 .format(default=a.split("=", 1)[1]) 318 .splitlines() 319 ) 320 for a in operation_kwargs 321 ] 322 323 if func.__name__ in reduction_names: 324 op_kind = "reduction" 325 doc_sections = ["signature", "descr", "identity", "args", "example"] 326 elif func.__name__ in normalization_names: 327 op_kind = "normalization" 328 doc_sections = ["signature", "descr", "args", "example"] 329 example_input = example_input.to(dtype=torch.float32) 330 else: 331 assert 0 # add function name to operation names dictionaries 332 example_output = func(example_input, *example_args, mask=example_mask) 333 334 template_data = { 335 "function_name": func.__name__, 336 "full_function_name": func.__module__ + "." + func.__name__, 337 "operation name": operation_names[func.__name__], 338 "operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args), 339 "operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs), 340 # one-line representation of a tensor: 341 "example_input": " ".join(str(example_input).split()), 342 "example_args": ", ".join(map(str, example_args)), 343 "example_mask": " ".join(str(example_mask).split()), 344 # multi-line representation of a tensor with indent 345 "indent_example_input": ("\n ").join(str(example_input).splitlines()), 346 "indent_example_mask": ("\n ").join(str(example_mask).splitlines()), 347 "indent_example_output": ("\n ").join(str(example_output).splitlines()), 348 } 349 350 if func.__name__ in reduction_names: 351 template_data.update( 352 identity_uint8=_reduction_identity( 353 func.__name__, torch.tensor(0, dtype=torch.uint8) 354 ), 355 identity_int32=_reduction_identity( 356 func.__name__, torch.tensor(0, dtype=torch.int32) 357 ), 358 identity_float32=_reduction_identity( 359 func.__name__, torch.tensor(0, dtype=torch.float32) 360 ), 361 ) 362 if func.__name__ == "norm": 363 template_data.update( 364 identity_ord_ninf=_reduction_identity( 365 func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf") 366 ) 367 ) 368 elif func.__name__ in normalization_names: 369 template_data.update(definition=definitions[func.__name__]) 370 else: 371 assert 0 # add function name to operation names dictionaries 372 template_data.update( 373 args_declarations=("\n ".join(arg_declarations)).format_map(template_data) 374 ) 375 template_data.update( 376 kwargs_declarations=("\n ".join(kwarg_declarations)).format_map( 377 template_data 378 ) 379 ) 380 381 # Apply function name info to docstring templates: 382 templates = { 383 k: v.format_map(template_data) 384 for k, v in docstring_templates.items() 385 if k.startswith(op_kind) 386 } 387 templates.update( 388 (k, v.format_map(template_data) if isinstance(v, str) else v) 389 for k, v in template_data.items() 390 ) 391 392 # Apply docstring templates to function doctring: 393 if func.__doc__ is None: 394 doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections]) 395 else: 396 doc_template = func.__doc__ 397 return doc_template.format_map(templates) 398 399 400def _reduction_identity(op_name: str, input: Tensor, *args): 401 """Return identity value as scalar tensor of a reduction operation on 402 given input, or None, if the identity value cannot be uniquely 403 defined for the given input. 404 405 The identity value of the operation is defined as the initial 406 value to reduction operation that has a property ``op(op_identity, 407 value) == value`` for any value in the domain of the operation. 408 Or put it another way, including or excluding the identity value in 409 a list of operands will not change the reduction result. 410 411 See https://github.com/pytorch/rfcs/pull/27 for more information. 412 413 """ 414 dtype: DType = input.dtype 415 device = input.device 416 op_name = op_name.rsplit(".", 1)[-1] # lstrip module name when present 417 if op_name in {"sum", "cumsum"}: 418 return torch.tensor(0, dtype=dtype, device=device) 419 elif op_name in {"prod", "cumprod"}: 420 return torch.tensor(1, dtype=dtype, device=device) 421 elif op_name in {"amax", "argmax", "logaddexp"}: 422 if torch.is_floating_point(input): 423 return torch.tensor(-torch.inf, dtype=dtype, device=device) 424 elif torch.is_signed(input) or dtype == torch.uint8: 425 return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) 426 elif op_name in {"logsumexp"}: 427 if torch.is_floating_point(input): 428 return torch.tensor(-torch.inf, dtype=dtype, device=device) 429 elif torch.is_complex(input): 430 return torch.tensor(-torch.inf + 0j, dtype=dtype, device=device) 431 elif torch.is_signed(input) or dtype == torch.uint8: 432 return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) 433 elif op_name in {"amin", "argmin"}: 434 if torch.is_floating_point(input): 435 return torch.tensor(torch.inf, dtype=dtype, device=device) 436 elif torch.is_signed(input) or dtype == torch.uint8: 437 return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device) 438 elif op_name == "mean": 439 # Strictly speaking, the identity value of the mean operation 440 # is the mean of the input. Since the mean value depends on 441 # the dim argument and it may be a non-scalar tensor, we 442 # consider the identity value of the mean operation ambiguous. 443 # Moreover, the mean value of empty input is undefined. 444 return None 445 elif op_name == "norm": 446 ord = args[0] if args else 2 447 if ord == float("-inf"): 448 assert torch.is_floating_point(input), input.dtype 449 return torch.tensor(torch.inf, dtype=dtype, device=device) 450 return torch.tensor(0, dtype=dtype, device=device) 451 elif op_name == "median": 452 # We use NaN for now because the implementation is currently using torch.nanmedian 453 # and NaN is the identity for that function since it gets ignored 454 dtype = input.dtype if torch.is_floating_point(input) else torch.float 455 return torch.tensor(torch.nan, dtype=dtype, device=device) 456 elif op_name in {"var", "std"}: 457 return None 458 raise NotImplementedError(f"identity of {op_name} on {dtype} input") 459 460 461def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]: 462 """Return dim argument as a tuple of sorted dim values.""" 463 dims: List[int] = [] 464 if dim == (): 465 # Currently, `dim=()` in reductions operations means "reduce 466 # over all dimensions" while in future, it will read "no 467 # reduce". See https://github.com/pytorch/pytorch/issues/29137 468 # When gh-29137 is resolved, this if-block must be deleted. 469 dim = None 470 if dim is None: 471 return tuple(range(ndim)) 472 ndim = max(ndim, 1) 473 dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim 474 for d in dim_: 475 if d in dims: 476 raise RuntimeError(f"dim={d} appears multiple times in the list of dims") 477 if d >= ndim or d < -ndim: 478 raise IndexError( 479 f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})" 480 ) 481 dims.append(d % ndim) 482 return tuple(sorted(dims)) 483 484 485def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple): 486 # Flatted N-D indices to 1-D indices 487 flat_indices = indices.new_zeros(indices.size(1)) 488 for d, sz in enumerate(shape): 489 flat_indices.mul_(sz) 490 flat_indices.add_(indices[d]) 491 return flat_indices 492 493 494def _any(input: Tensor, dim: tuple, keepdim: bool): 495 # Support torch.any with tuple dim argument. 496 # Workaround of https://github.com/pytorch/pytorch/issues/56586 497 r = input 498 for d in reversed(dim): 499 r = r.any(dim=d, keepdim=keepdim) 500 return r 501 502 503def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: 504 """Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors. 505 506 _sparse_coo_where implements the following invariant: 507 508 _sparse_coo_where(mask, input, fill_value).to_dense(fill_value) == 509 torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value)) 510 511 where `a == b` means `assertEqual(a, b)`, mask is boolean sparse 512 tensor, and `to_dense(fill_value)` is like `to_dense()` except 513 that the unspecified elements are mapped to `fill_value` rather 514 than to `0`. 515 516 Returns a sparse COO tensor with the following features: 517 518 - all specified elements correspond to masked-in elements that 519 have the values of the input tensor. If there exists a masked-in 520 element (as specified by mask) that is not specified in the 521 input, in the result tensor, the corresponding element has value 522 0. In the dense part of the sparse tensor, the masked-out 523 elements are replaced with fill_value. 524 525 - all unspecified elements correspond to masked-out elements. 526 """ 527 528 assert input.layout == torch.sparse_coo 529 assert mask.layout == input.layout 530 assert mask.shape == input.shape 531 assert mask.dense_dim() == input.dense_dim() # TODO: eliminate this restriction 532 533 input = input.coalesce() 534 535 # For set operations on sparse tensor indices, we'll convert 536 # multi-dimensional indices to 1-D indices for efficiency. 537 input_flat_indices = _sparse_coo_flatten_indices( 538 input.indices(), input.shape[: input.sparse_dim()] 539 ) 540 mask_flat_indices = _sparse_coo_flatten_indices( 541 mask.indices(), mask.shape[: mask.sparse_dim()] 542 ) 543 544 # the set of mask flat indices that define masked-in elements: 545 if mask.dense_dim() > 0: 546 mask_values = _any( 547 mask.values(), tuple(range(1, input.sparse_dim() + 1)), False 548 ) 549 else: 550 mask_values = mask.values() 551 maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]] 552 553 def intersection(i1, i2): 554 union, counts = torch.cat([i1, i2]).unique(return_counts=True) 555 return union, torch.where(counts.gt(1)) 556 557 def minus(i1, i2): 558 union, counts = torch.cat([i1, i2]).unique(return_counts=True) 559 return intersection(union[torch.where(counts.eq(1))], i1) 560 561 def _apply(a): 562 obj, w = a 563 return obj[w] 564 565 # the set of input flat indices of specified and masked-in elements: 566 maskin_input_flat_indices = _apply( 567 intersection(maskin_flat_indices, input_flat_indices) 568 ) 569 _, w = intersection(input_flat_indices, maskin_input_flat_indices) 570 571 # the indices and values of masked-in elements 572 where_input_indices = input.indices()[(slice(None),) + w] 573 where_input_values = input.values()[w] 574 575 if mask.dense_dim() > 0: 576 # apply mask to the dense part of the input values: 577 _, w1 = intersection(mask_flat_indices, maskin_input_flat_indices) 578 where_mask_values = mask.values()[w1] 579 where_input_values = torch.where( 580 where_mask_values, where_input_values, fill_value 581 ) 582 583 # the set of flat indices of unspecified input and masked-in elements: 584 maskin_zero_flat_indices = _apply( 585 minus(maskin_flat_indices, maskin_input_flat_indices) 586 ) 587 588 # the indices of masked-in zero elements 589 _, w = intersection(mask_flat_indices, maskin_zero_flat_indices) 590 where_zero_indices = mask.indices()[(slice(None),) + w] 591 592 # construct result 593 n = where_zero_indices.size(1) 594 if n == 0: 595 # the input is coalesced, hence input_flat_indices are ordered 596 # and the result is guaranteed to be coalesced: 597 result = torch.sparse_coo_tensor( 598 where_input_indices, where_input_values, input.shape 599 ) 600 return result._coalesced_(True) 601 602 where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1) 603 where_values = torch.cat( 604 [ 605 where_input_values, 606 where_input_values.new_zeros((n,) + where_input_values.shape[1:]), 607 ] 608 ) 609 result = torch.sparse_coo_tensor(where_indices, where_values, input.shape) 610 611 # appending zero elements leads to uncoalesced sparse tensor 612 return result.coalesce() 613 614 615def _sparse_coo_scatter_reduction_helper( 616 op, 617 mask_input: Tensor, 618 dims: Tuple[int, ...], 619 keepdim: bool, 620 dtype: Optional[DType] = None, 621) -> Tensor: 622 reduce = op.__name__ 623 valid_reductions = ["sum", "prod", "amax", "amin"] 624 if reduce not in valid_reductions: 625 raise ValueError( 626 f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" 627 ) 628 629 output_dtype = dtype 630 values, indices = mask_input._values(), mask_input._indices() 631 input_dims = mask_input.dim() 632 num_sparse_dims = mask_input.sparse_dim() 633 reduced_sparse_dims = [] 634 retained_sparse_dims = [] 635 reduced_dense_dims = [] 636 637 # promote dtype if specified 638 if values.dtype != output_dtype: 639 values = values.to(output_dtype) 640 641 if keepdim: 642 output_shape = tuple( 643 1 if i in dims else si for (i, si) in enumerate(mask_input.shape) 644 ) 645 else: 646 output_shape = tuple( 647 si for (i, si) in enumerate(mask_input.shape) if i not in dims 648 ) 649 650 for d in dims: 651 if d >= input_dims: 652 continue 653 654 if d < num_sparse_dims: 655 reduced_sparse_dims.append(d) 656 else: 657 reduced_dense_dims.append(d + 1 - num_sparse_dims) 658 659 # Reduce dense dimensions 660 if len(reduced_dense_dims) > 0: 661 if reduce == "sum": 662 new_values = values 663 new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim)) 664 else: 665 # FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities 666 return NotImplemented 667 else: 668 new_values = values.clone() 669 670 # Reduce sparse dimensions 671 if len(reduced_sparse_dims) == num_sparse_dims: 672 if reduce in {"amax", "amin"} and new_values.size(0) == 0: 673 # IndexError: amax(): Expected reduction dim 0 to have non-zero size. 674 # sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not 675 # See https://github.com/pytorch/pytorch/issues/61901 676 new_values = _reduction_identity(reduce, new_values) 677 else: 678 new_values = op(new_values, dim=0) 679 if keepdim: 680 for _ in range(num_sparse_dims): 681 new_values = new_values.unsqueeze(0) 682 return new_values.to(dtype=output_dtype).to_sparse() 683 else: 684 new_indices = indices.clone() 685 if keepdim: 686 # zero out reduced sparse dimensions if keepdim = True 687 # ensures that the call to torch.unique folds duplicated indices together while preserving the dimension 688 new_indices[reduced_sparse_dims, :] = 0 689 else: 690 # remove reduced sparse dimensions if keepdim = False 691 if len(reduced_sparse_dims) > 0: 692 retained_sparse_dims = [ 693 i 694 for i in range(num_sparse_dims) 695 if i not in set(reduced_sparse_dims) 696 ] 697 new_indices = new_indices.index_select( 698 0, torch.tensor(retained_sparse_dims).to(mask_input.device) 699 ) 700 701 # Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices 702 if new_indices.numel() > 0: 703 # lexsort indices and get index tensor for scatter reduction 704 new_indices, inverse_indices = torch.unique( 705 new_indices, return_inverse=True, dim=1 706 ) 707 out_shape = list(new_values.shape) 708 out_shape[0] = new_indices.shape[1] 709 for _ in range(new_values.ndim - 1): 710 inverse_indices = inverse_indices.unsqueeze(-1) 711 scatter_indices = inverse_indices.expand(new_values.shape) 712 # FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce 713 if output_dtype in {torch.bfloat16, torch.float16}: 714 new_values = new_values.to(torch.float) 715 out = new_values.new_empty(out_shape) 716 new_values = out.scatter_reduce_( 717 0, scatter_indices, new_values, reduce=reduce, include_self=False 718 ) 719 new_values = new_values.to(dtype=output_dtype) 720 else: 721 out = new_values.new_empty(out_shape) 722 new_values = out.scatter_reduce_( 723 0, scatter_indices, new_values, reduce=reduce, include_self=False 724 ) 725 726 return torch.sparse_coo_tensor( 727 new_indices, 728 new_values, 729 output_shape, 730 dtype=output_dtype, 731 device=mask_input.device, 732 ) 733 734 735def _sparse_csr_segment_reduction_helper( 736 op, 737 mask_input: Tensor, 738 dims: Tuple[int, ...], 739 keepdim: bool, 740 dtype: Optional[DType] = None, 741) -> Tensor: 742 # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True 743 # FIXME: when dense dimensions are implemented for CSR tensors 744 assert ( 745 keepdim 746 ), "reduction operations on CSR tensors with keepdim=False is unsupported" 747 reduce = op.__name__ 748 valid_reductions = ["sum", "prod", "mean", "amax", "amin"] 749 if reduce not in valid_reductions: 750 raise ValueError( 751 f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead" 752 ) 753 device = mask_input.device 754 output_dtype = dtype 755 values, crow_indices, col_indices = ( 756 mask_input.values(), 757 mask_input.crow_indices(), 758 mask_input.col_indices(), 759 ) 760 761 # promote dtype if specified 762 if values.dtype != output_dtype: 763 values = values.to(output_dtype) 764 765 if len(dims) == 0: 766 return mask_input 767 if len(dims) == 1: 768 if dims[0] == 0: 769 new_col_indices, scatter_indices = torch.unique( 770 col_indices, return_inverse=True 771 ) 772 new_nnz = new_col_indices.shape[0] 773 new_crow_indices = torch.tensor([0, new_nnz]) 774 new_values = values.new_empty(new_col_indices.shape) 775 new_values.scatter_reduce_( 776 0, scatter_indices, values, reduce, include_self=False 777 ) 778 new_shape = [1, mask_input.size(1)] 779 else: 780 assert ( 781 dims[0] == 1 782 ), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1." 783 # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1 784 # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0 785 new_crow_indices = torch.cat( 786 ( 787 crow_indices.new_zeros(1), 788 torch.cumsum(torch.diff(crow_indices) != 0, 0), 789 ), 790 0, 791 ) 792 new_nnz = new_crow_indices[-1] 793 new_col_indices = col_indices.new_zeros(new_nnz) 794 new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined] 795 new_shape = [mask_input.size(0), 1] 796 else: 797 assert len(dims) == 2 798 nnz = min(1, values.numel()) 799 if nnz == 1: 800 op_kwargs = {"keepdim": True, "dtype": output_dtype} 801 # amax and amin do not support dtype kwarg 802 if reduce in ["amax", "amin"]: 803 del op_kwargs["dtype"] 804 new_values = op(values, 0, **op_kwargs) 805 else: 806 new_values = torch.empty(0, dtype=output_dtype) 807 new_col_indices = col_indices.new_zeros(nnz) 808 new_crow_indices = torch.tensor([0, nnz]) 809 new_shape = [1, nnz] 810 811 return torch.sparse_csr_tensor( 812 new_crow_indices, 813 new_col_indices, 814 new_values, 815 new_shape, 816 dtype=output_dtype, 817 device=device, 818 ) 819 820 821def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: 822 """Sparse variant of torch.where. Supports sparse CSR tensors.""" 823 # TODO: implement sparse CSR specific where operator for efficiency 824 return _sparse_coo_where( 825 mask.to_sparse_coo(), input.to_sparse_coo(), fill_value 826 ).to_sparse_csr() 827 828 829def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor: 830 """torch.where with sparse inputs support. 831 832 _where implements the following invariant: 833 834 _where(mask, input, fill_value).to_dense(fill_value) == 835 torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value)) 836 837 where `a == b` means `assertEqual(a, b)`, mask is boolean sparse 838 tensor, and `to_dense(fill_value)` is like `to_dense()` except 839 that the unspecified elements are mapped to `fill_value` rather 840 than to `0`. 841 842 Returns a sparse tensor with the following features: 843 844 - all specified elements correspond to masked-in elements that 845 have the values of the input tensor. If there exists a masked-in 846 element (as specified by mask) that is not specified in the 847 input, in the result tensor, the corresponding element has value 848 0. In the dense part of the sparse tensor, the masked-out 849 elements are replaced with fill_value. 850 851 - all unspecified elements correspond to masked-out elements. 852 """ 853 if mask.layout == torch.strided: 854 return torch.where(mask, input, fill_value) 855 elif mask.layout == torch.sparse_coo: 856 return _sparse_coo_where(mask, input, fill_value) 857 elif mask.layout == torch.sparse_csr: 858 return _sparse_csr_where(mask, input, fill_value) 859 else: 860 raise ValueError( 861 f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}" 862 ) 863 864 865def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor: 866 """Return canonical input mask. 867 868 A canonical input mask is defined as a boolean mask tensor that 869 shape and layout matches with the shape and the layout of the 870 input. 871 872 The canonical input mask is computed from the :attr:`mask` tensor 873 content to meet the following criteria: 874 875 1. The shape of the canonical input mask is the same as the shape 876 of :attr:`input` tensor. If the mask tensor has a smaller shape 877 than the shape of the :attr:`input`, broadcasting rules will be 878 applied. Downcasting of mask is not supported. 879 880 2. The layout of the canonical input mask is the same as the 881 layout of the :attr:`input` tensor. If the mask has different 882 layout, it will be converted to the expected layout. In the 883 case of sparse COO layout, the canonical input mask will be 884 coalesced. 885 886 3. The dtype of the canonical input mask is torch.bool. If the 887 mask dtype is not bool then it will be converted to bool dtype 888 using `.to(dtype=bool)` method call. 889 890 4. The elements of the canonical input mask have boolean values 891 copied from the content of the :attr:`mask` tensor (after 892 possible broadcasting and dtype conversion transforms). In 893 general, the sparsity pattern of the sparse canonical input 894 mask need not to be the same as the sparsity pattern of the 895 sparse :attr:`input` tensor. 896 897 """ 898 if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}: 899 raise ValueError( 900 f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}" 901 ) 902 903 mask = kwargs.get("mask") 904 905 # default mask 906 if mask is None: 907 raise ValueError("_input_mask requires explicit mask") 908 909 # mask shape must match with input shape 910 if mask.shape != input.shape: 911 if mask.ndim > input.ndim: 912 raise IndexError( 913 "_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)" 914 ) 915 if mask.layout == torch.strided: 916 mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool) 917 elif mask.layout == torch.sparse_coo: 918 mask = torch._sparse_broadcast_to(mask, input.shape) 919 else: 920 assert mask.layout == torch.sparse_csr 921 # Broadcasting of CSR tensors is not implemented. Working 922 # around by using COO layout. 923 mask = torch._sparse_broadcast_to( 924 mask.to_sparse(), input.shape 925 ).to_sparse_csr() 926 927 # mask layout must match with input layout 928 if mask.layout != input.layout: 929 if input.layout == torch.strided: 930 mask = mask.to_dense() 931 elif input.layout == torch.sparse_coo: 932 if mask.layout == torch.strided: 933 mask = mask.to_sparse(input.sparse_dim()) 934 else: 935 mask = mask.to_sparse() 936 else: 937 assert input.layout == torch.sparse_csr 938 mask = mask.to_sparse_csr() 939 940 # sparse mask must be coalesced 941 if mask.layout == torch.sparse_coo: 942 mask = mask.coalesce() 943 944 # mask is a boolean tensor 945 mask = mask.to(dtype=torch.bool) 946 947 return mask 948 949 950def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor: 951 """Return output mask of masked operation applied to given arguments.""" 952 if callable(op): 953 is_reduction = op.__name__ in { 954 "sum", 955 "prod", 956 "amax", 957 "amin", 958 "argmax", 959 "argmin", 960 "mean", 961 "median", 962 "norm", 963 "var", 964 "std", 965 "logsumexp", 966 } 967 is_normalization = op.__name__ in { 968 "softmax", 969 "log_softmax", 970 "softmin", 971 "normalize", 972 "cumsum", 973 "cumprod", 974 } 975 if is_reduction: 976 if op.__name__ == "norm": 977 if args: 978 args = args[1:] # lstrip ord argument 979 dim = args[0] if args else kwargs.get("dim") 980 outmask = _input_mask(input, *args, **kwargs) 981 keepdim = kwargs.get("keepdim", False) 982 dim_ = _canonical_dim(dim, input.ndim) 983 return _any(outmask, dim_, bool(keepdim)) 984 elif is_normalization: 985 return _input_mask(input, *args, **kwargs) 986 else: 987 raise ValueError( 988 f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})" 989 ) 990 else: 991 raise ValueError( 992 f"_output_mask expected masked operation (got {type(op).__name__} object)" 993 ) 994 995 996def _combine_input_and_mask( 997 op, input: Union[MaskedTensor, Tensor], mask, *args 998) -> Tensor: 999 def helper(input, mask): 1000 if mask is None: 1001 return input 1002 canonical_mask = _input_mask(input, mask=mask) 1003 if callable(op): 1004 fill_value = _reduction_identity(op.__name__, input, *args) 1005 return _where(canonical_mask, input, fill_value) 1006 else: 1007 raise ValueError( 1008 f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)" 1009 ) 1010 1011 class Combine(torch.autograd.Function): 1012 @staticmethod 1013 def forward(ctx, input, mask): 1014 """Return input with masked-out elements eliminated for the given operations.""" 1015 ctx.save_for_backward(mask) 1016 1017 if mask is not None: 1018 ctx.mark_non_differentiable(mask) 1019 1020 return helper(input, mask) 1021 1022 @staticmethod 1023 def backward(ctx, grad_output): 1024 (mask,) = ctx.saved_tensors 1025 grad_data = ( 1026 grad_output.get_data() if is_masked_tensor(grad_output) else grad_output 1027 ) 1028 result = as_masked_tensor(grad_data, mask) 1029 return result, None 1030 1031 return ( 1032 Combine.apply(input.get_data(), input.get_mask()) # type: ignore[union-attr] 1033 if is_masked_tensor(input) 1034 else helper(input, mask) 1035 ) 1036 1037 1038@_apply_docstring_templates 1039def sum( 1040 input: Union[Tensor, MaskedTensor], 1041 dim: DimOrDims = None, 1042 *, 1043 keepdim: Optional[bool] = False, 1044 dtype: Optional[DType] = None, 1045 mask: Optional[Tensor] = None, 1046) -> Tensor: 1047 # __doc__ is generated by _apply_docstring_templates decorator 1048 if dtype is None: 1049 # promote integer types to int64 when output dtype is not specified 1050 if input.layout == torch.sparse_csr: 1051 if input.dtype in { 1052 torch.uint8, 1053 torch.bool, 1054 torch.int8, 1055 torch.int16, 1056 torch.int32, 1057 }: 1058 # csr.to(dtype=torch.int64) is not implemented, so 1059 # using coo.to on input to ensure the promoted dtype 1060 input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() 1061 else: 1062 dtype = input.dtype 1063 else: 1064 dtype = input.dtype 1065 if input.dtype in { 1066 torch.uint8, 1067 torch.bool, 1068 torch.int8, 1069 torch.int16, 1070 torch.int32, 1071 }: 1072 dtype = torch.int64 1073 dim_ = _canonical_dim(dim, input.ndim) 1074 mask_input = _combine_input_and_mask(sum, input, mask) 1075 if mask_input.layout == torch.strided: 1076 return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype) 1077 elif mask_input.layout == torch.sparse_coo: 1078 return _sparse_coo_scatter_reduction_helper( 1079 torch.sum, mask_input, dim_, bool(keepdim), dtype 1080 ) 1081 elif mask_input.layout == torch.sparse_csr: 1082 return torch._sparse_csr_sum( 1083 mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype 1084 ) 1085 else: 1086 raise ValueError( 1087 f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" 1088 ) 1089 1090 1091@_apply_docstring_templates 1092def prod( 1093 input: Union[Tensor, MaskedTensor], 1094 dim: DimOrDims = None, 1095 *, 1096 keepdim: Optional[bool] = False, 1097 dtype: Optional[DType] = None, 1098 mask: Optional[Tensor] = None, 1099) -> Tensor: 1100 # __doc__ is generated by _apply_docstring_templates decorator 1101 if dtype is None: 1102 # promote integer types to int64 when output dtype is not specified 1103 if input.layout == torch.sparse_csr: 1104 if input.dtype in { 1105 torch.uint8, 1106 torch.bool, 1107 torch.int8, 1108 torch.int16, 1109 torch.int32, 1110 }: 1111 # csr.to(dtype=torch.int64) is not implemented, so 1112 # using coo.to on input to ensure the promoted dtype 1113 input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr() 1114 else: 1115 dtype = input.dtype 1116 else: 1117 dtype = input.dtype 1118 if input.dtype in { 1119 torch.uint8, 1120 torch.bool, 1121 torch.int8, 1122 torch.int16, 1123 torch.int32, 1124 }: 1125 dtype = torch.int64 1126 dim_ = _canonical_dim(dim, input.ndim) 1127 mask_input = _combine_input_and_mask(prod, input, mask) 1128 if mask_input.layout == torch.strided: 1129 # Workaround https://github.com/pytorch/pytorch/issues/56586 1130 result = mask_input 1131 result = result.to(dtype=dtype) 1132 for d in reversed(dim_): 1133 result = result.prod(dim=d, keepdim=bool(keepdim)) 1134 return result 1135 elif mask_input.layout == torch.sparse_coo: 1136 if mask is None: 1137 # See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors 1138 raise ValueError( 1139 "masked prod expects explicit mask for sparse_coo tensor input" 1140 ) 1141 return _sparse_coo_scatter_reduction_helper( 1142 torch.prod, mask_input, dim_, bool(keepdim), dtype 1143 ) 1144 elif mask_input.layout == torch.sparse_csr: 1145 if mask is None: 1146 # mask is None corresponds to all-True mask. The 1147 # unspecified elements in the CSR tensor correspond to 1148 # zero values. Hence, the prod reduction result is 1149 # automatically zero unless all elements are specified. 1150 # A semi-optimal way to take this into account is to use: 1151 # 1152 # masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...) 1153 # 1154 # but that requires implementing `all` and `nonzero` 1155 # support for sparse csr tensors. 1156 raise ValueError( 1157 "masked prod expects explicit mask for sparse_csr tensor input" 1158 ) 1159 return torch._sparse_csr_prod( 1160 mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype 1161 ) 1162 else: 1163 raise ValueError( 1164 f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" 1165 ) 1166 1167 1168@_apply_docstring_templates 1169def cumsum( 1170 input: Tensor, 1171 dim: int, 1172 *, 1173 dtype: Optional[DType] = None, 1174 mask: Optional[Tensor] = None, 1175) -> Tensor: 1176 if dtype is None: 1177 dtype = input.dtype 1178 dim_ = _canonical_dim(dim, input.ndim)[0] 1179 mask_input = _combine_input_and_mask(sum, input, mask) 1180 if mask_input.layout == torch.strided: 1181 return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype) 1182 else: 1183 raise ValueError( 1184 f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)" 1185 ) 1186 1187 1188@_apply_docstring_templates 1189def cumprod( 1190 input: Tensor, 1191 dim: int, 1192 *, 1193 dtype: Optional[DType] = None, 1194 mask: Optional[Tensor] = None, 1195) -> Tensor: 1196 if dtype is None: 1197 dtype = input.dtype 1198 dim_ = _canonical_dim(dim, input.ndim)[0] 1199 mask_input = _combine_input_and_mask(prod, input, mask) 1200 if mask_input.layout == torch.strided: 1201 return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype) 1202 else: 1203 raise ValueError( 1204 f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)" 1205 ) 1206 1207 1208@_apply_docstring_templates 1209def amax( 1210 input: Union[Tensor, MaskedTensor], 1211 dim: DimOrDims = None, 1212 *, 1213 keepdim: Optional[bool] = False, 1214 dtype: Optional[DType] = None, 1215 mask: Optional[Tensor] = None, 1216) -> Tensor: 1217 """\ 1218{reduction_signature} 1219 1220{reduction_descr} 1221 1222{reduction_identity_dtype} 1223 1224{reduction_args} 1225 1226{reduction_example}""" 1227 if dtype is None: 1228 dtype = input.dtype 1229 1230 mask_input = _combine_input_and_mask(amax, input, mask) 1231 dim_ = _canonical_dim(dim, mask_input.ndim) 1232 if mask_input.layout == torch.strided: 1233 return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype) 1234 elif mask_input.layout == torch.sparse_coo: 1235 if mask is None: 1236 # See comment in the sparse_csr branch of prod, a similar issue arises here 1237 # where unspecified elements along a dimension may need to be reduced with the result 1238 raise ValueError( 1239 "masked amax expects explicit mask for sparse_coo tensor input" 1240 ) 1241 return _sparse_coo_scatter_reduction_helper( 1242 torch.amax, mask_input, dim_, bool(keepdim), dtype 1243 ) 1244 elif mask_input.layout == torch.sparse_csr: 1245 if mask is None: 1246 raise ValueError( 1247 "masked amax expects explicit mask for sparse_csr tensor input" 1248 ) 1249 return _sparse_csr_segment_reduction_helper( 1250 torch.amax, mask_input, dim_, bool(keepdim), dtype 1251 ) 1252 else: 1253 raise ValueError( 1254 f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" 1255 ) 1256 1257 1258@_apply_docstring_templates 1259def amin( 1260 input: Union[Tensor, MaskedTensor], 1261 dim: DimOrDims = None, 1262 *, 1263 keepdim: Optional[bool] = False, 1264 dtype: Optional[DType] = None, 1265 mask: Optional[Tensor] = None, 1266) -> Tensor: 1267 """\ 1268{reduction_signature} 1269 1270{reduction_descr} 1271 1272{reduction_identity_dtype} 1273 1274{reduction_args} 1275 1276{reduction_example}""" 1277 if dtype is None: 1278 dtype = input.dtype 1279 1280 mask_input = _combine_input_and_mask(amin, input, mask) 1281 dim_ = _canonical_dim(dim, mask_input.ndim) 1282 if mask_input.layout == torch.strided: 1283 return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype) 1284 elif mask_input.layout == torch.sparse_coo: 1285 if mask is None: 1286 # See comment in the sparse_csr branch of prod, a similar issue arises here 1287 # where unspecified elements along a dimension may need to be reduced with the result 1288 raise ValueError( 1289 "masked amax expects explicit mask for sparse_coo tensor input" 1290 ) 1291 return _sparse_coo_scatter_reduction_helper( 1292 torch.amin, mask_input, dim_, bool(keepdim), dtype 1293 ) 1294 elif mask_input.layout == torch.sparse_csr: 1295 if mask is None: 1296 raise ValueError( 1297 "masked amin expects explicit mask for sparse_csr tensor input" 1298 ) 1299 return _sparse_csr_segment_reduction_helper( 1300 torch.amin, mask_input, dim_, bool(keepdim), dtype 1301 ) 1302 else: 1303 raise ValueError( 1304 f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)" 1305 ) 1306 1307 1308@_apply_docstring_templates 1309def argmax( 1310 input: Union[Tensor, MaskedTensor], 1311 dim: Optional[int] = None, 1312 *, 1313 keepdim: Optional[bool] = False, 1314 dtype: Optional[DType] = None, 1315 mask: Optional[Tensor] = None, 1316) -> Tensor: 1317 """\ 1318{reduction_signature} 1319{reduction_descr} 1320{reduction_identity_dtype} 1321{reduction_args} 1322{reduction_example}""" 1323 if dtype is None: 1324 dtype = input.dtype 1325 mask_input = _combine_input_and_mask(argmax, input, mask) 1326 if mask_input.layout == torch.strided: 1327 return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype) 1328 else: 1329 raise ValueError( 1330 f"masked argmax expects strided tensor (got {mask_input.layout} tensor)" 1331 ) 1332 1333 1334@_apply_docstring_templates 1335def argmin( 1336 input: Union[Tensor, MaskedTensor], 1337 dim: Optional[int] = None, 1338 *, 1339 keepdim: Optional[bool] = False, 1340 dtype: Optional[DType] = None, 1341 mask: Optional[Tensor] = None, 1342) -> Tensor: 1343 """\ 1344{reduction_signature} 1345{reduction_descr} 1346{reduction_identity_dtype} 1347{reduction_args} 1348{reduction_example}""" 1349 if dtype is None: 1350 dtype = input.dtype 1351 mask_input = _combine_input_and_mask(argmin, input, mask) 1352 if mask_input.layout == torch.strided: 1353 return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype) 1354 else: 1355 raise ValueError( 1356 f"masked argmin expects strided tensor (got {mask_input.layout} tensor)" 1357 ) 1358 1359 1360@_apply_docstring_templates 1361def mean( 1362 input: Union[Tensor, MaskedTensor], 1363 dim: DimOrDims = None, 1364 *, 1365 keepdim: Optional[bool] = False, 1366 dtype: Optional[DType] = None, 1367 mask: Optional[Tensor] = None, 1368) -> Tensor: 1369 """\ 1370{reduction_signature} 1371 1372{reduction_descr} 1373 1374By definition, the identity value of a mean operation is the mean 1375value of the tensor. If all elements of the input tensor along given 1376dimension(s) :attr:`dim` are masked-out, the identity value of the 1377mean is undefined. Due to this ambiguity, the elements of output 1378tensor with strided layout, that correspond to fully masked-out 1379elements, have ``nan`` values. 1380 1381{reduction_args} 1382 1383{reduction_example}""" 1384 if dtype is None: 1385 dtype = input.dtype 1386 if input.layout == torch.strided: 1387 if mask is None: 1388 # TODO: compute count analytically 1389 count = sum( 1390 torch.ones(input.shape, dtype=torch.int64, device=input.device), 1391 dim, 1392 keepdim=keepdim, 1393 ) 1394 total = sum(input, dim, keepdim=keepdim, dtype=dtype) 1395 else: 1396 inmask = _input_mask(input, mask=mask) 1397 count = inmask.sum(dim=dim, keepdim=bool(keepdim)) 1398 total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask) 1399 return total / count 1400 elif input.layout == torch.sparse_csr: 1401 mask_input = _combine_input_and_mask(mean, input, mask) 1402 dim_ = _canonical_dim(dim, mask_input.ndim) 1403 if mask is None: 1404 raise ValueError( 1405 "masked mean expects explicit mask for sparse_csr tensor input" 1406 ) 1407 return _sparse_csr_segment_reduction_helper( 1408 torch.mean, mask_input, dim_, bool(keepdim), dtype 1409 ) 1410 else: 1411 raise ValueError( 1412 f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)" 1413 ) 1414 1415 1416@_apply_docstring_templates 1417def median( 1418 input: Union[Tensor, MaskedTensor], 1419 dim: int = -1, 1420 *, 1421 keepdim: bool = False, 1422 dtype: Optional[DType] = None, 1423 mask: Optional[Tensor] = None, 1424) -> Tensor: 1425 """\ 1426{reduction_signature} 1427{reduction_descr} 1428By definition, the identity value of a median operation is the median 1429value of the tensor. If all elements of the input tensor along given 1430dimension(s) :attr:`dim` are masked-out, the identity value of the 1431median is undefined. Due to this ambiguity, the elements of output 1432tensor with strided layout, that correspond to fully masked-out 1433elements, have ``nan`` values. 1434{reduction_args} 1435{reduction_example}""" 1436 if dtype is None: 1437 dtype = input.dtype 1438 dim_ = _canonical_dim(dim, input.ndim)[0] 1439 is_float = torch.is_floating_point(input) 1440 if not is_float: 1441 input = input.to(dtype=torch.float) 1442 mask_input = _combine_input_and_mask(median, input, mask) 1443 if mask_input.layout == torch.strided: 1444 output = torch.nanmedian(mask_input, dim_, keepdim).values 1445 if is_float: 1446 return output 1447 elif not is_float and not torch.isnan(output).any(): 1448 return output.to(dtype=dtype) 1449 else: 1450 raise ValueError( 1451 "masked median expects no fully masked out rows if dtype is not floating point" 1452 ) 1453 else: 1454 raise ValueError( 1455 f"masked median expects strided tensor (got {mask_input.layout} tensor)" 1456 ) 1457 1458 1459@_apply_docstring_templates 1460def logsumexp( 1461 input: Tensor, 1462 dim: DimOrDims = None, 1463 *, 1464 keepdim: bool = False, 1465 dtype: Optional[DType] = None, 1466 mask: Optional[Tensor] = None, 1467) -> Tensor: 1468 if dtype is None: 1469 dtype = input.dtype 1470 dim_ = _canonical_dim(dim, input.ndim) 1471 mask_input = _combine_input_and_mask(logsumexp, input, mask) 1472 if mask_input.layout == torch.strided: 1473 return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype) 1474 else: 1475 raise ValueError( 1476 f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)" 1477 ) 1478 1479 1480# Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations 1481def logaddexp( 1482 input: Union[Tensor, MaskedTensor], 1483 other: Union[Tensor, MaskedTensor], 1484 *, 1485 dtype: Optional[DType] = None, 1486 input_mask: Optional[Tensor] = None, 1487 other_mask: Optional[Tensor] = None, 1488) -> Tensor: 1489 """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor 1490 1491 Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other` 1492 tensor. The :attr:`input` elements are masked out according to the boolean tensor 1493 :attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor 1494 :attr:`other_mask`. 1495 1496 The shapes of a mask tensor and the tensor to be masked 1497 don't need to match, but they must be :ref:`broadcastable 1498 <broadcasting-semantics>` and the dimensionality of the mask 1499 tensor must not be greater than of the tensor to be masked. 1500 1501 Args: 1502 input (Tensor): the input tensor 1503 other (Tensor): the second input tensor 1504 1505 Keyword args: 1506 dtype (:class:`torch.dtype`, optional): the desired data type 1507 of returned tensor. If specified, the output tensor is 1508 casted to :attr:`dtype` after the operation is 1509 performed. Default: None. 1510 input_mask (:class:`torch.Tensor`, optional): the boolean tensor 1511 containing the binary mask of validity of :attr:`input` tensor elements. 1512 Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``. 1513 other_mask (:class:`torch.Tensor`, optional): the boolean tensor 1514 containing the binary mask of validity of :attr:`other` tensor elements. 1515 Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``. 1516 1517 Example:: 1518 1519 >>> input = torch.tensor([-100.0, -200, -300]) 1520 >>> input 1521 tensor([-100., -200., -300.]) 1522 >>> other = torch.tensor([-1.0, -2, -3]) 1523 >>> other 1524 tensor([-1., -2., -3.]) 1525 >>> mask = torch.tensor([True, False, True]) 1526 >>> mask 1527 tensor([ True, False, True]) 1528 >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask) 1529 tensor([-1., -inf, -3.])""" 1530 if dtype is None: 1531 dtype = input.dtype 1532 if input.layout == torch.strided and other.layout == torch.strided: 1533 mask_input = _combine_input_and_mask(logaddexp, input, input_mask) 1534 mask_other = _combine_input_and_mask(logaddexp, other, other_mask) 1535 return torch.logaddexp(mask_input, mask_other).to(dtype=dtype) 1536 else: 1537 raise ValueError( 1538 f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)" 1539 ) 1540 1541 1542@_apply_docstring_templates 1543def norm( 1544 input: Union[Tensor, MaskedTensor], 1545 ord: Optional[float] = 2.0, 1546 dim: DimOrDims = None, 1547 *, 1548 keepdim: Optional[bool] = False, 1549 dtype: Optional[DType] = None, 1550 mask: Optional[Tensor] = None, 1551) -> Tensor: 1552 """\ 1553{reduction_signature} 1554 1555{reduction_descr} 1556 1557The identity value of norm operation, which is used to start the 1558reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is 1559``{identity_ord_ninf}``. 1560 1561{reduction_args} 1562 1563{reduction_example}""" 1564 if dtype is None: 1565 dtype = input.dtype 1566 mask_input = _combine_input_and_mask(norm, input, mask, ord) 1567 if mask_input.layout == torch.strided: 1568 dim_ = _canonical_dim(dim, input.ndim) 1569 return torch.linalg.vector_norm( 1570 mask_input, ord, dim_, bool(keepdim), dtype=dtype 1571 ) 1572 else: 1573 raise ValueError( 1574 f"masked norm expects strided tensor (got {mask_input.layout} tensor)" 1575 ) 1576 1577 1578def _std_var( 1579 input: Union[Tensor, MaskedTensor], 1580 dim: DimOrDims, 1581 unbiased: Optional[bool], 1582 *, 1583 correction_opt: Optional[Union[int, float]], 1584 keepdim: Optional[bool], 1585 dtype: Optional[DType], 1586 mask: Optional[Tensor], 1587 take_sqrt: Optional[bool], 1588) -> Tensor: 1589 assert ( 1590 unbiased is None or correction_opt is None 1591 ), "Only one of unbiased and correction may be given" 1592 correction = 1.0 1593 if unbiased is not None: 1594 correction = 1.0 if unbiased else 0.0 1595 if correction_opt is not None: 1596 correction = sym_float(correction_opt) 1597 1598 if dtype is None: 1599 dtype = input.dtype 1600 if not (dtype.is_floating_point or dtype.is_complex): 1601 dtype = torch.float32 1602 compute_dtype = dtype 1603 if not (compute_dtype.is_floating_point or compute_dtype.is_complex): 1604 compute_dtype = torch.float32 1605 if input.layout == torch.strided: 1606 if mask is None: 1607 # TODO: compute count analytically 1608 count = sum( 1609 torch.ones(input.shape, dtype=torch.int64, device=input.device), 1610 dim, 1611 keepdim=True, 1612 ) 1613 sample_total = sum(input, dim, keepdim=True, dtype=dtype) 1614 else: 1615 inmask = _input_mask(input, mask=mask) 1616 count = inmask.sum(dim=dim, keepdim=True) 1617 sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask) 1618 # TODO: replace torch.subtract/divide/square/maximum with 1619 # masked subtract/divide/square/maximum when these will be 1620 # available. 1621 sample_mean = torch.divide(sample_total, count) 1622 x = torch.subtract(input, sample_mean) 1623 if mask is None: 1624 total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) 1625 else: 1626 total = sum( 1627 x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined] 1628 ) 1629 if not keepdim: 1630 count = count.reshape(total.shape) 1631 if correction != 0: 1632 real_dtype = ( 1633 corresponding_real_dtype(compute_dtype) 1634 if compute_dtype.is_complex 1635 else compute_dtype 1636 ) 1637 count = count.to(real_dtype) 1638 count = torch.subtract(count, correction) 1639 count = torch.maximum(count, count.new_zeros([])) 1640 output = torch.divide(total, count).to(dtype=dtype) 1641 if take_sqrt: 1642 output = torch.sqrt(output) 1643 return output 1644 else: 1645 raise ValueError( 1646 f"masked std/var expects strided tensor (got {input.layout} tensor)" 1647 ) 1648 1649 1650@_apply_docstring_templates 1651def var( 1652 input: Union[Tensor, MaskedTensor], 1653 dim: DimOrDims = None, 1654 unbiased: Optional[bool] = None, 1655 *, 1656 correction: Optional[Union[int, float]] = None, 1657 keepdim: Optional[bool] = False, 1658 dtype: Optional[DType] = None, 1659 mask: Optional[Tensor] = None, 1660) -> Tensor: 1661 """\ 1662{reduction_signature} 1663{reduction_descr} 1664The identity value of sample variance operation is undefined. The 1665elements of output tensor with strided layout, that correspond to 1666fully masked-out elements, have ``nan`` values. 1667{reduction_args} 1668{reduction_example}""" 1669 return _std_var( 1670 input=input, 1671 dim=dim, 1672 unbiased=unbiased, 1673 correction_opt=correction, 1674 keepdim=keepdim, 1675 dtype=dtype, 1676 mask=mask, 1677 take_sqrt=False, 1678 ) 1679 1680 1681@_apply_docstring_templates 1682def std( 1683 input: Union[Tensor, MaskedTensor], 1684 dim: DimOrDims = None, 1685 unbiased: Optional[bool] = None, 1686 *, 1687 correction: Optional[int] = None, 1688 keepdim: Optional[bool] = False, 1689 dtype: Optional[DType] = None, 1690 mask: Optional[Tensor] = None, 1691) -> Tensor: 1692 """\ 1693{reduction_signature} 1694{reduction_descr} 1695The identity value of sample standard deviation operation is undefined. The 1696elements of output tensor with strided layout, that correspond to 1697fully masked-out elements, have ``nan`` values. 1698{reduction_args} 1699{reduction_example}""" 1700 return _std_var( 1701 input=input, 1702 dim=dim, 1703 unbiased=unbiased, 1704 correction_opt=correction, 1705 keepdim=keepdim, 1706 dtype=dtype, 1707 mask=mask, 1708 take_sqrt=True, 1709 ) 1710 1711 1712@_apply_docstring_templates 1713def softmax( 1714 input: Union[Tensor, MaskedTensor], 1715 dim: int, 1716 *, 1717 dtype: Optional[DType] = None, 1718 mask: Optional[Tensor] = None, 1719) -> Tensor: 1720 if dtype is None: 1721 dtype = input.dtype 1722 dim_ = _canonical_dim(dim, input.ndim)[0] 1723 mask_input = _combine_input_and_mask(amax, input, mask) 1724 if mask_input.layout == torch.strided: 1725 return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype) 1726 else: 1727 raise ValueError( 1728 f"masked softmax expects strided tensor (got {mask_input.layout} tensor)" 1729 ) 1730 1731 1732@_apply_docstring_templates 1733def log_softmax( 1734 input: Union[Tensor, MaskedTensor], 1735 dim: int, 1736 *, 1737 dtype: Optional[DType] = None, 1738 mask: Optional[Tensor] = None, 1739) -> Tensor: 1740 if dtype is None: 1741 dtype = input.dtype 1742 dim_ = _canonical_dim(dim, input.ndim)[0] 1743 mask_input = _combine_input_and_mask(amax, input, mask) 1744 if mask_input.layout == torch.strided: 1745 return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype) 1746 else: 1747 raise ValueError( 1748 f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)" 1749 ) 1750 1751 1752@_apply_docstring_templates 1753def softmin( 1754 input: Union[Tensor, MaskedTensor], 1755 dim: int, 1756 *, 1757 dtype: Optional[DType] = None, 1758 mask: Optional[Tensor] = None, 1759) -> Tensor: 1760 if dtype is None: 1761 dtype = input.dtype 1762 dim_ = _canonical_dim(dim, input.ndim)[0] 1763 mask_input = _combine_input_and_mask(amin, input, mask) 1764 if mask_input.layout == torch.strided: 1765 return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype) 1766 else: 1767 raise ValueError( 1768 f"masked softmin expects strided tensor (got {mask_input.layout} tensor)" 1769 ) 1770 1771 1772@_apply_docstring_templates 1773def normalize( 1774 input: Union[Tensor, MaskedTensor], 1775 ord: float, 1776 dim: int, 1777 *, 1778 eps: float = 1e-12, 1779 dtype: Optional[DType] = None, 1780 mask: Optional[Tensor] = None, 1781) -> Tensor: 1782 if dtype is None: 1783 dtype = input.dtype 1784 dim_ = _canonical_dim(dim, input.ndim)[0] 1785 # TODO: eliminate mask_input as unnecessary when using masked divide. 1786 mask_input = _combine_input_and_mask(sum, input, mask) 1787 if mask_input.layout == torch.strided: 1788 nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask) 1789 # TODO: replace torch.maximum with masked maximum when available. 1790 denom = torch.maximum(nrm_, nrm_.new_full([], eps)) 1791 # TODO: replace torch.divide with masked divide when available. 1792 return torch.divide(mask_input, denom) 1793 else: 1794 raise ValueError( 1795 f"masked normalize expects strided tensor (got {mask_input.layout} tensor)" 1796 ) 1797