1# mypy: ignore-errors 2 3from abc import abstractmethod 4import tempfile 5import unittest 6 7from copy import deepcopy 8from functools import reduce, partial 9from itertools import product 10from operator import mul 11 12 13import torch 14import torch.cuda 15import torch.nn as nn 16import torch.nn.functional as F 17from torch.nn import _reduction as _Reduction 18from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \ 19 gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo 20from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater 21from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors 22from torch.autograd import Variable 23from torch.types import _TensorOrTensors 24import torch.backends.cudnn 25 26from typing import Dict, Callable, Tuple, List, Sequence, Union, Any 27 28TemporaryFile = tempfile.TemporaryFile 29PRECISION = 1e-5 30 31 32def get_reduction(m): 33 result = getattr(m, 'reduction', None) 34 if result is None: 35 result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False) 36 assert result is not None 37 return result 38 39 40def get_weight(m): 41 result = getattr(m, 'weight', None) 42 if result is not None: 43 return result 44 return getattr(m, 'weights', None) 45 46# NOTE [How to check NN module / functional API parity between Python and C++ frontends] 47# 48# The way to check API parity is to add parity tests for the NN module / functional of interest. 49# Here are the detailed steps: 50# 51# For NN module: 52# 1. Make sure you already have a test dict with the module configuration you want to test. 53# 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching 54# the Python module constructor arguments. For example, if in the test dict we pass 55# `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)` 56# as the corresponding C++ constructor argument to `torch::nn::Linear`. 57# 3. If in the process of performing the above step you referenced any variables 58# in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry 59# to the test dict to make sure that those variables are populated with the right Python values. 60# For example, if the Python constructor call is 61# `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`, 62# the corresponding C++ constructor argument is 63# `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`, 64# and the `cpp_var_map` entry must be 65# `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples` 66# used in the C++ constructor argument with the Python tensor value `random_samples`. 67# 68# For NN functional: 69# 1. Make sure you already have a test dict with the functional configuration you want to test. 70# 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`, 71# then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python 72# functional optional arguments. For example, if the test dict's `constructor` entry is 73# `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`, 74# then the `cpp_options_args` entry should be 75# "F::InterpolateFuncOptions().size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)". 76# 3. Otherwise, if the test dict's `constructor` entry looks like 77# `wrap_functional(lambda i: F.some_functional_name(...))`, 78# then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python 79# functional function call. For example, if the test dict's `constructor` entry is 80# `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`, 81# then the `cpp_function_call` entry should be 82# "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))". 83# 4. If in the process of performing the above two steps you referenced any variables 84# in the `cpp_options_args` or `cpp_function_call` entry, you must 85# add `cpp_var_map` entry to the test dict to make sure that those variables 86# are populated with the right Python values. For example, if the test dict's `constructor` entry is 87# `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`, 88# then the `cpp_function_call` entry should be 89# "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))". 90# Notice that there are two variables `i` and `t` that need to have their values provided, 91# and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`. 92# (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value 93# and the C++ parity test mechanism will populate `i` with the Python input value correctly.) 94# 95# There are also a few optional flags in the test dict to control the C++ parity test behavior: 96# 97# - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True. 98# - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True. 99 100 101module_tests = [ 102 dict( 103 module_name='Linear', 104 constructor_args=(10, 8), 105 cpp_constructor_args='torch::nn::LinearOptions(10, 8)', 106 input_size=(4, 10), 107 reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8), 108 with_tf32=True, 109 tf32_precision=0.005, 110 default_dtype=torch.double, 111 ), 112 dict( 113 module_name='Linear', 114 constructor_args=(10, 8, False), 115 cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)', 116 input_size=(4, 10), 117 desc='no_bias', 118 reference_fn=lambda i, p, _: torch.mm(i, p[0].t()), 119 with_tf32=True, 120 tf32_precision=0.005, 121 default_dtype=torch.double, 122 ), 123 dict( 124 module_name='RReLU', 125 input_size=(1, 2, 2), 126 test_cuda=False, 127 default_dtype=torch.double, 128 ), 129 dict( 130 module_name='RReLU', 131 constructor_args=(0.1, 0.9), 132 cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)', 133 input_size=(4, 4, 5), 134 desc='with_up_down', 135 test_cuda=False, 136 default_dtype=torch.double, 137 ), 138 dict( 139 module_name='Flatten', 140 input_size=(2, 3, 4, 5), 141 reference_fn=lambda i, *_: torch.flatten(i, 1), 142 default_dtype=torch.double, 143 ), 144 # TODO: reference function 145 dict( 146 module_name='CrossMapLRN2d', 147 constructor_args=(5, 5e-3, 1e-3, 2), 148 cpp_constructor_args='torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)', 149 input_size=(2, 3, 6, 6), 150 check_gradgrad=False, 151 # TODO(#50743): Figure out the error. "RuntimeError: Unrecognized tensor type ID: Batched" 152 check_batched_grad=False, 153 default_dtype=torch.double, 154 ), 155] 156 157 158# Generates rand tensor with non-equal values. This ensures that duplicate 159# values won't be causing test failure for modules like MaxPooling. 160# size should be small, otherwise randperm fails / long overflows. 161def _rand_tensor_non_equal(*size): 162 total = reduce(mul, size, 1) 163 return torch.randperm(total).view(*size).double() 164 165 166def wrap_functional(fn, **kwargs): 167 class FunctionalModule(nn.Module): 168 def forward(self, *args): 169 return fn(*args, **kwargs) 170 return FunctionalModule 171 172 173def poissonnllloss_no_reduce_test(): 174 t = torch.randn(10, 10) 175 return dict( 176 fullname='PoissonNLLLoss_no_reduce', 177 constructor=wrap_functional( 178 lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')), 179 cpp_function_call='F::poisson_nll_loss(' 180 'i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))', 181 input_fn=lambda: torch.rand(10, 10), 182 cpp_var_map={'i': '_get_input()', 't': t}, 183 reference_fn=lambda i, *_: i.exp() - t.mul(i), 184 pickle=False, 185 default_dtype=torch.double) 186 187 188def bceloss_no_reduce_test(): 189 t = Variable(torch.randn(15, 10).gt(0).to(torch.double)) 190 return dict( 191 fullname='BCELoss_no_reduce', 192 constructor=wrap_functional( 193 lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')), 194 cpp_function_call='F::binary_cross_entropy(' 195 'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))', 196 input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), 197 cpp_var_map={'i': '_get_input()', 't': t}, 198 reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()), 199 pickle=False, 200 precision=7e-4, 201 default_dtype=torch.double) 202 203 204def bceloss_no_reduce_scalar_test(): 205 t = torch.randn(()).gt(0).to(torch.double) 206 return dict( 207 fullname='BCELoss_no_reduce_scalar', 208 constructor=wrap_functional( 209 lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')), 210 cpp_function_call='F::binary_cross_entropy(' 211 'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))', 212 input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), 213 cpp_var_map={'i': '_get_input()', 't': t}, 214 reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()), 215 pickle=False, 216 default_dtype=torch.double) 217 218 219def bceloss_weights_no_reduce_test(): 220 t = Variable(torch.randn(15, 10, dtype=torch.double).gt(0).to(torch.double)) 221 weights = torch.rand(10, dtype=torch.double) 222 return dict( 223 fullname='BCELoss_weights_no_reduce', 224 constructor=wrap_functional( 225 lambda i: F.binary_cross_entropy(i, t.type_as(i), 226 weight=weights.type_as(i), reduction='none')), 227 cpp_function_call='F::binary_cross_entropy(' 228 'i, t.to(i.options()), ' 229 'F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))', 230 input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), 231 cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, 232 reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, 233 pickle=False, 234 precision=3e-4, 235 default_dtype=torch.double, 236 ) 237 238 239def bceloss_weights_no_reduce_scalar_test(): 240 t = torch.randn(()).gt(0).to(torch.double) 241 weights = torch.rand((), dtype=torch.double) 242 return dict( 243 fullname='BCELoss_weights_no_reduce_scalar', 244 constructor=wrap_functional( 245 lambda i: F.binary_cross_entropy(i, t.type_as(i), 246 weight=weights.type_as(i), reduction='none')), 247 cpp_function_call='''F::binary_cross_entropy( 248 i, t.to(i.options()), 249 F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''', 250 cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, 251 input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), 252 reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, 253 pickle=False, 254 default_dtype=torch.double, 255 ) 256 257 258def bce_with_logistic_legacy_enum_test(): 259 t = Variable(torch.randn(15, 10).gt(0).to(torch.double)) 260 sigmoid = nn.Sigmoid() 261 return dict( 262 fullname='BCEWithLogitsLoss_legacy_enum', 263 constructor=wrap_functional( 264 lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)), 265 cpp_function_call='''F::binary_cross_entropy_with_logits( 266 i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''', 267 input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), 268 cpp_var_map={'i': '_get_input()', 't': t}, 269 reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()), 270 check_gradgrad=False, 271 pickle=False, 272 default_dtype=torch.double, 273 ) 274 275 276def bce_with_logistic_no_reduce_test(): 277 t = Variable(torch.randn(15, 10).gt(0).to(torch.double)) 278 sigmoid = nn.Sigmoid() 279 return dict( 280 fullname='BCEWithLogitsLoss_no_reduce', 281 constructor=wrap_functional( 282 lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')), 283 cpp_function_call='''F::binary_cross_entropy_with_logits( 284 i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''', 285 input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), 286 cpp_var_map={'i': '_get_input()', 't': t}, 287 reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()), 288 check_gradgrad=False, 289 pickle=False, 290 default_dtype=torch.double, 291 ) 292 293 294def bce_with_logistic_no_reduce_scalar_test(): 295 t = torch.randn(()).gt(0).to(torch.double) 296 sigmoid = nn.Sigmoid() 297 return dict( 298 fullname='BCEWithLogitsLoss_no_reduce_scalar', 299 constructor=wrap_functional( 300 lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')), 301 cpp_function_call='''F::binary_cross_entropy_with_logits( 302 i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''', 303 input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), 304 cpp_var_map={'i': '_get_input()', 't': t}, 305 reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()), 306 check_gradgrad=False, 307 pickle=False, 308 default_dtype=torch.double, 309 ) 310 311 312def kldivloss_with_target_no_reduce_test(): 313 t = torch.rand(10, 10, dtype=torch.double) 314 return dict( 315 fullname='KLDivLoss_with_target_no_reduce', 316 constructor=wrap_functional( 317 lambda i: F.kl_div(i, t.type_as(i), reduction='none')), 318 cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))', 319 input_fn=lambda: torch.rand(10, 10).log(), 320 cpp_var_map={'i': '_get_input()', 't': t}, 321 reference_fn=lambda i, *_: 322 loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), 323 supports_forward_ad=True, 324 pickle=False, 325 default_dtype=torch.double) 326 327 328def kldivloss_no_reduce_test(): 329 t = torch.rand(10, 10, dtype=torch.double) 330 return dict( 331 fullname='KLDivLoss_no_reduce', 332 constructor=wrap_functional( 333 lambda i: F.kl_div(i, t.type_as(i), reduction='none')), 334 cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))', 335 input_fn=lambda: torch.rand(10, 10).log(), 336 cpp_var_map={'i': '_get_input()', 't': t}, 337 reference_fn=lambda i, *_: 338 loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), 339 supports_forward_ad=True, 340 pickle=False, 341 default_dtype=torch.double, 342 ) 343 344 345def kldivloss_no_reduce_scalar_test(): 346 t = torch.rand((), dtype=torch.double) 347 return dict( 348 fullname='KLDivLoss_no_reduce_scalar', 349 constructor=wrap_functional( 350 lambda i: F.kl_div(i, t.type_as(i), reduction='none')), 351 cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))', 352 input_fn=lambda: torch.rand(()).log(), 353 cpp_var_map={'i': '_get_input()', 't': t}, 354 reference_fn=lambda i, *_: 355 loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), 356 supports_forward_ad=True, 357 pickle=False, 358 default_dtype=torch.double) 359 360 361def kldivloss_with_log_target_no_reduce_test(): 362 t = torch.rand(10, 10, dtype=torch.double).log() 363 return dict( 364 fullname='KLDivLoss_with_log_target_no_reduce', 365 constructor=wrap_functional( 366 lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)), 367 cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))', 368 input_fn=lambda: torch.rand(10, 10).log(), 369 cpp_var_map={'i': '_get_input()', 't': t}, 370 reference_fn=lambda i, *_: 371 loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'), 372 supports_forward_ad=True, 373 pickle=False, 374 default_dtype=torch.double) 375 376 377def kldivloss_no_reduce_log_target_test(): 378 t = torch.rand(10, 10, dtype=torch.double).log() 379 return dict( 380 fullname='KLDivLoss_no_reduce_log_target', 381 constructor=wrap_functional( 382 lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)), 383 cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))', 384 input_fn=lambda: torch.rand(10, 10).log(), 385 cpp_var_map={'i': '_get_input()', 't': t}, 386 reference_fn=lambda i, *_: 387 loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'), 388 supports_forward_ad=True, 389 pickle=False, 390 default_dtype=torch.double, 391 ) 392 393 394def kldivloss_no_reduce_scalar_log_target_test(): 395 t = torch.rand((), dtype=torch.double).log() 396 return dict( 397 fullname='KLDivLoss_no_reduce_scalar_log_target', 398 constructor=wrap_functional( 399 lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)), 400 cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))', 401 input_fn=lambda: torch.rand(()).log(), 402 cpp_var_map={'i': '_get_input()', 't': t}, 403 reference_fn=lambda i, *_: 404 loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'), 405 supports_forward_ad=True, 406 pickle=False, 407 default_dtype=torch.double) 408 409 410def l1loss_no_reduce_test(): 411 t = torch.randn(2, 3, 4, dtype=torch.double) 412 return dict( 413 fullname='L1Loss_no_reduce', 414 constructor=wrap_functional( 415 lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), 416 cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))', 417 input_fn=lambda: torch.randn(2, 3, 4), 418 cpp_var_map={'i': '_get_input()', 't': t}, 419 reference_fn=lambda i, *_: (i - t.type_as(i)).abs(), 420 supports_forward_ad=True, 421 pickle=False, 422 default_dtype=torch.double) 423 424 425def l1loss_no_reduce_complex_test(): 426 t = torch.randn(2, 3, 4, dtype=torch.cdouble) 427 return dict( 428 fullname='L1Loss_no_reduce_complex', 429 constructor=wrap_functional( 430 lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), 431 cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))', 432 input_fn=lambda: torch.randn(2, 3, 4, dtype=torch.cdouble), 433 cpp_var_map={'i': '_get_input()', 't': t}, 434 reference_fn=lambda i, *_: (i - t.type_as(i)).abs(), 435 supports_forward_ad=True, 436 pickle=False) 437 438 439def l1loss_no_reduce_scalar_test(): 440 t = torch.randn((), dtype=torch.double) 441 return dict( 442 fullname='L1Loss_no_reduce_scalar', 443 constructor=wrap_functional( 444 lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), 445 cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))', 446 input_fn=lambda: torch.randn(()), 447 cpp_var_map={'i': '_get_input()', 't': t}, 448 reference_fn=lambda i, *_: (i - t.type_as(i)).abs(), 449 supports_forward_ad=True, 450 pickle=False, 451 default_dtype=torch.double) 452 453 454def mseloss_no_reduce_test(): 455 input_size = (2, 3, 4, 5) 456 target = torch.randn(*input_size, dtype=torch.double) 457 return dict( 458 fullname='MSELoss_no_reduce', 459 constructor=wrap_functional( 460 lambda i: F.mse_loss(i, target.type_as(i), reduction='none')), 461 cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))', 462 input_size=input_size, 463 cpp_var_map={'i': '_get_input()', 'target': target}, 464 reference_fn=lambda i, *_: (i - target).pow(2), 465 supports_forward_ad=True, 466 pickle=False, 467 default_dtype=torch.double) 468 469 470def mseloss_no_reduce_scalar_test(): 471 input_size = () 472 target = torch.randn(input_size, dtype=torch.double) 473 return dict( 474 fullname='MSELoss_no_reduce_scalar', 475 constructor=wrap_functional( 476 lambda i: F.mse_loss(i, target.type_as(i), reduction='none')), 477 cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))', 478 input_size=input_size, 479 cpp_var_map={'i': '_get_input()', 'target': target}, 480 reference_fn=lambda i, *_: (i - target).pow(2), 481 supports_forward_ad=True, 482 pickle=False, 483 default_dtype=torch.double) 484 485 486def nllloss_no_reduce_test(): 487 t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) 488 kwargs = {'reduction': 'none'} 489 return dict( 490 fullname='NLLLoss_no_reduce', 491 constructor=wrap_functional( 492 lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), 493 cpp_function_call='''F::nll_loss( 494 i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', 495 input_fn=lambda: torch.rand(15, 10).log(), 496 cpp_var_map={'i': '_get_input()', 't': t}, 497 reference_fn=lambda i, *_: 498 loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs), 499 pickle=False, 500 default_dtype=torch.double) 501 502 503def nllloss_no_reduce_ignore_index_test(): 504 t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) 505 kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'} 506 return dict( 507 fullname='NLLLoss_no_reduce_ignore_index', 508 constructor=wrap_functional( 509 lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), 510 reduction=str(kwargs['reduction']))), 511 cpp_function_call='''F::nll_loss( 512 i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''', 513 input_fn=lambda: torch.rand(15, 10).log(), 514 cpp_var_map={'i': '_get_input()', 't': t}, 515 reference_fn=lambda i, *_: 516 loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs), 517 pickle=False, 518 default_dtype=torch.double) 519 520 521def nllloss_no_reduce_weights_test(): 522 t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) 523 weight = torch.rand(10) 524 525 def kwargs(i): 526 return {'weight': weight.type_as(i), 'reduction': 'none'} 527 528 return dict( 529 fullname='NLLLoss_no_reduce_weights', 530 constructor=wrap_functional( 531 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), 532 cpp_function_call='''F::nll_loss( 533 i, t.to(i.options()).to(torch::kLong), 534 F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''', 535 input_fn=lambda: torch.rand(15, 10).add(1e-2).log(), 536 cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, 537 reference_fn=lambda i, *_: 538 loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)), 539 pickle=False, 540 default_dtype=torch.double) 541 542 543def nllloss_no_reduce_weights_ignore_index_test(): 544 t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) 545 weight = torch.rand(10) 546 547 def kwargs(i): 548 return {'weight': weight.type_as(i), 'reduction': 'none', 549 'ignore_index': 2} 550 551 return dict( 552 fullname='NLLLoss_no_reduce_weights_ignore_index', 553 constructor=wrap_functional( 554 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))), 555 cpp_function_call='''F::nll_loss( 556 i, t.to(i.options()).to(torch::kLong), 557 F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))''', 558 input_fn=lambda: torch.rand(15, 10).add(1e-2).log(), 559 cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, 560 reference_fn=lambda i, *_: 561 loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)), 562 pickle=False, 563 default_dtype=torch.double) 564 565 566def nllloss_no_reduce_weights_ignore_index_neg_test(): 567 t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) 568 weight = torch.rand(10) 569 570 def kwargs(i): 571 return {'weight': weight.type_as(i), 'reduction': 'none', 572 'ignore_index': -1} 573 574 return dict( 575 fullname='NLLLoss_no_reduce_weights_ignore_index_neg', 576 constructor=wrap_functional( 577 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), 578 cpp_function_call='''F::nll_loss( 579 i, t.to(i.options()).to(torch::kLong), 580 F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))''', 581 input=torch.rand(15, 10, dtype=torch.double).add(1e-2).log(), 582 cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, 583 reference_fn=lambda i, *_: 584 loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)), 585 pickle=False, 586 default_dtype=torch.double) 587 588 589def nllloss2d_no_reduce_test(): 590 t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) 591 kwargs = {'reduction': 'none'} 592 return dict( 593 fullname='NLLLoss2d_no_reduce', 594 constructor=wrap_functional( 595 lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), 596 cpp_function_call='''F::nll_loss( 597 i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', 598 input_fn=lambda: torch.rand(2, 3, 5, 5).log(), 599 cpp_var_map={'i': '_get_input()', 't': t}, 600 reference_fn=lambda i, *_: 601 loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), 602 pickle=False, 603 default_dtype=torch.double) 604 605 606def nllloss2d_no_reduce_ignore_index_test(): 607 t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) 608 kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'} 609 return dict( 610 fullname='NLLLoss2d_no_reduce_ignore_index', 611 constructor=wrap_functional( 612 lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), 613 reduction=str(kwargs['reduction']))), 614 cpp_function_call='''F::nll_loss( 615 i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''', 616 input_fn=lambda: torch.rand(2, 3, 5, 5).log(), 617 cpp_var_map={'i': '_get_input()', 't': t}, 618 reference_fn=lambda i, *_: 619 loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), 620 pickle=False, 621 default_dtype=torch.double) 622 623 624def nllloss2d_no_reduce_weights_test(): 625 t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) 626 weight = torch.rand(3) 627 628 def kwargs(i): 629 return {'weight': weight.type_as(i), 'reduction': 'none'} 630 631 return dict( 632 fullname='NLLLoss2d_no_reduce_weights', 633 constructor=wrap_functional( 634 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), 635 cpp_function_call='''F::nll_loss( 636 i, t.to(i.options()).to(torch::kLong), 637 F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''', 638 input_fn=lambda: torch.rand(2, 3, 5, 5).log(), 639 cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, 640 reference_fn=lambda i, *_: 641 loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)), 642 pickle=False, 643 default_dtype=torch.double) 644 645 646def nlllossNd_no_reduce_test(): 647 t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) 648 kwargs = {'reduction': 'none'} 649 return dict( 650 fullname='NLLLossNd_no_reduce', 651 constructor=wrap_functional( 652 lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), 653 cpp_function_call='''F::nll_loss( 654 i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', 655 input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(), 656 cpp_var_map={'i': '_get_input()', 't': t}, 657 reference_fn=lambda i, *_: 658 loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), 659 pickle=False, 660 default_dtype=torch.double) 661 662 663def nlllossNd_no_reduce_ignore_index_test(): 664 t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) 665 kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'} 666 return dict( 667 fullname='NLLLossNd_no_reduce_ignore_index', 668 constructor=wrap_functional( 669 lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), 670 reduction=str(kwargs['reduction']))), 671 cpp_function_call='''F::nll_loss( 672 i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''', 673 input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(), 674 cpp_var_map={'i': '_get_input()', 't': t}, 675 reference_fn=lambda i, *_: 676 loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), 677 pickle=False, 678 default_dtype=torch.double) 679 680 681def nlllossNd_no_reduce_weights_test(): 682 t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) 683 weight = torch.rand(3) 684 685 def kwargs(i): 686 return {'weight': weight.type_as(i), 'reduction': 'none'} 687 688 return dict( 689 fullname='NLLLossNd_no_reduce_weights', 690 constructor=wrap_functional( 691 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), 692 cpp_function_call='''F::nll_loss( 693 i, t.to(i.options()).to(torch::kLong), 694 F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''', 695 input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(), 696 cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, 697 reference_fn=lambda i, *_: 698 loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)), 699 pickle=False, 700 default_dtype=torch.double) 701 702 703def smoothl1loss_no_reduce_test(): 704 t = torch.randn(2, 3, 4, dtype=torch.double) 705 return dict( 706 fullname='SmoothL1Loss_no_reduce', 707 constructor=wrap_functional( 708 lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')), 709 cpp_function_call='''F::smooth_l1_loss( 710 i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''', 711 input_fn=lambda: torch.randn(2, 3, 4), 712 cpp_var_map={'i': '_get_input()', 't': t}, 713 reference_fn=lambda i, *_: 714 loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'), 715 supports_forward_ad=True, 716 pickle=False, 717 default_dtype=torch.double) 718 719 720def smoothl1loss_no_reduce_scalar_test(): 721 t = torch.randn((), dtype=torch.double) 722 return dict( 723 fullname='SmoothL1Loss_no_reduce_scalar', 724 constructor=wrap_functional( 725 lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')), 726 cpp_function_call='''F::smooth_l1_loss( 727 i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''', 728 input_fn=lambda: torch.randn(()), 729 cpp_var_map={'i': '_get_input()', 't': t}, 730 reference_fn=lambda i, *_: 731 loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'), 732 supports_forward_ad=True, 733 pickle=False, 734 default_dtype=torch.double) 735 736 737def smoothl1loss_beta_test(): 738 t = torch.randn(2, 3, 4, dtype=torch.double) 739 return dict( 740 fullname='SmoothL1Loss_beta', 741 constructor=wrap_functional( 742 lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0.5)), 743 cpp_function_call='''F::smooth_l1_loss( 744 i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)''', 745 input_fn=lambda: torch.randn(2, 3, 4), 746 cpp_var_map={'i': '_get_input()', 't': t}, 747 reference_fn=lambda i, *_: 748 loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0.5), 749 supports_forward_ad=True, 750 pickle=False, 751 default_dtype=torch.double) 752 753 754def smoothl1loss_zero_beta_test(): 755 t = torch.randn(2, 3, 4, dtype=torch.double) 756 return dict( 757 fullname='SmoothL1Loss_zero_beta', 758 constructor=wrap_functional( 759 lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0)), 760 cpp_function_call='''F::smooth_l1_loss( 761 i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)''', 762 input_fn=lambda: torch.randn(2, 3, 4), 763 cpp_var_map={'i': '_get_input()', 't': t}, 764 reference_fn=lambda i, *_: 765 loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0), 766 supports_forward_ad=True, 767 pickle=False, 768 default_dtype=torch.double) 769 770 771def huberloss_delta_test(): 772 t = torch.randn(2, 3, 4) 773 return dict( 774 fullname='HuberLoss_delta', 775 constructor=wrap_functional( 776 lambda i: F.huber_loss(i, t.type_as(i), reduction='none', delta=0.5)), 777 cpp_function_call='''F::huber_loss( 778 i, t.to(i.options()), F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5))''', 779 input_fn=lambda: torch.randn(2, 3, 4), 780 cpp_var_map={'i': '_get_input()', 't': t}, 781 reference_fn=lambda i, *_: 782 loss_reference_fns['HuberLoss'](i, t.type_as(i), reduction='none', delta=0.5), 783 supports_forward_ad=True, 784 pickle=False, 785 default_dtype=torch.double) 786 787 788def multilabelmarginloss_0d_no_reduce_test(): 789 t = torch.zeros(()).long() 790 return dict( 791 fullname='MultiLabelMarginLoss_0d_no_reduce', 792 constructor=wrap_functional( 793 lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), 794 cpp_function_call='''F::multilabel_margin_loss( 795 i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', 796 input_fn=lambda: torch.randn(()), 797 cpp_var_map={'i': '_get_input()', 't': t}, 798 reference_fn=lambda i, *_: 799 loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), 800 check_sum_reduction=True, 801 check_gradgrad=False, 802 pickle=False) 803 804 805def multilabelmarginloss_1d_no_reduce_test(): 806 t = Variable(torch.rand(10).mul(10).floor().long()) 807 return dict( 808 fullname='MultiLabelMarginLoss_1d_no_reduce', 809 constructor=wrap_functional( 810 lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), 811 cpp_function_call='''F::multilabel_margin_loss( 812 i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', 813 input_fn=lambda: torch.randn(10), 814 cpp_var_map={'i': '_get_input()', 't': t}, 815 reference_fn=lambda i, *_: 816 loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), 817 check_sum_reduction=True, 818 check_gradgrad=False, 819 pickle=False, 820 default_dtype=torch.double) 821 822 823def multilabelmarginloss_index_neg_test(): 824 t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1)) 825 return dict( 826 fullname='MultiLabelMarginLoss_index_neg', 827 constructor=wrap_functional( 828 lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), 829 cpp_function_call='''F::multilabel_margin_loss( 830 i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', 831 input_fn=lambda: torch.randn(5, 10), 832 cpp_var_map={'i': '_get_input()', 't': t}, 833 reference_fn=lambda i, *_: 834 loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), 835 check_sum_reduction=True, 836 check_gradgrad=False, 837 pickle=False, 838 default_dtype=torch.double) 839 840 841def multilabelmarginloss_no_reduce_test(): 842 t = Variable(torch.rand(5, 10).mul(10).floor().long()) 843 return dict( 844 fullname='MultiLabelMarginLoss_no_reduce', 845 constructor=wrap_functional( 846 lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), 847 cpp_function_call='''F::multilabel_margin_loss( 848 i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', 849 input_fn=lambda: torch.randn(5, 10), 850 cpp_var_map={'i': '_get_input()', 't': t}, 851 reference_fn=lambda i, *_: 852 loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), 853 check_sum_reduction=True, 854 check_gradgrad=False, 855 pickle=False, 856 default_dtype=torch.double) 857 858 859def hingeembeddingloss_no_reduce_test(): 860 t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1)) 861 return dict( 862 fullname='HingeEmbeddingLoss_no_reduce', 863 constructor=wrap_functional( 864 lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')), 865 cpp_function_call='''F::hinge_embedding_loss( 866 i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))''', 867 input_fn=lambda: torch.randn(10), 868 cpp_var_map={'i': '_get_input()', 't': t}, 869 reference_fn=lambda i, *_: 870 loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'), 871 check_sum_reduction=True, 872 pickle=False, 873 default_dtype=torch.double) 874 875 876def hingeembeddingloss_margin_no_reduce_test(): 877 t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1)) 878 return dict( 879 fullname='HingeEmbeddingLoss_margin_no_reduce', 880 constructor=wrap_functional( 881 lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')), 882 cpp_function_call='''F::hinge_embedding_loss( 883 i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))''', 884 input_fn=lambda: torch.randn(10), 885 cpp_var_map={'i': '_get_input()', 't': t}, 886 reference_fn=lambda i, *_: 887 loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'), 888 check_sum_reduction=True, 889 pickle=False, 890 default_dtype=torch.double) 891 892 893def softmarginloss_no_reduce_test(): 894 t = torch.randn(5, 5, dtype=torch.double) 895 return dict( 896 fullname='SoftMarginLoss_no_reduce', 897 constructor=wrap_functional( 898 lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')), 899 cpp_function_call='''F::soft_margin_loss( 900 i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))''', 901 input_fn=lambda: torch.randn(5, 5), 902 cpp_var_map={'i': '_get_input()', 't': t}, 903 reference_fn=lambda i, *_: 904 loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'), 905 supports_forward_ad=True, 906 pickle=False, 907 default_dtype=torch.double) 908 909 910def multilabelsoftmarginloss_no_reduce_test(): 911 t = torch.rand(5, 10).mul(2).floor() 912 return dict( 913 fullname='MultiLabelSoftMarginLoss_no_reduce', 914 constructor=wrap_functional( 915 lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')), 916 cpp_function_call='''F::multilabel_soft_margin_loss( 917 i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))''', 918 input_fn=lambda: torch.randn(5, 10), 919 cpp_var_map={'i': '_get_input()', 't': t}, 920 reference_fn=lambda i, *_: 921 (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1), 922 check_gradgrad=False, 923 pickle=False, 924 default_dtype=torch.double) 925 926 927def multilabelsoftmarginloss_weights_no_reduce_test(): 928 t = torch.rand(5, 10).mul(2).floor() 929 weights = torch.rand(10) 930 return dict( 931 fullname='MultiLabelSoftMarginLoss_weights_no_reduce', 932 constructor=wrap_functional( 933 lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), 934 weight=weights.type_as(i), reduction='none')), 935 cpp_function_call='''F::multilabel_soft_margin_loss( 936 i, t.to(i.options()), 937 F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''', 938 input_fn=lambda: torch.randn(5, 10), 939 cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, 940 reference_fn=lambda i, *_: 941 (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1), 942 check_sum_reduction=True, 943 check_gradgrad=False, 944 pickle=False, 945 default_dtype=torch.double) 946 947 948def multimarginloss_no_reduce_test(): 949 t = torch.rand(5).mul(8).floor().long() 950 return dict( 951 fullname='MultiMarginLoss_no_reduce', 952 constructor=wrap_functional( 953 lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')), 954 cpp_function_call='''F::multi_margin_loss( 955 i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''', 956 input_fn=lambda: torch.randn(5, 10), 957 cpp_var_map={'i': '_get_input()', 't': t}, 958 reference_fn=lambda i, *_: 959 loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), 960 check_sum_reduction=True, 961 check_gradgrad=False, 962 pickle=False, 963 default_dtype=torch.double) 964 965 966def multimarginloss_1d_no_reduce_test(): 967 t = torch.rand(1).mul(8).floor().long() 968 return dict( 969 fullname='MultiMarginLoss_1d_no_reduce', 970 constructor=wrap_functional( 971 lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')), 972 cpp_function_call='''F::multi_margin_loss( 973 i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''', 974 input_fn=lambda: torch.randn(10), 975 cpp_var_map={'i': '_get_input()', 't': t}, 976 reference_fn=lambda i, *_: 977 loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), 978 check_sum_reduction=True, 979 check_gradgrad=False, 980 pickle=False, 981 default_dtype=torch.double) 982 983 984def multimarginloss_1d_input_0d_target_no_reduce_test(): 985 t = torch.rand(()).mul(8).floor().long() 986 return dict( 987 fullname='multimarginloss_1d_input_0d_target_no_reduce', 988 constructor=wrap_functional( 989 lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')), 990 cpp_function_call='''F::multi_margin_loss( 991 i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''', 992 input_fn=lambda: torch.randn(10), 993 cpp_var_map={'i': '_get_input()', 't': t}, 994 reference_fn=lambda i, *_: 995 loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), 996 check_sum_reduction=True, 997 check_gradgrad=False, 998 pickle=False, 999 default_dtype=torch.double) 1000 1001 1002def multimarginloss_p_no_reduce_test(): 1003 t = torch.rand(5).mul(8).floor().long() 1004 return dict( 1005 fullname='MultiMarginLoss_p_no_reduce', 1006 constructor=wrap_functional( 1007 lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')), 1008 cpp_function_call='''F::multi_margin_loss( 1009 i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))''', 1010 input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2), 1011 cpp_var_map={'i': '_get_input()', 't': t}, 1012 reference_fn=lambda i, *_: 1013 loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'), 1014 check_sum_reduction=True, 1015 check_gradgrad=False, 1016 pickle=False, 1017 default_dtype=torch.double) 1018 1019 1020def multimarginloss_margin_no_reduce_test(): 1021 t = torch.rand(5).mul(8).floor().long() 1022 return dict( 1023 fullname='MultiMarginLoss_margin_no_reduce', 1024 constructor=wrap_functional( 1025 lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')), 1026 cpp_function_call='''F::multi_margin_loss( 1027 i, t.to(i.options()).to(torch::kLong), 1028 F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))''', 1029 input_fn=lambda: torch.randn(5, 10), 1030 cpp_var_map={'i': '_get_input()', 't': t}, 1031 reference_fn=lambda i, *_: 1032 loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), 1033 margin=0.5, reduction='none'), 1034 check_sum_reduction=True, 1035 check_gradgrad=False, 1036 pickle=False, 1037 default_dtype=torch.double) 1038 1039 1040def multimarginloss_weights_no_reduce_test(): 1041 t = torch.rand(5).mul(8).floor().long() 1042 weights = torch.rand(10, dtype=torch.double) 1043 return dict( 1044 fullname='MultiMarginLoss_weights_no_reduce', 1045 constructor=wrap_functional( 1046 lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i), 1047 reduction='none')), 1048 cpp_function_call='''F::multi_margin_loss( 1049 i, t.to(i.options()).to(torch::kLong), 1050 F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''', 1051 input_fn=lambda: torch.randn(5, 10), 1052 cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, 1053 reference_fn=lambda i, *_: 1054 loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), 1055 weight=weights, reduction='none'), 1056 check_sum_reduction=True, 1057 check_gradgrad=False, 1058 pickle=False, 1059 default_dtype=torch.double) 1060 1061 1062def single_batch_reference_fn(input, parameters, module): 1063 """Reference function for modules supporting no batch dimensions. 1064 1065 The module is passed the input and target in batched form with a single item. 1066 The output is squeezed to compare with the no-batch input. 1067 """ 1068 def unsqueeze_inp(inp): 1069 if isinstance(inp, (list, tuple)): 1070 return [t.unsqueeze(0) for t in inp] 1071 return inp.unsqueeze(0) 1072 1073 single_batch_input = unsqueeze_inp(input) 1074 single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input 1075 with freeze_rng_state(): 1076 return module(*single_batch_input).squeeze(0) 1077 1078 1079new_module_tests = [ 1080 poissonnllloss_no_reduce_test(), 1081 bceloss_no_reduce_test(), 1082 bceloss_weights_no_reduce_test(), 1083 bce_with_logistic_legacy_enum_test(), 1084 bce_with_logistic_no_reduce_test(), 1085 bceloss_no_reduce_scalar_test(), 1086 bceloss_weights_no_reduce_scalar_test(), 1087 bce_with_logistic_no_reduce_scalar_test(), 1088 kldivloss_with_target_no_reduce_test(), 1089 kldivloss_no_reduce_test(), 1090 kldivloss_no_reduce_scalar_test(), 1091 kldivloss_with_log_target_no_reduce_test(), 1092 kldivloss_no_reduce_log_target_test(), 1093 kldivloss_no_reduce_scalar_log_target_test(), 1094 l1loss_no_reduce_test(), 1095 l1loss_no_reduce_complex_test(), 1096 l1loss_no_reduce_scalar_test(), 1097 mseloss_no_reduce_test(), 1098 mseloss_no_reduce_scalar_test(), 1099 nllloss_no_reduce_test(), 1100 nllloss_no_reduce_ignore_index_test(), 1101 nllloss_no_reduce_weights_test(), 1102 nllloss_no_reduce_weights_ignore_index_test(), 1103 nllloss_no_reduce_weights_ignore_index_neg_test(), 1104 nllloss2d_no_reduce_test(), 1105 nllloss2d_no_reduce_weights_test(), 1106 nllloss2d_no_reduce_ignore_index_test(), 1107 nlllossNd_no_reduce_test(), 1108 nlllossNd_no_reduce_weights_test(), 1109 nlllossNd_no_reduce_ignore_index_test(), 1110 smoothl1loss_no_reduce_test(), 1111 smoothl1loss_no_reduce_scalar_test(), 1112 smoothl1loss_beta_test(), 1113 smoothl1loss_zero_beta_test(), 1114 huberloss_delta_test(), 1115 multilabelmarginloss_0d_no_reduce_test(), 1116 multilabelmarginloss_1d_no_reduce_test(), 1117 multilabelmarginloss_index_neg_test(), 1118 multilabelmarginloss_no_reduce_test(), 1119 hingeembeddingloss_no_reduce_test(), 1120 hingeembeddingloss_margin_no_reduce_test(), 1121 softmarginloss_no_reduce_test(), 1122 multilabelsoftmarginloss_no_reduce_test(), 1123 multilabelsoftmarginloss_weights_no_reduce_test(), 1124 multimarginloss_no_reduce_test(), 1125 multimarginloss_1d_no_reduce_test(), 1126 multimarginloss_1d_input_0d_target_no_reduce_test(), 1127 multimarginloss_p_no_reduce_test(), 1128 multimarginloss_margin_no_reduce_test(), 1129 multimarginloss_weights_no_reduce_test(), 1130 dict( 1131 module_name='Conv1d', 1132 constructor_args=(4, 5, 3), 1133 cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)', 1134 input_size=(2, 4, 10), 1135 cudnn=True, 1136 with_tf32=True, 1137 tf32_precision=0.005, 1138 default_dtype=torch.double, 1139 ), 1140 dict( 1141 module_name='Conv1d', 1142 constructor_args=(4, 5, 3, 2), 1143 cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)', 1144 input_size=(2, 4, 10), 1145 cudnn=True, 1146 desc='stride', 1147 with_tf32=True, 1148 tf32_precision=0.005, 1149 default_dtype=torch.double, 1150 ), 1151 dict( 1152 module_name='Conv1d', 1153 constructor_args=(4, 5, 3, 1, 1), 1154 cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)', 1155 input_size=(2, 4, 10), 1156 cudnn=True, 1157 desc='pad1', 1158 with_tf32=True, 1159 tf32_precision=0.01, 1160 default_dtype=torch.double, 1161 ), 1162 dict( 1163 module_name='Conv1d', 1164 constructor_args=(4, 5, 5, 1, 2), 1165 cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)', 1166 input_size=(2, 4, 10), 1167 cudnn=True, 1168 desc='pad2', 1169 with_tf32=True, 1170 tf32_precision=0.005, 1171 default_dtype=torch.double, 1172 ), 1173 dict( 1174 module_name='Conv1d', 1175 constructor_args=(4, 4, 3, 1, 1), 1176 cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)', 1177 input_size=(1, 4, 1), 1178 cudnn=True, 1179 desc='pad1size1', 1180 with_tf32=True, 1181 tf32_precision=0.005, 1182 default_dtype=torch.double, 1183 ), 1184 dict( 1185 module_name='Conv1d', 1186 constructor_args=(4, 4, 5, 1, 2), 1187 cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)', 1188 input_size=(1, 4, 1), 1189 cudnn=True, 1190 desc='pad2size1', 1191 with_tf32=True, 1192 tf32_precision=0.005, 1193 default_dtype=torch.double, 1194 ), 1195 dict( 1196 module_name='Conv1d', 1197 constructor_args=(4, 5, 3), 1198 cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)', 1199 input_size=(0, 4, 10), 1200 cudnn=True, 1201 desc='zero_batch', 1202 with_tf32=True, 1203 tf32_precision=0.005, 1204 ), 1205 dict( 1206 fullname='Conv1d_dilated', 1207 constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2), 1208 cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)', 1209 input_size=(2, 4, 10), 1210 with_tf32=True, 1211 tf32_precision=0.005, 1212 default_dtype=torch.double, 1213 ), 1214 dict( 1215 fullname='Conv1d_groups', 1216 constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2), 1217 cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)', 1218 input_size=(2, 4, 6), 1219 cudnn=True, 1220 with_tf32=True, 1221 tf32_precision=0.005, 1222 default_dtype=torch.double, 1223 ), 1224 dict( 1225 fullname='Conv1d_pad_valid', 1226 constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"), 1227 cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)', 1228 input_size=(2, 4, 10), 1229 cudnn=True, 1230 with_tf32=True, 1231 tf32_precision=0.005, 1232 default_dtype=torch.double, 1233 ), 1234 dict( 1235 fullname='Conv1d_pad_same', 1236 constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"), 1237 cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)', 1238 input_size=(2, 4, 10), 1239 cudnn=True, 1240 with_tf32=True, 1241 tf32_precision=0.005, 1242 default_dtype=torch.double, 1243 ), 1244 dict( 1245 fullname='Conv1d_pad_same2', 1246 constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"), 1247 cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)', 1248 input_size=(2, 4, 10), 1249 cudnn=True, 1250 with_tf32=True, 1251 tf32_precision=0.005, 1252 default_dtype=torch.double, 1253 ), 1254 dict( 1255 fullname='Conv1d_pad_same_dilated', 1256 constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2), 1257 cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)', 1258 input_size=(2, 4, 10), 1259 cudnn=True, 1260 with_tf32=True, 1261 tf32_precision=0.005, 1262 default_dtype=torch.double, 1263 ), 1264 dict( 1265 fullname='ConvTranspose1d', 1266 constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)), 1267 cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)', 1268 cudnn=True, 1269 input_size=(1, 3, 7), 1270 with_tf32=True, 1271 tf32_precision=0.005, 1272 default_dtype=torch.double, 1273 ), 1274 dict( 1275 module_name='ConvTranspose1d', 1276 constructor_args=(3, 4, 3, 2, 1, 1, 1, False), 1277 cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3) 1278 .stride(2).padding(1).output_padding(1).groups(1).bias(false)''', 1279 input_size=(1, 3, 6), 1280 cudnn=True, 1281 desc='no_bias', 1282 with_tf32=True, 1283 tf32_precision=0.005, 1284 default_dtype=torch.double, 1285 ), 1286 dict( 1287 module_name='ConvTranspose1d', 1288 constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2), 1289 cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3) 1290 .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''', 1291 input_size=(1, 3, 6), 1292 cudnn=True, 1293 desc='dilated', 1294 with_tf32=True, 1295 tf32_precision=0.005, 1296 default_dtype=torch.double, 1297 ), 1298 dict( 1299 fullname='ConvTranspose1d_groups', 1300 constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2), 1301 cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3) 1302 .stride(3).padding(1).output_padding(1).groups(2)''', 1303 cudnn=True, 1304 input_size=(2, 4, 7), 1305 with_tf32=True, 1306 tf32_precision=0.005, 1307 default_dtype=torch.double, 1308 ), 1309 dict( 1310 module_name='Conv2d', 1311 constructor_args=(3, 4, (3, 2)), 1312 cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})', 1313 input_size=(2, 3, 7, 5), 1314 cudnn=True, 1315 check_with_long_tensor=True, 1316 with_tf32=True, 1317 tf32_precision=0.005, 1318 default_dtype=torch.double, 1319 ), 1320 dict( 1321 module_name='Conv2d', 1322 constructor_args=(3, 4, (3, 3), (2, 2)), 1323 cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})', 1324 input_size=(2, 3, 6, 6), 1325 cudnn=True, 1326 desc='strided', 1327 check_with_long_tensor=True, 1328 with_tf32=True, 1329 tf32_precision=0.005, 1330 default_dtype=torch.double, 1331 ), 1332 dict( 1333 module_name='Conv2d', 1334 constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)), 1335 cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})', 1336 input_size=(2, 3, 6, 6), 1337 cudnn=True, 1338 desc='padding', 1339 check_with_long_tensor=True, 1340 with_tf32=True, 1341 tf32_precision=0.005, 1342 default_dtype=torch.double, 1343 ), 1344 dict( 1345 module_name='Conv2d', 1346 constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)), 1347 cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})', 1348 input_size=(2, 3, 8, 8), 1349 cudnn=True, 1350 desc='dilated', 1351 check_with_long_tensor=True, 1352 with_tf32=True, 1353 tf32_precision=0.005, 1354 default_dtype=torch.double, 1355 ), 1356 dict( 1357 module_name='Conv2d', 1358 constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False), 1359 cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2}) 1360 .stride(1).padding(0).dilation(1).groups(1).bias(false)''', 1361 input_size=(2, 3, 6, 5), 1362 cudnn=True, 1363 desc='no_bias', 1364 check_with_long_tensor=True, 1365 with_tf32=True, 1366 tf32_precision=0.015, 1367 default_dtype=torch.double, 1368 ), 1369 dict( 1370 module_name='Conv2d', 1371 constructor_args=(3, 4, (3, 2)), 1372 cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})', 1373 input_size=(0, 3, 7, 5), 1374 cudnn=True, 1375 desc='zero_batch', 1376 check_with_long_tensor=True, 1377 with_tf32=True, 1378 ), 1379 dict( 1380 fullname='Conv2d_groups', 1381 constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2), 1382 cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)', 1383 input_size=(2, 4, 6, 5), 1384 cudnn=True, 1385 check_with_long_tensor=True, 1386 with_tf32=True, 1387 tf32_precision=0.015, 1388 default_dtype=torch.double, 1389 ), 1390 dict( 1391 fullname='Conv2d_groups_thnn', 1392 constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2), 1393 cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)', 1394 input_size=(2, 4, 6, 5), 1395 check_with_long_tensor=True, 1396 with_tf32=True, 1397 tf32_precision=0.015, 1398 default_dtype=torch.double, 1399 ), 1400 dict( 1401 fullname='Conv2d_pad_valid', 1402 constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"), 1403 cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)', 1404 input_size=(2, 2, 6, 5), 1405 cudnn=True, 1406 with_tf32=True, 1407 tf32_precision=0.005, 1408 default_dtype=torch.double, 1409 ), 1410 dict( 1411 fullname='Conv2d_pad_same', 1412 constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"), 1413 cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)', 1414 input_size=(2, 2, 6, 5), 1415 cudnn=True, 1416 with_tf32=True, 1417 tf32_precision=0.01, 1418 default_dtype=torch.double, 1419 ), 1420 dict( 1421 fullname='Conv2d_pad_same_dilated', 1422 constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2), 1423 cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)', 1424 input_size=(2, 2, 6, 5), 1425 cudnn=True, 1426 with_tf32=True, 1427 tf32_precision=0.01, 1428 default_dtype=torch.double, 1429 ), 1430 dict( 1431 module_name='ConvTranspose2d', 1432 constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)), 1433 cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) 1434 .stride({3, 2}).padding(1).output_padding({1, 1})''', 1435 cudnn=True, 1436 input_size=(1, 3, 7, 6), 1437 check_with_long_tensor=True, 1438 with_tf32=True, 1439 tf32_precision=0.01, 1440 default_dtype=torch.double, 1441 ), 1442 dict( 1443 module_name='ConvTranspose2d', 1444 constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)), 1445 cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) 1446 .stride({2, 3}) 1447 .padding(1) 1448 .output_padding({1, 1}) 1449 .groups(1) 1450 .bias(false) 1451 .dilation({2, 2})''', 1452 input_size=(1, 3, 6, 7), 1453 cudnn=True, 1454 desc='dilated', 1455 check_with_long_tensor=True, 1456 with_tf32=True, 1457 tf32_precision=0.01, 1458 default_dtype=torch.double, 1459 ), 1460 dict( 1461 module_name='ConvTranspose2d', 1462 constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False), 1463 cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) 1464 .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''', 1465 input_size=(1, 3, 6, 7), 1466 cudnn=True, 1467 desc='no_bias', 1468 check_with_long_tensor=True, 1469 with_tf32=True, 1470 tf32_precision=0.01, 1471 default_dtype=torch.double, 1472 ), 1473 dict( 1474 fullname='ConvTranspose2d_groups', 1475 constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2), 1476 cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)', 1477 input_size=(1, 2, 4, 5), 1478 cudnn=True, 1479 check_with_long_tensor=True, 1480 with_tf32=True, 1481 tf32_precision=0.01, 1482 default_dtype=torch.double, 1483 ), 1484 dict( 1485 fullname='Conv2d_depthwise', 1486 constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4), 1487 cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)', 1488 input_size=(2, 4, 6, 6), 1489 with_tf32=True, 1490 tf32_precision=0.005, 1491 default_dtype=torch.double, 1492 ), 1493 dict( 1494 fullname='Conv2d_depthwise_with_multiplier', 1495 constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4), 1496 cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)', 1497 input_size=(2, 4, 6, 6), 1498 with_tf32=True, 1499 tf32_precision=0.005, 1500 default_dtype=torch.double, 1501 ), 1502 dict( 1503 fullname='Conv2d_depthwise_strided', 1504 constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4), 1505 cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)', 1506 input_size=(2, 4, 6, 6), 1507 with_tf32=True, 1508 tf32_precision=0.005, 1509 default_dtype=torch.double, 1510 ), 1511 dict( 1512 fullname='Conv2d_depthwise_padded', 1513 constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4), 1514 cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)', 1515 input_size=(2, 4, 6, 6), 1516 with_tf32=True, 1517 tf32_precision=0.005, 1518 default_dtype=torch.double, 1519 ), 1520 dict( 1521 fullname='Conv2d_depthwise_dilated', 1522 constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4), 1523 cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)', 1524 input_size=(2, 4, 5, 5), 1525 with_tf32=True, 1526 tf32_precision=0.005, 1527 default_dtype=torch.double, 1528 ), 1529 dict( 1530 module_name='Conv3d', 1531 constructor_args=(2, 3, (2, 3, 2)), 1532 cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})', 1533 input_size=(1, 2, 4, 5, 4), 1534 cudnn=True, 1535 check_with_long_tensor=True, 1536 with_tf32=True, 1537 tf32_precision=0.05, 1538 default_dtype=torch.double, 1539 ), 1540 dict( 1541 module_name='Conv3d', 1542 constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False), 1543 cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4}) 1544 .stride(1).padding(0).dilation(1).groups(1).bias(false)''', 1545 input_size=(1, 2, 3, 4, 5), 1546 cudnn=True, 1547 desc='no_bias', 1548 check_with_long_tensor=True, 1549 with_tf32=True, 1550 tf32_precision=0.05, 1551 default_dtype=torch.double, 1552 ), 1553 dict( 1554 module_name='Conv3d', 1555 constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False), 1556 cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4}) 1557 .stride(1).padding(0).dilation(1).groups(1).bias(false)''', 1558 input_size=(1, 2, 3, 4, 5), 1559 cudnn=True, 1560 desc='1x1x1_no_bias', 1561 check_with_long_tensor=False, 1562 with_tf32=True, 1563 tf32_precision=0.05, 1564 default_dtype=torch.double, 1565 ), 1566 dict( 1567 module_name='Conv3d', 1568 constructor_args=(3, 4, 2, 2), 1569 cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)', 1570 input_size=(2, 3, 5, 5, 5), 1571 cudnn=True, 1572 desc='stride', 1573 check_with_long_tensor=True, 1574 with_tf32=True, 1575 tf32_precision=0.05, 1576 default_dtype=torch.double, 1577 ), 1578 dict( 1579 module_name='Conv3d', 1580 constructor_args=(3, 4, 2, 2, 1), 1581 cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)', 1582 input_size=(2, 3, 5, 5, 5), 1583 cudnn=True, 1584 desc='stride_padding', 1585 check_with_long_tensor=True, 1586 with_tf32=True, 1587 tf32_precision=0.05, 1588 default_dtype=torch.double, 1589 ), 1590 dict( 1591 module_name='Conv3d', 1592 constructor_args=(3, 4, (2, 3, 4)), 1593 cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})', 1594 input_size=(0, 3, 3, 4, 5), 1595 cudnn=True, 1596 check_with_long_tensor=True, 1597 desc='zero_batch', 1598 with_tf32=True, 1599 ), 1600 dict( 1601 fullname='Conv3d_groups', 1602 constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2), 1603 cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)', 1604 input_size=(1, 2, 4, 5, 4), 1605 cudnn=True, 1606 check_with_long_tensor=True, 1607 with_tf32=True, 1608 tf32_precision=0.005, 1609 default_dtype=torch.double, 1610 ), 1611 dict( 1612 fullname='Conv3d_dilated', 1613 constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2), 1614 cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)', 1615 input_size=(2, 3, 5, 5, 5), 1616 with_tf32=True, 1617 tf32_precision=0.05, 1618 default_dtype=torch.double, 1619 ), 1620 dict( 1621 fullname='Conv3d_dilated_strided', 1622 constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2), 1623 cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)', 1624 input_size=(2, 3, 5, 5, 5), 1625 with_tf32=True, 1626 tf32_precision=0.05, 1627 default_dtype=torch.double, 1628 ), 1629 dict( 1630 fullname='Conv3d_pad_valid', 1631 constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"), 1632 cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)', 1633 input_size=(2, 3, 6, 5, 4), 1634 cudnn=True, 1635 with_tf32=True, 1636 tf32_precision=0.05, 1637 default_dtype=torch.double, 1638 ), 1639 dict( 1640 fullname='Conv3d_pad_same', 1641 constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"), 1642 cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)', 1643 input_size=(2, 3, 6, 5, 4), 1644 cudnn=True, 1645 with_tf32=True, 1646 tf32_precision=0.05, 1647 default_dtype=torch.double, 1648 ), 1649 dict( 1650 fullname='Conv3d_pad_same_dilated', 1651 constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2), 1652 cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)', 1653 input_size=(2, 3, 6, 5, 4), 1654 cudnn=True, 1655 with_tf32=True, 1656 tf32_precision=0.05, 1657 default_dtype=torch.double, 1658 ), 1659 dict( 1660 module_name='ConvTranspose3d', 1661 constructor_args=(2, 3, (2, 3, 2)), 1662 cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})', 1663 cudnn=True, 1664 input_size=(1, 2, 4, 5, 4), 1665 with_tf32=True, 1666 tf32_precision=0.05, 1667 default_dtype=torch.double, 1668 ), 1669 dict( 1670 module_name='ConvTranspose3d', 1671 constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)), 1672 cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2}) 1673 .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''', 1674 cudnn=True, 1675 input_size=(1, 2, 4, 5, 4), 1676 desc='dilated', 1677 with_tf32=True, 1678 tf32_precision=0.05, 1679 default_dtype=torch.double, 1680 ), 1681 dict( 1682 module_name='ReplicationPad3d', 1683 constructor_args=((1, 2, 3, 3, 2, 1),), 1684 cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', 1685 input_size=(2, 3, 2, 2, 2), 1686 default_dtype=torch.double, 1687 ), 1688 dict( 1689 module_name='ReplicationPad3d', 1690 constructor_args=((1, 2, 3, 3, 2, 1),), 1691 cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', 1692 input_size=(3, 2, 2, 2), 1693 reference_fn=single_batch_reference_fn, 1694 desc='no_batch_dim', 1695 default_dtype=torch.double, 1696 ), 1697 dict( 1698 module_name='ReplicationPad3d', 1699 constructor_args=((1, 2, 3, 3, 2, 1),), 1700 cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', 1701 input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True), 1702 skip_half=True, 1703 desc='complex' 1704 ), 1705 dict( 1706 module_name='Embedding', 1707 constructor_args=(4, 3), 1708 cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', 1709 input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), 1710 check_gradgrad=False, 1711 default_dtype=torch.double, 1712 decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971") 1713 ), 1714 dict( 1715 module_name='Embedding', 1716 constructor_args=(4, 3), 1717 cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', 1718 input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), 1719 check_gradgrad=False, 1720 desc='discontiguous', 1721 default_dtype=torch.double, 1722 decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971") 1723 ), 1724 dict( 1725 module_name='EmbeddingBag', 1726 constructor_args=(4, 3), 1727 cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', 1728 input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), 1729 check_gradgrad=False, 1730 desc='mean', 1731 default_dtype=torch.double, 1732 ), 1733 dict( 1734 module_name='EmbeddingBag', 1735 constructor_args=(4, 3), 1736 cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', 1737 input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), 1738 check_gradgrad=False, 1739 desc='discontiguous', 1740 default_dtype=torch.double, 1741 ), 1742 dict( 1743 module_name='EmbeddingBag', 1744 constructor_args=(4, 3, None, 2., False, 'sum'), 1745 cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) 1746 .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''', 1747 input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), 1748 check_gradgrad=False, 1749 desc='sum', 1750 default_dtype=torch.double, 1751 ), 1752 dict( 1753 module_name='EmbeddingBag', 1754 constructor_args=(4, 3, None, 2., False, 'max'), 1755 cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) 1756 .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''', 1757 input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), 1758 check_gradgrad=False, 1759 desc='max', 1760 default_dtype=torch.double, 1761 ), 1762 dict( 1763 fullname='EmbeddingBag_mean_padding_idx', 1764 constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1), 1765 cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)', 1766 input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), 1767 check_gradgrad=False, 1768 default_dtype=torch.double, 1769 ), 1770 dict( 1771 fullname='EmbeddingBag_sum_padding_idx', 1772 constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1), 1773 cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) 1774 .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''', 1775 input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), 1776 check_gradgrad=False, 1777 default_dtype=torch.double, 1778 ), 1779 dict( 1780 fullname='EmbeddingBag_max_padding_idx', 1781 constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1), 1782 cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) 1783 .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''', 1784 input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), 1785 check_gradgrad=False, 1786 default_dtype=torch.double, 1787 ), 1788 dict( 1789 fullname='EmbeddingBag_sparse', 1790 constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double), 1791 cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))', 1792 input_fn=lambda: torch.randperm(2).repeat(1, 2), 1793 check_gradgrad=False, 1794 has_sparse_gradients=True, 1795 ), 1796 dict( 1797 constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True), 1798 cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))', 1799 input_fn=lambda: torch.randperm(2).repeat(1, 2), 1800 fullname='Embedding_sparse', 1801 check_gradgrad=False, 1802 has_sparse_gradients=True, 1803 ), 1804 dict( 1805 module_name='PixelShuffle', 1806 constructor_args=(3,), 1807 cpp_constructor_args='torch::nn::PixelShuffleOptions(3)', 1808 input_size=(1, 9, 4, 4), 1809 default_dtype=torch.double, 1810 ), 1811 dict( 1812 module_name='PixelUnshuffle', 1813 constructor_args=(3,), 1814 cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)', 1815 input_size=(1, 1, 12, 12), 1816 default_dtype=torch.double, 1817 ), 1818 dict( 1819 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), 1820 cpp_options_args='''F::InterpolateFuncOptions() 1821 .size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', 1822 input_size=(1, 2, 4), 1823 fullname='interpolate_nearest_1d', 1824 pickle=False, 1825 default_dtype=torch.double, 1826 ), 1827 dict( 1828 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), 1829 cpp_options_args='''F::InterpolateFuncOptions() 1830 .size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', 1831 input_size=(0, 2, 4), 1832 fullname='interpolate_nearest_1d_zero_dim', 1833 pickle=False, 1834 ), 1835 dict( 1836 constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'), 1837 cpp_options_args='''F::InterpolateFuncOptions() 1838 .size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', 1839 input_size=(1, 2, 3), 1840 fullname='interpolate_nearest_tuple_1d', 1841 pickle=False, 1842 default_dtype=torch.double, 1843 ), 1844 dict( 1845 constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), 1846 cpp_options_args='''F::InterpolateFuncOptions() 1847 .size(std::nullopt).scale_factor(std::vector<double>({4.})).mode(torch::kNearest)''', 1848 input_size=(1, 2, 4), 1849 fullname='interpolate_nearest_scale_1d', 1850 pickle=False, 1851 default_dtype=torch.double, 1852 ), 1853 dict( 1854 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False), 1855 cpp_options_args='''F::InterpolateFuncOptions() 1856 .size(std::vector<int64_t>({12})) 1857 .scale_factor(std::nullopt) 1858 .mode(torch::kLinear) 1859 .align_corners(false)''', 1860 input_size=(1, 2, 4), 1861 fullname='interpolate_linear_1d', 1862 pickle=False, 1863 default_dtype=torch.double, 1864 ), 1865 dict( 1866 constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False), 1867 cpp_options_args='''F::InterpolateFuncOptions() 1868 .size(std::vector<int64_t>({4})) 1869 .scale_factor(std::nullopt) 1870 .mode(torch::kLinear) 1871 .align_corners(false)''', 1872 input_size=(1, 2, 3), 1873 fullname='interpolate_linear_tuple_1d', 1874 pickle=False, 1875 default_dtype=torch.double, 1876 ), 1877 dict( 1878 constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False), 1879 cpp_options_args='''F::InterpolateFuncOptions() 1880 .size(std::nullopt) 1881 .scale_factor(std::vector<double>({4.})) 1882 .mode(torch::kLinear) 1883 .align_corners(false)''', 1884 input_size=(1, 2, 4), 1885 fullname='interpolate_linear_scale_1d', 1886 pickle=False, 1887 default_dtype=torch.double, 1888 ), 1889 dict( 1890 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False), 1891 cpp_options_args='''F::InterpolateFuncOptions() 1892 .size(std::vector<int64_t>({12})) 1893 .scale_factor(std::nullopt) 1894 .mode(torch::kLinear) 1895 .align_corners(false)''', 1896 input_size=(0, 2, 4), 1897 fullname='interpolate_linear_1d_zero_dim', 1898 pickle=False, 1899 ), 1900 dict( 1901 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True), 1902 cpp_options_args='''F::InterpolateFuncOptions() 1903 .size(std::vector<int64_t>({12})) 1904 .scale_factor(std::nullopt) 1905 .mode(torch::kLinear) 1906 .align_corners(true)''', 1907 input_size=(1, 2, 4), 1908 fullname='interpolate_linear_1d_align_corners', 1909 pickle=False, 1910 default_dtype=torch.double, 1911 ), 1912 dict( 1913 constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True), 1914 cpp_options_args='''F::InterpolateFuncOptions() 1915 .size(std::nullopt) 1916 .scale_factor(std::vector<double>({4.})) 1917 .mode(torch::kLinear) 1918 .align_corners(true)''', 1919 input_size=(1, 2, 4), 1920 fullname='interpolate_linear_scale_1d_align_corners', 1921 pickle=False, 1922 default_dtype=torch.double, 1923 ), 1924 dict( 1925 constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'), 1926 cpp_options_args='''F::InterpolateFuncOptions() 1927 .size(std::vector<int64_t>({2, 2})) 1928 .scale_factor(std::nullopt) 1929 .mode(torch::kNearest)''', 1930 input_size=(1, 128, 1, 1), 1931 fullname='interpolate_nearest_2d_launch_configs', 1932 pickle=False, 1933 default_dtype=torch.double, 1934 ), 1935 dict( 1936 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), 1937 cpp_options_args='''F::InterpolateFuncOptions() 1938 .size(std::vector<int64_t>({12, 12})) 1939 .scale_factor(std::nullopt) 1940 .mode(torch::kNearest)''', 1941 input_size=(1, 2, 4, 4), 1942 fullname='interpolate_nearest_2d', 1943 pickle=False, 1944 default_dtype=torch.double, 1945 ), 1946 dict( 1947 constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'), 1948 cpp_options_args='''F::InterpolateFuncOptions() 1949 .size(std::vector<int64_t>({12, 16})) 1950 .scale_factor(std::nullopt) 1951 .mode(torch::kNearest)''', 1952 input_size=(1, 2, 3, 4), 1953 fullname='interpolate_nearest_tuple_2d', 1954 pickle=False, 1955 default_dtype=torch.double, 1956 ), 1957 dict( 1958 constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), 1959 cpp_options_args='''F::InterpolateFuncOptions() 1960 .size(std::nullopt) 1961 .scale_factor(std::vector<double>({4., 4.})) 1962 .mode(torch::kNearest)''', 1963 input_size=(1, 2, 4, 4), 1964 fullname='interpolate_nearest_scale_2d', 1965 pickle=False, 1966 default_dtype=torch.double, 1967 ), 1968 dict( 1969 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), 1970 cpp_options_args='''F::InterpolateFuncOptions() 1971 .size(std::vector<int64_t>({12, 12})) 1972 .scale_factor(std::nullopt) 1973 .mode(torch::kNearest)''', 1974 input_size=(0, 2, 4, 4), 1975 fullname='interpolate_nearest_2d_zero_dim', 1976 pickle=False, 1977 ), 1978 dict( 1979 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False), 1980 cpp_options_args='''F::InterpolateFuncOptions() 1981 .size(std::vector<int64_t>({12, 12})) 1982 .scale_factor(std::nullopt) 1983 .mode(torch::kBilinear) 1984 .align_corners(false)''', 1985 input_size=(1, 2, 4, 4), 1986 fullname='interpolate_bilinear_2d', 1987 pickle=False, 1988 default_dtype=torch.double, 1989 ), 1990 dict( 1991 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False), 1992 cpp_options_args='''F::InterpolateFuncOptions() 1993 .size(std::vector<int64_t>({12, 12})) 1994 .scale_factor(std::nullopt) 1995 .mode(torch::kBilinear) 1996 .align_corners(false)''', 1997 input_size=(0, 2, 4, 4), 1998 fullname='interpolate_bilinear_2d_zero_dim', 1999 pickle=False, 2000 ), 2001 dict( 2002 constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, 2003 mode='bilinear', align_corners=False), 2004 cpp_options_args='''F::InterpolateFuncOptions() 2005 .size(std::vector<int64_t>({4, 6})) 2006 .scale_factor(std::nullopt) 2007 .mode(torch::kBilinear) 2008 .align_corners(false)''', 2009 input_size=(1, 2, 2, 3), 2010 fullname='interpolate_bilinear_tuple_2d', 2011 pickle=False, 2012 default_dtype=torch.double, 2013 ), 2014 dict( 2015 constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., 2016 mode='bilinear', align_corners=False), 2017 cpp_options_args='''F::InterpolateFuncOptions() 2018 .size(std::nullopt) 2019 .scale_factor(std::vector<double>({4., 4.})) 2020 .mode(torch::kBilinear) 2021 .align_corners(false)''', 2022 input_size=(1, 2, 4, 4), 2023 fullname='interpolate_bilinear_scale_2d', 2024 pickle=False, 2025 default_dtype=torch.double, 2026 ), 2027 dict( 2028 constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.), 2029 mode='bilinear', align_corners=False), 2030 cpp_options_args='''F::InterpolateFuncOptions() 2031 .size(std::nullopt) 2032 .scale_factor(std::vector<double>({2., 2.})) 2033 .mode(torch::kBilinear) 2034 .align_corners(false)''', 2035 input_size=(1, 2, 4, 4), 2036 fullname='interpolate_bilinear_scale_tuple_shared_2d', 2037 pickle=False, 2038 default_dtype=torch.double, 2039 ), 2040 dict( 2041 constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), 2042 mode='bilinear', align_corners=False), 2043 cpp_options_args='''F::InterpolateFuncOptions() 2044 .size(std::nullopt) 2045 .scale_factor(std::vector<double>({2., 1.})) 2046 .mode(torch::kBilinear) 2047 .align_corners(false)''', 2048 input_size=(1, 2, 4, 4), 2049 fullname='interpolate_bilinear_scale_tuple_skewed_2d', 2050 pickle=False, 2051 default_dtype=torch.double, 2052 ), 2053 dict( 2054 constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True), 2055 cpp_options_args='''F::InterpolateFuncOptions() 2056 .size(std::vector<int64_t>({4, 6})) 2057 .scale_factor(std::nullopt) 2058 .mode(torch::kBilinear) 2059 .align_corners(true)''', 2060 input_size=(1, 2, 4, 4), 2061 fullname='interpolate_bilinear_tuple_2d_align_corners', 2062 pickle=False, 2063 default_dtype=torch.double, 2064 ), 2065 dict( 2066 constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), 2067 mode='bilinear', align_corners=True), 2068 cpp_options_args='''F::InterpolateFuncOptions() 2069 .size(std::nullopt) 2070 .scale_factor(std::vector<double>({2., 1.})) 2071 .mode(torch::kBilinear) 2072 .align_corners(true)''', 2073 input_size=(1, 2, 4, 4), 2074 fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners', 2075 pickle=False, 2076 default_dtype=torch.double, 2077 ), 2078 dict( 2079 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False), 2080 cpp_options_args='''F::InterpolateFuncOptions() 2081 .size(std::vector<int64_t>({12, 12})) 2082 .scale_factor(std::nullopt) 2083 .mode(torch::kBicubic) 2084 .align_corners(false)''', 2085 input_size=(1, 2, 4, 4), 2086 fullname='interpolate_bicubic_2d', 2087 pickle=False, 2088 default_dtype=torch.double, 2089 ), 2090 dict( 2091 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False), 2092 cpp_options_args='''F::InterpolateFuncOptions() 2093 .size(std::vector<int64_t>({12, 12})) 2094 .scale_factor(std::nullopt) 2095 .mode(torch::kBicubic) 2096 .align_corners(false)''', 2097 input_size=(0, 2, 4, 4), 2098 fullname='interpolate_bicubic_2d_zero_dim', 2099 pickle=False, 2100 ), 2101 dict( 2102 constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, 2103 mode='bicubic', align_corners=False), 2104 cpp_options_args='''F::InterpolateFuncOptions() 2105 .size(std::vector<int64_t>({4, 6})) 2106 .scale_factor(std::nullopt) 2107 .mode(torch::kBicubic) 2108 .align_corners(false)''', 2109 input_size=(1, 2, 2, 3), 2110 fullname='interpolate_bicubic_tuple_2d', 2111 pickle=False, 2112 default_dtype=torch.double, 2113 ), 2114 dict( 2115 constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False), 2116 cpp_options_args='''F::InterpolateFuncOptions() 2117 .size(std::nullopt) 2118 .scale_factor(std::vector<double>({4., 4.})) 2119 .mode(torch::kBicubic) 2120 .align_corners(false)''', 2121 input_size=(1, 2, 4, 4), 2122 fullname='interpolate_bicubic_scale_2d', 2123 pickle=False, 2124 default_dtype=torch.double, 2125 ), 2126 dict( 2127 constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.), 2128 mode='bicubic', align_corners=False), 2129 cpp_options_args='''F::InterpolateFuncOptions() 2130 .size(std::nullopt) 2131 .scale_factor(std::vector<double>({2., 2.})) 2132 .mode(torch::kBicubic) 2133 .align_corners(false)''', 2134 input_size=(1, 2, 4, 4), 2135 fullname='interpolate_bicubic_scale_tuple_shared_2d', 2136 pickle=False, 2137 default_dtype=torch.double, 2138 ), 2139 dict( 2140 constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), 2141 mode='bicubic', align_corners=False), 2142 cpp_options_args='''F::InterpolateFuncOptions() 2143 .size(std::nullopt) 2144 .scale_factor(std::vector<double>({2., 1.})) 2145 .mode(torch::kBicubic) 2146 .align_corners(false)''', 2147 input_size=(1, 2, 4, 4), 2148 fullname='interpolate_bicubic_scale_tuple_skewed_2d', 2149 pickle=False, 2150 default_dtype=torch.double, 2151 ), 2152 dict( 2153 constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True), 2154 cpp_options_args='''F::InterpolateFuncOptions() 2155 .size(std::vector<int64_t>({4, 6})) 2156 .scale_factor(std::nullopt) 2157 .mode(torch::kBicubic) 2158 .align_corners(true)''', 2159 input_size=(1, 2, 4, 4), 2160 fullname='interpolate_bicubic_tuple_2d_align_corners', 2161 pickle=False, 2162 default_dtype=torch.double, 2163 ), 2164 dict( 2165 constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), 2166 mode='bicubic', align_corners=True), 2167 cpp_options_args='''F::InterpolateFuncOptions() 2168 .size(std::nullopt) 2169 .scale_factor(std::vector<double>({2., 1.})) 2170 .mode(torch::kBicubic) 2171 .align_corners(true)''', 2172 input_size=(1, 2, 4, 4), 2173 fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners', 2174 pickle=False, 2175 default_dtype=torch.double, 2176 ), 2177 dict( 2178 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), 2179 cpp_options_args='''F::InterpolateFuncOptions() 2180 .size(std::vector<int64_t>({12, 12, 12})) 2181 .scale_factor(std::nullopt) 2182 .mode(torch::kNearest)''', 2183 input_size=(1, 2, 4, 4, 4), 2184 fullname='interpolate_nearest_3d', 2185 pickle=False, 2186 default_dtype=torch.double, 2187 ), 2188 dict( 2189 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), 2190 cpp_options_args='''F::InterpolateFuncOptions() 2191 .size(std::vector<int64_t>({12, 12, 12})) 2192 .scale_factor(std::nullopt) 2193 .mode(torch::kNearest)''', 2194 input_size=(0, 2, 4, 4, 4), 2195 fullname='interpolate_nearest_3d_zero_dim', 2196 pickle=False, 2197 ), 2198 dict( 2199 constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'), 2200 cpp_options_args='''F::InterpolateFuncOptions() 2201 .size(std::vector<int64_t>({12, 16, 16})) 2202 .scale_factor(std::nullopt) 2203 .mode(torch::kNearest)''', 2204 input_size=(1, 2, 3, 4, 4), 2205 fullname='interpolate_nearest_tuple_3d', 2206 pickle=False, 2207 default_dtype=torch.double, 2208 ), 2209 dict( 2210 constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), 2211 cpp_options_args='''F::InterpolateFuncOptions() 2212 .size(std::nullopt) 2213 .scale_factor(std::vector<double>({4., 4., 4.})) 2214 .mode(torch::kNearest)''', 2215 input_size=(1, 2, 4, 4, 4), 2216 fullname='interpolate_nearest_scale_3d', 2217 pickle=False, 2218 default_dtype=torch.double, 2219 ), 2220 dict( 2221 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False), 2222 cpp_options_args='''F::InterpolateFuncOptions() 2223 .size(std::vector<int64_t>({12, 12, 12})) 2224 .scale_factor(std::nullopt) 2225 .mode(torch::kTrilinear) 2226 .align_corners(false)''', 2227 input_size=(1, 2, 4, 4, 4), 2228 fullname='interpolate_trilinear_3d', 2229 pickle=False, 2230 default_dtype=torch.double, 2231 ), 2232 dict( 2233 constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False), 2234 cpp_options_args='''F::InterpolateFuncOptions() 2235 .size(std::vector<int64_t>({12, 12, 12})) 2236 .scale_factor(std::nullopt) 2237 .mode(torch::kTrilinear) 2238 .align_corners(false)''', 2239 input_size=(0, 2, 4, 4, 4), 2240 fullname='interpolate_trilinear_3d_zero_dim', 2241 pickle=False, 2242 ), 2243 dict( 2244 constructor=wrap_functional(F.interpolate, size=(4, 6, 6), 2245 scale_factor=None, mode='trilinear', align_corners=False), 2246 cpp_options_args='''F::InterpolateFuncOptions() 2247 .size(std::vector<int64_t>({4, 6, 6})) 2248 .scale_factor(std::nullopt) 2249 .mode(torch::kTrilinear) 2250 .align_corners(false)''', 2251 input_size=(1, 2, 2, 3, 3), 2252 fullname='interpolate_trilinear_tuple_3d', 2253 pickle=False, 2254 default_dtype=torch.double, 2255 ), 2256 dict( 2257 constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False), 2258 cpp_options_args='''F::InterpolateFuncOptions() 2259 .size(std::nullopt) 2260 .scale_factor(std::vector<double>({3., 3., 3.})) 2261 .mode(torch::kTrilinear) 2262 .align_corners(false)''', 2263 input_size=(1, 2, 3, 4, 5), 2264 fullname='interpolate_trilinear_scale_3d', 2265 # See https://github.com/pytorch/pytorch/issues/5006 2266 precision=3e-4, 2267 pickle=False, 2268 default_dtype=torch.double, 2269 ), 2270 dict( 2271 constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None, 2272 mode='trilinear', align_corners=True), 2273 cpp_options_args='''F::InterpolateFuncOptions() 2274 .size(std::vector<int64_t>({4, 6, 6})) 2275 .scale_factor(std::nullopt) 2276 .mode(torch::kTrilinear) 2277 .align_corners(true)''', 2278 input_size=(1, 2, 2, 3, 3), 2279 fullname='interpolate_trilinear_tuple_3d_align_corners', 2280 pickle=False, 2281 default_dtype=torch.double 2282 ), 2283 dict( 2284 constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True), 2285 cpp_options_args='''F::InterpolateFuncOptions() 2286 .size(std::nullopt) 2287 .scale_factor(std::vector<double>({3., 3., 3.})) 2288 .mode(torch::kTrilinear) 2289 .align_corners(true)''', 2290 input_size=(1, 2, 3, 4, 4), 2291 fullname='interpolate_trilinear_scale_3d_align_corners', 2292 # See https://github.com/pytorch/pytorch/issues/5006 2293 precision=3e-4, 2294 pickle=False, 2295 default_dtype=torch.double, 2296 ), 2297 dict( 2298 constructor=wrap_functional(F.softmax, dim=-1), 2299 cpp_options_args='F::SoftmaxFuncOptions(-1)', 2300 input_size=(2, 128), # trigger the last-dim algo in CUDA 2301 fullname='softmax_lastdim', 2302 pickle=False, 2303 default_dtype=torch.double, 2304 ), 2305 dict( 2306 constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64), 2307 cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)', 2308 input_size=(2, 128), 2309 fullname='softmax_lastdim_dtype', 2310 pickle=False, 2311 test_cuda=False, 2312 default_dtype=torch.double, 2313 ), 2314 dict( 2315 constructor=wrap_functional(F.softmax, dim=1), 2316 cpp_options_args='F::SoftmaxFuncOptions(1)', 2317 input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo 2318 fullname='softmax_spatial_special', 2319 pickle=False, 2320 default_dtype=torch.double, 2321 ), 2322 dict( 2323 constructor=wrap_functional(F.softmax, dim=1), 2324 cpp_options_args='F::SoftmaxFuncOptions(1)', 2325 input_size=(2, 2, 4, 4), # regular spatial algorithm 2326 fullname='softmax_spatial', 2327 pickle=False, 2328 default_dtype=torch.double, 2329 ), 2330 dict( 2331 constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64), 2332 cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)', 2333 input_size=(2, 2, 4, 4), # regular spatial algorithm 2334 fullname='softmax_spatial_dtype', 2335 pickle=False, 2336 test_cuda=False, 2337 default_dtype=torch.double, 2338 ), 2339 dict( 2340 constructor=wrap_functional(F.softmax, dim=0), 2341 cpp_options_args='F::SoftmaxFuncOptions(0)', 2342 input_size=(2, 3, 4, 5), 2343 fullname='softmax_functional_dim0', 2344 test_cuda=False, 2345 pickle=False, 2346 default_dtype=torch.double, 2347 ), 2348 dict( 2349 constructor=wrap_functional(F.softmax, dim=3), 2350 cpp_options_args='F::SoftmaxFuncOptions(3)', 2351 input_size=(2, 3, 4, 5), 2352 fullname='softmax_functional_dim3', 2353 test_cuda=False, 2354 pickle=False, 2355 default_dtype=torch.double, 2356 ), 2357 dict( 2358 constructor=wrap_functional(F.softmax, dim=-1), 2359 cpp_options_args='F::SoftmaxFuncOptions(-1)', 2360 input_size=(), 2361 fullname='softmax_functional_scalar', 2362 test_cuda=False, 2363 pickle=False, 2364 ), 2365 dict( 2366 constructor=wrap_functional(F.log_softmax, dim=-1), 2367 cpp_options_args='F::LogSoftmaxFuncOptions(-1)', 2368 input_size=(2, 128), # trigger the last-dim algo in CUDA 2369 fullname='log_softmax_lastdim', 2370 pickle=False, 2371 default_dtype=torch.double, 2372 ), 2373 dict( 2374 constructor=wrap_functional(F.log_softmax, dim=1), 2375 cpp_options_args='F::LogSoftmaxFuncOptions(1)', 2376 input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo 2377 fullname='log_softmax_spatial_special', 2378 pickle=False, 2379 default_dtype=torch.double, 2380 ), 2381 dict( 2382 constructor=wrap_functional(F.log_softmax, dim=1), 2383 cpp_options_args='F::LogSoftmaxFuncOptions(1)', 2384 input_size=(2, 2, 4, 4), # regular spatial algorithm 2385 fullname='log_softmax_spatial', 2386 pickle=False, 2387 default_dtype=torch.double, 2388 ), 2389 dict( 2390 constructor=wrap_functional(F.log_softmax, dim=0), 2391 cpp_options_args='F::LogSoftmaxFuncOptions(0)', 2392 input_size=(2, 3, 4, 5), 2393 fullname='log_softmax_dim0', 2394 pickle=False, 2395 default_dtype=torch.double, 2396 ), 2397 dict( 2398 constructor=wrap_functional(F.log_softmax, dim=3), 2399 cpp_options_args='F::LogSoftmaxFuncOptions(3)', 2400 input_size=(2, 3, 4, 5), 2401 fullname='log_softmax_dim3', 2402 pickle=False, 2403 default_dtype=torch.double, 2404 ), 2405 dict( 2406 constructor=wrap_functional(F.log_softmax, dim=0), 2407 cpp_options_args='F::LogSoftmaxFuncOptions(0)', 2408 input_size=(), 2409 fullname='log_softmax_scalar', 2410 pickle=False, 2411 ), 2412 dict( 2413 fullname='Unfold', 2414 constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)), 2415 cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', 2416 input_size=(2, 4, 3, 3), 2417 check_gradgrad=False, 2418 test_cuda=True, 2419 default_dtype=torch.double, 2420 ), 2421 dict( 2422 fullname='Fold', 2423 constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)), 2424 cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', 2425 input_size=(2, 16, 4), 2426 check_gradgrad=False, 2427 test_cuda=True, 2428 default_dtype=torch.double, 2429 ), 2430 dict( 2431 fullname='Fold_no_batch_dim_input', 2432 constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)), 2433 cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', 2434 input_size=(16, 4), 2435 check_gradgrad=False, 2436 ref=single_batch_reference_fn, 2437 test_cuda=True, 2438 default_dtype=torch.double, 2439 ), 2440 dict( 2441 fullname='Unfold_int_input', 2442 constructor=lambda: nn.Unfold(2, 1, 0, 1), 2443 cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)', 2444 input_size=(2, 4, 3, 3), 2445 check_gradgrad=False, 2446 test_cuda=True, 2447 default_dtype=torch.double, 2448 ), 2449 dict( 2450 fullname='Fold_int_input', 2451 constructor=lambda: nn.Fold(3, 2, 1, 0, 1), 2452 cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)', 2453 input_size=(2, 16, 4), 2454 check_gradgrad=False, 2455 test_cuda=True, 2456 default_dtype=torch.double, 2457 ), 2458 dict( 2459 fullname='Fold_no_batch_dim_int_input', 2460 constructor=lambda: nn.Fold(3, 2, 1, 0, 1), 2461 cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)', 2462 input_size=(16, 4), 2463 ref=single_batch_reference_fn, 2464 check_gradgrad=False, 2465 test_cuda=True, 2466 default_dtype=torch.double, 2467 ), 2468 dict( 2469 module_name='RReLU', 2470 constructor_args=(0.1, 0.9), 2471 cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)', 2472 input_size=(), 2473 desc='with_up_down_scalar', 2474 test_cuda=False, 2475 default_dtype=torch.double, 2476 ), 2477 dict( 2478 module_name='PairwiseDistance', 2479 input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), 2480 default_dtype=torch.double, 2481 ), 2482 dict( 2483 module_name='PairwiseDistance', 2484 input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)), 2485 desc='broadcast_lhs', 2486 default_dtype=torch.double, 2487 ), 2488 dict( 2489 module_name='PairwiseDistance', 2490 input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)), 2491 desc='broadcast_rhs', 2492 default_dtype=torch.double, 2493 ), 2494 dict( 2495 module_name='PairwiseDistance', 2496 constructor_args=(1.5, 1e-05, True), 2497 cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)', 2498 input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), 2499 desc='with_non_default_args', 2500 default_dtype=torch.double, 2501 ), 2502 dict( 2503 module_name='PairwiseDistance', 2504 input_fn=lambda: (torch.randn(8), torch.randn(8)), 2505 reference_fn=single_batch_reference_fn, 2506 desc='no_batch_dim', 2507 default_dtype=torch.double, 2508 ), 2509 dict( 2510 module_name='TransformerEncoderLayer', 2511 constructor_args=(4, 2, 16, 0.0), 2512 cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) 2513 .dim_feedforward(16) 2514 .dropout(0.0)''', 2515 input_size=(2, 3, 4), 2516 desc='relu_activation', 2517 with_tf32=True, 2518 tf32_precision=0.1, 2519 # TODO(#50743): figure out the error 2520 # RuntimeError: The size of tensor a (6) must match the size of tensor b (4) 2521 # at non-singleton dimension 2 2522 check_batched_grad=False, 2523 check_gradgrad=False, 2524 default_dtype=torch.double, 2525 ), 2526 dict( 2527 module_name='TransformerEncoderLayer', 2528 constructor_args=(4, 2, 8, 0.0, F.gelu), 2529 cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) 2530 .dim_feedforward(8) 2531 .dropout(0.0) 2532 .activation(torch::kGELU)''', 2533 input_size=(2, 3, 4), 2534 check_gradgrad=False, 2535 desc='gelu_activation', 2536 with_tf32=True, 2537 tf32_precision=0.08 if SM90OrLater else 0.05, 2538 default_dtype=torch.double, 2539 ), 2540 dict( 2541 module_name='TransformerDecoderLayer', 2542 constructor_args=(4, 2, 8, 0.0), 2543 cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) 2544 .dim_feedforward(8) 2545 .dropout(0.0)''', 2546 input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), 2547 check_gradgrad=False, 2548 desc='relu_activation', 2549 with_tf32=True, 2550 tf32_precision=0.05, 2551 default_dtype=torch.double, 2552 ), 2553 dict( 2554 module_name='TransformerDecoderLayer', 2555 constructor_args=(4, 2, 8, 0.0, F.gelu), 2556 cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) 2557 .dim_feedforward(8) 2558 .dropout(0.0) 2559 .activation(torch::kGELU)''', 2560 input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), 2561 check_gradgrad=False, 2562 desc='gelu_activation', 2563 with_tf32=True, 2564 tf32_precision=0.05, 2565 default_dtype=torch.double, 2566 ), 2567 dict( 2568 module_name='Transformer', 2569 constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu), 2570 cpp_constructor_args='''torch::nn::TransformerOptions() 2571 .d_model(4) 2572 .nhead(2) 2573 .num_encoder_layers(2) 2574 .num_decoder_layers(2) 2575 .dim_feedforward(8) 2576 .dropout(0.0) 2577 .activation(torch::kReLU)''', 2578 input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)), 2579 check_gradgrad=False, 2580 desc='multilayer_coder', 2581 with_tf32=True, 2582 tf32_precision=0.05 if SM90OrLater else 0.03, 2583 default_dtype=torch.double, 2584 ), 2585 dict( 2586 module_name='Linear', 2587 constructor_args=(3, 5), 2588 cpp_constructor_args='torch::nn::LinearOptions(3, 5)', 2589 input_fn=lambda: torch.rand(3), 2590 reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1], 2591 desc="no_batch_dim", 2592 with_tf32=True, 2593 tf32_precision=0.005, 2594 default_dtype=torch.double, 2595 ), 2596 dict( 2597 module_name='Flatten', 2598 cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)', 2599 constructor_args=(-3, -1), 2600 input_size=(3, 4, 5), 2601 reference_fn=single_batch_reference_fn, 2602 desc="no_batch_dim", 2603 default_dtype=torch.double, 2604 ), 2605 dict( 2606 module_name='Unflatten', 2607 cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})', 2608 constructor_args=(-2, torch.Size([2, 2])), 2609 input_size=(3, 4, 5), 2610 reference_fn=single_batch_reference_fn, 2611 desc="no_batch_dim", 2612 default_dtype=torch.double, 2613 ), 2614 dict( 2615 module_name='LayerNorm', 2616 constructor_args=([56, 56, 56], 1e-5, False), 2617 cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)', 2618 input_size=(4, 56, 56, 56), 2619 cudnn=True, 2620 check_eval=True, 2621 gradcheck_fast_mode=True, 2622 check_half=True, 2623 desc='3d_no_affine_large_feature', 2624 ), 2625] 2626 2627# add conv padding mode tests: 2628for padding_mode, cpp_padding_mode in zip( 2629 ['reflect', 'circular', 'replicate', 'zeros'], 2630 ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros']): 2631 # conv signature: 2632 # in_channels, out_channels, kernel_size, stride=1, 2633 # padding=0, dilation=1, groups=1, 2634 # bias=True, padding_mode='zeros' 2635 for d in (1, 2, 3): 2636 if d == 3 and padding_mode == 'reflect': 2637 # FIXME: remove after implementing reflection pad 3d 2638 # https://github.com/pytorch/pytorch/issues/27655 2639 continue 2640 padding = tuple(range(1, d + 1)) 2641 cpp_padding = '{' + ', '.join(map(str, padding)) + '}' 2642 input_size = (2, 2) + (4,) * d 2643 output_size = (2, 3) + tuple(p + 1 for p in padding) # simplified from `(4 + 2 * p - 3) // 2 + 1` 2644 new_module_tests.append( 2645 dict( 2646 module_name=f'Conv{d}d', 2647 constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode), 2648 cpp_constructor_args=f'''torch::nn::Conv{d}dOptions(2, 3, 3) 2649 .stride(2) 2650 .padding({cpp_padding}) 2651 .dilation(1) 2652 .groups(1) 2653 .bias(true) 2654 .padding_mode({cpp_padding_mode})''', 2655 input_size=input_size, 2656 output_size=output_size, 2657 cudnn=True, 2658 desc=f'{padding_mode}_stride2_pad2', 2659 with_tf32=True, 2660 tf32_precision=0.05, 2661 default_dtype=torch.double, 2662 ), 2663 ) 2664 2665# Check that non linear activations work with no batch dimensions 2666non_linear_activations_no_batch = [ 2667 'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU', 2668 'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU', 2669 'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh', 2670 'Tanhshrink', 'Threshold' 2671] 2672non_linear_activations_extra_info: Dict[str, dict] = { 2673 'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double}, 2674 'Threshold': {'constructor_args': (2., 1.)}, 2675 'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double}, 2676 'Hardswish': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double}, 2677 # For RRelu, test that compare CPU and GPU results fail because RNG 2678 # is different between CPU and GPU 2679 'RReLU': {'test_cuda': False, 'default_dtype': torch.double}, 2680 'ELU': {'default_dtype': torch.double}, 2681 'GELU': {'default_dtype': torch.double}, 2682 'GLU': {'default_dtype': torch.double}, 2683 'Hardshrink': {'default_dtype': torch.double}, 2684 'Hardtanh': {'default_dtype': torch.double}, 2685 'LeakyReLU': {'default_dtype': torch.double}, 2686 'LogSigmoid': {'default_dtype': torch.double}, 2687 'Mish': {'default_dtype': torch.double}, 2688 'PReLU': {'default_dtype': torch.double}, 2689 'ReLU6': {'default_dtype': torch.double}, 2690 'ReLU': {'default_dtype': torch.double}, 2691 'SELU': {'default_dtype': torch.double}, 2692 'SiLU': {'default_dtype': torch.double}, 2693 'Sigmoid': {'default_dtype': torch.double}, 2694 'Softplus': {'default_dtype': torch.double}, 2695 'Softshrink': {'default_dtype': torch.double}, 2696 'Softsign': {'default_dtype': torch.double}, 2697 'Tanh': {'default_dtype': torch.double}, 2698 'Tanhshrink': {'default_dtype': torch.double}, 2699} 2700for non_linear_activation in non_linear_activations_no_batch: 2701 activation_test_info = dict( 2702 module_name=non_linear_activation, 2703 input_size=(4,), 2704 reference_fn=single_batch_reference_fn, 2705 desc='no_batch_dim', 2706 test_cpp_api_parity=False, 2707 ) 2708 extra_info = non_linear_activations_extra_info.get(non_linear_activation, {}) 2709 activation_test_info.update(extra_info) 2710 new_module_tests.append(activation_test_info) 2711 2712 2713def kldivloss_reference(input, target, reduction='mean', log_target=False): 2714 if log_target: 2715 result = torch.exp(target) * (target - input) 2716 else: 2717 result = target * (target.log() - input) 2718 if reduction == 'mean': 2719 return result.mean() 2720 elif reduction == 'sum': 2721 return result.sum() 2722 elif reduction == 'batchmean' and result.dim() != 0: 2723 return result.sum() / result.size(0) 2724 return result 2725 2726 2727def nlllossNd_reference(input, target, weight=None, ignore_index=-100, 2728 reduction='mean'): 2729 assert input.dim() >= 3 2730 N = input.size(0) 2731 C = input.size(1) 2732 out_size = (N,) + input.size()[2:] 2733 output = torch.zeros(out_size).type_as(input) 2734 2735 if weight is None: 2736 weight = torch.ones(C).type_as(input) 2737 total_weight = 0 2738 for tup in product(*[range(size) for size in out_size]): 2739 t_nx = target[tup] 2740 norm = 0. if ignore_index == t_nx else weight[t_nx].item() 2741 input_index = list(tup) 2742 input_index.insert(1, t_nx) 2743 output[tup] = -input[tuple(input_index)] * norm 2744 total_weight += norm 2745 2746 if reduction == 'mean': 2747 return output.sum() / total_weight 2748 elif reduction == 'sum': 2749 return output.sum() 2750 return output 2751 2752 2753def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean', 2754 label_smoothing=0.0): 2755 assert input.dim() >= 2 2756 2757 input = torch.log_softmax(input, 1) 2758 C = input.size(1) 2759 if weight is None: 2760 weight = torch.ones(C).type_as(input) 2761 weight = weight.view(1, C, *(1 for _ in input.shape[2:])) 2762 2763 if label_smoothing > 0.0: 2764 assert label_smoothing <= 1.0 2765 target = (target * (1 - label_smoothing) + label_smoothing / C) 2766 2767 output = -(input * target * weight).sum(dim=1) 2768 if reduction == 'mean': 2769 return output.mean() 2770 elif reduction == 'sum': 2771 return output.sum() 2772 return output 2773 2774 2775def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100, 2776 reduction='mean', label_smoothing=0.0): 2777 log_softmax_input = torch.log_softmax(input, 1) 2778 nllloss = F.nll_loss( 2779 log_softmax_input, 2780 target, 2781 weight, 2782 ignore_index=ignore_index, 2783 reduction=reduction) 2784 2785 if label_smoothing == 0.0: 2786 return nllloss 2787 2788 assert 0.0 < label_smoothing <= 1.0 2789 2790 input = torch.log_softmax(input, 1) 2791 C = input.size(1) 2792 if weight is not None: 2793 input = input * weight.view(1, C, *(1 for _ in input.shape[2:])) 2794 2795 smooth_loss = -torch.sum(input, 1) 2796 2797 ignore_mask = target == ignore_index 2798 smooth_loss.masked_fill_(ignore_mask, 0.0) 2799 2800 if reduction == 'mean': 2801 if weight is not None: 2802 # TODO: This code can path can be removed if #61309 is resolved 2803 # loss is normalized by the weights to be consistent with nll_loss_nd 2804 ret = torch.sum(smooth_loss) / weight.gather(0, target.masked_select(ignore_mask.logical_not()).flatten()).sum() 2805 else: 2806 ret = torch.mean(smooth_loss.masked_select(ignore_mask.logical_not())) 2807 elif reduction == 'sum': 2808 ret = torch.sum(smooth_loss) 2809 else: 2810 ret = smooth_loss 2811 2812 return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C) 2813 2814 2815def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean', 2816 label_smoothing=0.0): 2817 if input.shape == target.shape: 2818 return cross_entropy_loss_prob_target_reference( 2819 input, 2820 target, 2821 weight=weight, 2822 reduction=reduction, 2823 label_smoothing=label_smoothing) 2824 else: 2825 return cross_entropy_loss_indices_target_reference( 2826 input, target, weight=weight, reduction=reduction, 2827 ignore_index=ignore_index, label_smoothing=label_smoothing 2828 ) 2829 2830 2831def nllloss_reference(input, target, weight=None, ignore_index=-100, 2832 reduction='mean'): 2833 2834 def nll_loss_helper(input, target, weight, ignore_index): 2835 if target == ignore_index: 2836 return (0, 0) 2837 norm = 1 if weight is None else weight[target] 2838 result = -input[target] * norm 2839 return (result, norm) 2840 2841 losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index) 2842 for i, t in zip(input, target)] 2843 losses, weights = zip(*losses_and_weights) 2844 losses_tensor = input.new_tensor(losses) 2845 if reduction == 'mean': 2846 return sum(losses_tensor) / sum(weights) 2847 elif reduction == 'sum': 2848 return sum(losses_tensor) 2849 else: 2850 return losses_tensor 2851 2852 2853def smoothl1loss_reference(input, target, reduction='mean', beta=1.0): 2854 abs_diff = (input - target).abs() 2855 ge_beta_mask = (abs_diff >= beta).type_as(abs_diff) 2856 lt_beta_mask = (abs_diff < beta).type_as(abs_diff) 2857 # when beta <= 0 we should just use l1_loss 2858 if beta == 0: 2859 output = abs_diff 2860 else: 2861 output = ge_beta_mask * (abs_diff - 0.5 * beta) + lt_beta_mask * 0.5 * (abs_diff ** 2) / beta 2862 if reduction == 'mean': 2863 return output.mean() 2864 elif reduction == 'sum': 2865 return output.sum() 2866 return output 2867 2868 2869def huberloss_reference(input, target, reduction='mean', delta=1.0): 2870 abs_diff = (input - target).abs() 2871 ge_delta_mask = (abs_diff >= delta) 2872 lt_delta_mask = (abs_diff < delta) 2873 output = ge_delta_mask * delta * (abs_diff - 0.5 * delta) + lt_delta_mask * 0.5 * (abs_diff ** 2) 2874 if reduction == 'mean': 2875 return output.mean() 2876 elif reduction == 'sum': 2877 return output.sum() 2878 return output 2879 2880 2881def _multilabelmarginloss_reference(input, target): 2882 targets = [] 2883 for target_index in target: 2884 if target_index < 0: 2885 break 2886 targets.append(target_index) 2887 2888 sum = 0 2889 for target_index in targets: 2890 for i in range(0, len(input)): 2891 if i not in targets: 2892 sum += max(0, 1 - input[target_index] + input[i]) 2893 2894 return sum 2895 2896 2897def multilabelmarginloss_reference(input, target, reduction='mean'): 2898 # make everything 2-dimensional 2899 input_dim = input.dim() 2900 if input.dim() < 2: 2901 assert target.dim() < 2 2902 input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0) 2903 target = target.unsqueeze(0) if target.dim() == 1 else target.unsqueeze(0).unsqueeze(0) 2904 2905 n = input.size(0) 2906 dim = input.size(1) 2907 output = input.new(n).zero_() 2908 for i in range(0, n): 2909 output[i] = _multilabelmarginloss_reference(input[i], target[i]) 2910 2911 if reduction == 'mean': 2912 return output.mean() / dim 2913 elif reduction == 'sum': 2914 return output.sum() / dim 2915 elif input_dim < 2: 2916 # we know we have (1, C) X (1, C) -> (1,), so squeeze will get us 2917 # back to correct dimensionality 2918 return output.squeeze() / dim 2919 else: 2920 return output / dim 2921 2922 2923def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'): 2924 margin_clamp = (margin - input).clamp(min=0).type_as(input) 2925 output = torch.where(target == 1, input, margin_clamp) 2926 2927 if reduction == 'mean': 2928 return output.mean() 2929 elif reduction == 'sum': 2930 return output.sum() 2931 return output 2932 2933 2934def softmarginloss_reference(input, target, reduction='mean'): 2935 output = (1 + (-input * target).exp()).log() 2936 2937 if reduction == 'mean': 2938 return output.mean() 2939 elif reduction == 'sum': 2940 return output.sum() 2941 return output 2942 2943 2944def _multimarginloss_reference(input, target_idx, p, margin, weight): 2945 if weight is None: 2946 weight = input.new(len(input)).fill_(1) 2947 2948 output = 0 2949 for i in range(0, len(input)): 2950 if i != target_idx: 2951 output += weight[target_idx] * (max(0, (margin - input[target_idx] + input[i])) ** p) 2952 return output 2953 2954 2955def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'): 2956 if input.dim() < 2: 2957 input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0) 2958 2959 target_dim = target.dim() 2960 if target.dim() == 0: 2961 target = target.unsqueeze(0) 2962 2963 n = input.size(0) 2964 dim = input.size(1) 2965 output = input.new(n) 2966 for x in range(0, n): 2967 output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight) 2968 2969 if reduction == 'mean': 2970 return output.mean() / dim 2971 elif reduction == 'sum': 2972 return output.sum() / dim 2973 elif target_dim == 0: 2974 return output.squeeze(0) / dim 2975 return output / dim 2976 2977 2978def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'): 2979 def _cos(a, b): 2980 cos = a.new(a.size(0)) 2981 for i in range(0, a.size(0)): 2982 cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5) 2983 return cos 2984 2985 output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0)) 2986 2987 if reduction == 'mean': 2988 return output.mean() 2989 elif reduction == 'sum': 2990 return output.sum() 2991 return output 2992 2993 2994def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, 2995 reduction='mean'): 2996 d_p = torch.pairwise_distance(anchor, positive, p, eps) 2997 d_n = torch.pairwise_distance(anchor, negative, p, eps) 2998 if swap: 2999 d_s = torch.pairwise_distance(positive, negative, p, eps) 3000 d_n = torch.min(d_n, d_s) 3001 3002 output = torch.clamp(margin + d_p - d_n, min=0.0) 3003 if reduction == 'mean': 3004 return output.mean() 3005 elif reduction == 'sum': 3006 return output.sum() 3007 return output 3008 3009 3010def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'): 3011 output = (-target * (input1 - input2) + margin).clamp(min=0) 3012 if reduction == 'mean': 3013 return output.mean() 3014 elif reduction == 'sum': 3015 return output.sum() 3016 return output 3017 3018 3019# this directly follows Graves et al.'s paper, in contrast to the production implementation, it does not use log-space 3020def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'): 3021 input_lengths = torch.as_tensor(input_lengths, dtype=torch.long) 3022 target_lengths = torch.as_tensor(target_lengths, dtype=torch.long) 3023 dt = log_probs.dtype 3024 log_probs = log_probs.double() # we need the accuracy as we are not in logspace 3025 targets = targets.long() 3026 cum_target_lengths = target_lengths.cumsum(0) 3027 losses = [] 3028 for i in range(log_probs.size(1)): 3029 input_length = input_lengths[i].item() 3030 target_length = target_lengths[i].item() 3031 cum_target_length = cum_target_lengths[i].item() 3032 targets_prime = targets.new_full((2 * target_length + 1,), blank) 3033 if targets.dim() == 2: 3034 targets_prime[1::2] = targets[i, :target_length] 3035 else: 3036 targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length] 3037 probs = log_probs[:input_length, i].exp() 3038 alpha = log_probs.new_zeros((target_length * 2 + 1,)) 3039 alpha[0] = probs[0, blank] 3040 alpha[1] = probs[0, targets_prime[1]] 3041 mask_third = (targets_prime[:-2] != targets_prime[2:]) 3042 for t in range(1, input_length): 3043 alpha_next = alpha.clone() 3044 alpha_next[1:] += alpha[:-1] 3045 alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1)) 3046 alpha = probs[t, targets_prime] * alpha_next 3047 losses.append(-alpha[-2:].sum().log()[None]) 3048 output = torch.cat(losses, 0) 3049 if reduction == 'mean': 3050 output = (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean() 3051 elif reduction == 'sum': 3052 output = output.sum() 3053 output = output.to(dt) 3054 return output 3055 3056 3057loss_reference_fns: Dict['str', Callable] = { 3058 'KLDivLoss': kldivloss_reference, 3059 'KLDivLoss_log_target': partial(kldivloss_reference, log_target=True), 3060 'NLLLoss': nllloss_reference, 3061 'NLLLossNd': nlllossNd_reference, 3062 'SmoothL1Loss': smoothl1loss_reference, 3063 'HuberLoss': huberloss_reference, 3064 'MultiLabelMarginLoss': multilabelmarginloss_reference, 3065 'HingeEmbeddingLoss': hingeembeddingloss_reference, 3066 'SoftMarginLoss': softmarginloss_reference, 3067 'MultiMarginLoss': multimarginloss_reference, 3068 'CosineEmbeddingLoss': cosineembeddingloss_reference, 3069 'TripletMarginLoss': tripletmarginloss_reference, 3070 'MarginRankingLoss': marginrankingloss_reference, 3071 'CTCLoss': ctcloss_reference, 3072 'CrossEntropyLoss': cross_entropy_loss_reference 3073} 3074 3075 3076criterion_tests = [] 3077 3078 3079def single_batch_reference_criterion_fn(*args): 3080 """Reference function for criterion supporting no batch dimensions. 3081 3082 The criterion is passed the input and target in batched form with a single item. 3083 The output is squeezed to compare with the no-batch input. 3084 """ 3085 criterion = args[-1] 3086 3087 def unsqueeze_inp(inp): 3088 if isinstance(inp, (list, tuple)): 3089 return [t.unsqueeze(0) for t in inp] 3090 return inp.unsqueeze(0) 3091 3092 def flatten(xs): 3093 result = [] 3094 if isinstance(xs, (list, tuple)): 3095 for x in xs: 3096 result.extend(flatten(x)) 3097 else: 3098 result.append(xs) 3099 return result 3100 3101 single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]]) 3102 3103 output = criterion(*single_batch_input_args) 3104 reduction = get_reduction(criterion) 3105 3106 if reduction == 'none': 3107 return output.squeeze(0) 3108 # reduction is 'sum' or 'mean' which results in a scalar 3109 return output 3110 3111 3112# Check that regression criterion work with no batch dimensions 3113regression_criterion_no_batch = [ 3114 'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss' 3115] 3116reductions = ['none', 'mean', 'sum'] 3117for name, reduction in product(regression_criterion_no_batch, reductions): 3118 regression_test_info = dict( 3119 fullname=f"{name}_no_batch_dim_{reduction}", 3120 constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction), 3121 input_size=(3, ), 3122 target_size=(3, ), 3123 reference_fn=single_batch_reference_criterion_fn, 3124 test_cpp_api_parity=False, 3125 default_dtype=torch.double, 3126 ) 3127 criterion_tests.append(regression_test_info) 3128 3129 3130for reduction in reductions: 3131 regression_test_info = dict( 3132 fullname=f"KLDivLoss_no_batch_dim_{reduction}", 3133 constructor=lambda: nn.KLDivLoss(reduction=reduction), 3134 input_fn=lambda: torch.rand((3,)).log(), 3135 target_fn=lambda: torch.rand((3,)), 3136 reference_fn=single_batch_reference_criterion_fn, 3137 test_cpp_api_parity=False, 3138 default_dtype=torch.double, 3139 ) 3140 criterion_tests.append(regression_test_info) 3141 3142 3143# Check that classification criterion work with no batch dimensions 3144# List of tuples of (name, input_fn, target_fn) 3145classification_criterion_no_batch = [ 3146 ( 3147 'BCELoss', 3148 lambda: torch.sigmoid(torch.randn(9, dtype=torch.double)), 3149 lambda: torch.randn(9, dtype=torch.double).gt(0).to(torch.double) 3150 ), 3151 ('BCEWithLogitsLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9, dtype=torch.double)), 3152 ('HingeEmbeddingLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)), 3153 ('MultiLabelMarginLoss', lambda: torch.randn(4, dtype=torch.double), lambda: torch.tensor([3, 0, -1, 1])), 3154 ('SoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)), 3155 ('NLLLoss', lambda: F.log_softmax(torch.randn(3, dtype=torch.double), dim=0), lambda: torch.tensor(1)), 3156 ( 3157 'CosineEmbeddingLoss', 3158 lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)), 3159 lambda: torch.tensor(1, dtype=torch.double) 3160 ), 3161 # For MarginRankingLoss, input_fn : (x1, x2) and target_fn : target 3162 ('MarginRankingLoss', lambda: (torch.randn(()), torch.randn(())), lambda: torch.randn(()).sign()), 3163 # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative 3164 ( 3165 'TripletMarginLoss', 3166 lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)), 3167 lambda: torch.randn(9, dtype=torch.double) 3168 ), 3169 ('MultiLabelSoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9)), 3170] 3171classification_criterion_no_batch_extra_info: Dict[str, dict] = { 3172 'MultiLabelMarginLoss': {'check_gradgrad': False}, 3173} 3174# TODO : Fix these discrepancies 3175classification_cpp_parity = { 3176 'BCELoss': False, 3177 'BCEWithLogitsLoss': False, 3178 'HingeEmbeddingLoss': False, 3179 'NLLLoss': False, 3180 'SoftMarginLoss': False, 3181} 3182reductions = ['none', 'mean', 'sum'] 3183for (name, input_fn, target_fn), reduction in product(classification_criterion_no_batch, 3184 reductions): 3185 classification_test_info = dict( 3186 fullname=f"{name}_no_batch_dim_{reduction}", 3187 constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction), 3188 input_fn=lambda f=input_fn: f(), 3189 target_fn=lambda f=target_fn: f(), 3190 reference_fn=single_batch_reference_criterion_fn, 3191 test_cpp_api_parity=True, 3192 has_parity=classification_cpp_parity.get(name, True) 3193 ) 3194 extra_info = classification_criterion_no_batch_extra_info.get(name, {}) 3195 classification_test_info.update(extra_info) 3196 criterion_tests.append(classification_test_info) 3197 3198 3199class NNTestCase(TestCase): 3200 3201 # _forward is defined in classes inheriting from NNTestCase 3202 @abstractmethod 3203 def _forward(self, *args, **kwargs): 3204 raise NotImplementedError 3205 3206 @abstractmethod 3207 def _get_parameters(self, module: nn.Module) -> Tuple[List[nn.Parameter], List[nn.Parameter]]: 3208 raise NotImplementedError 3209 3210 @abstractmethod 3211 def _zero_grad_parameters(self, module: nn.Module) -> None: 3212 raise NotImplementedError 3213 3214 @abstractmethod 3215 def _backward(self, module: nn.Module, 3216 input: _TensorOrTensors, output: torch.Tensor, 3217 grad_output: Union[torch.Tensor, Sequence[torch.Tensor]], 3218 create_graph: bool = False): 3219 raise NotImplementedError 3220 3221 def _jacobian(self, input, num_out): 3222 if isinstance(input, tuple): 3223 return tuple(self._jacobian(elem, num_out) for elem in input) 3224 elif isinstance(input, list): 3225 return [self._jacobian(elem, num_out) for elem in input] 3226 else: 3227 return torch.zeros(input.nelement(), num_out) 3228 3229 def _flatten_tensors(self, x): 3230 if isinstance(x, torch.Tensor): 3231 if x.is_sparse: 3232 return x.to_dense().view(-1) 3233 else: 3234 return x.view(-1) 3235 else: 3236 return tuple(self._flatten_tensors(a) for a in x) 3237 3238 def _zero_grad_input(self, input): 3239 if isinstance(input, torch.Tensor): 3240 if input.requires_grad and input.grad is not None: 3241 input.grad.zero_() 3242 input.grad.detach_() 3243 else: 3244 for i in input: 3245 self._zero_grad_input(i) 3246 3247 def _analytical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True): 3248 output = self._forward(module, input) 3249 output_size = output.nelement() 3250 3251 if jacobian_input: 3252 jacobian_inp = self._jacobian(input, output_size) 3253 flat_jacobian_input = list(_iter_tensors(jacobian_inp)) 3254 3255 if jacobian_parameters: 3256 num_param = sum(p.numel() for p in self._get_parameters(module)[0]) 3257 jacobian_param = torch.zeros(num_param, output_size) 3258 3259 for i in range(output_size): 3260 param, d_param = self._get_parameters(module) 3261 # make non grad zeros 3262 d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param)] 3263 3264 d_out = torch.zeros_like(output) 3265 flat_d_out = d_out.view(-1) 3266 flat_d_out[i] = 1 3267 3268 if jacobian_parameters: 3269 self._zero_grad_parameters(module) 3270 # Tensors will accumulate gradient from multiple steps 3271 if jacobian_input: 3272 self._zero_grad_input(input) 3273 d_input = self._backward(module, input, output, d_out) 3274 3275 if jacobian_input: 3276 for jacobian_x, d_x in zip(flat_jacobian_input, _iter_tensors(d_input)): 3277 jacobian_x[:, i] = d_x.contiguous().view(-1) 3278 if jacobian_parameters: 3279 jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0) 3280 3281 res: Tuple[torch.Tensor, ...] = () 3282 if jacobian_input: 3283 res += jacobian_inp, 3284 if jacobian_parameters: 3285 res += jacobian_param, 3286 3287 return res 3288 3289 def _numerical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True): 3290 def fw(*input): 3291 return self._forward(module, input).detach() 3292 3293 res: Tuple[torch.Tensor, ...] = () 3294 if jacobian_input: 3295 res += _get_numerical_jacobian(fw, input, eps=1e-6), 3296 if jacobian_parameters: 3297 param, _ = self._get_parameters(module) 3298 to_cat = [] 3299 for p in param: 3300 jacobian = _get_numerical_jacobian(fw, input, target=p, eps=1e-6) 3301 # get_numerical_jacobian returns a list of tuples but we require a tensor 3302 to_cat.append(jacobian[0][0]) 3303 res += (torch.cat(to_cat, 0),) 3304 return res 3305 3306 def check_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True): 3307 jacobian_parameters = bool(self._get_parameters(module)[0]) 3308 analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters) 3309 numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters) 3310 analytical_t = list(_iter_tensors(analytical)) 3311 numerical_t = list(_iter_tensors(numerical)) 3312 3313 differences = [] 3314 for a, n in zip(analytical_t, numerical_t): 3315 if a.numel() != 0: 3316 differences.append(a.add(n, alpha=-1).abs().max()) 3317 # TODO: compare structure (ensure analytic jacobian has correct shape) 3318 if len(differences) > 0: 3319 self.assertLessEqual(max(differences), PRECISION) # type: ignore[type-var] 3320 3321 3322class TestBase: 3323 3324 _required_arg_names = {'constructor_args', 'input', 'extra_args'} 3325 3326 def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs): 3327 self.desc = desc 3328 self.fullname = fullname 3329 self.constructor = constructor 3330 self.reference_fn = reference_fn 3331 for name in self._required_arg_names: 3332 if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs: 3333 if name in {'constructor_args', 'extra_args'}: 3334 kwargs[name] = () 3335 else: 3336 raise ValueError(f"{self.get_name()}: Specify {name} by a value, a function to generate it, or it's size!") 3337 self._extra_kwargs = kwargs 3338 self._arg_cache = {} 3339 3340 def get_name(self): 3341 if self.fullname is not None: 3342 return 'test_' + self.fullname 3343 3344 test_name = 'test_' + self.constructor.__name__ 3345 if self.desc: 3346 test_name += '_' + self.desc 3347 return test_name 3348 3349 def _unpack(self, value): 3350 if isinstance(value, torch.Tensor): 3351 return value 3352 elif is_iterable(value): 3353 return type(value)(self._unpack(v) for v in value) 3354 else: 3355 return value 3356 3357 @property 3358 def constructor_args(self): 3359 return self._get_arg('constructor_args', True) 3360 3361 @property 3362 def extra_args(self): 3363 return self._get_arg('extra_args', True) 3364 3365 def _get_arg(self, name, unpack): 3366 assert name in self._required_arg_names 3367 3368 if name not in self._arg_cache: 3369 fn_name = name + '_fn' 3370 size_name = name + '_size' 3371 3372 if name in self._extra_kwargs: 3373 self._arg_cache[name] = self._extra_kwargs[name] 3374 elif fn_name in self._extra_kwargs: 3375 self._arg_cache[name] = self._extra_kwargs[fn_name]() 3376 else: 3377 assert size_name in self._extra_kwargs, \ 3378 f"Missing `{name}`, `{size_name}` or `{fn_name}` for {self.get_name()}" 3379 3380 def map_tensor_sizes(sizes): 3381 if isinstance(sizes, list): 3382 return [map_tensor_sizes(s) for s in sizes] 3383 elif isinstance(sizes, torch.Tensor): 3384 return sizes.double() 3385 else: 3386 return torch.randn(sizes) 3387 3388 self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name]) 3389 3390 return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name] 3391 3392 def _get_input(self, unpack=True): 3393 return self._get_arg('input', unpack) 3394 3395 def __call__(self, test_case): 3396 raise NotImplementedError 3397 3398 3399class ModuleTest(TestBase): 3400 3401 @abstractmethod 3402 def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any: 3403 raise NotImplementedError 3404 3405 def __init__(self, *args, **kwargs): 3406 super().__init__(*args, **kwargs) 3407 self.jacobian_input = kwargs.get('jacobian_input', True) 3408 self.should_test_cuda = kwargs.get('test_cuda', True) 3409 self.should_test_pickle = kwargs.get('pickle', True) 3410 self.check_gradgrad = kwargs.get('check_gradgrad', True) 3411 self.FIXME_no_cuda_gradgrad_comparison = \ 3412 kwargs.get('FIXME_no_cuda_gradgrad_comparison', False) 3413 self.precision = kwargs.get('precision', 2e-4) 3414 self.check_forward_only = kwargs.get('check_forward_only', False) 3415 self.default_dtype = kwargs.get('default_dtype', None) 3416 if self.default_dtype is None: 3417 self.default_dtype = torch.get_default_dtype() 3418 3419 def __call__(self, test_case): 3420 with set_default_dtype(self.default_dtype): 3421 module = self.constructor(*self.constructor_args) 3422 input = self._get_input() 3423 3424 if self.reference_fn is not None: 3425 out = test_case._forward(module, input) 3426 ref_input = deepcopy(input) 3427 ref_module = deepcopy(module) 3428 expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0], ref_module) 3429 test_case.assertEqual(out, expected_out, exact_dtype=False) 3430 if self.check_forward_only: 3431 return 3432 self.test_noncontig(test_case, module, input) 3433 3434 if self.should_test_pickle: 3435 # TODO: do this with in-memory files as soon as torch.save will support it 3436 with tempfile.TemporaryFile() as f: 3437 test_case._forward(module, input) 3438 torch.save(module, f) 3439 f.seek(0) 3440 # weights_only=False as this is legacy code that saves the model 3441 module_copy = torch.load(f, weights_only=False) 3442 test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input)) 3443 3444 self._do_test(test_case, module, input) 3445 3446 def noncontiguize(self, obj): 3447 if isinstance(obj, list): 3448 return [self.noncontiguize(o) for o in obj] 3449 elif isinstance(obj, tuple): 3450 return tuple(self.noncontiguize(o) for o in obj) 3451 tensor = obj 3452 ndim = tensor.dim() 3453 # Always making only the last dimension noncontiguous is easy to hide 3454 # bugs because .view(-1) will still work. So try to find a dim with size 3455 # > 1 and make that non-contiguous, i.e., stack + select on the 3456 # dimension directly after that. 3457 dim = ndim 3458 for d in range(ndim): 3459 if tensor.size(d) > 1: 3460 dim = d + 1 3461 break 3462 noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach() 3463 assert noncontig.numel() == 1 or noncontig.numel() == 0 or not noncontig.is_contiguous() 3464 noncontig.requires_grad = tensor.requires_grad 3465 return noncontig 3466 3467 def test_noncontig(self, test_case, module, input): 3468 # check no scalars, can't make non-contig 3469 if isinstance(input, torch.Tensor) and input.dim() == 0: 3470 return 3471 if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)): 3472 return 3473 3474 test_case._zero_grad_parameters(module) 3475 test_case._zero_grad_input(input) 3476 with freeze_rng_state(): 3477 output = test_case._forward(module, input) 3478 if getattr(module, "return_indices", False): 3479 output = output[0] 3480 grad_output = output.new(output.shape).normal_() 3481 output = output.clone() 3482 d_input = deepcopy(test_case._backward(module, input, output, grad_output)) 3483 d_param = deepcopy(test_case._get_parameters(module)[1]) 3484 3485 nc_input = self.noncontiguize(input) 3486 nc_grad_output = self.noncontiguize(grad_output) 3487 for contig_i, contig_g in product((True, False), repeat=2): 3488 i = input if contig_i else nc_input 3489 # Some ops, e.g., nn.Flatten, return gradient that shares 3490 # storage with the grad_output. Hence we copy here. 3491 go = deepcopy(grad_output if contig_g else nc_grad_output) 3492 test_case._zero_grad_parameters(module) 3493 test_case._zero_grad_input(i) 3494 with freeze_rng_state(): 3495 out = test_case._forward(module, i) 3496 if getattr(module, "return_indices", False): 3497 out = out[0] 3498 grad = test_case._backward(module, i, out, go) 3499 3500 test_case.assertEqual(out, output) 3501 test_case.assertEqual(grad, d_input, atol=1e-4, rtol=0) 3502 test_case.assertEqual(test_case._get_parameters(module)[1], d_param) 3503 3504 def test_cuda(self, test_case): 3505 if not TEST_CUDA or not self.should_test_cuda: 3506 raise unittest.SkipTest('Excluded from CUDA tests') 3507 3508 with set_default_dtype(self.default_dtype): 3509 cpu_input = self._get_input() 3510 3511 type_map = {torch.double: torch.float} 3512 cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,) 3513 3514 is_any_input_complex = any(isinstance(t, torch.Tensor) and t.dtype.is_complex for t in cpu_input_tuple) 3515 3516 gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map) 3517 3518 cpu_module = self.constructor(*self.constructor_args) 3519 gpu_module = self.constructor(*self.constructor_args).float().cuda() 3520 cpu_param = test_case._get_parameters(cpu_module) 3521 gpu_param = test_case._get_parameters(gpu_module) 3522 for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]): 3523 gpu_p.data.copy_(cpu_p) 3524 3525 test_case._zero_grad_input(cpu_input_tuple) 3526 test_case._zero_grad_input(gpu_input_tuple) 3527 test_case._zero_grad_parameters(cpu_module) 3528 test_case._zero_grad_parameters(gpu_module) 3529 cpu_output = test_case._forward(cpu_module, cpu_input_tuple) 3530 gpu_output = test_case._forward(gpu_module, gpu_input_tuple) 3531 if getattr(cpu_module, "return_indices", False): 3532 cpu_output = cpu_output[0] 3533 gpu_output = gpu_output[0] 3534 test_case.assertEqual(cpu_output, gpu_output, atol=self.precision, rtol=0, exact_dtype=False) 3535 3536 # Run backwards on CPU and GPU and compare results 3537 for _ in range(5): 3538 cpu_gradOutput = cpu_output.clone().normal_() 3539 gpu_gradOutput = cpu_gradOutput.type_as(gpu_output) 3540 cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput) 3541 gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput) 3542 test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False) 3543 for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]): 3544 test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0) 3545 3546 # Run double-backwards on CPU and GPU and compare results 3547 if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison: 3548 cpu_output = cpu_module(*cpu_input_tuple) 3549 gpu_output = gpu_module(*gpu_input_tuple) 3550 if getattr(cpu_module, "return_indices", False): 3551 cpu_output = cpu_output[0] 3552 gpu_output = gpu_output[0] 3553 3554 cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True) 3555 gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach() 3556 gpu_gradOutput.requires_grad = True 3557 3558 cpu_gradInputs = torch.autograd.grad( 3559 cpu_output, 3560 cpu_input_tuple + tuple(cpu_module.parameters()), 3561 cpu_gradOutput, 3562 create_graph=True) 3563 gpu_gradInputs = torch.autograd.grad( 3564 gpu_output, 3565 gpu_input_tuple + tuple(gpu_module.parameters()), 3566 gpu_gradOutput, 3567 create_graph=True) 3568 3569 for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs): 3570 test_case.assertEqual(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0, exact_dtype=False) 3571 3572 # We mix output into the second backwards computation so that 3573 # torch.autograd.grad doesn't complain that some inputs 3574 # are unreachable (which can happen if you differentiate 3575 # only on the gradient. 3576 if is_any_input_complex: 3577 outputs_cpu = cpu_output.sum().abs() + sum(x.sum().abs() for x in cpu_gradInputs) 3578 outputs_gpu = gpu_output.sum().abs() + sum(x.sum().abs() for x in gpu_gradInputs) 3579 else: 3580 outputs_cpu = cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs) 3581 outputs_gpu = gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs) 3582 3583 cpu_gg = torch.autograd.grad( 3584 outputs_cpu, 3585 cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()), 3586 retain_graph=True) 3587 gpu_gg = torch.autograd.grad( 3588 outputs_gpu, 3589 gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()), 3590 retain_graph=True) 3591 test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False) 3592 for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg): 3593 test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0, exact_dtype=False) 3594 3595 self.test_noncontig(test_case, gpu_module, gpu_input_tuple) 3596 3597 3598class InputVariableMixin: 3599 def _get_input(self): 3600 input = TestBase._get_input(self, False) # type: ignore[arg-type] 3601 3602 def map_variables(i): 3603 if isinstance(i, torch.Tensor): 3604 if i.is_floating_point() or i.is_complex(): 3605 i.requires_grad = True 3606 return i 3607 else: 3608 return type(i)(map_variables(elem) for elem in i) 3609 3610 return map_variables(input) 3611 3612 3613class NewModuleTest(InputVariableMixin, ModuleTest): # type: ignore[misc] 3614 def __init__(self, *args, **kwargs): 3615 super().__init__(*args, **kwargs) 3616 self.cudnn = kwargs.get('cudnn', False) 3617 self.check_inplace = kwargs.get('check_inplace', False) 3618 self.check_gradgrad = kwargs.get('check_gradgrad', True) 3619 self.skip_double = kwargs.get('skip_double', False) 3620 self.skip_half = kwargs.get('skip_half', False) 3621 self.with_tf32 = kwargs.get('with_tf32', False) 3622 self.tf32_precision = kwargs.get('tf32_precision', 0.001) 3623 self.test_cpu = kwargs.get('test_cpu', True) 3624 self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False) 3625 self.check_batched_grad = kwargs.get('check_batched_grad', True) 3626 self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None) 3627 self.supports_forward_ad = kwargs.get('supports_forward_ad', False) 3628 self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False) 3629 3630 def _check_gradients(self, test_case, module, input_tuple): 3631 params = tuple(x for x in module.parameters()) 3632 num_inputs = len(input_tuple) 3633 3634 def fn_to_gradcheck(*inputs_and_params, **kwargs): 3635 assert not kwargs 3636 return test_case._forward(module, inputs_and_params[:num_inputs]) 3637 3638 # gradcheck doesn't support operators that take in dense inputs but 3639 # return sparse parameters. This only happens in the case of nn.Embedding 3640 # and nn.EmbeddingBag. Instead, we call `self.check_jacobian`, which 3641 # is a slightly different version of gradcheck that can handle this. 3642 if self.has_sparse_gradients: 3643 assert num_inputs == 1 3644 test_input_jacobian = torch.is_floating_point(input_tuple[0]) 3645 test_case.check_jacobian(module, input_tuple[0], test_input_jacobian) 3646 else: 3647 test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params, 3648 check_batched_grad=self.check_batched_grad, 3649 fast_mode=self.gradcheck_fast_mode, 3650 check_forward_ad=self.supports_forward_ad)) 3651 3652 if self.check_gradgrad: 3653 test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params, 3654 check_batched_grad=self.check_batched_grad, 3655 fast_mode=self.gradcheck_fast_mode, 3656 check_fwd_over_rev=self.supports_fwgrad_bwgrad)) 3657 3658 def _do_test(self, test_case, module, input): 3659 num_threads = torch.get_num_threads() 3660 torch.set_num_threads(1) 3661 input_tuple = input if isinstance(input, tuple) else (input,) 3662 3663 self._check_gradients(test_case, module, input_tuple) 3664 3665 # check if module can be printed 3666 module.__repr__() 3667 3668 if self.check_inplace: 3669 # check if the inplace variant of the module gives the same result 3670 # as the out-of-place 3671 3672 # check_inplace doesn't support multiple input tensors, since we don't have any modules 3673 # that modify the inputs in-place and that accept more than one input 3674 assert len(input_tuple) == 1 3675 input = input_tuple[0] 3676 3677 module_ip = self.constructor(*self.constructor_args, inplace=True) 3678 3679 input_version = input._version 3680 with freeze_rng_state(): 3681 output = module(input) 3682 test_case.assertEqual(input._version, input_version) 3683 3684 input_ip = deepcopy(input) 3685 input_ip_clone = input_ip.clone() 3686 with freeze_rng_state(): 3687 output_ip = module_ip(input_ip_clone) 3688 test_case.assertNotEqual(input_ip_clone._version, input_version) 3689 test_case.assertEqual(output, output_ip) 3690 grad = output.data.clone().normal_() 3691 if input.grad is not None: 3692 with torch.no_grad(): 3693 input.grad.zero_() 3694 if input_ip.grad is not None: 3695 with torch.no_grad(): 3696 input_ip.grad.zero_() 3697 output.backward(grad) 3698 output_ip.backward(grad) 3699 test_case.assertEqual(input.grad, input_ip.grad) 3700 3701 def assert_module_parameters_are(tensor_type, device_id=None): 3702 for p in module.parameters(): 3703 test_case.assertIsInstance(p, tensor_type) 3704 if device_id is not None: 3705 test_case.assertEqual(p.get_device(), device_id) 3706 3707 if all(isinstance(t, torch.LongTensor) for t in input_tuple) and TEST_CUDA: 3708 # check that cuda() moves module parameters to correct GPU device, 3709 # and that float() casts parameters correctly 3710 input_tuple = tuple(t.cuda() for t in input_tuple) 3711 module.float().cuda() 3712 module(*input_tuple) 3713 assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] 3714 3715 if torch.cuda.device_count() > 1: 3716 input_tuple = tuple(t.cuda(1) for t in input_tuple) 3717 module.cuda(1) 3718 with torch.cuda.device(1): 3719 module(*input_tuple) 3720 assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined] 3721 else: 3722 # check that float()/double() casters work correctly 3723 def to_type(tensor, real, complex): 3724 if tensor.is_complex(): 3725 return tensor.to(complex) 3726 elif tensor.is_floating_point(): 3727 return tensor.to(real) 3728 else: 3729 return tensor 3730 3731 def to_half(x): 3732 # TODO: torch.complex32 when properly supported 3733 return to_type(x, torch.float16, None) 3734 3735 def to_single(x): 3736 return to_type(x, torch.float32, torch.complex64) 3737 3738 def to_double(x): 3739 return to_type(x, torch.float64, torch.complex128) 3740 3741 # to float 3742 input_tuple = tuple(to_single(t) for t in input_tuple) 3743 module.float() 3744 module(*input_tuple) 3745 assert_module_parameters_are(torch.FloatTensor) 3746 3747 # and back to double 3748 input_tuple = tuple(to_double(t) for t in input_tuple) 3749 module.double() 3750 module(*input_tuple) 3751 assert_module_parameters_are(torch.DoubleTensor) 3752 3753 if TEST_CUDA and self.should_test_cuda: 3754 # check that cuda() moves module parameters to correct GPU device, 3755 # and that float() casts parameters correctly 3756 3757 # to GPU0 3758 input_tuple = tuple(to_single(t).cuda() for t in input_tuple) 3759 module.float().cuda() 3760 module(*input_tuple) 3761 assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] 3762 3763 # to CPU 3764 input_tuple = tuple(t.cpu() for t in input_tuple) 3765 module.cpu() 3766 module(*input_tuple) 3767 assert_module_parameters_are(torch.FloatTensor) 3768 3769 # back to GPU0 3770 input_tuple = tuple(t.cuda() for t in input_tuple) 3771 module.cuda() 3772 module(*input_tuple) 3773 assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] 3774 3775 # test that forwards of module runs correctly without cuDNN 3776 if self.cudnn: 3777 with torch.backends.cudnn.flags(enabled=False): 3778 module(*input_tuple) 3779 assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] 3780 3781 if torch.cuda.device_count() >= 2: 3782 # test cross-GPU transfer works 3783 # to GPU1 3784 input_tuple = tuple(t.cuda(1) for t in input_tuple) 3785 module.cuda(1) 3786 with torch.cuda.device(1): 3787 module(*input_tuple) 3788 assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined] 3789 3790 if not self.skip_double: 3791 # test double() 3792 input_tuple = tuple(to_double(t).cuda() for t in input_tuple) 3793 module.double().cuda() 3794 module(*input_tuple) 3795 assert_module_parameters_are(torch.cuda.DoubleTensor, 0) # type: ignore[attr-defined] 3796 3797 # test half() 3798 if not self.skip_half: 3799 input_tuple = tuple(to_half(t).cuda() for t in input_tuple) 3800 module.half().cuda() 3801 module(*input_tuple) 3802 assert_module_parameters_are(torch.cuda.HalfTensor, 0) # type: ignore[attr-defined] 3803 torch.set_num_threads(num_threads) 3804 3805 def _get_target(self): 3806 return self._get_arg('target', False) 3807 3808 @property 3809 def constructor_args(self): 3810 return self._get_arg('constructor_args', False) 3811 3812 3813class CriterionTest(InputVariableMixin, TestBase): # type: ignore[misc] 3814 # TODO: check that criterions don't ignore grad_output 3815 3816 _required_arg_names = TestBase._required_arg_names.union({'target'}) 3817 3818 def __init__(self, *args, **kwargs): 3819 super().__init__(*args, **kwargs) 3820 self.should_test_cuda = kwargs.get('test_cuda', True) 3821 self.check_forward_only = kwargs.get('check_forward_only', False) 3822 self.check_gradgrad = kwargs.get('check_gradgrad', True) 3823 self.check_half = kwargs.get('check_half', True) 3824 self.check_bfloat16 = kwargs.get('check_bfloat16', False) 3825 self.check_complex = kwargs.get('check_complex', False) 3826 self.test_cpu = kwargs.get('test_cpu', True) 3827 self.with_tf32 = kwargs.get('with_tf32', True) 3828 self.tf32_precision = kwargs.get('tf32_precision', 0.001) 3829 self.check_batched_grad = kwargs.get('check_batched_grad', True) 3830 self.default_dtype = kwargs.get('default_dtype', None) 3831 if self.default_dtype is None: 3832 self.default_dtype = torch.get_default_dtype() 3833 3834 def __call__(self, test_case): 3835 with set_default_dtype(self.default_dtype): 3836 module = self.constructor(*self.constructor_args) 3837 input = self._get_input() 3838 3839 # Check that these methods don't raise errors 3840 module.__repr__() 3841 str(module) 3842 3843 target = self._get_target() 3844 3845 if self.reference_fn is not None: 3846 out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args) 3847 ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,) 3848 expected_out = self.reference_fn(*ref_args) 3849 test_case.assertEqual(out, expected_out) 3850 3851 if self.check_forward_only: 3852 return 3853 3854 params = tuple(x for x in module.parameters()) 3855 if not isinstance(input, tuple): 3856 inputs = (input,) + params + (target,) 3857 3858 def apply_fn(input, target, *params): 3859 return module(input, target) 3860 else: 3861 inputs = input + params + (target,) 3862 3863 def apply_fn(input1, input2, target, *params): # type: ignore[misc] 3864 return module(input1, input2, target) 3865 3866 gradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad) 3867 3868 if self.check_gradgrad: 3869 gradgradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad) 3870 3871 def test_cuda(self, test_case, dtype, extra_args=None): 3872 def convert_dtype(obj, dtype, requires_grad=False): 3873 if isinstance(obj, torch.Tensor): 3874 return obj.detach().to(dtype=dtype).requires_grad_(requires_grad) 3875 elif isinstance(obj, tuple): 3876 return tuple(convert_dtype(o, dtype, requires_grad) for o in obj) 3877 else: 3878 return obj 3879 3880 if not TEST_CUDA or not self.should_test_cuda: 3881 raise unittest.SkipTest('Excluded from CUDA tests') 3882 3883 with set_default_dtype(self.default_dtype): 3884 cpu_input = self._get_input() 3885 cpu_target = self._get_target() 3886 cpu_module = self.constructor(*self.constructor_args) 3887 gpu_module = self.constructor(*self.constructor_args) 3888 3889 # Convert input, target and module parameters to dtype 3890 cpu_input = convert_dtype(cpu_input, dtype, True) 3891 if cpu_target.is_floating_point() or cpu_target.is_complex(): 3892 cpu_target = convert_dtype(cpu_target, dtype) 3893 cpu_module.type(dtype) 3894 gpu_module.type(dtype) 3895 3896 # GPU setup 3897 gpu_input = to_gpu(cpu_input) 3898 gpu_target = to_gpu(cpu_target) 3899 gpu_module.cuda() 3900 3901 # torch.HalfTensor doesn't support most operations, converting back to default 3902 if dtype in {torch.half, torch.bfloat16}: 3903 cpu_input = self._get_input() 3904 cpu_target = self._get_target() 3905 # Loss modules with weights require consistent input/module weight types 3906 cpu_module = self.constructor(*self.constructor_args) 3907 3908 cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args) 3909 gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args) 3910 # dtype used to be able to be None, so set precision in this way instead of a precision map 3911 test_case.assertEqual(cpu_output, gpu_output, 3912 atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False) 3913 3914 cpu_gradInput = test_case._backward_criterion( 3915 cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args) 3916 gpu_gradInput = test_case._backward_criterion( 3917 gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args) 3918 # dtype used to be able to be None, so set precision in this way instead of a precision map 3919 test_case.assertEqual(cpu_gradInput, gpu_gradInput, 3920 atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False) 3921 3922 def _get_target(self): 3923 return self._get_arg('target', False) 3924 3925 @property 3926 def constructor_args(self): 3927 return self._get_arg('constructor_args', False) 3928 3929 @property 3930 def extra_args(self): 3931 return self._get_arg('extra_args', False) 3932 3933 3934def _test_bfloat16_ops(test_case, op, device, inp_dims=(), prec=1e-2, scale_factor=None): 3935 # fp32 compute 3936 input1 = torch.randn(inp_dims, dtype=torch.float32, device=device, requires_grad=True) 3937 if scale_factor is not None: 3938 input1 = (torch.rand(inp_dims, dtype=torch.bfloat16, device=device) * scale_factor).float().requires_grad_() 3939 out1 = op(input1) 3940 grad_input1 = torch.randn_like(out1, device=device) 3941 out1.backward(grad_input1) 3942 3943 # bfloat16 compute 3944 op_bfp16 = op.bfloat16() 3945 input2 = input1.detach().bfloat16().requires_grad_() 3946 grad_input2 = grad_input1.bfloat16() 3947 out2 = op_bfp16(input2) 3948 out2.backward(grad_input2) 3949 3950 test_case.assertEqual(out1, out2, atol=prec, rtol=prec, exact_dtype=False) 3951 test_case.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=prec, exact_dtype=False) 3952 3953def _test_module_empty_input(test_case, module, inp, check_size=True, inference=False): 3954 if not inference: 3955 inp.requires_grad_(True) 3956 out = module(inp) 3957 if not inference: 3958 gO = torch.rand_like(out) 3959 out.backward(gO) 3960 if check_size: 3961 test_case.assertEqual(out.size(), inp.size()) 3962 if not inference: 3963 for p in module.parameters(): 3964 if p.requires_grad: 3965 test_case.assertEqual(p.grad, torch.zeros_like(p.grad)) 3966 test_case.assertEqual(inp.grad, torch.zeros_like(inp)) 3967 3968 3969def _create_basic_net(): 3970 class Layer(nn.Module): 3971 def __init__(self) -> None: 3972 super().__init__() 3973 self.layer_dummy_param = nn.Parameter(torch.empty(3, 5)) 3974 self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7)) 3975 3976 class Net(nn.Module): 3977 def __init__(self) -> None: 3978 super().__init__() 3979 self.l1 = Layer() 3980 self.dummy_param = nn.Parameter(torch.empty(3, 5)) 3981 self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1)) 3982 3983 l = Layer() 3984 n = Net() 3985 s = nn.Sequential(n, n) 3986 3987 return l, n, s 3988