1# mypy: allow-untyped-defs 2import collections 3import functools 4import warnings 5from itertools import product 6from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union 7from typing_extensions import deprecated 8 9import torch 10import torch.testing 11from torch._vmap_internals import _vmap, vmap 12from torch.overrides import is_tensor_like 13from torch.types import _TensorOrTensors 14 15 16# Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public 17# since they have been exposed from before we added `__all__` and we already maintain BC for them 18# We should eventually deprecate them and remove them from `__all__` 19__all__ = [ 20 "gradcheck", 21 "gradgradcheck", 22 "GradcheckError", 23 "get_numerical_jacobian", 24 "get_analytical_jacobian", 25 "get_numerical_jacobian_wrt_specific_input", 26] 27 28 29class GradcheckError(RuntimeError): 30 r"""Error raised by :func:`gradcheck` and :func:`gradgradcheck`.""" 31 32 33def _is_sparse_compressed_tensor(obj: torch.Tensor): 34 return obj.layout in { 35 torch.sparse_csr, 36 torch.sparse_csc, 37 torch.sparse_bsr, 38 torch.sparse_bsc, 39 } 40 41 42def _is_sparse_any_tensor(obj: torch.Tensor): 43 return _is_sparse_compressed_tensor(obj) or obj.layout is torch.sparse_coo 44 45 46def _is_float_or_complex_tensor(obj): 47 return is_tensor_like(obj) and (obj.is_floating_point() or obj.is_complex()) 48 49 50def _allocate_jacobians_with_inputs( 51 input_tensors: Tuple, numel_output 52) -> Tuple[torch.Tensor, ...]: 53 # Makes zero-filled tensors from inputs. If `numel_output` is not None, for 54 # each tensor in `input_tensors`, returns a new zero-filled tensor with height 55 # of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns 56 # a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have 57 # the same dtype and device as those of the corresponding input. 58 out: List[torch.Tensor] = [] 59 for t in input_tensors: 60 if _is_float_or_complex_tensor(t) and t.requires_grad: 61 out.append(t.new_zeros((t.numel(), numel_output), layout=torch.strided)) 62 return tuple(out) 63 64 65def _allocate_jacobians_with_outputs( 66 output_tensors: Tuple, numel_input, dtype=None, device=None 67) -> Tuple[torch.Tensor, ...]: 68 # Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor 69 # in `output_tensors`, returns a new zero-filled tensor with height of `dim` and 70 # width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size 71 # (t.numel,). 72 out: List[torch.Tensor] = [] 73 options = {"dtype": dtype, "device": device, "layout": torch.strided} 74 for t in output_tensors: 75 if _is_float_or_complex_tensor(t): 76 out.append(t.new_zeros((numel_input, t.numel()), **options)) 77 return tuple(out) 78 79 80def _iter_tensors( 81 x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False 82) -> Iterable[torch.Tensor]: 83 if is_tensor_like(x): 84 # mypy doesn't narrow type of `x` to torch.Tensor 85 if x.requires_grad or not only_requiring_grad: # type: ignore[union-attr] 86 yield x # type: ignore[misc] 87 elif isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 88 for elem in x: 89 yield from _iter_tensors(elem, only_requiring_grad) 90 91 92def _densify(x): 93 # return a copy of sparse x with all unspecified elements 94 # "replaced" with zero-valued elements 95 if isinstance(x, (list, tuple)): 96 return type(x)(map(_densify, x)) 97 elif not is_tensor_like(x) or x.layout in {torch.strided, torch._mkldnn}: # type: ignore[attr-defined] # no attr _mkldnn 98 return x 99 elif x.layout is torch.sparse_coo: 100 device = x.device 101 indices_dtype = x._indices().dtype 102 tmp = torch.ones(x.shape[: x.sparse_dim()], dtype=torch.int8, device=device) 103 indices = tmp.nonzero().t().to(dtype=indices_dtype) 104 values = torch.zeros( 105 (tmp.numel(), *x.shape[x.sparse_dim() :]), dtype=x.dtype, device=device 106 ) 107 x_coalesced = x.detach().coalesce() 108 if x_coalesced.numel() > 0: 109 stride = tmp.stride() 110 flat_indices = ( 111 x_coalesced.indices() 112 .mul( 113 torch.tensor(stride, dtype=indices_dtype, device=device).unsqueeze( 114 1 115 ) 116 ) 117 .sum(0) 118 ) 119 values[flat_indices] = x_coalesced.values() 120 return ( 121 torch.sparse_coo_tensor(indices, values, x.shape) 122 ._coalesced_(True) 123 .requires_grad_(x.requires_grad) 124 ) 125 elif _is_sparse_compressed_tensor(x): 126 blocksize = ( 127 x.values().shape[1:3] 128 if x.layout in {torch.sparse_bsr, torch.sparse_bsc} 129 else None 130 ) 131 compressed_indices = ( 132 x.crow_indices() 133 if x.layout in {torch.sparse_csr, torch.sparse_bsr} 134 else x.ccol_indices() 135 ) 136 # We'll use intermediate sparse COO for simplicity 137 r = _densify(x.detach().to_sparse(layout=torch.sparse_coo)).to_sparse( 138 layout=x.layout, blocksize=blocksize 139 ) 140 # Check that all elements are specified also after `to_sparse` op: 141 dense_numel = r.values().numel() // max(1, r.values().shape[0]) 142 batch_numel = compressed_indices.numel() // compressed_indices.shape[-1] 143 sparse_numel = r.numel() // max(1, dense_numel * batch_numel) 144 if sparse_numel != r._nnz(): 145 raise AssertionError( 146 f"{x.layout} densify failed: expected nnz={sparse_numel} but got {r._nnz()}" 147 ) 148 return r.requires_grad_(x.requires_grad) 149 elif _is_sparse_any_tensor(x): 150 raise NotImplementedError(x.layout) 151 return x 152 153 154def _iter_tensor(x_tensor): 155 # (Only used for slow gradcheck) Returns a generator that yields the following 156 # elements at each iteration: 157 # 1) a tensor: the same tensor is returned across all iterations. The tensor 158 # is not the same as the original x_tensor as given as input - it is 159 # prepared so that it can be modified in-place. Depending on whether the 160 # input tensor is strided, sparse, or dense, the returned tensor may or may 161 # not share storage with x_tensor. 162 # 2) a tuple of indices that can be used with advanced indexing (yielded in 163 # dictionary order) 164 # 3) flattened index that will be used to index into the Jacobian tensor 165 # 166 # For a tensor t with size (2, 2), _iter_tensor yields: 167 # `x, (0, 0), 0`, `x, (0, 1), 1`, `x, (1, 0), 2`, `x, (1, 1), 3` 168 # 169 # where x is the t.data of the original tensor. Perturbing the entry of x 170 # at index (1, 1) yields the 3rd column of the overall Jacobian matrix. 171 if _is_sparse_any_tensor(x_tensor): 172 173 def get_stride(size): 174 dim = len(size) 175 tmp = 1 176 stride = [0] * dim 177 for i in reversed(range(dim)): 178 stride[i] = tmp 179 tmp *= size[i] 180 return stride 181 182 x_nnz = x_tensor._nnz() 183 x_size = list(x_tensor.size()) 184 if x_tensor.layout is torch.sparse_coo: 185 x_indices = x_tensor._indices().t() 186 x_values = x_tensor._values() 187 elif x_tensor.layout is torch.sparse_csr: 188 x_indices = torch._convert_indices_from_csr_to_coo( 189 x_tensor.crow_indices(), x_tensor.col_indices() 190 ).t() 191 x_values = x_tensor.values() 192 elif x_tensor.layout is torch.sparse_csc: 193 x_indices = torch._convert_indices_from_csr_to_coo( 194 x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True 195 ).t() 196 x_values = x_tensor.values() 197 elif x_tensor.layout is torch.sparse_bsr: 198 x_block_values = x_tensor.values() 199 x_blocksize = x_block_values.size()[1:3] 200 x_indices = ( 201 torch._convert_indices_from_csr_to_coo( 202 x_tensor.crow_indices(), x_tensor.col_indices() 203 ) 204 .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) 205 .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) 206 .add_( 207 torch.stack( 208 torch.where(torch.ones(x_blocksize, device=x_tensor.device)) 209 ).repeat(1, x_nnz) 210 ) 211 .t() 212 ) 213 x_values = x_block_values.flatten(0, 2) 214 x_nnz = x_values.size(0) 215 elif x_tensor.layout is torch.sparse_bsc: 216 x_block_values = x_tensor.values() 217 x_blocksize = x_block_values.size()[1:3] 218 x_indices = ( 219 torch._convert_indices_from_csr_to_coo( 220 x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True 221 ) 222 .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) 223 .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) 224 .add_( 225 torch.stack( 226 torch.where(torch.ones(x_blocksize, device=x_tensor.device)) 227 ).repeat(1, x_nnz) 228 ) 229 .t() 230 ) 231 x_values = x_block_values.flatten(0, 2) 232 x_nnz = x_values.size(0) 233 else: 234 raise NotImplementedError(f"_iter_tensor for {x_tensor.layout} input") 235 x_stride = get_stride(x_size) 236 # Use .data here to get around the version check 237 x_values = x_values.data 238 for i in range(x_nnz): 239 x_value = x_values[i] 240 for x_idx in product(*[range(m) for m in x_values.size()[1:]]): 241 indices = x_indices[i].tolist() + list(x_idx) 242 d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size))) 243 yield x_value, x_idx, d_idx 244 elif x_tensor.layout == torch._mkldnn: # type: ignore[attr-defined] 245 for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): 246 # this is really inefficient, but without indexing implemented, there's 247 # not really a better way than converting back and forth 248 x_tensor_dense = x_tensor.to_dense() 249 yield x_tensor_dense, x_idx, d_idx 250 else: 251 # Use .data here to get around the version check 252 x_tensor = x_tensor.data 253 for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): 254 yield x_tensor, x_idx, d_idx 255 256 257def _get_numerical_jacobian( 258 fn, inputs, outputs=None, target=None, eps=1e-3, is_forward_ad=False 259) -> List[Tuple[torch.Tensor, ...]]: 260 """Compute the numerical Jacobian of `fn(inputs)` with respect to `target`. 261 262 If not specified, targets are the input. Returns M * N Jacobians where N is the 263 number of tensors in target that require grad and M is the number of non-integral 264 outputs. 265 266 Args: 267 fn: the function to compute the jacobian for 268 inputs: inputs to `fn` 269 outputs: provide precomputed outputs to avoid one extra invocation of fn 270 target: the Tensors wrt whom Jacobians are calculated (default=`inputs`) 271 eps: the magnitude of the perturbation during finite differencing 272 (default=`1e-3`) 273 is_forward_ad: if this numerical jacobian is computed to be checked wrt 274 forward AD gradients (this is used for error checking only) 275 276 Returns: 277 A list of M N-tuples of tensors 278 279 Note that `target` may not even be part of `input` to `fn`, so please be 280 **very careful** in this to not clone `target`. 281 """ 282 jacobians: List[Tuple[torch.Tensor, ...]] = [] 283 if outputs is None: 284 outputs = _as_tuple(fn(*_as_tuple(inputs))) 285 if not is_forward_ad and any(o.is_complex() for o in outputs): 286 raise ValueError( 287 "Expected output to be non-complex. get_numerical_jacobian no " 288 "longer supports functions that return complex outputs." 289 ) 290 if target is None: 291 target = inputs 292 inp_indices = [ 293 i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad 294 ] 295 for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)): 296 jacobians += [ 297 get_numerical_jacobian_wrt_specific_input( 298 fn, 299 inp_idx, 300 inputs, 301 outputs, 302 eps, 303 input=inp, 304 is_forward_ad=is_forward_ad, 305 ) 306 ] 307 return jacobians 308 309 310@deprecated( 311 "`get_numerical_jacobian` was part of PyTorch's private API and not " 312 "meant to be exposed. We are deprecating it and it will be removed " 313 "in a future version of PyTorch. If you have a specific use for " 314 "this or feature request for this to be a stable API, please file " 315 "us an issue at https://github.com/pytorch/pytorch/issues/new", 316 category=FutureWarning, 317) 318def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0): 319 """Compute the numerical Jacobian for a given fn and its inputs. 320 321 This is a Deprecated API. 322 323 Args: 324 fn: the function to compute the Jacobian for (must take inputs as a tuple) 325 input: input to `fn` 326 target: the Tensors wrt whom Jacobians are calculated (default=`input`) 327 eps: the magnitude of the perturbation during finite differencing 328 (default=`1e-3`) 329 330 Returns: 331 A list of Jacobians of `fn` (restricted to its first output) with respect to 332 each input or target, if provided. 333 334 Note that `target` may not even be part of `input` to `fn`, so please be 335 **very careful** in this to not clone `target`. 336 """ 337 if ( 338 grad_out != 1.0 339 ): # grad_out param is only kept for backward compatibility reasons 340 raise ValueError( 341 "Expected grad_out to be 1.0. get_numerical_jacobian no longer " 342 "supports values of grad_out != 1.0." 343 ) 344 345 def fn_pack_inps(*inps): 346 return fn(inps) 347 348 jacobians = _get_numerical_jacobian(fn_pack_inps, inputs, None, target, eps) 349 350 return tuple(jacobian_for_each_output[0] for jacobian_for_each_output in jacobians) 351 352 353def _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn): 354 # Computes numerical directional derivative as finite difference 355 # of function `fn` at input `entry`, perturbed by vector `v`. 356 if _is_sparse_compressed_tensor(entry): 357 # sparse compressed tensors don't implement sub/add/copy_ 358 # yet. However, in non-masked semantics context entry and v 359 # have the same sparse indices ... 360 assert entry.layout == v.layout, (entry.layout, v.layout) 361 assert entry._nnz() == v._nnz(), (entry._nnz(), v._nnz(), entry.shape) 362 # ... the finite differencing can be performed on values only: 363 entry = entry.values() 364 v = v.values() 365 # we'll detach to avoid backward computations that sparse 366 # tensors have limited support for. 367 entry = entry.detach() 368 369 orig = entry.clone() 370 entry.copy_(orig - v) 371 outa = fn() 372 entry.copy_(orig + v) 373 outb = fn() 374 entry.copy_(orig) 375 376 def compute(a, b): 377 nbhd_checks_fn(a, b) 378 ret = (b - a) / (2 * norm_v) # use central difference approx 379 return ret.detach().reshape(-1) 380 381 return tuple(compute(a, b) for (a, b) in zip(outa, outb)) 382 383 384def _compute_numerical_jvps_wrt_specific_input( 385 jvp_fn, delta, input_is_complex, is_forward_ad=False 386) -> List[torch.Tensor]: 387 # Computing the jacobian only works for real delta 388 # For details on the algorithm used here, refer: 389 # Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf 390 # s = fn(z) where z = x for real valued input 391 # and z = x + yj for complex valued input 392 jvps: List[torch.Tensor] = [] 393 ds_dx_tup = jvp_fn(delta[0] if isinstance(delta, tuple) else delta) 394 395 if input_is_complex: # C -> R 396 ds_dy_tup = ( 397 jvp_fn(delta[1] * 1j) if isinstance(delta, tuple) else jvp_fn(delta * 1j) 398 ) 399 for ds_dx, ds_dy in zip(ds_dx_tup, ds_dy_tup): 400 assert not ds_dx.is_complex() 401 # conjugate wirtinger derivative 402 conj_w_d = ds_dx + ds_dy * 1j 403 jvps.append(conj_w_d) 404 else: 405 for ds_dx in ds_dx_tup: # R -> R or (R -> C for the forward AD case) 406 assert is_forward_ad or not ds_dx.is_complex() 407 jvps.append(ds_dx) 408 return jvps 409 410 411def _combine_jacobian_cols( 412 jacobians_cols: Dict[int, List[torch.Tensor]], outputs, input, numel 413) -> Tuple[torch.Tensor, ...]: 414 # jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor 415 # we return a list that maps output_idx -> full jacobian Tensor 416 jacobians = _allocate_jacobians_with_outputs( 417 outputs, numel, dtype=input.dtype if input.dtype.is_complex else None 418 ) 419 for i, jacobian in enumerate(jacobians): 420 for k, v in jacobians_cols.items(): 421 jacobian[k] = v[i] 422 return jacobians 423 424 425def _prepare_input( 426 input: torch.Tensor, maybe_perturbed_input: Optional[torch.Tensor], fast_mode=False 427) -> torch.Tensor: 428 # Prepares the inputs to be passed into the function while including the new 429 # modified input. 430 if input.layout == torch._mkldnn: # type: ignore[attr-defined] # no attr _mkldnn 431 # Convert back to mkldnn 432 if maybe_perturbed_input is not None: 433 return maybe_perturbed_input.to_mkldnn() 434 else: 435 return input 436 elif _is_sparse_any_tensor(input): 437 if fast_mode and maybe_perturbed_input is not None: 438 # entry is already a "cloned" version of the original tensor 439 # thus changes to entry are not reflected in the input 440 return maybe_perturbed_input 441 else: 442 return input 443 else: 444 # We cannot use entry (input.data) if we want gradgrad to work because 445 # fn (in the gradgrad case) needs to compute grad wrt input 446 return input 447 448 449def _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None: 450 # Check that the returned outputs don't have different dtype or shape when you 451 # perturb the input 452 on_index = "on index {idx} " if idx is not None else "" 453 assert output1.shape == output2.shape, ( 454 f"Expected `func` to return outputs with the same shape" 455 f" when inputs are perturbed {on_index}by {eps}, but got:" 456 f" shapes {output1.shape} and {output2.shape}." 457 ) 458 assert output1.dtype == output2.dtype, ( 459 f"Expected `func` to return outputs with the same dtype" 460 f" when inputs are perturbed {on_index}by {eps}, but got:" 461 f" dtypes {output1.dtype} and {output2.dtype}." 462 ) 463 464 465def get_numerical_jacobian_wrt_specific_input( 466 fn, input_idx, inputs, outputs, eps, input=None, is_forward_ad=False 467) -> Tuple[torch.Tensor, ...]: 468 # Computes the numerical jacobians wrt to a single input. Returns N jacobian 469 # tensors, where N is the number of outputs. We use a dictionary for 470 # jacobian_cols because indices aren't necessarily consecutive for sparse inputs 471 # When we perturb only a single element of the input tensor at a time, the jvp 472 # is equivalent to a single col of the Jacobian matrix of fn. 473 jacobian_cols: Dict[int, List[torch.Tensor]] = {} 474 input = inputs[input_idx] if input is None else input 475 assert input.requires_grad 476 for x, idx, d_idx in _iter_tensor(input): 477 wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, x) 478 input_to_perturb = x[idx] 479 nbhd_checks_fn = functools.partial( 480 _check_outputs_same_dtype_and_shape, idx=idx, eps=eps 481 ) 482 jvp_fn = _get_numerical_jvp_fn( 483 wrapped_fn, input_to_perturb, eps, nbhd_checks_fn 484 ) 485 jacobian_cols[d_idx] = _compute_numerical_jvps_wrt_specific_input( 486 jvp_fn, eps, x.is_complex(), is_forward_ad 487 ) 488 return _combine_jacobian_cols(jacobian_cols, outputs, input, input.numel()) 489 490 491def _get_analytical_jacobian_forward_ad( 492 fn, inputs, outputs, *, check_grad_dtypes=False, all_u=None 493) -> Tuple[Tuple[torch.Tensor, ...], ...]: 494 """Compute the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect to `target`. 495 496 Return N * M Jacobians where N is the number of tensors in target that require grad and 497 M is the number of non-integral outputs. 498 Contrary to other functions here, this function requires "inputs" to actually be used by the function. 499 The computed value is expected to be wrong if the function captures the inputs by side effect instead of 500 using the passed ones (many torch.nn tests do this). 501 502 Args: 503 fn: the function to compute the jacobian for 504 inputs: inputs to `fn` 505 outputs: provide precomputed outputs to avoid one extra invocation of fn 506 check_grad_dtypes: if True, will check that the gradient dtype are valid 507 all_u (optional): if provided, the Jacobian will be right multiplied with this vector 508 509 Returns: 510 A tuple of M N-tuples of tensors 511 """ 512 # To avoid early import issues 513 fwAD = torch.autograd.forward_ad 514 515 tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad) 516 517 if any(i.is_complex() for i in tensor_inputs): 518 raise ValueError( 519 "Expected inputs to be non-complex for _get_analytical_jacobian_forward_ad." 520 ) 521 522 if all_u: 523 jacobians = tuple( 524 _allocate_jacobians_with_outputs(outputs, 1) for i in tensor_inputs 525 ) 526 else: 527 jacobians = tuple( 528 _allocate_jacobians_with_outputs(outputs, i.numel()) for i in tensor_inputs 529 ) 530 531 with fwAD.dual_level(): 532 fw_grads = [] 533 dual_inputs = [] 534 for i, inp in enumerate(inputs): 535 if is_tensor_like(inp) and inp.requires_grad: 536 if inp.layout == torch._mkldnn: # type: ignore[attr-defined] 537 raise ValueError( 538 "MKLDNN inputs are not support for forward AD gradcheck." 539 ) 540 541 inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) 542 # If inp is a differentiable view, the dual might not be the tangent given to 543 # make_dual, so read it explicitly from the dual tensor 544 fw_grads.append(fwAD.unpack_dual(inp)[1]) 545 dual_inputs.append(inp) 546 547 if all_u: 548 # Do the full reduction in one pass 549 # To be consistent with numerical evaluation, we actually compute one reduction per input 550 for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)): 551 fw_grad.copy_(u.view_as(fw_grad)) 552 raw_outputs = _as_tuple(fn(*dual_inputs)) 553 dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs) 554 for index_o, d_o in enumerate(dual_outputs): 555 val, res = fwAD.unpack_dual(d_o) 556 if ( 557 check_grad_dtypes 558 and res is not None 559 and val.is_complex() != res.is_complex() 560 ): 561 raise GradcheckError("Forward AD gradient has dtype mismatch.") 562 563 # Remove extra dimension of size 1 corresponding to the reduced input 564 jacobians[i][index_o].squeeze_(0) 565 if res is None: 566 jacobians[i][index_o].zero_() 567 else: 568 jacobians[i][index_o].copy_(res.reshape(-1)) 569 fw_grad.zero_() 570 else: 571 # Reconstruct the full Jacobian column by column 572 for i, fw_grad in enumerate(fw_grads): 573 for lin_idx, grad_idx in enumerate( 574 product(*[range(m) for m in fw_grad.size()]) 575 ): 576 fw_grad[grad_idx] = 1.0 577 raw_outputs = _as_tuple(fn(*dual_inputs)) 578 dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs) 579 for index_o, d_o in enumerate(dual_outputs): 580 val, res = fwAD.unpack_dual(d_o) 581 if ( 582 check_grad_dtypes 583 and res is not None 584 and val.is_complex() != res.is_complex() 585 ): 586 raise GradcheckError( 587 "Forward AD gradient has dtype mismatch." 588 ) 589 590 if res is None: 591 jacobians[i][index_o][lin_idx].zero_() 592 else: 593 jacobians[i][index_o][lin_idx].copy_(res.reshape(-1)) 594 fw_grad[grad_idx] = 0.0 595 596 return jacobians 597 598 599def _get_input_to_perturb(input): 600 # Prepare the input so that it can be modified in-place and do certain 601 # operations that require the tensor to have strides. If fast_mode=False, 602 # _iter_tensor would handle the below cases: 603 if input.layout == torch._mkldnn: # type: ignore[attr-defined] # no attr _mkldnn 604 # Convert to dense so we can perform operations that require strided tensors 605 input_to_perturb = input.to_dense() 606 elif _is_sparse_any_tensor(input): 607 # Clone because input may require grad, and copy_ calls resize_, 608 # which is not allowed for .data 609 input_to_perturb = input.clone() 610 else: 611 input_to_perturb = input.data 612 return input_to_perturb 613 614 615def _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, fast_mode=False): 616 # Wraps `fn` so that its inputs are already supplied 617 def wrapped_fn(): 618 inp = tuple( 619 _prepare_input(a, input_to_perturb if i == input_idx else None, fast_mode) 620 if is_tensor_like(a) 621 else a 622 for i, a in enumerate(_as_tuple(inputs)) 623 ) 624 return tuple(a.clone() for a in _as_tuple(fn(*inp))) 625 626 return wrapped_fn 627 628 629def _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn): 630 # Wraps jvp_fn so that certain arguments are already supplied 631 def jvp_fn(delta): 632 return _compute_numerical_gradient( 633 wrapped_fn, input_to_perturb, delta, eps, nbhd_checks_fn 634 ) 635 636 return jvp_fn 637 638 639def _reshape_tensor_or_tuple(u, shape): 640 # We don't need to reshape when input corresponding to u is sparse 641 if isinstance(u, tuple): 642 if not _is_sparse_any_tensor(u[0]): 643 return (u[0].reshape(shape), u[1].reshape(shape)) 644 else: 645 if not _is_sparse_any_tensor(u): 646 return u.reshape(shape) 647 return u 648 649 650def _mul_tensor_or_tuple(u, k): 651 if isinstance(u, tuple): 652 return (k * u[0], k * u[1]) 653 else: 654 return k * u 655 656 657def _get_numerical_jvp_wrt_specific_input( 658 fn, input_idx, inputs, u, eps, is_forward_ad=False 659) -> List[torch.Tensor]: 660 input = inputs[input_idx] 661 input_to_perturb = _get_input_to_perturb(input) 662 wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, True) 663 nbhd_checks_fn = functools.partial(_check_outputs_same_dtype_and_shape, eps=eps) 664 jvp_fn = _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn) 665 u = _reshape_tensor_or_tuple(u, input_to_perturb.shape) 666 u = _mul_tensor_or_tuple(u, eps) 667 return _compute_numerical_jvps_wrt_specific_input( 668 jvp_fn, u, input.is_complex(), is_forward_ad 669 ) 670 671 672def _get_numerical_vJu( 673 fn, inputs, inp_indices, func_out, all_u, all_v, eps, is_forward_ad 674): 675 # Note that all_v can also be None, in that case, this function only computes Ju. 676 reduced_jacobians: List[List[torch.Tensor]] = [] 677 for i, (inp_idx, u) in enumerate(zip(inp_indices, all_u)): 678 all_Ju = _get_numerical_jvp_wrt_specific_input( 679 fn, inp_idx, inputs, u, eps, is_forward_ad 680 ) 681 # Filter out the Ju for non floating point outputs 682 filtered_Ju = [] 683 func_out = _as_tuple(func_out) 684 assert len(all_Ju) == len(func_out) 685 for Ju, output in zip(all_Ju, func_out): 686 if _is_float_or_complex_tensor(output): 687 filtered_Ju.append(Ju) 688 else: 689 # TODO: handle the other Ju 690 pass 691 if all_v is not None: 692 jacobian_scalars: List[torch.Tensor] = [] 693 for v, Ju in zip(all_v, filtered_Ju): 694 jacobian_scalars.append(_dot_with_type_promotion(v, Ju)) 695 reduced_jacobians.append(jacobian_scalars) 696 else: 697 reduced_jacobians.append(filtered_Ju) 698 return reduced_jacobians 699 700 701def _check_jacobians_equal(j1, j2, atol): 702 # Check whether the max difference between two Jacobian tensors are within some 703 # tolerance `atol`. 704 for j1_x, j2_x in zip(j1, j2): 705 if j1_x.numel() != 0 and (j1_x - j2_x).abs().max() > atol: 706 return False 707 return True 708 709 710def _stack_and_check_tensors( 711 list_of_list_of_tensors, inputs, numel_outputs 712) -> Tuple[Tuple[torch.Tensor, ...], bool, bool]: 713 # For the ith tensor in the inner list checks whether it has the same size and 714 # dtype as the ith differentiable input. 715 out_jacobians = _allocate_jacobians_with_inputs(inputs, numel_outputs) 716 diff_input_list = list(_iter_tensors(inputs, True)) 717 correct_grad_sizes = True 718 correct_grad_types = True 719 for i, tensor_list in enumerate(list_of_list_of_tensors): 720 inp = diff_input_list[i] 721 out_jacobian = out_jacobians[i] 722 for j, tensor in enumerate(tensor_list): 723 if tensor is not None and tensor.size() != inp.size(): 724 correct_grad_sizes = False 725 elif tensor is not None and tensor.dtype != inp.dtype: 726 correct_grad_types = False 727 if tensor is None: 728 out_jacobian[:, j].zero_() 729 else: 730 dense = ( 731 tensor.to_dense() if not tensor.layout == torch.strided else tensor 732 ) 733 assert out_jacobian[:, j].numel() == dense.numel() 734 out_jacobian[:, j] = dense.reshape(-1) 735 return out_jacobians, correct_grad_sizes, correct_grad_types 736 737 738FAILED_NONDET_MSG = """\n 739NOTE: If your op relies on non-deterministic operations i.e., it is listed here: 740https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html 741this failure might be expected. 742 743If you are adding a new operator, please file an issue and then use one of the 744workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. 745If the test 746- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck 747 with `nondet_tol=<tol>` as a keyword argument. 748- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test 749 to have `gradcheck_nondet_tol=<tol>`. 750- is a Module test (e.g., in common_nn.py), then modify the corresponding 751 module_test entry to have `gradcheck_nondet_tol=<tol>` 752""" 753 754 755def _check_analytical_jacobian_attributes( 756 inputs, output, nondet_tol, check_grad_dtypes, fast_mode=False, v=None 757) -> Tuple[torch.Tensor, ...]: 758 # This is used by both fast and slow mode: 759 # - For slow mode, vjps[i][j] is the jth row of the Jacobian wrt the ith 760 # input. 761 # - For fast mode, vjps[i][0] is a linear combination of the rows 762 # of the Jacobian wrt the ith input 763 diff_input_list = list(_iter_tensors(inputs, True)) 764 765 def vjp_fn(grad_output): 766 return torch.autograd.grad( 767 output, diff_input_list, grad_output, retain_graph=True, allow_unused=True 768 ) 769 770 # Compute everything twice to check for nondeterminism (which we call reentrancy) 771 if fast_mode: 772 vjps1 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v) 773 vjps2 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v) 774 else: 775 vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) 776 vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) 777 778 output_numel = output.numel() if not fast_mode else 1 779 jacobians1, types_ok, sizes_ok = _stack_and_check_tensors( 780 vjps1, inputs, output_numel 781 ) 782 jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel) 783 reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol) 784 785 if not types_ok and check_grad_dtypes: 786 raise GradcheckError("Gradient has dtype mismatch") 787 if not sizes_ok: 788 raise GradcheckError("Analytical gradient has incorrect size") 789 if not reentrant: 790 raise GradcheckError( 791 "Backward is not reentrant, i.e., running backward with " 792 "same input and grad_output multiple times gives different values, " 793 "although analytical gradient matches numerical gradient." 794 f"The tolerance for nondeterminism was {nondet_tol}." + FAILED_NONDET_MSG 795 ) 796 return jacobians1 797 798 799def _get_analytical_vJu_backward_mode( 800 inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u 801): 802 reduced_jacobians: List[List[torch.Tensor]] = [] 803 for output, v in zip(outputs, all_v): 804 all_vJ = _check_analytical_jacobian_attributes( 805 inputs, output, nondet_tol, check_grad_dtypes, fast_mode=True, v=v 806 ) 807 jacobian_scalars: List[torch.Tensor] = [] 808 for vJ, u in zip(all_vJ, all_u): 809 # Why do we need squeeze here? vJ is a 2-d tensor so that we can reuse 810 # the error checking logic from slow mode 811 vJ = vJ.T.squeeze(0) 812 if vJ.is_complex(): # C -> R 813 tv = torch.view_as_real(vJ.resolve_conj()) 814 tr = tv.select(-1, 0) 815 ti = tv.select(-1, 1) 816 jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1])) 817 else: # R -> R 818 jacobian_scalars.append(vJ.dot(u)) 819 reduced_jacobians.append(jacobian_scalars) 820 return reduced_jacobians 821 822 823@deprecated( 824 "`get_analytical_jacobian` was part of PyTorch's private API and not " 825 "meant to be exposed. We are deprecating it and it will be removed " 826 "in a future version of PyTorch. If you have a specific use for " 827 "this or feature request for this to be a stable API, please file " 828 "us an issue at https://github.com/pytorch/pytorch/issues/new", 829 category=FutureWarning, 830) 831def get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0): 832 # Replicates the behavior of the old get_analytical_jacobian before the refactor 833 # This shares much of its code with _check_analytical_jacobian_attributes 834 if ( 835 grad_out != 1.0 836 ): # grad_out param is only kept for backward compatibility reasons 837 raise ValueError( 838 "Expected grad_out to be 1.0. get_analytical_jacobian no longer " 839 "supports values of grad_out != 1.0." 840 ) 841 if output.is_complex(): 842 raise ValueError( 843 "Expected output to be non-complex. get_analytical_jacobian no " 844 "longer supports functions that return complex outputs." 845 ) 846 diff_input_list = list(_iter_tensors(inputs, True)) 847 848 def vjp_fn(grad_output): 849 return torch.autograd.grad( 850 output, diff_input_list, grad_output, retain_graph=True, allow_unused=True 851 ) 852 853 # Compute everything twice to check for nondeterminism (which we call reentrancy) 854 vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) 855 vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) 856 857 output_numel = output.numel() 858 jacobians1, types_ok, sizes_ok = _stack_and_check_tensors( 859 vjps1, inputs, output_numel 860 ) 861 jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel) 862 reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol) 863 864 return jacobians1, reentrant, sizes_ok, types_ok 865 866 867def _get_analytical_jacobian(inputs, outputs, input_idx, output_idx): 868 # Computes the analytical Jacobian in slow mode for a single input-output pair. 869 # Forgoes performing checks on dtype, shape, and reentrancy. 870 jacobians = _check_analytical_jacobian_attributes( 871 inputs, outputs[output_idx], nondet_tol=float("inf"), check_grad_dtypes=False 872 ) 873 return jacobians[input_idx] 874 875 876def _compute_analytical_jacobian_rows( 877 vjp_fn, sample_output 878) -> List[List[Optional[torch.Tensor]]]: 879 # Computes Jacobian row-by-row by projecting `vjp_fn` = v^T J on standard basis 880 # vectors: vjp_fn(e) = e^T J is a corresponding row of the Jacobian. 881 # NB: this function does not assume vjp_fn(v) to return tensors with the same 882 # number of elements for different v. This is checked when we later combine the 883 # rows into a single tensor. 884 grad_out_base = torch.zeros_like( 885 sample_output, memory_format=torch.legacy_contiguous_format 886 ) 887 flat_grad_out = grad_out_base.view(-1) 888 # jacobians_rows[i][j] is the Jacobian jth row for the ith input 889 jacobians_rows: List[List[Optional[torch.Tensor]]] = [] 890 for j in range(flat_grad_out.numel()): 891 flat_grad_out.zero_() 892 flat_grad_out[j] = 1.0 # projection for jth row of Jacobian 893 grad_inputs = vjp_fn(grad_out_base) 894 for i, d_x in enumerate(grad_inputs): 895 if j == 0: 896 jacobians_rows.append([]) 897 jacobians_rows[i] += [ 898 d_x.clone() if isinstance(d_x, torch.Tensor) else None 899 ] 900 return jacobians_rows 901 902 903def _get_analytical_vjps_wrt_specific_output( 904 vjp_fn, sample_output, v 905) -> List[List[Optional[torch.Tensor]]]: 906 vjps: List[List[Optional[torch.Tensor]]] = [] 907 grad_inputs = vjp_fn(v.reshape(sample_output.shape)) 908 for vjp in grad_inputs: 909 vjps.append([vjp.clone() if isinstance(vjp, torch.Tensor) else None]) 910 return vjps 911 912 913def _check_inputs(tupled_inputs) -> bool: 914 # Make sure that gradients are saved for at least one input 915 any_input_requiring_grad = False 916 for idx, inp in enumerate(tupled_inputs): 917 if is_tensor_like(inp) and inp.requires_grad: 918 if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128): 919 warnings.warn( 920 f"Input #{idx} requires gradient and " 921 "is not a double precision floating point or complex. " 922 "This check will likely fail if all the inputs are " 923 "not of double precision floating point or complex. " 924 ) 925 if inp.is_sparse: 926 content = inp._values() 927 elif _is_sparse_compressed_tensor(inp): 928 content = inp.values() 929 else: 930 content = inp 931 # TODO: To cover more problematic cases, replace stride = 0 check with 932 # "any overlap in memory" once we have a proper function to check it. 933 if content.layout is not torch._mkldnn: # type: ignore[attr-defined] 934 if not all( 935 st > 0 or sz <= 1 936 for st, sz in zip(content.stride(), content.size()) 937 ): 938 raise RuntimeError( 939 f"The {idx}th input has a dimension with stride 0. gradcheck only " 940 "supports inputs that are non-overlapping to be able to " 941 "compute the numerical gradients correctly. You should call " 942 ".contiguous on the input before passing it to gradcheck." 943 ) 944 any_input_requiring_grad = True 945 946 if not any_input_requiring_grad: 947 raise ValueError( 948 "gradcheck expects at least one input tensor to require gradient, " 949 "but none of the them have requires_grad=True." 950 ) 951 return True 952 953 954def _check_outputs(outputs) -> None: 955 if any(_is_sparse_any_tensor(t) for t in outputs if isinstance(t, torch.Tensor)): 956 # it is easier to call to_dense() on the sparse output than 957 # to modify analytical jacobian 958 raise ValueError( 959 "Sparse output is not supported at gradcheck yet. " 960 "Please call to_dense(masked_grad=...) on the output of fn for gradcheck." 961 ) 962 if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)): # type: ignore[attr-defined] 963 raise ValueError( 964 "MKLDNN output is not supported at gradcheck yet. " 965 "Please call to_dense(masked_grad=...) on the output of fn for gradcheck." 966 ) 967 968 969def _check_no_differentiable_outputs( 970 func, inputs, func_out, eps, *, is_forward_ad 971) -> bool: 972 # When there are no differentiable outputs, numerical gradient for a function is 973 # expected to be zero. 974 jacobians_all_inputs_outputs = _get_numerical_jacobian( 975 func, inputs, func_out, eps=eps, is_forward_ad=is_forward_ad 976 ) 977 for jacobians_all_outputs_and_fixed_input in jacobians_all_inputs_outputs: 978 for jacobian in jacobians_all_outputs_and_fixed_input: 979 if torch.ne(jacobian, 0).sum() > 0: 980 raise GradcheckError( 981 "Numerical gradient for function expected to be zero" 982 ) 983 return True 984 985 986def _check_no_differentiable_outputs_fast( 987 func, func_out, all_inputs, inputs_indices, all_u, eps, nondet_tol 988): 989 for inp_idx, u in zip(inputs_indices, all_u): 990 jvps = _get_numerical_jvp_wrt_specific_input(func, inp_idx, all_inputs, u, eps) 991 for jvp in jvps: 992 if jvp.numel() == 0: 993 continue 994 if (jvp - torch.zeros_like(jvp)).abs().max() > nondet_tol: 995 raise GradcheckError( 996 "Numerical gradient for function expected to be zero" 997 ) 998 return True 999 1000 1001FAILED_BATCHED_GRAD_MSG = """ 1002gradcheck or gradgradcheck failed while testing batched gradient computation. 1003This could have been invoked in a number of ways (via a test that calls 1004gradcheck/gradgradcheck directly or via an autogenerated test). 1005 1006If you are adding a new operator, please file an issue and then use one of the 1007workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. 1008If the test 1009- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck 1010 with `check_batched_grad=False` as a keyword argument. 1011- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test 1012 to have `check_batched_grad=False` and/or `check_batched_gradgrad=False`. 1013 1014If you're modifying an existing operator that supports batched grad computation, 1015or wish to make a new operator work with batched grad computation, please read 1016the following. 1017 1018To compute batched grads (e.g., jacobians, hessians), we vmap over the backward 1019computation. The most common failure case is if there is a 'vmap-incompatible 1020operation' in the backward pass. Please see 1021NOTE: [How to write vmap-compatible backward formulas] 1022in the codebase for an explanation of how to fix this. 1023""".strip() 1024 1025FAILED_BATCHED_GRAD_MSG_FWD_AD = """ 1026gradcheck failed while testing batched gradient computation with forward-mode AD. 1027This test is enabled automatically when both `check_batched_grad=True` 1028and `check_forward_ad=True`, but can be disabled in the following ways 1029dependong on how the test was invoked (via a test that calls gradcheck 1030directly or via an autogenerated test). 1031 1032If you are adding a new operator, please file an issue and then use one of the 1033workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. 1034If the test 1035- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck 1036 with `check_batched_forward_grad=False` as a keyword argument. 1037- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test 1038 to have `check_batched_forward_grad=False` 1039""" 1040 1041 1042def _get_failed_batched_grad_test_msg( 1043 output_idx, input_idx, res, exp, is_forward_ad=False 1044): 1045 return f""" 1046For output {output_idx} and input {input_idx}: 1047 1048{FAILED_BATCHED_GRAD_MSG_FWD_AD if is_forward_ad else FAILED_BATCHED_GRAD_MSG} 1049 1050Got: 1051{res} 1052 1053Expected: 1054{exp} 1055""".strip() 1056 1057 1058def _test_batched_grad_forward_ad(func, inputs) -> bool: 1059 fwAD = torch.autograd.forward_ad # To avoid early import issues (do we need this?) 1060 assert isinstance(inputs, tuple) 1061 1062 for input_idx, current_input in enumerate(inputs): 1063 if not (is_tensor_like(current_input) and current_input.requires_grad): 1064 continue 1065 1066 def jvp(tangent: torch.Tensor): 1067 with fwAD.dual_level(): 1068 dual = fwAD.make_dual(current_input.detach(), tangent) 1069 inputs_with_dual = tuple( 1070 dual 1071 if idx == input_idx 1072 else (inp.detach() if is_tensor_like(inp) else inp) 1073 for idx, inp in enumerate(inputs) 1074 ) 1075 dual_outputs = _as_tuple(func(*inputs_with_dual)) 1076 ret = [] 1077 for dual_output in dual_outputs: 1078 if dual_output is None: 1079 continue 1080 primal_out, tangent_out = fwAD.unpack_dual(dual_output) 1081 if tangent_out is not None: 1082 ret.append(tangent_out) 1083 else: 1084 ret.append( 1085 torch.zeros( 1086 [], dtype=primal_out.dtype, device=primal_out.device 1087 ).expand(primal_out.shape) 1088 ) 1089 return tuple(ret) 1090 1091 if not _is_float_or_complex_tensor(current_input): 1092 continue 1093 1094 tangents = [torch.randn_like(current_input) for _ in range(2)] 1095 expected = [jvp(t) for t in tangents] 1096 expected = [torch.stack(shards) for shards in zip(*expected)] 1097 1098 try: 1099 result = _vmap(jvp)(torch.stack(tangents)) 1100 except RuntimeError as ex: 1101 # Rethrow to provide a better error message 1102 raise GradcheckError( 1103 f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}" 1104 ) from ex 1105 1106 for input_idx, (res, exp) in enumerate(zip(result, expected)): 1107 if torch.allclose(res, exp): 1108 continue 1109 raise GradcheckError( 1110 _get_failed_batched_grad_test_msg( 1111 input_idx, input_idx, res, exp, is_forward_ad=True 1112 ) 1113 ) 1114 return True 1115 1116 1117def _test_batched_grad(input, output, output_idx) -> bool: 1118 # NB: _test_batched_grad compares two autograd.grad invocations with a single 1119 # vmap(autograd.grad) invocation. It's not exactly a "gradcheck" in the 1120 # sense that we're not comparing an analytical jacobian with a numeric one, 1121 # but it is morally similar (we could have computed a full analytic jac 1122 # via vmap, but that is potentially slow) 1123 diff_input_list = list(_iter_tensors(input, True)) 1124 grad = functools.partial( 1125 torch.autograd.grad, 1126 output, 1127 diff_input_list, 1128 retain_graph=True, 1129 allow_unused=True, 1130 ) 1131 1132 def vjp(v): 1133 results = grad(v) 1134 results = tuple( 1135 grad 1136 if grad is not None 1137 else torch.zeros([], dtype=inp.dtype, device=inp.device).expand(inp.shape) 1138 for grad, inp in zip(results, diff_input_list) 1139 ) 1140 return results 1141 1142 grad_outputs = [torch.randn_like(output) for _ in range(2)] 1143 1144 expected = [vjp(gO) for gO in grad_outputs] 1145 expected = [torch.stack(shards) for shards in zip(*expected)] 1146 1147 # Squash warnings since these are expected to happen in most cases 1148 # NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209 1149 with warnings.catch_warnings(): 1150 warnings.filterwarnings("ignore", message="There is a performance drop") 1151 warnings.filterwarnings("ignore", message="Please use torch.vmap") 1152 try: 1153 result = vmap(vjp)(torch.stack(grad_outputs)) 1154 except RuntimeError as ex: 1155 # It's OK that we're not raising the error at the correct callsite. 1156 # That's because the callsite is always going to inside the Python 1157 # autograd.grad instead of the C++ traceback of what line in the 1158 # backward formula 1159 raise GradcheckError( 1160 f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}" 1161 ) from ex 1162 1163 for input_idx, (res, exp) in enumerate(zip(result, expected)): 1164 if torch.allclose(res, exp): 1165 continue 1166 raise GradcheckError( 1167 _get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp) 1168 ) 1169 return True 1170 1171 1172def _test_backward_mul_by_grad_output(outputs, inputs, masked) -> bool: 1173 # Tests that backward is multiplied by grad_output 1174 diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True)) 1175 if not diff_input_list: 1176 raise GradcheckError("no Tensors requiring grad found in input") 1177 grads_input = torch.autograd.grad( 1178 outputs, 1179 diff_input_list, 1180 [ 1181 torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) 1182 for o in outputs 1183 ], 1184 allow_unused=True, 1185 ) 1186 for gi, di in zip(grads_input, diff_input_list): 1187 if gi is None: 1188 continue 1189 if isinstance(gi, torch.Tensor) and gi.layout != torch.strided: 1190 if gi.layout != di.layout: 1191 raise GradcheckError( 1192 "grad is incorrect layout (" 1193 + str(gi.layout) 1194 + " is not " 1195 + str(di.layout) 1196 + ")" 1197 ) 1198 if _is_sparse_any_tensor(gi): 1199 sparse_kind = str(gi.layout).replace("torch.", "").replace("_coo", "") 1200 if gi.sparse_dim() != di.sparse_dim(): 1201 raise GradcheckError( 1202 f"grad is {sparse_kind} tensor, but has incorrect sparse_dim" 1203 f" {gi.sparse_dim()}, expected {di.sparse_dim()}" 1204 ) 1205 if gi.dense_dim() != di.dense_dim(): 1206 raise GradcheckError( 1207 f"grad is {sparse_kind} tensor, but has incorrect dense_dim" 1208 f" {gi.dense_dim()}, expected {di.dense_dim()}" 1209 ) 1210 gi = gi.to_dense() 1211 di = di.to_dense() 1212 if masked: 1213 if not torch.allclose(gi, torch.zeros_like(gi)): 1214 raise GradcheckError("backward not multiplied by grad_output") 1215 elif not gi.eq(0).all(): 1216 raise GradcheckError("backward not multiplied by grad_output") 1217 if gi.dtype != di.dtype: 1218 raise GradcheckError("grad is incorrect type") 1219 if gi.device != di.device: 1220 raise GradcheckError("grad is incorrect device") 1221 if gi.size() != di.size(): 1222 raise GradcheckError("grad is incorrect size") 1223 return True 1224 1225 1226def _test_undefined_forward_mode(func, outputs, inputs): 1227 fwAD = torch.autograd.forward_ad 1228 1229 inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs) 1230 all_v, all_u, all_u_dense = _make_vectors(inp_tensors, outputs, use_forward_ad=True) 1231 1232 tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad) 1233 1234 with fwAD.dual_level(): 1235 fw_grads = [] 1236 dual_inputs = [] 1237 tensor_indices = set() 1238 for i, inp in enumerate(inputs): 1239 if is_tensor_like(inp) and inp.requires_grad: 1240 if inp.layout == torch._mkldnn: # type: ignore[attr-defined] 1241 raise ValueError( 1242 "MKLDNN inputs are not support for forward AD gradcheck." 1243 ) 1244 1245 inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) 1246 # If inp is a differentiable view, the dual might not be the tangent given to 1247 # make_dual, so read it explicitly from the dual tensor 1248 fw_grads.append(fwAD.unpack_dual(inp)[1]) 1249 tensor_indices.add(i) 1250 dual_inputs.append(inp) 1251 1252 for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)): 1253 fw_grad.copy_(u.view_as(fw_grad)) 1254 1255 for idx, inp in enumerate(inputs): 1256 if idx not in tensor_indices: 1257 continue 1258 dual_inp_obj = dual_inputs[idx] 1259 1260 # case 1 (Materialized Zero Tensor Tangent) 1261 dual_inputs[idx] = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) 1262 raw_outputs = _as_tuple(func(*dual_inputs)) 1263 dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs) 1264 1265 # case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor) 1266 dual_inputs[idx] = inp.detach() 1267 raw_outputs = _as_tuple(func(*dual_inputs)) 1268 dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs) 1269 1270 # reset 1271 dual_inputs[idx] = dual_inp_obj 1272 1273 for index_o, (d_o1, d_o2) in enumerate(zip(dual_outputs1, dual_outputs2)): 1274 val1, res1 = fwAD.unpack_dual(d_o1) 1275 val2, res2 = fwAD.unpack_dual(d_o2) 1276 1277 if not (res1 is None or res2 is None): 1278 if not torch.allclose(res1, res2): 1279 raise GradcheckError( 1280 "Mismatch in tangent values for output with index: ", 1281 index_o, 1282 " when input: ", 1283 inp, 1284 " has an undefined tangent value. ", 1285 " Got: ", 1286 res1, 1287 " but expected: ", 1288 res2, 1289 ) 1290 return True 1291 1292 1293def _test_undefined_backward_mode(func, outputs, inputs) -> bool: 1294 diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True)) 1295 if not diff_input_list: 1296 raise GradcheckError("no Tensors requiring grad found in input") 1297 1298 def warn_bc_breaking(): 1299 warnings.warn( 1300 "Backwards compatibility: New undefined gradient support checking " 1301 "feature is enabled by default, but it may break existing callers " 1302 "of this function. If this is true for you, you can call this " 1303 'function with "check_undefined_grad=False" to disable the feature' 1304 ) 1305 1306 def check_undefined_grad_support(output_to_check): 1307 grads_output = [ 1308 torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) 1309 for o in output_to_check 1310 ] 1311 try: 1312 grads_input = torch.autograd.grad( 1313 output_to_check, diff_input_list, grads_output, allow_unused=True 1314 ) 1315 except RuntimeError as e: 1316 warn_bc_breaking() 1317 raise GradcheckError( 1318 "Expected backward function to handle undefined output grads. " 1319 'Please look at "Notes about undefined output gradients" in ' 1320 '"tools/autograd/derivatives.yaml"' 1321 ) from e 1322 1323 for gi, i in zip(grads_input, diff_input_list): 1324 if (gi is not None) and (not gi.eq(0).all()): 1325 warn_bc_breaking() 1326 raise GradcheckError( 1327 "Expected all input grads to be undefined or zero when all output grads are undefined " 1328 'or zero. Please look at "Notes about undefined output gradients" in ' 1329 '"tools/autograd/derivatives.yaml"' 1330 ) 1331 return True 1332 1333 # All backward functions must work properly if all output grads are undefined 1334 outputs_to_check = [ 1335 [ 1336 torch._C._functions.UndefinedGrad()(o) 1337 for o in _differentiable_outputs(func(*inputs)) 1338 # This check filters out Tensor-likes that aren't instances of Tensor. 1339 if isinstance(o, torch.Tensor) 1340 ] 1341 ] 1342 1343 # If there are multiple output grads, we should be able to undef one at a time without error 1344 if len(outputs_to_check[0]) > 1: 1345 for undef_grad_idx in range(len(outputs)): 1346 output_to_check = _differentiable_outputs(func(*inputs)) 1347 outputs_to_check.append( 1348 [ 1349 torch._C._functions.UndefinedGrad()(o) 1350 if idx == undef_grad_idx 1351 else o 1352 for idx, o in enumerate(output_to_check) 1353 ] 1354 ) 1355 1356 return all(check_undefined_grad_support(output) for output in outputs_to_check) 1357 1358 1359def _as_tuple(x): 1360 if isinstance(x, tuple): 1361 return x 1362 elif isinstance(x, list): 1363 return tuple(x) 1364 else: 1365 return (x,) 1366 1367 1368def _differentiable_outputs(x): 1369 return tuple(o for o in _as_tuple(x) if o.requires_grad) 1370 1371 1372def _get_notallclose_msg( 1373 analytical, 1374 numerical, 1375 output_idx, 1376 input_idx, 1377 complex_indices, 1378 test_imag=False, 1379 is_forward_ad=False, 1380) -> str: 1381 out_is_complex = ( 1382 (not is_forward_ad) and complex_indices and output_idx in complex_indices 1383 ) 1384 inp_is_complex = is_forward_ad and complex_indices and input_idx in complex_indices 1385 part = "imaginary" if test_imag else "real" 1386 element = "inputs" if is_forward_ad else "outputs" 1387 prefix = ( 1388 "" 1389 if not (out_is_complex or inp_is_complex) 1390 else f"While considering the {part} part of complex {element} only, " 1391 ) 1392 mode = "computed with forward mode " if is_forward_ad else "" 1393 return ( 1394 prefix + "Jacobian %smismatch for output %d with respect to input %d,\n" 1395 "numerical:%s\nanalytical:%s\n" 1396 % (mode, output_idx, input_idx, numerical, analytical) 1397 ) 1398 1399 1400def _transpose(matrix_of_tensors): 1401 # returns list of tuples 1402 return list(zip(*matrix_of_tensors)) 1403 1404 1405def _real_and_imag_output(fn): 1406 # returns new functions real(fn), and imag(fn) where real(fn) and imag(fn) behave the same as 1407 # the original fn, except torch.real or torch.imag are applied to the complex outputs 1408 def apply_to_c_outs(fn, fn_to_apply): 1409 def wrapped_fn(*inputs): 1410 outs = _as_tuple(fn(*inputs)) 1411 return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs) 1412 1413 return wrapped_fn 1414 1415 return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag) 1416 1417 1418def _real_and_imag_input(fn, complex_inp_indices, tupled_inputs): 1419 # returns new functions that take real inputs instead of complex inputs as 1420 # (x, y) -> fn(x + y * 1j). And it computes: inp -> fn(inp + y * 1j) and inp -> fn(x + inp * 1j). 1421 # In each case, the other part is considered constant. 1422 # We do not use 0 for the constant here to make sure we always call the user function with a valid input. 1423 def apply_to_c_inps(fn, fn_to_apply): 1424 def wrapped_fn(*inputs): 1425 new_inputs = list(inputs) 1426 for should_be_complex in complex_inp_indices: 1427 new_inputs[should_be_complex] = fn_to_apply( 1428 new_inputs[should_be_complex], tupled_inputs[should_be_complex] 1429 ) 1430 return _as_tuple(fn(*new_inputs)) 1431 1432 return wrapped_fn 1433 1434 real_fn = apply_to_c_inps(fn, lambda inp, orig: inp + orig.imag * 1j) 1435 imag_fn = apply_to_c_inps(fn, lambda inp, orig: orig.real + inp * 1j) 1436 return real_fn, imag_fn 1437 1438 1439def _gradcheck_real_imag( 1440 gradcheck_fn, 1441 func, 1442 func_out, 1443 tupled_inputs, 1444 outputs, 1445 eps, 1446 rtol, 1447 atol, 1448 check_grad_dtypes, 1449 check_forward_ad, 1450 check_backward_ad, 1451 nondet_tol, 1452 check_undefined_grad, 1453): 1454 complex_out_indices = [i for i, o in enumerate(outputs) if o.is_complex()] 1455 has_any_complex_output = any(o.is_complex() for o in _as_tuple(func_out)) 1456 if check_backward_ad: 1457 if has_any_complex_output: 1458 real_fn, imag_fn = _real_and_imag_output(func) 1459 1460 imag_func_out = imag_fn(*tupled_inputs) 1461 imag_outputs = _differentiable_outputs(imag_func_out) 1462 gradcheck_fn( 1463 imag_fn, 1464 imag_func_out, 1465 tupled_inputs, 1466 imag_outputs, 1467 eps, 1468 rtol, 1469 atol, 1470 check_grad_dtypes, 1471 nondet_tol, 1472 complex_indices=complex_out_indices, 1473 test_imag=True, 1474 ) 1475 1476 real_func_out = real_fn(*tupled_inputs) 1477 real_outputs = _differentiable_outputs(real_func_out) 1478 gradcheck_fn( 1479 real_fn, 1480 real_func_out, 1481 tupled_inputs, 1482 real_outputs, 1483 eps, 1484 rtol, 1485 atol, 1486 check_grad_dtypes, 1487 nondet_tol, 1488 complex_indices=complex_out_indices, 1489 ) 1490 else: 1491 gradcheck_fn( 1492 func, 1493 func_out, 1494 tupled_inputs, 1495 outputs, 1496 eps, 1497 rtol, 1498 atol, 1499 check_grad_dtypes, 1500 nondet_tol, 1501 ) 1502 1503 if check_forward_ad: 1504 complex_inp_indices = [ 1505 i 1506 for i, inp in enumerate(tupled_inputs) 1507 if is_tensor_like(inp) and inp.is_complex() 1508 ] 1509 if complex_inp_indices: 1510 real_fn, imag_fn = _real_and_imag_input( 1511 func, complex_inp_indices, tupled_inputs 1512 ) 1513 1514 imag_inputs = [ 1515 inp.imag if is_tensor_like(inp) and inp.is_complex() else inp 1516 for inp in tupled_inputs 1517 ] 1518 imag_func_out = imag_fn(*imag_inputs) 1519 diff_imag_func_out = _differentiable_outputs(imag_func_out) 1520 gradcheck_fn( 1521 imag_fn, 1522 imag_func_out, 1523 imag_inputs, 1524 diff_imag_func_out, 1525 eps, 1526 rtol, 1527 atol, 1528 check_grad_dtypes, 1529 nondet_tol, 1530 complex_indices=complex_inp_indices, 1531 test_imag=True, 1532 use_forward_ad=True, 1533 ) 1534 1535 real_inputs = [ 1536 inp.real if is_tensor_like(inp) and inp.is_complex() else inp 1537 for inp in tupled_inputs 1538 ] 1539 real_func_out = real_fn(*real_inputs) 1540 diff_real_func_out = _differentiable_outputs(real_func_out) 1541 gradcheck_fn( 1542 real_fn, 1543 real_func_out, 1544 real_inputs, 1545 diff_real_func_out, 1546 eps, 1547 rtol, 1548 atol, 1549 check_grad_dtypes, 1550 nondet_tol, 1551 complex_indices=complex_inp_indices, 1552 use_forward_ad=True, 1553 ) 1554 if check_undefined_grad: 1555 _test_undefined_forward_mode(imag_fn, imag_func_out, imag_inputs) 1556 _test_undefined_forward_mode(real_fn, real_func_out, real_inputs) 1557 else: 1558 gradcheck_fn( 1559 func, 1560 func_out, 1561 tupled_inputs, 1562 outputs, 1563 eps, 1564 rtol, 1565 atol, 1566 check_grad_dtypes, 1567 nondet_tol, 1568 use_forward_ad=True, 1569 ) 1570 if check_undefined_grad: 1571 _test_undefined_forward_mode(func, outputs, tupled_inputs) 1572 1573 1574def _slow_gradcheck( 1575 func, 1576 func_out, 1577 tupled_inputs, 1578 outputs, 1579 eps, 1580 rtol, 1581 atol, 1582 check_grad_dtypes, 1583 nondet_tol, 1584 *, 1585 use_forward_ad=False, 1586 complex_indices=None, 1587 test_imag=False, 1588 masked=False, 1589): 1590 func_out = _as_tuple(func_out) 1591 if not outputs: 1592 return _check_no_differentiable_outputs( 1593 func, tupled_inputs, func_out, eps=eps, is_forward_ad=use_forward_ad 1594 ) 1595 tupled_inputs_numerical = tupled_inputs if masked else _densify(tupled_inputs) 1596 1597 numerical = _transpose( 1598 _get_numerical_jacobian( 1599 func, 1600 tupled_inputs_numerical, 1601 func_out, 1602 eps=eps, 1603 is_forward_ad=use_forward_ad, 1604 ) 1605 ) 1606 # Note: [numerical vs analytical output length] 1607 # The numerical path returns jacobian quantity for all outputs, even if requires_grad of that 1608 # output is False. This behavior is necessary for _check_no_differentiable_outputs to work. 1609 numerical = [nj for o, nj in zip(func_out, numerical) if o.requires_grad] 1610 if use_forward_ad: 1611 analytical_forward = _get_analytical_jacobian_forward_ad( 1612 func, tupled_inputs, func_out, check_grad_dtypes=check_grad_dtypes 1613 ) 1614 1615 for i, n_per_out in enumerate(numerical): 1616 for j, n in enumerate(n_per_out): 1617 a = analytical_forward[j][i] 1618 if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol): 1619 raise GradcheckError( 1620 _get_notallclose_msg( 1621 a, n, i, j, complex_indices, test_imag, is_forward_ad=True 1622 ) 1623 ) 1624 else: 1625 for i, o in enumerate(outputs): 1626 analytical = _check_analytical_jacobian_attributes( 1627 tupled_inputs, o, nondet_tol, check_grad_dtypes 1628 ) 1629 1630 for j, (a, n) in enumerate(zip(analytical, numerical[i])): 1631 if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol): 1632 raise GradcheckError( 1633 _get_notallclose_msg(a, n, i, j, complex_indices, test_imag) 1634 ) 1635 1636 return True 1637 1638 1639def _dot_with_type_promotion(u, v): 1640 assert u.dim() == 1 and v.dim() == 1 1641 return (u * v).sum() 1642 1643 1644def _allclose_with_type_promotion(a, b, rtol, atol): 1645 promoted_type = torch.promote_types(a.dtype, b.dtype) 1646 a = a.to(dtype=promoted_type) 1647 b = b.to(dtype=promoted_type) 1648 return torch.allclose(a, b, rtol, atol) 1649 1650 1651def _to_real_dtype(dtype): 1652 if dtype == torch.complex128: 1653 return torch.float64 1654 elif dtype == torch.complex64: 1655 return torch.float32 1656 else: 1657 return dtype 1658 1659 1660def _vec_from_tensor(x, generator, downcast_complex=False): 1661 # Create a random vector with the same number of elements as x and the same 1662 # dtype/device. If x is complex and downcast_complex is False, we create a 1663 # complex tensor with only real component. 1664 if x.layout == torch.sparse_coo: 1665 # For sparse, create a random sparse vec with random values in the same 1666 # indices. Make sure size is set so that it isn't inferred to be smaller. 1667 x_values = x._values() 1668 dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype 1669 values = ( 1670 torch.rand(x_values.numel(), generator=generator) 1671 .to(dtype=dtype, device=x.device) 1672 .view(x_values.shape) 1673 ) 1674 values /= values.norm() 1675 vec = torch.sparse_coo_tensor(x._indices(), values, x.size(), device=x.device) 1676 elif _is_sparse_compressed_tensor(x): 1677 if x.layout in {torch.sparse_csr, torch.sparse_bsr}: 1678 compressed_indices, plain_indices = x.crow_indices(), x.col_indices() 1679 else: 1680 compressed_indices, plain_indices = x.ccol_indices(), x.row_indices() 1681 x_values = x.values() 1682 dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype 1683 values = ( 1684 torch.rand(x_values.numel(), generator=generator) 1685 .to(dtype=dtype, device=x.device) 1686 .view(x_values.shape) 1687 ) 1688 values /= values.norm() 1689 vec = torch.sparse_compressed_tensor( 1690 compressed_indices, 1691 plain_indices, 1692 values, 1693 x.size(), 1694 layout=x.layout, 1695 device=x.device, 1696 ) 1697 else: 1698 dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype 1699 vec = torch.rand(x.numel(), generator=generator).to( 1700 dtype=dtype, device=x.device 1701 ) 1702 vec /= vec.norm() 1703 return vec 1704 1705 1706def _get_inp_tensors(tupled_inputs): 1707 inp_idx_tup = [ 1708 (i, t) 1709 for i, t in enumerate(tupled_inputs) 1710 if is_tensor_like(t) and t.requires_grad 1711 ] 1712 return [tup[0] for tup in inp_idx_tup], [tup[1] for tup in inp_idx_tup] 1713 1714 1715def _adjusted_atol(atol, u, v): 1716 # In slow gradcheck, we compare A and B element-wise, i.e., for some a, b we 1717 # allow: |a - b| < atol + rtol * b. But since we now compare q1 = v^T A u and 1718 # q2 = v^T B u, we must allow |q1 - q2| < v^T E u + rtol * v^T B u, where E is 1719 # the correctly sized matrix in which each entry is atol. 1720 # 1721 # We see that atol needs to be scaled by v^T M u (where M is an all-ones M x N 1722 # matrix): v^T M u = \sum_{i} \sum_{j} u_i * v_j = (\sum_{i} u_i)(\sum_{i} v_i) 1723 # TODO: properly handle case when u is tuple instead of only taking first element 1724 u = u[0] if isinstance(u, tuple) else u 1725 sum_u = u.sum() 1726 sum_v = 1.0 if v is None else v.sum() 1727 return atol * float(sum_u) * float(sum_v) 1728 1729 1730FAST_FAIL_SLOW_OK_MSG = """ 1731Fast gradcheck failed but element-wise differences are small. This means that the 1732test might've passed in slow_mode! 1733 1734If you are adding a new operator, please file an issue and then use one of the 1735workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck: 1736 1737If the test 1738- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck 1739 with `fast_mode=False` as a keyword argument. 1740- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test 1741 to have `gradcheck_fast_mode=False` 1742- is a Module test (e.g., in common_nn.py), then modify the corresponding 1743 module_test entry to have `gradcheck_fast_mode=False` 1744""".strip() 1745 1746 1747def _run_slow_mode_and_get_error( 1748 func, tupled_inputs, outputs, input_idx, output_idx, rtol, atol, eps, is_forward_ad 1749): 1750 # Compute jacobians in slow mode for better error message 1751 slow_numerical = _get_numerical_jacobian( 1752 func, tupled_inputs, outputs, eps=eps, is_forward_ad=is_forward_ad 1753 )[input_idx][output_idx] 1754 if is_forward_ad: 1755 1756 def new_fn(inp): 1757 new_inputs = list(tupled_inputs) 1758 new_inputs[input_idx] = inp 1759 return _as_tuple(func(*new_inputs))[output_idx] 1760 1761 slow_analytical = _get_analytical_jacobian_forward_ad( 1762 new_fn, (tupled_inputs[input_idx],), (outputs[output_idx],) 1763 )[0][0] 1764 else: 1765 slow_analytical = _get_analytical_jacobian( 1766 tupled_inputs, outputs, input_idx, output_idx 1767 ) 1768 1769 # Assume jacobians are non-empty and have the same shape 1770 slow_max_diff = (slow_numerical - slow_analytical).abs().max() 1771 1772 slow_allclose = torch.allclose(slow_analytical, slow_numerical, rtol, atol) 1773 msg = ( 1774 "\nThe above quantities relating the numerical and analytical jacobians are computed \n" 1775 "in fast mode. See: https://github.com/pytorch/pytorch/issues/53876 for more background \n" 1776 "about fast mode. Below, we recompute numerical and analytical jacobians in slow mode:\n\n" 1777 f"Numerical:\n {slow_numerical}\n" 1778 f"Analytical:\n{slow_analytical}\n\n" 1779 f"The max per-element difference (slow mode) is: {slow_max_diff}.\n" 1780 ) 1781 if slow_allclose: 1782 # Slow gradcheck would've passed! 1783 msg += FAST_FAIL_SLOW_OK_MSG 1784 return msg 1785 1786 1787def _to_flat_dense_if_sparse(tensor): 1788 if _is_sparse_any_tensor(tensor): 1789 return tensor.to_dense().reshape(-1) 1790 else: 1791 return tensor 1792 1793 1794def _make_vectors(inp_tensors, outputs, *, use_forward_ad): 1795 # Use our own generator to avoid messing with the user's RNG state 1796 g_cpu = torch.Generator() 1797 1798 def _vec_from_tensor_cpu(*args): 1799 # Default allocate all tensors on CPU, so they are on the same device as the generator 1800 # even if the user specified a default device 1801 with torch.device("cpu"): 1802 return _vec_from_tensor(*args) 1803 1804 all_u = [] 1805 all_u_dense = [] 1806 for inp in inp_tensors: 1807 ur = _vec_from_tensor_cpu(inp, g_cpu, True) 1808 ur_dense = _to_flat_dense_if_sparse(ur) 1809 if inp.is_complex(): 1810 ui = _vec_from_tensor_cpu(inp, g_cpu, True) 1811 all_u.append((ur, ui)) 1812 ui_dense = _to_flat_dense_if_sparse(ui) 1813 all_u_dense.append((ur_dense, ui_dense)) 1814 else: 1815 all_u.append(ur) 1816 all_u_dense.append(ur_dense) 1817 all_v = ( 1818 None 1819 if use_forward_ad 1820 else [_vec_from_tensor_cpu(out, g_cpu) for out in outputs] 1821 ) 1822 return all_v, all_u, all_u_dense 1823 1824 1825def _check_analytical_numerical_equal( 1826 all_analytical, 1827 all_numerical, 1828 complex_indices, 1829 tupled_inputs, 1830 outputs, 1831 func, 1832 all_v, 1833 all_u, 1834 rtol, 1835 atol, 1836 eps, 1837 test_imag, 1838 *, 1839 is_forward_ad=False, 1840): 1841 for i, all_numerical_for_input_i in enumerate(all_numerical): 1842 for j, n in enumerate(all_numerical_for_input_i): 1843 # Forward AD generates the transpose of what this function expects 1844 if is_forward_ad: 1845 a = all_analytical[i][j] 1846 else: 1847 a = all_analytical[j][i] 1848 n = n.to(device=a.device) 1849 updated_atol = _adjusted_atol(atol, all_u[i], all_v[j] if all_v else None) 1850 if not _allclose_with_type_promotion(a, n.to(a.device), rtol, updated_atol): 1851 jacobians_str = _run_slow_mode_and_get_error( 1852 func, tupled_inputs, outputs, i, j, rtol, atol, eps, is_forward_ad 1853 ) 1854 raise GradcheckError( 1855 _get_notallclose_msg( 1856 a, n, j, i, complex_indices, test_imag, is_forward_ad 1857 ) 1858 + jacobians_str 1859 ) 1860 1861 1862def _fast_gradcheck( 1863 func, 1864 func_out, 1865 inputs, 1866 outputs, 1867 eps, 1868 rtol, 1869 atol, 1870 check_grad_dtypes, 1871 nondet_tol, 1872 *, 1873 use_forward_ad=False, 1874 complex_indices=None, 1875 test_imag=False, 1876 masked=False, 1877): 1878 # See https://github.com/pytorch/pytorch/issues/53876 for details 1879 inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs) 1880 # Backward mode computes v^T * J (VJP) 1881 # Since we computed J * u (JVP) through finite difference method, we perform an equality check 1882 # between VJP * u, v * JVP 1883 # ---- 1884 # Forward mode computes J * u (JVP) 1885 # Since we already compute JVP through finite difference method, 1886 # we don't need v for correctness check here as asserted below 1887 all_v, all_u, all_u_dense = _make_vectors( 1888 inp_tensors, outputs, use_forward_ad=use_forward_ad 1889 ) 1890 1891 inputs_numerical, all_u_numerical, all_v_numerical = ( 1892 (inputs, all_u, all_v) if masked else _densify((inputs, all_u, all_v)) 1893 ) 1894 1895 numerical_vJu = _get_numerical_vJu( 1896 func, 1897 inputs_numerical, 1898 inp_tensors_idx, 1899 func_out, 1900 all_u_numerical, 1901 all_v_numerical, 1902 eps, 1903 is_forward_ad=use_forward_ad, 1904 ) 1905 # TODO: replicate https://github.com/pytorch/pytorch/pull/77743 for fast gradcheck as well 1906 if use_forward_ad: 1907 assert all_v is None 1908 analytical_vJu = _get_analytical_jacobian_forward_ad( 1909 func, 1910 inputs, 1911 _as_tuple(func_out), 1912 all_u=all_u, 1913 check_grad_dtypes=check_grad_dtypes, 1914 ) 1915 else: 1916 if not outputs: 1917 _check_no_differentiable_outputs_fast( 1918 func, func_out, inputs, inp_tensors_idx, all_u, eps, nondet_tol 1919 ) 1920 1921 analytical_vJu = _get_analytical_vJu_backward_mode( 1922 inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u_dense 1923 ) 1924 1925 _check_analytical_numerical_equal( 1926 analytical_vJu, 1927 numerical_vJu, 1928 complex_indices, 1929 inputs, 1930 outputs, 1931 func, 1932 all_v, 1933 all_u, 1934 rtol, 1935 atol, 1936 eps, 1937 test_imag, 1938 is_forward_ad=use_forward_ad, 1939 ) 1940 1941 return True 1942 1943 1944# Note [VarArg of Tensors] 1945# ~~~~~~~~~~~~~~~~~~~~~~~~ 1946# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment. 1947# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted, 1948# the '...' first argument of Callable can be replaced with VarArg(Tensor). 1949# For now, we permit any input. 1950def gradcheck( 1951 func: Callable[..., Union[_TensorOrTensors]], # See Note [VarArg of Tensors] 1952 inputs: _TensorOrTensors, 1953 *, 1954 eps: float = 1e-6, 1955 atol: float = 1e-5, 1956 rtol: float = 1e-3, 1957 raise_exception: bool = True, 1958 nondet_tol: float = 0.0, 1959 check_undefined_grad: bool = True, 1960 check_grad_dtypes: bool = False, 1961 check_batched_grad: bool = False, 1962 check_batched_forward_grad: bool = False, 1963 check_forward_ad: bool = False, 1964 check_backward_ad: bool = True, 1965 fast_mode: bool = False, 1966 masked: Optional[bool] = None, 1967) -> bool: # noqa: D400,D205 1968 r"""Check gradients computed via small finite differences against analytical 1969 gradients wrt tensors in :attr:`inputs` that are of floating point or complex type 1970 and with ``requires_grad=True``. 1971 1972 The check between numerical and analytical gradients uses :func:`~torch.allclose`. 1973 1974 For most of the complex functions we consider for optimization purposes, no notion of 1975 Jacobian exists. Instead, gradcheck verifies if the numerical and analytical values of 1976 the Wirtinger and Conjugate Wirtinger derivatives are consistent. Because the gradient 1977 computation is done under the assumption that the overall function has a real-valued 1978 output, we treat functions with complex output in a special way. For these functions, 1979 gradcheck is applied to two real-valued functions corresponding to taking the real 1980 components of the complex outputs for the first, and taking the imaginary components 1981 of the complex outputs for the second. For more details, check out 1982 :ref:`complex_autograd-doc`. 1983 1984 .. note:: 1985 The default values are designed for :attr:`input` of double precision. 1986 This check will likely fail if :attr:`input` is of less precision, e.g., 1987 ``FloatTensor``. 1988 1989 .. note:: 1990 Gradcheck may fail when evaluated on non-differentiable points 1991 because the numerically computed gradients via finite differencing may differ 1992 those computed analytically (not necessarily because either is incorrect). 1993 For more context, see :ref:`non-differentiable-func-grad`. 1994 1995 .. warning:: 1996 If any checked tensor in :attr:`input` has overlapping memory, i.e., 1997 different indices pointing to the same memory address (e.g., from 1998 :func:`torch.expand`), this check will likely fail because the numerical 1999 gradients computed by point perturbation at such indices will change 2000 values at all other indices that share the same memory address. 2001 2002 Args: 2003 func (function): a Python function that takes Tensor inputs and returns 2004 a Tensor or a tuple of Tensors 2005 inputs (tuple of Tensor or Tensor): inputs to the function 2006 eps (float, optional): perturbation for finite differences 2007 atol (float, optional): absolute tolerance 2008 rtol (float, optional): relative tolerance 2009 raise_exception (bool, optional): indicating whether to raise an exception if 2010 the check fails. The exception gives more information about the 2011 exact nature of the failure. This is helpful when debugging gradchecks. 2012 nondet_tol (float, optional): tolerance for non-determinism. When running 2013 identical inputs through the differentiation, the results must either match 2014 exactly (default, 0.0) or be within this tolerance. 2015 check_undefined_grad (bool, optional): if ``True``, check if undefined output grads 2016 are supported and treated as zeros, for ``Tensor`` outputs. 2017 check_batched_grad (bool, optional): if ``True``, check if we can compute 2018 batched gradients using prototype vmap support. Defaults to False. 2019 check_batched_forward_grad (bool, optional): if ``True``, checks if we can compute 2020 batched forward gradients using forward ad and prototype vmap support. Defaults to ``False``. 2021 check_forward_ad (bool, optional): if ``True``, check that the gradients computed with forward 2022 mode AD match the numerical ones. Defaults to ``False``. 2023 check_backward_ad (bool, optional): if ``False``, do not perform any checks that rely on 2024 backward mode AD to be implemented. Defaults to ``True``. 2025 fast_mode (bool, optional): Fast mode for gradcheck and gradgradcheck is currently only 2026 implemented for R to R functions. If none of the inputs and outputs are complex 2027 a faster implementation of gradcheck that no longer computes the entire jacobian 2028 is run; otherwise, we fall back to the slow implementation. 2029 masked (bool, optional): if ``True``, the gradients of unspecified elements of 2030 sparse tensors are ignored. Defaults to ``False``. 2031 Returns: 2032 ``True`` if all differences satisfy allclose condition 2033 2034 """ 2035 assert ( 2036 check_forward_ad or check_backward_ad 2037 ), "Expected at least one of check_forward_ad or check_backward_ad to be True" 2038 assert not ( 2039 check_batched_grad and not check_backward_ad 2040 ), "Setting check_batched_grad=True requires check_backward_ad to be True" 2041 assert not ( 2042 check_batched_forward_grad and not check_forward_ad 2043 ), "Setting check_batched_forward_grad=True requires check_forward_ad to be True" 2044 args = locals().copy() 2045 args.pop("raise_exception") 2046 if not raise_exception: 2047 try: 2048 return _gradcheck_helper(**args) 2049 except GradcheckError as e: 2050 return False 2051 else: 2052 return _gradcheck_helper(**args) 2053 2054 2055def _gradcheck_helper( 2056 func, 2057 inputs, 2058 eps, 2059 atol, 2060 rtol, 2061 nondet_tol, 2062 check_undefined_grad, 2063 check_grad_dtypes, 2064 check_batched_grad, 2065 check_batched_forward_grad, 2066 check_forward_ad, 2067 check_backward_ad, 2068 fast_mode, 2069 masked, 2070): 2071 tupled_inputs = _as_tuple(inputs) 2072 _check_inputs(tupled_inputs) 2073 2074 func_out = func(*tupled_inputs) 2075 outputs = _differentiable_outputs(func_out) 2076 _check_outputs(outputs) 2077 2078 gradcheck_fn = functools.partial( 2079 _fast_gradcheck if fast_mode else _slow_gradcheck, masked=masked 2080 ) 2081 _gradcheck_real_imag( 2082 gradcheck_fn, 2083 func, 2084 func_out, 2085 tupled_inputs, 2086 outputs, 2087 eps, 2088 rtol, 2089 atol, 2090 check_grad_dtypes, 2091 check_forward_ad=check_forward_ad, 2092 check_backward_ad=check_backward_ad, 2093 nondet_tol=nondet_tol, 2094 check_undefined_grad=check_undefined_grad, 2095 ) 2096 2097 if check_batched_forward_grad: 2098 _test_batched_grad_forward_ad(func, tupled_inputs) 2099 2100 # Short circuit because remaining tests rely on backward AD to be implemented 2101 if not check_backward_ad: 2102 return True 2103 2104 for i, o in enumerate(outputs): 2105 if check_batched_grad: 2106 _test_batched_grad(tupled_inputs, o, i) 2107 2108 _test_backward_mul_by_grad_output(outputs, tupled_inputs, masked) 2109 2110 if check_undefined_grad and check_backward_ad: 2111 _test_undefined_backward_mode(func, outputs, tupled_inputs) 2112 return True 2113 2114 2115def gradgradcheck( 2116 func: Callable[..., _TensorOrTensors], # See Note [VarArg of Tensors] 2117 inputs: _TensorOrTensors, 2118 grad_outputs: Optional[_TensorOrTensors] = None, 2119 *, 2120 eps: float = 1e-6, 2121 atol: float = 1e-5, 2122 rtol: float = 1e-3, 2123 gen_non_contig_grad_outputs: bool = False, 2124 raise_exception: bool = True, 2125 nondet_tol: float = 0.0, 2126 check_undefined_grad: bool = True, 2127 check_grad_dtypes: bool = False, 2128 check_batched_grad: bool = False, 2129 check_fwd_over_rev: bool = False, 2130 check_rev_over_rev: bool = True, 2131 fast_mode: bool = False, 2132 masked: bool = False, 2133) -> bool: # noqa: D400,D205 2134 r"""Check gradients of gradients computed via small finite differences 2135 against analytical gradients wrt tensors in :attr:`inputs` and 2136 :attr:`grad_outputs` that are of floating point or complex type and with 2137 ``requires_grad=True``. 2138 2139 This function checks that backpropagating through the gradients computed 2140 to the given :attr:`grad_outputs` are correct. 2141 2142 The check between numerical and analytical gradients uses :func:`~torch.allclose`. 2143 2144 .. note:: 2145 The default values are designed for :attr:`input` and 2146 :attr:`grad_outputs` of double precision. This check will likely fail if 2147 they are of less precision, e.g., ``FloatTensor``. 2148 2149 .. warning:: 2150 If any checked tensor in :attr:`input` and :attr:`grad_outputs` has 2151 overlapping memory, i.e., different indices pointing to the same memory 2152 address (e.g., from :func:`torch.expand`), this check will likely fail 2153 because the numerical gradients computed by point perturbation at such 2154 indices will change values at all other indices that share the same 2155 memory address. 2156 2157 Args: 2158 func (function): a Python function that takes Tensor inputs and returns 2159 a Tensor or a tuple of Tensors 2160 inputs (tuple of Tensor or Tensor): inputs to the function 2161 grad_outputs (tuple of Tensor or Tensor, optional): The gradients with 2162 respect to the function's outputs. 2163 eps (float, optional): perturbation for finite differences 2164 atol (float, optional): absolute tolerance 2165 rtol (float, optional): relative tolerance 2166 gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is 2167 ``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the 2168 randomly generated gradient outputs are made to be noncontiguous 2169 raise_exception (bool, optional): indicating whether to raise an exception if 2170 the check fails. The exception gives more information about the 2171 exact nature of the failure. This is helpful when debugging gradchecks. 2172 nondet_tol (float, optional): tolerance for non-determinism. When running 2173 identical inputs through the differentiation, the results must either match 2174 exactly (default, 0.0) or be within this tolerance. Note that a small amount 2175 of nondeterminism in the gradient will lead to larger inaccuracies in 2176 the second derivative. 2177 check_undefined_grad (bool, optional): if True, check if undefined output grads 2178 are supported and treated as zeros 2179 check_batched_grad (bool, optional): if True, check if we can compute 2180 batched gradients using prototype vmap support. Defaults to False. 2181 fast_mode (bool, optional): if True, run a faster implementation of gradgradcheck that 2182 no longer computes the entire jacobian. 2183 masked (bool, optional): if True, the gradients of unspecified elements of 2184 sparse tensors are ignored (default, False). 2185 Returns: 2186 True if all differences satisfy allclose condition 2187 """ 2188 assert ( 2189 check_fwd_over_rev or check_rev_over_rev 2190 ), "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True" 2191 assert not ( 2192 check_undefined_grad and not check_rev_over_rev 2193 ), "Setting check_undefined_grad=True requires check_rev_over_rev to be True" 2194 assert not ( 2195 check_batched_grad and not check_rev_over_rev 2196 ), "Setting check_batched_grad=True requires check_rev_over_rev to be True" 2197 # TODO: do we want to test this too? 2198 # assert not (check_batched_forward_grad and not check_fwd_over_rev), ( 2199 # "Setting check_batched_forward_grad=True requires check_fwd_over_rev to be True") 2200 tupled_inputs = _as_tuple(inputs) 2201 2202 if grad_outputs is None: 2203 # If grad_outputs is not specified, create random Tensors of the same shape, type, and device as the outputs 2204 2205 outputs = _differentiable_outputs(func(*tupled_inputs)) 2206 tupled_grad_outputs = tuple( 2207 torch.testing.make_tensor( 2208 x.shape, 2209 dtype=x.dtype 2210 if x.is_floating_point() or x.is_complex() 2211 else torch.double, 2212 device=x.device, 2213 low=-1, 2214 high=1, 2215 requires_grad=True, 2216 noncontiguous=gen_non_contig_grad_outputs, 2217 ) 2218 for x in outputs 2219 ) 2220 else: 2221 tupled_grad_outputs = _as_tuple(grad_outputs) 2222 2223 num_outputs = len(tupled_grad_outputs) 2224 2225 # NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs 2226 # before running forward mode AD 2227 diff_input_args_indices = { 2228 i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad 2229 } 2230 diff_grad_output_indices = { 2231 i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad 2232 } 2233 2234 def new_func(*args): 2235 # Restore the requires_grad information 2236 input_args = tuple( 2237 x.requires_grad_() if i in diff_input_args_indices else x 2238 for i, x in enumerate(args[:-num_outputs]) 2239 ) 2240 outputs = _differentiable_outputs(func(*input_args)) 2241 grad_outputs = tuple( 2242 x.requires_grad_() if i in diff_grad_output_indices else x 2243 for i, x in enumerate(args[-num_outputs:]) 2244 ) 2245 diff_input_args = tuple( 2246 x for i, x in enumerate(input_args) if i in diff_input_args_indices 2247 ) 2248 grad_inputs = torch.autograd.grad( 2249 outputs, diff_input_args, grad_outputs, create_graph=True, allow_unused=True 2250 ) 2251 grad_inputs = tuple(g for g in grad_inputs if g is not None) 2252 return grad_inputs 2253 2254 return gradcheck( 2255 new_func, 2256 tupled_inputs + tupled_grad_outputs, 2257 eps=eps, 2258 atol=atol, 2259 rtol=rtol, 2260 raise_exception=raise_exception, 2261 nondet_tol=nondet_tol, 2262 check_undefined_grad=check_undefined_grad, 2263 check_grad_dtypes=check_grad_dtypes, 2264 check_batched_grad=check_batched_grad, 2265 fast_mode=fast_mode, 2266 check_forward_ad=check_fwd_over_rev, 2267 check_backward_ad=check_rev_over_rev, 2268 masked=masked, 2269 ) 2270