1# mypy: ignore-errors 2 3import collections 4import collections.abc 5import math 6import operator 7import unittest 8from dataclasses import asdict, dataclass 9from enum import Enum 10from functools import partial 11from itertools import product 12from typing import Any, Callable, Iterable, List, Optional, Tuple, Union 13 14import torch 15from torch.testing import make_tensor 16from torch.testing._internal.common_device_type import ( 17 skipCPUIfNoFFT, 18 tol, 19 toleranceOverride, 20) 21from torch.testing._internal.common_dtype import ( 22 _dispatch_dtypes, 23 floating_and_complex_types, 24 floating_and_complex_types_and, 25 floating_types, 26 get_all_dtypes, 27) 28from torch.testing._internal.common_utils import ( 29 is_iterable_of_tensors, 30 noncontiguous_like, 31 OPINFO_SAMPLE_INPUT_INDEX, 32 TEST_WITH_ROCM, 33 torch_to_numpy_dtype_dict, 34 TrackedInputIter, 35) 36from torch.testing._internal.opinfo import utils 37from torchgen.utils import dataclass_repr 38 39 40# Reasonable testing sizes for dimensions 41L = 20 42M = 10 43S = 5 44XS = 3 45 46# Unique value to distinguish default from anything else 47_NOTHING = object() 48 49 50# Extension of getattr to support qualified names 51# e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm 52def _getattr_qual(obj, name, default=_NOTHING): 53 try: 54 for path in name.split("."): 55 obj = getattr(obj, path) 56 return obj 57 except AttributeError: 58 if default is not _NOTHING: 59 return default 60 else: 61 raise 62 63 64class DecorateInfo: 65 """Describes which test, or type of tests, should be wrapped in the given 66 decorators when testing an operator. Any test that matches all provided 67 arguments will be decorated. The decorators will only be applied if the 68 active_if argument is True.""" 69 70 __slots__ = [ 71 "decorators", 72 "cls_name", 73 "test_name", 74 "device_type", 75 "dtypes", 76 "active_if", 77 ] 78 79 def __init__( 80 self, 81 decorators, 82 cls_name=None, 83 test_name=None, 84 *, 85 device_type=None, 86 dtypes=None, 87 active_if=True, 88 ): 89 self.decorators = ( 90 list(decorators) 91 if isinstance(decorators, collections.abc.Sequence) 92 else [decorators] 93 ) 94 self.cls_name = cls_name 95 self.test_name = test_name 96 self.device_type = device_type 97 self.dtypes = dtypes 98 self.active_if = active_if 99 100 # Validate dtypes 101 if self.dtypes is not None: 102 for dtype in self.dtypes: 103 assert isinstance(dtype, torch.dtype) 104 105 def is_active(self, cls_name, test_name, device_type, dtype, param_kwargs): 106 return ( 107 self.active_if 108 and (self.cls_name is None or self.cls_name == cls_name) 109 and (self.test_name is None or self.test_name == test_name) 110 and (self.device_type is None or self.device_type == device_type) 111 and (self.dtypes is None or dtype in self.dtypes) 112 # Support callables over kwargs to determine if the decorator is active. 113 and ( 114 self.active_if(param_kwargs) 115 if isinstance(self.active_if, Callable) 116 else self.active_if 117 ) 118 ) 119 120 121# FIXME 122# Note: historically the 'input' kwarg had to be a Tensor or TensorList, but we are trying 123# to support scalar inputs, too. Some tests still depend on 'input' being a Tensor 124# or TensorList, however. 125class SampleInput: 126 """Represents sample inputs to a function.""" 127 128 __slots__ = [ 129 "input", 130 "args", 131 "kwargs", 132 "output_process_fn_grad", 133 "broadcasts_input", 134 "name", 135 ] 136 137 def __init__( 138 self, 139 input, 140 *var_args, 141 args=None, 142 kwargs=None, 143 output_process_fn_grad=None, 144 broadcasts_input=None, 145 name=None, 146 **var_kwargs, 147 ): 148 # input is the first input to the op and is typically either a Tensor or TensorList (Sequence[Tensor]). 149 # This follows the typical pattern where for Tensor inputs op(t, ...) = t.op(...). 150 self.input = input 151 152 # Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as 153 # SampleInput(input, *args, **kwargs) but not to mix the two forms 154 if args is not None or kwargs is not None: 155 assert ( 156 not var_args and not var_kwargs 157 ), """ 158A SampleInput can be constructed "naturally" with *args and **kwargs or by 159explicitly setting the "args" and "kwargs" parameters, but the two 160methods of construction cannot be mixed!""" 161 elif len(var_args) or len(var_kwargs): 162 assert ( 163 output_process_fn_grad is None 164 and broadcasts_input is None 165 and name is None 166 ), """ 167A SampleInput constructed "naturally" with *args and **kwargs 168cannot specify additional metadata in keyword arguments""" 169 170 self.args = args if args is not None else var_args 171 assert isinstance(self.args, tuple) 172 self.kwargs = kwargs if kwargs is not None else var_kwargs 173 assert isinstance(self.kwargs, dict) 174 175 self.output_process_fn_grad = ( 176 output_process_fn_grad 177 if output_process_fn_grad is not None 178 else lambda x: x 179 ) 180 self.name = name if name is not None else "" 181 182 # Specifies if `self.input` is broadcasted or not, 183 # given that the operator supports broadcasting. 184 # This field is used to verify the behavior for inplace variant. 185 # 186 # If a SampleInput is marked with `broadcasts_input=True`, 187 # it is verified that we get a `RuntimeError` with this sample, 188 # and inplace variant. Also inplace grad{grad} tests are skipped, 189 # for such inputs (as they will error out otherwise). 190 self.broadcasts_input = ( 191 broadcasts_input if broadcasts_input is not None else False 192 ) 193 194 def with_metadata( 195 self, *, output_process_fn_grad=None, broadcasts_input=None, name=None 196 ): 197 if output_process_fn_grad is not None: 198 self.output_process_fn_grad = output_process_fn_grad 199 if broadcasts_input is not None: 200 self.broadcasts_input = broadcasts_input 201 if name is not None: 202 self.name = name 203 return self 204 205 def _repr_helper(self, formatter): 206 # Helper function to return the details of the SampleInput as `str` 207 # It consolidates all the fields of SampleInput and allows, 208 # formatting the fields like `input`, `args`, etc with `formatter` 209 # callable to customize the representation. 210 # Look at `summary` method for example. 211 arguments = [ 212 f"input={formatter(self.input)}", 213 f"args={formatter(self.args)}", 214 f"kwargs={formatter(self.kwargs)}", 215 f"broadcasts_input={self.broadcasts_input}", 216 f"name={repr(self.name)}", 217 ] 218 219 return f'SampleInput({", ".join(a for a in arguments if a is not None)})' 220 221 def __repr__(self): 222 return self._repr_helper(lambda x: x) 223 224 def summary(self): 225 # Returns the SampleInput details in a more 226 # friendly format. 227 # It formats `Tensor` and `TensorList` 228 # in a more condensed representation. 229 def formatter(arg): 230 # Format any instance of `Tensor` (standalone, in list, or in dict) 231 # by Tensor[TensorShape] 232 # Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4] 233 if isinstance(arg, torch.Tensor): 234 shape = str(tuple(arg.shape)) 235 dtype = str(arg.dtype) 236 device = str(arg.device) 237 contiguity_suffix = "" 238 # NB: sparse CSR tensors annoyingly return is_sparse=False 239 is_sparse = arg.is_sparse or arg.layout == torch.sparse_csr 240 if not is_sparse and not arg.is_contiguous(): 241 contiguity_suffix = ", contiguous=False" 242 return f'Tensor[size={shape}, device="{device}", dtype={dtype}{contiguity_suffix}]' 243 elif isinstance(arg, dict): 244 return {k: formatter(v) for k, v in arg.items()} 245 elif is_iterable_of_tensors(arg): 246 return "TensorList[" + ", ".join(map(formatter, arg)) + "]" 247 elif isinstance(arg, (list, tuple)): # Handle list, tuple 248 return "(" + ",".join(map(formatter, arg)) + ")" 249 250 return repr(arg) 251 252 return self._repr_helper(formatter) 253 254 # Applies the transform f(t) -> t to each tensor and dtype in the SampleInput 255 def transform(self, f): 256 def tt(t): 257 def _tt(t): 258 with torch.no_grad(): 259 return f(t) 260 261 if isinstance(t, torch.Tensor): 262 return _tt(t) 263 elif isinstance(t, torch.dtype): 264 return _tt(t) 265 elif isinstance(t, list): 266 return list(map(tt, t)) 267 elif isinstance(t, tuple): 268 return tuple(map(tt, t)) 269 elif isinstance(t, dict): 270 return {k: tt(v) for k, v in t.items()} 271 else: 272 return t 273 274 sample_tt_input, tt_args, tt_kwargs = ( 275 tt(self.input), 276 tt(self.args), 277 tt(self.kwargs), 278 ) 279 280 # Note the transformed SampleInput assumes metadata like output_process_fn_grad is still valid! 281 return SampleInput( 282 sample_tt_input, 283 args=tt_args, 284 kwargs=tt_kwargs, 285 output_process_fn_grad=self.output_process_fn_grad, 286 broadcasts_input=self.broadcasts_input, 287 name=self.name + "_transformed", 288 ) 289 290 # Returns the NumPy version of the sample input object in the form of a tuple: (input, args, kwargs) 291 # Converts tensors to ndarrays by calling .detach().cpu().numpy() on them 292 # Converts dtypes by remapping them using torch_to_numpy_dtype_dict 293 def numpy(self): 294 def to_numpy(t): 295 if isinstance(t, torch.Tensor): 296 if t.dtype is torch.bfloat16: 297 return t.detach().cpu().to(torch.float32).numpy() 298 if t.dtype is torch.chalf: 299 return t.detach().cpu().to(torch.cfloat).numpy() 300 return t.detach().cpu().numpy() 301 elif isinstance(t, torch.dtype): 302 return torch_to_numpy_dtype_dict[t] 303 304 return t 305 306 return self.transform(to_numpy) 307 308 def noncontiguous(self): 309 def to_noncontiguous(t): 310 if isinstance(t, torch.Tensor): 311 return noncontiguous_like(t) 312 elif isinstance(t, torch.dtype): 313 return t 314 315 return t 316 317 return self.transform(to_noncontiguous) 318 319 320NumericsFilter = collections.namedtuple("NumericsFilter", ["condition", "safe_val"]) 321 322 323class ErrorInput: 324 """ 325 A SampleInput that will cause the operation to throw an error plus information 326 about the resulting error. 327 """ 328 329 __slots__ = ["sample_input", "error_type", "error_regex"] 330 331 def __init__(self, sample_input, *, error_type=RuntimeError, error_regex): 332 self.sample_input = sample_input 333 self.error_type = error_type 334 self.error_regex = error_regex 335 336 337class AliasInfo: 338 """Class holds alias information. For example, torch.abs -> 339 torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_ 340 """ 341 342 def __init__(self, alias_name): 343 self.name = alias_name 344 self.op = _getattr_qual(torch, alias_name) 345 self.method_variant = getattr(torch.Tensor, alias_name, None) 346 self.inplace_variant = getattr(torch.Tensor, alias_name + "_", None) 347 348 def __call__(self, *args, **kwargs): 349 return self.op(*args, **kwargs) 350 351 352# Note [OpInfos] 353# ~~~~~~~~~~~~~~ 354# 355# The majority of this note was written shortly after the PyTorch 1.9 release. 356# If you notice it's out-of-date or think it could be improved then please 357# file an issue. 358# 359# See also: the OpInfo tracker (https://github.com/pytorch/pytorch/issues/54261) 360# See also: "Writing Test Templates" in common_device_type.py to learn how to 361# parametrize a test template using OpInfos. 362# See also: PyTorch's GitHub wiki on running and writing tests 363# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests 364# See also: ModuleInfos, OpInfo's sister class, defined in common_modules.py 365# 366# An OpInfo is a collection of metadata related to a PyTorch operator. This 367# metadata is used to generate tests that validate properties of the operator, 368# like if it implements the correct gradient formula. 369# 370# WHY OPINFOS? 371# ~~~~~~~~~~~~ 372# 373# OpInfos are principally intended to do three things: 374# 375# 1) to allow systematic testing over all PyTorch's operators 376# 2) to simplify operating testing by autogenerating many tests 377# 3) to allow systems (like autograd, torchscript, fx, nnc...) to test 378# against every PyTorch operator 379# 380# All these goals are still a work in progress. Not every operator has an 381# OpInfo, and some operator tests that could be automatically generated 382# still have to be written manually. 383# 384# It's helpful to understand that OpInfos are both about test simplification and 385# modularity. PyTorch is a complicated framework with many interrelated systems, 386# too many for any one person to keep track of. An OpInfo can be thought of as the 387# interface between an operator implementer and those other systems. Instead of 388# requiring the implementer of torch.foo understand how to test its forward 389# mode AD or NNC support that's typically handled automatically just by 390# defining an OpInfo. 391# 392# It's often surprising to OpInfo writers that just implementing an OpInfo 393# typically can't verify an operator is actually implemented correctly: 394# 395# "If an OpInfo doesn't validate my op works as expected, what's the point 396# of it?" 397# 398# But the point of is the above. OpInfos are intended to let you focus on testing 399# the operator logic you're familiar with instead of having to write tests for 400# how the operator interacts with each of PyTorch's many systems. 401# 402# And, OK, it turns out that SOMETIMES just writing an OpInfo DOES 403# validate your op works as expected, but that's only in special 404# cases. See below for details. 405# 406# WHAT'S AN OPINFO? 407# ~~~~~~~~~~~~~~~~~ 408# 409# So what is an OpInfo? It's a Python class that describes an operator's properties, 410# like which dtypes it supports on the CPU and whether it has any aliases. 411# These properties can be divided into three categories: 412# 413# 1) Metadata describing the operator, like the operator's name and if it 414# "supports" the out kwarg. 415# 2) Test directives, like "skips" that tell the test suite to skip some 416# tests. 417# 3) A "sample inputs" function that generates valid inputs for the operator. 418# 419# OpInfo attributes are described in more detail below. 420# 421# THE SAMPLE INPUTS FUNCTION 422# ~~~~~~~~~~~~~~~~~~~~~~~~~~ 423# 424# The "sample inputs" function merits special elaboration. This function is 425# crucial to testing with OpInfos. A typical OpInfo test has to treat the operator 426# as a black box. There's no structure for the test to understand or exploit. 427# Without "sample inputs" it wouldn't even know how to call the OpInfo's 428# operator. The sample input function saves the day by providing different 429# "SampleInputs" that can be used to call the operator. A sample input 430# function should have the following signature: 431# 432# def sample_inputs_foo(op_info, device, dtype, requires_grad, **kwargs): 433# 434# And should return an iterable of SampleInputs (see the class description 435# above). Each SampleInput defines an "input", "args", "kwargs", an 436# "output_process_fn_grad" function, the "broadcasts_input" bool and a 437# "name". 438# 439# All the "sample_inputs" functions are invoked within a `torch.no_grad()` 440# environment for efficiency and correctness. As such remember to set the 441# "requires_grad" flag on the inputs **after** performing any transformations 442# on them. 443# 444# The "input" is the first argument to the operator, or the tensor that 445# the method or inplace variants of the operator should be called on, and 446# should be on the requested device, of the requested dtype, and its 447# requires_grad attribute should be set to the requires_grad argument. 448# 449# "args" should contain positional arguments, and "kwargs" keyword arguments. 450# 451# "output_process_fn_grad" has an interesting name. It's a function that maps 452# the operator's output (when given the input, args, and kwargs) to the 453# portion of the output to gradcheck. For example, consider an operator 454# like torch.linalg.slogdet 455# (https://pytorch.org/docs/main/generated/torch.linalg.slogdet.html). 456# This operator returns a tuple of two tensors, but the first tensor 457# cannot be backwarded through. Its "output_process_fn_grad" filters 458# this output tuple to just the second argument, which we can call backward 459# on. Functions that produce a single tensor can ignore this argument. 460# 461# "broadcasts_input" is a bool indicated if the SampleInput causes the operator 462# to broadcast the "input" argument. This is important for tests to understand 463# because inplace variants of operations throw a runtime error if they 464# would broadcast their input arguments, so tests that work with inplace 465# variants filter SampleInputs that broadcast their input. 466# 467# "name" is a string that's just used for debugging. It appears when printing 468# the SampleInput. 469# 470# Sample inputs are designed to be used with many tests, some 471# that are very time consuming, so they should be a small 472# set with small tensors. An elaborated set of sample inputs 473# can be specified using the "reference_inputs_func" attribute. 474# The "reference inputs" for an operation are an extended 475# set of sample inputs that can more exhausively test an 476# operator. They are used by only a few tests that are careful 477# not to take too long to run. Adding reference inputs 478# is highly encouraged! 479# 480# THE (OPTIONAL) ERROR INPUTS FUNCTION 481# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 482# 483# OpInfos may optionally specify "error inputs" through an error function. If 484# specified test_errors in test_ops.py will call the op with these inputs 485# and validate that the desired error is thrown. 486# 487# Error inputs automate a common testing pattern where multiple inputs are 488# passed to an operation and the errors they thrown are reviewed. Tests 489# written in this style should be ported to the new OpInfo pattern. 490# 491# Error inputs are specified using the ErrorInputs class, which contains 492# a SampleInput (see above) and data about the expected error. 493# 494# OPINFO FILE ORGANIZATION 495# ~~~~~~~~~~~~~~~~~~~~~~~~ 496# 497# All OpInfos are currently defined in this file. Most OpInfo tests are defined 498# in test_ops.py, but some system-specific tests are defined in those 499# systems' test files, and subclass-specific tests are defined in the test 500# file that corresponds to that subclass (see the below). 501# Expect a reorganization in the future. 502# 503# WHAT'S TESTED? 504# ~~~~~~~~~~~~~~ 505# 506# Every OpInfo in the op_db sequence has the following properties validated in 507# test_ops.py: 508# 509# - that its supported dtypes are specified correctly 510# - that the operation produces the same results when called with noncontiguous inputs 511# - that it supports the out= argument properly (if it allows out=), 512# see https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch 513# - that it works with the conjugate view bit properly 514# - that its function, method, and inplace variants perform the same operation 515# (that is, that torch.add, torch.Tensor.add, and torch.Tensor.add_ all 516# do the same thing). 517# - that its inplace variant preserves the input's storage 518# - that its gradient formula is implemented correctly, and that it supports 519# gradgrad and complex grad and gradgrad and forward mode AD properly for 520# the op's function and inplace variants (method variants are skipped 521# to reduce test time). 522# - that the operation performs the same operation when traced or scripted 523# using the jit 524# - that the operation is autodifferentiated by the jit as expected 525# - that the operator's aliases, if any, perform the same operation and that 526# the jit understands the alias 527# - that the operator throws the correct errors (if error_inputs is defined) 528# - that the operator produces the same results as a NumPy reference (if ref is defined) 529# - that the operator produces the same results as a NumPy reference on an extended 530# set of "reference inputs" (if both ref and reference_inputs_func are defined) 531# (NOTE: elementwise unary and elementwise binary OpInfos do this even if only 532# ref is defined, because they effectively autogenerate reference inputs) 533# - that the operator works on different CUDA devices 534# 535# Additional OpInfo tests are in test_jit_fuser_te.py, test_fx_experimental.py, 536# and test_fx.py. These tests validate that operators work with NNC and FX 537# as expected. 538# 539# For performance, some of the above tests may only run on the first 540# SampleInput returned by an OpInfo's sample input function. 541# 542# In addition to these tests, some subclasses (discussed in the next section) 543# define additional tests. 544# 545# Critically, as mentioned above, what's not necessarily tested is that the operator 546# works as expected. When implementing an OpInfo an engineer must still 547# typically write one or more tests validating the operator's behavior. 548# The exception to this is if reference testing is sufficient, or if 549# the operation belongs to an OpInfo subclass that has more exhaustive 550# operator testing. Elementwise unary and elementwise binary operators, 551# in particular, usually don't require additional testing beyond 552# writing an Opinfo. 553# 554# 555# OPINFO (SUB)CLASSES 556# ~~~~~~~~~~~~~~~~~~~ 557# 558# In addition to the OpInfo base class there are several specialized OpInfo 559# subclasses. For example, the UnaryUfuncInfo subclass is used for 560# unary elementwise operations. These operations have a common structure 561# that test_unary_ufuncs.py exploits with additional automated testing. 562# The automated testing in test_unary_ufuncs.py is so thorough, comparing 563# the operator to a NumPy reference function on a plethora of values, that 564# just implementing an OpInfo for a unary elementwise operation is often 565# sufficient testing. 566# 567# The ForeachFuncInfo is another OpInfo subclass that is hyper-specialized to a 568# very unique class of operations. These OpInfos aren't included in the 569# op_db sequence and have their own tests. 570# 571# Other OpInfo subclasses, like SpectralFuncInfo, are just for convenience 572# when writing OpInfos. 573# 574# TESTING A NEW OPERATOR 575# ~~~~~~~~~~~~~~~~~~~~~~ 576# 577# If you're adding a new operator to any of the following namespaces: 578# - torch 579# - torch.fft 580# - torch.linalg, 581# - torch.special 582# - torch.nn.functional 583# then you should typically add an OpInfo for it. 584# 585# As mentioned a couple times above, implementing an OpInfo is not 586# usually sufficient testing (unless the operator is a unary or binary elementwise 587# operator). The OpInfo will only test the properties described in the 588# "WHAT'S TESTED" section. It DOES NOT necessarily verify that the operator is 589# implemented correctly. 590# 591# TIPS FOR WRITING AN OPINFO AND OPINFO TESTS 592# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 593# 594# Writing an OpInfo can be a little daunting. Since the point of an OpInfo is to 595# be consumed by a variety of systems it can be hard to understand how to 596# deal with test failures or how to set the OpInfo metadata properly. 597# 598# Before adding an OpInfo it helps to look at other OpInfos. A sample inputs 599# function must be defined, and the operator's dtypes must be specified. 600# Once that's done you should run the operator's tests in test_ops.py 601# (these can be filtered using the "-k" argument in pytest). Tests that 602# fail should provide an error message that describes what to change about 603# your OpInfo. You don't need to worry about changing an OpInfo's default 604# values unless a test yells at you. 605# 606# Similarly, if you're writing a test that consumes OpInfos then it's critical 607# your test provides a clear error message describing what to do when it 608# fails. You should not assume the OpInfo implementer is familiar with your 609# system. 610# 611# If you see a confusing error message while developing an OpInfo then please 612# file an issue describing what happened. 613# 614# This trial-and-error approach to writing an OpInfo can be frustrating, 615# but it's probably necessary as long as OpInfos don't require 616# learning about all the systems that consume them. One thing that can help 617# is the get_supported_dtypes() function defined in utils.py. This 618# function can be used to programmatically specify the dtypes an operator 619# supports, and is especially useful if writing an OpInfo on a machine 620# without a CUDA device. See its documentation for more details. 621# 622# THE FUTURE OF OPINFOS AND OPINFO TESTING 623# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 624# 625# In the future we expect OpInfo coverage to improve and cover 626# the great majority of PyTorch's (public) operators. 627# 628 629 630# Classes and methods for the operator database 631@dataclass 632class OpInfo: 633 """Operator information and helper functions for acquiring it.""" 634 635 # the string name of the function 636 name: str 637 638 # An optional reference function that accepts ndarrays (AKA "NumPy arrays"). 639 # If given, the op will be compared with its reference on each of its sample inputs. 640 ref: Optional[Callable] = None 641 642 # the following metadata describes the operator, its variants, and its aliases, if any 643 644 # iterable of aliases, e.g. ("absolute",) for torch.abs 645 aliases: Iterable = None 646 647 # additional string to include in the test name 648 # this is useful when an op needs multiple OpInfos, 649 # like divide does, often because it's really several 650 # different ops behind the scenes 651 variant_test_name: str = "" 652 653 # the function variant of the operation, populated as torch.<name> if None 654 op: Callable = None 655 656 # allows the method variant of this operation to be specified as follows: 657 # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name 658 # - if None, then the OpInfo explicitly specifies is has no associated method 659 # - if a Callable, then that callable should be the method associated with this operation 660 method_variant: Callable = _NOTHING 661 662 # allows the inplace variant of this operation to be specified as follows: 663 # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name 664 # - if None, then the OpInfo explicitly specifies is has no associated inplace variant 665 # - if a Callable, then that callable should be the inplace variant associated with this operation 666 inplace_variant: Callable = _NOTHING 667 668 # allows the operator variant of this operation to be specified as follows: 669 # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name 670 # - if None, then the OpInfo explicitly specifies is has no associated operator 671 # - if a Callable, then that callable should be the operator associated with this operation 672 operator_variant: Callable = _NOTHING 673 674 # allows the inplace operator variant of this operation to be specified as follows: 675 # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name 676 # - if None, then the OpInfo explicitly specifies is has no associated inplace operator 677 # - if a Callable, then that callable should be the inplace operator associated with this operation 678 inplace_operator_variant: Callable = _NOTHING 679 680 # the following metadata are test directives for skipping or modifying tests 681 682 # information about which tests to skip 683 skips: Tuple = () 684 685 # decorators to apply to generated tests 686 decorators: Tuple = () 687 688 # the following are pointers to functions to generate certain classes of inputs 689 690 # function to generate sample inputs with strided layouts 691 sample_inputs_func: Callable = None 692 693 # function to generate a more thorough set of samples inputs with strided layouts 694 reference_inputs_func: Callable = None 695 696 # function to generate inputs that will throw errors 697 error_inputs_func: Callable = None 698 699 # function to generate sparse (coo, csr, csc, bsr, bsc) inputs that will throw errors 700 error_inputs_sparse_func: Callable = None 701 702 # function to generate sample inputs with sparse coo layouts 703 sample_inputs_sparse_coo_func: Callable = None 704 705 # function to generate sample inputs with sparse csr layouts 706 sample_inputs_sparse_csr_func: Callable = None 707 708 # function to generate sample inputs with sparse csc layouts 709 sample_inputs_sparse_csc_func: Callable = None 710 711 # function to generate sample inputs with sparse bsr layouts 712 sample_inputs_sparse_bsr_func: Callable = None 713 714 # function to generate sample inputs with sparse bsc layouts 715 sample_inputs_sparse_bsc_func: Callable = None 716 717 # the following metadata relates to dtype support and is tested for correctness in test_ops.py 718 719 # dtypes this function works with on the CPU, 720 # inherited by other device types that don't specify their own dtypes 721 dtypes: _dispatch_dtypes = None 722 723 # the following dtypesIf... options override the dtypes value on their respective device types 724 725 # dtypes this function is expected to work with on CUDA 726 dtypesIfCUDA: _dispatch_dtypes = None 727 728 # dtypes this function is expected to work with on ROCM 729 dtypesIfROCM: _dispatch_dtypes = None 730 731 dtypesIfHpu: _dispatch_dtypes = None 732 733 # dtypes this function is expected to work with on XPU 734 dtypesIfXPU: _dispatch_dtypes = None 735 736 # backward dtypes this function is expected to work with 737 backward_dtypes: _dispatch_dtypes = None 738 739 # backward dtypes this function is expected to work with on CUDA 740 backward_dtypesIfCUDA: _dispatch_dtypes = None 741 742 # backward dtypes this function is expected to work with on ROCM 743 backward_dtypesIfROCM: _dispatch_dtypes = None 744 745 backward_dtypesIfHpu: _dispatch_dtypes = None 746 747 # the following metadata describes the operators out= support 748 749 # whether the op supports the out kwarg 750 # defaults to True, if the op does not allow the out kwarg or 751 # supports it incorrectly then test_out in test_ops.py should fail 752 supports_out: bool = True 753 754 # the following metadata relates to autograd support 755 # whether the operation supports backward mode AD 756 # if true, gradient correctness is tested in test_ops.py 757 # using the op's sample inputs 758 supports_autograd: bool = True 759 760 # whether the op supports second order gradients 761 # if true, gradgrad correctness is tested in test_ops.py 762 # defaults to support_autograd's value 763 # TODO: rename this to supports_bwgrad_bwgrad to be consistent with below 764 supports_gradgrad: bool = None 765 766 # whether the ops supports second order gradients via 767 # forward-over-reverse. If True, forward-over-reverse gradgrad correctness 768 # is tested. If False, test that forward grad is not implemented. 769 # Defaults to False. 770 supports_fwgrad_bwgrad: bool = False 771 772 # whether the operation supports inplace autograd 773 # if true, tested in test_ops.py 774 # defaults to supports_autograd's value 775 supports_inplace_autograd: bool = None 776 777 # Whether the operation support forward mode AD 778 # If the value is True, we check that the gradients are correct 779 # If the value is False, we test that forward grad is not implemented 780 supports_forward_ad: bool = False 781 782 # Whether the operation has a varargs variant 783 # (e.g. functions like ones, zeros, methods like view, permute) 784 supports_varargs: bool = False 785 786 # Whether the forward operation avoids materializing COW tensor inputs 787 supports_cow_input_no_materialize_forward: bool = True 788 789 # Whether the backward operation avoids materializing COW tensor inputs 790 supports_cow_input_no_materialize_backward: bool = True 791 792 # Whether to skip the backward part of the COW tensor input test 793 skip_cow_input_backward: bool = False 794 795 # If `supports_cow_input_no_materialize_forward == True`, this list contains 796 # the arg indices or kwarg names of inputs that are expected to materialize 797 allow_cow_input_materialize_forward: List[Union[int, str]] = None 798 799 # If `supports_cow_input_no_materialize_backward == True`, this list contains 800 # the arg indices or kwarg names of inputs that are expected to materialize 801 allow_cow_input_materialize_backward: List[Union[int, str]] = None 802 803 # wrapper function for gradcheck 804 gradcheck_wrapper: Callable = lambda op, *args, **kwargs: op(*args, **kwargs) 805 806 # whether to check batched grad when doing gradcheck 807 # defaults to support_autograd's value 808 check_batched_grad: bool = None 809 810 # whether to check batched grad grad when doing gradgradcheck 811 # default's to support_gradgrad's value 812 check_batched_gradgrad: bool = None 813 814 # whether to check batched forward grad when doing gradcheck 815 # defaults to the value of `supports_forward_ad` 816 check_batched_forward_grad: bool = None 817 818 # whether to check batched forward grad when doing gradcheck 819 # defaults to the value of `check_batched_forward_grad` 820 check_inplace_batched_forward_grad: bool = None 821 822 # tolerance for nondeterminism while performing gradcheck 823 gradcheck_nondet_tol: float = 0.0 824 825 # Whether to use the fast implmentation for gradcheck/gradgradcheck. 826 # When set to None, defers to the default value provided by the wrapper 827 # function around gradcheck (testing._internal.common_utils.gradcheck) 828 gradcheck_fast_mode: bool = None 829 830 # the following metadata relates to JIT support and is tested for correctness in test_ops.py 831 832 # name of the corresponding aten:: operator 833 aten_name: str = None 834 835 # if this is a composite implicit autograd op, the decomposed op 836 decomp_aten_name: Optional[str] = None 837 838 # name of the corresponding aten:: operator for backwards 839 aten_backward_name: Optional[str] = None 840 841 # if a op's aten::node is expected to be symbolically autodiffed 842 assert_autodiffed: bool = False 843 844 # a list of strings with node names that are expected to be in a 845 # DifferentiableGraph when autodiffed. Ex: ['aten::add', 'aten::mm'], 846 # default is populated to be ['aten::(name of Python operator)'] 847 autodiff_nonfusible_nodes: List[str] = None 848 849 # a list of strings with node names that are expected to be in FusionGroups 850 # inside of DifferentiableGraphs when this operation is autodiffed. 851 # Ex: ['aten::add', 'aten::mm'], defaults to an empty list 852 # Note: currently no ops use fusible nodes 853 autodiff_fusible_nodes: List[str] = None 854 855 # the following metadata relates to sparse support and is used in test_sparse.py 856 857 # whether the op supports sparse coo inputs, defaults to False 858 # TODO: rename supports_sparse to supports_sparse_coo 859 supports_sparse: bool = None 860 861 # only run tracing tests 862 supports_scripting: bool = True 863 864 # if the operator can be traced 865 supports_tracing: bool = True 866 867 # the following metadata relates to sparse compressed support and 868 # is used in test_sparse_csr.py and test_sparse.py 869 870 # whether the op supports sparse csr inputs, defaults to False 871 supports_sparse_csr: bool = None 872 # whether the op supports sparse csc inputs, defaults to False 873 supports_sparse_csc: bool = None 874 # whether the op supports sparse bsr inputs, defaults to False 875 supports_sparse_bsr: bool = None 876 # whether the op supports sparse bsc inputs, defaults to False 877 supports_sparse_bsc: bool = None 878 # whether the op supports nested jagged inputs, defaults to False 879 supports_njt: bool = None 880 881 # whether the op promotes integer inputs to float 882 promotes_int_to_float: bool = False 883 884 # the following metadata relates to complex support and is checked in test_ops.py 885 886 test_conjugated_samples: bool = True 887 888 test_neg_view: bool = True 889 890 # assert that jit shape analysis fully propagates shape 891 assert_jit_shape_analysis: bool = False 892 893 # the following metadata relates to ExpandedWeights support and is checked in test_expanded_weights.py 894 895 supports_expanded_weight: bool = False 896 897 is_factory_function: bool = False 898 899 def __post_init__(self): 900 self._original_opinfo_args = asdict(self).copy() 901 902 assert self.dtypes is not None, f"OpInfo for {self.name} has no dtypes!" 903 904 dtypes_args = ( 905 self.dtypes, 906 self.dtypesIfCUDA, 907 self.dtypesIfROCM, 908 self.dtypesIfXPU, 909 ) 910 911 # Validates the dtypes are generated from the dispatch-related functions 912 for dtype_list in dtypes_args: 913 assert isinstance(dtype_list, (_dispatch_dtypes, type(None))) 914 915 if self.aten_name is None: 916 self.aten_name = self.name 917 918 # Attribute to verify dynamic_dtypes are used. 919 self.dynamic_dtypes = any( 920 isinstance(dtypes, utils._dynamic_dispatch_dtypes) for dtypes in dtypes_args 921 ) 922 923 if self.dynamic_dtypes: 924 # Make sure `dtyesIfCUDA` is dynamic, if dynamic dispatch is used for CPU 925 # This is because, below we set dtypesIfCUDA to dtypes if they are None. 926 assert isinstance(self.dtypesIfCUDA, utils._dynamic_dispatch_dtypes), ( 927 f"To use dynamic dypes for operator {self.name}, " 928 "acquire the dtypes dynamically for argument `dtypesIfCUDA`." 929 "This is to ensure that CUDA dtypes are acquired correctly as they" 930 "differ from CPU dtypes occasionally" 931 ) 932 933 self.dtypes = set(self.dtypes) 934 935 # NOTE: backward dtypes must be acquired before forward dtypes 936 # since they fallback to explicit (not implicit!) specifications of 937 # forward dtypes 938 self.backward_dtypesIfROCM = ( 939 set(self.backward_dtypesIfROCM) 940 if self.backward_dtypesIfROCM is not None 941 else ( 942 self.backward_dtypesIfCUDA 943 if self.backward_dtypesIfCUDA is not None 944 else self.backward_dtypes 945 if self.backward_dtypes is not None 946 else self.dtypesIfROCM 947 if self.dtypesIfROCM is not None 948 else self.dtypesIfCUDA 949 if self.dtypesIfCUDA is not None 950 else self.dtypes 951 ) 952 ) 953 self.backward_dtypesIfCUDA = ( 954 set(self.backward_dtypesIfCUDA) 955 if self.backward_dtypesIfCUDA is not None 956 else ( 957 self.backward_dtypes 958 if self.backward_dtypes is not None 959 else self.dtypesIfCUDA 960 if self.dtypesIfCUDA is not None 961 else self.dtypes 962 ) 963 ) 964 self.backward_dtypesIfHpu = ( 965 set(self.backward_dtypesIfHpu) 966 if self.backward_dtypesIfHpu is not None 967 else ( 968 self.backward_dtypes 969 if self.backward_dtypes is not None 970 else self.dtypes 971 ) 972 ) 973 974 self.backward_dtypes = ( 975 set(self.backward_dtypes) 976 if self.backward_dtypes is not None 977 else self.dtypes 978 ) 979 980 self.dtypesIfCUDA = ( 981 set(self.dtypesIfCUDA) if self.dtypesIfCUDA is not None else self.dtypes 982 ) 983 self.dtypesIfROCM = ( 984 set(self.dtypesIfROCM) 985 if self.dtypesIfROCM is not None 986 else self.dtypesIfCUDA 987 ) 988 self.dtypesIfXPU = ( 989 set(self.dtypesIfXPU) if self.dtypesIfXPU is not None else self.dtypesIfCUDA 990 ) 991 992 self.dtypesIfHpu = ( 993 set(self.dtypesIfHpu) if self.dtypesIfHpu is not None else self.dtypes 994 ) 995 996 # NOTE: if the op is unspecified it is assumed to be under the torch namespace 997 if not self.op: 998 self.op = _getattr_qual(torch, self.name) 999 1000 if self.method_variant is _NOTHING: 1001 self.method_variant = getattr(torch.Tensor, self.name, None) 1002 1003 # attributes like real, imag are not callable 1004 if not callable(self.method_variant): 1005 self.method_variant = None 1006 1007 if self.inplace_variant is _NOTHING: 1008 inplace_name = self.name + "_" 1009 self.inplace_variant = getattr(torch.Tensor, inplace_name, None) 1010 1011 if self.operator_variant is _NOTHING: 1012 self.operator_variant = getattr(operator, self.name, None) 1013 1014 if self.inplace_operator_variant is _NOTHING: 1015 # Note: operator.i<op> will use operator.<op> and assign the result to the lhs when no 1016 # __i<op>__ method is found. This results in the appearance of an inplace operator variant which 1017 # does not have the correct inplace behavior. To avoid this, we guard automatic detection of the inplace 1018 # operator with a check that an inplace variant exists. 1019 if self.inplace_variant is not None: 1020 inplace_operator_name = "i" + self.name 1021 self.inplace_operator_variant = getattr( 1022 operator, inplace_operator_name, None 1023 ) 1024 else: 1025 self.inplace_operator_variant = None 1026 1027 self.decorators = (*self.decorators, *self.skips) 1028 1029 # Specifying sample inputs function without specifying the 1030 # corresponding layout support implies the layout support: 1031 if self.supports_sparse is None: 1032 self.supports_sparse = self.sample_inputs_sparse_coo_func is not None 1033 if self.sample_inputs_sparse_coo_func is None: 1034 self.sample_inputs_sparse_coo_func = self._sample_inputs_unspecified 1035 1036 if self.supports_sparse_csr is None: 1037 self.supports_sparse_csr = self.sample_inputs_sparse_csr_func is not None 1038 if self.sample_inputs_sparse_csr_func is None: 1039 self.sample_inputs_sparse_csr_func = self._sample_inputs_unspecified 1040 1041 if self.supports_sparse_csc is None: 1042 self.supports_sparse_csc = self.sample_inputs_sparse_csc_func is not None 1043 if self.sample_inputs_sparse_csc_func is None: 1044 self.sample_inputs_sparse_csc_func = self._sample_inputs_unspecified 1045 1046 if self.supports_sparse_bsr is None: 1047 self.supports_sparse_bsr = self.sample_inputs_sparse_bsr_func is not None 1048 if self.sample_inputs_sparse_bsr_func is None: 1049 self.sample_inputs_sparse_bsr_func = self._sample_inputs_unspecified 1050 1051 if self.supports_sparse_bsc is None: 1052 self.supports_sparse_bsc = self.sample_inputs_sparse_bsc_func is not None 1053 if self.sample_inputs_sparse_bsc_func is None: 1054 self.sample_inputs_sparse_bsc_func = self._sample_inputs_unspecified 1055 1056 if self.supports_njt is None: 1057 self.supports_njt = False 1058 1059 # We run the sampling functions without tracking the gradiends of the creation of inputs 1060 self.sample_inputs_func = torch.no_grad()(self.sample_inputs_func) 1061 self.sample_inputs_sparse_coo_func = torch.no_grad()( 1062 self.sample_inputs_sparse_coo_func 1063 ) 1064 self.sample_inputs_sparse_csr_func = torch.no_grad()( 1065 self.sample_inputs_sparse_csr_func 1066 ) 1067 self.sample_inputs_sparse_csc_func = torch.no_grad()( 1068 self.sample_inputs_sparse_csc_func 1069 ) 1070 self.sample_inputs_sparse_bsr_func = torch.no_grad()( 1071 self.sample_inputs_sparse_bsr_func 1072 ) 1073 self.sample_inputs_sparse_bsc_func = torch.no_grad()( 1074 self.sample_inputs_sparse_bsc_func 1075 ) 1076 if self.reference_inputs_func is not None: 1077 self.reference_inputs_func = torch.no_grad()(self.reference_inputs_func) 1078 1079 if not self.autodiff_fusible_nodes: 1080 self.autodiff_fusible_nodes = [] 1081 1082 if self.autodiff_nonfusible_nodes is None: 1083 self.autodiff_nonfusible_nodes = ["aten::" + self.name] 1084 1085 # Autograd support 1086 1087 # Autograd flags that depend on backward AD only 1088 # - If setting has been explicitly set, raise error if inconsistent 1089 if self.supports_gradgrad is None: 1090 self.supports_gradgrad = self.supports_autograd 1091 else: 1092 assert not (self.supports_gradgrad and not self.supports_autograd), ( 1093 "supports_gradgrad refines the part of autograd is supported, so it should " 1094 "not be set if supports_autograd is False" 1095 ) 1096 if self.check_batched_grad is None: 1097 self.check_batched_grad = self.supports_autograd or self.supports_forward_ad 1098 else: 1099 assert not ( 1100 self.check_batched_grad 1101 and not (self.supports_autograd or self.supports_forward_ad) 1102 ), ( 1103 "check_batched_grad refines the part of autograd that will be checked (by gradcheck), so " 1104 "it should not be set if supports_autograd is False" 1105 ) 1106 if self.check_batched_gradgrad is None: 1107 self.check_batched_gradgrad = self.supports_gradgrad 1108 else: 1109 assert not (self.check_batched_gradgrad and not self.supports_gradgrad), ( 1110 "check_batched_gradgrad refines the part of autograd that will be checked (by " 1111 "gradgradcheck), so it should not be set if either supports_gradgrad or supports_autograd " 1112 "is False." 1113 ) 1114 if self.check_batched_forward_grad is None: 1115 self.check_batched_forward_grad = self.supports_forward_ad 1116 else: 1117 assert not ( 1118 self.check_batched_forward_grad and not self.supports_forward_ad 1119 ), ( 1120 "check_batched_forward_grad should only be used when supports_forward_ad " 1121 "is True. It is used to disable the test in the specific cases " 1122 "where the op supports forward ad but fails to compute " 1123 "batched forward grad." 1124 ) 1125 1126 if self.check_inplace_batched_forward_grad is None: 1127 self.check_inplace_batched_forward_grad = self.check_batched_forward_grad 1128 else: 1129 assert not ( 1130 self.check_inplace_batched_forward_grad 1131 and not self.check_batched_forward_grad 1132 ), ( 1133 "check_batched_forward_grad should only be used when check_batched_forward_grad " 1134 "is True. It is used to disable the test in the specific cases " 1135 "where the op supports batched forward grad but fails to compute batched forward " 1136 "grad for the inplace variant of the op." 1137 ) 1138 1139 assert not (self.supports_fwgrad_bwgrad and not self.supports_autograd), ( 1140 "supports_fwgrad_bwgrad enables forward-over-backward gradgrad checks and should only be " 1141 "True if backward ad is also checked, i.e., supports_forward_ad should be True.", 1142 self.name, 1143 ) 1144 1145 # Autograd flags that depend on both forward AD and backward AD 1146 if self.supports_inplace_autograd is None: 1147 self.supports_inplace_autograd = ( 1148 self.supports_autograd or self.supports_forward_ad 1149 ) 1150 else: 1151 assert not ( 1152 self.supports_inplace_autograd 1153 and not self.supports_autograd 1154 and not self.supports_forward_ad 1155 ), ( 1156 "supports_inplace_autograd refines the part of autograd that is supported, so " 1157 "it should not be set if both supports_autograd and supports_forward_ad are False" 1158 ) 1159 1160 if self.aliases is not None: 1161 self.aliases = tuple(AliasInfo(a) for a in self.aliases) # type: ignore[assignment] 1162 else: 1163 self.aliases = () 1164 1165 def __call__(self, *args, **kwargs): 1166 """Calls the function variant of the operator.""" 1167 return self.op(*args, **kwargs) 1168 1169 def __str__(self): 1170 return dataclass_repr(self) 1171 1172 def get_op(self): 1173 """Returns the function variant of the operator, torch.<op_name>.""" 1174 return self.op 1175 1176 def get_method(self): 1177 """Returns the method variant of the operator, torch.Tensor.<op_name>. 1178 Returns None if the operator has no method variant. 1179 """ 1180 return self.method_variant 1181 1182 def get_inplace(self): 1183 """Returns the inplace variant of the operator, torch.Tensor.<op_name>_. 1184 Returns None if the operator has no inplace variant. 1185 """ 1186 return self.inplace_variant 1187 1188 def get_operator(self): 1189 """Returns operator variant of the operator, e.g. operator.neg 1190 Returns None if the operator has no operator variant. 1191 """ 1192 return self.operator_variant 1193 1194 def get_inplace_operator(self): 1195 """Returns the inplace operator variant of the operator, e.g operator.iadd 1196 Returns None if the operator has no inplace operator variant""" 1197 return self.inplace_operator_variant 1198 1199 def conjugate_sample_inputs(self, device, dtype, requires_grad=False, **kwargs): 1200 """Returns an iterable of SampleInputs but with the tensor input or first 1201 tensor in a sequence input conjugated. 1202 """ 1203 1204 set_seed = kwargs.pop("set_seed", True) 1205 samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) 1206 conj_samples = list(samples) 1207 1208 def conjugate(tensor): 1209 _requires_grad = tensor.requires_grad 1210 tensor = tensor.conj() 1211 return tensor.requires_grad_(_requires_grad) 1212 1213 for i, sample in enumerate(samples): 1214 sample = conj_samples[i] 1215 # Note: it is assumed that the input here is either a tensor or tensorlist 1216 if isinstance(sample.input, torch.Tensor): 1217 sample.input = conjugate(sample.input) 1218 else: 1219 sample.input[0] = conjugate(sample.input[0]) 1220 1221 return TrackedInputIter( 1222 iter(conj_samples), 1223 "conjugate sample input", 1224 set_seed=set_seed, 1225 restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, 1226 ) 1227 1228 def sample_inputs(self, device, dtype, requires_grad=False, **kwargs): 1229 """ 1230 Returns an iterable of SampleInputs. 1231 1232 These samples should be sufficient to test the function works correctly 1233 with autograd, TorchScript, etc. 1234 """ 1235 set_seed = kwargs.pop("set_seed", True) 1236 samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs) 1237 1238 if kwargs.get("include_conjugated_inputs", False): 1239 conj_samples = self.conjugate_sample_inputs( 1240 device, dtype, requires_grad, **kwargs 1241 ) 1242 samples_list = list(samples) 1243 samples_list.extend(conj_samples) 1244 samples = tuple(samples_list) 1245 1246 return TrackedInputIter( 1247 iter(samples), 1248 "sample input", 1249 set_seed=set_seed, 1250 restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, 1251 ) 1252 1253 def reference_inputs(self, device, dtype, requires_grad=False, **kwargs): 1254 """ 1255 Returns an iterable of SampleInputs. 1256 1257 Distinct from sample_inputs() above because this returns an expanded set 1258 of inputs when reference_inputs_func is defined. If undefined this returns 1259 the sample inputs. 1260 """ 1261 set_seed = kwargs.pop("set_seed", True) 1262 if self.reference_inputs_func is None: 1263 samples = self.sample_inputs_func( 1264 self, device, dtype, requires_grad, **kwargs 1265 ) 1266 return TrackedInputIter( 1267 iter(samples), 1268 "reference input", 1269 set_seed=set_seed, 1270 restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, 1271 ) 1272 1273 if kwargs.get("include_conjugated_inputs", False): 1274 raise NotImplementedError 1275 1276 references = self.reference_inputs_func( 1277 self, device, dtype, requires_grad, **kwargs 1278 ) 1279 return TrackedInputIter( 1280 iter(references), 1281 "reference input", 1282 set_seed=set_seed, 1283 restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, 1284 ) 1285 1286 def error_inputs(self, device, **kwargs): 1287 """ 1288 Returns an iterable of ErrorInputs. 1289 """ 1290 set_seed = kwargs.pop("set_seed", True) 1291 errs = self.error_inputs_func(self, device, **kwargs) 1292 return TrackedInputIter( 1293 iter(errs), 1294 "error input", 1295 callback=lambda e: e.sample_input, 1296 set_seed=set_seed, 1297 restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX, 1298 ) 1299 1300 def error_inputs_sparse(self, device, layout, **kwargs): 1301 """ 1302 Returns an iterable of ErrorInputs that contain sparse sample 1303 inputs with a specified layout. 1304 """ 1305 if not self.supports_sparse_layout(layout): 1306 raise unittest.SkipTest("unsupported sparse layout") 1307 return self.error_inputs_sparse_func(self, device, layout, **kwargs) 1308 1309 def supports_sparse_layout(self, layout): 1310 """Return True if OpInfo supports the specified sparse layout.""" 1311 layout_name = str(layout).split(".")[-1] 1312 # map torch.sparse_coo to OpInfo.supports_sparse: 1313 layout_name = layout_name.replace("_coo", "") 1314 return getattr(self, f"supports_{layout_name}") 1315 1316 def sample_inputs_sparse( 1317 self, layout, device, dtype, requires_grad=False, **kwargs 1318 ): 1319 """Returns an iterable of SampleInputs that contain inputs with a 1320 specified sparse layout. 1321 """ 1322 layout_name = str(layout).split(".")[-1] 1323 sample_inputs_mth = getattr(self, "sample_inputs_" + layout_name) 1324 1325 def non_empty_sampler(op, generator): 1326 found_sample = False 1327 for sample in generator: 1328 found_sample = True 1329 yield sample 1330 if not found_sample: 1331 raise unittest.SkipTest("NO SAMPLES!") 1332 1333 return non_empty_sampler( 1334 self, 1335 sample_inputs_mth(device, dtype, requires_grad=requires_grad, **kwargs), 1336 ) 1337 1338 def _sample_inputs_unspecified(self, *args, **kwargs): 1339 """Raises an NotImplemented exception in a OpInfo instance creation 1340 that specifies supports_sparse(|_csr|_csc|_bsr|_bsc)=True 1341 without specifying the corresponding sample function as 1342 sample_inputs_sparse_(coo|csr|csc|bsr|bsc)_func. 1343 1344 To avoid this, either define the corresponding sample function, 1345 or re-map unsupported samples to error inputs in an appropiate 1346 1347 opinfo/definitions/sparse.py:_validate_sample_input_sparse_<op> 1348 1349 function. 1350 """ 1351 raise NotImplementedError("no sample function specified") 1352 1353 def sample_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs): 1354 """Returns an iterable of SampleInputs that contain inputs with sparse 1355 coo layout. 1356 """ 1357 return self.sample_inputs_sparse_coo_func( 1358 self, device, dtype, requires_grad, **kwargs 1359 ) 1360 1361 def sample_inputs_sparse_csr(self, device, dtype, requires_grad=False, **kwargs): 1362 """Returns an iterable of SampleInputs that contain inputs with sparse 1363 csr layout. 1364 """ 1365 return self.sample_inputs_sparse_csr_func( 1366 self, device, dtype, requires_grad, **kwargs 1367 ) 1368 1369 def sample_inputs_sparse_csc(self, device, dtype, requires_grad=False, **kwargs): 1370 """Returns an iterable of SampleInputs that contain inputs with sparse 1371 csc layout. 1372 """ 1373 return self.sample_inputs_sparse_csc_func( 1374 self, device, dtype, requires_grad, **kwargs 1375 ) 1376 1377 def sample_inputs_sparse_bsr(self, device, dtype, requires_grad=False, **kwargs): 1378 """Returns an iterable of SampleInputs that contain inputs with sparse 1379 bsr layout. 1380 """ 1381 return self.sample_inputs_sparse_bsr_func( 1382 self, device, dtype, requires_grad, **kwargs 1383 ) 1384 1385 def sample_inputs_sparse_bsc(self, device, dtype, requires_grad=False, **kwargs): 1386 """Returns an iterable of SampleInputs that contain inputs with sparse 1387 bsc layout. 1388 """ 1389 return self.sample_inputs_sparse_bsc_func( 1390 self, device, dtype, requires_grad, **kwargs 1391 ) 1392 1393 def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): 1394 """Returns the decorators targeting the given test.""" 1395 result = [] 1396 for decorator in self.decorators: 1397 if isinstance(decorator, DecorateInfo): 1398 if decorator.is_active( 1399 test_class, test_name, device, dtype, param_kwargs 1400 ): 1401 result.extend(decorator.decorators) 1402 else: 1403 result.append(decorator) 1404 return result 1405 1406 def supported_dtypes(self, device_type): 1407 if device_type == "privateuse1": 1408 device_type = torch._C._get_privateuse1_backend_name() 1409 device_type = torch.device(device_type).type 1410 if device_type == "cuda": 1411 return self.dtypesIfROCM if TEST_WITH_ROCM else self.dtypesIfCUDA 1412 if device_type == "xpu": 1413 return self.dtypesIfXPU 1414 if device_type == "hpu": 1415 return self.dtypesIfHpu 1416 return self.dtypes 1417 1418 def supported_backward_dtypes(self, device_type): 1419 if not self.supports_autograd: 1420 return set() 1421 1422 if device_type == "privateuse1": 1423 device_type = torch._C._get_privateuse1_backend_name() 1424 device_type = torch.device(device_type).type 1425 backward_dtypes = None 1426 if device_type == "cuda": 1427 backward_dtypes = ( 1428 self.backward_dtypesIfROCM 1429 if TEST_WITH_ROCM 1430 else self.backward_dtypesIfCUDA 1431 ) 1432 elif device_type == "hpu": 1433 backward_dtype = self.backward_dtypesIfHpu 1434 else: 1435 backward_dtypes = self.backward_dtypes 1436 1437 allowed_backward_dtypes = floating_and_complex_types_and( 1438 torch.bfloat16, torch.float16, torch.complex32 1439 ) 1440 return set(allowed_backward_dtypes).intersection(backward_dtypes) 1441 1442 def supports_dtype(self, dtype, device_type) -> bool: 1443 return dtype in self.supported_dtypes(device_type) 1444 1445 @property 1446 def full_name(self): 1447 """Returns a full name that helps to uniquely identify this OpInfo.""" 1448 variant = "." + self.variant_test_name if self.variant_test_name else "" 1449 # example: "normal.in_place" where "normal" is the name and "in_place" is the variant 1450 return f"{self.name}{variant}" 1451 1452 @property 1453 def formatted_name(self): 1454 """Returns a formatted full name for this OpInfo that can be used in test names.""" 1455 return self.full_name.replace(".", "_") 1456 1457 1458def _generate_reduction_inputs(device, dtype, requires_grad, **kwargs): 1459 """Generates input tensors for testing reduction operators""" 1460 yield make_tensor([], dtype=dtype, device=device, requires_grad=requires_grad) 1461 yield make_tensor([2], dtype=dtype, device=device, requires_grad=requires_grad) 1462 yield make_tensor([3, 5], dtype=dtype, device=device, requires_grad=requires_grad) 1463 yield make_tensor( 1464 [3, 2, 1, 2], dtype=dtype, device=device, requires_grad=requires_grad 1465 ) 1466 1467 1468def _generate_reduction_kwargs(ndim, supports_multiple_dims=True): 1469 """Generates a subset of all valid dim and keepdim kwargs given ndim that 1470 is appropriate for testing reduction operators. 1471 """ 1472 1473 # Test default dim and keepdim 1474 yield {} 1475 1476 # Test reducing inner and outer most dimensions 1477 yield {"dim": 0, "keepdim": True} 1478 yield {"dim": -1, "keepdim": False} 1479 1480 # Test reducing middle dimension 1481 if ndim > 2: 1482 yield {"dim": ndim // 2, "keepdim": True} 1483 1484 if supports_multiple_dims: 1485 # Test reducing all dimensions 1486 yield {"dim": tuple(range(ndim)), "keepdim": False} 1487 1488 # Test reducing both first and last dimensions 1489 if ndim > 1: 1490 yield {"dim": (0, -1), "keepdim": True} 1491 1492 # Test reducing every other dimension starting with the second 1493 if ndim > 3: 1494 yield {"dim": tuple(range(1, ndim, 2)), "keepdim": False} 1495 1496 1497def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs): 1498 """Sample inputs for reduction operators.""" 1499 1500 # TODO(@heitorschueroff) Once all reduction operators are using 1501 # ReductionOpInfo use op_info.supports_multiple_dims directly. 1502 supports_multiple_dims: bool = kwargs.get("supports_multiple_dims", True) 1503 1504 # TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo 1505 # use op_info.generate_args_kwargs directly. 1506 generate_args_kwargs = kwargs.get( 1507 "generate_args_kwargs", lambda *args, **kwargs: (yield (), {}) 1508 ) 1509 1510 for t in _generate_reduction_inputs(device, dtype, requires_grad): 1511 for reduction_kwargs in _generate_reduction_kwargs( 1512 t.ndim, supports_multiple_dims 1513 ): 1514 for args, kwargs in generate_args_kwargs(t, **reduction_kwargs): 1515 kwargs.update(reduction_kwargs) 1516 yield SampleInput( 1517 t.detach().requires_grad_(requires_grad), args=args, kwargs=kwargs 1518 ) 1519 1520 1521# NOTE [Reductions]: 1522# 1523# For testing purposes, we relax the definition of a reduction operator 1524# as defined in the docstring below. We do this to capture operators with 1525# a similar API so they can be tested automatically. However... 1526# 1527# Strictly speaking a reduction operator is an operator that can reduce an 1528# array to a single scalar value and that can be computed from the partial 1529# result of reducing subarrays. This usually means that the reduction operation 1530# should be commutative and associative. This definition is important when it 1531# comes to implementation as it determines how a reduction can be parallelized. 1532# 1533# For example, many summary statistics such as median, mode and quantile cannot 1534# be computed from partial results because these are sorting and counting based 1535# algorithms that need information that would be lost in the reduced value. 1536class ReductionOpInfo(OpInfo): 1537 """Reduction operator information. 1538 1539 An operator is a reduction operator if it reduces one or more dimensions of 1540 the input tensor to a single value. Reduction operators must implement the 1541 following signature: 1542 1543 - `op(input, *args, *, dim=None, keepdim=False, **kwargs) -> Tensor` 1544 1545 ReductionOpInfo tests that reduction operators implement a consistent API. 1546 Optional features such as reducing over multiple dimensions are captured in 1547 the optional keyword parameters of the ReductionOpInfo constructor. 1548 1549 If a reduction operator does not yet implement the full required API of 1550 reduction operators, this should be documented by xfailing the failing 1551 tests rather than adding optional parameters to ReductionOpInfo. 1552 1553 NOTE 1554 The API for reduction operators has not yet been finalized and some 1555 requirements may change. 1556 1557 See tests in test/test_reductions.py 1558 """ 1559 1560 def __init__( 1561 self, 1562 name, 1563 *, 1564 # The identity value for the operator if it has one. 1565 identity: Optional[Any] = None, 1566 # The nan policy for the operator if it implements one. 1567 # - propagate: NaN values are propagated to the output 1568 # - omit: NaN values are discarded during the reduction 1569 nan_policy: Optional[str] = None, 1570 # Whether the operator supports reducing multiple dimensions. 1571 supports_multiple_dims: bool = True, 1572 # Whether the operator promotes integral to floating point dtypes. 1573 promotes_int_to_float: bool = False, 1574 # Whether the operator promotes all integral dtypes to int64. 1575 promotes_int_to_int64: bool = False, 1576 # If a specific dtype is given, then the operator always returns that 1577 # dtype irrespective of the input dtype. If None, the operator returns 1578 # the dtype according to the type promotion rules above. 1579 result_dtype: Optional[torch.dtype] = None, 1580 # Casts complex results to real (e.g. linalg.norm or torch.var) 1581 complex_to_real: bool = False, 1582 # ReductionOpInfo tests generate their own input, dim and keepdim 1583 # arguments and call this function to generate tuples of extra args and 1584 # kwargs to use when calling the op. This is required for operators that 1585 # have other required parameters besides the input tensor. 1586 generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: ( 1587 yield (), 1588 {}, 1589 ), 1590 # Options from the OpInfo base class 1591 **kwargs, 1592 ): 1593 self._original_reduction_args = locals().copy() 1594 assert nan_policy in (None, "propagate", "omit") 1595 1596 # These are mutually exclusive options 1597 assert not (result_dtype and promotes_int_to_float) 1598 assert not (result_dtype and promotes_int_to_int64) 1599 assert not (result_dtype and complex_to_real) 1600 assert not (promotes_int_to_float and promotes_int_to_int64) 1601 1602 # Default sample_inputs_func for ReductionOpInfo which augments sample 1603 # inputs from sample_inputs_reduction with the args and kwargs from 1604 # generate_args_kwargs. This is only used if sample_inputs_func is None. 1605 def sample_inputs_func(*args, **kwargs): 1606 kwargs["supports_multiple_dims"] = supports_multiple_dims 1607 kwargs["generate_args_kwargs"] = generate_args_kwargs 1608 yield from sample_inputs_reduction(*args, **kwargs) 1609 1610 # Override OpInfo defaults and call base class __init__ 1611 kwargs.setdefault("inplace_variant", None) 1612 kwargs.setdefault("sample_inputs_func", sample_inputs_func) 1613 super().__init__(name, promotes_int_to_float=promotes_int_to_float, **kwargs) 1614 1615 self.identity = identity 1616 self.nan_policy = nan_policy 1617 self.supports_multiple_dims = supports_multiple_dims 1618 self.promotes_int_to_int64 = promotes_int_to_int64 1619 self.complex_to_real = complex_to_real 1620 self.result_dtype = result_dtype 1621 self.generate_args_kwargs = generate_args_kwargs 1622 1623 1624# The base reference input generation for elementwise binary operations 1625def _reference_inputs_elementwise_binary( 1626 op, device, dtype, requires_grad, exclude_zero, **kwargs 1627): 1628 yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs) 1629 yield from generate_elementwise_binary_tensors( 1630 op, 1631 device=device, 1632 dtype=dtype, 1633 requires_grad=requires_grad, 1634 exclude_zero=exclude_zero, 1635 ) 1636 if dtype is not torch.bool: 1637 yield from generate_elementwise_binary_small_value_tensors( 1638 op, device=device, dtype=dtype, requires_grad=requires_grad 1639 ) 1640 if dtype not in (torch.bool, torch.uint8, torch.int8): 1641 yield from generate_elementwise_binary_large_value_tensors( 1642 op, device=device, dtype=dtype, requires_grad=requires_grad 1643 ) 1644 yield from generate_elementwise_binary_broadcasting_tensors( 1645 op, 1646 device=device, 1647 dtype=dtype, 1648 requires_grad=requires_grad, 1649 exclude_zero=exclude_zero, 1650 ) 1651 yield from generate_elementwise_binary_with_scalar_samples( 1652 op, device=device, dtype=dtype, requires_grad=requires_grad 1653 ) 1654 1655 yield from generate_elementwise_binary_with_scalar_and_type_promotion_samples( 1656 op, device=device, dtype=dtype, requires_grad=requires_grad 1657 ) 1658 1659 if dtype.is_floating_point or dtype.is_complex: 1660 yield from generate_elementwise_binary_extremal_value_tensors( 1661 op, device=device, dtype=dtype, requires_grad=requires_grad 1662 ) 1663 1664 1665# Note that these references inputs use scalars for the SampleInput.input value, 1666# and many tests require SampleInput.input be a tensor or a list of tensors 1667def reference_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): 1668 if hasattr(op, "rhs_make_tensor_kwargs"): 1669 exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) 1670 1671 gen = partial( 1672 _reference_inputs_elementwise_binary, 1673 op, 1674 device, 1675 dtype, 1676 requires_grad, 1677 exclude_zero, 1678 **kwargs, 1679 ) 1680 1681 # yields "normal" samples 1682 yield from gen() 1683 1684 # yields noncontiguous samples 1685 for sample in gen(): 1686 yield sample.noncontiguous() 1687 1688 yield from generate_elementwise_binary_noncontiguous_tensors( 1689 op, 1690 device=device, 1691 dtype=dtype, 1692 requires_grad=requires_grad, 1693 exclude_zero=exclude_zero, 1694 ) 1695 1696 yield from generate_elementwise_binary_arbitrarily_strided_tensors( 1697 op, 1698 device=device, 1699 dtype=dtype, 1700 requires_grad=requires_grad, 1701 exclude_zero=exclude_zero, 1702 ) 1703 1704 1705# A functional that extends an elementwise binary operator's bespoke error inputs 1706# with generic error inputs for the class of elementwise binary operations 1707def make_error_inputs_elementwise_binary(error_inputs_func): 1708 def error_inputs_func_wrapper(op, device, **kwargs): 1709 if error_inputs_func is not None: 1710 yield from error_inputs_func(op, device, **kwargs) 1711 1712 if not op.supports_rhs_python_scalar: 1713 si = SampleInput(torch.tensor((1, 2, 3), device=device), args=(2,)) 1714 yield ErrorInput(si, error_type=Exception, error_regex="") 1715 1716 if not op.supports_one_python_scalar: 1717 si = SampleInput(2, args=(torch.tensor((1, 2, 3), device=device),)) 1718 yield ErrorInput(si, error_type=Exception, error_regex="") 1719 1720 if ( 1721 not kwargs.get("skip_two_python_scalars", False) 1722 and not op.supports_two_python_scalars 1723 ): 1724 si = SampleInput(2, args=(3,)) 1725 yield ErrorInput(si, error_type=Exception, error_regex="") 1726 1727 return error_inputs_func_wrapper 1728 1729 1730# The following functions and classes are for testing elementwise binary operators. 1731 1732 1733# Returns a generator of pairs of contiguous tensors on the requested device 1734# and with the requested dtype. 1735# 1736# This function is intended to test the non-vectorized and vectorized code 1737# paths of elementwise binary functions, as well as their handling of odd tensor 1738# sizes (like zero-dim tensors and tensors with zero elements). 1739# 1740# Each iterable will include an a tensor with no elements, 1741# zero dim (scalar) tensors, small 1D tensors, a medium 1D tensor, and 1742# a large 2D tensor. 1743def generate_elementwise_binary_tensors( 1744 op, *, device, dtype, requires_grad=False, exclude_zero=False 1745): 1746 shapes = ( 1747 # tensors with no elements 1748 (0,), 1749 (1, 0, 3), 1750 # zero dim (scalar) tensor 1751 (), 1752 # small 1D tensor 1753 (20,), 1754 # medium 1D tensor 1755 (812,), 1756 # large 2D tensor 1757 (1029, 917), 1758 ) 1759 1760 make_arg = partial( 1761 make_tensor, 1762 device=device, 1763 dtype=dtype, 1764 requires_grad=requires_grad, 1765 exclude_zero=exclude_zero, 1766 ) 1767 for shape in shapes: 1768 lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) 1769 rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) 1770 yield SampleInput(lhs, args=(rhs,)) 1771 1772 1773def generate_elementwise_binary_arbitrarily_strided_tensors( 1774 op, *, device, dtype, requires_grad=False, exclude_zero=False 1775): 1776 # shape, strides, offset 1777 strided_cases = ( 1778 ((5, 6, 2), (1, 1, 7), 2), 1779 ((5, 5, 4), (1, 1, 7), 2), 1780 ((5, 5, 2), (4, 5, 7), 3), 1781 ((5, 5, 2), (5, 5, 7), 3), 1782 ((5, 5, 2), (5, 5, 5), 3), 1783 ((9, 5, 2), (0, 1, 7), 3), 1784 ) 1785 1786 make_arg = partial( 1787 make_tensor, 1788 device=device, 1789 dtype=dtype, 1790 requires_grad=requires_grad, 1791 exclude_zero=exclude_zero, 1792 ) 1793 for shape, strides, offset in strided_cases: 1794 a = make_arg( 1795 500, 1796 ).as_strided(shape, strides, offset) 1797 b = make_arg(shape) 1798 yield SampleInput(a, args=(b,)) 1799 1800 1801# Returns a generator of pairs of contiguous tensors on the requested device and with 1802# the requested dtype. 1803# 1804# Unlike the previous function, the values in these tensors are specified manually. 1805def generate_elementwise_binary_small_value_tensors( 1806 op, *, device, dtype, requires_grad=False, exclude_zero=None 1807): 1808 if exclude_zero is None: 1809 if hasattr(op, "rhs_make_tensor_kwargs"): 1810 exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) 1811 1812 # defines interesting values 1813 _unsigned_int_vals = (0, 1, 55, 127, 128, 190, 210, 220, 254) 1814 _int_vals = (0, -1, 1, -55, 55, -127, 127, -128) 1815 _float_vals = ( 1816 0.0, 1817 -0.0, 1818 -0.001, 1819 0.001, 1820 -0.25, 1821 0.25, 1822 -1.0, 1823 1.0, 1824 -math.pi / 2, 1825 math.pi / 2, 1826 -math.pi + 0.00001, 1827 math.pi - 0.00001, 1828 -math.pi, 1829 math.pi, 1830 -math.pi - 0.00001, 1831 math.pi + 0.00001, 1832 ) 1833 1834 l_vals = [] 1835 r_vals = [] 1836 1837 if dtype.is_floating_point: 1838 prod = product(_float_vals, _float_vals) 1839 elif dtype.is_complex: 1840 complex_vals = product(_float_vals, _float_vals) 1841 # Note the use of list is required here or the map generator will be 1842 # emptied by the following product and it won't produce the desired cross-product 1843 complex_vals = [complex(*x) for x in complex_vals] 1844 prod = product(complex_vals, complex_vals) 1845 elif dtype in (torch.int8, torch.int16, torch.int32, torch.int64): 1846 prod = product(_int_vals, _int_vals) 1847 elif dtype is torch.uint8: 1848 prod = product(_unsigned_int_vals, _unsigned_int_vals) 1849 else: 1850 raise ValueError("Unsupported dtype!") 1851 1852 for l, r in prod: 1853 l_vals.append(l) 1854 if r == 0 and exclude_zero: 1855 r_vals.append(1) 1856 else: 1857 r_vals.append(r) 1858 1859 lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) 1860 rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) 1861 1862 yield SampleInput(lhs, args=(rhs,)) 1863 1864 1865def generate_elementwise_binary_large_value_tensors( 1866 op, *, device, dtype, requires_grad=False 1867): 1868 _large_int_vals = (-1113, 1113, -10701, 10701) 1869 _large_float16_vals = (-501, 501, -1001.2, 1001.2, -13437.7, 13437.7) 1870 _large_float_vals = _large_float16_vals + (-4988429.2, 4988429.2, -1e20, 1e20) 1871 1872 l_vals = [] 1873 r_vals = [] 1874 1875 if dtype == torch.float16: 1876 prod = product(_large_float16_vals, _large_float16_vals) 1877 elif dtype.is_floating_point: 1878 prod = product(_large_float_vals, _large_float_vals) 1879 elif dtype.is_complex: 1880 complex_vals = product(_large_float_vals, _large_float_vals) 1881 # Note the use of list is required here or the map generator will be 1882 # emptied by the following product and it won't produce the desired cross-product 1883 complex_vals = [complex(*x) for x in complex_vals] 1884 prod = product(complex_vals, complex_vals) 1885 elif dtype in (torch.int16, torch.int32, torch.int64): 1886 prod = product(_large_int_vals, _large_int_vals) 1887 else: 1888 raise ValueError("Unsupported dtype!") 1889 1890 for l, r in prod: 1891 l_vals.append(l) 1892 r_vals.append(r) 1893 1894 lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) 1895 rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) 1896 1897 yield SampleInput(lhs, args=(rhs,)) 1898 1899 1900def generate_elementwise_binary_extremal_value_tensors( 1901 op, *, device, dtype, requires_grad=False 1902): 1903 _float_extremals = (float("inf"), float("-inf"), float("nan")) 1904 1905 l_vals = [] 1906 r_vals = [] 1907 1908 if dtype.is_floating_point: 1909 prod = product(_float_extremals, _float_extremals) 1910 elif dtype.is_complex: 1911 complex_vals = product(_float_extremals, _float_extremals) 1912 # Note the use of list is required here or the map generator will be 1913 # emptied by the following product and it won't produce the desired cross-product 1914 complex_vals = [complex(*x) for x in complex_vals] 1915 prod = product(complex_vals, complex_vals) 1916 else: 1917 raise ValueError("Unsupported dtype!") 1918 1919 for l, r in prod: 1920 l_vals.append(l) 1921 r_vals.append(r) 1922 1923 lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad) 1924 rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad) 1925 1926 yield SampleInput(lhs, args=(rhs,)) 1927 1928 # Test case for NaN propagation 1929 nan = ( 1930 float("nan") if dtype.is_floating_point else complex(float("nan"), float("nan")) 1931 ) 1932 lhs = make_tensor( 1933 (128, 128), device=device, dtype=dtype, requires_grad=requires_grad 1934 ) 1935 lhs.view(-1)[::3] = nan 1936 rhs = make_tensor( 1937 (128, 128), device=device, dtype=dtype, requires_grad=requires_grad 1938 ) 1939 rhs.view(-1)[::3] = nan 1940 1941 yield SampleInput(lhs, args=(rhs,)) 1942 1943 1944# Returns a generator of pairs of contiguous and noncontiguous tensors that 1945# require broadcasting 1946def generate_elementwise_binary_broadcasting_tensors( 1947 op, *, device, dtype, requires_grad=False, exclude_zero=False 1948): 1949 shapes = ( 1950 ((1,), ()), 1951 ((2,), ()), 1952 ((1,), (2,)), 1953 ((2, 1), (2,)), 1954 ((1, 2), (2,)), 1955 ((3, 2), (2,)), 1956 ((1, 3, 2), (2,)), 1957 ((1, 3, 2), (3, 2)), 1958 ((3, 1, 2), (3, 2)), 1959 ((2, 3, 2), ()), 1960 ((3, 1, 2), (1, 3, 2)), 1961 ) 1962 1963 make_arg = partial( 1964 make_tensor, 1965 device=device, 1966 dtype=dtype, 1967 requires_grad=requires_grad, 1968 exclude_zero=exclude_zero, 1969 ) 1970 for shape, noncontiguous in product(shapes, [True, False]): 1971 shape_lhs, shape_rhs = shape 1972 lhs = make_arg( 1973 shape_lhs, noncontiguous=noncontiguous, **op.lhs_make_tensor_kwargs 1974 ) 1975 rhs = make_arg( 1976 shape_rhs, noncontiguous=noncontiguous, **op.rhs_make_tensor_kwargs 1977 ) 1978 1979 yield SampleInput(lhs, args=(rhs,), broadcasts_input=True) 1980 1981 1982# Returns a generator of pairs of contiguous tensors and scalars 1983def generate_elementwise_binary_with_scalar_samples( 1984 op, *, device, dtype, requires_grad=False 1985): 1986 make_arg = partial( 1987 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad 1988 ) 1989 1990 shapes = ((), (3,), (5, 3), (0, 1, 3), (1, 5)) 1991 if op.supports_rhs_python_scalar: 1992 for shape in shapes: 1993 lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) 1994 rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) 1995 lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item() 1996 rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item() 1997 1998 yield SampleInput(lhs, args=(rhs_scalar,)) 1999 2000 # Extends with scalar lhs 2001 if op.supports_one_python_scalar: 2002 yield SampleInput(lhs_scalar, args=(rhs,)) 2003 2004 if op.supports_two_python_scalars: 2005 lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item() 2006 rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item() 2007 2008 yield SampleInput(lhs_scalar, args=(rhs_scalar,)) 2009 2010 2011# Returns a generator of pairs of contiguous tensors and 0d tensors and scalars and type promotion 2012def generate_elementwise_binary_with_scalar_and_type_promotion_samples( 2013 op, *, device, dtype, requires_grad=False 2014): 2015 # add these samples only for logical and comparison ops, arithmetic ops are not happy about extremal scalars 2016 if op.name in ( 2017 "eq", 2018 "ne", 2019 "gt", 2020 "ge", 2021 "lt", 2022 "le", 2023 "logical_and", 2024 "logical_or", 2025 "logical_xor", 2026 ): 2027 make_arg = partial( 2028 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad 2029 ) 2030 shape = ( 2031 23, 2032 ) # this shape is big enough to trigger vectorization, and has non-vectorized tail 2033 values = (float("nan"), float("inf"), -float("inf")) 2034 scalar_tensors = tuple(torch.tensor(val) for val in values) 2035 if op.supports_rhs_python_scalar: 2036 lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) 2037 rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) 2038 for scalar in values + scalar_tensors: 2039 yield SampleInput(lhs, args=(scalar,)) 2040 # Extends with scalar lhs 2041 if op.supports_one_python_scalar: 2042 yield SampleInput(scalar, args=(rhs,)) 2043 2044 2045# Returns a generator of pairs of noncontiguous tensors 2046def generate_elementwise_binary_noncontiguous_tensors( 2047 op, *, device, dtype, requires_grad=False, exclude_zero=False 2048): 2049 make_arg = partial( 2050 make_tensor, 2051 device=device, 2052 dtype=dtype, 2053 requires_grad=requires_grad, 2054 exclude_zero=exclude_zero, 2055 ) 2056 2057 # Generic noncontiguity 2058 lhs = make_arg((1026,), noncontiguous=True, **op.lhs_make_tensor_kwargs) 2059 rhs = make_arg((1026,), noncontiguous=True, **op.rhs_make_tensor_kwargs) 2060 2061 yield SampleInput(lhs.clone(), args=(rhs.clone(),)) 2062 yield SampleInput(lhs.contiguous(), args=(rhs,)) 2063 2064 # Transposed 2065 lhs = make_arg((789, 357), **op.lhs_make_tensor_kwargs) 2066 rhs = make_arg((789, 357), **op.rhs_make_tensor_kwargs) 2067 2068 yield SampleInput(lhs.T, args=(rhs.T,)) 2069 2070 # More noncontiguity 2071 shapes = ((5, 7), (1024,)) 2072 2073 for shape in shapes: 2074 lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) 2075 rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) 2076 2077 lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] 2078 lhs_non_contig.copy_(lhs) 2079 2080 rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] 2081 rhs_non_contig.copy_(rhs) 2082 2083 yield SampleInput(lhs_non_contig.clone(), args=(rhs_non_contig.clone(),)) 2084 yield SampleInput(lhs_non_contig.contiguous(), args=(rhs_non_contig,)) 2085 2086 # Noncontiguous indices 2087 shape = (2, 2, 1, 2) 2088 lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) 2089 rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) 2090 2091 lhs_non_contig = lhs[:, 1, ...] 2092 rhs_non_contig = rhs[:, 1, ...] 2093 2094 yield SampleInput(lhs_non_contig.clone(), args=(rhs_non_contig.clone(),)) 2095 yield SampleInput(lhs_non_contig.contiguous(), args=(rhs_non_contig,)) 2096 2097 # Expanded tensors 2098 shapes = ((1, 3), (1, 7), (5, 7)) 2099 2100 for shape in shapes: 2101 lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) 2102 rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) 2103 2104 lhs_non_contig = lhs.expand(3, -1, -1) 2105 rhs_non_contig = rhs.expand(3, -1, -1) 2106 2107 yield SampleInput(lhs_non_contig, args=(rhs_non_contig,)) 2108 2109 2110# Sample inputs for elementwise binary operators, like add 2111def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): 2112 _M = S if kwargs.get("small_inputs_only", False) else M 2113 _S = XS if kwargs.get("small_inputs_only", False) else S 2114 2115 if hasattr(op, "rhs_make_tensor_kwargs"): 2116 exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) 2117 2118 make_arg = partial( 2119 make_tensor, 2120 device=device, 2121 dtype=dtype, 2122 requires_grad=requires_grad, 2123 exclude_zero=exclude_zero, 2124 ) 2125 2126 shapes = ( 2127 ((), ()), 2128 ((_S,), ()), 2129 ((_S, 1), (_S,)), 2130 ((_M, _S), ()), 2131 ((_S, _M, _S), (_M, _S)), 2132 ((_S, _M, _S), (_S, _M, _S)), 2133 ((_M, 1, _S), (_M, _S)), 2134 ((_M, 1, _S), (1, _M, _S)), 2135 ((0, 1, XS), (0, _M, XS)), 2136 ) 2137 2138 sample_kwargs = kwargs.get("sample_kwargs", {}) 2139 2140 for shape_lhs, shape_rhs in shapes: 2141 lhs = make_arg(shape_lhs, **op.lhs_make_tensor_kwargs) 2142 rhs = make_arg(shape_rhs, **op.rhs_make_tensor_kwargs) 2143 broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs) 2144 2145 yield SampleInput( 2146 lhs, args=(rhs,), kwargs=sample_kwargs, broadcasts_input=broadcasts_input 2147 ) 2148 2149 2150# Metadata class for binary "universal functions (ufuncs)" that accept two 2151# tensor and have common properties 2152class BinaryUfuncInfo(OpInfo): 2153 """Operator information for 'universal binary functions (binary ufuncs).' 2154 These are functions of two tensors with common properties like: 2155 - they are elementwise functions 2156 - the output shape is determined by the input shape 2157 - they typically have method and inplace variants 2158 - they typically support the out kwarg 2159 - they typically have NumPy or SciPy references 2160 See NumPy's universal function documentation 2161 (https://numpy.org/doc/stable/reference/ufuncs.html) for more details 2162 about the concept of ufuncs. 2163 """ 2164 2165 def __init__( 2166 self, 2167 name, 2168 *, 2169 sample_inputs_func=sample_inputs_elementwise_binary, 2170 reference_inputs_func=reference_inputs_elementwise_binary, 2171 error_inputs_func=None, 2172 lhs_make_tensor_kwargs=None, 2173 rhs_make_tensor_kwargs=None, 2174 always_returns_bool=False, # Set to true if the op always returns bool tensors 2175 supports_rhs_python_scalar=True, # Whether the operator allows Tensor x scalar inputs 2176 supports_one_python_scalar=False, # Whether the operator allows scalar x tensor and tensor x scalar inputs 2177 supports_two_python_scalars=False, # Whether the operator allows scalar x scalar inputs 2178 **kwargs, 2179 ): 2180 self._original_binary_ufunc_args = locals().copy() 2181 2182 # Elementwise binary operations perform the equivalent of test_numpy_refs 2183 # in test_binary_ufuncs, but with additional test granularity. So the 2184 # generic test_ops.py test is skipped because it's redundant. 2185 common_skips = ( 2186 DecorateInfo( 2187 unittest.skip("Skipping redundant test."), 2188 "TestCommon", 2189 "test_numpy_refs", 2190 ), 2191 ) 2192 kwargs["skips"] = kwargs.get("skips", ()) + common_skips 2193 super().__init__( 2194 name, 2195 sample_inputs_func=sample_inputs_func, 2196 reference_inputs_func=reference_inputs_func, 2197 error_inputs_func=make_error_inputs_elementwise_binary(error_inputs_func), 2198 **kwargs, 2199 ) 2200 2201 # [lr]hs_make_tensor_kwargs are part of the OpInfo to be able to dynamically generate valid samples later on. 2202 if lhs_make_tensor_kwargs is None: 2203 lhs_make_tensor_kwargs = {} 2204 self.lhs_make_tensor_kwargs = lhs_make_tensor_kwargs 2205 2206 if rhs_make_tensor_kwargs is None: 2207 rhs_make_tensor_kwargs = {} 2208 self.rhs_make_tensor_kwargs = rhs_make_tensor_kwargs 2209 2210 self.always_returns_bool = always_returns_bool 2211 self.supports_rhs_python_scalar = supports_rhs_python_scalar 2212 self.supports_one_python_scalar = supports_one_python_scalar 2213 self.supports_two_python_scalars = supports_two_python_scalars 2214 2215 if self.supports_two_python_scalars: 2216 self.supports_one_python_scalar = True 2217 2218 if self.supports_one_python_scalar: 2219 assert ( 2220 supports_rhs_python_scalar 2221 ), "Can't support lhs and rhs Python scalars but not rhs scalars!" 2222 2223 2224# The following functions and classes are for testing elementwise unary operators. 2225def sample_inputs_elementwise_unary( 2226 op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs 2227): 2228 if not op_kwargs: 2229 op_kwargs = {} 2230 2231 _L = S if kwargs.get("small_inputs_only", False) else L 2232 2233 low, high = op_info.domain 2234 is_floating = dtype.is_floating_point or dtype.is_complex 2235 low = low if low is None or not is_floating else low + op_info._domain_eps 2236 high = high if high is None or not is_floating else high - op_info._domain_eps 2237 if ( 2238 op_info.supports_sparse_csr 2239 or op_info.supports_sparse_csc 2240 or op_info.supports_sparse_bsr 2241 or op_info.supports_sparse_bsc 2242 ): 2243 # Tensors with dim=2 for sparse compressed testing 2244 yield SampleInput( 2245 make_tensor( 2246 (_L, _L), 2247 device=device, 2248 dtype=dtype, 2249 low=low, 2250 high=high, 2251 requires_grad=requires_grad, 2252 ), 2253 kwargs=op_kwargs, 2254 ) 2255 else: 2256 # Creates a 1D, empty, and scalar tensor 2257 for shape in ((_L,), (1, 0, 3), ()): 2258 yield SampleInput( 2259 make_tensor( 2260 shape, 2261 device=device, 2262 dtype=dtype, 2263 low=low, 2264 high=high, 2265 requires_grad=requires_grad, 2266 ), 2267 kwargs=op_kwargs, 2268 ) 2269 2270 2271# Replace values satisfying condition with a safe value. This is used to block 2272# out values the could cause singularity like tan(pi/2) 2273def _replace_values_in_tensor(tensor, condition, safe_value): 2274 mask = condition(tensor) 2275 tensor.masked_fill_(mask, safe_value) 2276 2277 2278# Helper to create a unary elementwise tensor with valid inputs 2279def _make_unary_elementwise_tensor(shape, *, op, dtype, **kwargs): 2280 low, high = op.domain 2281 is_floating = dtype.is_floating_point or dtype.is_complex 2282 low = low if low is None or not is_floating else low + op._domain_eps 2283 high = high if high is None or not is_floating else high - op._domain_eps 2284 2285 a = make_tensor(shape, low=low, high=high, dtype=dtype, **kwargs) 2286 2287 if op.reference_numerics_filter is not None and dtype is not torch.bool: 2288 condition, safe_value = op.reference_numerics_filter 2289 _replace_values_in_tensor(a, condition, safe_value) 2290 2291 return a 2292 2293 2294# Restricts the values in the tensor to the domain of the 2295# given elementwise unary operator 2296def _filter_unary_elementwise_tensor(a, *, op): 2297 # short-circuits for boolean tensors 2298 if a.dtype is torch.bool: 2299 return a 2300 2301 low, high = op.domain 2302 is_floating = a.dtype.is_floating_point or a.dtype.is_complex 2303 low = low if low is None or not is_floating else low + op._domain_eps 2304 high = high if high is None or not is_floating else high - op._domain_eps 2305 2306 if a.dtype is torch.uint8 and low is not None: 2307 low = max(low, 0) 2308 2309 if not a.dtype.is_floating_point and not a.dtype.is_complex: 2310 low = math.ceil(low) if low is not None else None 2311 high = math.floor(high) if high is not None else None 2312 2313 if op.reference_numerics_filter is not None: 2314 condition, safe_value = op.reference_numerics_filter 2315 _replace_values_in_tensor(a, condition, safe_value) 2316 2317 if low is not None or high is not None: 2318 if a.dtype.is_complex: 2319 a.real.clamp_(low, high) 2320 a.imag.clamp_(low, high) 2321 else: 2322 a.clamp_(min=low, max=high) 2323 2324 return a 2325 2326 2327def generate_elementwise_unary_tensors(op, *, device, dtype, requires_grad, **kwargs): 2328 # Special-cases bool 2329 if dtype is torch.bool: 2330 tensors = ( 2331 torch.empty(0, device=device, dtype=torch.bool), 2332 torch.tensor(True, device=device), 2333 torch.tensor(False, device=device), 2334 torch.tensor((True, False), device=device), 2335 make_tensor((812,), device=device, dtype=dtype), 2336 make_tensor((1029, 917), device=device, dtype=dtype), 2337 ) 2338 for a in tensors: 2339 yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) 2340 2341 shapes = ( 2342 (1029, 917), 2343 (812,), 2344 # Empty sizes 2345 (0,), 2346 (0, 3, 3), 2347 (1, 0, 5), 2348 (6, 0, 0, 0), 2349 (3, 0, 1, 0), 2350 ) 2351 2352 make_arg = partial( 2353 _make_unary_elementwise_tensor, 2354 op=op, 2355 device=device, 2356 dtype=dtype, 2357 requires_grad=requires_grad, 2358 ) 2359 for shape in shapes: 2360 a = make_arg(shape) 2361 yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) 2362 2363 2364def generate_elementwise_unary_small_value_tensors( 2365 op, *, device, dtype, requires_grad=False 2366): 2367 for sample in generate_elementwise_binary_small_value_tensors( 2368 op, device=device, dtype=dtype, requires_grad=requires_grad 2369 ): 2370 a = _filter_unary_elementwise_tensor(sample.input, op=op) 2371 yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) 2372 2373 2374def generate_elementwise_unary_large_value_tensors( 2375 op, *, device, dtype, requires_grad=False 2376): 2377 for sample in generate_elementwise_binary_large_value_tensors( 2378 op, device=device, dtype=dtype, requires_grad=requires_grad 2379 ): 2380 a = _filter_unary_elementwise_tensor(sample.input, op=op) 2381 yield SampleInput(sample.input, kwargs=op.sample_kwargs(device, dtype, a)[0]) 2382 2383 2384def generate_elementwise_unary_extremal_value_tensors( 2385 op, *, device, dtype, requires_grad=False 2386): 2387 for sample in generate_elementwise_binary_extremal_value_tensors( 2388 op, device=device, dtype=dtype, requires_grad=requires_grad 2389 ): 2390 yield SampleInput( 2391 sample.input, kwargs=op.sample_kwargs(device, dtype, sample.input)[0] 2392 ) 2393 2394 2395def generate_elementwise_unary_noncontiguous_tensors( 2396 op, *, device, dtype, requires_grad=False 2397): 2398 make_arg = partial( 2399 _make_unary_elementwise_tensor, 2400 op=op, 2401 device=device, 2402 dtype=dtype, 2403 requires_grad=requires_grad, 2404 ) 2405 2406 # Generic noncontiguity 2407 t = make_arg((1026,), noncontiguous=True) 2408 yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0]) 2409 2410 # Transposed 2411 t = make_arg((1024, 1024)).T 2412 yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0]) 2413 2414 # Expanded tensors 2415 shapes = ((1, 3), (1, 7), (5, 7)) 2416 2417 for shape in shapes: 2418 t = make_arg(shape) 2419 t_non_contig = t.expand(3, -1, -1) 2420 yield SampleInput( 2421 t_non_contig, kwargs=op.sample_kwargs(device, dtype, t_non_contig)[0] 2422 ) 2423 2424 2425def generate_elementwise_unary_arbitrarily_strided_tensors( 2426 op, *, device, dtype, requires_grad=False 2427): 2428 # shape, strides, offset 2429 strided_cases = ( 2430 ((5, 6, 2), (1, 1, 7), 2), 2431 ((5, 5, 4), (1, 1, 7), 2), 2432 ((5, 5, 2), (4, 5, 7), 3), 2433 ((5, 5, 2), (5, 5, 7), 3), 2434 ((5, 5, 2), (5, 5, 5), 3), 2435 ((9, 5, 2), (0, 1, 7), 3), 2436 ) 2437 2438 make_arg = partial( 2439 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad 2440 ) 2441 for shape, strides, offset in strided_cases: 2442 a = make_arg( 2443 500, 2444 ).as_strided(shape, strides, offset) 2445 yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) 2446 2447 2448# Reuses the elementwise binary generators for consistency 2449# TODO: in the future generalize the reference generators to handle n-ary elementwise operations 2450def _reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs): 2451 yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs) 2452 2453 yield from generate_elementwise_unary_tensors( 2454 op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs 2455 ) 2456 2457 if dtype is not torch.bool: 2458 yield from generate_elementwise_unary_small_value_tensors( 2459 op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs 2460 ) 2461 if dtype not in (torch.bool, torch.uint8, torch.int8) and ( 2462 op.handles_large_floats 2463 or (not dtype.is_floating_point and not dtype.is_complex) 2464 ): 2465 yield from generate_elementwise_unary_large_value_tensors( 2466 op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs 2467 ) 2468 2469 if dtype.is_floating_point or ( 2470 op.handles_complex_extremal_values and dtype.is_complex 2471 ): 2472 yield from generate_elementwise_unary_extremal_value_tensors( 2473 op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs 2474 ) 2475 2476 2477def reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs): 2478 gen = partial( 2479 _reference_inputs_elementwise_unary, op, device, dtype, requires_grad, **kwargs 2480 ) 2481 2482 # yields "normal" samples 2483 yield from gen() 2484 2485 # yields noncontiguous samples 2486 for sample in gen(): 2487 yield sample.noncontiguous() 2488 2489 yield from generate_elementwise_unary_noncontiguous_tensors( 2490 op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs 2491 ) 2492 2493 yield from generate_elementwise_unary_arbitrarily_strided_tensors( 2494 op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs 2495 ) 2496 2497 2498# Metadata class for unary "universal functions (ufuncs)" that accept a single 2499# tensor and have common properties like: 2500class UnaryUfuncInfo(OpInfo): 2501 """Operator information for 'universal unary functions (unary ufuncs).' 2502 These are functions of a single tensor with common properties like: 2503 - they are elementwise functions 2504 - the input shape is the output shape 2505 - they typically have method and inplace variants 2506 - they typically support the out kwarg 2507 - they typically have NumPy or SciPy references 2508 See NumPy's universal function documentation 2509 (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details 2510 about the concept of ufuncs. 2511 """ 2512 2513 def __init__( 2514 self, 2515 name, # the string name of the function 2516 *, 2517 dtypes=floating_types(), 2518 domain=(None, None), # the [low, high) domain of the function 2519 handles_complex_extremal_values=True, # whether the op correctly handles extremal values (like nan/inf) 2520 handles_large_floats=True, # whether the op correctly handles large float values (like 1e20) 2521 supports_complex_to_float=False, # op supports casting from complex input to real output safely eg. angle 2522 sample_inputs_func=sample_inputs_elementwise_unary, 2523 reference_inputs_func=reference_inputs_elementwise_unary, 2524 sample_kwargs=lambda device, dtype, input: ({}, {}), 2525 reference_numerics_filter=None, # Filters values in the range of the domain specified above but that should not be tested 2526 **kwargs, 2527 ): 2528 self._original_unary_ufunc_args = locals().copy() 2529 2530 super().__init__( 2531 name, 2532 dtypes=dtypes, 2533 sample_inputs_func=sample_inputs_func, 2534 reference_inputs_func=reference_inputs_func, 2535 **kwargs, 2536 ) 2537 self.domain = domain 2538 self.handles_complex_extremal_values = handles_complex_extremal_values 2539 self.handles_large_floats = handles_large_floats 2540 self.supports_complex_to_float = supports_complex_to_float 2541 self.reference_numerics_filter = reference_numerics_filter 2542 2543 # test_unary_ufuncs.py generates its own inputs to test the consistency 2544 # of the operator on sliced tensors, non-contig tensors, etc. 2545 # `sample_kwargs` is a utility function to provide kwargs 2546 # along with those inputs if required (eg. clamp). 2547 # It should return two dictionaries, first holding kwarg for 2548 # torch operator and second one for reference NumPy operator. 2549 self.sample_kwargs = sample_kwargs 2550 2551 # Epsilon to ensure grad and gradgrad checks don't test values 2552 # outside a function's domain. 2553 self._domain_eps = 1e-5 2554 2555 2556def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs): 2557 is_fp16_or_chalf = dtype == torch.complex32 or dtype == torch.half 2558 if not is_fp16_or_chalf: 2559 nd_tensor = partial( 2560 make_tensor, 2561 (S, S + 1, S + 2), 2562 device=device, 2563 dtype=dtype, 2564 requires_grad=requires_grad, 2565 ) 2566 oned_tensor = partial( 2567 make_tensor, (31,), device=device, dtype=dtype, requires_grad=requires_grad 2568 ) 2569 else: 2570 # cuFFT supports powers of 2 for half and complex half precision 2571 # NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args 2572 # where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two 2573 low = None 2574 high = None 2575 if self.name in ["fft.hfft", "fft.irfft", "_refs.fft.hfft", "_refs.fft.irfft"]: 2576 shapes = ((2, 9, 9), (33,)) 2577 elif self.name in [ 2578 "fft.hfft2", 2579 "fft.irfft2", 2580 "_refs.fft.hfft2", 2581 "_refs.fft.irfft2", 2582 ]: 2583 shapes = ((2, 8, 9), (33,)) 2584 elif self.name in [ 2585 "fft.hfftn", 2586 "fft.irfftn", 2587 "_refs.fft.hfftn", 2588 "_refs.fft.irfftn", 2589 ]: 2590 shapes = ((2, 2, 33), (33,)) 2591 # Adjusting the limits because the test would be flaky due to over-saturation of float16 2592 # See: https://github.com/pytorch/pytorch/pull/81416 2593 low = -1.0 2594 high = 1.0 2595 else: 2596 shapes = ((2, 8, 16), (32,)) 2597 nd_tensor = partial( 2598 make_tensor, 2599 shapes[0], 2600 device=device, 2601 low=low, 2602 high=high, 2603 dtype=dtype, 2604 requires_grad=requires_grad, 2605 ) 2606 oned_tensor = partial( 2607 make_tensor, 2608 shapes[1], 2609 device=device, 2610 low=low, 2611 high=high, 2612 dtype=dtype, 2613 requires_grad=requires_grad, 2614 ) 2615 2616 if self.ndimensional == SpectralFuncType.ND: 2617 yield SampleInput( 2618 nd_tensor(), 2619 s=(3, 10) if not is_fp16_or_chalf else (4, 8), 2620 dim=(1, 2), 2621 norm="ortho", 2622 ) 2623 yield SampleInput(nd_tensor(), norm="ortho") 2624 yield SampleInput(nd_tensor(), s=(8,)) 2625 yield SampleInput(oned_tensor()) 2626 yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3, (0, -1)]) 2627 elif self.ndimensional == SpectralFuncType.TwoD: 2628 yield SampleInput( 2629 nd_tensor(), 2630 s=(3, 10) if not is_fp16_or_chalf else (4, 8), 2631 dim=(1, 2), 2632 norm="ortho", 2633 ) 2634 yield SampleInput(nd_tensor(), norm="ortho") 2635 yield SampleInput(nd_tensor(), s=(6, 8) if not is_fp16_or_chalf else (4, 8)) 2636 yield SampleInput(nd_tensor(), dim=0) 2637 yield SampleInput(nd_tensor(), dim=(0, -1)) 2638 yield SampleInput(nd_tensor(), dim=(-3, -2, -1)) 2639 else: 2640 yield SampleInput( 2641 nd_tensor(), 2642 n=10 if not is_fp16_or_chalf else 8, 2643 dim=1, 2644 norm="ortho", 2645 ) 2646 yield SampleInput(nd_tensor(), norm="ortho") 2647 yield SampleInput(nd_tensor(), n=7 if not is_fp16_or_chalf else 8) 2648 yield SampleInput(oned_tensor()) 2649 yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3]) 2650 2651 2652SpectralFuncType = Enum("SpectralFuncType", ("OneD", "TwoD", "ND")) 2653 2654 2655# Metadata class for Fast Fourier Transforms in torch.fft. 2656class SpectralFuncInfo(OpInfo): 2657 """Operator information for torch.fft transforms.""" 2658 2659 def __init__( 2660 self, 2661 name, # the string name of the function 2662 *, 2663 ref=None, # Reference implementation (probably in np.fft namespace) 2664 dtypes=floating_and_complex_types(), 2665 ndimensional: SpectralFuncType, 2666 sample_inputs_func=sample_inputs_spectral_ops, 2667 decorators=None, 2668 **kwargs, 2669 ): 2670 self._original_spectral_func_args = dict(locals()).copy() 2671 self._original_spectral_func_args.update(kwargs) 2672 2673 decorators = list(decorators) if decorators is not None else [] 2674 decorators += [ 2675 skipCPUIfNoFFT, 2676 DecorateInfo( 2677 toleranceOverride({torch.chalf: tol(4e-2, 4e-2)}), 2678 "TestCommon", 2679 "test_complex_half_reference_testing", 2680 ), 2681 ] 2682 2683 super().__init__( 2684 name=name, 2685 dtypes=dtypes, 2686 decorators=decorators, 2687 sample_inputs_func=sample_inputs_func, 2688 **kwargs, 2689 ) 2690 self.ref = ref 2691 self.ndimensional = ndimensional 2692 2693 2694class ShapeFuncInfo(OpInfo): 2695 """Early version of a specialized OpInfo for Shape manipulating operations like tile and roll""" 2696 2697 def __init__( 2698 self, 2699 name, # the string name of the function 2700 *, 2701 ref, # a reference function 2702 dtypes=floating_types(), 2703 dtypesIfCUDA=None, 2704 dtypesIfROCM=None, 2705 dtypesIfXPU=None, 2706 sample_inputs_func=None, 2707 **kwargs, 2708 ): 2709 super().__init__( 2710 name, 2711 dtypes=dtypes, 2712 dtypesIfCUDA=dtypesIfCUDA, 2713 dtypesIfROCM=dtypesIfROCM, 2714 dtypesIfXPU=dtypesIfXPU, 2715 sample_inputs_func=sample_inputs_func, 2716 **kwargs, 2717 ) 2718 self.ref = ref 2719 2720 2721def sample_inputs_foreach( 2722 self, 2723 device, 2724 dtype, 2725 N, 2726 *, 2727 noncontiguous=False, 2728 same_size=False, 2729 low=None, 2730 high=None, 2731 zero_size: bool, 2732 requires_grad: bool, 2733 # mutually exclusive from same_size and zero_size, which are all or nothing 2734 intersperse_empty_tensors: bool = False, 2735): 2736 if zero_size: 2737 return [torch.empty(0, dtype=dtype, device=device) for _ in range(N)] 2738 if same_size: 2739 return [ 2740 make_tensor( 2741 (N, N), 2742 dtype=dtype, 2743 device=device, 2744 noncontiguous=noncontiguous, 2745 low=low, 2746 high=high, 2747 requires_grad=requires_grad, 2748 ) 2749 for _ in range(N) 2750 ] 2751 else: 2752 # interweave some empty tensors + have the last 2 tensors be empty (see #100701) 2753 return [ 2754 torch.empty(0, dtype=dtype, device=device, requires_grad=requires_grad) 2755 if (i % 3 == 0 or i >= N - 2) and intersperse_empty_tensors 2756 else make_tensor( 2757 (N - i, N - i), 2758 dtype=dtype, 2759 device=device, 2760 noncontiguous=noncontiguous, 2761 low=low, 2762 high=high, 2763 requires_grad=requires_grad, 2764 ) 2765 for i in range(N) 2766 ] 2767 2768 2769def get_foreach_method_names(name): 2770 # get torch inplace reference function 2771 op_name = "_foreach_" + name 2772 inplace_op_name = op_name + "_" 2773 2774 op = getattr(torch, op_name, None) 2775 inplace_op = getattr(torch, inplace_op_name, None) 2776 2777 ref = getattr(torch, name, None) 2778 ref_inplace = getattr(torch.Tensor, name + "_", None) 2779 return op, inplace_op, ref, ref_inplace 2780 2781 2782@dataclass 2783class ForeachFuncInfo(OpInfo): 2784 """Early version of a specialized OpInfo for foreach functions 2785 2786 The main differences from the parent class are (a) `dtypes`, `dtypesIfCUDA`, and `dtypesIfROCM` 2787 are set to `get_all_dtypes(include_qint=False)`, and (b) the following arguments. 2788 2789 ``supports_alpha_param=True`` means that the function supports a python scalar (``numbers.Number``) 2790 as the last keyword argument such as `_foreach_add`. 2791 ``supports_scalar_self_arg=True`` means that the function can take a python scalar as its first argument. 2792 Currently only `_foreach_pow` supports this. 2793 ``backward_requires_result=True``, which could sound self-explanatory, means that the function uses 2794 the forward result for its backward computation. 2795 """ 2796 2797 supports_alpha_param: bool = False 2798 supports_scalar_self_arg: bool = False 2799 backward_requires_result: bool = False 2800 2801 def __post_init__(self): 2802 ( 2803 foreach_method, 2804 foreach_method_inplace, 2805 torch_ref_method, 2806 torch_ref_inplace, 2807 ) = get_foreach_method_names(self.name) 2808 if not self.supports_out: 2809 # note(crcrpar): `foreach_method` for `"zero"` is `None` but `None` would call 2810 # `_getattr_qual` in `OpInfo.__post_init__` which should fail since `_foreach_zero` 2811 # is not defined at the moment. Thus to skip the qualification, set a similar torch 2812 # function. 2813 assert foreach_method is None 2814 assert torch_ref_method is None 2815 foreach_method = foreach_method_inplace 2816 torch_ref_method = torch_ref_inplace 2817 2818 self.dtypes = _dispatch_dtypes(get_all_dtypes(include_qint=False)) 2819 2820 self.op = foreach_method 2821 self.method_variant = foreach_method 2822 self.ref = torch_ref_method 2823 self.inplace_variant = foreach_method_inplace 2824 self.ref_inplace = torch_ref_inplace 2825 self.has_no_in_place = self.inplace_variant is None 2826 2827 name = self.name 2828 self.name = f"_foreach_{name}" 2829 if name == "norm": 2830 self.ref = torch.linalg.vector_norm 2831 elif name == "minimum": 2832 # because minimum ref does not support inplace or scalar 2833 self.ref = torch.clamp_max 2834 self.ref_inplace = torch.Tensor.clamp_max_ 2835 elif name == "maximum": 2836 # because maximum ref does not support inplace or scalar 2837 self.ref = torch.clamp_min 2838 self.ref_inplace = torch.Tensor.clamp_min_ 2839 2840 # The following sets `dtypesIfCUDA` and `dtypesIfROCM` accordingly. 2841 super().__post_init__() 2842 2843 def sample_zero_size_inputs(self, device, dtype, requires_grad=False, **kwargs): 2844 if not hasattr(self.sample_inputs_func, "sample_zero_size_tensor_inputs"): 2845 return [] 2846 return self.sample_inputs_func.sample_zero_size_tensor_inputs( 2847 self, device, dtype, requires_grad, **kwargs 2848 ) 2849 2850 2851def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs): 2852 """Gradcheck wrapper for functions that take Hermitian matrices as input. 2853 2854 They require a modified function because the finite-difference algorithm 2855 for calculating derivatives does not preserve the Hermitian property of the input. 2856 """ 2857 return op(input + input.mH, *args, **kwargs) 2858 2859 2860def gradcheck_wrapper_triangular_input(op, *args, upper=False, idx=0, **kwargs): 2861 """Gradcheck wrapper for functions that take lower or upper triangular matrices as input. 2862 2863 They require a modified function because the finite-difference algorithm 2864 for calculating derivatives does not preserve the triangular property of the input. 2865 `idx` is used to specific which `args[idx]` is to be triangularized. 2866 """ 2867 triangular_arg = args[idx].triu() if upper else args[idx].tril() 2868 return op(*args[:idx], triangular_arg, *args[idx + 1 :], upper, **kwargs) 2869 2870 2871def gradcheck_wrapper_triangular_input_real_positive_diagonal( 2872 op, *args, upper=False, idx=0, **kwargs 2873): 2874 """Gradcheck wrapper for functions that take lower/upper triangular matrices 2875 with real and positive diagonals, for example, cholesky-like operations. 2876 """ 2877 arg = args[idx] 2878 arg_diag = arg.diagonal(0, -2, -1) 2879 arg_diag_embed = torch.diag_embed(arg_diag) 2880 id_diag_tensor = torch.ones_like(arg_diag) 2881 id_tensor = torch.diag_embed(id_diag_tensor) 2882 # new_arg = arg - diag(arg) + I 2883 new_arg = arg - arg_diag_embed + id_tensor 2884 return gradcheck_wrapper_triangular_input( 2885 op, *args[:idx], new_arg, *args[idx + 1 :], upper=upper, idx=idx, **kwargs 2886 ) 2887 2888 2889def gradcheck_wrapper_masked_operation(op, input, *args, **kwargs): 2890 """Gradcheck wrapper for masked operations. 2891 2892 When mask is specified, replaces masked-out elements with zeros. 2893 2894 Use for operations that produce non-finite masked-out elements, 2895 for instance, for minimum and maximum reductions. 2896 """ 2897 output = op(input, *args, **kwargs) 2898 mask = kwargs.get("mask") 2899 if mask is not None: 2900 output_mask = torch.masked._output_mask(op, input, *args, **kwargs) 2901 output = torch.where(output_mask, output, output.new_zeros([])) 2902 return output 2903 2904 2905def gradcheck_wrapper_masked_pointwise_operation(op, input, *args, **kwargs): 2906 """Gradcheck wrapper for masked pointwise operations. Assumes that the result 2907 will be masked iff both tensors are masked at a specific index 2908 2909 When mask is specified, replaces masked-out elements with zeros. 2910 2911 Use for operations that produce non-finite masked-out elements, 2912 for instance, for minimum and maximum reductions. 2913 """ 2914 output = op(input, *args, **kwargs) 2915 input_mask = kwargs.get("input_mask") 2916 other_mask = kwargs.get("other_mask") 2917 if input_mask is not None and other_mask is not None: 2918 combined_mask = torch.logical_and(input_mask, other_mask) 2919 new_kwargs = dict(mask=combined_mask, **kwargs) 2920 output_mask = torch.masked._input_mask(input, *args, **new_kwargs) 2921 output = torch.where(output_mask, output, output.new_zeros([])) 2922 return output 2923 2924 2925def clone_sample(sample, **kwargs): 2926 """ 2927 Given a SampleInput, this function analyzes its input, args and kwargs, 2928 and produces a copy with each non-Tensor entry being copied by reference, 2929 and with each Tensor entry cloned with `t.clone().requires_grad_(t.requires_grad)` 2930 """ 2931 2932 def clone_tensor(t): 2933 if isinstance(t, torch.Tensor): 2934 return t.detach().clone().requires_grad_(t.requires_grad) 2935 else: 2936 return t 2937 2938 sample_kwargs = kwargs if kwargs else sample.kwargs 2939 2940 return SampleInput( 2941 clone_tensor(sample.input), 2942 args=tuple(map(clone_tensor, sample.args)), 2943 kwargs={k: clone_tensor(v) for k, v in sample_kwargs.items()}, 2944 ) 2945