xref: /aosp_15_r20/external/pytorch/torch/autograd/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3``torch.autograd`` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions.
4
5It requires minimal changes to the existing code - you only need to declare :class:`Tensor` s
6for which gradients should be computed with the ``requires_grad=True`` keyword.
7As of now, we only support autograd for floating point :class:`Tensor` types (
8half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble).
9"""
10
11import warnings
12from typing import cast, List, Optional, Sequence, Tuple, Union
13
14import torch
15from torch import _vmap_internals
16from torch.overrides import handle_torch_function, has_torch_function, is_tensor_like
17from torch.types import _size, _TensorOrTensors, _TensorOrTensorsOrGradEdge
18
19from . import forward_ad, functional, graph
20from .anomaly_mode import detect_anomaly, set_detect_anomaly
21from .function import Function, NestedIOFunction
22from .grad_mode import (
23    _force_original_view_tracking,
24    _unsafe_preserve_version_counter,
25    enable_grad,
26    inference_mode,
27    no_grad,
28    set_grad_enabled,
29    set_multithreading_enabled,
30)
31from .gradcheck import gradcheck, gradgradcheck
32from .graph import _engine_run_backward
33from .variable import Variable
34
35
36__all__ = [
37    "Variable",
38    "Function",
39    "backward",
40    "grad_mode",
41    "NestedIOFunction",
42    "detect_anomaly",
43    "enable_grad",
44    "grad",
45    "gradcheck",
46    "gradgradcheck",
47    "inference_mode",
48    "no_grad",
49    "set_detect_anomaly",
50    "set_grad_enabled",
51    "set_multithreading_enabled",
52    "variable",
53]
54
55_OptionalTensor = Optional[torch.Tensor]
56_ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor]
57
58
59def _calculate_shape(
60    output: Union[torch.Tensor, graph.GradientEdge],
61    grad: torch.Tensor,
62    is_grads_batched: bool,
63) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
64    # is_same_size ensures that both tensors are either nested or non nested
65    # circular import
66    from torch.nested._internal.nested_tensor import NestedTensor
67
68    if isinstance(output, graph.GradientEdge):
69        # We have already checked that we are not a C++ NestedTensor
70        if is_grads_batched:
71            raise RuntimeError("Batched grads are not supported with GradientEdge")
72        out_metadata = output.node._input_metadata[output.output_nr]
73        return torch.Size(out_metadata.shape), grad.shape
74
75    if output.is_nested and not isinstance(output, NestedTensor):
76        if is_grads_batched:
77            raise RuntimeError("Batched grads are not supported with Nested Tensor.")
78        out_shape = output._nested_tensor_size()
79        grad_shape = grad._nested_tensor_size()
80
81        return out_shape, grad_shape
82
83    reg_out_shape = output.shape
84    reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
85    return reg_out_shape, reg_grad_shape
86
87
88def _make_grads(
89    outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]],
90    grads: Sequence[_OptionalTensor],
91    is_grads_batched: bool,
92) -> Tuple[_OptionalTensor, ...]:
93    new_grads: List[_OptionalTensor] = []
94    for out, grad in zip(outputs, grads):
95        out = cast(Union[torch.Tensor, graph.GradientEdge], out)
96        out_size = None
97        out_device = None
98
99        if isinstance(out, graph.GradientEdge):
100            out_metadata = out.node._input_metadata[out.output_nr]
101            out_size = torch.Size(out_metadata.shape)
102            out_dtype = out_metadata.dtype
103            out_device = out_metadata.device
104            out_is_nested = out_metadata.is_nested_tensor
105            if out_metadata.is_cpp_nested_tensor:
106                raise RuntimeError(
107                    "C++ NestedTensor are not supported with GradientEdge"
108                )
109            out_is_cpp_nested = False
110        else:
111            # circular import
112            from torch.nested._internal.nested_tensor import NestedTensor
113
114            assert isinstance(out, torch.Tensor)
115            out_dtype = out.dtype
116            out_is_nested = out.is_nested
117            out_is_cpp_nested = out_is_nested and not isinstance(out, NestedTensor)
118            if not out_is_cpp_nested:
119                out_size = out.shape
120
121        if isinstance(grad, torch.Tensor):
122            from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq
123
124            first_grad = grad if not is_grads_batched else grad[0]
125
126            # TODO: We can remove this conditional once we uniformly use
127            # singleton int to represent jagged dimension, so that size() call
128            # on nested tensor works.
129            if out_is_cpp_nested:
130                assert isinstance(out, torch.Tensor)
131                shape_matches = torch.is_same_size(out, first_grad)
132            else:
133                # We need to do a regular size check, without going through
134                # the operator, to be able to handle unbacked symints
135                # (expect_true ensures we can deal with unbacked)
136                assert out_size is not None
137                shape_matches = expect_true(sym_eq(out_size, first_grad.size()))
138
139            if not shape_matches:
140                out = cast(Union[torch.Tensor, graph.GradientEdge], out)
141                out_shape, grad_shape = _calculate_shape(
142                    out, first_grad, is_grads_batched
143                )
144                if is_grads_batched:
145                    raise RuntimeError(
146                        "If `is_grads_batched=True`, we interpret the first "
147                        "dimension of each grad_output as the batch dimension. "
148                        "The sizes of the remaining dimensions are expected to match "
149                        "the shape of corresponding output, but a mismatch "
150                        "was detected: grad_output["
151                        + str(grads.index(grad))
152                        + "] has a shape of "
153                        + str(grad_shape)
154                        + " and output["
155                        + str(outputs.index(out))
156                        + "] has a shape of "
157                        + str(out_shape)
158                        + ". "
159                        "If you only want some tensors in `grad_output` to be considered "
160                        "batched, consider using vmap."
161                    )
162                else:
163                    raise RuntimeError(
164                        "Mismatch in shape: grad_output["
165                        + str(grads.index(grad))
166                        + "] has a shape of "
167                        + str(grad_shape)
168                        + " and output["
169                        + str(outputs.index(out))
170                        + "] has a shape of "
171                        + str(out_shape)
172                        + "."
173                    )
174            if out_dtype.is_complex != grad.dtype.is_complex:
175                raise RuntimeError(
176                    "For complex Tensors, both grad_output and output"
177                    " are required to have the same dtype."
178                    " Mismatch in dtype: grad_output["
179                    + str(grads.index(grad))
180                    + "] has a dtype of "
181                    + str(grad.dtype)
182                    + " and output["
183                    + str(outputs.index(out))
184                    + "] has a dtype of "
185                    + str(out_dtype)
186                    + "."
187                )
188            new_grads.append(grad)
189        elif grad is None:
190            if isinstance(out, graph.GradientEdge) or out.requires_grad:  # type: ignore[attr-defined]
191                if isinstance(out, graph.GradientEdge):
192                    assert out_size is not None
193                    out_numel_is_1 = all(o == 1 for o in out_size)
194                else:
195                    assert isinstance(out, torch.Tensor)
196                    out_numel_is_1 = out.numel() == 1
197                if not out_numel_is_1:
198                    raise RuntimeError(
199                        "grad can be implicitly created only for scalar outputs"
200                    )
201                if not out_dtype.is_floating_point:
202                    msg = (
203                        "grad can be implicitly created only for real scalar outputs"
204                        f" but got {out_dtype}"
205                    )
206                    raise RuntimeError(msg)
207                if isinstance(out, graph.GradientEdge):
208                    assert out_size is not None
209                    assert out_device is not None
210                    new_grads.append(
211                        torch.ones(
212                            out_size,
213                            dtype=out_dtype,
214                            device=out_device,
215                        )
216                    )
217                else:
218                    assert isinstance(out, torch.Tensor)
219                    new_grads.append(
220                        torch.ones_like(out, memory_format=torch.preserve_format)
221                    )
222            else:
223                new_grads.append(None)
224        else:
225            raise TypeError(
226                "gradients can be either Tensors or None, but got "
227                + type(grad).__name__
228            )
229    return tuple(new_grads)
230
231
232def _tensor_or_tensors_to_tuple(
233    tensors: Optional[_TensorOrTensors], length: int
234) -> Tuple[_OptionalTensor, ...]:
235    if tensors is None:
236        return (None,) * length
237    if isinstance(tensors, torch.Tensor):
238        return (tensors,)
239    return tuple(tensors)
240
241
242def backward(
243    tensors: _TensorOrTensors,
244    grad_tensors: Optional[_TensorOrTensors] = None,
245    retain_graph: Optional[bool] = None,
246    create_graph: bool = False,
247    grad_variables: Optional[_TensorOrTensors] = None,
248    inputs: Optional[_TensorOrTensorsOrGradEdge] = None,
249) -> None:
250    r"""Compute the sum of gradients of given tensors with respect to graph leaves.
251
252    The graph is differentiated using the chain rule. If any of ``tensors``
253    are non-scalar (i.e. their data has more than one element) and require
254    gradient, then the Jacobian-vector product would be computed, in this
255    case the function additionally requires specifying ``grad_tensors``.
256    It should be a sequence of matching length, that contains the "vector"
257    in the Jacobian-vector product, usually the gradient of the differentiated
258    function w.r.t. corresponding tensors (``None`` is an acceptable value for
259    all tensors that don't need gradient tensors).
260
261    This function accumulates gradients in the leaves - you might need to zero
262    ``.grad`` attributes or set them to ``None`` before calling it.
263    See :ref:`Default gradient layouts<default-grad-layouts>`
264    for details on the memory layout of accumulated gradients.
265
266    .. note::
267        Using this method with ``create_graph=True`` will create a reference cycle
268        between the parameter and its gradient which can cause a memory leak.
269        We recommend using ``autograd.grad`` when creating the graph to avoid this.
270        If you have to use this function, make sure to reset the ``.grad`` fields of your
271        parameters to ``None`` after use to break the cycle and avoid the leak.
272
273    .. note::
274
275        If you run any forward ops, create ``grad_tensors``, and/or call ``backward``
276        in a user-specified CUDA stream context, see
277        :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
278
279    .. note::
280
281        When ``inputs`` are provided and a given input is not a leaf,
282        the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
283        It is an implementation detail on which the user should not rely.
284        See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
285
286    Args:
287        tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be
288            computed.
289        grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in
290            the Jacobian-vector product, usually gradients w.r.t. each element of
291            corresponding tensors. None values can be specified for scalar Tensors or
292            ones that don't require grad. If a None value would be acceptable for all
293            grad_tensors, then this argument is optional.
294        retain_graph (bool, optional): If ``False``, the graph used to compute the grad
295            will be freed. Note that in nearly all cases setting this option to ``True``
296            is not needed and often can be worked around in a much more efficient
297            way. Defaults to the value of ``create_graph``.
298        create_graph (bool, optional): If ``True``, graph of the derivative will
299            be constructed, allowing to compute higher order derivative products.
300            Defaults to ``False``.
301        inputs (Sequence[Tensor] or Tensor or Sequence[GradientEdge], optional): Inputs w.r.t. which the gradient
302            be will accumulated into ``.grad``. All other Tensors will be ignored. If
303            not provided, the gradient is accumulated into all the leaf Tensors that
304            were used to compute the :attr:`tensors`.
305    """
306    if torch._C._are_functorch_transforms_active():
307        raise RuntimeError(
308            "backward() called inside a functorch transform. This is not "
309            "supported, please use functorch.grad or functorch.vjp instead "
310            "or call backward() outside of functorch transforms."
311        )
312
313    if grad_variables is not None:
314        warnings.warn(
315            "`grad_variables` is deprecated. Use `grad_tensors` instead.",
316            FutureWarning,
317            stacklevel=2,
318        )
319        if grad_tensors is None:
320            grad_tensors = grad_variables
321        else:
322            raise RuntimeError(
323                "`grad_tensors` and `grad_variables` (deprecated) "
324                "arguments both passed to `backward()`. Please only "
325                "use `grad_tensors`."
326            )
327    if inputs is not None and len(inputs) == 0:
328        raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
329
330    tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
331    inputs = (
332        (inputs,)
333        if isinstance(inputs, (torch.Tensor, graph.GradientEdge))
334        else tuple(inputs)
335        if inputs is not None
336        else ()
337    )
338
339    grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
340    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
341    if retain_graph is None:
342        retain_graph = create_graph
343
344    # The reason we repeat the same comment below is that
345    # some Python versions print out the first line of a multi-line function
346    # calls in the traceback and some print out the last line
347    _engine_run_backward(
348        tensors,
349        grad_tensors_,
350        retain_graph,
351        create_graph,
352        inputs,
353        allow_unreachable=True,
354        accumulate_grad=True,
355    )
356
357
358def grad(
359    outputs: _TensorOrTensorsOrGradEdge,
360    inputs: _TensorOrTensorsOrGradEdge,
361    grad_outputs: Optional[_TensorOrTensors] = None,
362    retain_graph: Optional[bool] = None,
363    create_graph: bool = False,
364    only_inputs: bool = True,
365    allow_unused: Optional[bool] = None,
366    is_grads_batched: bool = False,
367    materialize_grads: bool = False,
368) -> Tuple[torch.Tensor, ...]:
369    r"""Compute and return the sum of gradients of outputs with respect to the inputs.
370
371    ``grad_outputs`` should be a sequence of length matching ``output``
372    containing the "vector" in vector-Jacobian product, usually the pre-computed
373    gradients w.r.t. each of the outputs. If an output doesn't require_grad,
374    then the gradient can be ``None``).
375
376    .. note::
377
378        If you run any forward ops, create ``grad_outputs``, and/or call ``grad``
379        in a user-specified CUDA stream context, see
380        :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
381
382    .. note::
383
384        ``only_inputs`` argument is deprecated and is ignored now (defaults to ``True``).
385        To accumulate gradient for other parts of the graph, please use
386        ``torch.autograd.backward``.
387
388    Args:
389        outputs (sequence of Tensor or GradientEdge): outputs of the differentiated function.
390        inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be
391            returned (and not accumulated into ``.grad``).
392        grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product.
393            Usually gradients w.r.t. each output. None values can be specified for scalar
394            Tensors or ones that don't require grad. If a None value would be acceptable
395            for all grad_tensors, then this argument is optional. Default: None.
396        retain_graph (bool, optional): If ``False``, the graph used to compute the grad
397            will be freed. Note that in nearly all cases setting this option to ``True``
398            is not needed and often can be worked around in a much more efficient
399            way. Defaults to the value of ``create_graph``.
400        create_graph (bool, optional): If ``True``, graph of the derivative will
401            be constructed, allowing to compute higher order derivative products.
402            Default: ``False``.
403        allow_unused (Optional[bool], optional): If ``False``, specifying inputs
404            that were not used when computing outputs (and therefore their grad is
405            always zero) is an error. Defaults to the value of ``materialize_grads``.
406        is_grads_batched (bool, optional): If ``True``, the first dimension of each
407            tensor in ``grad_outputs`` will be interpreted as the batch dimension.
408            Instead of computing a single vector-Jacobian product, we compute a
409            batch of vector-Jacobian products for each "vector" in the batch.
410            We use the vmap prototype feature as the backend to vectorize calls
411            to the autograd engine so that this computation can be performed in a
412            single call. This should lead to performance improvements when compared
413            to manually looping and performing backward multiple times. Note that
414            due to this feature being experimental, there may be performance
415            cliffs. Please use ``torch._C._debug_only_display_vmap_fallback_warnings(True)``
416            to show any performance warnings and file an issue on github if warnings exist
417            for your use case. Defaults to ``False``.
418        materialize_grads (bool, optional): If ``True``, set the gradient for unused inputs
419            to zero instead of None. This is useful when computing higher-order derivatives.
420            If ``materialize_grads`` is ``True`` and ``allow_unused`` is ``False``, an error
421            will be raised. Defaults to ``False``.
422
423    """
424    if materialize_grads and allow_unused is False:
425        raise ValueError(
426            "Expected allow_unused to be True or not passed when materialize_grads=True, "
427            "but got: allow_unused=False."
428        )
429    if allow_unused is None:
430        allow_unused = materialize_grads
431    if is_tensor_like(outputs) or isinstance(outputs, graph.GradientEdge):
432        outputs = cast(
433            Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
434        )
435    else:
436        outputs = tuple(outputs)
437    if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
438        inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
439    else:
440        inputs = tuple(inputs)
441    t_outputs = tuple(i for i in outputs if is_tensor_like(i))
442    t_inputs = tuple(i for i in inputs if is_tensor_like(i))
443    overridable_args = t_outputs + t_inputs
444    if has_torch_function(overridable_args):
445        return handle_torch_function(
446            grad,
447            overridable_args,
448            outputs,
449            inputs,
450            grad_outputs=grad_outputs,
451            retain_graph=retain_graph,
452            create_graph=create_graph,
453            only_inputs=only_inputs,
454            allow_unused=allow_unused,
455            is_grads_batched=is_grads_batched,
456            materialize_grads=materialize_grads,
457        )
458
459    if not only_inputs:
460        warnings.warn(
461            "only_inputs argument is deprecated and is ignored now "
462            "(defaults to True). To accumulate gradient for other "
463            "parts of the graph, please use torch.autograd.backward.",
464            FutureWarning,
465            stacklevel=2,
466        )
467
468    grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
469    grad_outputs_ = _make_grads(
470        outputs, grad_outputs_, is_grads_batched=is_grads_batched
471    )
472
473    if retain_graph is None:
474        retain_graph = create_graph
475
476    # The reason we repeat the same comment several times below is because
477    # some Python versions print out the first line of multi-line function
478    # calls in the traceback and some print out the last line
479    if is_grads_batched:
480
481        def vjp(gO):
482            return _engine_run_backward(
483                outputs,
484                gO,
485                retain_graph,
486                create_graph,
487                inputs,
488                allow_unused,
489                accumulate_grad=False,
490            )
491
492        result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
493            grad_outputs_
494        )
495    else:
496        result = _engine_run_backward(
497            outputs,
498            grad_outputs_,
499            retain_graph,
500            create_graph,
501            inputs,
502            allow_unused,
503            accumulate_grad=False,
504        )
505    if materialize_grads:
506        if any(
507            result[i] is None and not is_tensor_like(inputs[i])
508            for i in range(len(inputs))
509        ):
510            raise RuntimeError(
511                "materialize_grads cannot be used when the given input is a GradientEdge"
512            )
513        result = tuple(
514            output
515            if output is not None
516            else torch.zeros_like(input, requires_grad=True)
517            for (output, input) in zip(result, inputs)
518        )
519    return result
520
521
522# This function applies in case of gradient checkpointing for memory
523# optimization. Currently, gradient checkpointing is supported only if the
524# execution engine is invoked through torch.autograd.backward() and its
525# inputs argument is not passed. It is not supported for torch.autograd.grad().
526# This is because if inputs are specified, the gradient won't be calculated for
527# anything else e.g. model parameters like weights, bias etc.
528#
529# This function returns whether the checkpointing is valid i.e. torch.autograd.backward
530# or not i.e. torch.autograd.grad. The implementation works by maintaining a thread
531# local variable in torch/csrc/autograd/engine.cpp which looks at the NodeTask
532# in the stack and before a NodeTask is executed in evaluate_function, it
533# checks for whether reentrant backwards is imperative or not.
534# See https://github.com/pytorch/pytorch/pull/4594 for more discussion/context
535def _is_checkpoint_valid():
536    return Variable._execution_engine.is_checkpoint_valid()
537
538
539def variable(*args, **kwargs):  # noqa: D103
540    raise RuntimeError(
541        "torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead"
542    )
543
544
545# Monkey patching variable.Variable to fix FX codegen. FX generates a call by roughly doing
546# f"{fn.__module__}.{fn.__name__}(...). This yields torch.autograd.variable.Variable(...) in the
547# output of an FX graph.  Unfortunately the module name torch.autograd.variable is shadowed by the
548# deprecated function - variable(...).
549variable.Variable = Variable  # type: ignore[attr-defined]
550
551if not torch._C._autograd_init():
552    raise RuntimeError("autograd initialization failed")
553
554# Import all native method/classes
555from torch._C._autograd import (
556    _add_metadata_json,
557    _disable_profiler,
558    _disable_profiler_legacy,
559    _enable_profiler,
560    _enable_profiler_legacy,
561    _enable_record_function,
562    _get_sequence_nr,
563    _kineto_step,
564    _KinetoEvent,
565    _pop_saved_tensors_default_hooks,
566    _prepare_profiler,
567    _profiler_enabled,
568    _ProfilerResult,
569    _push_saved_tensors_default_hooks,
570    _record_function_with_args_enter,
571    _record_function_with_args_exit,
572    _set_empty_test_observer,
573    _supported_activities,
574    _toggle_collection_dynamic,
575    DeviceType,
576    kineto_available,
577    ProfilerEvent,
578    SavedTensor,
579)
580from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState
581
582from . import profiler
583
584
585def _register_py_tensor_class_for_device(device, cls):
586    if not isinstance(cls, type):
587        raise RuntimeError("cls isn't a typeinfo object")
588    torch._C._register_py_class_for_device(device, cls)
589
590
591is_multithreading_enabled = torch._C._is_multithreading_enabled
592torch._C._add_docstr(
593    is_multithreading_enabled, "Returns True if multithreading is currently enabled."
594)
595
596is_view_replay_enabled = torch._C._is_view_replay_enabled
597torch._C._add_docstr(
598    is_view_replay_enabled, "Returns True if view-replay is currently enabled."
599)
600