xref: /aosp_15_r20/external/pytorch/torch/nn/functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Functional interface."""
2
3import importlib
4import math
5import warnings
6from typing import Callable, List, Optional, Tuple, TYPE_CHECKING, Union
7
8import torch
9from torch import _VF, sym_int as _sym_int, Tensor
10from torch._C import _add_docstr, _infer_size
11from torch._jit_internal import (
12    _overload,
13    boolean_dispatch,
14    BroadcastingList1,
15    BroadcastingList2,
16    BroadcastingList3,
17)
18from torch._torch_docs import reproducibility_notes, sparse_support_notes, tf32_notes
19from torch.nn import _reduction as _Reduction, grad  # noqa: F401
20from torch.nn.modules.utils import _list_with_default, _pair, _single, _triple
21from torch.overrides import (
22    handle_torch_function,
23    has_torch_function,
24    has_torch_function_unary,
25    has_torch_function_variadic,
26)
27
28
29if TYPE_CHECKING:
30    from torch.types import _dtype as DType
31else:
32    # The JIT doesn't understand Union, nor torch.dtype here
33    DType = int
34
35try:
36    import numpy as np
37except ModuleNotFoundError:
38    np = None
39
40
41conv1d = _add_docstr(
42    torch.conv1d,
43    r"""
44conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
45
46Applies a 1D convolution over an input signal composed of several input
47planes.
48
49{tf32_note}
50
51See :class:`~torch.nn.Conv1d` for details and output shape.
52
53Note:
54    {cudnn_reproducibility_note}
55
56Note:
57    This operator supports complex data types i.e. ``complex32, complex64, complex128``.
58""".format(
59        **reproducibility_notes, **tf32_notes
60    )
61    + r"""
62
63Args:
64    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
65    weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kW)`
66    bias: optional bias of shape :math:`(\text{out\_channels})`. Default: ``None``
67    stride: the stride of the convolving kernel. Can be a single number or
68      a one-element tuple `(sW,)`. Default: 1
69    padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'},
70      single number or a one-element tuple `(padW,)`. Default: 0
71      ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
72      the input so the output has the same shape as the input. However, this mode
73      doesn't support any stride values other than 1.
74
75      .. warning::
76          For ``padding='same'``, if the ``weight`` is even-length and
77          ``dilation`` is odd in any dimension, a full :func:`pad` operation
78          may be needed internally. Lowering performance.
79    dilation: the spacing between kernel elements. Can be a single number or
80      a one-element tuple `(dW,)`. Default: 1
81    groups: split input into groups, :math:`\text{in\_channels}` should be divisible by
82      the number of groups. Default: 1
83
84Examples::
85
86    >>> inputs = torch.randn(33, 16, 30)
87    >>> filters = torch.randn(20, 16, 5)
88    >>> F.conv1d(inputs, filters)
89""",
90)
91
92conv2d = _add_docstr(
93    torch.conv2d,
94    r"""
95conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
96
97Applies a 2D convolution over an input image composed of several input
98planes.
99
100{tf32_note}
101
102See :class:`~torch.nn.Conv2d` for details and output shape.
103
104Note:
105    {cudnn_reproducibility_note}
106
107Note:
108    This operator supports complex data types i.e. ``complex32, complex64, complex128``.
109""".format(
110        **reproducibility_notes, **tf32_notes
111    )
112    + r"""
113
114Args:
115    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
116    weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
117    bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
118    stride: the stride of the convolving kernel. Can be a single number or a
119      tuple `(sH, sW)`. Default: 1
120    padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'},
121      single number or a tuple `(padH, padW)`. Default: 0
122      ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
123      the input so the output has the same shape as the input. However, this mode
124      doesn't support any stride values other than 1.
125
126      .. warning::
127          For ``padding='same'``, if the ``weight`` is even-length and
128          ``dilation`` is odd in any dimension, a full :func:`pad` operation
129          may be needed internally. Lowering performance.
130
131    dilation: the spacing between kernel elements. Can be a single number or
132      a tuple `(dH, dW)`. Default: 1
133    groups: split input into groups, both :math:`\text{in\_channels}` and :math:`\text{out\_channels}`
134      should be divisible by the number of groups. Default: 1
135
136Examples::
137
138    >>> # With square kernels and equal stride
139    >>> filters = torch.randn(8, 4, 3, 3)
140    >>> inputs = torch.randn(1, 4, 5, 5)
141    >>> F.conv2d(inputs, filters, padding=1)
142""",
143)  # noqa: E501
144
145conv3d = _add_docstr(
146    torch.conv3d,
147    r"""
148conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
149
150Applies a 3D convolution over an input image composed of several input
151planes.
152
153{tf32_note}
154
155See :class:`~torch.nn.Conv3d` for details and output shape.
156
157Note:
158    {cudnn_reproducibility_note}
159
160Note:
161    This operator supports complex data types i.e. ``complex32, complex64, complex128``.
162""".format(
163        **reproducibility_notes, **tf32_notes
164    )
165    + r"""
166
167Args:
168    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)`
169    weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kT , kH , kW)`
170    bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: None
171    stride: the stride of the convolving kernel. Can be a single number or a
172      tuple `(sT, sH, sW)`. Default: 1
173    padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'},
174      single number or a tuple `(padT, padH, padW)`. Default: 0
175      ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
176      the input so the output has the same shape as the input. However, this mode
177      doesn't support any stride values other than 1.
178
179      .. warning::
180          For ``padding='same'``, if the ``weight`` is even-length and
181          ``dilation`` is odd in any dimension, a full :func:`pad` operation
182          may be needed internally. Lowering performance.
183
184    dilation: the spacing between kernel elements. Can be a single number or
185      a tuple `(dT, dH, dW)`. Default: 1
186    groups: split input into groups, :math:`\text{in\_channels}` should be divisible by
187      the number of groups. Default: 1
188
189Examples::
190
191    >>> filters = torch.randn(33, 16, 3, 3, 3)
192    >>> inputs = torch.randn(20, 16, 50, 10, 20)
193    >>> F.conv3d(inputs, filters)
194""",
195)  # noqa: E501
196
197conv_transpose1d = _add_docstr(
198    torch.conv_transpose1d,
199    r"""
200conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
201
202Applies a 1D transposed convolution operator over an input signal
203composed of several input planes, sometimes also called "deconvolution".
204
205{tf32_note}
206
207See :class:`~torch.nn.ConvTranspose1d` for details and output shape.
208
209Note:
210    {cudnn_reproducibility_note}
211""".format(
212        **reproducibility_notes, **tf32_notes
213    )
214    + r"""
215
216Args:
217    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
218    weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kW)`
219    bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
220    stride: the stride of the convolving kernel. Can be a single number or a
221      tuple ``(sW,)``. Default: 1
222    padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
223      sides of each dimension in the input. Can be a single number or a tuple
224      ``(padW,)``. Default: 0
225    output_padding: additional size added to one side of each dimension in the
226      output shape. Can be a single number or a tuple ``(out_padW)``. Default: 0
227    groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
228      number of groups. Default: 1
229    dilation: the spacing between kernel elements. Can be a single number or
230      a tuple ``(dW,)``. Default: 1
231
232Examples::
233
234    >>> inputs = torch.randn(20, 16, 50)
235    >>> weights = torch.randn(16, 33, 5)
236    >>> F.conv_transpose1d(inputs, weights)
237""",
238)
239
240conv_transpose2d = _add_docstr(
241    torch.conv_transpose2d,
242    r"""
243conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
244
245Applies a 2D transposed convolution operator over an input image
246composed of several input planes, sometimes also called "deconvolution".
247
248{tf32_note}
249
250See :class:`~torch.nn.ConvTranspose2d` for details and output shape.
251
252Note:
253    {cudnn_reproducibility_note}
254""".format(
255        **reproducibility_notes, **tf32_notes
256    )
257    + r"""
258
259Args:
260    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
261    weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kH , kW)`
262    bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
263    stride: the stride of the convolving kernel. Can be a single number or a
264      tuple ``(sH, sW)``. Default: 1
265    padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
266      sides of each dimension in the input. Can be a single number or a tuple
267      ``(padH, padW)``. Default: 0
268    output_padding: additional size added to one side of each dimension in the
269      output shape. Can be a single number or a tuple ``(out_padH, out_padW)``.
270      Default: 0
271    groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
272      number of groups. Default: 1
273    dilation: the spacing between kernel elements. Can be a single number or
274      a tuple ``(dH, dW)``. Default: 1
275
276Examples::
277
278    >>> # With square kernels and equal stride
279    >>> inputs = torch.randn(1, 4, 5, 5)
280    >>> weights = torch.randn(4, 8, 3, 3)
281    >>> F.conv_transpose2d(inputs, weights, padding=1)
282""",
283)  # noqa: E501
284
285conv_transpose3d = _add_docstr(
286    torch.conv_transpose3d,
287    r"""
288conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
289
290Applies a 3D transposed convolution operator over an input image
291composed of several input planes, sometimes also called "deconvolution"
292
293{tf32_note}
294
295See :class:`~torch.nn.ConvTranspose3d` for details and output shape.
296
297Note:
298    {cudnn_reproducibility_note}
299""".format(
300        **reproducibility_notes, **tf32_notes
301    )
302    + r"""
303
304Args:
305    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)`
306    weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kT , kH , kW)`
307    bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
308    stride: the stride of the convolving kernel. Can be a single number or a
309      tuple ``(sT, sH, sW)``. Default: 1
310    padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
311      sides of each dimension in the input. Can be a single number or a tuple
312      ``(padT, padH, padW)``. Default: 0
313    output_padding: additional size added to one side of each dimension in the
314      output shape. Can be a single number or a tuple
315      ``(out_padT, out_padH, out_padW)``. Default: 0
316    groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
317      number of groups. Default: 1
318    dilation: the spacing between kernel elements. Can be a single number or
319      a tuple `(dT, dH, dW)`. Default: 1
320
321Examples::
322
323    >>> inputs = torch.randn(20, 16, 50, 10, 20)
324    >>> weights = torch.randn(16, 33, 3, 3, 3)
325    >>> F.conv_transpose3d(inputs, weights)
326""",
327)  # noqa: E501
328
329conv_tbc = _add_docstr(
330    torch.conv_tbc,
331    r"""
332Applies a 1-dimensional sequence convolution over an input sequence.
333Input and output dimensions are (Time, Batch, Channels) - hence TBC.
334
335Args:
336    input: input tensor of shape :math:`(\text{sequence length} \times batch \times \text{in\_channels})`
337    weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`)
338    bias: bias of shape (:math:`\text{out\_channels}`)
339    pad: number of timesteps to pad. Default: 0
340""",
341)
342
343
344# Pooling
345avg_pool1d = _add_docstr(
346    torch.avg_pool1d,
347    r"""
348avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor
349
350Applies a 1D average pooling over an input signal composed of several
351input planes.
352
353See :class:`~torch.nn.AvgPool1d` for details and output shape.
354
355Args:
356    input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
357    kernel_size: the size of the window. Can be a single number or a
358      tuple `(kW,)`
359    stride: the stride of the window. Can be a single number or a tuple
360      `(sW,)`. Default: :attr:`kernel_size`
361    padding: implicit zero paddings on both sides of the input. Can be a
362      single number or a tuple `(padW,)`. Default: 0
363    ceil_mode: when True, will use `ceil` instead of `floor` to compute the
364        output shape. Default: ``False``
365    count_include_pad: when True, will include the zero-padding in the
366        averaging calculation. Default: ``True``
367
368Examples::
369
370    >>> # pool of square window of size=3, stride=2
371    >>> input = torch.tensor([[[1, 2, 3, 4, 5, 6, 7]]], dtype=torch.float32)
372    >>> F.avg_pool1d(input, kernel_size=3, stride=2)
373    tensor([[[ 2.,  4.,  6.]]])
374
375""",
376)
377
378
379avg_pool2d = _add_docstr(
380    torch._C._nn.avg_pool2d,
381    r"""
382avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor
383
384Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size
385:math:`sH \times sW` steps. The number of output features is equal to the number of
386input planes.
387
388See :class:`~torch.nn.AvgPool2d` for details and output shape.
389
390Args:
391    input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
392    kernel_size: size of the pooling region. Can be a single number or a
393      tuple `(kH, kW)`
394    stride: stride of the pooling operation. Can be a single number or a
395      tuple `(sH, sW)`. Default: :attr:`kernel_size`
396    padding: implicit zero paddings on both sides of the input. Can be a
397      single number or a tuple `(padH, padW)`. Default: 0
398    ceil_mode: when True, will use `ceil` instead of `floor` in the formula
399        to compute the output shape. Default: ``False``
400    count_include_pad: when True, will include the zero-padding in the
401        averaging calculation. Default: ``True``
402    divisor_override: if specified, it will be used as divisor, otherwise
403         size of the pooling region will be used. Default: None
404""",
405)
406
407avg_pool3d = _add_docstr(
408    torch._C._nn.avg_pool3d,
409    r"""
410avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor
411
412Applies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step
413size :math:`sT \times sH \times sW` steps. The number of output features is equal to
414:math:`\lfloor\frac{\text{input planes}}{sT}\rfloor`.
415
416See :class:`~torch.nn.AvgPool3d` for details and output shape.
417
418Args:
419    input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iT \times iH , iW)`
420    kernel_size: size of the pooling region. Can be a single number or a
421      tuple `(kT, kH, kW)`
422    stride: stride of the pooling operation. Can be a single number or a
423      tuple `(sT, sH, sW)`. Default: :attr:`kernel_size`
424    padding: implicit zero paddings on both sides of the input. Can be a
425      single number or a tuple `(padT, padH, padW)`, Default: 0
426    ceil_mode: when True, will use `ceil` instead of `floor` in the formula
427        to compute the output shape
428    count_include_pad: when True, will include the zero-padding in the
429        averaging calculation
430    divisor_override: if specified, it will be used as divisor, otherwise
431        size of the pooling region will be used. Default: None
432""",
433)
434
435
436def fractional_max_pool2d_with_indices(
437    input: Tensor,
438    kernel_size: BroadcastingList2[int],
439    output_size: Optional[BroadcastingList2[int]] = None,
440    output_ratio: Optional[BroadcastingList2[float]] = None,
441    return_indices: bool = False,
442    _random_samples: Optional[Tensor] = None,
443) -> Tuple[Tensor, Tensor]:  # noqa: D400
444    r"""
445    fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
446
447    Applies 2D fractional max pooling over an input signal composed of several input planes.
448
449    Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
450
451    The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic
452    step size determined by the target output size.
453    The number of output features is equal to the number of input planes.
454
455    Args:
456        kernel_size: the size of the window to take a max over.
457                     Can be a single number :math:`k` (for a square kernel of :math:`k \times k`)
458                     or a tuple `(kH, kW)`
459        output_size: the target output size of the image of the form :math:`oH \times oW`.
460                     Can be a tuple `(oH, oW)` or a single number :math:`oH` for a square image :math:`oH \times oH`
461        output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
462                      This has to be a number or tuple in the range (0, 1)
463        return_indices: if ``True``, will return the indices along with the outputs.
464                        Useful to pass to :func:`~torch.nn.functional.max_unpool2d`.
465
466    Examples::
467        >>> input = torch.randn(20, 16, 50, 32)
468        >>> # pool of square window of size=3, and target output size 13x12
469        >>> F.fractional_max_pool2d(input, 3, output_size=(13, 12))
470        >>> # pool of square window and target output size being half of input image size
471        >>> F.fractional_max_pool2d(input, 3, output_ratio=(0.5, 0.5))
472
473    .. _Fractional MaxPooling:
474        http://arxiv.org/abs/1412.6071
475    """
476    if has_torch_function_variadic(input, _random_samples):
477        return handle_torch_function(
478            fractional_max_pool2d_with_indices,
479            (input, _random_samples),
480            input,
481            kernel_size,
482            output_size=output_size,
483            output_ratio=output_ratio,
484            return_indices=return_indices,
485            _random_samples=_random_samples,
486        )
487    if output_size is None and output_ratio is None:
488        raise ValueError(
489            "fractional_max_pool2d requires specifying either an output_size or an output_ratio"
490        )
491    if output_size is None:
492        assert output_ratio is not None
493        if len(output_ratio) > 2:
494            raise ValueError(
495                "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."
496            )
497        _output_ratio = _pair(output_ratio)
498        output_size = [
499            int(input.size(-2) * _output_ratio[0]),
500            int(input.size(-1) * _output_ratio[1]),
501        ]
502
503    if _random_samples is None:
504        n_batch = 1 if input.dim() == 3 else input.size(0)
505        _random_samples = torch.rand(
506            n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device
507        )
508    return torch._C._nn.fractional_max_pool2d(
509        input, kernel_size, output_size, _random_samples
510    )
511
512
513def _fractional_max_pool2d(
514    input: Tensor,
515    kernel_size: BroadcastingList2[int],
516    output_size: Optional[BroadcastingList2[int]] = None,
517    output_ratio: Optional[BroadcastingList2[float]] = None,
518    return_indices: bool = False,
519    _random_samples: Optional[Tensor] = None,
520) -> Tensor:
521    if has_torch_function_variadic(input, _random_samples):
522        return handle_torch_function(
523            fractional_max_pool2d,
524            (input, _random_samples),
525            input,
526            kernel_size,
527            output_size=output_size,
528            output_ratio=output_ratio,
529            return_indices=return_indices,
530            _random_samples=_random_samples,
531        )
532    return fractional_max_pool2d_with_indices(
533        input, kernel_size, output_size, output_ratio, return_indices, _random_samples
534    )[0]
535
536
537fractional_max_pool2d = boolean_dispatch(
538    arg_name="return_indices",
539    arg_index=4,
540    default=False,
541    if_true=fractional_max_pool2d_with_indices,
542    if_false=_fractional_max_pool2d,
543    module_name=__name__,
544    func_name="fractional_max_pool2d",
545)
546
547
548def fractional_max_pool3d_with_indices(
549    input: Tensor,
550    kernel_size: BroadcastingList3[int],
551    output_size: Optional[BroadcastingList3[int]] = None,
552    output_ratio: Optional[BroadcastingList3[float]] = None,
553    return_indices: bool = False,
554    _random_samples: Optional[Tensor] = None,
555) -> Tuple[Tensor, Tensor]:  # noqa: D400
556    r"""
557    fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
558
559    Applies 3D fractional max pooling over an input signal composed of several input planes.
560
561    Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
562
563    The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic
564    step size determined by the target output size.
565    The number of output features is equal to the number of input planes.
566
567    Args:
568        kernel_size: the size of the window to take a max over.
569                     Can be a single number :math:`k` (for a square kernel of :math:`k \times k \times k`)
570                     or a tuple `(kT, kH, kW)`
571        output_size: the target output size of the form :math:`oT \times oH \times oW`.
572                     Can be a tuple `(oT, oH, oW)` or a single number :math:`oH` for a cubic output
573                     :math:`oH \times oH \times oH`
574        output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
575                      This has to be a number or tuple in the range (0, 1)
576        return_indices: if ``True``, will return the indices along with the outputs.
577                        Useful to pass to :func:`~torch.nn.functional.max_unpool3d`.
578
579    Shape:
580        - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`.
581        - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where
582          :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or
583          :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})`
584
585    Examples::
586        >>> input = torch.randn(20, 16, 50, 32, 16)
587        >>> # pool of cubic window of size=3, and target output size 13x12x11
588        >>> F.fractional_max_pool3d(input, 3, output_size=(13, 12, 11))
589        >>> # pool of cubic window and target output size being half of input size
590        >>> F.fractional_max_pool3d(input, 3, output_ratio=(0.5, 0.5, 0.5))
591
592    .. _Fractional MaxPooling:
593        http://arxiv.org/abs/1412.6071
594    """
595    if has_torch_function_variadic(input, _random_samples):
596        return handle_torch_function(
597            fractional_max_pool3d_with_indices,
598            (input, _random_samples),
599            input,
600            kernel_size,
601            output_size=output_size,
602            output_ratio=output_ratio,
603            return_indices=return_indices,
604            _random_samples=_random_samples,
605        )
606    if output_size is None and output_ratio is None:
607        raise ValueError(
608            "fractional_max_pool3d requires specifying either an output_size or an output_ratio"
609        )
610    if output_size is None:
611        assert output_ratio is not None
612        _output_ratio = _triple(output_ratio)
613        output_size = [
614            int(input.size(-3) * _output_ratio[0]),
615            int(input.size(-2) * _output_ratio[1]),
616            int(input.size(-1) * _output_ratio[2]),
617        ]
618
619    if _random_samples is None:
620        n_batch = 1 if input.dim() == 4 else input.size(0)
621        _random_samples = torch.rand(
622            n_batch, input.size(-4), 3, dtype=input.dtype, device=input.device
623        )
624    return torch._C._nn.fractional_max_pool3d(
625        input, kernel_size, output_size, _random_samples
626    )
627
628
629def _fractional_max_pool3d(
630    input: Tensor,
631    kernel_size: BroadcastingList3[int],
632    output_size: Optional[BroadcastingList3[int]] = None,
633    output_ratio: Optional[BroadcastingList3[float]] = None,
634    return_indices: bool = False,
635    _random_samples: Optional[Tensor] = None,
636) -> Tensor:
637    if has_torch_function_variadic(input, _random_samples):
638        return handle_torch_function(
639            fractional_max_pool3d,
640            (input, _random_samples),
641            input,
642            kernel_size,
643            output_size=output_size,
644            output_ratio=output_ratio,
645            return_indices=return_indices,
646            _random_samples=_random_samples,
647        )
648    return fractional_max_pool3d_with_indices(
649        input, kernel_size, output_size, output_ratio, return_indices, _random_samples
650    )[0]
651
652
653fractional_max_pool3d = boolean_dispatch(
654    arg_name="return_indices",
655    arg_index=4,
656    default=False,
657    if_true=fractional_max_pool3d_with_indices,
658    if_false=_fractional_max_pool3d,
659    module_name=__name__,
660    func_name="fractional_max_pool3d",
661)
662
663
664def max_pool1d_with_indices(
665    input: Tensor,
666    kernel_size: BroadcastingList1[int],
667    stride: Optional[BroadcastingList1[int]] = None,
668    padding: BroadcastingList1[int] = 0,
669    dilation: BroadcastingList1[int] = 1,
670    ceil_mode: bool = False,
671    return_indices: bool = False,
672) -> Tuple[Tensor, Tensor]:  # noqa: D400
673    r"""
674    max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
675
676    Applies a 1D max pooling over an input signal composed of several input
677    planes.
678
679    .. note::
680        The order of :attr:`ceil_mode` and :attr:`return_indices` is different from
681        what seen in :class:`~torch.nn.MaxPool1d`, and will change in a future release.
682
683    See :class:`~torch.nn.MaxPool1d` for details.
684
685    Args:
686        input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`, minibatch dim optional.
687        kernel_size: the size of the window. Can be a single number or a
688            tuple `(kW,)`
689        stride: the stride of the window. Can be a single number or a tuple
690            `(sW,)`. Default: :attr:`kernel_size`
691        padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
692        dilation: The stride between elements within a sliding window, must be > 0.
693        ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
694                   ensures that every element in the input tensor is covered by a sliding window.
695        return_indices: If ``True``, will return the argmax along with the max values.
696                        Useful for :class:`torch.nn.functional.max_unpool1d` later
697    """
698    if has_torch_function_unary(input):
699        return handle_torch_function(
700            max_pool1d_with_indices,
701            (input,),
702            input,
703            kernel_size,
704            stride=stride,
705            padding=padding,
706            dilation=dilation,
707            ceil_mode=ceil_mode,
708            return_indices=return_indices,
709        )
710    if stride is None:
711        stride = torch.jit.annotate(List[int], [])
712    return torch.max_pool1d_with_indices(
713        input, kernel_size, stride, padding, dilation, ceil_mode
714    )
715
716
717def _max_pool1d(
718    input: Tensor,
719    kernel_size: BroadcastingList1[int],
720    stride: Optional[BroadcastingList1[int]] = None,
721    padding: BroadcastingList1[int] = 0,
722    dilation: BroadcastingList1[int] = 1,
723    ceil_mode: bool = False,
724    return_indices: bool = False,
725) -> Tensor:
726    if has_torch_function_unary(input):
727        return handle_torch_function(
728            max_pool1d,
729            (input,),
730            input,
731            kernel_size,
732            stride=stride,
733            padding=padding,
734            dilation=dilation,
735            ceil_mode=ceil_mode,
736            return_indices=return_indices,
737        )
738    if stride is None:
739        stride = torch.jit.annotate(List[int], [])
740    return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)
741
742
743max_pool1d = boolean_dispatch(
744    arg_name="return_indices",
745    arg_index=6,
746    default=False,
747    if_true=max_pool1d_with_indices,
748    if_false=_max_pool1d,
749    module_name=__name__,
750    func_name="max_pool1d",
751)
752
753
754def max_pool2d_with_indices(
755    input: Tensor,
756    kernel_size: BroadcastingList2[int],
757    stride: Optional[BroadcastingList2[int]] = None,
758    padding: BroadcastingList2[int] = 0,
759    dilation: BroadcastingList2[int] = 1,
760    ceil_mode: bool = False,
761    return_indices: bool = False,
762) -> Tuple[Tensor, Tensor]:  # noqa: D400
763    r"""
764    max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
765
766    Applies a 2D max pooling over an input signal composed of several input
767    planes.
768
769    .. note::
770        The order of :attr:`ceil_mode` and :attr:`return_indices` is different from
771        what seen in :class:`~torch.nn.MaxPool2d`, and will change in a future release.
772
773    See :class:`~torch.nn.MaxPool2d` for details.
774
775    Args:
776        input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`, minibatch dim optional.
777        kernel_size: size of the pooling region. Can be a single number or a
778            tuple `(kH, kW)`
779        stride: stride of the pooling operation. Can be a single number or a
780            tuple `(sH, sW)`. Default: :attr:`kernel_size`
781        padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
782        dilation: The stride between elements within a sliding window, must be > 0.
783        ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
784                   ensures that every element in the input tensor is covered by a sliding window.
785        return_indices: If ``True``, will return the argmax along with the max values.
786                        Useful for :class:`torch.nn.functional.max_unpool2d` later
787    """
788    if has_torch_function_unary(input):
789        return handle_torch_function(
790            max_pool2d_with_indices,
791            (input,),
792            input,
793            kernel_size,
794            stride=stride,
795            padding=padding,
796            dilation=dilation,
797            ceil_mode=ceil_mode,
798            return_indices=return_indices,
799        )
800    if stride is None:
801        stride = torch.jit.annotate(List[int], [])
802    return torch._C._nn.max_pool2d_with_indices(
803        input, kernel_size, stride, padding, dilation, ceil_mode
804    )
805
806
807def _max_pool2d(
808    input: Tensor,
809    kernel_size: BroadcastingList2[int],
810    stride: Optional[BroadcastingList2[int]] = None,
811    padding: BroadcastingList2[int] = 0,
812    dilation: BroadcastingList2[int] = 1,
813    ceil_mode: bool = False,
814    return_indices: bool = False,
815) -> Tensor:
816    if has_torch_function_unary(input):
817        return handle_torch_function(
818            max_pool2d,
819            (input,),
820            input,
821            kernel_size,
822            stride=stride,
823            padding=padding,
824            dilation=dilation,
825            ceil_mode=ceil_mode,
826            return_indices=return_indices,
827        )
828    if stride is None:
829        stride = torch.jit.annotate(List[int], [])
830    return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
831
832
833max_pool2d = boolean_dispatch(
834    arg_name="return_indices",
835    arg_index=6,
836    default=False,
837    if_true=max_pool2d_with_indices,
838    if_false=_max_pool2d,
839    module_name=__name__,
840    func_name="max_pool2d",
841)
842
843
844def max_pool3d_with_indices(
845    input: Tensor,
846    kernel_size: BroadcastingList3[int],
847    stride: Optional[BroadcastingList3[int]] = None,
848    padding: BroadcastingList3[int] = 0,
849    dilation: BroadcastingList3[int] = 1,
850    ceil_mode: bool = False,
851    return_indices: bool = False,
852) -> Tuple[Tensor, Tensor]:  # noqa: D400
853    r"""
854    max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
855
856    Applies a 3D max pooling over an input signal composed of several input
857    planes.
858
859    .. note::
860        The order of :attr:`ceil_mode` and :attr:`return_indices` is different from
861        what seen in :class:`~torch.nn.MaxPool3d`, and will change in a future release.
862
863    See :class:`~torch.nn.MaxPool3d` for details.
864
865    Args:
866        input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iD, iH , iW)`, minibatch dim optional.
867        kernel_size: size of the pooling region. Can be a single number or a
868                     tuple `(kT, kH, kW)`
869        stride: stride of the pooling operation. Can be a single number or a
870                tuple `(sT, sH, sW)`. Default: :attr:`kernel_size`
871        padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
872        dilation: The stride between elements within a sliding window, must be > 0.
873        ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
874                   ensures that every element in the input tensor is covered by a sliding window.
875        return_indices: If ``True``, will return the argmax along with the max values.
876                        Useful for :class:`torch.nn.functional.max_unpool3d` later
877    """
878    if has_torch_function_unary(input):
879        return handle_torch_function(
880            max_pool3d_with_indices,
881            (input,),
882            input,
883            kernel_size,
884            stride=stride,
885            padding=padding,
886            dilation=dilation,
887            ceil_mode=ceil_mode,
888            return_indices=return_indices,
889        )
890    if stride is None:
891        stride = torch.jit.annotate(List[int], [])
892    return torch._C._nn.max_pool3d_with_indices(
893        input, kernel_size, stride, padding, dilation, ceil_mode
894    )
895
896
897def _max_pool3d(
898    input: Tensor,
899    kernel_size: BroadcastingList3[int],
900    stride: Optional[BroadcastingList3[int]] = None,
901    padding: BroadcastingList3[int] = 0,
902    dilation: BroadcastingList3[int] = 1,
903    ceil_mode: bool = False,
904    return_indices: bool = False,
905) -> Tensor:
906    if has_torch_function_unary(input):
907        return handle_torch_function(
908            max_pool3d,
909            (input,),
910            input,
911            kernel_size,
912            stride=stride,
913            padding=padding,
914            dilation=dilation,
915            ceil_mode=ceil_mode,
916            return_indices=return_indices,
917        )
918    if stride is None:
919        stride = torch.jit.annotate(List[int], [])
920    return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)
921
922
923max_pool3d = boolean_dispatch(
924    arg_name="return_indices",
925    arg_index=6,
926    default=False,
927    if_true=max_pool3d_with_indices,
928    if_false=_max_pool3d,
929    module_name=__name__,
930    func_name="max_pool3d",
931)
932
933
934def _unpool_output_size(
935    input: Tensor,
936    kernel_size: List[int],
937    stride: List[int],
938    padding: List[int],
939    output_size: Optional[List[int]],
940) -> List[int]:
941    input_size = input.size()
942    default_size = torch.jit.annotate(List[int], [])
943    for d in range(len(kernel_size)):
944        default_size.append(
945            (input_size[-len(kernel_size) + d] - 1) * stride[d]
946            + kernel_size[d]
947            - 2 * padding[d]
948        )
949    if output_size is None:
950        ret = default_size
951    else:
952        if len(output_size) == len(kernel_size) + 2:
953            output_size = output_size[2:]
954        if len(output_size) != len(kernel_size):
955            raise ValueError(
956                "output_size should be a sequence containing "
957                f"{len(kernel_size)} or {len(kernel_size) + 2} elements, but it has a length of '{len(output_size)}'"
958            )
959        for d in range(len(kernel_size)):
960            min_size = default_size[d] - stride[d]
961            max_size = default_size[d] + stride[d]
962            if not (min_size < output_size[d] < max_size):
963                raise ValueError(
964                    f'invalid output_size "{output_size}" (dim {d} must be between {min_size} and {max_size})'
965                )
966
967        ret = output_size
968    return ret
969
970
971def max_unpool1d(
972    input: Tensor,
973    indices: Tensor,
974    kernel_size: BroadcastingList1[int],
975    stride: Optional[BroadcastingList1[int]] = None,
976    padding: BroadcastingList1[int] = 0,
977    output_size: Optional[BroadcastingList1[int]] = None,
978) -> Tensor:
979    r"""Compute a partial inverse of :class:`MaxPool1d`.
980
981    See :class:`~torch.nn.MaxUnpool1d` for details.
982    """
983    if has_torch_function_unary(input):
984        return handle_torch_function(
985            max_unpool1d,
986            (input,),
987            input,
988            indices,
989            kernel_size,
990            stride=stride,
991            padding=padding,
992            output_size=output_size,
993        )
994    kernel_size = _single(kernel_size)
995    if stride is not None:
996        _stride = _single(stride)
997    else:
998        _stride = kernel_size
999    padding = _single(padding)
1000    output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
1001    if isinstance(output_size, list):
1002        output_size = output_size + [1]
1003    else:
1004        output_size = output_size + (1,)
1005    return torch._C._nn.max_unpool2d(
1006        input.unsqueeze(-1), indices.unsqueeze(-1), output_size
1007    ).squeeze(-1)
1008
1009
1010def max_unpool2d(
1011    input: Tensor,
1012    indices: Tensor,
1013    kernel_size: BroadcastingList2[int],
1014    stride: Optional[BroadcastingList2[int]] = None,
1015    padding: BroadcastingList2[int] = 0,
1016    output_size: Optional[BroadcastingList2[int]] = None,
1017) -> Tensor:
1018    r"""Compute a partial inverse of :class:`MaxPool2d`.
1019
1020    See :class:`~torch.nn.MaxUnpool2d` for details.
1021    """
1022    if has_torch_function_unary(input):
1023        return handle_torch_function(
1024            max_unpool2d,
1025            (input,),
1026            input,
1027            indices,
1028            kernel_size,
1029            stride=stride,
1030            padding=padding,
1031            output_size=output_size,
1032        )
1033    kernel_size = _pair(kernel_size)
1034    if stride is not None:
1035        _stride = _pair(stride)
1036    else:
1037        _stride = kernel_size
1038    padding = _pair(padding)
1039    output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
1040    return torch._C._nn.max_unpool2d(input, indices, output_size)
1041
1042
1043def max_unpool3d(
1044    input: Tensor,
1045    indices: Tensor,
1046    kernel_size: BroadcastingList3[int],
1047    stride: Optional[BroadcastingList3[int]] = None,
1048    padding: BroadcastingList3[int] = 0,
1049    output_size: Optional[BroadcastingList3[int]] = None,
1050) -> Tensor:
1051    r"""Compute a partial inverse of :class:`MaxPool3d`.
1052
1053    See :class:`~torch.nn.MaxUnpool3d` for details.
1054    """
1055    if has_torch_function_unary(input):
1056        return handle_torch_function(
1057            max_unpool3d,
1058            (input,),
1059            input,
1060            indices,
1061            kernel_size,
1062            stride=stride,
1063            padding=padding,
1064            output_size=output_size,
1065        )
1066    kernel_size = _triple(kernel_size)
1067    if stride is not None:
1068        _stride = _triple(stride)
1069    else:
1070        _stride = kernel_size
1071    padding = _triple(padding)
1072    output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
1073    return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding)
1074
1075
1076def lp_pool3d(
1077    input: Tensor,
1078    norm_type: Union[int, float],
1079    kernel_size: BroadcastingList3[int],
1080    stride: Optional[BroadcastingList3[int]] = None,
1081    ceil_mode: bool = False,
1082) -> Tensor:
1083    r"""
1084    Apply a 3D power-average pooling over an input signal composed of several input planes.
1085
1086    If the sum of all inputs to the power of `p` is
1087    zero, the gradient is set to zero as well.
1088
1089    See :class:`~torch.nn.LPPool3d` for details.
1090    """
1091    if has_torch_function_unary(input):
1092        return handle_torch_function(
1093            lp_pool3d,
1094            (input,),
1095            input,
1096            norm_type,
1097            kernel_size,
1098            stride=stride,
1099            ceil_mode=ceil_mode,
1100        )
1101    kd, kw, kh = _triple(kernel_size)
1102    if stride is not None:
1103        out = avg_pool3d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
1104    else:
1105        out = avg_pool3d(
1106            input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode
1107        )
1108
1109    return (
1110        (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type)
1111    )
1112
1113
1114def lp_pool2d(
1115    input: Tensor,
1116    norm_type: Union[int, float],
1117    kernel_size: BroadcastingList2[int],
1118    stride: Optional[BroadcastingList2[int]] = None,
1119    ceil_mode: bool = False,
1120) -> Tensor:
1121    r"""
1122    Apply a 2D power-average pooling over an input signal composed of several input planes.
1123
1124    If the sum of all inputs to the power of `p` is
1125    zero, the gradient is set to zero as well.
1126
1127    See :class:`~torch.nn.LPPool2d` for details.
1128    """
1129    if has_torch_function_unary(input):
1130        return handle_torch_function(
1131            lp_pool2d,
1132            (input,),
1133            input,
1134            norm_type,
1135            kernel_size,
1136            stride=stride,
1137            ceil_mode=ceil_mode,
1138        )
1139    kw, kh = _pair(kernel_size)
1140    if stride is not None:
1141        out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
1142    else:
1143        out = avg_pool2d(
1144            input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode
1145        )
1146
1147    return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type)
1148
1149
1150def lp_pool1d(
1151    input: Tensor,
1152    norm_type: Union[int, float],
1153    kernel_size: int,
1154    stride: Optional[BroadcastingList1[int]] = None,
1155    ceil_mode: bool = False,
1156) -> Tensor:
1157    r"""Apply a 1D power-average pooling over an input signal composed of several input planes.
1158
1159    If the sum of all inputs to the power of `p` is
1160    zero, the gradient is set to zero as well.
1161
1162    See :class:`~torch.nn.LPPool1d` for details.
1163    """
1164    if has_torch_function_unary(input):
1165        return handle_torch_function(
1166            lp_pool1d,
1167            (input,),
1168            input,
1169            norm_type,
1170            kernel_size,
1171            stride=stride,
1172            ceil_mode=ceil_mode,
1173        )
1174    if stride is not None:
1175        out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
1176    else:
1177        out = avg_pool1d(
1178            input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode
1179        )
1180
1181    return (
1182        (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type)
1183    )
1184
1185
1186def adaptive_max_pool1d_with_indices(
1187    input: Tensor,
1188    output_size: BroadcastingList1[int],
1189    return_indices: bool = False,
1190) -> Tuple[Tensor, Tensor]:  # noqa: D400
1191    r"""
1192    adaptive_max_pool1d(input, output_size, return_indices=False)
1193
1194    Applies a 1D adaptive max pooling over an input signal composed of
1195    several input planes.
1196
1197    See :class:`~torch.nn.AdaptiveMaxPool1d` for details and output shape.
1198
1199    Args:
1200        output_size: the target output size (single integer)
1201        return_indices: whether to return pooling indices. Default: ``False``
1202    """
1203    if has_torch_function_unary(input):
1204        return handle_torch_function(
1205            adaptive_max_pool1d_with_indices,
1206            (input,),
1207            input,
1208            output_size,
1209            return_indices=return_indices,
1210        )
1211    return torch.adaptive_max_pool1d(input, output_size)
1212
1213
1214def _adaptive_max_pool1d(
1215    input: Tensor,
1216    output_size: BroadcastingList1[int],
1217    return_indices: bool = False,
1218) -> Tensor:
1219    if has_torch_function_unary(input):
1220        return handle_torch_function(
1221            adaptive_max_pool1d,
1222            (input,),
1223            input,
1224            output_size,
1225            return_indices=return_indices,
1226        )
1227    return adaptive_max_pool1d_with_indices(input, output_size)[0]
1228
1229
1230adaptive_max_pool1d = boolean_dispatch(
1231    arg_name="return_indices",
1232    arg_index=2,
1233    default=False,
1234    if_true=adaptive_max_pool1d_with_indices,
1235    if_false=_adaptive_max_pool1d,
1236    module_name=__name__,
1237    func_name="adaptive_max_pool1d",
1238)
1239
1240
1241def adaptive_max_pool2d_with_indices(
1242    input: Tensor,
1243    output_size: BroadcastingList2[int],
1244    return_indices: bool = False,
1245) -> Tuple[Tensor, Tensor]:  # noqa: D400
1246    r"""adaptive_max_pool2d(input, output_size, return_indices=False)
1247
1248    Applies a 2D adaptive max pooling over an input signal composed of
1249    several input planes.
1250
1251    See :class:`~torch.nn.AdaptiveMaxPool2d` for details and output shape.
1252
1253    Args:
1254        output_size: the target output size (single integer or
1255            double-integer tuple)
1256        return_indices: whether to return pooling indices. Default: ``False``
1257    """
1258    if has_torch_function_unary(input):
1259        return handle_torch_function(
1260            adaptive_max_pool2d_with_indices,
1261            (input,),
1262            input,
1263            output_size,
1264            return_indices=return_indices,
1265        )
1266    output_size = _list_with_default(output_size, input.size())
1267    return torch._C._nn.adaptive_max_pool2d(input, output_size)
1268
1269
1270def _adaptive_max_pool2d(
1271    input: Tensor,
1272    output_size: BroadcastingList2[int],
1273    return_indices: bool = False,
1274) -> Tensor:
1275    if has_torch_function_unary(input):
1276        return handle_torch_function(
1277            adaptive_max_pool2d,
1278            (input,),
1279            input,
1280            output_size,
1281            return_indices=return_indices,
1282        )
1283    return adaptive_max_pool2d_with_indices(input, output_size)[0]
1284
1285
1286adaptive_max_pool2d = boolean_dispatch(
1287    arg_name="return_indices",
1288    arg_index=2,
1289    default=False,
1290    if_true=adaptive_max_pool2d_with_indices,
1291    if_false=_adaptive_max_pool2d,
1292    module_name=__name__,
1293    func_name="adaptive_max_pool2d",
1294)
1295
1296
1297def adaptive_max_pool3d_with_indices(
1298    input: Tensor,
1299    output_size: BroadcastingList3[int],
1300    return_indices: bool = False,
1301) -> Tuple[Tensor, Tensor]:  # noqa: D400
1302    r"""
1303    adaptive_max_pool3d(input, output_size, return_indices=False)
1304
1305    Applies a 3D adaptive max pooling over an input signal composed of
1306    several input planes.
1307
1308    See :class:`~torch.nn.AdaptiveMaxPool3d` for details and output shape.
1309
1310    Args:
1311        output_size: the target output size (single integer or
1312            triple-integer tuple)
1313        return_indices: whether to return pooling indices. Default: ``False``
1314    """
1315    if has_torch_function_unary(input):
1316        return handle_torch_function(
1317            adaptive_max_pool3d_with_indices,
1318            (input,),
1319            input,
1320            output_size,
1321            return_indices=return_indices,
1322        )
1323    output_size = _list_with_default(output_size, input.size())
1324    return torch._C._nn.adaptive_max_pool3d(input, output_size)
1325
1326
1327def _adaptive_max_pool3d(
1328    input: Tensor,
1329    output_size: BroadcastingList3[int],
1330    return_indices: bool = False,
1331) -> Tensor:
1332    if has_torch_function_unary(input):
1333        return handle_torch_function(
1334            adaptive_max_pool3d,
1335            (input,),
1336            input,
1337            output_size,
1338            return_indices=return_indices,
1339        )
1340    return adaptive_max_pool3d_with_indices(input, output_size)[0]
1341
1342
1343adaptive_max_pool3d = boolean_dispatch(
1344    arg_name="return_indices",
1345    arg_index=2,
1346    default=False,
1347    if_true=adaptive_max_pool3d_with_indices,
1348    if_false=_adaptive_max_pool3d,
1349    module_name=__name__,
1350    func_name="adaptive_max_pool3d",
1351)
1352
1353
1354adaptive_avg_pool1d = _add_docstr(
1355    torch.adaptive_avg_pool1d,
1356    r"""
1357adaptive_avg_pool1d(input, output_size) -> Tensor
1358
1359Applies a 1D adaptive average pooling over an input signal composed of
1360several input planes.
1361
1362See :class:`~torch.nn.AdaptiveAvgPool1d` for details and output shape.
1363
1364Args:
1365    output_size: the target output size (single integer)
1366""",
1367)
1368
1369
1370def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
1371    r"""Apply a 2D adaptive average pooling over an input signal composed of several input planes.
1372
1373    See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape.
1374
1375    Args:
1376        output_size: the target output size (single integer or
1377            double-integer tuple)
1378    """
1379    if has_torch_function_unary(input):
1380        return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
1381    _output_size = _list_with_default(output_size, input.size())
1382    return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
1383
1384
1385def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> Tensor:
1386    r"""Apply a 3D adaptive average pooling over an input signal composed of several input planes.
1387
1388    See :class:`~torch.nn.AdaptiveAvgPool3d` for details and output shape.
1389
1390    Args:
1391        output_size: the target output size (single integer or
1392            triple-integer tuple)
1393    """
1394    if has_torch_function_unary(input):
1395        return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
1396    _output_size = _list_with_default(output_size, input.size())
1397    return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
1398
1399
1400# Activation functions
1401def dropout(
1402    input: Tensor,
1403    p: float = 0.5,
1404    training: bool = True,
1405    inplace: bool = False,
1406) -> Tensor:
1407    r"""During training, randomly zeroes some elements of the input tensor with probability :attr:`p`.
1408
1409    Uses samples from a Bernoulli distribution.
1410
1411    See :class:`~torch.nn.Dropout` for details.
1412
1413    Args:
1414        p: probability of an element to be zeroed. Default: 0.5
1415        training: apply dropout if is ``True``. Default: ``True``
1416        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
1417    """
1418    if has_torch_function_unary(input):
1419        return handle_torch_function(
1420            dropout, (input,), input, p=p, training=training, inplace=inplace
1421        )
1422    if p < 0.0 or p > 1.0:
1423        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1424    return (
1425        _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
1426    )
1427
1428
1429def alpha_dropout(
1430    input: Tensor,
1431    p: float = 0.5,
1432    training: bool = False,
1433    inplace: bool = False,
1434) -> Tensor:
1435    r"""Apply alpha dropout to the input.
1436
1437    See :class:`~torch.nn.AlphaDropout` for details.
1438    """
1439    if has_torch_function_unary(input):
1440        return handle_torch_function(
1441            alpha_dropout, (input,), input, p=p, training=training, inplace=inplace
1442        )
1443    if p < 0.0 or p > 1.0:
1444        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1445    return (
1446        _VF.alpha_dropout_(input, p, training)
1447        if inplace
1448        else _VF.alpha_dropout(input, p, training)
1449    )
1450
1451
1452def dropout1d(
1453    input: Tensor,
1454    p: float = 0.5,
1455    training: bool = True,
1456    inplace: bool = False,
1457) -> Tensor:
1458    r"""Randomly zero out entire channels (a channel is a 1D feature map).
1459
1460    For example, the :math:`j`-th channel of the :math:`i`-th sample in the
1461    batched input is a 1D tensor :math:`\text{input}[i, j]` of the input tensor.
1462    Each channel will be zeroed out independently on every forward call with
1463    probability :attr:`p` using samples from a Bernoulli distribution.
1464
1465    See :class:`~torch.nn.Dropout1d` for details.
1466
1467    Args:
1468        p: probability of a channel to be zeroed. Default: 0.5
1469        training: apply dropout if is ``True``. Default: ``True``
1470        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
1471    """
1472    if has_torch_function_unary(input):
1473        return handle_torch_function(
1474            dropout1d, (input,), input, p=p, training=training, inplace=inplace
1475        )
1476    if p < 0.0 or p > 1.0:
1477        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1478    inp_dim = input.dim()
1479    if inp_dim not in (2, 3):
1480        raise RuntimeError(
1481            f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. "
1482            "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
1483            "spatial dimension, a channel dimension, and an optional batch dimension "
1484            "(i.e. 2D or 3D inputs)."
1485        )
1486
1487    is_batched = inp_dim == 3
1488    if not is_batched:
1489        input = input.unsqueeze_(0) if inplace else input.unsqueeze(0)
1490
1491    result = (
1492        _VF.feature_dropout_(input, p, training)
1493        if inplace
1494        else _VF.feature_dropout(input, p, training)
1495    )
1496
1497    if not is_batched:
1498        result = result.squeeze_(0) if inplace else result.squeeze(0)
1499
1500    return result
1501
1502
1503def dropout2d(
1504    input: Tensor,
1505    p: float = 0.5,
1506    training: bool = True,
1507    inplace: bool = False,
1508) -> Tensor:
1509    r"""Randomly zero out entire channels (a channel is a 2D feature map).
1510
1511    For example, the :math:`j`-th channel of the :math:`i`-th sample in the
1512    batched input is a 2D tensor :math:`\text{input}[i, j]` of the input tensor.
1513    Each channel will be zeroed out independently on every forward call with
1514    probability :attr:`p` using samples from a Bernoulli distribution.
1515
1516    See :class:`~torch.nn.Dropout2d` for details.
1517
1518    Args:
1519        p: probability of a channel to be zeroed. Default: 0.5
1520        training: apply dropout if is ``True``. Default: ``True``
1521        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
1522    """
1523    if has_torch_function_unary(input):
1524        return handle_torch_function(
1525            dropout2d, (input,), input, p=p, training=training, inplace=inplace
1526        )
1527    if p < 0.0 or p > 1.0:
1528        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1529    inp_dim = input.dim()
1530    if inp_dim not in (3, 4):
1531        warn_msg = (
1532            f"dropout2d: Received a {inp_dim}-D input to dropout2d, which is deprecated "
1533            "and will result in an error in a future release. To retain the behavior "
1534            "and silence this warning, please use dropout instead. Note that dropout2d "
1535            "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, "
1536            "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)."
1537        )
1538        warnings.warn(warn_msg)
1539
1540    # TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing
1541    # a 3D input will perform dropout1d behavior instead. This was done historically and the
1542    # behavior is maintained here for now.
1543    # See https://github.com/pytorch/pytorch/issues/77081
1544    if inp_dim == 3:
1545        warnings.warn(
1546            "dropout2d: Received a 3D input to dropout2d and assuming that channel-wise "
1547            "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C "
1548            "is the channel dim. This behavior will change in a future release to interpret the "
1549            "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D "
1550            "channel-wise dropout behavior, please switch to using dropout1d instead."
1551        )
1552
1553    result = (
1554        _VF.feature_dropout_(input, p, training)
1555        if inplace
1556        else _VF.feature_dropout(input, p, training)
1557    )
1558
1559    return result
1560
1561
1562def dropout3d(
1563    input: Tensor,
1564    p: float = 0.5,
1565    training: bool = True,
1566    inplace: bool = False,
1567) -> Tensor:
1568    r"""Randomly zero out entire channels (a channel is a 3D feature map).
1569
1570    For example, the :math:`j`-th channel of the :math:`i`-th sample in the
1571    batched input is a 3D tensor :math:`\text{input}[i, j]` of the input tensor.
1572    Each channel will be zeroed out independently on every forward call with
1573    probability :attr:`p` using samples from a Bernoulli distribution.
1574
1575    See :class:`~torch.nn.Dropout3d` for details.
1576
1577    Args:
1578        p: probability of a channel to be zeroed. Default: 0.5
1579        training: apply dropout if is ``True``. Default: ``True``
1580        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
1581    """
1582    if has_torch_function_unary(input):
1583        return handle_torch_function(
1584            dropout3d, (input,), input, p=p, training=training, inplace=inplace
1585        )
1586    if p < 0.0 or p > 1.0:
1587        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1588    inp_dim = input.dim()
1589    if inp_dim not in (4, 5):
1590        warn_msg = (
1591            f"dropout3d: Received a {inp_dim}-D input to dropout3d, which is deprecated "
1592            "and will result in an error in a future release. To retain the behavior "
1593            "and silence this warning, please use dropout instead. Note that dropout3d "
1594            "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, "
1595            "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs)."
1596        )
1597        warnings.warn(warn_msg)
1598
1599    is_batched = inp_dim == 5
1600    if not is_batched:
1601        input = input.unsqueeze_(0) if inplace else input.unsqueeze(0)
1602
1603    result = (
1604        _VF.feature_dropout_(input, p, training)
1605        if inplace
1606        else _VF.feature_dropout(input, p, training)
1607    )
1608
1609    if not is_batched:
1610        result = result.squeeze_(0) if inplace else result.squeeze(0)
1611    return result
1612
1613
1614def feature_alpha_dropout(
1615    input: Tensor,
1616    p: float = 0.5,
1617    training: bool = False,
1618    inplace: bool = False,
1619) -> Tensor:
1620    r"""Randomly masks out entire channels (a channel is a feature map).
1621
1622    For example, the :math:`j`-th channel of the :math:`i`-th sample in the batch input
1623    is a tensor :math:`\text{input}[i, j]` of the input tensor. Instead of
1624    setting activations to zero, as in regular Dropout, the activations are set
1625    to the negative saturation value of the SELU activation function.
1626
1627    Each element will be masked independently on every forward call with
1628    probability :attr:`p` using samples from a Bernoulli distribution.
1629    The elements to be masked are randomized on every forward call, and scaled
1630    and shifted to maintain zero mean and unit variance.
1631
1632    See :class:`~torch.nn.FeatureAlphaDropout` for details.
1633
1634    Args:
1635        p: dropout probability of a channel to be zeroed. Default: 0.5
1636        training: apply dropout if is ``True``. Default: ``True``
1637        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
1638    """
1639    if has_torch_function_unary(input):
1640        return handle_torch_function(
1641            feature_alpha_dropout,
1642            (input,),
1643            input,
1644            p=p,
1645            training=training,
1646            inplace=inplace,
1647        )
1648    if p < 0.0 or p > 1.0:
1649        raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
1650    return (
1651        _VF.feature_alpha_dropout_(input, p, training)
1652        if inplace
1653        else _VF.feature_alpha_dropout(input, p, training)
1654    )
1655
1656
1657def _threshold(
1658    input: Tensor,
1659    threshold: float,
1660    value: float,
1661    inplace: bool = False,
1662) -> Tensor:
1663    r"""Apply a threshold to each element of the input Tensor.
1664
1665    See :class:`~torch.nn.Threshold` for more details.
1666    """
1667    if has_torch_function_unary(input):
1668        return handle_torch_function(
1669            _threshold, (input,), input, threshold, value, inplace=inplace
1670        )
1671    if inplace:
1672        result = _VF.threshold_(input, threshold, value)
1673    else:
1674        result = _VF.threshold(input, threshold, value)
1675    return result
1676
1677
1678# We define this function as _threshold because it takes an argument
1679# named threshold, which clobbers the recursive reference to the
1680# function needed for __torch_function__ support
1681threshold = _threshold
1682
1683threshold_ = _add_docstr(
1684    _VF.threshold_,
1685    r"""
1686threshold_(input, threshold, value) -> Tensor
1687
1688In-place version of :func:`~threshold`.
1689""",
1690)
1691
1692
1693def relu(input: Tensor, inplace: bool = False) -> Tensor:  # noqa: D400,D402
1694    r"""relu(input, inplace=False) -> Tensor
1695
1696    Applies the rectified linear unit function element-wise. See
1697    :class:`~torch.nn.ReLU` for more details.
1698    """
1699    if has_torch_function_unary(input):
1700        return handle_torch_function(relu, (input,), input, inplace=inplace)
1701    if inplace:
1702        result = torch.relu_(input)
1703    else:
1704        result = torch.relu(input)
1705    return result
1706
1707
1708relu_ = _add_docstr(
1709    torch.relu_,
1710    r"""
1711relu_(input) -> Tensor
1712
1713In-place version of :func:`~relu`.
1714""",
1715)
1716
1717
1718def glu(input: Tensor, dim: int = -1) -> Tensor:  # noqa: D400,D402
1719    r"""
1720    glu(input, dim=-1) -> Tensor
1721
1722    The gated linear unit. Computes:
1723
1724    .. math ::
1725        \text{GLU}(a, b) = a \otimes \sigma(b)
1726
1727    where `input` is split in half along `dim` to form `a` and `b`, :math:`\sigma`
1728    is the sigmoid function and :math:`\otimes` is the element-wise product between matrices.
1729
1730    See `Language Modeling with Gated Convolutional Networks <https://arxiv.org/abs/1612.08083>`_.
1731
1732    Args:
1733        input (Tensor): input tensor
1734        dim (int): dimension on which to split the input. Default: -1
1735    """
1736    if has_torch_function_unary(input):
1737        return handle_torch_function(glu, (input,), input, dim=dim)
1738    if input.dim() == 0:
1739        raise RuntimeError(
1740            "glu does not support scalars because halving size must be even"
1741        )
1742    return torch._C._nn.glu(input, dim)
1743
1744
1745def hardtanh(
1746    input: Tensor,
1747    min_val: float = -1.0,
1748    max_val: float = 1.0,
1749    inplace: bool = False,
1750) -> Tensor:  # noqa: D400,D402
1751    r"""
1752    hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor
1753
1754    Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more
1755    details.
1756    """
1757    if has_torch_function_unary(input):
1758        return handle_torch_function(
1759            hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace
1760        )
1761    if min_val > max_val:
1762        raise ValueError("min_val cannot be greater than max_val")
1763    if inplace:
1764        result = torch._C._nn.hardtanh_(input, min_val, max_val)
1765    else:
1766        result = torch._C._nn.hardtanh(input, min_val, max_val)
1767    return result
1768
1769
1770hardtanh_ = _add_docstr(
1771    torch._C._nn.hardtanh_,
1772    r"""
1773hardtanh_(input, min_val=-1., max_val=1.) -> Tensor
1774
1775In-place version of :func:`~hardtanh`.
1776""",
1777)
1778
1779
1780def relu6(input: Tensor, inplace: bool = False) -> Tensor:  # noqa: D400,D402
1781    r"""relu6(input, inplace=False) -> Tensor
1782
1783    Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`.
1784
1785    See :class:`~torch.nn.ReLU6` for more details.
1786    """
1787    if has_torch_function_unary(input):
1788        return handle_torch_function(relu6, (input,), input, inplace=inplace)
1789    if inplace:
1790        result = torch._C._nn.relu6_(input)
1791    else:
1792        result = torch._C._nn.relu6(input)
1793    return result
1794
1795
1796def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor:
1797    r"""Apply the Exponential Linear Unit (ELU) function element-wise.
1798
1799    See :class:`~torch.nn.ELU` for more details.
1800    """
1801    if has_torch_function_unary(input):
1802        return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace)
1803    if inplace:
1804        result = torch._C._nn.elu_(input, alpha)
1805    else:
1806        result = torch._C._nn.elu(input, alpha)
1807    return result
1808
1809
1810elu_ = _add_docstr(
1811    torch._C._nn.elu_,
1812    r"""
1813elu_(input, alpha=1.) -> Tensor
1814
1815In-place version of :func:`~elu`.
1816""",
1817)
1818
1819
1820def selu(input: Tensor, inplace: bool = False) -> Tensor:  # noqa: D400,D402
1821    r"""selu(input, inplace=False) -> Tensor
1822
1823    Applies element-wise,
1824    :math:`\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))`,
1825    with :math:`\alpha=1.6732632423543772848170429916717` and
1826    :math:`scale=1.0507009873554804934193349852946`.
1827
1828    See :class:`~torch.nn.SELU` for more details.
1829    """
1830    if has_torch_function_unary(input):
1831        return handle_torch_function(selu, (input,), input, inplace=inplace)
1832    if inplace:
1833        result = torch.selu_(input)
1834    else:
1835        result = torch.selu(input)
1836    return result
1837
1838
1839selu_ = _add_docstr(
1840    torch.selu_,
1841    r"""
1842selu_(input) -> Tensor
1843
1844In-place version of :func:`~selu`.
1845""",
1846)
1847
1848
1849def celu(
1850    input: Tensor,
1851    alpha: float = 1.0,
1852    inplace: bool = False,
1853) -> Tensor:  # noqa: D400,D402
1854    r"""celu(input, alpha=1., inplace=False) -> Tensor
1855
1856    Applies element-wise,
1857    :math:`\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))`.
1858
1859    See :class:`~torch.nn.CELU` for more details.
1860    """
1861    if has_torch_function_unary(input):
1862        return handle_torch_function(
1863            celu, (input,), input, alpha=alpha, inplace=inplace
1864        )
1865    if inplace:
1866        result = torch.celu_(input, alpha)
1867    else:
1868        result = torch.celu(input, alpha)
1869    return result
1870
1871
1872celu_ = _add_docstr(
1873    torch.celu_,
1874    r"""
1875celu_(input, alpha=1.) -> Tensor
1876
1877In-place version of :func:`~celu`.
1878""",
1879)
1880
1881
1882def leaky_relu(
1883    input: Tensor,
1884    negative_slope: float = 0.01,
1885    inplace: bool = False,
1886) -> Tensor:  # noqa: D400,D402
1887    r"""
1888    leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor
1889
1890    Applies element-wise,
1891    :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)`
1892
1893    See :class:`~torch.nn.LeakyReLU` for more details.
1894    """
1895    if has_torch_function_unary(input):
1896        return handle_torch_function(
1897            leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace
1898        )
1899    if inplace:
1900        result = torch._C._nn.leaky_relu_(input, negative_slope)
1901    else:
1902        result = torch._C._nn.leaky_relu(input, negative_slope)
1903    return result
1904
1905
1906leaky_relu_ = _add_docstr(
1907    torch._C._nn.leaky_relu_,
1908    r"""
1909leaky_relu_(input, negative_slope=0.01) -> Tensor
1910
1911In-place version of :func:`~leaky_relu`.
1912""",
1913)
1914
1915
1916prelu = _add_docstr(
1917    torch.prelu,
1918    r"""prelu(input, weight) -> Tensor
1919
1920Applies element-wise the function
1921:math:`\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x)` where weight is a
1922learnable parameter.
1923
1924.. note::
1925    `weight` is expected to be a scalar or 1-D tensor. If `weight` is 1-D,
1926    its size must match the number of input channels, determined by
1927    `input.size(1)` when `input.dim() >= 2`, otherwise 1.
1928    In the 1-D case, note that when `input` has dim > 2, `weight` can be expanded
1929    to the shape of `input` in a way that is not possible using normal
1930    :ref:`broadcasting semantics<broadcasting-semantics>`.
1931
1932See :class:`~torch.nn.PReLU` for more details.
1933""",
1934)
1935
1936
1937def rrelu(
1938    input: Tensor,
1939    lower: float = 1.0 / 8,
1940    upper: float = 1.0 / 3,
1941    training: bool = False,
1942    inplace: bool = False,
1943) -> Tensor:  # noqa: D400,D402
1944    r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor
1945
1946    Randomized leaky ReLU.
1947
1948    See :class:`~torch.nn.RReLU` for more details.
1949    """
1950    if has_torch_function_unary(input):
1951        return handle_torch_function(
1952            rrelu,
1953            (input,),
1954            input,
1955            lower=lower,
1956            upper=upper,
1957            training=training,
1958            inplace=inplace,
1959        )
1960    if inplace:
1961        result = torch.rrelu_(input, lower, upper, training)
1962    else:
1963        result = torch.rrelu(input, lower, upper, training)
1964    return result
1965
1966
1967rrelu_ = _add_docstr(
1968    torch.rrelu_,
1969    r"""
1970rrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor
1971
1972In-place version of :func:`~rrelu`.
1973""",
1974)
1975
1976logsigmoid = _add_docstr(
1977    torch._C._nn.log_sigmoid,
1978    r"""
1979logsigmoid(input) -> Tensor
1980
1981Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)`
1982
1983See :class:`~torch.nn.LogSigmoid` for more details.
1984""",
1985)
1986
1987gelu = _add_docstr(
1988    torch._C._nn.gelu,
1989    r"""
1990gelu(input, approximate = 'none') -> Tensor
1991
1992When the approximate argument is 'none', it applies element-wise the function
1993:math:`\text{GELU}(x) = x * \Phi(x)`
1994
1995where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
1996
1997When the approximate argument is 'tanh', Gelu is estimated with
1998
1999.. math::
2000    \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
2001
2002See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
2003""",
2004)
2005
2006hardshrink = _add_docstr(
2007    torch.hardshrink,
2008    r"""
2009hardshrink(input, lambd=0.5) -> Tensor
2010
2011Applies the hard shrinkage function element-wise
2012
2013See :class:`~torch.nn.Hardshrink` for more details.
2014""",
2015)
2016
2017
2018def tanhshrink(input):  # noqa: D400,D402
2019    r"""tanhshrink(input) -> Tensor
2020
2021    Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)`
2022
2023    See :class:`~torch.nn.Tanhshrink` for more details.
2024    """
2025    if has_torch_function_unary(input):
2026        return handle_torch_function(tanhshrink, (input,), input)
2027    return input - input.tanh()
2028
2029
2030def softsign(input):  # noqa: D400,D402
2031    r"""softsign(input) -> Tensor
2032
2033    Applies element-wise, the function :math:`\text{SoftSign}(x) = \frac{x}{1 + |x|}`
2034
2035    See :class:`~torch.nn.Softsign` for more details.
2036    """
2037    if has_torch_function_unary(input):
2038        return handle_torch_function(softsign, (input,), input)
2039    return input / (input.abs() + 1)
2040
2041
2042softplus = _add_docstr(
2043    torch._C._nn.softplus,
2044    r"""
2045softplus(input, beta=1, threshold=20) -> Tensor
2046
2047Applies element-wise, the function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`.
2048
2049For numerical stability the implementation reverts to the linear function
2050when :math:`input \times \beta > threshold`.
2051
2052See :class:`~torch.nn.Softplus` for more details.
2053""",
2054)
2055
2056
2057def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int:
2058    warnings.warn(
2059        f"Implicit dimension choice for {name} has been deprecated. "
2060        "Change the call to include dim=X as an argument.",
2061        stacklevel=stacklevel,
2062    )
2063    if ndim == 0 or ndim == 1 or ndim == 3:
2064        ret = 0
2065    else:
2066        ret = 1
2067    return ret
2068
2069
2070def softmin(
2071    input: Tensor,
2072    dim: Optional[int] = None,
2073    _stacklevel: int = 3,
2074    dtype: Optional[DType] = None,
2075) -> Tensor:
2076    r"""Apply a softmin function.
2077
2078    Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula.
2079
2080    See :class:`~torch.nn.Softmin` for more details.
2081
2082    Args:
2083        input (Tensor): input
2084        dim (int): A dimension along which softmin will be computed (so every slice
2085            along dim will sum to 1).
2086        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
2087          If specified, the input tensor is casted to :attr:`dtype` before the operation
2088          is performed. This is useful for preventing data type overflows. Default: None.
2089    """
2090    if has_torch_function_unary(input):
2091        return handle_torch_function(
2092            softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype
2093        )
2094    if dim is None:
2095        dim = _get_softmax_dim("softmin", input.dim(), _stacklevel)
2096    if dtype is None:
2097        ret = (-input).softmax(dim)
2098    else:
2099        ret = (-input).softmax(dim, dtype=dtype)
2100    return ret
2101
2102
2103def softmax(
2104    input: Tensor,
2105    dim: Optional[int] = None,
2106    _stacklevel: int = 3,
2107    dtype: Optional[DType] = None,
2108) -> Tensor:
2109    r"""Apply a softmax function.
2110
2111    Softmax is defined as:
2112
2113    :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}`
2114
2115    It is applied to all slices along dim, and will re-scale them so that the elements
2116    lie in the range `[0, 1]` and sum to 1.
2117
2118    See :class:`~torch.nn.Softmax` for more details.
2119
2120    Args:
2121        input (Tensor): input
2122        dim (int): A dimension along which softmax will be computed.
2123        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
2124          If specified, the input tensor is casted to :attr:`dtype` before the operation
2125          is performed. This is useful for preventing data type overflows. Default: None.
2126
2127    .. note::
2128        This function doesn't work directly with NLLLoss,
2129        which expects the Log to be computed between the Softmax and itself.
2130        Use log_softmax instead (it's faster and has better numerical properties).
2131
2132    """
2133    if has_torch_function_unary(input):
2134        return handle_torch_function(
2135            softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype
2136        )
2137    if dim is None:
2138        dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)
2139    if dtype is None:
2140        ret = input.softmax(dim)
2141    else:
2142        ret = input.softmax(dim, dtype=dtype)
2143    return ret
2144
2145
2146def gumbel_softmax(
2147    logits: Tensor,
2148    tau: float = 1,
2149    hard: bool = False,
2150    eps: float = 1e-10,
2151    dim: int = -1,
2152) -> Tensor:
2153    r"""
2154    Sample from the Gumbel-Softmax distribution (`Link 1`_  `Link 2`_) and optionally discretize.
2155
2156    Args:
2157      logits: `[..., num_features]` unnormalized log probabilities
2158      tau: non-negative scalar temperature
2159      hard: if ``True``, the returned samples will be discretized as one-hot vectors,
2160            but will be differentiated as if it is the soft sample in autograd
2161      dim (int): A dimension along which softmax will be computed. Default: -1.
2162
2163    Returns:
2164      Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
2165      If ``hard=True``, the returned samples will be one-hot, otherwise they will
2166      be probability distributions that sum to 1 across `dim`.
2167
2168    .. note::
2169      This function is here for legacy reasons, may be removed from nn.Functional in the future.
2170
2171    .. note::
2172      The main trick for `hard` is to do  `y_hard - y_soft.detach() + y_soft`
2173
2174      It achieves two things:
2175      - makes the output value exactly one-hot
2176      (since we add then subtract y_soft value)
2177      - makes the gradient equal to y_soft gradient
2178      (since we strip all other gradients)
2179
2180    Examples::
2181        >>> logits = torch.randn(20, 32)
2182        >>> # Sample soft categorical using reparametrization trick:
2183        >>> F.gumbel_softmax(logits, tau=1, hard=False)
2184        >>> # Sample hard categorical using "Straight-through" trick:
2185        >>> F.gumbel_softmax(logits, tau=1, hard=True)
2186
2187    .. _Link 1:
2188        https://arxiv.org/abs/1611.00712
2189    .. _Link 2:
2190        https://arxiv.org/abs/1611.01144
2191    """
2192    if has_torch_function_unary(logits):
2193        return handle_torch_function(
2194            gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim
2195        )
2196    if eps != 1e-10:
2197        warnings.warn("`eps` parameter is deprecated and has no effect.")
2198
2199    gumbels = (
2200        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
2201        .exponential_()
2202        .log()
2203    )  # ~Gumbel(0,1)
2204    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
2205    y_soft = gumbels.softmax(dim)
2206
2207    if hard:
2208        # Straight through.
2209        index = y_soft.max(dim, keepdim=True)[1]
2210        y_hard = torch.zeros_like(
2211            logits, memory_format=torch.legacy_contiguous_format
2212        ).scatter_(dim, index, 1.0)
2213        ret = y_hard - y_soft.detach() + y_soft
2214    else:
2215        # Reparametrization trick.
2216        ret = y_soft
2217    return ret
2218
2219
2220def log_softmax(
2221    input: Tensor,
2222    dim: Optional[int] = None,
2223    _stacklevel: int = 3,
2224    dtype: Optional[DType] = None,
2225) -> Tensor:
2226    r"""Apply a softmax followed by a logarithm.
2227
2228    While mathematically equivalent to log(softmax(x)), doing these two
2229    operations separately is slower and numerically unstable. This function
2230    uses an alternative formulation to compute the output and gradient correctly.
2231
2232    See :class:`~torch.nn.LogSoftmax` for more details.
2233
2234    Args:
2235        input (Tensor): input
2236        dim (int): A dimension along which log_softmax will be computed.
2237        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
2238          If specified, the input tensor is cast to :attr:`dtype` before the operation
2239          is performed. This is useful for preventing data type overflows. Default: None.
2240    """
2241    if has_torch_function_unary(input):
2242        return handle_torch_function(
2243            log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype
2244        )
2245    if dim is None:
2246        dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel)
2247    if dtype is None:
2248        ret = input.log_softmax(dim)
2249    else:
2250        ret = input.log_softmax(dim, dtype=dtype)
2251    return ret
2252
2253
2254softshrink = _add_docstr(
2255    torch._C._nn.softshrink,
2256    r"""
2257softshrink(input, lambd=0.5) -> Tensor
2258
2259Applies the soft shrinkage function elementwise
2260
2261See :class:`~torch.nn.Softshrink` for more details.
2262""",
2263)
2264
2265
2266def tanh(input):  # noqa: D400,D402
2267    r"""tanh(input) -> Tensor
2268
2269    Applies element-wise,
2270    :math:`\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}`
2271
2272    See :class:`~torch.nn.Tanh` for more details.
2273    """
2274    return input.tanh()
2275
2276
2277def sigmoid(input):  # noqa: D400,D402
2278    r"""sigmoid(input) -> Tensor
2279
2280    Applies the element-wise function :math:`\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}`
2281
2282    See :class:`~torch.nn.Sigmoid` for more details.
2283    """
2284    return input.sigmoid()
2285
2286
2287def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor:
2288    r"""Apply the Hardsigmoid function element-wise.
2289
2290    .. math::
2291        \text{Hardsigmoid}(x) = \begin{cases}
2292            0 & \text{if~} x \le -3, \\
2293            1 & \text{if~} x \ge +3, \\
2294            x / 6 + 1 / 2 & \text{otherwise}
2295        \end{cases}
2296
2297    Args:
2298        inplace: If set to ``True``, will do this operation in-place. Default: ``False``
2299
2300    See :class:`~torch.nn.Hardsigmoid` for more details.
2301    """
2302    if has_torch_function_unary(input):
2303        return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace)
2304    if inplace:
2305        return torch._C._nn.hardsigmoid_(input)
2306    return torch._C._nn.hardsigmoid(input)
2307
2308
2309linear = _add_docstr(
2310    torch._C._nn.linear,
2311    r"""
2312linear(input, weight, bias=None) -> Tensor
2313
2314Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
2315
2316This operation supports 2-D :attr:`weight` with :ref:`sparse layout<sparse-docs>`
2317
2318{sparse_beta_warning}
2319
2320This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
2321
2322Shape:
2323
2324    - Input: :math:`(*, in\_features)` where `*` means any number of
2325      additional dimensions, including none
2326    - Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)`
2327    - Bias: :math:`(out\_features)` or :math:`()`
2328    - Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight
2329""".format(
2330        **sparse_support_notes
2331    ),
2332)
2333
2334
2335bilinear = _add_docstr(
2336    torch.bilinear,
2337    r"""
2338bilinear(input1, input2, weight, bias=None) -> Tensor
2339
2340Applies a bilinear transformation to the incoming data:
2341:math:`y = x_1^T A x_2 + b`
2342
2343Shape:
2344
2345    - input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}`
2346      and :math:`*` means any number of additional dimensions.
2347      All but the last dimension of the inputs should be the same.
2348    - input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}`
2349    - weight: :math:`(\text{out\_features}, \text{in1\_features},
2350      \text{in2\_features})`
2351    - bias: :math:`(\text{out\_features})`
2352    - output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}`
2353      and all but the last dimension are the same shape as the input.
2354""",
2355)
2356
2357
2358def silu(input: Tensor, inplace: bool = False) -> Tensor:
2359    r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.
2360
2361    The SiLU function is also known as the swish function.
2362
2363    .. math::
2364        \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
2365
2366    .. note::
2367        See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
2368        where the SiLU (Sigmoid Linear Unit) was originally coined, and see
2369        `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
2370        in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
2371        a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
2372        where the SiLU was experimented with later.
2373
2374    See :class:`~torch.nn.SiLU` for more details.
2375    """
2376    if has_torch_function_unary(input):
2377        return handle_torch_function(silu, (input,), input, inplace=inplace)
2378    if inplace:
2379        return torch._C._nn.silu_(input)
2380    return torch._C._nn.silu(input)
2381
2382
2383def mish(input: Tensor, inplace: bool = False) -> Tensor:
2384    r"""Apply the Mish function, element-wise.
2385
2386    Mish: A Self Regularized Non-Monotonic Neural Activation Function.
2387
2388    .. math::
2389        \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
2390
2391    .. note::
2392        See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
2393
2394    See :class:`~torch.nn.Mish` for more details.
2395    """
2396    if has_torch_function_unary(input):
2397        return handle_torch_function(mish, (input,), input, inplace=inplace)
2398    if inplace:
2399        return torch._C._nn.mish_(input)
2400    return torch._C._nn.mish(input)
2401
2402
2403def hardswish(input: Tensor, inplace: bool = False) -> Tensor:
2404    r"""Apply hardswish function, element-wise.
2405
2406    Follows implementation as described in the paper:
2407    `Searching for MobileNetV3`_.
2408
2409    .. math::
2410        \text{Hardswish}(x) = \begin{cases}
2411            0 & \text{if~} x \le -3, \\
2412            x & \text{if~} x \ge +3, \\
2413            x \cdot (x + 3) /6 & \text{otherwise}
2414        \end{cases}
2415
2416    See :class:`~torch.nn.Hardswish` for more details.
2417
2418    .. _`Searching for MobileNetV3`:
2419        https://arxiv.org/abs/1905.02244
2420    """
2421    if has_torch_function_unary(input):
2422        return handle_torch_function(hardswish, (input,), input, inplace=inplace)
2423    if inplace:
2424        return torch._C._nn.hardswish_(input)
2425    return torch._C._nn.hardswish(input)
2426
2427
2428def _no_grad_embedding_renorm_(
2429    weight: Tensor,
2430    input: Tensor,
2431    max_norm: float,
2432    norm_type: float,
2433) -> Tuple[Tensor, Tensor]:
2434    torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
2435
2436
2437def embedding(
2438    input: Tensor,
2439    weight: Tensor,
2440    padding_idx: Optional[int] = None,
2441    max_norm: Optional[float] = None,
2442    norm_type: float = 2.0,
2443    scale_grad_by_freq: bool = False,
2444    sparse: bool = False,
2445) -> Tensor:
2446    r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size.
2447
2448    This module is often used to retrieve word embeddings using indices.
2449    The input to the module is a list of indices, and the embedding matrix,
2450    and the output is the corresponding word embeddings.
2451
2452    See :class:`torch.nn.Embedding` for more details.
2453
2454    .. note::
2455        Note that the analytical gradients of this function with respect to
2456        entries in :attr:`weight` at the row specified by :attr:`padding_idx`
2457        are expected to differ from the numerical ones.
2458
2459    .. note::
2460        Note that `:class:`torch.nn.Embedding` differs from this function in
2461        that it initializes the row of :attr:`weight` specified by
2462        :attr:`padding_idx` to all zeros on construction.
2463
2464    Args:
2465        input (LongTensor): Tensor containing indices into the embedding matrix
2466        weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,
2467            and number of columns equal to the embedding size
2468        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
2469                                     therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
2470                                     i.e. it remains as a fixed "pad".
2471        max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
2472                                    is renormalized to have norm :attr:`max_norm`.
2473                                    Note: this will modify :attr:`weight` in-place.
2474        norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
2475        scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of
2476                                                the words in the mini-batch. Default ``False``.
2477        sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
2478                                 :class:`torch.nn.Embedding` for more details regarding sparse gradients.
2479
2480    Shape:
2481        - Input: LongTensor of arbitrary shape containing the indices to extract
2482        - Weight: Embedding matrix of floating point type with shape `(V, embedding_dim)`,
2483          where V = maximum index + 1 and embedding_dim = the embedding size
2484        - Output: `(*, embedding_dim)`, where `*` is the input shape
2485
2486    Examples::
2487
2488        >>> # a batch of 2 samples of 4 indices each
2489        >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
2490        >>> # an embedding matrix containing 10 tensors of size 3
2491        >>> embedding_matrix = torch.rand(10, 3)
2492        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
2493        >>> F.embedding(input, embedding_matrix)
2494        tensor([[[ 0.8490,  0.9625,  0.6753],
2495                 [ 0.9666,  0.7761,  0.6108],
2496                 [ 0.6246,  0.9751,  0.3618],
2497                 [ 0.4161,  0.2419,  0.7383]],
2498
2499                [[ 0.6246,  0.9751,  0.3618],
2500                 [ 0.0237,  0.7794,  0.0528],
2501                 [ 0.9666,  0.7761,  0.6108],
2502                 [ 0.3385,  0.8612,  0.1867]]])
2503
2504        >>> # example with padding_idx
2505        >>> weights = torch.rand(10, 3)
2506        >>> weights[0, :].zero_()
2507        >>> embedding_matrix = weights
2508        >>> input = torch.tensor([[0, 2, 0, 5]])
2509        >>> F.embedding(input, embedding_matrix, padding_idx=0)
2510        tensor([[[ 0.0000,  0.0000,  0.0000],
2511                 [ 0.5609,  0.5384,  0.8720],
2512                 [ 0.0000,  0.0000,  0.0000],
2513                 [ 0.6262,  0.2438,  0.7471]]])
2514    """
2515    if has_torch_function_variadic(input, weight):
2516        return handle_torch_function(
2517            embedding,
2518            (input, weight),
2519            input,
2520            weight,
2521            padding_idx=padding_idx,
2522            max_norm=max_norm,
2523            norm_type=norm_type,
2524            scale_grad_by_freq=scale_grad_by_freq,
2525            sparse=sparse,
2526        )
2527    if padding_idx is not None:
2528        if padding_idx > 0:
2529            assert padding_idx < weight.size(
2530                0
2531            ), "Padding_idx must be within num_embeddings"
2532        elif padding_idx < 0:
2533            assert padding_idx >= -weight.size(
2534                0
2535            ), "Padding_idx must be within num_embeddings"
2536            padding_idx = weight.size(0) + padding_idx
2537    else:
2538        padding_idx = -1
2539    if max_norm is not None:
2540        # Note [embedding_renorm contiguous]
2541        # `embedding_renorm_` will call .contiguous() on input anyways, so we
2542        # call it here and take advantage of the improved locality in the
2543        # `embedding` call below too.
2544        input = input.contiguous()
2545        # Note [embedding_renorm set_grad_enabled]
2546        # XXX: equivalent to
2547        # with torch.no_grad():
2548        #   torch.embedding_renorm_
2549        # remove once script supports set_grad_enabled
2550        _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
2551    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
2552
2553
2554def embedding_bag(
2555    input: Tensor,
2556    weight: Tensor,
2557    offsets: Optional[Tensor] = None,
2558    max_norm: Optional[float] = None,
2559    norm_type: float = 2,
2560    scale_grad_by_freq: bool = False,
2561    mode: str = "mean",
2562    sparse: bool = False,
2563    per_sample_weights: Optional[Tensor] = None,
2564    include_last_offset: bool = False,
2565    padding_idx: Optional[int] = None,
2566) -> Tensor:
2567    r"""Compute sums, means or maxes of `bags` of embeddings.
2568
2569    Calculation is done without instantiating the intermediate embeddings.
2570    See :class:`torch.nn.EmbeddingBag` for more details.
2571
2572    Note:
2573        {backward_reproducibility_note}
2574
2575    Args:
2576        input (LongTensor): Tensor containing bags of indices into the embedding matrix
2577        weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,
2578            and number of columns equal to the embedding size
2579        offsets (LongTensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines
2580                             the starting index position of each bag (sequence) in :attr:`input`.
2581        max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
2582                                    is renormalized to have norm :attr:`max_norm`.
2583                                    Note: this will modify :attr:`weight` in-place.
2584        norm_type (float, optional): The ``p`` in the ``p``-norm to compute for the :attr:`max_norm` option.
2585                                     Default ``2``.
2586        scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of
2587                                                the words in the mini-batch. Default ``False``.
2588                                                Note: this option is not supported when ``mode="max"``.
2589        mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
2590                                 Default: ``"mean"``
2591        sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
2592                                 :class:`torch.nn.Embedding` for more details regarding sparse gradients.
2593                                 Note: this option is not supported when ``mode="max"``.
2594        per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
2595            to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights`
2596            must have exactly the same shape as input and is treated as having the same
2597            :attr:`offsets`, if those are not None.
2598
2599        include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1.
2600            The last element is the size of the input, or the ending index position of the last bag (sequence).
2601
2602        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the
2603                                     gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated
2604                                     during training, i.e. it remains as a fixed "pad". Note that the embedding
2605                                     vector at :attr:`padding_idx` is excluded from the reduction.
2606
2607    Shape:
2608        - :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional)
2609
2610          - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences)
2611            each of fixed length ``N``, and this will return ``B`` values aggregated in a way
2612            depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case.
2613
2614          - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of
2615            multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing
2616            the starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets`
2617            of shape `(B)`, :attr:`input` will be viewed as having ``B`` bags.
2618            Empty bags (i.e., having 0-length) will have returned vectors filled by zeros.
2619
2620        - :attr:`weight` (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
2621
2622        - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as :attr:`input`.
2623
2624        - :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)`
2625
2626    Examples::
2627
2628        >>> # an Embedding module containing 10 tensors of size 3
2629        >>> embedding_matrix = torch.rand(10, 3)
2630        >>> # a batch of 2 samples of 4 indices each
2631        >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
2632        >>> offsets = torch.tensor([0, 4])
2633        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
2634        >>> F.embedding_bag(input, embedding_matrix, offsets)
2635        tensor([[ 0.3397,  0.3552,  0.5545],
2636                [ 0.5893,  0.4386,  0.5882]])
2637
2638        >>> # example with padding_idx
2639        >>> embedding_matrix = torch.rand(10, 3)
2640        >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9])
2641        >>> offsets = torch.tensor([0, 4])
2642        >>> F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum')
2643        tensor([[ 0.0000,  0.0000,  0.0000],
2644                [-0.7082,  3.2145, -2.6251]])
2645    """
2646    if has_torch_function_variadic(input, weight, offsets, per_sample_weights):
2647        return handle_torch_function(
2648            embedding_bag,
2649            (input, weight, offsets, per_sample_weights),
2650            input,
2651            weight,
2652            offsets=offsets,
2653            max_norm=max_norm,
2654            norm_type=norm_type,
2655            scale_grad_by_freq=scale_grad_by_freq,
2656            mode=mode,
2657            sparse=sparse,
2658            per_sample_weights=per_sample_weights,
2659            include_last_offset=include_last_offset,
2660            padding_idx=padding_idx,
2661        )
2662    # Check for backward compatibility.
2663    # Used to be embedding_bag(weight, input, ...)
2664    # Now is     embedding_bag(input, weight, ...)
2665    if weight.dtype == torch.long and input.is_floating_point():
2666        warnings.warn(
2667            "Argument order of nn.functional.embedding_bag was changed. "
2668            "Usage `embedding_bag(weight, input, ...)` is deprecated, "
2669            "and should now be `embedding_bag(input, weight, ...)`."
2670        )
2671        weight, input = input, weight
2672
2673    if per_sample_weights is not None and input.size() != per_sample_weights.size():
2674        raise ValueError(
2675            f"embedding_bag: If per_sample_weights ({per_sample_weights.shape}) is not None, "
2676            f"then it must have the same shape as the input ({input.shape})"
2677        )
2678
2679    if not weight.dim() == 2:
2680        raise ValueError(
2681            f"weight has to be a 2D Tensor, but got Tensor of dimension {weight.dim()}"
2682        )
2683
2684    if input.dim() == 2:
2685        if offsets is not None:
2686            type_str = "<unknown>"
2687            # TODO: Remove this once script supports type() calls
2688            if not torch.jit.is_scripting():
2689                type_str = str(type(offsets))
2690            raise ValueError(
2691                "if input is 2D, then offsets has to be None"
2692                ", as input is treated is a mini-batch of"
2693                " fixed length sequences. However, found "
2694                f"offsets of type {type_str}"
2695            )
2696        offsets = torch.arange(
2697            0, input.numel(), input.size(1), dtype=input.dtype, device=input.device
2698        )
2699
2700        input = input.reshape(-1)
2701        if per_sample_weights is not None:
2702            per_sample_weights = per_sample_weights.reshape(-1)
2703    elif input.dim() == 1:
2704        if offsets is None:
2705            raise ValueError("offsets has to be a 1D Tensor but got None")
2706        if offsets.dim() != 1:
2707            raise ValueError("offsets has to be a 1D Tensor")
2708    else:
2709        raise ValueError(
2710            f"input has to be 1D or 2D Tensor, but got Tensor of dimension {input.dim()}"
2711        )
2712    if mode == "sum":
2713        mode_enum = 0
2714    elif mode == "mean":
2715        mode_enum = 1
2716    elif mode == "max":
2717        mode_enum = 2
2718
2719        if scale_grad_by_freq:
2720            raise ValueError(
2721                "max mode does not support scaling the gradient by the frequency"
2722            )
2723
2724        if sparse:
2725            raise ValueError("max mode does not support sparse weights")
2726
2727    else:
2728        raise ValueError("mode has to be one of sum, mean or max")
2729
2730    if max_norm is not None:
2731        # XXX: equivalent to
2732        # with torch.no_grad():
2733        #   torch.nembedding_renorm_
2734        # remove once script supports set_grad_enabled
2735        _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
2736
2737    if per_sample_weights is not None and mode != "sum":
2738        raise NotImplementedError(
2739            "embedding_bag: per_sample_weights was not None. "
2740            "per_sample_weights is only supported for mode='sum' "
2741            f"(got mode='{mode}'). Please open a feature request on GitHub."
2742        )
2743
2744    ret, _, _, _ = torch.embedding_bag(
2745        weight,
2746        input,
2747        offsets,
2748        scale_grad_by_freq,
2749        mode_enum,
2750        sparse,
2751        per_sample_weights,
2752        include_last_offset,
2753        padding_idx,
2754    )
2755    return ret
2756
2757
2758if embedding_bag.__doc__:
2759    embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes)
2760
2761
2762def _verify_batch_size(size: List[int]) -> None:
2763    # XXX: JIT script does not support the reduce from functools, and mul op is a
2764    # builtin, which cannot be used as a value to a func yet, so rewrite this size
2765    # check to a simple equivalent for loop
2766    #
2767    # TODO: make use of reduce like below when JIT is ready with the missing features:
2768    # from operator import mul
2769    # from functools import reduce
2770    #
2771    #   if reduce(mul, size[2:], size[0]) == 1
2772    size_prods = size[0]
2773    for i in range(len(size) - 2):
2774        size_prods *= size[i + 2]
2775    if size_prods == 1:
2776        raise ValueError(
2777            f"Expected more than 1 value per channel when training, got input size {size}"
2778        )
2779
2780
2781def batch_norm(
2782    input: Tensor,
2783    running_mean: Optional[Tensor],
2784    running_var: Optional[Tensor],
2785    weight: Optional[Tensor] = None,
2786    bias: Optional[Tensor] = None,
2787    training: bool = False,
2788    momentum: float = 0.1,
2789    eps: float = 1e-5,
2790) -> Tensor:
2791    r"""Apply Batch Normalization for each channel across a batch of data.
2792
2793    See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
2794    :class:`~torch.nn.BatchNorm3d` for details.
2795    """
2796    if has_torch_function_variadic(input, running_mean, running_var, weight, bias):
2797        return handle_torch_function(
2798            batch_norm,
2799            (input, running_mean, running_var, weight, bias),
2800            input,
2801            running_mean,
2802            running_var,
2803            weight=weight,
2804            bias=bias,
2805            training=training,
2806            momentum=momentum,
2807            eps=eps,
2808        )
2809    if training:
2810        _verify_batch_size(input.size())
2811
2812    return torch.batch_norm(
2813        input,
2814        weight,
2815        bias,
2816        running_mean,
2817        running_var,
2818        training,
2819        momentum,
2820        eps,
2821        torch.backends.cudnn.enabled,
2822    )
2823
2824
2825def _verify_spatial_size(size: List[int]) -> None:
2826    # Verify that there is > 1 spatial element for instance norm calculation.
2827    size_prods = 1
2828    for i in range(2, len(size)):
2829        size_prods *= size[i]
2830    if size_prods == 1:
2831        raise ValueError(
2832            f"Expected more than 1 spatial element when training, got input size {size}"
2833        )
2834
2835
2836def instance_norm(
2837    input: Tensor,
2838    running_mean: Optional[Tensor] = None,
2839    running_var: Optional[Tensor] = None,
2840    weight: Optional[Tensor] = None,
2841    bias: Optional[Tensor] = None,
2842    use_input_stats: bool = True,
2843    momentum: float = 0.1,
2844    eps: float = 1e-5,
2845) -> Tensor:
2846    r"""Apply Instance Normalization independently for each channel in every data sample within a batch.
2847
2848    See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`,
2849    :class:`~torch.nn.InstanceNorm3d` for details.
2850    """
2851    if has_torch_function_variadic(input, running_mean, running_var, weight, bias):
2852        return handle_torch_function(
2853            instance_norm,
2854            (input, running_mean, running_var, weight, bias),
2855            input,
2856            running_mean=running_mean,
2857            running_var=running_var,
2858            weight=weight,
2859            bias=bias,
2860            use_input_stats=use_input_stats,
2861            momentum=momentum,
2862            eps=eps,
2863        )
2864    if use_input_stats:
2865        _verify_spatial_size(input.size())
2866    return torch.instance_norm(
2867        input,
2868        weight,
2869        bias,
2870        running_mean,
2871        running_var,
2872        use_input_stats,
2873        momentum,
2874        eps,
2875        torch.backends.cudnn.enabled,
2876    )
2877
2878
2879def layer_norm(
2880    input: Tensor,
2881    normalized_shape: List[int],
2882    weight: Optional[Tensor] = None,
2883    bias: Optional[Tensor] = None,
2884    eps: float = 1e-5,
2885) -> Tensor:
2886    r"""Apply Layer Normalization for last certain number of dimensions.
2887
2888    See :class:`~torch.nn.LayerNorm` for details.
2889    """
2890    if has_torch_function_variadic(input, weight, bias):
2891        return handle_torch_function(
2892            layer_norm,
2893            (input, weight, bias),
2894            input,
2895            normalized_shape,
2896            weight=weight,
2897            bias=bias,
2898            eps=eps,
2899        )
2900    return torch.layer_norm(
2901        input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled
2902    )
2903
2904
2905def rms_norm(
2906    input: Tensor,
2907    normalized_shape: List[int],
2908    weight: Optional[Tensor] = None,
2909    eps: Optional[float] = None,
2910) -> Tensor:
2911    r"""Apply Root Mean Square Layer Normalization.
2912
2913    See :class:`~torch.nn.RMSNorm` for details.
2914    """
2915    if has_torch_function_variadic(input, weight):
2916        return handle_torch_function(
2917            rms_norm, (input, weight), input, normalized_shape, weight=weight, eps=eps
2918        )
2919    return torch.rms_norm(input, normalized_shape, weight, eps)
2920
2921
2922def group_norm(
2923    input: Tensor,
2924    num_groups: int,
2925    weight: Optional[Tensor] = None,
2926    bias: Optional[Tensor] = None,
2927    eps: float = 1e-5,
2928) -> Tensor:
2929    r"""Apply Group Normalization for last certain number of dimensions.
2930
2931    See :class:`~torch.nn.GroupNorm` for details.
2932    """
2933    if has_torch_function_variadic(input, weight, bias):
2934        return handle_torch_function(
2935            group_norm,
2936            (
2937                input,
2938                weight,
2939                bias,
2940            ),
2941            input,
2942            num_groups,
2943            weight=weight,
2944            bias=bias,
2945            eps=eps,
2946        )
2947    if input.dim() < 2:
2948        raise RuntimeError(
2949            f"Expected at least 2 dimensions for input tensor but received {input.dim()}"
2950        )
2951    _verify_batch_size(
2952        [input.size(0) * input.size(1) // num_groups, num_groups]
2953        + list(input.size()[2:])
2954    )
2955    return torch.group_norm(
2956        input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled
2957    )
2958
2959
2960def local_response_norm(
2961    input: Tensor,
2962    size: int,
2963    alpha: float = 1e-4,
2964    beta: float = 0.75,
2965    k: float = 1.0,
2966) -> Tensor:
2967    r"""Apply local response normalization over an input signal.
2968
2969    The input signal is composed of several input planes, where channels occupy the second dimension.
2970    Normalization is applied across channels.
2971
2972    See :class:`~torch.nn.LocalResponseNorm` for details.
2973    """
2974    if has_torch_function_unary(input):
2975        return handle_torch_function(
2976            local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k
2977        )
2978    dim = input.dim()
2979    if dim < 3:
2980        raise ValueError(
2981            f"Expected 3D or higher dimensionality                          input (got {dim} dimensions)"
2982        )
2983
2984    if input.numel() == 0:
2985        return input
2986
2987    div = input.mul(input)
2988    if dim == 3:
2989        div = div.unsqueeze(1)
2990        div = pad(div, (0, 0, size // 2, (size - 1) // 2))
2991        div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
2992    else:
2993        sizes = input.size()
2994        div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
2995        div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
2996        div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
2997        div = div.view(sizes)
2998    div = div.mul(alpha).add(k).pow(beta)
2999    return input / div
3000
3001
3002# loss
3003
3004
3005def ctc_loss(
3006    log_probs: Tensor,
3007    targets: Tensor,
3008    input_lengths: Tensor,
3009    target_lengths: Tensor,
3010    blank: int = 0,
3011    reduction: str = "mean",
3012    zero_infinity: bool = False,
3013) -> Tensor:
3014    r"""Apply the Connectionist Temporal Classification loss.
3015
3016    See :class:`~torch.nn.CTCLoss` for details.
3017
3018    Note:
3019        {cudnn_reproducibility_note}
3020
3021    Note:
3022        {backward_reproducibility_note}
3023
3024    Args:
3025        log_probs: :math:`(T, N, C)` or :math:`(T, C)` where `C = number of characters in alphabet including blank`,
3026            `T = input length`, and `N = batch size`.
3027            The logarithmized probabilities of the outputs
3028            (e.g. obtained with :func:`torch.nn.functional.log_softmax`).
3029        targets: :math:`(N, S)` or `(sum(target_lengths))`.
3030            Targets cannot be blank. In the second form, the targets are assumed to be concatenated.
3031        input_lengths: :math:`(N)` or :math:`()`.
3032            Lengths of the inputs (must each be :math:`\leq T`)
3033        target_lengths: :math:`(N)` or :math:`()`.
3034            Lengths of the targets
3035        blank (int, optional):
3036            Blank label. Default :math:`0`.
3037        reduction (str, optional): Specifies the reduction to apply to the output:
3038            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3039            ``'mean'``: the output losses will be divided by the target lengths and
3040            then the mean over the batch is taken, ``'sum'``: the output will be
3041            summed. Default: ``'mean'``
3042        zero_infinity (bool, optional):
3043            Whether to zero infinite losses and the associated gradients.
3044            Default: ``False``
3045            Infinite losses mainly occur when the inputs are too short
3046            to be aligned to the targets.
3047
3048    Example::
3049
3050        >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
3051        >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
3052        >>> input_lengths = torch.full((16,), 50, dtype=torch.long)
3053        >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
3054        >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
3055        >>> loss.backward()
3056    """
3057    if has_torch_function_variadic(log_probs, targets, input_lengths, target_lengths):
3058        return handle_torch_function(
3059            ctc_loss,
3060            (log_probs, targets, input_lengths, target_lengths),
3061            log_probs,
3062            targets,
3063            input_lengths,
3064            target_lengths,
3065            blank=blank,
3066            reduction=reduction,
3067            zero_infinity=zero_infinity,
3068        )
3069    return torch.ctc_loss(
3070        log_probs,
3071        targets,
3072        input_lengths,
3073        target_lengths,
3074        blank,
3075        _Reduction.get_enum(reduction),
3076        zero_infinity,
3077    )
3078
3079
3080if ctc_loss.__doc__:
3081    ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes)
3082
3083
3084def nll_loss(
3085    input: Tensor,
3086    target: Tensor,
3087    weight: Optional[Tensor] = None,
3088    size_average: Optional[bool] = None,
3089    ignore_index: int = -100,
3090    reduce: Optional[bool] = None,
3091    reduction: str = "mean",
3092) -> Tensor:
3093    r"""Compute the negative log likelihood loss.
3094
3095    See :class:`~torch.nn.NLLLoss` for details.
3096
3097    Args:
3098        input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
3099            in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \geq 1`
3100            in the case of K-dimensional loss. `input` is expected to be log-probabilities.
3101        target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`,
3102            or :math:`(N, d_1, d_2, ..., d_K)` where :math:`K \geq 1` for
3103            K-dimensional loss.
3104        weight (Tensor, optional): a manual rescaling weight given to each
3105            class. If given, has to be a Tensor of size `C`
3106        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3107            the losses are averaged over each loss element in the batch. Note that for
3108            some losses, there multiple elements per sample. If the field :attr:`size_average`
3109            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3110            when reduce is ``False``. Default: ``True``
3111        ignore_index (int, optional): Specifies a target value that is ignored
3112            and does not contribute to the input gradient. When :attr:`size_average` is
3113            ``True``, the loss is averaged over non-ignored targets. Default: -100
3114        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3115            losses are averaged or summed over observations for each minibatch depending
3116            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3117            batch element instead and ignores :attr:`size_average`. Default: ``True``
3118        reduction (str, optional): Specifies the reduction to apply to the output:
3119            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3120            ``'mean'``: the sum of the output will be divided by the number of
3121            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
3122            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
3123            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
3124
3125    Example::
3126
3127        >>> # input is of size N x C = 3 x 5
3128        >>> input = torch.randn(3, 5, requires_grad=True)
3129        >>> # each element in target has to have 0 <= value < C
3130        >>> target = torch.tensor([1, 0, 4])
3131        >>> output = F.nll_loss(F.log_softmax(input, dim=1), target)
3132        >>> output.backward()
3133    """
3134    if has_torch_function_variadic(input, target, weight):
3135        return handle_torch_function(
3136            nll_loss,
3137            (input, target, weight),
3138            input,
3139            target,
3140            weight=weight,
3141            size_average=size_average,
3142            ignore_index=ignore_index,
3143            reduce=reduce,
3144            reduction=reduction,
3145        )
3146    if size_average is not None or reduce is not None:
3147        reduction = _Reduction.legacy_get_string(size_average, reduce)
3148    return torch._C._nn.nll_loss_nd(
3149        input, target, weight, _Reduction.get_enum(reduction), ignore_index
3150    )
3151
3152
3153def poisson_nll_loss(
3154    input: Tensor,
3155    target: Tensor,
3156    log_input: bool = True,
3157    full: bool = False,
3158    size_average: Optional[bool] = None,
3159    eps: float = 1e-8,
3160    reduce: Optional[bool] = None,
3161    reduction: str = "mean",
3162) -> Tensor:
3163    r"""Poisson negative log likelihood loss.
3164
3165    See :class:`~torch.nn.PoissonNLLLoss` for details.
3166
3167    Args:
3168        input: expectation of underlying Poisson distribution.
3169        target: random sample :math:`target \sim \text{Poisson}(input)`.
3170        log_input: if ``True`` the loss is computed as
3171            :math:`\exp(\text{input}) - \text{target} * \text{input}`, if ``False`` then loss is
3172            :math:`\text{input} - \text{target} * \log(\text{input}+\text{eps})`. Default: ``True``
3173        full: whether to compute full loss, i. e. to add the Stirling
3174            approximation term. Default: ``False``
3175            :math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`.
3176        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3177            the losses are averaged over each loss element in the batch. Note that for
3178            some losses, there multiple elements per sample. If the field :attr:`size_average`
3179            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3180            when reduce is ``False``. Default: ``True``
3181        eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when
3182            :attr:`log_input`\ =\ ``False``. Default: 1e-8
3183        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3184            losses are averaged or summed over observations for each minibatch depending
3185            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3186            batch element instead and ignores :attr:`size_average`. Default: ``True``
3187        reduction (str, optional): Specifies the reduction to apply to the output:
3188            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3189            ``'mean'``: the sum of the output will be divided by the number of
3190            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
3191            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
3192            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
3193
3194    """
3195    if has_torch_function_variadic(input, target):
3196        return handle_torch_function(
3197            poisson_nll_loss,
3198            (input, target),
3199            input,
3200            target,
3201            log_input=log_input,
3202            full=full,
3203            size_average=size_average,
3204            eps=eps,
3205            reduce=reduce,
3206            reduction=reduction,
3207        )
3208    if size_average is not None or reduce is not None:
3209        reduction = _Reduction.legacy_get_string(size_average, reduce)
3210    if reduction != "none" and reduction != "mean" and reduction != "sum":
3211        ret = input
3212        raise ValueError(reduction + " is not a valid value for reduction")
3213
3214    ret = torch.poisson_nll_loss(
3215        input, target, log_input, full, eps, _Reduction.get_enum(reduction)
3216    )
3217    return ret
3218
3219
3220def gaussian_nll_loss(
3221    input: Tensor,
3222    target: Tensor,
3223    var: Tensor,
3224    full: bool = False,
3225    eps: float = 1e-6,
3226    reduction: str = "mean",
3227) -> Tensor:
3228    r"""Gaussian negative log likelihood loss.
3229
3230    See :class:`~torch.nn.GaussianNLLLoss` for details.
3231
3232    Args:
3233        input: expectation of the Gaussian distribution.
3234        target: sample from the Gaussian distribution.
3235        var: tensor of positive variance(s), one for each of the expectations
3236            in the input (heteroscedastic), or a single one (homoscedastic).
3237        full (bool, optional): include the constant term in the loss calculation. Default: ``False``.
3238        eps (float, optional): value added to var, for stability. Default: 1e-6.
3239        reduction (str, optional): specifies the reduction to apply to the output:
3240            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3241            ``'mean'``: the output is the average of all batch member losses,
3242            ``'sum'``: the output is the sum of all batch member losses.
3243            Default: ``'mean'``.
3244    """
3245    if has_torch_function_variadic(input, target, var):
3246        return handle_torch_function(
3247            gaussian_nll_loss,
3248            (input, target, var),
3249            input,
3250            target,
3251            var,
3252            full=full,
3253            eps=eps,
3254            reduction=reduction,
3255        )
3256
3257    # Check var size
3258    # If var.size == input.size, the case is heteroscedastic and no further checks are needed.
3259    # Otherwise:
3260    if var.size() != input.size():
3261        # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case.
3262        # e.g. input.size = (10, 2, 3), var.size = (10, 2)
3263        # -> unsqueeze var so that var.shape = (10, 2, 1)
3264        # this is done so that broadcasting can happen in the loss calculation
3265        if input.size()[:-1] == var.size():
3266            var = torch.unsqueeze(var, -1)
3267
3268        # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1.
3269        # This is also a homoscedastic case.
3270        # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1)
3271        elif (
3272            input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1
3273        ):  # Heteroscedastic case
3274            pass
3275
3276        # If none of the above pass, then the size of var is incorrect.
3277        else:
3278            raise ValueError("var is of incorrect size")
3279
3280    # Check validity of reduction mode
3281    if reduction != "none" and reduction != "mean" and reduction != "sum":
3282        raise ValueError(reduction + " is not valid")
3283
3284    # Entries of var must be non-negative
3285    if torch.any(var < 0):
3286        raise ValueError("var has negative entry/entries")
3287
3288    # Clamp for stability
3289    var = var.clone()
3290    with torch.no_grad():
3291        var.clamp_(min=eps)
3292
3293    # Calculate the loss
3294    loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var)
3295    if full:
3296        loss += 0.5 * math.log(2 * math.pi)
3297
3298    if reduction == "mean":
3299        return loss.mean()
3300    elif reduction == "sum":
3301        return loss.sum()
3302    else:
3303        return loss
3304
3305
3306def kl_div(
3307    input: Tensor,
3308    target: Tensor,
3309    size_average: Optional[bool] = None,
3310    reduce: Optional[bool] = None,
3311    reduction: str = "mean",
3312    log_target: bool = False,
3313) -> Tensor:
3314    r"""Compute the KL Divergence loss.
3315
3316    Refer - The `Kullback-Leibler divergence Loss
3317    <https://en.wikipedia.org/wiki/Kullback-Leibler_divergence>`__
3318
3319    See :class:`~torch.nn.KLDivLoss` for details.
3320
3321    Args:
3322        input: Tensor of arbitrary shape in log-probabilities.
3323        target: Tensor of the same shape as input. See :attr:`log_target` for
3324            the target's interpretation.
3325        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3326            the losses are averaged over each loss element in the batch. Note that for
3327            some losses, there multiple elements per sample. If the field :attr:`size_average`
3328            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3329            when reduce is ``False``. Default: ``True``
3330        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3331            losses are averaged or summed over observations for each minibatch depending
3332            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3333            batch element instead and ignores :attr:`size_average`. Default: ``True``
3334        reduction (str, optional): Specifies the reduction to apply to the output:
3335            ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
3336            ``'none'``: no reduction will be applied
3337            ``'batchmean'``: the sum of the output will be divided by the batchsize
3338            ``'sum'``: the output will be summed
3339            ``'mean'``: the output will be divided by the number of elements in the output
3340            Default: ``'mean'``
3341        log_target (bool): A flag indicating whether ``target`` is passed in the log space.
3342            It is recommended to pass certain distributions (like ``softmax``)
3343            in the log space to avoid numerical issues caused by explicit ``log``.
3344            Default: ``False``
3345
3346    .. note::
3347        :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
3348        and in the meantime, specifying either of those two args will override :attr:`reduction`.
3349
3350    .. warning::
3351        :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use
3352        :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition.
3353    """
3354    if has_torch_function_variadic(input, target):
3355        return handle_torch_function(
3356            kl_div,
3357            (input, target),
3358            input,
3359            target,
3360            size_average=size_average,
3361            reduce=reduce,
3362            reduction=reduction,
3363            log_target=log_target,
3364        )
3365    if size_average is not None or reduce is not None:
3366        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3367    else:
3368        if reduction == "mean":
3369            warnings.warn(
3370                "reduction: 'mean' divides the total loss by both the batch size and the support size."
3371                "'batchmean' divides only by the batch size, and aligns with the KL div math definition."
3372                "'mean' will be changed to behave the same as 'batchmean' in the next major release."
3373            )
3374
3375        # special case for batchmean
3376        if reduction == "batchmean":
3377            reduction_enum = _Reduction.get_enum("sum")
3378        else:
3379            reduction_enum = _Reduction.get_enum(reduction)
3380
3381    reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target)
3382
3383    if reduction == "batchmean" and input.dim() != 0:
3384        reduced = reduced / input.size()[0]
3385
3386    return reduced
3387
3388
3389def cross_entropy(
3390    input: Tensor,
3391    target: Tensor,
3392    weight: Optional[Tensor] = None,
3393    size_average: Optional[bool] = None,
3394    ignore_index: int = -100,
3395    reduce: Optional[bool] = None,
3396    reduction: str = "mean",
3397    label_smoothing: float = 0.0,
3398) -> Tensor:
3399    r"""Compute the cross entropy loss between input logits and target.
3400
3401    See :class:`~torch.nn.CrossEntropyLoss` for details.
3402
3403    Args:
3404        input (Tensor) : Predicted unnormalized logits;
3405            see Shape section below for supported shapes.
3406        target (Tensor) : Ground truth class indices or class probabilities;
3407            see Shape section below for supported shapes.
3408        weight (Tensor, optional): a manual rescaling weight given to each
3409            class. If given, has to be a Tensor of size `C`
3410        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3411            the losses are averaged over each loss element in the batch. Note that for
3412            some losses, there multiple elements per sample. If the field :attr:`size_average`
3413            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3414            when reduce is ``False``. Default: ``True``
3415        ignore_index (int, optional): Specifies a target value that is ignored
3416            and does not contribute to the input gradient. When :attr:`size_average` is
3417            ``True``, the loss is averaged over non-ignored targets. Note that
3418            :attr:`ignore_index` is only applicable when the target contains class indices.
3419            Default: -100
3420        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3421            losses are averaged or summed over observations for each minibatch depending
3422            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3423            batch element instead and ignores :attr:`size_average`. Default: ``True``
3424        reduction (str, optional): Specifies the reduction to apply to the output:
3425            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3426            ``'mean'``: the sum of the output will be divided by the number of
3427            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
3428            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
3429            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
3430        label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount
3431            of smoothing when computing the loss, where 0.0 means no smoothing. The targets
3432            become a mixture of the original ground truth and a uniform distribution as described in
3433            `Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`.
3434
3435    Shape:
3436        - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
3437          in the case of `K`-dimensional loss.
3438        - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with
3439          :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`.
3440          If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`.
3441
3442        where:
3443
3444        .. math::
3445            \begin{aligned}
3446                C ={} & \text{number of classes} \\
3447                N ={} & \text{batch size} \\
3448            \end{aligned}
3449
3450    Examples::
3451
3452        >>> # Example of target with class indices
3453        >>> input = torch.randn(3, 5, requires_grad=True)
3454        >>> target = torch.randint(5, (3,), dtype=torch.int64)
3455        >>> loss = F.cross_entropy(input, target)
3456        >>> loss.backward()
3457        >>>
3458        >>> # Example of target with class probabilities
3459        >>> input = torch.randn(3, 5, requires_grad=True)
3460        >>> target = torch.randn(3, 5).softmax(dim=1)
3461        >>> loss = F.cross_entropy(input, target)
3462        >>> loss.backward()
3463    """
3464    if has_torch_function_variadic(input, target, weight):
3465        return handle_torch_function(
3466            cross_entropy,
3467            (input, target, weight),
3468            input,
3469            target,
3470            weight=weight,
3471            size_average=size_average,
3472            ignore_index=ignore_index,
3473            reduce=reduce,
3474            reduction=reduction,
3475            label_smoothing=label_smoothing,
3476        )
3477    if size_average is not None or reduce is not None:
3478        reduction = _Reduction.legacy_get_string(size_average, reduce)
3479    return torch._C._nn.cross_entropy_loss(
3480        input,
3481        target,
3482        weight,
3483        _Reduction.get_enum(reduction),
3484        ignore_index,
3485        label_smoothing,
3486    )
3487
3488
3489def binary_cross_entropy(
3490    input: Tensor,
3491    target: Tensor,
3492    weight: Optional[Tensor] = None,
3493    size_average: Optional[bool] = None,
3494    reduce: Optional[bool] = None,
3495    reduction: str = "mean",
3496) -> Tensor:
3497    r"""Measure Binary Cross Entropy between the target and input probabilities.
3498
3499    See :class:`~torch.nn.BCELoss` for details.
3500
3501    Args:
3502        input: Tensor of arbitrary shape as probabilities.
3503        target: Tensor of the same shape as input with values between 0 and 1.
3504        weight (Tensor, optional): a manual rescaling weight
3505                if provided it's repeated to match input tensor shape
3506        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3507            the losses are averaged over each loss element in the batch. Note that for
3508            some losses, there multiple elements per sample. If the field :attr:`size_average`
3509            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3510            when reduce is ``False``. Default: ``True``
3511        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3512            losses are averaged or summed over observations for each minibatch depending
3513            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3514            batch element instead and ignores :attr:`size_average`. Default: ``True``
3515        reduction (str, optional): Specifies the reduction to apply to the output:
3516            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3517            ``'mean'``: the sum of the output will be divided by the number of
3518            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
3519            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
3520            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
3521
3522    Examples::
3523
3524        >>> input = torch.randn(3, 2, requires_grad=True)
3525        >>> target = torch.rand(3, 2, requires_grad=False)
3526        >>> loss = F.binary_cross_entropy(torch.sigmoid(input), target)
3527        >>> loss.backward()
3528    """
3529    if has_torch_function_variadic(input, target, weight):
3530        return handle_torch_function(
3531            binary_cross_entropy,
3532            (input, target, weight),
3533            input,
3534            target,
3535            weight=weight,
3536            size_average=size_average,
3537            reduce=reduce,
3538            reduction=reduction,
3539        )
3540    if size_average is not None or reduce is not None:
3541        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3542    else:
3543        reduction_enum = _Reduction.get_enum(reduction)
3544    if target.size() != input.size():
3545        raise ValueError(
3546            f"Using a target size ({target.size()}) that is different to the input size ({input.size()}) is deprecated. "
3547            "Please ensure they have the same size."
3548        )
3549
3550    if weight is not None:
3551        new_size = _infer_size(target.size(), weight.size())
3552        weight = weight.expand(new_size)
3553
3554    return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
3555
3556
3557def binary_cross_entropy_with_logits(
3558    input: Tensor,
3559    target: Tensor,
3560    weight: Optional[Tensor] = None,
3561    size_average: Optional[bool] = None,
3562    reduce: Optional[bool] = None,
3563    reduction: str = "mean",
3564    pos_weight: Optional[Tensor] = None,
3565) -> Tensor:
3566    r"""Calculate Binary Cross Entropy between target and input logits.
3567
3568    See :class:`~torch.nn.BCEWithLogitsLoss` for details.
3569
3570    Args:
3571        input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits).
3572        target: Tensor of the same shape as input with values between 0 and 1
3573        weight (Tensor, optional): a manual rescaling weight
3574            if provided it's repeated to match input tensor shape
3575        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
3576            the losses are averaged over each loss element in the batch. Note that for
3577            some losses, there multiple elements per sample. If the field :attr:`size_average`
3578            is set to ``False``, the losses are instead summed for each minibatch. Ignored
3579            when reduce is ``False``. Default: ``True``
3580        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
3581            losses are averaged or summed over observations for each minibatch depending
3582            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
3583            batch element instead and ignores :attr:`size_average`. Default: ``True``
3584        reduction (str, optional): Specifies the reduction to apply to the output:
3585            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
3586            ``'mean'``: the sum of the output will be divided by the number of
3587            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
3588            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
3589            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
3590        pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target.
3591            Must be a tensor with equal size along the class dimension to the number of classes.
3592            Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired
3593            operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of
3594            size [B, C, H, W] will apply different pos_weights to each element of the batch or
3595            [C, H, W] the same pos_weights across the batch. To apply the same positive weight
3596            along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1].
3597            Default: ``None``
3598
3599    Examples::
3600
3601         >>> input = torch.randn(3, requires_grad=True)
3602         >>> target = torch.empty(3).random_(2)
3603         >>> loss = F.binary_cross_entropy_with_logits(input, target)
3604         >>> loss.backward()
3605    """
3606    if has_torch_function_variadic(input, target, weight, pos_weight):
3607        return handle_torch_function(
3608            binary_cross_entropy_with_logits,
3609            (input, target, weight, pos_weight),
3610            input,
3611            target,
3612            weight=weight,
3613            size_average=size_average,
3614            reduce=reduce,
3615            reduction=reduction,
3616            pos_weight=pos_weight,
3617        )
3618    if size_average is not None or reduce is not None:
3619        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3620    else:
3621        reduction_enum = _Reduction.get_enum(reduction)
3622
3623    if not (target.size() == input.size()):
3624        raise ValueError(
3625            f"Target size ({target.size()}) must be the same as input size ({input.size()})"
3626        )
3627
3628    return torch.binary_cross_entropy_with_logits(
3629        input, target, weight, pos_weight, reduction_enum
3630    )
3631
3632
3633def smooth_l1_loss(
3634    input: Tensor,
3635    target: Tensor,
3636    size_average: Optional[bool] = None,
3637    reduce: Optional[bool] = None,
3638    reduction: str = "mean",
3639    beta: float = 1.0,
3640) -> Tensor:
3641    r"""Compute the Smooth L1 loss.
3642
3643    Function uses a squared term if the absolute
3644    element-wise error falls below beta and an L1 term otherwise.
3645
3646    See :class:`~torch.nn.SmoothL1Loss` for details.
3647    """
3648    if has_torch_function_variadic(input, target):
3649        return handle_torch_function(
3650            smooth_l1_loss,
3651            (input, target),
3652            input,
3653            target,
3654            size_average=size_average,
3655            reduce=reduce,
3656            reduction=reduction,
3657            beta=beta,
3658        )
3659    if not (target.size() == input.size()):
3660        warnings.warn(
3661            f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
3662            "This will likely lead to incorrect results due to broadcasting. "
3663            "Please ensure they have the same size.",
3664            stacklevel=2,
3665        )
3666    if size_average is not None or reduce is not None:
3667        reduction = _Reduction.legacy_get_string(size_average, reduce)
3668
3669    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
3670
3671    if beta == 0.0:
3672        return torch._C._nn.l1_loss(
3673            expanded_input, expanded_target, _Reduction.get_enum(reduction)
3674        )
3675    else:
3676        return torch._C._nn.smooth_l1_loss(
3677            expanded_input, expanded_target, _Reduction.get_enum(reduction), beta
3678        )
3679
3680
3681def huber_loss(
3682    input: Tensor,
3683    target: Tensor,
3684    reduction: str = "mean",
3685    delta: float = 1.0,
3686) -> Tensor:
3687    r"""Compute the Huber loss.
3688
3689    Function uses a squared term if the absolute
3690    element-wise error falls below delta and a delta-scaled L1 term otherwise.
3691
3692    When delta equals 1, this loss is equivalent to SmoothL1Loss.
3693    In general, Huber loss differs from SmoothL1Loss by a factor of delta (AKA beta in Smooth L1).
3694
3695    See :class:`~torch.nn.HuberLoss` for details.
3696    """
3697    if has_torch_function_variadic(input, target):
3698        return handle_torch_function(
3699            huber_loss,
3700            (input, target),
3701            input,
3702            target,
3703            reduction=reduction,
3704            delta=delta,
3705        )
3706    if not (target.size() == input.size()):
3707        warnings.warn(
3708            f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
3709            "This will likely lead to incorrect results due to broadcasting. "
3710            "Please ensure they have the same size.",
3711            stacklevel=2,
3712        )
3713
3714    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
3715    return torch._C._nn.huber_loss(
3716        expanded_input, expanded_target, _Reduction.get_enum(reduction), delta
3717    )
3718
3719
3720def l1_loss(
3721    input: Tensor,
3722    target: Tensor,
3723    size_average: Optional[bool] = None,
3724    reduce: Optional[bool] = None,
3725    reduction: str = "mean",
3726) -> Tensor:  # noqa: D400,D402
3727    r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
3728
3729    Function that takes the mean element-wise absolute value difference.
3730
3731    See :class:`~torch.nn.L1Loss` for details.
3732    """
3733    if has_torch_function_variadic(input, target):
3734        return handle_torch_function(
3735            l1_loss,
3736            (input, target),
3737            input,
3738            target,
3739            size_average=size_average,
3740            reduce=reduce,
3741            reduction=reduction,
3742        )
3743    if not (target.size() == input.size()):
3744        warnings.warn(
3745            f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
3746            "This will likely lead to incorrect results due to broadcasting. "
3747            "Please ensure they have the same size.",
3748            stacklevel=2,
3749        )
3750    if size_average is not None or reduce is not None:
3751        reduction = _Reduction.legacy_get_string(size_average, reduce)
3752
3753    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
3754    return torch._C._nn.l1_loss(
3755        expanded_input, expanded_target, _Reduction.get_enum(reduction)
3756    )
3757
3758
3759def mse_loss(
3760    input: Tensor,
3761    target: Tensor,
3762    size_average: Optional[bool] = None,
3763    reduce: Optional[bool] = None,
3764    reduction: str = "mean",
3765) -> Tensor:  # noqa: D400,D402
3766    r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
3767
3768    Measures the element-wise mean squared error.
3769    See :class:`~torch.nn.MSELoss` for details.
3770    """
3771    if has_torch_function_variadic(input, target):
3772        return handle_torch_function(
3773            mse_loss,
3774            (input, target),
3775            input,
3776            target,
3777            size_average=size_average,
3778            reduce=reduce,
3779            reduction=reduction,
3780        )
3781    if not (target.size() == input.size()):
3782        warnings.warn(
3783            f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
3784            "This will likely lead to incorrect results due to broadcasting. "
3785            "Please ensure they have the same size.",
3786            stacklevel=2,
3787        )
3788    if size_average is not None or reduce is not None:
3789        reduction = _Reduction.legacy_get_string(size_average, reduce)
3790
3791    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
3792    return torch._C._nn.mse_loss(
3793        expanded_input, expanded_target, _Reduction.get_enum(reduction)
3794    )
3795
3796
3797def margin_ranking_loss(
3798    input1: Tensor,
3799    input2: Tensor,
3800    target: Tensor,
3801    margin: float = 0,
3802    size_average: Optional[bool] = None,
3803    reduce: Optional[bool] = None,
3804    reduction: str = "mean",
3805) -> Tensor:  # noqa: D400,D402
3806    r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor
3807
3808    See :class:`~torch.nn.MarginRankingLoss` for details.
3809    """
3810    if has_torch_function_variadic(input1, input2, target):
3811        return handle_torch_function(
3812            margin_ranking_loss,
3813            (input1, input2, target),
3814            input1,
3815            input2,
3816            target,
3817            margin=margin,
3818            size_average=size_average,
3819            reduce=reduce,
3820            reduction=reduction,
3821        )
3822    if size_average is not None or reduce is not None:
3823        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3824    else:
3825        reduction_enum = _Reduction.get_enum(reduction)
3826    if input1.dim() != input2.dim() or input1.dim() != target.dim():
3827        raise RuntimeError(
3828            f"margin_ranking_loss : All input tensors should have same dimension but got sizes: "
3829            f"input1: {input1.size()}, input2: {input2.size()}, target: {target.size()} "
3830        )
3831    return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum)
3832
3833
3834def hinge_embedding_loss(
3835    input: Tensor,
3836    target: Tensor,
3837    margin: float = 1.0,
3838    size_average: Optional[bool] = None,
3839    reduce: Optional[bool] = None,
3840    reduction: str = "mean",
3841) -> Tensor:  # noqa: D400,D402
3842    r"""hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean') -> Tensor
3843
3844    See :class:`~torch.nn.HingeEmbeddingLoss` for details.
3845    """
3846    if has_torch_function_variadic(input, target):
3847        return handle_torch_function(
3848            hinge_embedding_loss,
3849            (input, target),
3850            input,
3851            target,
3852            margin=margin,
3853            size_average=size_average,
3854            reduce=reduce,
3855            reduction=reduction,
3856        )
3857    if size_average is not None or reduce is not None:
3858        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3859    else:
3860        reduction_enum = _Reduction.get_enum(reduction)
3861    return torch.hinge_embedding_loss(input, target, margin, reduction_enum)
3862
3863
3864def multilabel_margin_loss(
3865    input: Tensor,
3866    target: Tensor,
3867    size_average: Optional[bool] = None,
3868    reduce: Optional[bool] = None,
3869    reduction: str = "mean",
3870) -> Tensor:  # noqa: D400,D402
3871    r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
3872
3873    See :class:`~torch.nn.MultiLabelMarginLoss` for details.
3874    """
3875    if has_torch_function_variadic(input, target):
3876        return handle_torch_function(
3877            multilabel_margin_loss,
3878            (input, target),
3879            input,
3880            target,
3881            size_average=size_average,
3882            reduce=reduce,
3883            reduction=reduction,
3884        )
3885    if size_average is not None or reduce is not None:
3886        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3887    else:
3888        reduction_enum = _Reduction.get_enum(reduction)
3889    return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
3890
3891
3892def soft_margin_loss(
3893    input: Tensor,
3894    target: Tensor,
3895    size_average: Optional[bool] = None,
3896    reduce: Optional[bool] = None,
3897    reduction: str = "mean",
3898) -> Tensor:  # noqa: D400,D402
3899    r"""
3900    soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
3901
3902    See :class:`~torch.nn.SoftMarginLoss` for details.
3903    """
3904    if has_torch_function_variadic(input, target):
3905        return handle_torch_function(
3906            soft_margin_loss,
3907            (input, target),
3908            input,
3909            target,
3910            size_average=size_average,
3911            reduce=reduce,
3912            reduction=reduction,
3913        )
3914    if size_average is not None or reduce is not None:
3915        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3916    else:
3917        reduction_enum = _Reduction.get_enum(reduction)
3918    return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
3919
3920
3921def multilabel_soft_margin_loss(
3922    input: Tensor,
3923    target: Tensor,
3924    weight: Optional[Tensor] = None,
3925    size_average: Optional[bool] = None,
3926    reduce: Optional[bool] = None,
3927    reduction: str = "mean",
3928) -> Tensor:  # noqa: D400,D402
3929    r"""multilabel_soft_margin_loss(input, target, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor
3930
3931    See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details.
3932    """
3933    if has_torch_function_variadic(input, target, weight):
3934        return handle_torch_function(
3935            multilabel_soft_margin_loss,
3936            (input, target, weight),
3937            input,
3938            target,
3939            weight=weight,
3940            size_average=size_average,
3941            reduce=reduce,
3942            reduction=reduction,
3943        )
3944    if size_average is not None or reduce is not None:
3945        reduction = _Reduction.legacy_get_string(size_average, reduce)
3946
3947    loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input))
3948
3949    if weight is not None:
3950        loss = loss * weight
3951
3952    class_dim = input.dim() - 1
3953    C = input.size(class_dim)
3954    loss = loss.sum(dim=class_dim) / C  # only return N loss values
3955
3956    if reduction == "none":
3957        ret = loss
3958    elif reduction == "mean":
3959        ret = loss.mean()
3960    elif reduction == "sum":
3961        ret = loss.sum()
3962    else:
3963        ret = input
3964        raise ValueError(reduction + " is not valid")
3965    return ret
3966
3967
3968def cosine_embedding_loss(
3969    input1: Tensor,
3970    input2: Tensor,
3971    target: Tensor,
3972    margin: float = 0,
3973    size_average: Optional[bool] = None,
3974    reduce: Optional[bool] = None,
3975    reduction: str = "mean",
3976) -> Tensor:  # noqa: D400,D402
3977    r"""cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor
3978
3979    See :class:`~torch.nn.CosineEmbeddingLoss` for details.
3980    """
3981    if has_torch_function_variadic(input1, input2, target):
3982        return handle_torch_function(
3983            cosine_embedding_loss,
3984            (input1, input2, target),
3985            input1,
3986            input2,
3987            target,
3988            margin=margin,
3989            size_average=size_average,
3990            reduce=reduce,
3991            reduction=reduction,
3992        )
3993    if size_average is not None or reduce is not None:
3994        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
3995    else:
3996        reduction_enum = _Reduction.get_enum(reduction)
3997    return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum)
3998
3999
4000def multi_margin_loss(
4001    input: Tensor,
4002    target: Tensor,
4003    p: int = 1,
4004    margin: float = 1.0,
4005    weight: Optional[Tensor] = None,
4006    size_average: Optional[bool] = None,
4007    reduce: Optional[bool] = None,
4008    reduction: str = "mean",
4009) -> Tensor:  # noqa: D400,D402
4010    r"""multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor
4011
4012    See :class:`~torch.nn.MultiMarginLoss` for details.
4013    """
4014    if has_torch_function_variadic(input, target, weight):
4015        return handle_torch_function(
4016            multi_margin_loss,
4017            (input, target, weight),
4018            input,
4019            target,
4020            p=p,
4021            margin=margin,
4022            weight=weight,
4023            size_average=size_average,
4024            reduce=reduce,
4025            reduction=reduction,
4026        )
4027    if size_average is not None or reduce is not None:
4028        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
4029    else:
4030        reduction_enum = _Reduction.get_enum(reduction)
4031    if p != 1 and p != 2:
4032        raise ValueError("only p == 1 and p == 2 supported")
4033    if weight is not None:
4034        if weight.dim() != 1:
4035            raise ValueError("weight must be one-dimensional")
4036
4037    return torch._C._nn.multi_margin_loss(
4038        input, target, p, margin, weight, reduction_enum
4039    )
4040
4041
4042pixel_shuffle = _add_docstr(
4043    torch.pixel_shuffle,
4044    r"""
4045pixel_shuffle(input, upscale_factor) -> Tensor
4046
4047Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a
4048tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`.
4049
4050See :class:`~torch.nn.PixelShuffle` for details.
4051
4052Args:
4053    input (Tensor): the input tensor
4054    upscale_factor (int): factor to increase spatial resolution by
4055
4056Examples::
4057
4058    >>> input = torch.randn(1, 9, 4, 4)
4059    >>> output = torch.nn.functional.pixel_shuffle(input, 3)
4060    >>> print(output.size())
4061    torch.Size([1, 1, 12, 12])
4062""",
4063)
4064
4065pixel_unshuffle = _add_docstr(
4066    torch.pixel_unshuffle,
4067    r"""
4068pixel_unshuffle(input, downscale_factor) -> Tensor
4069
4070Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a
4071tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape
4072:math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`.
4073
4074See :class:`~torch.nn.PixelUnshuffle` for details.
4075
4076Args:
4077    input (Tensor): the input tensor
4078    downscale_factor (int): factor to increase spatial resolution by
4079
4080Examples::
4081
4082    >>> input = torch.randn(1, 1, 12, 12)
4083    >>> output = torch.nn.functional.pixel_unshuffle(input, 3)
4084    >>> print(output.size())
4085    torch.Size([1, 9, 4, 4])
4086""",
4087)
4088
4089channel_shuffle = _add_docstr(
4090    torch.channel_shuffle,
4091    r"""
4092channel_shuffle(input, groups) -> Tensor
4093
4094Divide the channels in a tensor of shape :math:`(*, C , H, W)`
4095into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`,
4096while keeping the original tensor shape.
4097
4098See :class:`~torch.nn.ChannelShuffle` for details.
4099
4100Args:
4101    input (Tensor): the input tensor
4102    groups (int): number of groups to divide channels in and rearrange.
4103
4104Examples::
4105
4106    >>> input = torch.randn(1, 4, 2, 2)
4107    >>> print(input)
4108    [[[[1, 2],
4109       [3, 4]],
4110      [[5, 6],
4111       [7, 8]],
4112      [[9, 10],
4113       [11, 12]],
4114      [[13, 14],
4115       [15, 16]],
4116     ]]
4117    >>> output = torch.nn.functional.channel_shuffle(input, 2)
4118    >>> print(output)
4119    [[[[1, 2],
4120       [3, 4]],
4121      [[9, 10],
4122       [11, 12]],
4123      [[5, 6],
4124       [7, 8]],
4125      [[13, 14],
4126       [15, 16]],
4127     ]]
4128""",
4129)
4130
4131native_channel_shuffle = _add_docstr(
4132    torch.native_channel_shuffle,
4133    r"""
4134native_channel_shuffle(input, groups) -> Tensor
4135
4136Native kernel level implementation of the `channel_shuffle`.
4137This function might become private in future releases, use with caution.
4138
4139Divide the channels in a tensor of shape :math:`(*, C , H, W)`
4140into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`,
4141while keeping the original tensor shape.
4142
4143See :class:`~torch.nn.ChannelShuffle` for details.
4144
4145Args:
4146    input (Tensor): the input tensor
4147    groups (int): number of groups to divide channels in and rearrange.
4148
4149Examples::
4150
4151    >>> input = torch.randn(1, 4, 2, 2)
4152    >>> print(input)
4153    [[[[1, 2],
4154       [3, 4]],
4155      [[5, 6],
4156       [7, 8]],
4157      [[9, 10],
4158       [11, 12]],
4159      [[13, 14],
4160       [15, 16]],
4161     ]]
4162    >>> output = torch.nn.functional.native_channel_shuffle(input, 2)
4163    >>> print(output)
4164    [[[[1, 2],
4165       [3, 4]],
4166      [[9, 10],
4167       [11, 12]],
4168      [[5, 6],
4169       [7, 8]],
4170      [[13, 14],
4171       [15, 16]],
4172     ]]
4173""",
4174)
4175
4176
4177@_overload
4178def upsample(  # noqa: F811
4179    input: Tensor,
4180    size: Optional[int] = None,
4181    scale_factor: Optional[float] = None,
4182    mode: str = "nearest",
4183    align_corners: Optional[bool] = None,
4184) -> Tensor:  # noqa: B950
4185    pass
4186
4187
4188@_overload
4189def upsample(  # noqa: F811
4190    input: Tensor,
4191    size: Optional[List[int]] = None,
4192    scale_factor: Optional[float] = None,
4193    mode: str = "nearest",
4194    align_corners: Optional[bool] = None,
4195) -> Tensor:  # noqa: B950
4196    pass
4197
4198
4199def upsample(  # noqa: F811
4200    input,
4201    size=None,
4202    scale_factor=None,
4203    mode="nearest",
4204    align_corners=None,
4205):
4206    r"""Upsample input.
4207
4208    Provided tensor is upsampled to either the given :attr:`size` or the given
4209    :attr:`scale_factor`
4210
4211    .. warning::
4212        This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
4213        This is equivalent with ``nn.functional.interpolate(...)``.
4214
4215    Note:
4216        {backward_reproducibility_note}
4217
4218    The algorithm used for upsampling is determined by :attr:`mode`.
4219
4220    Currently temporal, spatial and volumetric upsampling are supported, i.e.
4221    expected inputs are 3-D, 4-D or 5-D in shape.
4222
4223    The input dimensions are interpreted in the form:
4224    `mini-batch x channels x [optional depth] x [optional height] x width`.
4225
4226    The modes available for upsampling are: `nearest`, `linear` (3D-only),
4227    `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only)
4228
4229    Args:
4230        input (Tensor): the input tensor
4231        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
4232            output spatial size.
4233        scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
4234        mode (str): algorithm used for upsampling:
4235            ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
4236            ``'trilinear'``. Default: ``'nearest'``
4237        align_corners (bool, optional): Geometrically, we consider the pixels of the
4238            input and output as squares rather than points.
4239            If set to ``True``, the input and output tensors are aligned by the
4240            center points of their corner pixels, preserving the values at the corner pixels.
4241            If set to ``False``, the input and output tensors are aligned by the corner
4242            points of their corner pixels, and the interpolation uses edge value padding
4243            for out-of-boundary values, making this operation *independent* of input size
4244            when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
4245            is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.
4246            Default: ``False``
4247
4248    .. note::
4249        With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce
4250        negative values or values greater than 255 for images.
4251        Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot
4252        when displaying the image.
4253
4254    .. warning::
4255        With ``align_corners = True``, the linearly interpolating modes
4256        (`linear`, `bilinear`, and `trilinear`) don't proportionally align the
4257        output and input pixels, and thus the output values can depend on the
4258        input size. This was the default behavior for these modes up to version
4259        0.3.1. Since then, the default behavior is ``align_corners = False``.
4260        See :class:`~torch.nn.Upsample` for concrete examples on how this
4261        affects the outputs.
4262
4263    """
4264    warnings.warn(
4265        "`nn.functional.upsample` is deprecated. "
4266        "Use `nn.functional.interpolate` instead.",
4267        stacklevel=2,
4268    )
4269    return interpolate(input, size, scale_factor, mode, align_corners)
4270
4271
4272if upsample.__doc__:
4273    upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes)
4274
4275
4276def _is_integer(x) -> bool:
4277    r"""Type check the input number is an integer.
4278
4279    Will return True for int, SymInt, Numpy integers and Tensors with integer elements.
4280    """
4281    if isinstance(x, (int, torch.SymInt)):
4282        return True
4283    if np is not None and isinstance(x, np.integer):
4284        return True
4285    return isinstance(x, Tensor) and not x.is_floating_point()
4286
4287
4288@_overload
4289def interpolate(  # noqa: F811
4290    input: Tensor,
4291    size: Optional[int] = None,
4292    scale_factor: Optional[List[float]] = None,
4293    mode: str = "nearest",
4294    align_corners: Optional[bool] = None,
4295    recompute_scale_factor: Optional[bool] = None,
4296    antialias: bool = False,
4297) -> Tensor:  # noqa: B950
4298    pass
4299
4300
4301@_overload
4302def interpolate(  # noqa: F811
4303    input: Tensor,
4304    size: Optional[List[int]] = None,
4305    scale_factor: Optional[List[float]] = None,
4306    mode: str = "nearest",
4307    align_corners: Optional[bool] = None,
4308    recompute_scale_factor: Optional[bool] = None,
4309    antialias: bool = False,
4310) -> Tensor:  # noqa: B950
4311    pass
4312
4313
4314@_overload
4315def interpolate(  # noqa: F811
4316    input: Tensor,
4317    size: Optional[int] = None,
4318    scale_factor: Optional[float] = None,
4319    mode: str = "nearest",
4320    align_corners: Optional[bool] = None,
4321    recompute_scale_factor: Optional[bool] = None,
4322    antialias: bool = False,
4323) -> Tensor:  # noqa: B950
4324    pass
4325
4326
4327@_overload
4328def interpolate(  # noqa: F811
4329    input: Tensor,
4330    size: Optional[List[int]] = None,
4331    scale_factor: Optional[float] = None,
4332    mode: str = "nearest",
4333    align_corners: Optional[bool] = None,
4334    recompute_scale_factor: Optional[bool] = None,
4335    antialias: bool = False,
4336) -> Tensor:
4337    pass
4338
4339
4340def interpolate(  # noqa: F811
4341    input: Tensor,
4342    size: Optional[int] = None,
4343    scale_factor: Optional[List[float]] = None,
4344    mode: str = "nearest",
4345    align_corners: Optional[bool] = None,
4346    recompute_scale_factor: Optional[bool] = None,
4347    antialias: bool = False,
4348) -> Tensor:  # noqa: B950
4349    r"""Down/up samples the input.
4350
4351    Tensor interpolated to either the given :attr:`size` or the given
4352    :attr:`scale_factor`
4353
4354    The algorithm used for interpolation is determined by :attr:`mode`.
4355
4356    Currently temporal, spatial and volumetric sampling are supported, i.e.
4357    expected inputs are 3-D, 4-D or 5-D in shape.
4358
4359    The input dimensions are interpreted in the form:
4360    `mini-batch x channels x [optional depth] x [optional height] x width`.
4361
4362    The modes available for resizing are: `nearest`, `linear` (3D-only),
4363    `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`, `nearest-exact`
4364
4365    Args:
4366        input (Tensor): the input tensor
4367        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
4368            output spatial size.
4369        scale_factor (float or Tuple[float]): multiplier for spatial size. If `scale_factor` is a tuple,
4370            its length has to match the number of spatial dimensions; `input.dim() - 2`.
4371        mode (str): algorithm used for upsampling:
4372            ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
4373            ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'``
4374        align_corners (bool, optional): Geometrically, we consider the pixels of the
4375            input and output as squares rather than points.
4376            If set to ``True``, the input and output tensors are aligned by the
4377            center points of their corner pixels, preserving the values at the corner pixels.
4378            If set to ``False``, the input and output tensors are aligned by the corner
4379            points of their corner pixels, and the interpolation uses edge value padding
4380            for out-of-boundary values, making this operation *independent* of input size
4381            when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
4382            is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.
4383            Default: ``False``
4384        recompute_scale_factor (bool, optional): recompute the scale_factor for use in the
4385            interpolation calculation. If `recompute_scale_factor` is ``True``, then
4386            `scale_factor` must be passed in and `scale_factor` is used to compute the
4387            output `size`. The computed output `size` will be used to infer new scales for
4388            the interpolation. Note that when `scale_factor` is floating-point, it may differ
4389            from the recomputed `scale_factor` due to rounding and precision issues.
4390            If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will
4391            be used directly for interpolation. Default: ``None``.
4392        antialias (bool, optional): flag to apply anti-aliasing. Default: ``False``. Using anti-alias
4393            option together with ``align_corners=False``, interpolation result would match Pillow
4394            result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``.
4395
4396    .. note::
4397        With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce
4398        negative values or values greater than 255 for images.
4399        Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot
4400        when displaying the image.
4401
4402    .. note::
4403        Mode ``mode='nearest-exact'`` matches Scikit-Image and PIL nearest neighbours interpolation
4404        algorithms and fixes known issues with ``mode='nearest'``. This mode is introduced to keep
4405        backward compatibility.
4406        Mode ``mode='nearest'`` matches buggy OpenCV's ``INTER_NEAREST`` interpolation algorithm.
4407
4408    .. note::
4409        The gradients for the dtype ``float16`` on CUDA may be inaccurate in the upsample operation
4410        when using modes ``['linear', 'bilinear', 'bicubic', 'trilinear', 'area']``.
4411        For more details, please refer to the discussion in
4412        `issue#104157 <https://github.com/pytorch/pytorch/issues/104157>`_.
4413
4414    Note:
4415        {backward_reproducibility_note}
4416    """
4417    if has_torch_function_unary(input):
4418        return handle_torch_function(
4419            interpolate,
4420            (input,),
4421            input,
4422            size=size,
4423            scale_factor=scale_factor,
4424            mode=mode,
4425            align_corners=align_corners,
4426            recompute_scale_factor=recompute_scale_factor,
4427            antialias=antialias,
4428        )
4429
4430    if mode in ("nearest", "area", "nearest-exact"):
4431        if align_corners is not None:
4432            raise ValueError(
4433                "align_corners option can only be set with the "
4434                "interpolating modes: linear | bilinear | bicubic | trilinear"
4435            )
4436    else:
4437        if align_corners is None:
4438            align_corners = False
4439
4440    dim = input.dim() - 2  # Number of spatial dimensions.
4441
4442    # Process size and scale_factor.  Validate that exactly one is set.
4443    # Validate its length if it is a list, or expand it if it is a scalar.
4444    # After this block, exactly one of output_size and scale_factors will
4445    # be non-None, and it will be a list (or tuple).
4446    if size is not None and scale_factor is not None:
4447        raise ValueError("only one of size or scale_factor should be defined")
4448    elif size is not None:
4449        assert scale_factor is None
4450        scale_factors = None
4451        if isinstance(size, (list, tuple)):
4452            if len(size) != dim:
4453                raise ValueError(
4454                    "Input and output must have the same number of spatial dimensions, but got "
4455                    f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. "
4456                    "Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
4457                    "output size in (o1, o2, ...,oK) format."
4458                )
4459            if not torch.jit.is_scripting():
4460                if not all(_is_integer(x) for x in size):
4461                    raise TypeError(
4462                        "expected size to be one of int or Tuple[int] or Tuple[int, int] or "
4463                        f"Tuple[int, int, int], but got size with types {[type(x) for x in size]}"
4464                    )
4465            output_size = size
4466        else:
4467            output_size = [size for _ in range(dim)]
4468    elif scale_factor is not None:
4469        assert size is None
4470        output_size = None
4471        if isinstance(scale_factor, (list, tuple)):
4472            if len(scale_factor) != dim:
4473                raise ValueError(
4474                    "Input and scale_factor must have the same number of spatial dimensions, but "
4475                    f"got input with spatial dimensions of {list(input.shape[2:])} and "
4476                    f"scale_factor of shape {scale_factor}. "
4477                    "Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
4478                    "scale_factor in (s1, s2, ...,sK) format."
4479                )
4480            scale_factors = scale_factor
4481        else:
4482            scale_factors = [scale_factor for _ in range(dim)]
4483    else:
4484        raise ValueError("either size or scale_factor should be defined")
4485
4486    if (
4487        recompute_scale_factor is not None
4488        and recompute_scale_factor
4489        and size is not None
4490    ):
4491        raise ValueError(
4492            "recompute_scale_factor is not meaningful with an explicit size."
4493        )
4494
4495    # "area" mode always requires an explicit size rather than scale factor.
4496    # Re-use the recompute_scale_factor code path.
4497    if mode == "area" and output_size is None:
4498        recompute_scale_factor = True
4499
4500    if recompute_scale_factor is not None and recompute_scale_factor:
4501        # We compute output_size here, then un-set scale_factors.
4502        # The C++ code will recompute it based on the (integer) output size.
4503        assert scale_factors is not None
4504        if not torch.jit.is_scripting() and torch._C._get_tracing_state():
4505            # make scale_factor a tensor in tracing so constant doesn't get baked in
4506            output_size = [
4507                (
4508                    torch.floor(
4509                        (
4510                            input.size(i + 2).float()
4511                            * torch.tensor(scale_factors[i], dtype=torch.float32)
4512                        ).float()
4513                    )
4514                )
4515                for i in range(dim)
4516            ]
4517        elif torch.jit.is_scripting():
4518            output_size = [
4519                int(math.floor(float(input.size(i + 2)) * scale_factors[i]))
4520                for i in range(dim)
4521            ]
4522        else:
4523            output_size = [
4524                _sym_int(input.size(i + 2) * scale_factors[i]) for i in range(dim)
4525            ]
4526        scale_factors = None
4527
4528    if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4):
4529        raise ValueError(
4530            "Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input"
4531        )
4532
4533    if input.dim() == 3 and mode == "nearest":
4534        return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
4535    if input.dim() == 4 and mode == "nearest":
4536        return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
4537    if input.dim() == 5 and mode == "nearest":
4538        return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
4539
4540    if input.dim() == 3 and mode == "nearest-exact":
4541        return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
4542    if input.dim() == 4 and mode == "nearest-exact":
4543        return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
4544    if input.dim() == 5 and mode == "nearest-exact":
4545        return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
4546
4547    if input.dim() == 3 and mode == "area":
4548        assert output_size is not None
4549        return adaptive_avg_pool1d(input, output_size)
4550    if input.dim() == 4 and mode == "area":
4551        assert output_size is not None
4552        return adaptive_avg_pool2d(input, output_size)
4553    if input.dim() == 5 and mode == "area":
4554        assert output_size is not None
4555        return adaptive_avg_pool3d(input, output_size)
4556
4557    if input.dim() == 3 and mode == "linear":
4558        assert align_corners is not None
4559        return torch._C._nn.upsample_linear1d(
4560            input, output_size, align_corners, scale_factors
4561        )
4562    if input.dim() == 4 and mode == "bilinear":
4563        assert align_corners is not None
4564        if antialias:
4565            return torch._C._nn._upsample_bilinear2d_aa(
4566                input, output_size, align_corners, scale_factors
4567            )
4568        # Two levels are necessary to prevent TorchScript from touching
4569        # are_deterministic_algorithms_enabled.
4570        if not torch.jit.is_scripting():
4571            if torch.are_deterministic_algorithms_enabled() and (
4572                input.is_cuda or input.is_xpu
4573            ):
4574                # Use slow decomp whose backward will be in terms of index_put
4575                # importlib is required because the import cannot be top level
4576                # (cycle) and cannot be nested (TS doesn't support)
4577                return importlib.import_module(
4578                    "torch._decomp.decompositions"
4579                )._upsample_linear_vec(input, output_size, align_corners, scale_factors)
4580        return torch._C._nn.upsample_bilinear2d(
4581            input, output_size, align_corners, scale_factors
4582        )
4583    if input.dim() == 5 and mode == "trilinear":
4584        assert align_corners is not None
4585        return torch._C._nn.upsample_trilinear3d(
4586            input, output_size, align_corners, scale_factors
4587        )
4588    if input.dim() == 4 and mode == "bicubic":
4589        assert align_corners is not None
4590        if antialias:
4591            return torch._C._nn._upsample_bicubic2d_aa(
4592                input, output_size, align_corners, scale_factors
4593            )
4594        return torch._C._nn.upsample_bicubic2d(
4595            input, output_size, align_corners, scale_factors
4596        )
4597
4598    if input.dim() == 3 and mode == "bilinear":
4599        raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input")
4600    if input.dim() == 3 and mode == "trilinear":
4601        raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input")
4602    if input.dim() == 4 and mode == "linear":
4603        raise NotImplementedError("Got 4D input, but linear mode needs 3D input")
4604    if input.dim() == 4 and mode == "trilinear":
4605        raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
4606    if input.dim() == 5 and mode == "linear":
4607        raise NotImplementedError("Got 5D input, but linear mode needs 3D input")
4608    if input.dim() == 5 and mode == "bilinear":
4609        raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
4610
4611    raise NotImplementedError(
4612        "Input Error: Only 3D, 4D and 5D input Tensors supported"
4613        f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact"
4614        f" (got {mode})"
4615    )
4616
4617
4618if interpolate.__doc__:
4619    interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes)
4620
4621
4622@_overload
4623def upsample_nearest(  # noqa: F811
4624    input: Tensor,
4625    size: Optional[int] = None,
4626    scale_factor: Optional[float] = None,
4627) -> Tensor:
4628    pass
4629
4630
4631@_overload
4632def upsample_nearest(  # noqa: F811
4633    input: Tensor,
4634    size: Optional[List[int]] = None,
4635    scale_factor: Optional[float] = None,
4636) -> Tensor:
4637    pass
4638
4639
4640def upsample_nearest(input, size=None, scale_factor=None):  # noqa: F811
4641    r"""Upsamples the input, using nearest neighbours' pixel values.
4642
4643    .. warning::
4644        This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
4645        This is equivalent with ``nn.functional.interpolate(..., mode='nearest')``.
4646
4647    Currently spatial and volumetric upsampling are supported (i.e. expected
4648    inputs are 4 or 5 dimensional).
4649
4650    Args:
4651        input (Tensor): input
4652        size (int or Tuple[int, int] or Tuple[int, int, int]): output spatia
4653            size.
4654        scale_factor (int): multiplier for spatial size. Has to be an integer.
4655
4656    Note:
4657        {backward_reproducibility_note}
4658    """
4659    # DeprecationWarning is ignored by default
4660    warnings.warn(
4661        "`nn.functional.upsample_nearest` is deprecated. "
4662        "Use `nn.functional.interpolate` instead.",
4663        stacklevel=2,
4664    )
4665    return interpolate(input, size, scale_factor, mode="nearest")
4666
4667
4668if upsample_nearest.__doc__:
4669    upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes)
4670
4671
4672@_overload
4673def upsample_bilinear(  # noqa: F811
4674    input: Tensor,
4675    size: Optional[int] = None,
4676    scale_factor: Optional[float] = None,
4677) -> Tensor:
4678    pass
4679
4680
4681@_overload
4682def upsample_bilinear(  # noqa: F811
4683    input: Tensor,
4684    size: Optional[List[int]] = None,
4685    scale_factor: Optional[float] = None,
4686) -> Tensor:
4687    pass
4688
4689
4690@_overload
4691def upsample_bilinear(  # noqa: F811
4692    input: Tensor,
4693    size: Optional[int] = None,
4694    scale_factor: Optional[List[float]] = None,
4695) -> Tensor:
4696    pass
4697
4698
4699@_overload
4700def upsample_bilinear(  # noqa: F811
4701    input: Tensor,
4702    size: Optional[List[int]] = None,
4703    scale_factor: Optional[List[float]] = None,
4704) -> Tensor:
4705    pass
4706
4707
4708def upsample_bilinear(input, size=None, scale_factor=None):  # noqa: F811
4709    r"""Upsamples the input, using bilinear upsampling.
4710
4711    .. warning::
4712        This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
4713        This is equivalent with
4714        ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``.
4715
4716    Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` fo
4717    volumetric (5 dimensional) inputs.
4718
4719    Args:
4720        input (Tensor): input
4721        size (int or Tuple[int, int]): output spatial size.
4722        scale_factor (int or Tuple[int, int]): multiplier for spatial size
4723
4724    Note:
4725        {backward_reproducibility_note}
4726    """
4727    # DeprecationWarning is ignored by default
4728    warnings.warn(
4729        "`nn.functional.upsample_bilinear` is deprecated. "
4730        "Use `nn.functional.interpolate` instead.",
4731        stacklevel=2,
4732    )
4733    return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True)
4734
4735
4736if upsample_bilinear.__doc__:
4737    upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format(
4738        **reproducibility_notes
4739    )
4740
4741GRID_SAMPLE_INTERPOLATION_MODES = {
4742    "bilinear": 0,
4743    "nearest": 1,
4744    "bicubic": 2,
4745}
4746
4747GRID_SAMPLE_PADDING_MODES = {
4748    "zeros": 0,
4749    "border": 1,
4750    "reflection": 2,
4751}
4752
4753
4754def grid_sample(
4755    input: Tensor,
4756    grid: Tensor,
4757    mode: str = "bilinear",
4758    padding_mode: str = "zeros",
4759    align_corners: Optional[bool] = None,
4760) -> Tensor:
4761    r"""Compute grid sample.
4762
4763    Given an :attr:`input` and a flow-field :attr:`grid`, computes the
4764    ``output`` using :attr:`input` values and pixel locations from :attr:`grid`.
4765
4766    Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are
4767    supported.
4768
4769    In the spatial (4-D) case, for :attr:`input` with shape
4770    :math:`(N, C, H_\text{in}, W_\text{in})` and :attr:`grid` with shape
4771    :math:`(N, H_\text{out}, W_\text{out}, 2)`, the output will have shape
4772    :math:`(N, C, H_\text{out}, W_\text{out})`.
4773
4774    For each output location ``output[n, :, h, w]``, the size-2 vector
4775    ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``,
4776    which are used to interpolate the output value ``output[n, :, h, w]``.
4777    In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the
4778    ``x``, ``y``, ``z`` pixel locations for interpolating
4779    ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or
4780    ``bilinear`` interpolation method to sample the input pixels.
4781
4782    :attr:`grid` specifies the sampling pixel locations normalized by the
4783    :attr:`input` spatial dimensions. Therefore, it should have most values in
4784    the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the
4785    left-top pixel of :attr:`input`, and values  ``x = 1, y = 1`` is the
4786    right-bottom pixel of :attr:`input`.
4787
4788    If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding
4789    outputs are handled as defined by :attr:`padding_mode`. Options are
4790
4791        * ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations,
4792        * ``padding_mode="border"``: use border values for out-of-bound grid locations,
4793        * ``padding_mode="reflection"``: use values at locations reflected by
4794          the border for out-of-bound grid locations. For location far away
4795          from the border, it will keep being reflected until becoming in bound,
4796          e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1``
4797          and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes
4798          ``x'' = -0.5``.
4799
4800    Note:
4801        This function is often used in conjunction with :func:`affine_grid`
4802        to build `Spatial Transformer Networks`_ .
4803
4804    Note:
4805        When using the CUDA backend, this operation may induce nondeterministic
4806        behaviour in its backward pass that is not easily switched off.
4807        Please see the notes on :doc:`/notes/randomness` for background.
4808
4809    Note:
4810        NaN values in :attr:`grid` would be interpreted as ``-1``.
4811
4812    Args:
4813        input (Tensor): input of shape :math:`(N, C, H_\text{in}, W_\text{in})` (4-D case)
4814                        or :math:`(N, C, D_\text{in}, H_\text{in}, W_\text{in})` (5-D case)
4815        grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case)
4816                       or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case)
4817        mode (str): interpolation mode to calculate output values
4818            ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'``
4819            Note: ``mode='bicubic'`` supports only 4-D input.
4820            When ``mode='bilinear'`` and the input is 5-D, the interpolation mode
4821            used internally will actually be trilinear. However, when the input is 4-D,
4822            the interpolation mode will legitimately be bilinear.
4823        padding_mode (str): padding mode for outside grid values
4824            ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'``
4825        align_corners (bool, optional): Geometrically, we consider the pixels of the
4826            input  as squares rather than points.
4827            If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring
4828            to the center points of the input's corner pixels. If set to ``False``, they
4829            are instead considered as referring to the corner points of the input's corner
4830            pixels, making the sampling more resolution agnostic.
4831            This option parallels the ``align_corners`` option in
4832            :func:`interpolate`, and so whichever option is used here
4833            should also be used there to resize the input image before grid sampling.
4834            Default: ``False``
4835
4836    Returns:
4837        output (Tensor): output Tensor
4838
4839    .. _`Spatial Transformer Networks`:
4840        https://arxiv.org/abs/1506.02025
4841
4842    .. warning::
4843        When ``align_corners = True``, the grid positions depend on the pixel
4844        size relative to the input image size, and so the locations sampled by
4845        :func:`grid_sample` will differ for the same input given at different
4846        resolutions (that is, after being upsampled or downsampled).
4847        The default behavior up to version 1.2.0 was ``align_corners = True``.
4848        Since then, the default behavior has been changed to ``align_corners = False``,
4849        in order to bring it in line with the default for :func:`interpolate`.
4850
4851    .. note::
4852        ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`.
4853        The constant :math:`\alpha` might be different from packages to packages.
4854        For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively.
4855        This algorithm may "overshoot" the range of values it's interpolating.
4856        For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255].
4857        Clamp the results with :func:`torch.clamp` to ensure they are within the valid range.
4858    .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation
4859    .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51
4860    .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908
4861    """
4862    if has_torch_function_variadic(input, grid):
4863        return handle_torch_function(
4864            grid_sample,
4865            (input, grid),
4866            input,
4867            grid,
4868            mode=mode,
4869            padding_mode=padding_mode,
4870            align_corners=align_corners,
4871        )
4872    if mode != "bilinear" and mode != "nearest" and mode != "bicubic":
4873        raise ValueError(
4874            f"nn.functional.grid_sample(): expected mode to be 'bilinear', 'nearest' or 'bicubic', but got: '{mode}'"
4875        )
4876    if (
4877        padding_mode != "zeros"
4878        and padding_mode != "border"
4879        and padding_mode != "reflection"
4880    ):
4881        raise ValueError(
4882            "nn.functional.grid_sample(): expected padding_mode "
4883            "to be 'zeros', 'border', or 'reflection', "
4884            f"but got: '{padding_mode}'"
4885        )
4886
4887    if mode == "bilinear":
4888        mode_enum = 0
4889    elif mode == "nearest":
4890        mode_enum = 1
4891    else:  # mode == 'bicubic'
4892        mode_enum = 2
4893
4894    if padding_mode == "zeros":
4895        padding_mode_enum = 0
4896    elif padding_mode == "border":
4897        padding_mode_enum = 1
4898    else:  # padding_mode == 'reflection'
4899        padding_mode_enum = 2
4900
4901    if align_corners is None:
4902        warnings.warn(
4903            "Default grid_sample and affine_grid behavior has changed "
4904            "to align_corners=False since 1.3.0. Please specify "
4905            "align_corners=True if the old behavior is desired. "
4906            "See the documentation of grid_sample for details."
4907        )
4908        align_corners = False
4909
4910    return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
4911
4912
4913def affine_grid(
4914    theta: Tensor,
4915    size: List[int],
4916    align_corners: Optional[bool] = None,
4917) -> Tensor:
4918    r"""Generate 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`.
4919
4920    .. note::
4921        This function is often used in conjunction with :func:`grid_sample`
4922        to build `Spatial Transformer Networks`_ .
4923
4924    Args:
4925        theta (Tensor): input batch of affine matrices with shape
4926            (:math:`N \times 2 \times 3`) for 2D or
4927            (:math:`N \times 3 \times 4`) for 3D
4928        size (torch.Size): the target output image size.
4929            (:math:`N \times C \times H \times W` for 2D or
4930            :math:`N \times C \times D \times H \times W` for 3D)
4931            Example: torch.Size((32, 3, 24, 24))
4932        align_corners (bool, optional): if ``True``, consider ``-1`` and ``1``
4933            to refer to the centers of the corner pixels rather than the image corners.
4934            Refer to :func:`grid_sample` for a more complete description.
4935            A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample`
4936            with the same setting for this option.
4937            Default: ``False``
4938
4939    Returns:
4940        output (Tensor): output Tensor of size (:math:`N \times H \times W \times 2`)
4941
4942    .. _`Spatial Transformer Networks`:
4943        https://arxiv.org/abs/1506.02025
4944
4945    .. warning::
4946        When ``align_corners = True``, the grid positions depend on the pixel
4947        size relative to the input image size, and so the locations sampled by
4948        :func:`grid_sample` will differ for the same input given at different
4949        resolutions (that is, after being upsampled or downsampled).
4950        The default behavior up to version 1.2.0 was ``align_corners = True``.
4951        Since then, the default behavior has been changed to ``align_corners = False``,
4952        in order to bring it in line with the default for :func:`interpolate`.
4953    .. warning::
4954        When ``align_corners = True``, 2D affine transforms on 1D data and
4955        3D affine transforms on 2D data (that is, when one of the spatial
4956        dimensions has unit size) are ill-defined, and not an intended use case.
4957        This is not a problem when ``align_corners = False``.
4958        Up to version 1.2.0, all grid points along a unit dimension were
4959        considered arbitrarily to be at ``-1``.
4960        From version 1.3.0, under ``align_corners = True`` all grid points
4961        along a unit dimension are considered to be at ``0``
4962        (the center of the input image).
4963    """
4964    if has_torch_function_unary(theta):
4965        return handle_torch_function(
4966            affine_grid, (theta,), theta, size, align_corners=align_corners
4967        )
4968    if align_corners is None:
4969        warnings.warn(
4970            "Default grid_sample and affine_grid behavior has changed "
4971            "to align_corners=False since 1.3.0. Please specify "
4972            "align_corners=True if the old behavior is desired. "
4973            "See the documentation of grid_sample for details."
4974        )
4975        align_corners = False
4976
4977    # enforce floating point dtype on theta
4978    if not theta.is_floating_point():
4979        raise ValueError(
4980            f"Expected theta to have floating point type, but got {theta.dtype}"
4981        )
4982    # check that shapes and sizes match
4983    if len(size) == 4:
4984        if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3:
4985            raise ValueError(
4986                f"Expected a batch of 2D affine matrices of shape Nx2x3 for size {size}. Got {theta.shape}."
4987            )
4988        spatial_size = size[-2:]  # spatial dimension sizes
4989    elif len(size) == 5:
4990        if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4:
4991            raise ValueError(
4992                f"Expected a batch of 3D affine matrices of shape Nx3x4 for size {size}. Got {theta.shape}."
4993            )
4994        spatial_size = size[-3:]  # spatial dimension sizes
4995    else:
4996        raise NotImplementedError(
4997            "affine_grid only supports 4D and 5D sizes, "
4998            "for 2D and 3D affine transforms, respectively. "
4999            f"Got size {size}."
5000        )
5001    # check for empty span
5002    if align_corners and min(spatial_size) == 1:
5003        warnings.warn(
5004            "Since version 1.3.0, affine_grid behavior has changed "
5005            "for unit-size grids when align_corners=True. "
5006            "This is not an intended use case of affine_grid. "
5007            "See the documentation of affine_grid for details."
5008        )
5009    elif min(size) <= 0:
5010        raise ValueError(f"Expected non-zero, positive output size. Got {size}")
5011
5012    return torch.affine_grid_generator(theta, size, align_corners)
5013
5014
5015def pad(
5016    input: Tensor,
5017    pad: List[int],
5018    mode: str = "constant",
5019    value: Optional[float] = None,
5020) -> Tensor:
5021    r"""
5022    pad(input, pad, mode="constant", value=None) -> Tensor
5023
5024    Pads tensor.
5025
5026    Padding size:
5027        The padding size by which to pad some dimensions of :attr:`input`
5028        are described starting from the last dimension and moving forward.
5029        :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions
5030        of ``input`` will be padded.
5031        For example, to pad only the last dimension of the input tensor, then
5032        :attr:`pad` has the form
5033        :math:`(\text{padding\_left}, \text{padding\_right})`;
5034        to pad the last 2 dimensions of the input tensor, then use
5035        :math:`(\text{padding\_left}, \text{padding\_right},`
5036        :math:`\text{padding\_top}, \text{padding\_bottom})`;
5037        to pad the last 3 dimensions, use
5038        :math:`(\text{padding\_left}, \text{padding\_right},`
5039        :math:`\text{padding\_top}, \text{padding\_bottom}`
5040        :math:`\text{padding\_front}, \text{padding\_back})`.
5041
5042    Padding mode:
5043        See :class:`torch.nn.CircularPad2d`, :class:`torch.nn.ConstantPad2d`,
5044        :class:`torch.nn.ReflectionPad2d`, and :class:`torch.nn.ReplicationPad2d`
5045        for concrete examples on how each of the padding modes works. Constant
5046        padding is implemented for arbitrary dimensions. Circular, replicate and
5047        reflection padding are implemented for padding the last 3 dimensions of a
5048        4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor,
5049        or the last dimension of a 2D or 3D input tensor.
5050
5051    Note:
5052        When using the CUDA backend, this operation may induce nondeterministic
5053        behaviour in its backward pass that is not easily switched off.
5054        Please see the notes on :doc:`/notes/randomness` for background.
5055
5056    Args:
5057        input (Tensor): N-dimensional tensor
5058        pad (tuple): m-elements tuple, where
5059            :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even.
5060        mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
5061            Default: ``'constant'``
5062        value: fill value for ``'constant'`` padding. Default: ``0``
5063
5064    Examples::
5065
5066        >>> t4d = torch.empty(3, 3, 4, 2)
5067        >>> p1d = (1, 1) # pad last dim by 1 on each side
5068        >>> out = F.pad(t4d, p1d, "constant", 0)  # effectively zero padding
5069        >>> print(out.size())
5070        torch.Size([3, 3, 4, 4])
5071        >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
5072        >>> out = F.pad(t4d, p2d, "constant", 0)
5073        >>> print(out.size())
5074        torch.Size([3, 3, 8, 4])
5075        >>> t4d = torch.empty(3, 3, 4, 2)
5076        >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3)
5077        >>> out = F.pad(t4d, p3d, "constant", 0)
5078        >>> print(out.size())
5079        torch.Size([3, 9, 7, 3])
5080    """
5081    if has_torch_function_unary(input):
5082        return handle_torch_function(
5083            torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value
5084        )
5085    if not torch.jit.is_scripting():
5086        if torch.are_deterministic_algorithms_enabled() and (
5087            input.is_cuda or input.is_xpu
5088        ):
5089            if mode == "replicate":
5090                # Use slow decomp whose backward will be in terms of index_put.
5091                # importlib is required because the import cannot be top level
5092                # (cycle) and cannot be nested (TS doesn't support)
5093                return importlib.import_module(
5094                    "torch._decomp.decompositions"
5095                )._replication_pad(input, pad)
5096    return torch._C._nn.pad(input, pad, mode, value)
5097
5098
5099# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798
5100pad.__module__ = "torch.nn.functional"
5101
5102# distance
5103
5104
5105pairwise_distance = _add_docstr(
5106    torch.pairwise_distance,
5107    r"""
5108pairwise_distance(x1, x2, p=2.0, eps=1e-6, keepdim=False) -> Tensor
5109
5110See :class:`torch.nn.PairwiseDistance` for details
5111""",
5112)
5113
5114
5115pdist = _add_docstr(
5116    torch.pdist,
5117    r"""
5118pdist(input, p=2) -> Tensor
5119
5120Computes the p-norm distance between every pair of row vectors in the input.
5121This is identical to the upper triangular portion, excluding the diagonal, of
5122`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster
5123if the rows are contiguous.
5124
5125If input has shape :math:`N \times M` then the output will have shape
5126:math:`\frac{1}{2} N (N - 1)`.
5127
5128This function is equivalent to ``scipy.spatial.distance.pdist(input,
5129'minkowski', p=p)`` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is
5130equivalent to ``scipy.spatial.distance.pdist(input, 'hamming') * M``.
5131When :math:`p = \infty`, the closest scipy function is
5132``scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())``.
5133
5134Args:
5135    input: input tensor of shape :math:`N \times M`.
5136    p: p value for the p-norm distance to calculate between each vector pair
5137        :math:`\in [0, \infty]`.
5138""",
5139)
5140
5141
5142cosine_similarity = _add_docstr(
5143    torch.cosine_similarity,
5144    r"""
5145cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor
5146
5147Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable
5148to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is
5149squeezed (see :func:`torch.squeeze`), resulting in the
5150output tensor having 1 fewer dimension.
5151
5152.. math ::
5153    \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2, \epsilon) \cdot \max(\Vert x_2 \Vert _2, \epsilon)}
5154
5155Supports :ref:`type promotion <type-promotion-doc>`.
5156
5157Args:
5158    x1 (Tensor): First input.
5159    x2 (Tensor): Second input.
5160    dim (int, optional): Dimension along which cosine similarity is computed. Default: 1
5161    eps (float, optional): Small value to avoid division by zero.
5162        Default: 1e-8
5163
5164Example::
5165
5166    >>> input1 = torch.randn(100, 128)
5167    >>> input2 = torch.randn(100, 128)
5168    >>> output = F.cosine_similarity(input1, input2)
5169    >>> print(output)
5170""",
5171)
5172
5173
5174one_hot = _add_docstr(
5175    torch._C._nn.one_hot,
5176    r"""
5177one_hot(tensor, num_classes=-1) -> LongTensor
5178
5179Takes LongTensor with index values of shape ``(*)`` and returns a tensor
5180of shape ``(*, num_classes)`` that have zeros everywhere except where the
5181index of last dimension matches the corresponding value of the input tensor,
5182in which case it will be 1.
5183
5184See also `One-hot on Wikipedia`_ .
5185
5186.. _One-hot on Wikipedia:
5187    https://en.wikipedia.org/wiki/One-hot
5188
5189Arguments:
5190    tensor (LongTensor): class values of any shape.
5191    num_classes (int):  Total number of classes. If set to -1, the number
5192        of classes will be inferred as one greater than the largest class
5193        value in the input tensor.
5194
5195Returns:
5196    LongTensor that has one more dimension with 1 values at the
5197    index of last dimension indicated by the input, and 0 everywhere
5198    else.
5199
5200Examples:
5201    >>> F.one_hot(torch.arange(0, 5) % 3)
5202    tensor([[1, 0, 0],
5203            [0, 1, 0],
5204            [0, 0, 1],
5205            [1, 0, 0],
5206            [0, 1, 0]])
5207    >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
5208    tensor([[1, 0, 0, 0, 0],
5209            [0, 1, 0, 0, 0],
5210            [0, 0, 1, 0, 0],
5211            [1, 0, 0, 0, 0],
5212            [0, 1, 0, 0, 0]])
5213    >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)
5214    tensor([[[1, 0, 0],
5215             [0, 1, 0]],
5216            [[0, 0, 1],
5217             [1, 0, 0]],
5218            [[0, 1, 0],
5219             [0, 0, 1]]])
5220""",
5221)
5222
5223
5224def triplet_margin_loss(
5225    anchor: Tensor,
5226    positive: Tensor,
5227    negative: Tensor,
5228    margin: float = 1.0,
5229    p: float = 2,
5230    eps: float = 1e-6,
5231    swap: bool = False,
5232    size_average: Optional[bool] = None,
5233    reduce: Optional[bool] = None,
5234    reduction: str = "mean",
5235) -> Tensor:
5236    r"""Compute the triplet loss between given input tensors and a margin greater than 0.
5237
5238    See :class:`~torch.nn.TripletMarginLoss` for details.
5239    """
5240    if has_torch_function_variadic(anchor, positive, negative):
5241        return handle_torch_function(
5242            triplet_margin_loss,
5243            (anchor, positive, negative),
5244            anchor,
5245            positive,
5246            negative,
5247            margin=margin,
5248            p=p,
5249            eps=eps,
5250            swap=swap,
5251            size_average=size_average,
5252            reduce=reduce,
5253            reduction=reduction,
5254        )
5255    if size_average is not None or reduce is not None:
5256        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
5257    else:
5258        reduction_enum = _Reduction.get_enum(reduction)
5259    if margin <= 0:
5260        raise ValueError(f"margin must be greater than 0, got {margin}")
5261    return torch.triplet_margin_loss(
5262        anchor, positive, negative, margin, p, eps, swap, reduction_enum
5263    )
5264
5265
5266def triplet_margin_with_distance_loss(
5267    anchor: Tensor,
5268    positive: Tensor,
5269    negative: Tensor,
5270    *,
5271    distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
5272    margin: float = 1.0,
5273    swap: bool = False,
5274    reduction: str = "mean",
5275) -> Tensor:
5276    r"""Compute the triplet margin loss for input tensors using a custom distance function.
5277
5278    See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details.
5279    """
5280    if torch.jit.is_scripting():
5281        raise NotImplementedError(
5282            "F.triplet_margin_with_distance_loss does not support JIT scripting: "
5283            "functions requiring Callables cannot be scripted."
5284        )
5285
5286    if has_torch_function_variadic(anchor, positive, negative):
5287        return handle_torch_function(
5288            triplet_margin_with_distance_loss,
5289            (anchor, positive, negative),
5290            anchor,
5291            positive,
5292            negative,
5293            distance_function=distance_function,
5294            margin=margin,
5295            swap=swap,
5296            reduction=reduction,
5297        )
5298
5299    # Check validity of reduction mode
5300    if reduction not in ("mean", "sum", "none"):
5301        raise ValueError(f"{reduction} is not a valid value for reduction")
5302
5303    # Check validity of margin
5304    if margin <= 0:
5305        raise ValueError(f"margin must be greater than 0, got {margin}")
5306
5307    # Check dimensions
5308    a_dim = anchor.ndim
5309    p_dim = positive.ndim
5310    n_dim = negative.ndim
5311    if not (a_dim == p_dim and p_dim == n_dim):
5312        raise RuntimeError(
5313            f"The anchor, positive, and negative tensors are expected to have "
5314            f"the same number of dimensions, but got: anchor {a_dim}D, "
5315            f"positive {p_dim}D, and negative {n_dim}D inputs"
5316        )
5317
5318    # Calculate loss
5319    if distance_function is None:
5320        distance_function = torch.pairwise_distance
5321
5322    dist_pos = distance_function(anchor, positive)
5323    dist_neg = distance_function(anchor, negative)
5324    # The distance swap is described in the paper "Learning shallow
5325    # convolutional feature descriptors with triplet losses" by V. Balntas, E.
5326    # Riba et al.  If True, and if the positive example is closer to the
5327    # negative example than the anchor is, swaps the positive example and the
5328    # anchor in the loss computation.
5329    if swap:
5330        dist_swap = distance_function(positive, negative)
5331        dist_neg = torch.minimum(dist_neg, dist_swap)
5332    loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
5333
5334    # Apply reduction
5335    if reduction == "sum":
5336        return torch.sum(loss)
5337    elif reduction == "mean":
5338        return torch.mean(loss)
5339    else:  # reduction == "none"
5340        return loss
5341
5342
5343def normalize(
5344    input: Tensor,
5345    p: float = 2.0,
5346    dim: int = 1,
5347    eps: float = 1e-12,
5348    out: Optional[Tensor] = None,
5349) -> Tensor:
5350    r"""Perform :math:`L_p` normalization of inputs over specified dimension.
5351
5352    For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
5353    :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as
5354
5355    .. math::
5356        v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
5357
5358    With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization.
5359
5360    Args:
5361        input: input tensor of any shape
5362        p (float): the exponent value in the norm formulation. Default: 2
5363        dim (int or tuple of ints): the dimension to reduce. Default: 1
5364        eps (float): small value to avoid division by zero. Default: 1e-12
5365        out (Tensor, optional): the output tensor. If :attr:`out` is used, this
5366                                operation won't be differentiable.
5367    """
5368    if has_torch_function_variadic(input, out):
5369        return handle_torch_function(
5370            normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out
5371        )
5372    if out is None:
5373        denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input)
5374        return input / denom
5375    else:
5376        denom = input.norm(p, dim, keepdim=True).clamp_min_(eps).expand_as(input)
5377        return torch.div(input, denom, out=out)
5378
5379
5380def assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None:
5381    assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
5382
5383
5384def unfold(
5385    input: Tensor,
5386    kernel_size: BroadcastingList2[int],
5387    dilation: BroadcastingList2[int] = 1,
5388    padding: BroadcastingList2[int] = 0,
5389    stride: BroadcastingList2[int] = 1,
5390) -> Tensor:
5391    r"""Extract sliding local blocks from a batched input tensor.
5392
5393    .. warning::
5394        Currently, only 4-D input tensors (batched image-like tensors) are
5395        supported.
5396
5397    .. warning::
5398
5399        More than one element of the unfolded tensor may refer to a single
5400        memory location. As a result, in-place operations (especially ones that
5401        are vectorized) may result in incorrect behavior. If you need to write
5402        to the tensor, please clone it first.
5403
5404
5405    See :class:`torch.nn.Unfold` for details
5406    """
5407    if has_torch_function_unary(input):
5408        return handle_torch_function(
5409            unfold,
5410            (input,),
5411            input,
5412            kernel_size,
5413            dilation=dilation,
5414            padding=padding,
5415            stride=stride,
5416        )
5417    return torch._C._nn.im2col(
5418        input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)
5419    )
5420
5421
5422def fold(
5423    input: Tensor,
5424    output_size: BroadcastingList2[int],
5425    kernel_size: BroadcastingList2[int],
5426    dilation: BroadcastingList2[int] = 1,
5427    padding: BroadcastingList2[int] = 0,
5428    stride: BroadcastingList2[int] = 1,
5429) -> Tensor:
5430    r"""Combine an array of sliding local blocks into a large containing tensor.
5431
5432    .. warning::
5433        Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported.
5434
5435    See :class:`torch.nn.Fold` for details
5436    """
5437    if has_torch_function_unary(input):
5438        return handle_torch_function(
5439            fold,
5440            (input,),
5441            input,
5442            output_size,
5443            kernel_size,
5444            dilation=dilation,
5445            padding=padding,
5446            stride=stride,
5447        )
5448    return torch._C._nn.col2im(
5449        input,
5450        _pair(output_size),
5451        _pair(kernel_size),
5452        _pair(dilation),
5453        _pair(padding),
5454        _pair(stride),
5455    )
5456
5457
5458#
5459# multihead attention
5460#
5461
5462
5463def _in_projection_packed(
5464    q: Tensor,
5465    k: Tensor,
5466    v: Tensor,
5467    w: Tensor,
5468    b: Optional[Tensor] = None,
5469) -> List[Tensor]:
5470    r"""Perform the in-projection step of the attention operation, using packed weights.
5471
5472    Output is a triple containing projection tensors for query, key and value.
5473
5474    Args:
5475        q, k, v: query, key and value tensors to be projected. For self-attention,
5476            these are typically the same tensor; for encoder-decoder attention,
5477            k and v are typically the same tensor. (We take advantage of these
5478            identities for performance if they are present.) Regardless, q, k and v
5479            must share a common embedding dimension; otherwise their shapes may vary.
5480        w: projection weights for q, k and v, packed into a single tensor. Weights
5481            are packed along dimension 0, in q, k, v order.
5482        b: optional projection biases for q, k and v, packed into a single tensor
5483            in q, k, v order.
5484
5485    Shape:
5486        Inputs:
5487        - q: :math:`(..., E)` where E is the embedding dimension
5488        - k: :math:`(..., E)` where E is the embedding dimension
5489        - v: :math:`(..., E)` where E is the embedding dimension
5490        - w: :math:`(E * 3, E)` where E is the embedding dimension
5491        - b: :math:`E * 3` where E is the embedding dimension
5492
5493        Output:
5494        - in output list :math:`[q', k', v']`, each output tensor will have the
5495            same shape as the corresponding input tensor.
5496    """
5497    E = q.size(-1)
5498    if k is v:
5499        if q is k:
5500            # self-attention
5501            proj = linear(q, w, b)
5502            # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
5503            proj = (
5504                proj.unflatten(-1, (3, E))
5505                .unsqueeze(0)
5506                .transpose(0, -2)
5507                .squeeze(-2)
5508                .contiguous()
5509            )
5510            return proj[0], proj[1], proj[2]
5511        else:
5512            # encoder-decoder attention
5513            w_q, w_kv = w.split([E, E * 2])
5514            if b is None:
5515                b_q = b_kv = None
5516            else:
5517                b_q, b_kv = b.split([E, E * 2])
5518            q_proj = linear(q, w_q, b_q)
5519            kv_proj = linear(k, w_kv, b_kv)
5520            # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
5521            kv_proj = (
5522                kv_proj.unflatten(-1, (2, E))
5523                .unsqueeze(0)
5524                .transpose(0, -2)
5525                .squeeze(-2)
5526                .contiguous()
5527            )
5528            return (q_proj, kv_proj[0], kv_proj[1])
5529    else:
5530        w_q, w_k, w_v = w.chunk(3)
5531        if b is None:
5532            b_q = b_k = b_v = None
5533        else:
5534            b_q, b_k, b_v = b.chunk(3)
5535        return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
5536
5537
5538def _in_projection(
5539    q: Tensor,
5540    k: Tensor,
5541    v: Tensor,
5542    w_q: Tensor,
5543    w_k: Tensor,
5544    w_v: Tensor,
5545    b_q: Optional[Tensor] = None,
5546    b_k: Optional[Tensor] = None,
5547    b_v: Optional[Tensor] = None,
5548) -> Tuple[Tensor, Tensor, Tensor]:
5549    r"""Perform the in-projection step of the attention operation.
5550
5551    This is simply a triple of linear projections,
5552    with shape constraints on the weights which
5553    ensure embedding dimension uniformity in the projected outputs.
5554    Output is a triple containing projection tensors for query, key and value.
5555
5556    Args:
5557        q, k, v: query, key and value tensors to be projected.
5558        w_q, w_k, w_v: weights for q, k and v, respectively.
5559        b_q, b_k, b_v: optional biases for q, k and v, respectively.
5560
5561    Shape:
5562        Inputs:
5563        - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
5564            number of leading dimensions.
5565        - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
5566            number of leading dimensions.
5567        - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
5568            number of leading dimensions.
5569        - w_q: :math:`(Eq, Eq)`
5570        - w_k: :math:`(Eq, Ek)`
5571        - w_v: :math:`(Eq, Ev)`
5572        - b_q: :math:`(Eq)`
5573        - b_k: :math:`(Eq)`
5574        - b_v: :math:`(Eq)`
5575
5576        Output: in output triple :math:`(q', k', v')`,
5577         - q': :math:`[Qdims..., Eq]`
5578         - k': :math:`[Kdims..., Eq]`
5579         - v': :math:`[Vdims..., Eq]`
5580
5581    """
5582    Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
5583    assert w_q.shape == (
5584        Eq,
5585        Eq,
5586    ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
5587    assert w_k.shape == (
5588        Eq,
5589        Ek,
5590    ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
5591    assert w_v.shape == (
5592        Eq,
5593        Ev,
5594    ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
5595    assert b_q is None or b_q.shape == (
5596        Eq,
5597    ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
5598    assert b_k is None or b_k.shape == (
5599        Eq,
5600    ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
5601    assert b_v is None or b_v.shape == (
5602        Eq,
5603    ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
5604    return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
5605
5606
5607scaled_dot_product_attention = _add_docstr(
5608    torch._C._nn.scaled_dot_product_attention,
5609    r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
5610        is_causal=False, scale=None, enable_gqa=False) -> Tensor:
5611
5612    Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed,
5613    and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be
5614    specified as a keyword argument.
5615
5616    .. code-block:: python
5617
5618        # Efficient implementation equivalent to the following:
5619        def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
5620                is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
5621            L, S = query.size(-2), key.size(-2)
5622            scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
5623            attn_bias = torch.zeros(L, S, dtype=query.dtype)
5624            if is_causal:
5625                assert attn_mask is None
5626                temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
5627                attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
5628                attn_bias.to(query.dtype)
5629
5630            if attn_mask is not None:
5631                if attn_mask.dtype == torch.bool:
5632                    attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
5633                else:
5634                    attn_bias += attn_mask
5635
5636            if enable_gqa:
5637                key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
5638                value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
5639
5640            attn_weight = query @ key.transpose(-2, -1) * scale_factor
5641            attn_weight += attn_bias
5642            attn_weight = torch.softmax(attn_weight, dim=-1)
5643            attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
5644            return attn_weight @ value
5645
5646    .. warning::
5647        This function is beta and subject to change.
5648
5649    .. warning::
5650        This function always applies dropout according to the specified ``dropout_p`` argument.
5651        To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module
5652        that makes the function call is not in training mode.
5653
5654        For example:
5655
5656        .. code-block:: python
5657
5658            class MyModel(nn.Module):
5659                def __init__(self, p=0.5):
5660                    super().__init__()
5661                    self.p = p
5662
5663                def forward(self, ...):
5664                    return F.scaled_dot_product_attention(...,
5665                        dropout_p=(self.p if self.training else 0.0))
5666
5667    Note:
5668
5669        There are currently three supported implementations of scaled dot product attention:
5670
5671            - `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_
5672            - `Memory-Efficient Attention`_
5673            - A PyTorch implementation defined in C++ matching the above formulation
5674
5675        The function may call optimized kernels for improved performance when using the CUDA backend.
5676        For all other backends, the PyTorch implementation will be used.
5677
5678        All implementations are enabled by default. Scaled dot product attention attempts to automatically select the
5679        most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation
5680        is used, the following functions are provided for enabling and disabling implementations.
5681        The context manager is the preferred mechanism:
5682
5683            - :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations.
5684            - :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention.
5685            - :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables  Memory-Efficient Attention.
5686            - :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables  the PyTorch C++ implementation.
5687
5688        Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation,
5689        disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`.
5690        In the event that a fused implementation is not available, a warning will be raised with the
5691        reasons why the fused implementation cannot run.
5692
5693        Due to the nature of fusing floating point operations, the output of this function may be different
5694        depending on what backend kernel is chosen.
5695        The c++ implementation supports torch.float64 and can be used when higher precision is required.
5696        For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16.
5697    For more information please see :doc:`/notes/numerical_accuracy`
5698
5699        Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention
5700        and math kernel on CUDA tensor, and does not support Nested tensor.
5701        Constraints for GQA:
5702
5703            - number_of_heads_query % number_of_heads_key_value == 0 and,
5704            - number_of_heads_key == number_of_heads_value
5705
5706    Note:
5707
5708        {cudnn_reproducibility_note}
5709    """.format(
5710        **reproducibility_notes
5711    )
5712    + r"""
5713    Args:
5714        query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
5715        key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`.
5716        value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`.
5717        attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
5718            which is :math:`(N,..., L, S)`. Two types of masks are supported.
5719            A boolean mask where a value of True indicates that the element *should* take part in attention.
5720            A float mask of the same type as query, key, value that is added to the attention score.
5721        dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
5722        is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a
5723            square matrix. The attention masking has the form of the upper left causal bias due to the alignment
5724            (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix.
5725            An error is thrown if both attn_mask and is_causal are set.
5726        scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
5727            to :math:`\frac{1}{\sqrt{E}}`.
5728        enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
5729
5730    Returns:
5731        output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`.
5732
5733    Shape legend:
5734        - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
5735        - :math:`S: \text{Source sequence length}`
5736        - :math:`L: \text{Target sequence length}`
5737        - :math:`E: \text{Embedding dimension of the query and key}`
5738        - :math:`Ev: \text{Embedding dimension of the value}`
5739        - :math:`Hq: \text{Number of heads of query}`
5740        - :math:`H: \text{Number of heads of key and value}`
5741
5742    Examples:
5743
5744        >>> # Optionally use the context manager to ensure one of the fused kernels is run
5745        >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
5746        >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
5747        >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
5748        >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
5749        >>>     F.scaled_dot_product_attention(query,key,value)
5750
5751
5752        >>> # Sample for GQA for llama3
5753        >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
5754        >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
5755        >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
5756        >>> with sdpa_kernel(backends=[SDPBackend.MATH]):
5757        >>>     F.scaled_dot_product_attention(query,key,value,enable_gqa=True)
5758
5759
5760    .. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning:
5761        https://arxiv.org/abs/2307.08691
5762    .. _Memory-Efficient Attention:
5763        https://github.com/facebookresearch/xformers
5764    .. _Grouped-Query Attention:
5765        https://arxiv.org/pdf/2305.13245
5766    """,
5767)
5768
5769
5770def _mha_shape_check(
5771    query: Tensor,
5772    key: Tensor,
5773    value: Tensor,
5774    key_padding_mask: Optional[Tensor],
5775    attn_mask: Optional[Tensor],
5776    num_heads: int,
5777):
5778    # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
5779    # and returns if the input is batched or not.
5780    # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
5781
5782    # Shape check.
5783    if query.dim() == 3:
5784        # Batched Inputs
5785        is_batched = True
5786        assert key.dim() == 3 and value.dim() == 3, (
5787            "For batched (3-D) `query`, expected `key` and `value` to be 3-D"
5788            f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
5789        )
5790        if key_padding_mask is not None:
5791            assert key_padding_mask.dim() == 2, (
5792                "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
5793                f" but found {key_padding_mask.dim()}-D tensor instead"
5794            )
5795        if attn_mask is not None:
5796            assert attn_mask.dim() in (2, 3), (
5797                "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
5798                f" but found {attn_mask.dim()}-D tensor instead"
5799            )
5800    elif query.dim() == 2:
5801        # Unbatched Inputs
5802        is_batched = False
5803        assert key.dim() == 2 and value.dim() == 2, (
5804            "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
5805            f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
5806        )
5807
5808        if key_padding_mask is not None:
5809            assert key_padding_mask.dim() == 1, (
5810                "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
5811                f" but found {key_padding_mask.dim()}-D tensor instead"
5812            )
5813
5814        if attn_mask is not None:
5815            assert attn_mask.dim() in (2, 3), (
5816                "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
5817                f" but found {attn_mask.dim()}-D tensor instead"
5818            )
5819            if attn_mask.dim() == 3:
5820                expected_shape = (num_heads, query.shape[0], key.shape[0])
5821                assert (
5822                    attn_mask.shape == expected_shape
5823                ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
5824    else:
5825        raise AssertionError(
5826            f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
5827        )
5828
5829    return is_batched
5830
5831
5832def _canonical_mask(
5833    mask: Optional[Tensor],
5834    mask_name: str,
5835    other_type: Optional[DType],
5836    other_name: str,
5837    target_type: DType,
5838    check_other: bool = True,
5839) -> Optional[Tensor]:
5840    if mask is not None:
5841        _mask_dtype = mask.dtype
5842        _mask_is_float = torch.is_floating_point(mask)
5843        if _mask_dtype != torch.bool and not _mask_is_float:
5844            raise AssertionError(
5845                f"only bool and floating types of {mask_name} are supported"
5846            )
5847        if check_other and other_type is not None:
5848            if _mask_dtype != other_type:
5849                warnings.warn(
5850                    f"Support for mismatched {mask_name} and {other_name} "
5851                    "is deprecated. Use same type for both instead."
5852                )
5853        if not _mask_is_float:
5854            mask = torch.zeros_like(mask, dtype=target_type).masked_fill_(
5855                mask, float("-inf")
5856            )
5857    return mask
5858
5859
5860def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
5861    if input is None:
5862        return None
5863    elif isinstance(input, torch.Tensor):
5864        return input.dtype
5865    raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
5866
5867
5868def multi_head_attention_forward(
5869    query: Tensor,
5870    key: Tensor,
5871    value: Tensor,
5872    embed_dim_to_check: int,
5873    num_heads: int,
5874    in_proj_weight: Optional[Tensor],
5875    in_proj_bias: Optional[Tensor],
5876    bias_k: Optional[Tensor],
5877    bias_v: Optional[Tensor],
5878    add_zero_attn: bool,
5879    dropout_p: float,
5880    out_proj_weight: Tensor,
5881    out_proj_bias: Optional[Tensor],
5882    training: bool = True,
5883    key_padding_mask: Optional[Tensor] = None,
5884    need_weights: bool = True,
5885    attn_mask: Optional[Tensor] = None,
5886    use_separate_proj_weight: bool = False,
5887    q_proj_weight: Optional[Tensor] = None,
5888    k_proj_weight: Optional[Tensor] = None,
5889    v_proj_weight: Optional[Tensor] = None,
5890    static_k: Optional[Tensor] = None,
5891    static_v: Optional[Tensor] = None,
5892    average_attn_weights: bool = True,
5893    is_causal: bool = False,
5894) -> Tuple[Tensor, Optional[Tensor]]:
5895    r"""Forward method for MultiHeadAttention.
5896
5897    See :class:`torch.nn.MultiheadAttention` for details.
5898
5899    Args:
5900        query, key, value: map a query and a set of key-value pairs to an output.
5901            See "Attention Is All You Need" for more details.
5902        embed_dim_to_check: total dimension of the model.
5903        num_heads: parallel attention heads.
5904        in_proj_weight, in_proj_bias: input projection weight and bias.
5905        bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
5906        add_zero_attn: add a new batch of zeros to the key and
5907                       value sequences at dim=1.
5908        dropout_p: probability of an element to be zeroed.
5909        out_proj_weight, out_proj_bias: the output projection weight and bias.
5910        training: apply dropout if is ``True``.
5911        key_padding_mask: if provided, specified padding elements in the key will
5912            be ignored by the attention. This is an binary mask. When the value is True,
5913            the corresponding value on the attention layer will be filled with -inf.
5914        need_weights: output attn_output_weights.
5915            Default: `True`
5916            Note: `needs_weight` defaults to `True`, but should be set to `False`
5917            For best performance when attention weights are not needed.
5918            *Setting needs_weights to `True`
5919            leads to a significant performance degradation.*
5920        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
5921            the batches while a 3D mask allows to specify a different mask for the entries of each batch.
5922        is_causal: If specified, applies a causal mask as attention mask, and ignores
5923            attn_mask for computing scaled dot product attention.
5924            Default: ``False``.
5925            .. warning::
5926                is_causal is provides a hint that the attn_mask is the
5927                causal mask.Providing incorrect hints can result in
5928                incorrect execution, including forward and backward
5929                compatibility.
5930        use_separate_proj_weight: the function accept the proj. weights for query, key,
5931            and value in different forms. If false, in_proj_weight will be used, which is
5932            a combination of q_proj_weight, k_proj_weight, v_proj_weight.
5933        q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
5934        static_k, static_v: static key and value used for attention operators.
5935        average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
5936            Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
5937            when ``need_weights=True.``. Default: True
5938
5939
5940    Shape:
5941        Inputs:
5942        - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
5943          the embedding dimension.
5944        - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
5945          the embedding dimension.
5946        - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
5947          the embedding dimension.
5948        - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
5949          If a FloatTensor is provided, it will be directly added to the value.
5950          If a BoolTensor is provided, the positions with the
5951          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
5952        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
5953          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
5954          S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
5955          positions. If a BoolTensor is provided, positions with ``True``
5956          are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
5957          is provided, it will be added to the attention weight.
5958        - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
5959          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
5960        - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
5961          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
5962
5963        Outputs:
5964        - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
5965          E is the embedding dimension.
5966        - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
5967          attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
5968          :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
5969          :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
5970          head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
5971    """
5972    tens_ops = (
5973        query,
5974        key,
5975        value,
5976        in_proj_weight,
5977        in_proj_bias,
5978        bias_k,
5979        bias_v,
5980        out_proj_weight,
5981        out_proj_bias,
5982    )
5983    if has_torch_function(tens_ops):
5984        return handle_torch_function(
5985            multi_head_attention_forward,
5986            tens_ops,
5987            query,
5988            key,
5989            value,
5990            embed_dim_to_check,
5991            num_heads,
5992            in_proj_weight,
5993            in_proj_bias,
5994            bias_k,
5995            bias_v,
5996            add_zero_attn,
5997            dropout_p,
5998            out_proj_weight,
5999            out_proj_bias,
6000            training=training,
6001            key_padding_mask=key_padding_mask,
6002            need_weights=need_weights,
6003            attn_mask=attn_mask,
6004            is_causal=is_causal,
6005            use_separate_proj_weight=use_separate_proj_weight,
6006            q_proj_weight=q_proj_weight,
6007            k_proj_weight=k_proj_weight,
6008            v_proj_weight=v_proj_weight,
6009            static_k=static_k,
6010            static_v=static_v,
6011            average_attn_weights=average_attn_weights,
6012        )
6013
6014    is_batched = _mha_shape_check(
6015        query, key, value, key_padding_mask, attn_mask, num_heads
6016    )
6017
6018    # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
6019    # is batched, run the computation and before returning squeeze the
6020    # batch dimension so that the output doesn't carry this temporary batch dimension.
6021    if not is_batched:
6022        # unsqueeze if the input is unbatched
6023        query = query.unsqueeze(1)
6024        key = key.unsqueeze(1)
6025        value = value.unsqueeze(1)
6026        if key_padding_mask is not None:
6027            key_padding_mask = key_padding_mask.unsqueeze(0)
6028
6029    # set up shape vars
6030    tgt_len, bsz, embed_dim = query.shape
6031    src_len, _, _ = key.shape
6032
6033    key_padding_mask = _canonical_mask(
6034        mask=key_padding_mask,
6035        mask_name="key_padding_mask",
6036        other_type=_none_or_dtype(attn_mask),
6037        other_name="attn_mask",
6038        target_type=query.dtype,
6039    )
6040
6041    if is_causal and attn_mask is None:
6042        raise RuntimeError(
6043            "Need attn_mask if specifying the is_causal hint. "
6044            "You may use the Transformer module method "
6045            "`generate_square_subsequent_mask` to create this mask."
6046        )
6047
6048    if is_causal and key_padding_mask is None and not need_weights:
6049        # when we have a kpm or need weights, we need attn_mask
6050        # Otherwise, we use the is_causal hint go as is_causal
6051        # indicator to SDPA.
6052        attn_mask = None
6053    else:
6054        attn_mask = _canonical_mask(
6055            mask=attn_mask,
6056            mask_name="attn_mask",
6057            other_type=None,
6058            other_name="",
6059            target_type=query.dtype,
6060            check_other=False,
6061        )
6062
6063        if key_padding_mask is not None:
6064            # We have the attn_mask, and use that to merge kpm into it.
6065            # Turn off use of is_causal hint, as the merged mask is no
6066            # longer causal.
6067            is_causal = False
6068
6069    assert (
6070        embed_dim == embed_dim_to_check
6071    ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
6072    if isinstance(embed_dim, torch.Tensor):
6073        # embed_dim can be a tensor when JIT tracing
6074        head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
6075    else:
6076        head_dim = embed_dim // num_heads
6077    assert (
6078        head_dim * num_heads == embed_dim
6079    ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
6080    if use_separate_proj_weight:
6081        # allow MHA to have different embedding dimensions when separate projection weights are used
6082        assert (
6083            key.shape[:2] == value.shape[:2]
6084        ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
6085    else:
6086        assert (
6087            key.shape == value.shape
6088        ), f"key shape {key.shape} does not match value shape {value.shape}"
6089
6090    #
6091    # compute in-projection
6092    #
6093    if not use_separate_proj_weight:
6094        assert (
6095            in_proj_weight is not None
6096        ), "use_separate_proj_weight is False but in_proj_weight is None"
6097        q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
6098    else:
6099        assert (
6100            q_proj_weight is not None
6101        ), "use_separate_proj_weight is True but q_proj_weight is None"
6102        assert (
6103            k_proj_weight is not None
6104        ), "use_separate_proj_weight is True but k_proj_weight is None"
6105        assert (
6106            v_proj_weight is not None
6107        ), "use_separate_proj_weight is True but v_proj_weight is None"
6108        if in_proj_bias is None:
6109            b_q = b_k = b_v = None
6110        else:
6111            b_q, b_k, b_v = in_proj_bias.chunk(3)
6112        q, k, v = _in_projection(
6113            query,
6114            key,
6115            value,
6116            q_proj_weight,
6117            k_proj_weight,
6118            v_proj_weight,
6119            b_q,
6120            b_k,
6121            b_v,
6122        )
6123
6124    # prep attention mask
6125
6126    if attn_mask is not None:
6127        # ensure attn_mask's dim is 3
6128        if attn_mask.dim() == 2:
6129            correct_2d_size = (tgt_len, src_len)
6130            if attn_mask.shape != correct_2d_size:
6131                raise RuntimeError(
6132                    f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
6133                )
6134            attn_mask = attn_mask.unsqueeze(0)
6135        elif attn_mask.dim() == 3:
6136            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
6137            if attn_mask.shape != correct_3d_size:
6138                raise RuntimeError(
6139                    f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
6140                )
6141        else:
6142            raise RuntimeError(
6143                f"attn_mask's dimension {attn_mask.dim()} is not supported"
6144            )
6145
6146    # add bias along batch dimension (currently second)
6147    if bias_k is not None and bias_v is not None:
6148        assert static_k is None, "bias cannot be added to static key."
6149        assert static_v is None, "bias cannot be added to static value."
6150        k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
6151        v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
6152        if attn_mask is not None:
6153            attn_mask = pad(attn_mask, (0, 1))
6154        if key_padding_mask is not None:
6155            key_padding_mask = pad(key_padding_mask, (0, 1))
6156    else:
6157        assert bias_k is None
6158        assert bias_v is None
6159
6160    #
6161    # reshape q, k, v for multihead attention and make them batch first
6162    #
6163    q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
6164    if static_k is None:
6165        k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
6166    else:
6167        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
6168        assert (
6169            static_k.size(0) == bsz * num_heads
6170        ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
6171        assert (
6172            static_k.size(2) == head_dim
6173        ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
6174        k = static_k
6175    if static_v is None:
6176        v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
6177    else:
6178        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
6179        assert (
6180            static_v.size(0) == bsz * num_heads
6181        ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
6182        assert (
6183            static_v.size(2) == head_dim
6184        ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
6185        v = static_v
6186
6187    # add zero attention along batch dimension (now first)
6188    if add_zero_attn:
6189        zero_attn_shape = (bsz * num_heads, 1, head_dim)
6190        k = torch.cat(
6191            [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
6192        )
6193        v = torch.cat(
6194            [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
6195        )
6196        if attn_mask is not None:
6197            attn_mask = pad(attn_mask, (0, 1))
6198        if key_padding_mask is not None:
6199            key_padding_mask = pad(key_padding_mask, (0, 1))
6200
6201    # update source sequence length after adjustments
6202    src_len = k.size(1)
6203
6204    # merge key padding and attention masks
6205    if key_padding_mask is not None:
6206        assert key_padding_mask.shape == (
6207            bsz,
6208            src_len,
6209        ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
6210        key_padding_mask = (
6211            key_padding_mask.view(bsz, 1, 1, src_len)
6212            .expand(-1, num_heads, -1, -1)
6213            .reshape(bsz * num_heads, 1, src_len)
6214        )
6215        if attn_mask is None:
6216            attn_mask = key_padding_mask
6217        else:
6218            attn_mask = attn_mask + key_padding_mask
6219
6220    # adjust dropout probability
6221    if not training:
6222        dropout_p = 0.0
6223
6224    #
6225    # (deep breath) calculate attention and out projection
6226    #
6227
6228    if need_weights:
6229        B, Nt, E = q.shape
6230        q_scaled = q * math.sqrt(1.0 / float(E))
6231
6232        assert not (
6233            is_causal and attn_mask is None
6234        ), "FIXME: is_causal not implemented for need_weights"
6235
6236        if attn_mask is not None:
6237            attn_output_weights = torch.baddbmm(
6238                attn_mask, q_scaled, k.transpose(-2, -1)
6239            )
6240        else:
6241            attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
6242        attn_output_weights = softmax(attn_output_weights, dim=-1)
6243        if dropout_p > 0.0:
6244            attn_output_weights = dropout(attn_output_weights, p=dropout_p)
6245
6246        attn_output = torch.bmm(attn_output_weights, v)
6247
6248        attn_output = (
6249            attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
6250        )
6251        attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
6252        attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
6253
6254        # optionally average attention weights over heads
6255        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
6256        if average_attn_weights:
6257            attn_output_weights = attn_output_weights.mean(dim=1)
6258
6259        if not is_batched:
6260            # squeeze the output if input was unbatched
6261            attn_output = attn_output.squeeze(1)
6262            attn_output_weights = attn_output_weights.squeeze(0)
6263        return attn_output, attn_output_weights
6264    else:
6265        # attn_mask can be either (L,S) or (N*num_heads, L, S)
6266        # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
6267        # in order to match the input for SDPA of (N, num_heads, L, S)
6268        if attn_mask is not None:
6269            if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
6270                attn_mask = attn_mask.unsqueeze(0)
6271            else:
6272                attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
6273
6274        q = q.view(bsz, num_heads, tgt_len, head_dim)
6275        k = k.view(bsz, num_heads, src_len, head_dim)
6276        v = v.view(bsz, num_heads, src_len, head_dim)
6277
6278        attn_output = scaled_dot_product_attention(
6279            q, k, v, attn_mask, dropout_p, is_causal
6280        )
6281        attn_output = (
6282            attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
6283        )
6284
6285        attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
6286        attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
6287        if not is_batched:
6288            # squeeze the output if input was unbatched
6289            attn_output = attn_output.squeeze(1)
6290        return attn_output, None
6291