1# mypy: ignore-errors 2 3import torch 4from functools import partial 5from torch.testing import make_tensor 6from torch.testing._internal.opinfo.core import ( 7 OpInfo, 8 SampleInput, 9) 10from torch.testing._internal.common_dtype import all_types_and 11import numpy as np 12 13# Note: [autograd.Function db] 14# 15# This is a collection of autograd.Function test cases written as OpInfos 16# so they can easily be consumed by OpInfo-based tests to check if a subsystem 17# supports autograd.Function. 18# 19# Axes: 20# - saves {output, input, intermediate, non-tensor} 21# - {inputs, output} x {single tensor, tensors, arbitrary objects} 22# - Uses {mark_dirty, mark_non_differentiable, once_differentiable} 23 24 25def to_numpy(tensor): 26 return tensor.cpu().numpy() 27 28 29class NumpyCube(torch.autograd.Function): 30 @staticmethod 31 def forward(input): 32 input_np = to_numpy(input) 33 dinput = torch.tensor(3 * input_np ** 2, device=input.device) 34 return torch.tensor(input_np ** 3, device=input.device), dinput 35 36 @staticmethod 37 def setup_context(ctx, inputs, output): 38 ctx.save_for_backward(inputs[0], output[1]) 39 ctx.save_for_forward(inputs[0], output[1]) 40 41 @staticmethod 42 def backward(ctx, grad_output, grad_saved): 43 input, dinput = ctx.saved_tensors 44 return NumpyMul.apply(grad_output, dinput) + 6 * NumpyMul.apply(grad_saved, input) 45 46 @staticmethod 47 def vmap(info, in_dims, input): 48 result = NumpyCube.apply(input) 49 return result, (in_dims[0], in_dims[0]) 50 51 @staticmethod 52 def jvp(ctx, input_tangent): 53 input, dinput = ctx.saved_tensors 54 return NumpyMul.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input) 55 56 57class CubeGenVmap(torch.autograd.Function): 58 generate_vmap_rule = True 59 60 @staticmethod 61 def forward(x): 62 return x ** 3, 3 * x ** 2 63 64 @staticmethod 65 def setup_context(ctx, inputs, outputs): 66 ctx.save_for_backward(inputs[0], outputs[1]) 67 ctx.save_for_forward(inputs[0], outputs[1]) 68 69 @staticmethod 70 def backward(ctx, grad_output, grad_saved): 71 input, dinput = ctx.saved_tensors 72 result = grad_output * dinput + 6 * dinput 73 return result 74 75 @staticmethod 76 def jvp(ctx, input_tangent): 77 input, dinput = ctx.saved_tensors 78 return MulGenVmap.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input) 79 80 81def sample_inputs_numpy_cube(opinfo, device, dtype, requires_grad, **kwargs): 82 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 83 yield SampleInput(make_arg(1, low=0.8, high=2), args=()) 84 85 86class NumpyCubeNotComposable(torch.autograd.Function): 87 @staticmethod 88 def forward(input): 89 input_np = to_numpy(input) 90 return torch.tensor(input_np ** 3, device=input.device), input_np 91 92 @staticmethod 93 def setup_context(ctx, inputs, output): 94 _, input_np = output 95 ctx.input_np = input_np 96 ctx.device = inputs[0].device 97 98 @staticmethod 99 @torch.autograd.function.once_differentiable 100 def backward(ctx, grad_output, grad_saved): 101 result_np = 3 * (ctx.input_np ** 2) 102 return torch.tensor(result_np, device=ctx.device) 103 104 105class NumpyMul(torch.autograd.Function): 106 @staticmethod 107 def forward(x, y): 108 return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) 109 110 @staticmethod 111 def setup_context(ctx, inputs, output): 112 ctx.save_for_backward(*inputs) 113 ctx.save_for_forward(*inputs) 114 115 @staticmethod 116 def backward(ctx, grad_output): 117 x, y = ctx.saved_tensors 118 gx = None 119 if ctx.needs_input_grad[0]: 120 gx = NumpyMul.apply(grad_output, y) 121 gy = None 122 if ctx.needs_input_grad[1]: 123 gy = NumpyMul.apply(grad_output, x) 124 return gx, gy 125 126 @staticmethod 127 def vmap(info, in_dims, x, y): 128 x_bdim, y_bdim = in_dims 129 x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 130 y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 131 result = NumpyMul.apply(x, y) 132 result = result.movedim(-1, 0) 133 return result, 0 134 135 @staticmethod 136 def jvp(ctx, x_tangent, y_tangent): 137 x, y = ctx.saved_tensors 138 return x_tangent * y + y_tangent * x 139 140def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs): 141 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 142 # Broadcasting 143 yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),)) 144 145def sample_inputs_numpy_mul_scalar(opinfo, device, dtype, requires_grad, **kwargs): 146 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 147 yield SampleInput(make_arg(4, low=0.9, high=2), args=(), kwargs={"scalar": 3.14}) 148 149class MulGenVmap(torch.autograd.Function): 150 generate_vmap_rule = True 151 152 @staticmethod 153 def forward(x, y): 154 return x * y 155 156 @staticmethod 157 def setup_context(ctx, inputs, outputs): 158 ctx.save_for_backward(*inputs) 159 ctx.save_for_forward(*inputs) 160 161 @staticmethod 162 def backward(ctx, grad_output): 163 x, y = ctx.saved_tensors 164 gx = None 165 if ctx.needs_input_grad[0]: 166 gx = MulGenVmap.apply(grad_output, y) 167 gy = None 168 if ctx.needs_input_grad[1]: 169 gy = MulGenVmap.apply(grad_output, x) 170 return gx, gy 171 172 @staticmethod 173 def jvp(ctx, x_tangent, y_tangent): 174 x, y = ctx.saved_tensors 175 return x_tangent * y + y_tangent * x 176 177 178class NumpyExp_(torch.autograd.Function): 179 @staticmethod 180 def forward(x): 181 x_np = to_numpy(x) 182 np.exp(x_np, x_np) 183 return x 184 185 @staticmethod 186 def setup_context(ctx, inputs, output): 187 x, = inputs 188 ctx.mark_dirty(x) 189 ctx.save_for_backward(output) 190 ctx.save_for_forward(output) 191 192 @staticmethod 193 def backward(ctx, grad_output): 194 output, = ctx.saved_tensors 195 return NumpyMul.apply(grad_output, output) 196 197 @staticmethod 198 def vmap(info, in_dims, x): 199 NumpyExp_.apply(x) 200 return x, in_dims[0] 201 202 @staticmethod 203 def jvp(ctx, x_tangent): 204 # Doesn't call numpy operations because I didn't want to write NumpyMul_ 205 output, = ctx.saved_tensors 206 x_tangent.mul_(output) 207 return x_tangent 208 209class NumpySort(torch.autograd.Function): 210 @staticmethod 211 def forward(x, dim): 212 device = x.device 213 x = to_numpy(x) 214 ind = np.argsort(x, axis=dim) 215 ind_inv = np.argsort(ind, axis=dim) 216 result = np.take_along_axis(x, ind, axis=dim) 217 return ( 218 torch.tensor(x, device=device), 219 torch.tensor(ind, device=device), 220 torch.tensor(ind_inv, device=device), 221 ) 222 223 @staticmethod 224 def setup_context(ctx, inputs, output): 225 x, dim = inputs 226 _, ind, ind_inv = output 227 ctx.mark_non_differentiable(ind, ind_inv) 228 ctx.save_for_backward(ind, ind_inv) 229 ctx.save_for_forward(ind, ind_inv) 230 ctx.dim = dim 231 232 @staticmethod 233 def backward(ctx, grad_output, _0, _1): 234 ind, ind_inv = ctx.saved_tensors 235 return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None 236 237 @staticmethod 238 def vmap(info, in_dims, x, dim): 239 x_bdim, _ = in_dims 240 x = x.movedim(x_bdim, 0) 241 # wrap dim 242 dim = dim if dim >= 0 else dim + x.dim() - 1 243 return NumpySort.apply(x, dim + 1), (0, 0, 0) 244 245 @staticmethod 246 def jvp(ctx, x_tangent, _): 247 ind, ind_inv = ctx.saved_tensors 248 return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim), None, None 249 250class SortGenVmap(torch.autograd.Function): 251 generate_vmap_rule = True 252 253 @staticmethod 254 def forward(x, dim): 255 device = x.device 256 ind = torch.argsort(x, dim=dim) 257 ind_inv = torch.argsort(ind, axis=dim) 258 result = torch.take_along_dim(x, ind, dim=dim) 259 return result, ind, ind_inv 260 261 @staticmethod 262 def setup_context(ctx, inputs, outputs): 263 x, dim = inputs 264 _, ind, ind_inv = outputs 265 ctx.mark_non_differentiable(ind, ind_inv) 266 ctx.save_for_backward(ind, ind_inv) 267 ctx.save_for_forward(ind, ind_inv) 268 ctx.dim = dim 269 270 @staticmethod 271 def backward(ctx, grad_output, _0, _1): 272 ind, ind_inv = ctx.saved_tensors 273 return TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim), None 274 275 @staticmethod 276 def jvp(ctx, x_tangent, _): 277 ind, ind_inv = ctx.saved_tensors 278 return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim), None, None 279 280 281def sample_inputs_numpy_sort(opinfo, device, dtype, requires_grad, **kwargs): 282 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 283 yield SampleInput(make_arg(3, 5), args=(1,)) 284 285 286def sample_inputs_numpy_take(opinfo, device, dtype, requires_grad, **kwargs): 287 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 288 tensor = make_arg(3, 5) 289 dim = 1 290 _, ind, ind_inv = NumpySort.apply(tensor, 1) 291 yield SampleInput(tensor, args=(ind, ind_inv, dim)) 292 293 294class NumpyTake(torch.autograd.Function): 295 @staticmethod 296 def forward(x, ind, ind_inv, dim): 297 device = x.device 298 x = to_numpy(x) 299 ind = to_numpy(ind) 300 return torch.tensor(np.take_along_axis(x, ind, dim), device=device) 301 302 @staticmethod 303 def setup_context(ctx, inputs, output): 304 x, ind, ind_inv, dim = inputs 305 ctx.save_for_backward(ind, ind_inv) 306 ctx.save_for_forward(ind, ind_inv) 307 ctx.dim = dim 308 309 @staticmethod 310 def backward(ctx, grad_output): 311 ind, ind_inv = ctx.saved_tensors 312 result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim) 313 return result, None, None, None 314 315 @staticmethod 316 def vmap(info, in_dims, x, ind, ind_inv, dim): 317 x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims 318 319 # wrap dim 320 logical_dim = x.dim() if x_bdim is None else x_bdim - 1 321 dim = dim if dim >= 0 else dim + logical_dim 322 323 def expand_bdim(x, x_bdim): 324 if x_bdim is None: 325 return x.expand(info.batch_size, *x.shape) 326 return x.movedim(x_bdim, 0) 327 328 x = expand_bdim(x, x_bdim) 329 ind = expand_bdim(ind, ind_bdim) 330 ind_inv = expand_bdim(ind_inv, ind_inv_bdim) 331 332 return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0 333 334 @staticmethod 335 def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _): 336 assert ind_tangent is None 337 assert ind_inv_tangent is None 338 ind, ind_inv = ctx.saved_tensors 339 return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim) 340 341class TakeGenVmap(torch.autograd.Function): 342 generate_vmap_rule = True 343 344 @staticmethod 345 def forward(x, ind, ind_inv, dim): 346 return torch.take_along_dim(x, ind, dim) 347 348 @staticmethod 349 def setup_context(ctx, inputs, outputs): 350 x, ind, ind_inv, dim = inputs 351 ctx.save_for_backward(ind, ind_inv) 352 ctx.save_for_forward(ind, ind_inv) 353 ctx.dim = dim 354 355 @staticmethod 356 def backward(ctx, grad_output): 357 ind, ind_inv = ctx.saved_tensors 358 result = TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim) 359 return result, None, None, None 360 361 @staticmethod 362 def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _): 363 ind, ind_inv = ctx.saved_tensors 364 return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim) 365 366class Select(torch.autograd.Function): 367 @staticmethod 368 def forward(x, idx): 369 return x[idx] 370 371 @staticmethod 372 def setup_context(ctx, inputs, output): 373 x, idx = inputs 374 ctx.x_shape = x.shape 375 ctx.idx = idx 376 377 @staticmethod 378 def backward(ctx, grad_output): 379 result = grad_output.new_zeros(ctx.x_shape) 380 result[ctx.idx] = grad_output 381 return result, None 382 383 @staticmethod 384 def vmap(info, in_dims, x, idx): 385 x_bdim, _ = in_dims 386 x = x.movedim(x_bdim, 1) 387 return Select.apply(x, idx), 0 388 389 @staticmethod 390 def jvp(ctx, x_tangent, _): 391 return Select.apply(x_tangent, ctx.idx) 392 393class SelectGenVmap(torch.autograd.Function): 394 generate_vmap_rule = True 395 396 @staticmethod 397 def forward(x, idx): 398 return x[idx] 399 400 @staticmethod 401 def setup_context(ctx, inputs, outputs): 402 x, idx = inputs 403 ctx.x_shape = x.shape 404 ctx.idx = idx 405 406 @staticmethod 407 def backward(ctx, grad_output): 408 result = grad_output.new_zeros(ctx.x_shape) 409 result[ctx.idx] = grad_output 410 return result, None 411 412 @staticmethod 413 def jvp(ctx, x_tangent, _): 414 return SelectGenVmap.apply(x_tangent, ctx.idx) 415 416 417def sample_inputs_select(opinfo, device, dtype, requires_grad, **kwargs): 418 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 419 yield SampleInput(make_arg(3, 5), args=(2,)) 420 421class ScaleGradGenVmap(torch.autograd.Function): 422 generate_vmap_rule = True 423 scale = 3.14 424 425 @staticmethod 426 def forward(x): 427 return x.clone() 428 429 @staticmethod 430 def setup_context(ctx, inputs, outputs): 431 pass 432 433 @staticmethod 434 def backward(ctx, grad_output): 435 return grad_output * ScaleGradGenVmap.scale 436 437 @staticmethod 438 def jvp(ctx, x_tangent): 439 return x_tangent * ScaleGradGenVmap.scale 440 441class ZeroGradientsGenVmap(torch.autograd.Function): 442 generate_vmap_rule = True 443 444 @staticmethod 445 def forward(x, y): 446 return x.clone(), y.clone() 447 448 @staticmethod 449 def setup_context(ctx, inputs, outputs): 450 pass 451 452 @staticmethod 453 def backward(ctx, gx, gy): 454 # Intentionally returning torch.zeros instead of zeros_like or new_zeros. 455 # Also intentionally not None. 456 return ( 457 # Intentionally too-large gradient 458 torch.zeros(3, 4, *gx.shape, dtype=gx.dtype, device=gx.device), 459 torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device), 460 ) 461 462 @staticmethod 463 def jvp(ctx, gx, gy): 464 # Intentionally returning torch.zeros instead of zeros_like or new_zeros. 465 # Also intentionally not None. 466 return ( 467 torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device), 468 torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device), 469 ) 470 471 472def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs): 473 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 474 yield SampleInput(make_arg(3, 5)) 475 476 477class ForwardHasDefaultArgs(torch.autograd.Function): 478 @staticmethod 479 def forward(x, idx=(2,)): 480 return x[idx] 481 482 @staticmethod 483 def setup_context(ctx, inputs, output): 484 x, idx = inputs 485 ctx.x_shape = x.shape 486 ctx.idx = idx 487 488 @staticmethod 489 def backward(ctx, grad_output): 490 result = grad_output.new_zeros(ctx.x_shape) 491 result[ctx.idx] = grad_output 492 return result, None 493 494 @staticmethod 495 def vmap(info, in_dims, x, idx): 496 x_bdim, _ = in_dims 497 x = x.movedim(x_bdim, 1) 498 return ForwardHasDefaultArgs.apply(x, idx), 0 499 500 @staticmethod 501 def jvp(ctx, x_tangent, _): 502 return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx) 503 504 505autograd_function_db = [ 506 OpInfo( 507 'NumpyCubeAutogradFunction', 508 op=NumpyCube.apply, 509 supports_forward_ad=True, 510 supports_fwgrad_bwgrad=True, 511 sample_inputs_func=sample_inputs_numpy_cube, 512 dtypes=all_types_and(torch.bool, torch.half), 513 supports_out=False, 514 ), 515 OpInfo( 516 'NumpyExpMarkDirtyAutogradFunction', 517 op=lambda x: NumpyExp_.apply(x.clone()), 518 inplace_variant=NumpyExp_.apply, 519 supports_forward_ad=True, 520 supports_fwgrad_bwgrad=True, 521 sample_inputs_func=sample_inputs_numpy_cube, 522 dtypes=all_types_and(torch.bool, torch.half), 523 supports_out=False, 524 ), 525 OpInfo( 526 'NumpyMulAutogradFunction', 527 op=NumpyMul.apply, 528 supports_forward_ad=True, 529 supports_fwgrad_bwgrad=True, 530 sample_inputs_func=sample_inputs_numpy_mul, 531 dtypes=all_types_and(torch.bool, torch.half), 532 supports_out=False, 533 ), 534 OpInfo( 535 'NumpyCubeNotComposableAutogradFunction', 536 op=lambda x: NumpyCubeNotComposable.apply(x)[0], 537 supports_forward_ad=False, 538 supports_fwgrad_bwgrad=False, 539 sample_inputs_func=sample_inputs_numpy_cube, 540 dtypes=all_types_and(torch.bool, torch.half), 541 supports_out=False, 542 ), 543 OpInfo( 544 'NumpySortAutogradFunction', 545 op=NumpySort.apply, 546 supports_forward_ad=False, 547 supports_fwgrad_bwgrad=False, 548 sample_inputs_func=sample_inputs_numpy_sort, 549 dtypes=all_types_and(torch.bool, torch.half), 550 supports_out=False, 551 gradcheck_wrapper=lambda y, ind: y, 552 ), 553 OpInfo( 554 'NumpyTakeAutogradFunction', 555 op=NumpyTake.apply, 556 supports_forward_ad=False, 557 supports_fwgrad_bwgrad=False, 558 sample_inputs_func=sample_inputs_numpy_take, 559 dtypes=all_types_and(torch.bool, torch.half), 560 supports_out=False, 561 ), 562 OpInfo( 563 'SelectAutogradFunction', 564 op=Select.apply, 565 supports_forward_ad=True, 566 supports_fwgrad_bwgrad=True, 567 sample_inputs_func=sample_inputs_select, 568 dtypes=all_types_and(torch.bool, torch.half), 569 supports_out=False, 570 ), 571 OpInfo( 572 'CubeGenVmapAutogradFunction', 573 op=CubeGenVmap.apply, 574 supports_forward_ad=True, 575 supports_fwgrad_bwgrad=True, 576 sample_inputs_func=sample_inputs_numpy_cube, 577 dtypes=all_types_and(torch.bool, torch.half), 578 supports_out=False, 579 ), 580 OpInfo( 581 'MulGenVmapAutogradFunction', 582 op=MulGenVmap.apply, 583 supports_forward_ad=True, 584 supports_fwgrad_bwgrad=True, 585 sample_inputs_func=sample_inputs_numpy_mul, 586 dtypes=all_types_and(torch.bool, torch.half), 587 supports_out=False, 588 ), 589 OpInfo( 590 'SortGenVmapAutogradFunction', 591 op=SortGenVmap.apply, 592 supports_forward_ad=True, 593 supports_fwgrad_bwgrad=True, 594 sample_inputs_func=sample_inputs_numpy_sort, 595 dtypes=all_types_and(torch.bool, torch.half), 596 supports_out=False, 597 gradcheck_wrapper=lambda y, ind: y, 598 ), 599 OpInfo( 600 'SelectGenVmapAutogradFunction', 601 op=SelectGenVmap.apply, 602 supports_forward_ad=True, 603 supports_fwgrad_bwgrad=True, 604 sample_inputs_func=sample_inputs_select, 605 dtypes=all_types_and(torch.bool, torch.half), 606 supports_out=False, 607 ), 608 OpInfo( 609 'ScaleGradGenVmapAutogradFunction', 610 op=ScaleGradGenVmap.apply, 611 supports_forward_ad=True, 612 supports_fwgrad_bwgrad=True, 613 sample_inputs_func=sample_inputs_numpy_cube, 614 dtypes=all_types_and(torch.bool, torch.half), 615 supports_out=False, 616 ), 617 OpInfo( 618 'ZeroGradientsGenVmapAutogradFunction', 619 op=ZeroGradientsGenVmap.apply, 620 supports_forward_ad=True, 621 supports_fwgrad_bwgrad=True, 622 sample_inputs_func=sample_inputs_numpy_mul, 623 dtypes=all_types_and(torch.bool, torch.half), 624 supports_out=False, 625 ), 626 OpInfo( 627 'ForwardHasDefaultArgsAutogradFunction', 628 op=ForwardHasDefaultArgs.apply, 629 supports_forward_ad=True, 630 supports_fwgrad_bwgrad=True, 631 sample_inputs_func=sample_inputs_forward_default_args, 632 dtypes=all_types_and(torch.bool, torch.half), 633 supports_out=False, 634 ), 635] 636