1"""Functional interface.""" 2 3import importlib 4import math 5import warnings 6from typing import Callable, List, Optional, Tuple, TYPE_CHECKING, Union 7 8import torch 9from torch import _VF, sym_int as _sym_int, Tensor 10from torch._C import _add_docstr, _infer_size 11from torch._jit_internal import ( 12 _overload, 13 boolean_dispatch, 14 BroadcastingList1, 15 BroadcastingList2, 16 BroadcastingList3, 17) 18from torch._torch_docs import reproducibility_notes, sparse_support_notes, tf32_notes 19from torch.nn import _reduction as _Reduction, grad # noqa: F401 20from torch.nn.modules.utils import _list_with_default, _pair, _single, _triple 21from torch.overrides import ( 22 handle_torch_function, 23 has_torch_function, 24 has_torch_function_unary, 25 has_torch_function_variadic, 26) 27 28 29if TYPE_CHECKING: 30 from torch.types import _dtype as DType 31else: 32 # The JIT doesn't understand Union, nor torch.dtype here 33 DType = int 34 35try: 36 import numpy as np 37except ModuleNotFoundError: 38 np = None 39 40 41conv1d = _add_docstr( 42 torch.conv1d, 43 r""" 44conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor 45 46Applies a 1D convolution over an input signal composed of several input 47planes. 48 49{tf32_note} 50 51See :class:`~torch.nn.Conv1d` for details and output shape. 52 53Note: 54 {cudnn_reproducibility_note} 55 56Note: 57 This operator supports complex data types i.e. ``complex32, complex64, complex128``. 58""".format( 59 **reproducibility_notes, **tf32_notes 60 ) 61 + r""" 62 63Args: 64 input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` 65 weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kW)` 66 bias: optional bias of shape :math:`(\text{out\_channels})`. Default: ``None`` 67 stride: the stride of the convolving kernel. Can be a single number or 68 a one-element tuple `(sW,)`. Default: 1 69 padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, 70 single number or a one-element tuple `(padW,)`. Default: 0 71 ``padding='valid'`` is the same as no padding. ``padding='same'`` pads 72 the input so the output has the same shape as the input. However, this mode 73 doesn't support any stride values other than 1. 74 75 .. warning:: 76 For ``padding='same'``, if the ``weight`` is even-length and 77 ``dilation`` is odd in any dimension, a full :func:`pad` operation 78 may be needed internally. Lowering performance. 79 dilation: the spacing between kernel elements. Can be a single number or 80 a one-element tuple `(dW,)`. Default: 1 81 groups: split input into groups, :math:`\text{in\_channels}` should be divisible by 82 the number of groups. Default: 1 83 84Examples:: 85 86 >>> inputs = torch.randn(33, 16, 30) 87 >>> filters = torch.randn(20, 16, 5) 88 >>> F.conv1d(inputs, filters) 89""", 90) 91 92conv2d = _add_docstr( 93 torch.conv2d, 94 r""" 95conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor 96 97Applies a 2D convolution over an input image composed of several input 98planes. 99 100{tf32_note} 101 102See :class:`~torch.nn.Conv2d` for details and output shape. 103 104Note: 105 {cudnn_reproducibility_note} 106 107Note: 108 This operator supports complex data types i.e. ``complex32, complex64, complex128``. 109""".format( 110 **reproducibility_notes, **tf32_notes 111 ) 112 + r""" 113 114Args: 115 input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` 116 weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` 117 bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None`` 118 stride: the stride of the convolving kernel. Can be a single number or a 119 tuple `(sH, sW)`. Default: 1 120 padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, 121 single number or a tuple `(padH, padW)`. Default: 0 122 ``padding='valid'`` is the same as no padding. ``padding='same'`` pads 123 the input so the output has the same shape as the input. However, this mode 124 doesn't support any stride values other than 1. 125 126 .. warning:: 127 For ``padding='same'``, if the ``weight`` is even-length and 128 ``dilation`` is odd in any dimension, a full :func:`pad` operation 129 may be needed internally. Lowering performance. 130 131 dilation: the spacing between kernel elements. Can be a single number or 132 a tuple `(dH, dW)`. Default: 1 133 groups: split input into groups, both :math:`\text{in\_channels}` and :math:`\text{out\_channels}` 134 should be divisible by the number of groups. Default: 1 135 136Examples:: 137 138 >>> # With square kernels and equal stride 139 >>> filters = torch.randn(8, 4, 3, 3) 140 >>> inputs = torch.randn(1, 4, 5, 5) 141 >>> F.conv2d(inputs, filters, padding=1) 142""", 143) # noqa: E501 144 145conv3d = _add_docstr( 146 torch.conv3d, 147 r""" 148conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor 149 150Applies a 3D convolution over an input image composed of several input 151planes. 152 153{tf32_note} 154 155See :class:`~torch.nn.Conv3d` for details and output shape. 156 157Note: 158 {cudnn_reproducibility_note} 159 160Note: 161 This operator supports complex data types i.e. ``complex32, complex64, complex128``. 162""".format( 163 **reproducibility_notes, **tf32_notes 164 ) 165 + r""" 166 167Args: 168 input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` 169 weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kT , kH , kW)` 170 bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: None 171 stride: the stride of the convolving kernel. Can be a single number or a 172 tuple `(sT, sH, sW)`. Default: 1 173 padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, 174 single number or a tuple `(padT, padH, padW)`. Default: 0 175 ``padding='valid'`` is the same as no padding. ``padding='same'`` pads 176 the input so the output has the same shape as the input. However, this mode 177 doesn't support any stride values other than 1. 178 179 .. warning:: 180 For ``padding='same'``, if the ``weight`` is even-length and 181 ``dilation`` is odd in any dimension, a full :func:`pad` operation 182 may be needed internally. Lowering performance. 183 184 dilation: the spacing between kernel elements. Can be a single number or 185 a tuple `(dT, dH, dW)`. Default: 1 186 groups: split input into groups, :math:`\text{in\_channels}` should be divisible by 187 the number of groups. Default: 1 188 189Examples:: 190 191 >>> filters = torch.randn(33, 16, 3, 3, 3) 192 >>> inputs = torch.randn(20, 16, 50, 10, 20) 193 >>> F.conv3d(inputs, filters) 194""", 195) # noqa: E501 196 197conv_transpose1d = _add_docstr( 198 torch.conv_transpose1d, 199 r""" 200conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor 201 202Applies a 1D transposed convolution operator over an input signal 203composed of several input planes, sometimes also called "deconvolution". 204 205{tf32_note} 206 207See :class:`~torch.nn.ConvTranspose1d` for details and output shape. 208 209Note: 210 {cudnn_reproducibility_note} 211""".format( 212 **reproducibility_notes, **tf32_notes 213 ) 214 + r""" 215 216Args: 217 input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` 218 weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kW)` 219 bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None 220 stride: the stride of the convolving kernel. Can be a single number or a 221 tuple ``(sW,)``. Default: 1 222 padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both 223 sides of each dimension in the input. Can be a single number or a tuple 224 ``(padW,)``. Default: 0 225 output_padding: additional size added to one side of each dimension in the 226 output shape. Can be a single number or a tuple ``(out_padW)``. Default: 0 227 groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the 228 number of groups. Default: 1 229 dilation: the spacing between kernel elements. Can be a single number or 230 a tuple ``(dW,)``. Default: 1 231 232Examples:: 233 234 >>> inputs = torch.randn(20, 16, 50) 235 >>> weights = torch.randn(16, 33, 5) 236 >>> F.conv_transpose1d(inputs, weights) 237""", 238) 239 240conv_transpose2d = _add_docstr( 241 torch.conv_transpose2d, 242 r""" 243conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor 244 245Applies a 2D transposed convolution operator over an input image 246composed of several input planes, sometimes also called "deconvolution". 247 248{tf32_note} 249 250See :class:`~torch.nn.ConvTranspose2d` for details and output shape. 251 252Note: 253 {cudnn_reproducibility_note} 254""".format( 255 **reproducibility_notes, **tf32_notes 256 ) 257 + r""" 258 259Args: 260 input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` 261 weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kH , kW)` 262 bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None 263 stride: the stride of the convolving kernel. Can be a single number or a 264 tuple ``(sH, sW)``. Default: 1 265 padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both 266 sides of each dimension in the input. Can be a single number or a tuple 267 ``(padH, padW)``. Default: 0 268 output_padding: additional size added to one side of each dimension in the 269 output shape. Can be a single number or a tuple ``(out_padH, out_padW)``. 270 Default: 0 271 groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the 272 number of groups. Default: 1 273 dilation: the spacing between kernel elements. Can be a single number or 274 a tuple ``(dH, dW)``. Default: 1 275 276Examples:: 277 278 >>> # With square kernels and equal stride 279 >>> inputs = torch.randn(1, 4, 5, 5) 280 >>> weights = torch.randn(4, 8, 3, 3) 281 >>> F.conv_transpose2d(inputs, weights, padding=1) 282""", 283) # noqa: E501 284 285conv_transpose3d = _add_docstr( 286 torch.conv_transpose3d, 287 r""" 288conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor 289 290Applies a 3D transposed convolution operator over an input image 291composed of several input planes, sometimes also called "deconvolution" 292 293{tf32_note} 294 295See :class:`~torch.nn.ConvTranspose3d` for details and output shape. 296 297Note: 298 {cudnn_reproducibility_note} 299""".format( 300 **reproducibility_notes, **tf32_notes 301 ) 302 + r""" 303 304Args: 305 input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` 306 weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kT , kH , kW)` 307 bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None 308 stride: the stride of the convolving kernel. Can be a single number or a 309 tuple ``(sT, sH, sW)``. Default: 1 310 padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both 311 sides of each dimension in the input. Can be a single number or a tuple 312 ``(padT, padH, padW)``. Default: 0 313 output_padding: additional size added to one side of each dimension in the 314 output shape. Can be a single number or a tuple 315 ``(out_padT, out_padH, out_padW)``. Default: 0 316 groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the 317 number of groups. Default: 1 318 dilation: the spacing between kernel elements. Can be a single number or 319 a tuple `(dT, dH, dW)`. Default: 1 320 321Examples:: 322 323 >>> inputs = torch.randn(20, 16, 50, 10, 20) 324 >>> weights = torch.randn(16, 33, 3, 3, 3) 325 >>> F.conv_transpose3d(inputs, weights) 326""", 327) # noqa: E501 328 329conv_tbc = _add_docstr( 330 torch.conv_tbc, 331 r""" 332Applies a 1-dimensional sequence convolution over an input sequence. 333Input and output dimensions are (Time, Batch, Channels) - hence TBC. 334 335Args: 336 input: input tensor of shape :math:`(\text{sequence length} \times batch \times \text{in\_channels})` 337 weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`) 338 bias: bias of shape (:math:`\text{out\_channels}`) 339 pad: number of timesteps to pad. Default: 0 340""", 341) 342 343 344# Pooling 345avg_pool1d = _add_docstr( 346 torch.avg_pool1d, 347 r""" 348avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor 349 350Applies a 1D average pooling over an input signal composed of several 351input planes. 352 353See :class:`~torch.nn.AvgPool1d` for details and output shape. 354 355Args: 356 input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` 357 kernel_size: the size of the window. Can be a single number or a 358 tuple `(kW,)` 359 stride: the stride of the window. Can be a single number or a tuple 360 `(sW,)`. Default: :attr:`kernel_size` 361 padding: implicit zero paddings on both sides of the input. Can be a 362 single number or a tuple `(padW,)`. Default: 0 363 ceil_mode: when True, will use `ceil` instead of `floor` to compute the 364 output shape. Default: ``False`` 365 count_include_pad: when True, will include the zero-padding in the 366 averaging calculation. Default: ``True`` 367 368Examples:: 369 370 >>> # pool of square window of size=3, stride=2 371 >>> input = torch.tensor([[[1, 2, 3, 4, 5, 6, 7]]], dtype=torch.float32) 372 >>> F.avg_pool1d(input, kernel_size=3, stride=2) 373 tensor([[[ 2., 4., 6.]]]) 374 375""", 376) 377 378 379avg_pool2d = _add_docstr( 380 torch._C._nn.avg_pool2d, 381 r""" 382avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor 383 384Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size 385:math:`sH \times sW` steps. The number of output features is equal to the number of 386input planes. 387 388See :class:`~torch.nn.AvgPool2d` for details and output shape. 389 390Args: 391 input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` 392 kernel_size: size of the pooling region. Can be a single number or a 393 tuple `(kH, kW)` 394 stride: stride of the pooling operation. Can be a single number or a 395 tuple `(sH, sW)`. Default: :attr:`kernel_size` 396 padding: implicit zero paddings on both sides of the input. Can be a 397 single number or a tuple `(padH, padW)`. Default: 0 398 ceil_mode: when True, will use `ceil` instead of `floor` in the formula 399 to compute the output shape. Default: ``False`` 400 count_include_pad: when True, will include the zero-padding in the 401 averaging calculation. Default: ``True`` 402 divisor_override: if specified, it will be used as divisor, otherwise 403 size of the pooling region will be used. Default: None 404""", 405) 406 407avg_pool3d = _add_docstr( 408 torch._C._nn.avg_pool3d, 409 r""" 410avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor 411 412Applies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step 413size :math:`sT \times sH \times sW` steps. The number of output features is equal to 414:math:`\lfloor\frac{\text{input planes}}{sT}\rfloor`. 415 416See :class:`~torch.nn.AvgPool3d` for details and output shape. 417 418Args: 419 input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iT \times iH , iW)` 420 kernel_size: size of the pooling region. Can be a single number or a 421 tuple `(kT, kH, kW)` 422 stride: stride of the pooling operation. Can be a single number or a 423 tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` 424 padding: implicit zero paddings on both sides of the input. Can be a 425 single number or a tuple `(padT, padH, padW)`, Default: 0 426 ceil_mode: when True, will use `ceil` instead of `floor` in the formula 427 to compute the output shape 428 count_include_pad: when True, will include the zero-padding in the 429 averaging calculation 430 divisor_override: if specified, it will be used as divisor, otherwise 431 size of the pooling region will be used. Default: None 432""", 433) 434 435 436def fractional_max_pool2d_with_indices( 437 input: Tensor, 438 kernel_size: BroadcastingList2[int], 439 output_size: Optional[BroadcastingList2[int]] = None, 440 output_ratio: Optional[BroadcastingList2[float]] = None, 441 return_indices: bool = False, 442 _random_samples: Optional[Tensor] = None, 443) -> Tuple[Tensor, Tensor]: # noqa: D400 444 r""" 445 fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) 446 447 Applies 2D fractional max pooling over an input signal composed of several input planes. 448 449 Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham 450 451 The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic 452 step size determined by the target output size. 453 The number of output features is equal to the number of input planes. 454 455 Args: 456 kernel_size: the size of the window to take a max over. 457 Can be a single number :math:`k` (for a square kernel of :math:`k \times k`) 458 or a tuple `(kH, kW)` 459 output_size: the target output size of the image of the form :math:`oH \times oW`. 460 Can be a tuple `(oH, oW)` or a single number :math:`oH` for a square image :math:`oH \times oH` 461 output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. 462 This has to be a number or tuple in the range (0, 1) 463 return_indices: if ``True``, will return the indices along with the outputs. 464 Useful to pass to :func:`~torch.nn.functional.max_unpool2d`. 465 466 Examples:: 467 >>> input = torch.randn(20, 16, 50, 32) 468 >>> # pool of square window of size=3, and target output size 13x12 469 >>> F.fractional_max_pool2d(input, 3, output_size=(13, 12)) 470 >>> # pool of square window and target output size being half of input image size 471 >>> F.fractional_max_pool2d(input, 3, output_ratio=(0.5, 0.5)) 472 473 .. _Fractional MaxPooling: 474 http://arxiv.org/abs/1412.6071 475 """ 476 if has_torch_function_variadic(input, _random_samples): 477 return handle_torch_function( 478 fractional_max_pool2d_with_indices, 479 (input, _random_samples), 480 input, 481 kernel_size, 482 output_size=output_size, 483 output_ratio=output_ratio, 484 return_indices=return_indices, 485 _random_samples=_random_samples, 486 ) 487 if output_size is None and output_ratio is None: 488 raise ValueError( 489 "fractional_max_pool2d requires specifying either an output_size or an output_ratio" 490 ) 491 if output_size is None: 492 assert output_ratio is not None 493 if len(output_ratio) > 2: 494 raise ValueError( 495 "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints." 496 ) 497 _output_ratio = _pair(output_ratio) 498 output_size = [ 499 int(input.size(-2) * _output_ratio[0]), 500 int(input.size(-1) * _output_ratio[1]), 501 ] 502 503 if _random_samples is None: 504 n_batch = 1 if input.dim() == 3 else input.size(0) 505 _random_samples = torch.rand( 506 n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device 507 ) 508 return torch._C._nn.fractional_max_pool2d( 509 input, kernel_size, output_size, _random_samples 510 ) 511 512 513def _fractional_max_pool2d( 514 input: Tensor, 515 kernel_size: BroadcastingList2[int], 516 output_size: Optional[BroadcastingList2[int]] = None, 517 output_ratio: Optional[BroadcastingList2[float]] = None, 518 return_indices: bool = False, 519 _random_samples: Optional[Tensor] = None, 520) -> Tensor: 521 if has_torch_function_variadic(input, _random_samples): 522 return handle_torch_function( 523 fractional_max_pool2d, 524 (input, _random_samples), 525 input, 526 kernel_size, 527 output_size=output_size, 528 output_ratio=output_ratio, 529 return_indices=return_indices, 530 _random_samples=_random_samples, 531 ) 532 return fractional_max_pool2d_with_indices( 533 input, kernel_size, output_size, output_ratio, return_indices, _random_samples 534 )[0] 535 536 537fractional_max_pool2d = boolean_dispatch( 538 arg_name="return_indices", 539 arg_index=4, 540 default=False, 541 if_true=fractional_max_pool2d_with_indices, 542 if_false=_fractional_max_pool2d, 543 module_name=__name__, 544 func_name="fractional_max_pool2d", 545) 546 547 548def fractional_max_pool3d_with_indices( 549 input: Tensor, 550 kernel_size: BroadcastingList3[int], 551 output_size: Optional[BroadcastingList3[int]] = None, 552 output_ratio: Optional[BroadcastingList3[float]] = None, 553 return_indices: bool = False, 554 _random_samples: Optional[Tensor] = None, 555) -> Tuple[Tensor, Tensor]: # noqa: D400 556 r""" 557 fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) 558 559 Applies 3D fractional max pooling over an input signal composed of several input planes. 560 561 Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham 562 563 The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic 564 step size determined by the target output size. 565 The number of output features is equal to the number of input planes. 566 567 Args: 568 kernel_size: the size of the window to take a max over. 569 Can be a single number :math:`k` (for a square kernel of :math:`k \times k \times k`) 570 or a tuple `(kT, kH, kW)` 571 output_size: the target output size of the form :math:`oT \times oH \times oW`. 572 Can be a tuple `(oT, oH, oW)` or a single number :math:`oH` for a cubic output 573 :math:`oH \times oH \times oH` 574 output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. 575 This has to be a number or tuple in the range (0, 1) 576 return_indices: if ``True``, will return the indices along with the outputs. 577 Useful to pass to :func:`~torch.nn.functional.max_unpool3d`. 578 579 Shape: 580 - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`. 581 - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where 582 :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or 583 :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})` 584 585 Examples:: 586 >>> input = torch.randn(20, 16, 50, 32, 16) 587 >>> # pool of cubic window of size=3, and target output size 13x12x11 588 >>> F.fractional_max_pool3d(input, 3, output_size=(13, 12, 11)) 589 >>> # pool of cubic window and target output size being half of input size 590 >>> F.fractional_max_pool3d(input, 3, output_ratio=(0.5, 0.5, 0.5)) 591 592 .. _Fractional MaxPooling: 593 http://arxiv.org/abs/1412.6071 594 """ 595 if has_torch_function_variadic(input, _random_samples): 596 return handle_torch_function( 597 fractional_max_pool3d_with_indices, 598 (input, _random_samples), 599 input, 600 kernel_size, 601 output_size=output_size, 602 output_ratio=output_ratio, 603 return_indices=return_indices, 604 _random_samples=_random_samples, 605 ) 606 if output_size is None and output_ratio is None: 607 raise ValueError( 608 "fractional_max_pool3d requires specifying either an output_size or an output_ratio" 609 ) 610 if output_size is None: 611 assert output_ratio is not None 612 _output_ratio = _triple(output_ratio) 613 output_size = [ 614 int(input.size(-3) * _output_ratio[0]), 615 int(input.size(-2) * _output_ratio[1]), 616 int(input.size(-1) * _output_ratio[2]), 617 ] 618 619 if _random_samples is None: 620 n_batch = 1 if input.dim() == 4 else input.size(0) 621 _random_samples = torch.rand( 622 n_batch, input.size(-4), 3, dtype=input.dtype, device=input.device 623 ) 624 return torch._C._nn.fractional_max_pool3d( 625 input, kernel_size, output_size, _random_samples 626 ) 627 628 629def _fractional_max_pool3d( 630 input: Tensor, 631 kernel_size: BroadcastingList3[int], 632 output_size: Optional[BroadcastingList3[int]] = None, 633 output_ratio: Optional[BroadcastingList3[float]] = None, 634 return_indices: bool = False, 635 _random_samples: Optional[Tensor] = None, 636) -> Tensor: 637 if has_torch_function_variadic(input, _random_samples): 638 return handle_torch_function( 639 fractional_max_pool3d, 640 (input, _random_samples), 641 input, 642 kernel_size, 643 output_size=output_size, 644 output_ratio=output_ratio, 645 return_indices=return_indices, 646 _random_samples=_random_samples, 647 ) 648 return fractional_max_pool3d_with_indices( 649 input, kernel_size, output_size, output_ratio, return_indices, _random_samples 650 )[0] 651 652 653fractional_max_pool3d = boolean_dispatch( 654 arg_name="return_indices", 655 arg_index=4, 656 default=False, 657 if_true=fractional_max_pool3d_with_indices, 658 if_false=_fractional_max_pool3d, 659 module_name=__name__, 660 func_name="fractional_max_pool3d", 661) 662 663 664def max_pool1d_with_indices( 665 input: Tensor, 666 kernel_size: BroadcastingList1[int], 667 stride: Optional[BroadcastingList1[int]] = None, 668 padding: BroadcastingList1[int] = 0, 669 dilation: BroadcastingList1[int] = 1, 670 ceil_mode: bool = False, 671 return_indices: bool = False, 672) -> Tuple[Tensor, Tensor]: # noqa: D400 673 r""" 674 max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) 675 676 Applies a 1D max pooling over an input signal composed of several input 677 planes. 678 679 .. note:: 680 The order of :attr:`ceil_mode` and :attr:`return_indices` is different from 681 what seen in :class:`~torch.nn.MaxPool1d`, and will change in a future release. 682 683 See :class:`~torch.nn.MaxPool1d` for details. 684 685 Args: 686 input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`, minibatch dim optional. 687 kernel_size: the size of the window. Can be a single number or a 688 tuple `(kW,)` 689 stride: the stride of the window. Can be a single number or a tuple 690 `(sW,)`. Default: :attr:`kernel_size` 691 padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. 692 dilation: The stride between elements within a sliding window, must be > 0. 693 ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This 694 ensures that every element in the input tensor is covered by a sliding window. 695 return_indices: If ``True``, will return the argmax along with the max values. 696 Useful for :class:`torch.nn.functional.max_unpool1d` later 697 """ 698 if has_torch_function_unary(input): 699 return handle_torch_function( 700 max_pool1d_with_indices, 701 (input,), 702 input, 703 kernel_size, 704 stride=stride, 705 padding=padding, 706 dilation=dilation, 707 ceil_mode=ceil_mode, 708 return_indices=return_indices, 709 ) 710 if stride is None: 711 stride = torch.jit.annotate(List[int], []) 712 return torch.max_pool1d_with_indices( 713 input, kernel_size, stride, padding, dilation, ceil_mode 714 ) 715 716 717def _max_pool1d( 718 input: Tensor, 719 kernel_size: BroadcastingList1[int], 720 stride: Optional[BroadcastingList1[int]] = None, 721 padding: BroadcastingList1[int] = 0, 722 dilation: BroadcastingList1[int] = 1, 723 ceil_mode: bool = False, 724 return_indices: bool = False, 725) -> Tensor: 726 if has_torch_function_unary(input): 727 return handle_torch_function( 728 max_pool1d, 729 (input,), 730 input, 731 kernel_size, 732 stride=stride, 733 padding=padding, 734 dilation=dilation, 735 ceil_mode=ceil_mode, 736 return_indices=return_indices, 737 ) 738 if stride is None: 739 stride = torch.jit.annotate(List[int], []) 740 return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) 741 742 743max_pool1d = boolean_dispatch( 744 arg_name="return_indices", 745 arg_index=6, 746 default=False, 747 if_true=max_pool1d_with_indices, 748 if_false=_max_pool1d, 749 module_name=__name__, 750 func_name="max_pool1d", 751) 752 753 754def max_pool2d_with_indices( 755 input: Tensor, 756 kernel_size: BroadcastingList2[int], 757 stride: Optional[BroadcastingList2[int]] = None, 758 padding: BroadcastingList2[int] = 0, 759 dilation: BroadcastingList2[int] = 1, 760 ceil_mode: bool = False, 761 return_indices: bool = False, 762) -> Tuple[Tensor, Tensor]: # noqa: D400 763 r""" 764 max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) 765 766 Applies a 2D max pooling over an input signal composed of several input 767 planes. 768 769 .. note:: 770 The order of :attr:`ceil_mode` and :attr:`return_indices` is different from 771 what seen in :class:`~torch.nn.MaxPool2d`, and will change in a future release. 772 773 See :class:`~torch.nn.MaxPool2d` for details. 774 775 Args: 776 input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`, minibatch dim optional. 777 kernel_size: size of the pooling region. Can be a single number or a 778 tuple `(kH, kW)` 779 stride: stride of the pooling operation. Can be a single number or a 780 tuple `(sH, sW)`. Default: :attr:`kernel_size` 781 padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. 782 dilation: The stride between elements within a sliding window, must be > 0. 783 ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This 784 ensures that every element in the input tensor is covered by a sliding window. 785 return_indices: If ``True``, will return the argmax along with the max values. 786 Useful for :class:`torch.nn.functional.max_unpool2d` later 787 """ 788 if has_torch_function_unary(input): 789 return handle_torch_function( 790 max_pool2d_with_indices, 791 (input,), 792 input, 793 kernel_size, 794 stride=stride, 795 padding=padding, 796 dilation=dilation, 797 ceil_mode=ceil_mode, 798 return_indices=return_indices, 799 ) 800 if stride is None: 801 stride = torch.jit.annotate(List[int], []) 802 return torch._C._nn.max_pool2d_with_indices( 803 input, kernel_size, stride, padding, dilation, ceil_mode 804 ) 805 806 807def _max_pool2d( 808 input: Tensor, 809 kernel_size: BroadcastingList2[int], 810 stride: Optional[BroadcastingList2[int]] = None, 811 padding: BroadcastingList2[int] = 0, 812 dilation: BroadcastingList2[int] = 1, 813 ceil_mode: bool = False, 814 return_indices: bool = False, 815) -> Tensor: 816 if has_torch_function_unary(input): 817 return handle_torch_function( 818 max_pool2d, 819 (input,), 820 input, 821 kernel_size, 822 stride=stride, 823 padding=padding, 824 dilation=dilation, 825 ceil_mode=ceil_mode, 826 return_indices=return_indices, 827 ) 828 if stride is None: 829 stride = torch.jit.annotate(List[int], []) 830 return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) 831 832 833max_pool2d = boolean_dispatch( 834 arg_name="return_indices", 835 arg_index=6, 836 default=False, 837 if_true=max_pool2d_with_indices, 838 if_false=_max_pool2d, 839 module_name=__name__, 840 func_name="max_pool2d", 841) 842 843 844def max_pool3d_with_indices( 845 input: Tensor, 846 kernel_size: BroadcastingList3[int], 847 stride: Optional[BroadcastingList3[int]] = None, 848 padding: BroadcastingList3[int] = 0, 849 dilation: BroadcastingList3[int] = 1, 850 ceil_mode: bool = False, 851 return_indices: bool = False, 852) -> Tuple[Tensor, Tensor]: # noqa: D400 853 r""" 854 max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) 855 856 Applies a 3D max pooling over an input signal composed of several input 857 planes. 858 859 .. note:: 860 The order of :attr:`ceil_mode` and :attr:`return_indices` is different from 861 what seen in :class:`~torch.nn.MaxPool3d`, and will change in a future release. 862 863 See :class:`~torch.nn.MaxPool3d` for details. 864 865 Args: 866 input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iD, iH , iW)`, minibatch dim optional. 867 kernel_size: size of the pooling region. Can be a single number or a 868 tuple `(kT, kH, kW)` 869 stride: stride of the pooling operation. Can be a single number or a 870 tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` 871 padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. 872 dilation: The stride between elements within a sliding window, must be > 0. 873 ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This 874 ensures that every element in the input tensor is covered by a sliding window. 875 return_indices: If ``True``, will return the argmax along with the max values. 876 Useful for :class:`torch.nn.functional.max_unpool3d` later 877 """ 878 if has_torch_function_unary(input): 879 return handle_torch_function( 880 max_pool3d_with_indices, 881 (input,), 882 input, 883 kernel_size, 884 stride=stride, 885 padding=padding, 886 dilation=dilation, 887 ceil_mode=ceil_mode, 888 return_indices=return_indices, 889 ) 890 if stride is None: 891 stride = torch.jit.annotate(List[int], []) 892 return torch._C._nn.max_pool3d_with_indices( 893 input, kernel_size, stride, padding, dilation, ceil_mode 894 ) 895 896 897def _max_pool3d( 898 input: Tensor, 899 kernel_size: BroadcastingList3[int], 900 stride: Optional[BroadcastingList3[int]] = None, 901 padding: BroadcastingList3[int] = 0, 902 dilation: BroadcastingList3[int] = 1, 903 ceil_mode: bool = False, 904 return_indices: bool = False, 905) -> Tensor: 906 if has_torch_function_unary(input): 907 return handle_torch_function( 908 max_pool3d, 909 (input,), 910 input, 911 kernel_size, 912 stride=stride, 913 padding=padding, 914 dilation=dilation, 915 ceil_mode=ceil_mode, 916 return_indices=return_indices, 917 ) 918 if stride is None: 919 stride = torch.jit.annotate(List[int], []) 920 return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) 921 922 923max_pool3d = boolean_dispatch( 924 arg_name="return_indices", 925 arg_index=6, 926 default=False, 927 if_true=max_pool3d_with_indices, 928 if_false=_max_pool3d, 929 module_name=__name__, 930 func_name="max_pool3d", 931) 932 933 934def _unpool_output_size( 935 input: Tensor, 936 kernel_size: List[int], 937 stride: List[int], 938 padding: List[int], 939 output_size: Optional[List[int]], 940) -> List[int]: 941 input_size = input.size() 942 default_size = torch.jit.annotate(List[int], []) 943 for d in range(len(kernel_size)): 944 default_size.append( 945 (input_size[-len(kernel_size) + d] - 1) * stride[d] 946 + kernel_size[d] 947 - 2 * padding[d] 948 ) 949 if output_size is None: 950 ret = default_size 951 else: 952 if len(output_size) == len(kernel_size) + 2: 953 output_size = output_size[2:] 954 if len(output_size) != len(kernel_size): 955 raise ValueError( 956 "output_size should be a sequence containing " 957 f"{len(kernel_size)} or {len(kernel_size) + 2} elements, but it has a length of '{len(output_size)}'" 958 ) 959 for d in range(len(kernel_size)): 960 min_size = default_size[d] - stride[d] 961 max_size = default_size[d] + stride[d] 962 if not (min_size < output_size[d] < max_size): 963 raise ValueError( 964 f'invalid output_size "{output_size}" (dim {d} must be between {min_size} and {max_size})' 965 ) 966 967 ret = output_size 968 return ret 969 970 971def max_unpool1d( 972 input: Tensor, 973 indices: Tensor, 974 kernel_size: BroadcastingList1[int], 975 stride: Optional[BroadcastingList1[int]] = None, 976 padding: BroadcastingList1[int] = 0, 977 output_size: Optional[BroadcastingList1[int]] = None, 978) -> Tensor: 979 r"""Compute a partial inverse of :class:`MaxPool1d`. 980 981 See :class:`~torch.nn.MaxUnpool1d` for details. 982 """ 983 if has_torch_function_unary(input): 984 return handle_torch_function( 985 max_unpool1d, 986 (input,), 987 input, 988 indices, 989 kernel_size, 990 stride=stride, 991 padding=padding, 992 output_size=output_size, 993 ) 994 kernel_size = _single(kernel_size) 995 if stride is not None: 996 _stride = _single(stride) 997 else: 998 _stride = kernel_size 999 padding = _single(padding) 1000 output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) 1001 if isinstance(output_size, list): 1002 output_size = output_size + [1] 1003 else: 1004 output_size = output_size + (1,) 1005 return torch._C._nn.max_unpool2d( 1006 input.unsqueeze(-1), indices.unsqueeze(-1), output_size 1007 ).squeeze(-1) 1008 1009 1010def max_unpool2d( 1011 input: Tensor, 1012 indices: Tensor, 1013 kernel_size: BroadcastingList2[int], 1014 stride: Optional[BroadcastingList2[int]] = None, 1015 padding: BroadcastingList2[int] = 0, 1016 output_size: Optional[BroadcastingList2[int]] = None, 1017) -> Tensor: 1018 r"""Compute a partial inverse of :class:`MaxPool2d`. 1019 1020 See :class:`~torch.nn.MaxUnpool2d` for details. 1021 """ 1022 if has_torch_function_unary(input): 1023 return handle_torch_function( 1024 max_unpool2d, 1025 (input,), 1026 input, 1027 indices, 1028 kernel_size, 1029 stride=stride, 1030 padding=padding, 1031 output_size=output_size, 1032 ) 1033 kernel_size = _pair(kernel_size) 1034 if stride is not None: 1035 _stride = _pair(stride) 1036 else: 1037 _stride = kernel_size 1038 padding = _pair(padding) 1039 output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) 1040 return torch._C._nn.max_unpool2d(input, indices, output_size) 1041 1042 1043def max_unpool3d( 1044 input: Tensor, 1045 indices: Tensor, 1046 kernel_size: BroadcastingList3[int], 1047 stride: Optional[BroadcastingList3[int]] = None, 1048 padding: BroadcastingList3[int] = 0, 1049 output_size: Optional[BroadcastingList3[int]] = None, 1050) -> Tensor: 1051 r"""Compute a partial inverse of :class:`MaxPool3d`. 1052 1053 See :class:`~torch.nn.MaxUnpool3d` for details. 1054 """ 1055 if has_torch_function_unary(input): 1056 return handle_torch_function( 1057 max_unpool3d, 1058 (input,), 1059 input, 1060 indices, 1061 kernel_size, 1062 stride=stride, 1063 padding=padding, 1064 output_size=output_size, 1065 ) 1066 kernel_size = _triple(kernel_size) 1067 if stride is not None: 1068 _stride = _triple(stride) 1069 else: 1070 _stride = kernel_size 1071 padding = _triple(padding) 1072 output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) 1073 return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding) 1074 1075 1076def lp_pool3d( 1077 input: Tensor, 1078 norm_type: Union[int, float], 1079 kernel_size: BroadcastingList3[int], 1080 stride: Optional[BroadcastingList3[int]] = None, 1081 ceil_mode: bool = False, 1082) -> Tensor: 1083 r""" 1084 Apply a 3D power-average pooling over an input signal composed of several input planes. 1085 1086 If the sum of all inputs to the power of `p` is 1087 zero, the gradient is set to zero as well. 1088 1089 See :class:`~torch.nn.LPPool3d` for details. 1090 """ 1091 if has_torch_function_unary(input): 1092 return handle_torch_function( 1093 lp_pool3d, 1094 (input,), 1095 input, 1096 norm_type, 1097 kernel_size, 1098 stride=stride, 1099 ceil_mode=ceil_mode, 1100 ) 1101 kd, kw, kh = _triple(kernel_size) 1102 if stride is not None: 1103 out = avg_pool3d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) 1104 else: 1105 out = avg_pool3d( 1106 input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode 1107 ) 1108 1109 return ( 1110 (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type) 1111 ) 1112 1113 1114def lp_pool2d( 1115 input: Tensor, 1116 norm_type: Union[int, float], 1117 kernel_size: BroadcastingList2[int], 1118 stride: Optional[BroadcastingList2[int]] = None, 1119 ceil_mode: bool = False, 1120) -> Tensor: 1121 r""" 1122 Apply a 2D power-average pooling over an input signal composed of several input planes. 1123 1124 If the sum of all inputs to the power of `p` is 1125 zero, the gradient is set to zero as well. 1126 1127 See :class:`~torch.nn.LPPool2d` for details. 1128 """ 1129 if has_torch_function_unary(input): 1130 return handle_torch_function( 1131 lp_pool2d, 1132 (input,), 1133 input, 1134 norm_type, 1135 kernel_size, 1136 stride=stride, 1137 ceil_mode=ceil_mode, 1138 ) 1139 kw, kh = _pair(kernel_size) 1140 if stride is not None: 1141 out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) 1142 else: 1143 out = avg_pool2d( 1144 input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode 1145 ) 1146 1147 return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type) 1148 1149 1150def lp_pool1d( 1151 input: Tensor, 1152 norm_type: Union[int, float], 1153 kernel_size: int, 1154 stride: Optional[BroadcastingList1[int]] = None, 1155 ceil_mode: bool = False, 1156) -> Tensor: 1157 r"""Apply a 1D power-average pooling over an input signal composed of several input planes. 1158 1159 If the sum of all inputs to the power of `p` is 1160 zero, the gradient is set to zero as well. 1161 1162 See :class:`~torch.nn.LPPool1d` for details. 1163 """ 1164 if has_torch_function_unary(input): 1165 return handle_torch_function( 1166 lp_pool1d, 1167 (input,), 1168 input, 1169 norm_type, 1170 kernel_size, 1171 stride=stride, 1172 ceil_mode=ceil_mode, 1173 ) 1174 if stride is not None: 1175 out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) 1176 else: 1177 out = avg_pool1d( 1178 input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode 1179 ) 1180 1181 return ( 1182 (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type) 1183 ) 1184 1185 1186def adaptive_max_pool1d_with_indices( 1187 input: Tensor, 1188 output_size: BroadcastingList1[int], 1189 return_indices: bool = False, 1190) -> Tuple[Tensor, Tensor]: # noqa: D400 1191 r""" 1192 adaptive_max_pool1d(input, output_size, return_indices=False) 1193 1194 Applies a 1D adaptive max pooling over an input signal composed of 1195 several input planes. 1196 1197 See :class:`~torch.nn.AdaptiveMaxPool1d` for details and output shape. 1198 1199 Args: 1200 output_size: the target output size (single integer) 1201 return_indices: whether to return pooling indices. Default: ``False`` 1202 """ 1203 if has_torch_function_unary(input): 1204 return handle_torch_function( 1205 adaptive_max_pool1d_with_indices, 1206 (input,), 1207 input, 1208 output_size, 1209 return_indices=return_indices, 1210 ) 1211 return torch.adaptive_max_pool1d(input, output_size) 1212 1213 1214def _adaptive_max_pool1d( 1215 input: Tensor, 1216 output_size: BroadcastingList1[int], 1217 return_indices: bool = False, 1218) -> Tensor: 1219 if has_torch_function_unary(input): 1220 return handle_torch_function( 1221 adaptive_max_pool1d, 1222 (input,), 1223 input, 1224 output_size, 1225 return_indices=return_indices, 1226 ) 1227 return adaptive_max_pool1d_with_indices(input, output_size)[0] 1228 1229 1230adaptive_max_pool1d = boolean_dispatch( 1231 arg_name="return_indices", 1232 arg_index=2, 1233 default=False, 1234 if_true=adaptive_max_pool1d_with_indices, 1235 if_false=_adaptive_max_pool1d, 1236 module_name=__name__, 1237 func_name="adaptive_max_pool1d", 1238) 1239 1240 1241def adaptive_max_pool2d_with_indices( 1242 input: Tensor, 1243 output_size: BroadcastingList2[int], 1244 return_indices: bool = False, 1245) -> Tuple[Tensor, Tensor]: # noqa: D400 1246 r"""adaptive_max_pool2d(input, output_size, return_indices=False) 1247 1248 Applies a 2D adaptive max pooling over an input signal composed of 1249 several input planes. 1250 1251 See :class:`~torch.nn.AdaptiveMaxPool2d` for details and output shape. 1252 1253 Args: 1254 output_size: the target output size (single integer or 1255 double-integer tuple) 1256 return_indices: whether to return pooling indices. Default: ``False`` 1257 """ 1258 if has_torch_function_unary(input): 1259 return handle_torch_function( 1260 adaptive_max_pool2d_with_indices, 1261 (input,), 1262 input, 1263 output_size, 1264 return_indices=return_indices, 1265 ) 1266 output_size = _list_with_default(output_size, input.size()) 1267 return torch._C._nn.adaptive_max_pool2d(input, output_size) 1268 1269 1270def _adaptive_max_pool2d( 1271 input: Tensor, 1272 output_size: BroadcastingList2[int], 1273 return_indices: bool = False, 1274) -> Tensor: 1275 if has_torch_function_unary(input): 1276 return handle_torch_function( 1277 adaptive_max_pool2d, 1278 (input,), 1279 input, 1280 output_size, 1281 return_indices=return_indices, 1282 ) 1283 return adaptive_max_pool2d_with_indices(input, output_size)[0] 1284 1285 1286adaptive_max_pool2d = boolean_dispatch( 1287 arg_name="return_indices", 1288 arg_index=2, 1289 default=False, 1290 if_true=adaptive_max_pool2d_with_indices, 1291 if_false=_adaptive_max_pool2d, 1292 module_name=__name__, 1293 func_name="adaptive_max_pool2d", 1294) 1295 1296 1297def adaptive_max_pool3d_with_indices( 1298 input: Tensor, 1299 output_size: BroadcastingList3[int], 1300 return_indices: bool = False, 1301) -> Tuple[Tensor, Tensor]: # noqa: D400 1302 r""" 1303 adaptive_max_pool3d(input, output_size, return_indices=False) 1304 1305 Applies a 3D adaptive max pooling over an input signal composed of 1306 several input planes. 1307 1308 See :class:`~torch.nn.AdaptiveMaxPool3d` for details and output shape. 1309 1310 Args: 1311 output_size: the target output size (single integer or 1312 triple-integer tuple) 1313 return_indices: whether to return pooling indices. Default: ``False`` 1314 """ 1315 if has_torch_function_unary(input): 1316 return handle_torch_function( 1317 adaptive_max_pool3d_with_indices, 1318 (input,), 1319 input, 1320 output_size, 1321 return_indices=return_indices, 1322 ) 1323 output_size = _list_with_default(output_size, input.size()) 1324 return torch._C._nn.adaptive_max_pool3d(input, output_size) 1325 1326 1327def _adaptive_max_pool3d( 1328 input: Tensor, 1329 output_size: BroadcastingList3[int], 1330 return_indices: bool = False, 1331) -> Tensor: 1332 if has_torch_function_unary(input): 1333 return handle_torch_function( 1334 adaptive_max_pool3d, 1335 (input,), 1336 input, 1337 output_size, 1338 return_indices=return_indices, 1339 ) 1340 return adaptive_max_pool3d_with_indices(input, output_size)[0] 1341 1342 1343adaptive_max_pool3d = boolean_dispatch( 1344 arg_name="return_indices", 1345 arg_index=2, 1346 default=False, 1347 if_true=adaptive_max_pool3d_with_indices, 1348 if_false=_adaptive_max_pool3d, 1349 module_name=__name__, 1350 func_name="adaptive_max_pool3d", 1351) 1352 1353 1354adaptive_avg_pool1d = _add_docstr( 1355 torch.adaptive_avg_pool1d, 1356 r""" 1357adaptive_avg_pool1d(input, output_size) -> Tensor 1358 1359Applies a 1D adaptive average pooling over an input signal composed of 1360several input planes. 1361 1362See :class:`~torch.nn.AdaptiveAvgPool1d` for details and output shape. 1363 1364Args: 1365 output_size: the target output size (single integer) 1366""", 1367) 1368 1369 1370def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: 1371 r"""Apply a 2D adaptive average pooling over an input signal composed of several input planes. 1372 1373 See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. 1374 1375 Args: 1376 output_size: the target output size (single integer or 1377 double-integer tuple) 1378 """ 1379 if has_torch_function_unary(input): 1380 return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) 1381 _output_size = _list_with_default(output_size, input.size()) 1382 return torch._C._nn.adaptive_avg_pool2d(input, _output_size) 1383 1384 1385def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> Tensor: 1386 r"""Apply a 3D adaptive average pooling over an input signal composed of several input planes. 1387 1388 See :class:`~torch.nn.AdaptiveAvgPool3d` for details and output shape. 1389 1390 Args: 1391 output_size: the target output size (single integer or 1392 triple-integer tuple) 1393 """ 1394 if has_torch_function_unary(input): 1395 return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) 1396 _output_size = _list_with_default(output_size, input.size()) 1397 return torch._C._nn.adaptive_avg_pool3d(input, _output_size) 1398 1399 1400# Activation functions 1401def dropout( 1402 input: Tensor, 1403 p: float = 0.5, 1404 training: bool = True, 1405 inplace: bool = False, 1406) -> Tensor: 1407 r"""During training, randomly zeroes some elements of the input tensor with probability :attr:`p`. 1408 1409 Uses samples from a Bernoulli distribution. 1410 1411 See :class:`~torch.nn.Dropout` for details. 1412 1413 Args: 1414 p: probability of an element to be zeroed. Default: 0.5 1415 training: apply dropout if is ``True``. Default: ``True`` 1416 inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 1417 """ 1418 if has_torch_function_unary(input): 1419 return handle_torch_function( 1420 dropout, (input,), input, p=p, training=training, inplace=inplace 1421 ) 1422 if p < 0.0 or p > 1.0: 1423 raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1424 return ( 1425 _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) 1426 ) 1427 1428 1429def alpha_dropout( 1430 input: Tensor, 1431 p: float = 0.5, 1432 training: bool = False, 1433 inplace: bool = False, 1434) -> Tensor: 1435 r"""Apply alpha dropout to the input. 1436 1437 See :class:`~torch.nn.AlphaDropout` for details. 1438 """ 1439 if has_torch_function_unary(input): 1440 return handle_torch_function( 1441 alpha_dropout, (input,), input, p=p, training=training, inplace=inplace 1442 ) 1443 if p < 0.0 or p > 1.0: 1444 raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1445 return ( 1446 _VF.alpha_dropout_(input, p, training) 1447 if inplace 1448 else _VF.alpha_dropout(input, p, training) 1449 ) 1450 1451 1452def dropout1d( 1453 input: Tensor, 1454 p: float = 0.5, 1455 training: bool = True, 1456 inplace: bool = False, 1457) -> Tensor: 1458 r"""Randomly zero out entire channels (a channel is a 1D feature map). 1459 1460 For example, the :math:`j`-th channel of the :math:`i`-th sample in the 1461 batched input is a 1D tensor :math:`\text{input}[i, j]` of the input tensor. 1462 Each channel will be zeroed out independently on every forward call with 1463 probability :attr:`p` using samples from a Bernoulli distribution. 1464 1465 See :class:`~torch.nn.Dropout1d` for details. 1466 1467 Args: 1468 p: probability of a channel to be zeroed. Default: 0.5 1469 training: apply dropout if is ``True``. Default: ``True`` 1470 inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 1471 """ 1472 if has_torch_function_unary(input): 1473 return handle_torch_function( 1474 dropout1d, (input,), input, p=p, training=training, inplace=inplace 1475 ) 1476 if p < 0.0 or p > 1.0: 1477 raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1478 inp_dim = input.dim() 1479 if inp_dim not in (2, 3): 1480 raise RuntimeError( 1481 f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. " 1482 "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 " 1483 "spatial dimension, a channel dimension, and an optional batch dimension " 1484 "(i.e. 2D or 3D inputs)." 1485 ) 1486 1487 is_batched = inp_dim == 3 1488 if not is_batched: 1489 input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) 1490 1491 result = ( 1492 _VF.feature_dropout_(input, p, training) 1493 if inplace 1494 else _VF.feature_dropout(input, p, training) 1495 ) 1496 1497 if not is_batched: 1498 result = result.squeeze_(0) if inplace else result.squeeze(0) 1499 1500 return result 1501 1502 1503def dropout2d( 1504 input: Tensor, 1505 p: float = 0.5, 1506 training: bool = True, 1507 inplace: bool = False, 1508) -> Tensor: 1509 r"""Randomly zero out entire channels (a channel is a 2D feature map). 1510 1511 For example, the :math:`j`-th channel of the :math:`i`-th sample in the 1512 batched input is a 2D tensor :math:`\text{input}[i, j]` of the input tensor. 1513 Each channel will be zeroed out independently on every forward call with 1514 probability :attr:`p` using samples from a Bernoulli distribution. 1515 1516 See :class:`~torch.nn.Dropout2d` for details. 1517 1518 Args: 1519 p: probability of a channel to be zeroed. Default: 0.5 1520 training: apply dropout if is ``True``. Default: ``True`` 1521 inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 1522 """ 1523 if has_torch_function_unary(input): 1524 return handle_torch_function( 1525 dropout2d, (input,), input, p=p, training=training, inplace=inplace 1526 ) 1527 if p < 0.0 or p > 1.0: 1528 raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1529 inp_dim = input.dim() 1530 if inp_dim not in (3, 4): 1531 warn_msg = ( 1532 f"dropout2d: Received a {inp_dim}-D input to dropout2d, which is deprecated " 1533 "and will result in an error in a future release. To retain the behavior " 1534 "and silence this warning, please use dropout instead. Note that dropout2d " 1535 "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, " 1536 "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)." 1537 ) 1538 warnings.warn(warn_msg) 1539 1540 # TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing 1541 # a 3D input will perform dropout1d behavior instead. This was done historically and the 1542 # behavior is maintained here for now. 1543 # See https://github.com/pytorch/pytorch/issues/77081 1544 if inp_dim == 3: 1545 warnings.warn( 1546 "dropout2d: Received a 3D input to dropout2d and assuming that channel-wise " 1547 "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C " 1548 "is the channel dim. This behavior will change in a future release to interpret the " 1549 "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D " 1550 "channel-wise dropout behavior, please switch to using dropout1d instead." 1551 ) 1552 1553 result = ( 1554 _VF.feature_dropout_(input, p, training) 1555 if inplace 1556 else _VF.feature_dropout(input, p, training) 1557 ) 1558 1559 return result 1560 1561 1562def dropout3d( 1563 input: Tensor, 1564 p: float = 0.5, 1565 training: bool = True, 1566 inplace: bool = False, 1567) -> Tensor: 1568 r"""Randomly zero out entire channels (a channel is a 3D feature map). 1569 1570 For example, the :math:`j`-th channel of the :math:`i`-th sample in the 1571 batched input is a 3D tensor :math:`\text{input}[i, j]` of the input tensor. 1572 Each channel will be zeroed out independently on every forward call with 1573 probability :attr:`p` using samples from a Bernoulli distribution. 1574 1575 See :class:`~torch.nn.Dropout3d` for details. 1576 1577 Args: 1578 p: probability of a channel to be zeroed. Default: 0.5 1579 training: apply dropout if is ``True``. Default: ``True`` 1580 inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 1581 """ 1582 if has_torch_function_unary(input): 1583 return handle_torch_function( 1584 dropout3d, (input,), input, p=p, training=training, inplace=inplace 1585 ) 1586 if p < 0.0 or p > 1.0: 1587 raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1588 inp_dim = input.dim() 1589 if inp_dim not in (4, 5): 1590 warn_msg = ( 1591 f"dropout3d: Received a {inp_dim}-D input to dropout3d, which is deprecated " 1592 "and will result in an error in a future release. To retain the behavior " 1593 "and silence this warning, please use dropout instead. Note that dropout3d " 1594 "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, " 1595 "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs)." 1596 ) 1597 warnings.warn(warn_msg) 1598 1599 is_batched = inp_dim == 5 1600 if not is_batched: 1601 input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) 1602 1603 result = ( 1604 _VF.feature_dropout_(input, p, training) 1605 if inplace 1606 else _VF.feature_dropout(input, p, training) 1607 ) 1608 1609 if not is_batched: 1610 result = result.squeeze_(0) if inplace else result.squeeze(0) 1611 return result 1612 1613 1614def feature_alpha_dropout( 1615 input: Tensor, 1616 p: float = 0.5, 1617 training: bool = False, 1618 inplace: bool = False, 1619) -> Tensor: 1620 r"""Randomly masks out entire channels (a channel is a feature map). 1621 1622 For example, the :math:`j`-th channel of the :math:`i`-th sample in the batch input 1623 is a tensor :math:`\text{input}[i, j]` of the input tensor. Instead of 1624 setting activations to zero, as in regular Dropout, the activations are set 1625 to the negative saturation value of the SELU activation function. 1626 1627 Each element will be masked independently on every forward call with 1628 probability :attr:`p` using samples from a Bernoulli distribution. 1629 The elements to be masked are randomized on every forward call, and scaled 1630 and shifted to maintain zero mean and unit variance. 1631 1632 See :class:`~torch.nn.FeatureAlphaDropout` for details. 1633 1634 Args: 1635 p: dropout probability of a channel to be zeroed. Default: 0.5 1636 training: apply dropout if is ``True``. Default: ``True`` 1637 inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 1638 """ 1639 if has_torch_function_unary(input): 1640 return handle_torch_function( 1641 feature_alpha_dropout, 1642 (input,), 1643 input, 1644 p=p, 1645 training=training, 1646 inplace=inplace, 1647 ) 1648 if p < 0.0 or p > 1.0: 1649 raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 1650 return ( 1651 _VF.feature_alpha_dropout_(input, p, training) 1652 if inplace 1653 else _VF.feature_alpha_dropout(input, p, training) 1654 ) 1655 1656 1657def _threshold( 1658 input: Tensor, 1659 threshold: float, 1660 value: float, 1661 inplace: bool = False, 1662) -> Tensor: 1663 r"""Apply a threshold to each element of the input Tensor. 1664 1665 See :class:`~torch.nn.Threshold` for more details. 1666 """ 1667 if has_torch_function_unary(input): 1668 return handle_torch_function( 1669 _threshold, (input,), input, threshold, value, inplace=inplace 1670 ) 1671 if inplace: 1672 result = _VF.threshold_(input, threshold, value) 1673 else: 1674 result = _VF.threshold(input, threshold, value) 1675 return result 1676 1677 1678# We define this function as _threshold because it takes an argument 1679# named threshold, which clobbers the recursive reference to the 1680# function needed for __torch_function__ support 1681threshold = _threshold 1682 1683threshold_ = _add_docstr( 1684 _VF.threshold_, 1685 r""" 1686threshold_(input, threshold, value) -> Tensor 1687 1688In-place version of :func:`~threshold`. 1689""", 1690) 1691 1692 1693def relu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 1694 r"""relu(input, inplace=False) -> Tensor 1695 1696 Applies the rectified linear unit function element-wise. See 1697 :class:`~torch.nn.ReLU` for more details. 1698 """ 1699 if has_torch_function_unary(input): 1700 return handle_torch_function(relu, (input,), input, inplace=inplace) 1701 if inplace: 1702 result = torch.relu_(input) 1703 else: 1704 result = torch.relu(input) 1705 return result 1706 1707 1708relu_ = _add_docstr( 1709 torch.relu_, 1710 r""" 1711relu_(input) -> Tensor 1712 1713In-place version of :func:`~relu`. 1714""", 1715) 1716 1717 1718def glu(input: Tensor, dim: int = -1) -> Tensor: # noqa: D400,D402 1719 r""" 1720 glu(input, dim=-1) -> Tensor 1721 1722 The gated linear unit. Computes: 1723 1724 .. math :: 1725 \text{GLU}(a, b) = a \otimes \sigma(b) 1726 1727 where `input` is split in half along `dim` to form `a` and `b`, :math:`\sigma` 1728 is the sigmoid function and :math:`\otimes` is the element-wise product between matrices. 1729 1730 See `Language Modeling with Gated Convolutional Networks <https://arxiv.org/abs/1612.08083>`_. 1731 1732 Args: 1733 input (Tensor): input tensor 1734 dim (int): dimension on which to split the input. Default: -1 1735 """ 1736 if has_torch_function_unary(input): 1737 return handle_torch_function(glu, (input,), input, dim=dim) 1738 if input.dim() == 0: 1739 raise RuntimeError( 1740 "glu does not support scalars because halving size must be even" 1741 ) 1742 return torch._C._nn.glu(input, dim) 1743 1744 1745def hardtanh( 1746 input: Tensor, 1747 min_val: float = -1.0, 1748 max_val: float = 1.0, 1749 inplace: bool = False, 1750) -> Tensor: # noqa: D400,D402 1751 r""" 1752 hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor 1753 1754 Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more 1755 details. 1756 """ 1757 if has_torch_function_unary(input): 1758 return handle_torch_function( 1759 hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace 1760 ) 1761 if min_val > max_val: 1762 raise ValueError("min_val cannot be greater than max_val") 1763 if inplace: 1764 result = torch._C._nn.hardtanh_(input, min_val, max_val) 1765 else: 1766 result = torch._C._nn.hardtanh(input, min_val, max_val) 1767 return result 1768 1769 1770hardtanh_ = _add_docstr( 1771 torch._C._nn.hardtanh_, 1772 r""" 1773hardtanh_(input, min_val=-1., max_val=1.) -> Tensor 1774 1775In-place version of :func:`~hardtanh`. 1776""", 1777) 1778 1779 1780def relu6(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 1781 r"""relu6(input, inplace=False) -> Tensor 1782 1783 Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`. 1784 1785 See :class:`~torch.nn.ReLU6` for more details. 1786 """ 1787 if has_torch_function_unary(input): 1788 return handle_torch_function(relu6, (input,), input, inplace=inplace) 1789 if inplace: 1790 result = torch._C._nn.relu6_(input) 1791 else: 1792 result = torch._C._nn.relu6(input) 1793 return result 1794 1795 1796def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: 1797 r"""Apply the Exponential Linear Unit (ELU) function element-wise. 1798 1799 See :class:`~torch.nn.ELU` for more details. 1800 """ 1801 if has_torch_function_unary(input): 1802 return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) 1803 if inplace: 1804 result = torch._C._nn.elu_(input, alpha) 1805 else: 1806 result = torch._C._nn.elu(input, alpha) 1807 return result 1808 1809 1810elu_ = _add_docstr( 1811 torch._C._nn.elu_, 1812 r""" 1813elu_(input, alpha=1.) -> Tensor 1814 1815In-place version of :func:`~elu`. 1816""", 1817) 1818 1819 1820def selu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 1821 r"""selu(input, inplace=False) -> Tensor 1822 1823 Applies element-wise, 1824 :math:`\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))`, 1825 with :math:`\alpha=1.6732632423543772848170429916717` and 1826 :math:`scale=1.0507009873554804934193349852946`. 1827 1828 See :class:`~torch.nn.SELU` for more details. 1829 """ 1830 if has_torch_function_unary(input): 1831 return handle_torch_function(selu, (input,), input, inplace=inplace) 1832 if inplace: 1833 result = torch.selu_(input) 1834 else: 1835 result = torch.selu(input) 1836 return result 1837 1838 1839selu_ = _add_docstr( 1840 torch.selu_, 1841 r""" 1842selu_(input) -> Tensor 1843 1844In-place version of :func:`~selu`. 1845""", 1846) 1847 1848 1849def celu( 1850 input: Tensor, 1851 alpha: float = 1.0, 1852 inplace: bool = False, 1853) -> Tensor: # noqa: D400,D402 1854 r"""celu(input, alpha=1., inplace=False) -> Tensor 1855 1856 Applies element-wise, 1857 :math:`\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))`. 1858 1859 See :class:`~torch.nn.CELU` for more details. 1860 """ 1861 if has_torch_function_unary(input): 1862 return handle_torch_function( 1863 celu, (input,), input, alpha=alpha, inplace=inplace 1864 ) 1865 if inplace: 1866 result = torch.celu_(input, alpha) 1867 else: 1868 result = torch.celu(input, alpha) 1869 return result 1870 1871 1872celu_ = _add_docstr( 1873 torch.celu_, 1874 r""" 1875celu_(input, alpha=1.) -> Tensor 1876 1877In-place version of :func:`~celu`. 1878""", 1879) 1880 1881 1882def leaky_relu( 1883 input: Tensor, 1884 negative_slope: float = 0.01, 1885 inplace: bool = False, 1886) -> Tensor: # noqa: D400,D402 1887 r""" 1888 leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor 1889 1890 Applies element-wise, 1891 :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)` 1892 1893 See :class:`~torch.nn.LeakyReLU` for more details. 1894 """ 1895 if has_torch_function_unary(input): 1896 return handle_torch_function( 1897 leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace 1898 ) 1899 if inplace: 1900 result = torch._C._nn.leaky_relu_(input, negative_slope) 1901 else: 1902 result = torch._C._nn.leaky_relu(input, negative_slope) 1903 return result 1904 1905 1906leaky_relu_ = _add_docstr( 1907 torch._C._nn.leaky_relu_, 1908 r""" 1909leaky_relu_(input, negative_slope=0.01) -> Tensor 1910 1911In-place version of :func:`~leaky_relu`. 1912""", 1913) 1914 1915 1916prelu = _add_docstr( 1917 torch.prelu, 1918 r"""prelu(input, weight) -> Tensor 1919 1920Applies element-wise the function 1921:math:`\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x)` where weight is a 1922learnable parameter. 1923 1924.. note:: 1925 `weight` is expected to be a scalar or 1-D tensor. If `weight` is 1-D, 1926 its size must match the number of input channels, determined by 1927 `input.size(1)` when `input.dim() >= 2`, otherwise 1. 1928 In the 1-D case, note that when `input` has dim > 2, `weight` can be expanded 1929 to the shape of `input` in a way that is not possible using normal 1930 :ref:`broadcasting semantics<broadcasting-semantics>`. 1931 1932See :class:`~torch.nn.PReLU` for more details. 1933""", 1934) 1935 1936 1937def rrelu( 1938 input: Tensor, 1939 lower: float = 1.0 / 8, 1940 upper: float = 1.0 / 3, 1941 training: bool = False, 1942 inplace: bool = False, 1943) -> Tensor: # noqa: D400,D402 1944 r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor 1945 1946 Randomized leaky ReLU. 1947 1948 See :class:`~torch.nn.RReLU` for more details. 1949 """ 1950 if has_torch_function_unary(input): 1951 return handle_torch_function( 1952 rrelu, 1953 (input,), 1954 input, 1955 lower=lower, 1956 upper=upper, 1957 training=training, 1958 inplace=inplace, 1959 ) 1960 if inplace: 1961 result = torch.rrelu_(input, lower, upper, training) 1962 else: 1963 result = torch.rrelu(input, lower, upper, training) 1964 return result 1965 1966 1967rrelu_ = _add_docstr( 1968 torch.rrelu_, 1969 r""" 1970rrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor 1971 1972In-place version of :func:`~rrelu`. 1973""", 1974) 1975 1976logsigmoid = _add_docstr( 1977 torch._C._nn.log_sigmoid, 1978 r""" 1979logsigmoid(input) -> Tensor 1980 1981Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)` 1982 1983See :class:`~torch.nn.LogSigmoid` for more details. 1984""", 1985) 1986 1987gelu = _add_docstr( 1988 torch._C._nn.gelu, 1989 r""" 1990gelu(input, approximate = 'none') -> Tensor 1991 1992When the approximate argument is 'none', it applies element-wise the function 1993:math:`\text{GELU}(x) = x * \Phi(x)` 1994 1995where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. 1996 1997When the approximate argument is 'tanh', Gelu is estimated with 1998 1999.. math:: 2000 \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) 2001 2002See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_. 2003""", 2004) 2005 2006hardshrink = _add_docstr( 2007 torch.hardshrink, 2008 r""" 2009hardshrink(input, lambd=0.5) -> Tensor 2010 2011Applies the hard shrinkage function element-wise 2012 2013See :class:`~torch.nn.Hardshrink` for more details. 2014""", 2015) 2016 2017 2018def tanhshrink(input): # noqa: D400,D402 2019 r"""tanhshrink(input) -> Tensor 2020 2021 Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)` 2022 2023 See :class:`~torch.nn.Tanhshrink` for more details. 2024 """ 2025 if has_torch_function_unary(input): 2026 return handle_torch_function(tanhshrink, (input,), input) 2027 return input - input.tanh() 2028 2029 2030def softsign(input): # noqa: D400,D402 2031 r"""softsign(input) -> Tensor 2032 2033 Applies element-wise, the function :math:`\text{SoftSign}(x) = \frac{x}{1 + |x|}` 2034 2035 See :class:`~torch.nn.Softsign` for more details. 2036 """ 2037 if has_torch_function_unary(input): 2038 return handle_torch_function(softsign, (input,), input) 2039 return input / (input.abs() + 1) 2040 2041 2042softplus = _add_docstr( 2043 torch._C._nn.softplus, 2044 r""" 2045softplus(input, beta=1, threshold=20) -> Tensor 2046 2047Applies element-wise, the function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`. 2048 2049For numerical stability the implementation reverts to the linear function 2050when :math:`input \times \beta > threshold`. 2051 2052See :class:`~torch.nn.Softplus` for more details. 2053""", 2054) 2055 2056 2057def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: 2058 warnings.warn( 2059 f"Implicit dimension choice for {name} has been deprecated. " 2060 "Change the call to include dim=X as an argument.", 2061 stacklevel=stacklevel, 2062 ) 2063 if ndim == 0 or ndim == 1 or ndim == 3: 2064 ret = 0 2065 else: 2066 ret = 1 2067 return ret 2068 2069 2070def softmin( 2071 input: Tensor, 2072 dim: Optional[int] = None, 2073 _stacklevel: int = 3, 2074 dtype: Optional[DType] = None, 2075) -> Tensor: 2076 r"""Apply a softmin function. 2077 2078 Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula. 2079 2080 See :class:`~torch.nn.Softmin` for more details. 2081 2082 Args: 2083 input (Tensor): input 2084 dim (int): A dimension along which softmin will be computed (so every slice 2085 along dim will sum to 1). 2086 dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 2087 If specified, the input tensor is casted to :attr:`dtype` before the operation 2088 is performed. This is useful for preventing data type overflows. Default: None. 2089 """ 2090 if has_torch_function_unary(input): 2091 return handle_torch_function( 2092 softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype 2093 ) 2094 if dim is None: 2095 dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) 2096 if dtype is None: 2097 ret = (-input).softmax(dim) 2098 else: 2099 ret = (-input).softmax(dim, dtype=dtype) 2100 return ret 2101 2102 2103def softmax( 2104 input: Tensor, 2105 dim: Optional[int] = None, 2106 _stacklevel: int = 3, 2107 dtype: Optional[DType] = None, 2108) -> Tensor: 2109 r"""Apply a softmax function. 2110 2111 Softmax is defined as: 2112 2113 :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}` 2114 2115 It is applied to all slices along dim, and will re-scale them so that the elements 2116 lie in the range `[0, 1]` and sum to 1. 2117 2118 See :class:`~torch.nn.Softmax` for more details. 2119 2120 Args: 2121 input (Tensor): input 2122 dim (int): A dimension along which softmax will be computed. 2123 dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 2124 If specified, the input tensor is casted to :attr:`dtype` before the operation 2125 is performed. This is useful for preventing data type overflows. Default: None. 2126 2127 .. note:: 2128 This function doesn't work directly with NLLLoss, 2129 which expects the Log to be computed between the Softmax and itself. 2130 Use log_softmax instead (it's faster and has better numerical properties). 2131 2132 """ 2133 if has_torch_function_unary(input): 2134 return handle_torch_function( 2135 softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype 2136 ) 2137 if dim is None: 2138 dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) 2139 if dtype is None: 2140 ret = input.softmax(dim) 2141 else: 2142 ret = input.softmax(dim, dtype=dtype) 2143 return ret 2144 2145 2146def gumbel_softmax( 2147 logits: Tensor, 2148 tau: float = 1, 2149 hard: bool = False, 2150 eps: float = 1e-10, 2151 dim: int = -1, 2152) -> Tensor: 2153 r""" 2154 Sample from the Gumbel-Softmax distribution (`Link 1`_ `Link 2`_) and optionally discretize. 2155 2156 Args: 2157 logits: `[..., num_features]` unnormalized log probabilities 2158 tau: non-negative scalar temperature 2159 hard: if ``True``, the returned samples will be discretized as one-hot vectors, 2160 but will be differentiated as if it is the soft sample in autograd 2161 dim (int): A dimension along which softmax will be computed. Default: -1. 2162 2163 Returns: 2164 Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. 2165 If ``hard=True``, the returned samples will be one-hot, otherwise they will 2166 be probability distributions that sum to 1 across `dim`. 2167 2168 .. note:: 2169 This function is here for legacy reasons, may be removed from nn.Functional in the future. 2170 2171 .. note:: 2172 The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft` 2173 2174 It achieves two things: 2175 - makes the output value exactly one-hot 2176 (since we add then subtract y_soft value) 2177 - makes the gradient equal to y_soft gradient 2178 (since we strip all other gradients) 2179 2180 Examples:: 2181 >>> logits = torch.randn(20, 32) 2182 >>> # Sample soft categorical using reparametrization trick: 2183 >>> F.gumbel_softmax(logits, tau=1, hard=False) 2184 >>> # Sample hard categorical using "Straight-through" trick: 2185 >>> F.gumbel_softmax(logits, tau=1, hard=True) 2186 2187 .. _Link 1: 2188 https://arxiv.org/abs/1611.00712 2189 .. _Link 2: 2190 https://arxiv.org/abs/1611.01144 2191 """ 2192 if has_torch_function_unary(logits): 2193 return handle_torch_function( 2194 gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim 2195 ) 2196 if eps != 1e-10: 2197 warnings.warn("`eps` parameter is deprecated and has no effect.") 2198 2199 gumbels = ( 2200 -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) 2201 .exponential_() 2202 .log() 2203 ) # ~Gumbel(0,1) 2204 gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 2205 y_soft = gumbels.softmax(dim) 2206 2207 if hard: 2208 # Straight through. 2209 index = y_soft.max(dim, keepdim=True)[1] 2210 y_hard = torch.zeros_like( 2211 logits, memory_format=torch.legacy_contiguous_format 2212 ).scatter_(dim, index, 1.0) 2213 ret = y_hard - y_soft.detach() + y_soft 2214 else: 2215 # Reparametrization trick. 2216 ret = y_soft 2217 return ret 2218 2219 2220def log_softmax( 2221 input: Tensor, 2222 dim: Optional[int] = None, 2223 _stacklevel: int = 3, 2224 dtype: Optional[DType] = None, 2225) -> Tensor: 2226 r"""Apply a softmax followed by a logarithm. 2227 2228 While mathematically equivalent to log(softmax(x)), doing these two 2229 operations separately is slower and numerically unstable. This function 2230 uses an alternative formulation to compute the output and gradient correctly. 2231 2232 See :class:`~torch.nn.LogSoftmax` for more details. 2233 2234 Args: 2235 input (Tensor): input 2236 dim (int): A dimension along which log_softmax will be computed. 2237 dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 2238 If specified, the input tensor is cast to :attr:`dtype` before the operation 2239 is performed. This is useful for preventing data type overflows. Default: None. 2240 """ 2241 if has_torch_function_unary(input): 2242 return handle_torch_function( 2243 log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype 2244 ) 2245 if dim is None: 2246 dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) 2247 if dtype is None: 2248 ret = input.log_softmax(dim) 2249 else: 2250 ret = input.log_softmax(dim, dtype=dtype) 2251 return ret 2252 2253 2254softshrink = _add_docstr( 2255 torch._C._nn.softshrink, 2256 r""" 2257softshrink(input, lambd=0.5) -> Tensor 2258 2259Applies the soft shrinkage function elementwise 2260 2261See :class:`~torch.nn.Softshrink` for more details. 2262""", 2263) 2264 2265 2266def tanh(input): # noqa: D400,D402 2267 r"""tanh(input) -> Tensor 2268 2269 Applies element-wise, 2270 :math:`\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}` 2271 2272 See :class:`~torch.nn.Tanh` for more details. 2273 """ 2274 return input.tanh() 2275 2276 2277def sigmoid(input): # noqa: D400,D402 2278 r"""sigmoid(input) -> Tensor 2279 2280 Applies the element-wise function :math:`\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}` 2281 2282 See :class:`~torch.nn.Sigmoid` for more details. 2283 """ 2284 return input.sigmoid() 2285 2286 2287def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: 2288 r"""Apply the Hardsigmoid function element-wise. 2289 2290 .. math:: 2291 \text{Hardsigmoid}(x) = \begin{cases} 2292 0 & \text{if~} x \le -3, \\ 2293 1 & \text{if~} x \ge +3, \\ 2294 x / 6 + 1 / 2 & \text{otherwise} 2295 \end{cases} 2296 2297 Args: 2298 inplace: If set to ``True``, will do this operation in-place. Default: ``False`` 2299 2300 See :class:`~torch.nn.Hardsigmoid` for more details. 2301 """ 2302 if has_torch_function_unary(input): 2303 return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) 2304 if inplace: 2305 return torch._C._nn.hardsigmoid_(input) 2306 return torch._C._nn.hardsigmoid(input) 2307 2308 2309linear = _add_docstr( 2310 torch._C._nn.linear, 2311 r""" 2312linear(input, weight, bias=None) -> Tensor 2313 2314Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. 2315 2316This operation supports 2-D :attr:`weight` with :ref:`sparse layout<sparse-docs>` 2317 2318{sparse_beta_warning} 2319 2320This operator supports :ref:`TensorFloat32<tf32_on_ampere>`. 2321 2322Shape: 2323 2324 - Input: :math:`(*, in\_features)` where `*` means any number of 2325 additional dimensions, including none 2326 - Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)` 2327 - Bias: :math:`(out\_features)` or :math:`()` 2328 - Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight 2329""".format( 2330 **sparse_support_notes 2331 ), 2332) 2333 2334 2335bilinear = _add_docstr( 2336 torch.bilinear, 2337 r""" 2338bilinear(input1, input2, weight, bias=None) -> Tensor 2339 2340Applies a bilinear transformation to the incoming data: 2341:math:`y = x_1^T A x_2 + b` 2342 2343Shape: 2344 2345 - input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` 2346 and :math:`*` means any number of additional dimensions. 2347 All but the last dimension of the inputs should be the same. 2348 - input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}` 2349 - weight: :math:`(\text{out\_features}, \text{in1\_features}, 2350 \text{in2\_features})` 2351 - bias: :math:`(\text{out\_features})` 2352 - output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}` 2353 and all but the last dimension are the same shape as the input. 2354""", 2355) 2356 2357 2358def silu(input: Tensor, inplace: bool = False) -> Tensor: 2359 r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise. 2360 2361 The SiLU function is also known as the swish function. 2362 2363 .. math:: 2364 \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} 2365 2366 .. note:: 2367 See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_ 2368 where the SiLU (Sigmoid Linear Unit) was originally coined, and see 2369 `Sigmoid-Weighted Linear Units for Neural Network Function Approximation 2370 in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish: 2371 a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_ 2372 where the SiLU was experimented with later. 2373 2374 See :class:`~torch.nn.SiLU` for more details. 2375 """ 2376 if has_torch_function_unary(input): 2377 return handle_torch_function(silu, (input,), input, inplace=inplace) 2378 if inplace: 2379 return torch._C._nn.silu_(input) 2380 return torch._C._nn.silu(input) 2381 2382 2383def mish(input: Tensor, inplace: bool = False) -> Tensor: 2384 r"""Apply the Mish function, element-wise. 2385 2386 Mish: A Self Regularized Non-Monotonic Neural Activation Function. 2387 2388 .. math:: 2389 \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) 2390 2391 .. note:: 2392 See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_ 2393 2394 See :class:`~torch.nn.Mish` for more details. 2395 """ 2396 if has_torch_function_unary(input): 2397 return handle_torch_function(mish, (input,), input, inplace=inplace) 2398 if inplace: 2399 return torch._C._nn.mish_(input) 2400 return torch._C._nn.mish(input) 2401 2402 2403def hardswish(input: Tensor, inplace: bool = False) -> Tensor: 2404 r"""Apply hardswish function, element-wise. 2405 2406 Follows implementation as described in the paper: 2407 `Searching for MobileNetV3`_. 2408 2409 .. math:: 2410 \text{Hardswish}(x) = \begin{cases} 2411 0 & \text{if~} x \le -3, \\ 2412 x & \text{if~} x \ge +3, \\ 2413 x \cdot (x + 3) /6 & \text{otherwise} 2414 \end{cases} 2415 2416 See :class:`~torch.nn.Hardswish` for more details. 2417 2418 .. _`Searching for MobileNetV3`: 2419 https://arxiv.org/abs/1905.02244 2420 """ 2421 if has_torch_function_unary(input): 2422 return handle_torch_function(hardswish, (input,), input, inplace=inplace) 2423 if inplace: 2424 return torch._C._nn.hardswish_(input) 2425 return torch._C._nn.hardswish(input) 2426 2427 2428def _no_grad_embedding_renorm_( 2429 weight: Tensor, 2430 input: Tensor, 2431 max_norm: float, 2432 norm_type: float, 2433) -> Tuple[Tensor, Tensor]: 2434 torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type) 2435 2436 2437def embedding( 2438 input: Tensor, 2439 weight: Tensor, 2440 padding_idx: Optional[int] = None, 2441 max_norm: Optional[float] = None, 2442 norm_type: float = 2.0, 2443 scale_grad_by_freq: bool = False, 2444 sparse: bool = False, 2445) -> Tensor: 2446 r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size. 2447 2448 This module is often used to retrieve word embeddings using indices. 2449 The input to the module is a list of indices, and the embedding matrix, 2450 and the output is the corresponding word embeddings. 2451 2452 See :class:`torch.nn.Embedding` for more details. 2453 2454 .. note:: 2455 Note that the analytical gradients of this function with respect to 2456 entries in :attr:`weight` at the row specified by :attr:`padding_idx` 2457 are expected to differ from the numerical ones. 2458 2459 .. note:: 2460 Note that `:class:`torch.nn.Embedding` differs from this function in 2461 that it initializes the row of :attr:`weight` specified by 2462 :attr:`padding_idx` to all zeros on construction. 2463 2464 Args: 2465 input (LongTensor): Tensor containing indices into the embedding matrix 2466 weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1, 2467 and number of columns equal to the embedding size 2468 padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; 2469 therefore, the embedding vector at :attr:`padding_idx` is not updated during training, 2470 i.e. it remains as a fixed "pad". 2471 max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` 2472 is renormalized to have norm :attr:`max_norm`. 2473 Note: this will modify :attr:`weight` in-place. 2474 norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. 2475 scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of 2476 the words in the mini-batch. Default ``False``. 2477 sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under 2478 :class:`torch.nn.Embedding` for more details regarding sparse gradients. 2479 2480 Shape: 2481 - Input: LongTensor of arbitrary shape containing the indices to extract 2482 - Weight: Embedding matrix of floating point type with shape `(V, embedding_dim)`, 2483 where V = maximum index + 1 and embedding_dim = the embedding size 2484 - Output: `(*, embedding_dim)`, where `*` is the input shape 2485 2486 Examples:: 2487 2488 >>> # a batch of 2 samples of 4 indices each 2489 >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) 2490 >>> # an embedding matrix containing 10 tensors of size 3 2491 >>> embedding_matrix = torch.rand(10, 3) 2492 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 2493 >>> F.embedding(input, embedding_matrix) 2494 tensor([[[ 0.8490, 0.9625, 0.6753], 2495 [ 0.9666, 0.7761, 0.6108], 2496 [ 0.6246, 0.9751, 0.3618], 2497 [ 0.4161, 0.2419, 0.7383]], 2498 2499 [[ 0.6246, 0.9751, 0.3618], 2500 [ 0.0237, 0.7794, 0.0528], 2501 [ 0.9666, 0.7761, 0.6108], 2502 [ 0.3385, 0.8612, 0.1867]]]) 2503 2504 >>> # example with padding_idx 2505 >>> weights = torch.rand(10, 3) 2506 >>> weights[0, :].zero_() 2507 >>> embedding_matrix = weights 2508 >>> input = torch.tensor([[0, 2, 0, 5]]) 2509 >>> F.embedding(input, embedding_matrix, padding_idx=0) 2510 tensor([[[ 0.0000, 0.0000, 0.0000], 2511 [ 0.5609, 0.5384, 0.8720], 2512 [ 0.0000, 0.0000, 0.0000], 2513 [ 0.6262, 0.2438, 0.7471]]]) 2514 """ 2515 if has_torch_function_variadic(input, weight): 2516 return handle_torch_function( 2517 embedding, 2518 (input, weight), 2519 input, 2520 weight, 2521 padding_idx=padding_idx, 2522 max_norm=max_norm, 2523 norm_type=norm_type, 2524 scale_grad_by_freq=scale_grad_by_freq, 2525 sparse=sparse, 2526 ) 2527 if padding_idx is not None: 2528 if padding_idx > 0: 2529 assert padding_idx < weight.size( 2530 0 2531 ), "Padding_idx must be within num_embeddings" 2532 elif padding_idx < 0: 2533 assert padding_idx >= -weight.size( 2534 0 2535 ), "Padding_idx must be within num_embeddings" 2536 padding_idx = weight.size(0) + padding_idx 2537 else: 2538 padding_idx = -1 2539 if max_norm is not None: 2540 # Note [embedding_renorm contiguous] 2541 # `embedding_renorm_` will call .contiguous() on input anyways, so we 2542 # call it here and take advantage of the improved locality in the 2543 # `embedding` call below too. 2544 input = input.contiguous() 2545 # Note [embedding_renorm set_grad_enabled] 2546 # XXX: equivalent to 2547 # with torch.no_grad(): 2548 # torch.embedding_renorm_ 2549 # remove once script supports set_grad_enabled 2550 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) 2551 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) 2552 2553 2554def embedding_bag( 2555 input: Tensor, 2556 weight: Tensor, 2557 offsets: Optional[Tensor] = None, 2558 max_norm: Optional[float] = None, 2559 norm_type: float = 2, 2560 scale_grad_by_freq: bool = False, 2561 mode: str = "mean", 2562 sparse: bool = False, 2563 per_sample_weights: Optional[Tensor] = None, 2564 include_last_offset: bool = False, 2565 padding_idx: Optional[int] = None, 2566) -> Tensor: 2567 r"""Compute sums, means or maxes of `bags` of embeddings. 2568 2569 Calculation is done without instantiating the intermediate embeddings. 2570 See :class:`torch.nn.EmbeddingBag` for more details. 2571 2572 Note: 2573 {backward_reproducibility_note} 2574 2575 Args: 2576 input (LongTensor): Tensor containing bags of indices into the embedding matrix 2577 weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1, 2578 and number of columns equal to the embedding size 2579 offsets (LongTensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines 2580 the starting index position of each bag (sequence) in :attr:`input`. 2581 max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` 2582 is renormalized to have norm :attr:`max_norm`. 2583 Note: this will modify :attr:`weight` in-place. 2584 norm_type (float, optional): The ``p`` in the ``p``-norm to compute for the :attr:`max_norm` option. 2585 Default ``2``. 2586 scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of 2587 the words in the mini-batch. Default ``False``. 2588 Note: this option is not supported when ``mode="max"``. 2589 mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. 2590 Default: ``"mean"`` 2591 sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under 2592 :class:`torch.nn.Embedding` for more details regarding sparse gradients. 2593 Note: this option is not supported when ``mode="max"``. 2594 per_sample_weights (Tensor, optional): a tensor of float / double weights, or None 2595 to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights` 2596 must have exactly the same shape as input and is treated as having the same 2597 :attr:`offsets`, if those are not None. 2598 2599 include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1. 2600 The last element is the size of the input, or the ending index position of the last bag (sequence). 2601 2602 padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the 2603 gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated 2604 during training, i.e. it remains as a fixed "pad". Note that the embedding 2605 vector at :attr:`padding_idx` is excluded from the reduction. 2606 2607 Shape: 2608 - :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional) 2609 2610 - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences) 2611 each of fixed length ``N``, and this will return ``B`` values aggregated in a way 2612 depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case. 2613 2614 - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of 2615 multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing 2616 the starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` 2617 of shape `(B)`, :attr:`input` will be viewed as having ``B`` bags. 2618 Empty bags (i.e., having 0-length) will have returned vectors filled by zeros. 2619 2620 - :attr:`weight` (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` 2621 2622 - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as :attr:`input`. 2623 2624 - :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)` 2625 2626 Examples:: 2627 2628 >>> # an Embedding module containing 10 tensors of size 3 2629 >>> embedding_matrix = torch.rand(10, 3) 2630 >>> # a batch of 2 samples of 4 indices each 2631 >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) 2632 >>> offsets = torch.tensor([0, 4]) 2633 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 2634 >>> F.embedding_bag(input, embedding_matrix, offsets) 2635 tensor([[ 0.3397, 0.3552, 0.5545], 2636 [ 0.5893, 0.4386, 0.5882]]) 2637 2638 >>> # example with padding_idx 2639 >>> embedding_matrix = torch.rand(10, 3) 2640 >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9]) 2641 >>> offsets = torch.tensor([0, 4]) 2642 >>> F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum') 2643 tensor([[ 0.0000, 0.0000, 0.0000], 2644 [-0.7082, 3.2145, -2.6251]]) 2645 """ 2646 if has_torch_function_variadic(input, weight, offsets, per_sample_weights): 2647 return handle_torch_function( 2648 embedding_bag, 2649 (input, weight, offsets, per_sample_weights), 2650 input, 2651 weight, 2652 offsets=offsets, 2653 max_norm=max_norm, 2654 norm_type=norm_type, 2655 scale_grad_by_freq=scale_grad_by_freq, 2656 mode=mode, 2657 sparse=sparse, 2658 per_sample_weights=per_sample_weights, 2659 include_last_offset=include_last_offset, 2660 padding_idx=padding_idx, 2661 ) 2662 # Check for backward compatibility. 2663 # Used to be embedding_bag(weight, input, ...) 2664 # Now is embedding_bag(input, weight, ...) 2665 if weight.dtype == torch.long and input.is_floating_point(): 2666 warnings.warn( 2667 "Argument order of nn.functional.embedding_bag was changed. " 2668 "Usage `embedding_bag(weight, input, ...)` is deprecated, " 2669 "and should now be `embedding_bag(input, weight, ...)`." 2670 ) 2671 weight, input = input, weight 2672 2673 if per_sample_weights is not None and input.size() != per_sample_weights.size(): 2674 raise ValueError( 2675 f"embedding_bag: If per_sample_weights ({per_sample_weights.shape}) is not None, " 2676 f"then it must have the same shape as the input ({input.shape})" 2677 ) 2678 2679 if not weight.dim() == 2: 2680 raise ValueError( 2681 f"weight has to be a 2D Tensor, but got Tensor of dimension {weight.dim()}" 2682 ) 2683 2684 if input.dim() == 2: 2685 if offsets is not None: 2686 type_str = "<unknown>" 2687 # TODO: Remove this once script supports type() calls 2688 if not torch.jit.is_scripting(): 2689 type_str = str(type(offsets)) 2690 raise ValueError( 2691 "if input is 2D, then offsets has to be None" 2692 ", as input is treated is a mini-batch of" 2693 " fixed length sequences. However, found " 2694 f"offsets of type {type_str}" 2695 ) 2696 offsets = torch.arange( 2697 0, input.numel(), input.size(1), dtype=input.dtype, device=input.device 2698 ) 2699 2700 input = input.reshape(-1) 2701 if per_sample_weights is not None: 2702 per_sample_weights = per_sample_weights.reshape(-1) 2703 elif input.dim() == 1: 2704 if offsets is None: 2705 raise ValueError("offsets has to be a 1D Tensor but got None") 2706 if offsets.dim() != 1: 2707 raise ValueError("offsets has to be a 1D Tensor") 2708 else: 2709 raise ValueError( 2710 f"input has to be 1D or 2D Tensor, but got Tensor of dimension {input.dim()}" 2711 ) 2712 if mode == "sum": 2713 mode_enum = 0 2714 elif mode == "mean": 2715 mode_enum = 1 2716 elif mode == "max": 2717 mode_enum = 2 2718 2719 if scale_grad_by_freq: 2720 raise ValueError( 2721 "max mode does not support scaling the gradient by the frequency" 2722 ) 2723 2724 if sparse: 2725 raise ValueError("max mode does not support sparse weights") 2726 2727 else: 2728 raise ValueError("mode has to be one of sum, mean or max") 2729 2730 if max_norm is not None: 2731 # XXX: equivalent to 2732 # with torch.no_grad(): 2733 # torch.nembedding_renorm_ 2734 # remove once script supports set_grad_enabled 2735 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) 2736 2737 if per_sample_weights is not None and mode != "sum": 2738 raise NotImplementedError( 2739 "embedding_bag: per_sample_weights was not None. " 2740 "per_sample_weights is only supported for mode='sum' " 2741 f"(got mode='{mode}'). Please open a feature request on GitHub." 2742 ) 2743 2744 ret, _, _, _ = torch.embedding_bag( 2745 weight, 2746 input, 2747 offsets, 2748 scale_grad_by_freq, 2749 mode_enum, 2750 sparse, 2751 per_sample_weights, 2752 include_last_offset, 2753 padding_idx, 2754 ) 2755 return ret 2756 2757 2758if embedding_bag.__doc__: 2759 embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes) 2760 2761 2762def _verify_batch_size(size: List[int]) -> None: 2763 # XXX: JIT script does not support the reduce from functools, and mul op is a 2764 # builtin, which cannot be used as a value to a func yet, so rewrite this size 2765 # check to a simple equivalent for loop 2766 # 2767 # TODO: make use of reduce like below when JIT is ready with the missing features: 2768 # from operator import mul 2769 # from functools import reduce 2770 # 2771 # if reduce(mul, size[2:], size[0]) == 1 2772 size_prods = size[0] 2773 for i in range(len(size) - 2): 2774 size_prods *= size[i + 2] 2775 if size_prods == 1: 2776 raise ValueError( 2777 f"Expected more than 1 value per channel when training, got input size {size}" 2778 ) 2779 2780 2781def batch_norm( 2782 input: Tensor, 2783 running_mean: Optional[Tensor], 2784 running_var: Optional[Tensor], 2785 weight: Optional[Tensor] = None, 2786 bias: Optional[Tensor] = None, 2787 training: bool = False, 2788 momentum: float = 0.1, 2789 eps: float = 1e-5, 2790) -> Tensor: 2791 r"""Apply Batch Normalization for each channel across a batch of data. 2792 2793 See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`, 2794 :class:`~torch.nn.BatchNorm3d` for details. 2795 """ 2796 if has_torch_function_variadic(input, running_mean, running_var, weight, bias): 2797 return handle_torch_function( 2798 batch_norm, 2799 (input, running_mean, running_var, weight, bias), 2800 input, 2801 running_mean, 2802 running_var, 2803 weight=weight, 2804 bias=bias, 2805 training=training, 2806 momentum=momentum, 2807 eps=eps, 2808 ) 2809 if training: 2810 _verify_batch_size(input.size()) 2811 2812 return torch.batch_norm( 2813 input, 2814 weight, 2815 bias, 2816 running_mean, 2817 running_var, 2818 training, 2819 momentum, 2820 eps, 2821 torch.backends.cudnn.enabled, 2822 ) 2823 2824 2825def _verify_spatial_size(size: List[int]) -> None: 2826 # Verify that there is > 1 spatial element for instance norm calculation. 2827 size_prods = 1 2828 for i in range(2, len(size)): 2829 size_prods *= size[i] 2830 if size_prods == 1: 2831 raise ValueError( 2832 f"Expected more than 1 spatial element when training, got input size {size}" 2833 ) 2834 2835 2836def instance_norm( 2837 input: Tensor, 2838 running_mean: Optional[Tensor] = None, 2839 running_var: Optional[Tensor] = None, 2840 weight: Optional[Tensor] = None, 2841 bias: Optional[Tensor] = None, 2842 use_input_stats: bool = True, 2843 momentum: float = 0.1, 2844 eps: float = 1e-5, 2845) -> Tensor: 2846 r"""Apply Instance Normalization independently for each channel in every data sample within a batch. 2847 2848 See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`, 2849 :class:`~torch.nn.InstanceNorm3d` for details. 2850 """ 2851 if has_torch_function_variadic(input, running_mean, running_var, weight, bias): 2852 return handle_torch_function( 2853 instance_norm, 2854 (input, running_mean, running_var, weight, bias), 2855 input, 2856 running_mean=running_mean, 2857 running_var=running_var, 2858 weight=weight, 2859 bias=bias, 2860 use_input_stats=use_input_stats, 2861 momentum=momentum, 2862 eps=eps, 2863 ) 2864 if use_input_stats: 2865 _verify_spatial_size(input.size()) 2866 return torch.instance_norm( 2867 input, 2868 weight, 2869 bias, 2870 running_mean, 2871 running_var, 2872 use_input_stats, 2873 momentum, 2874 eps, 2875 torch.backends.cudnn.enabled, 2876 ) 2877 2878 2879def layer_norm( 2880 input: Tensor, 2881 normalized_shape: List[int], 2882 weight: Optional[Tensor] = None, 2883 bias: Optional[Tensor] = None, 2884 eps: float = 1e-5, 2885) -> Tensor: 2886 r"""Apply Layer Normalization for last certain number of dimensions. 2887 2888 See :class:`~torch.nn.LayerNorm` for details. 2889 """ 2890 if has_torch_function_variadic(input, weight, bias): 2891 return handle_torch_function( 2892 layer_norm, 2893 (input, weight, bias), 2894 input, 2895 normalized_shape, 2896 weight=weight, 2897 bias=bias, 2898 eps=eps, 2899 ) 2900 return torch.layer_norm( 2901 input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled 2902 ) 2903 2904 2905def rms_norm( 2906 input: Tensor, 2907 normalized_shape: List[int], 2908 weight: Optional[Tensor] = None, 2909 eps: Optional[float] = None, 2910) -> Tensor: 2911 r"""Apply Root Mean Square Layer Normalization. 2912 2913 See :class:`~torch.nn.RMSNorm` for details. 2914 """ 2915 if has_torch_function_variadic(input, weight): 2916 return handle_torch_function( 2917 rms_norm, (input, weight), input, normalized_shape, weight=weight, eps=eps 2918 ) 2919 return torch.rms_norm(input, normalized_shape, weight, eps) 2920 2921 2922def group_norm( 2923 input: Tensor, 2924 num_groups: int, 2925 weight: Optional[Tensor] = None, 2926 bias: Optional[Tensor] = None, 2927 eps: float = 1e-5, 2928) -> Tensor: 2929 r"""Apply Group Normalization for last certain number of dimensions. 2930 2931 See :class:`~torch.nn.GroupNorm` for details. 2932 """ 2933 if has_torch_function_variadic(input, weight, bias): 2934 return handle_torch_function( 2935 group_norm, 2936 ( 2937 input, 2938 weight, 2939 bias, 2940 ), 2941 input, 2942 num_groups, 2943 weight=weight, 2944 bias=bias, 2945 eps=eps, 2946 ) 2947 if input.dim() < 2: 2948 raise RuntimeError( 2949 f"Expected at least 2 dimensions for input tensor but received {input.dim()}" 2950 ) 2951 _verify_batch_size( 2952 [input.size(0) * input.size(1) // num_groups, num_groups] 2953 + list(input.size()[2:]) 2954 ) 2955 return torch.group_norm( 2956 input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled 2957 ) 2958 2959 2960def local_response_norm( 2961 input: Tensor, 2962 size: int, 2963 alpha: float = 1e-4, 2964 beta: float = 0.75, 2965 k: float = 1.0, 2966) -> Tensor: 2967 r"""Apply local response normalization over an input signal. 2968 2969 The input signal is composed of several input planes, where channels occupy the second dimension. 2970 Normalization is applied across channels. 2971 2972 See :class:`~torch.nn.LocalResponseNorm` for details. 2973 """ 2974 if has_torch_function_unary(input): 2975 return handle_torch_function( 2976 local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k 2977 ) 2978 dim = input.dim() 2979 if dim < 3: 2980 raise ValueError( 2981 f"Expected 3D or higher dimensionality input (got {dim} dimensions)" 2982 ) 2983 2984 if input.numel() == 0: 2985 return input 2986 2987 div = input.mul(input) 2988 if dim == 3: 2989 div = div.unsqueeze(1) 2990 div = pad(div, (0, 0, size // 2, (size - 1) // 2)) 2991 div = avg_pool2d(div, (size, 1), stride=1).squeeze(1) 2992 else: 2993 sizes = input.size() 2994 div = div.view(sizes[0], 1, sizes[1], sizes[2], -1) 2995 div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2)) 2996 div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1) 2997 div = div.view(sizes) 2998 div = div.mul(alpha).add(k).pow(beta) 2999 return input / div 3000 3001 3002# loss 3003 3004 3005def ctc_loss( 3006 log_probs: Tensor, 3007 targets: Tensor, 3008 input_lengths: Tensor, 3009 target_lengths: Tensor, 3010 blank: int = 0, 3011 reduction: str = "mean", 3012 zero_infinity: bool = False, 3013) -> Tensor: 3014 r"""Apply the Connectionist Temporal Classification loss. 3015 3016 See :class:`~torch.nn.CTCLoss` for details. 3017 3018 Note: 3019 {cudnn_reproducibility_note} 3020 3021 Note: 3022 {backward_reproducibility_note} 3023 3024 Args: 3025 log_probs: :math:`(T, N, C)` or :math:`(T, C)` where `C = number of characters in alphabet including blank`, 3026 `T = input length`, and `N = batch size`. 3027 The logarithmized probabilities of the outputs 3028 (e.g. obtained with :func:`torch.nn.functional.log_softmax`). 3029 targets: :math:`(N, S)` or `(sum(target_lengths))`. 3030 Targets cannot be blank. In the second form, the targets are assumed to be concatenated. 3031 input_lengths: :math:`(N)` or :math:`()`. 3032 Lengths of the inputs (must each be :math:`\leq T`) 3033 target_lengths: :math:`(N)` or :math:`()`. 3034 Lengths of the targets 3035 blank (int, optional): 3036 Blank label. Default :math:`0`. 3037 reduction (str, optional): Specifies the reduction to apply to the output: 3038 ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3039 ``'mean'``: the output losses will be divided by the target lengths and 3040 then the mean over the batch is taken, ``'sum'``: the output will be 3041 summed. Default: ``'mean'`` 3042 zero_infinity (bool, optional): 3043 Whether to zero infinite losses and the associated gradients. 3044 Default: ``False`` 3045 Infinite losses mainly occur when the inputs are too short 3046 to be aligned to the targets. 3047 3048 Example:: 3049 3050 >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() 3051 >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) 3052 >>> input_lengths = torch.full((16,), 50, dtype=torch.long) 3053 >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) 3054 >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) 3055 >>> loss.backward() 3056 """ 3057 if has_torch_function_variadic(log_probs, targets, input_lengths, target_lengths): 3058 return handle_torch_function( 3059 ctc_loss, 3060 (log_probs, targets, input_lengths, target_lengths), 3061 log_probs, 3062 targets, 3063 input_lengths, 3064 target_lengths, 3065 blank=blank, 3066 reduction=reduction, 3067 zero_infinity=zero_infinity, 3068 ) 3069 return torch.ctc_loss( 3070 log_probs, 3071 targets, 3072 input_lengths, 3073 target_lengths, 3074 blank, 3075 _Reduction.get_enum(reduction), 3076 zero_infinity, 3077 ) 3078 3079 3080if ctc_loss.__doc__: 3081 ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes) 3082 3083 3084def nll_loss( 3085 input: Tensor, 3086 target: Tensor, 3087 weight: Optional[Tensor] = None, 3088 size_average: Optional[bool] = None, 3089 ignore_index: int = -100, 3090 reduce: Optional[bool] = None, 3091 reduction: str = "mean", 3092) -> Tensor: 3093 r"""Compute the negative log likelihood loss. 3094 3095 See :class:`~torch.nn.NLLLoss` for details. 3096 3097 Args: 3098 input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)` 3099 in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \geq 1` 3100 in the case of K-dimensional loss. `input` is expected to be log-probabilities. 3101 target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, 3102 or :math:`(N, d_1, d_2, ..., d_K)` where :math:`K \geq 1` for 3103 K-dimensional loss. 3104 weight (Tensor, optional): a manual rescaling weight given to each 3105 class. If given, has to be a Tensor of size `C` 3106 size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3107 the losses are averaged over each loss element in the batch. Note that for 3108 some losses, there multiple elements per sample. If the field :attr:`size_average` 3109 is set to ``False``, the losses are instead summed for each minibatch. Ignored 3110 when reduce is ``False``. Default: ``True`` 3111 ignore_index (int, optional): Specifies a target value that is ignored 3112 and does not contribute to the input gradient. When :attr:`size_average` is 3113 ``True``, the loss is averaged over non-ignored targets. Default: -100 3114 reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3115 losses are averaged or summed over observations for each minibatch depending 3116 on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3117 batch element instead and ignores :attr:`size_average`. Default: ``True`` 3118 reduction (str, optional): Specifies the reduction to apply to the output: 3119 ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3120 ``'mean'``: the sum of the output will be divided by the number of 3121 elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 3122 and :attr:`reduce` are in the process of being deprecated, and in the meantime, 3123 specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 3124 3125 Example:: 3126 3127 >>> # input is of size N x C = 3 x 5 3128 >>> input = torch.randn(3, 5, requires_grad=True) 3129 >>> # each element in target has to have 0 <= value < C 3130 >>> target = torch.tensor([1, 0, 4]) 3131 >>> output = F.nll_loss(F.log_softmax(input, dim=1), target) 3132 >>> output.backward() 3133 """ 3134 if has_torch_function_variadic(input, target, weight): 3135 return handle_torch_function( 3136 nll_loss, 3137 (input, target, weight), 3138 input, 3139 target, 3140 weight=weight, 3141 size_average=size_average, 3142 ignore_index=ignore_index, 3143 reduce=reduce, 3144 reduction=reduction, 3145 ) 3146 if size_average is not None or reduce is not None: 3147 reduction = _Reduction.legacy_get_string(size_average, reduce) 3148 return torch._C._nn.nll_loss_nd( 3149 input, target, weight, _Reduction.get_enum(reduction), ignore_index 3150 ) 3151 3152 3153def poisson_nll_loss( 3154 input: Tensor, 3155 target: Tensor, 3156 log_input: bool = True, 3157 full: bool = False, 3158 size_average: Optional[bool] = None, 3159 eps: float = 1e-8, 3160 reduce: Optional[bool] = None, 3161 reduction: str = "mean", 3162) -> Tensor: 3163 r"""Poisson negative log likelihood loss. 3164 3165 See :class:`~torch.nn.PoissonNLLLoss` for details. 3166 3167 Args: 3168 input: expectation of underlying Poisson distribution. 3169 target: random sample :math:`target \sim \text{Poisson}(input)`. 3170 log_input: if ``True`` the loss is computed as 3171 :math:`\exp(\text{input}) - \text{target} * \text{input}`, if ``False`` then loss is 3172 :math:`\text{input} - \text{target} * \log(\text{input}+\text{eps})`. Default: ``True`` 3173 full: whether to compute full loss, i. e. to add the Stirling 3174 approximation term. Default: ``False`` 3175 :math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`. 3176 size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3177 the losses are averaged over each loss element in the batch. Note that for 3178 some losses, there multiple elements per sample. If the field :attr:`size_average` 3179 is set to ``False``, the losses are instead summed for each minibatch. Ignored 3180 when reduce is ``False``. Default: ``True`` 3181 eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when 3182 :attr:`log_input`\ =\ ``False``. Default: 1e-8 3183 reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3184 losses are averaged or summed over observations for each minibatch depending 3185 on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3186 batch element instead and ignores :attr:`size_average`. Default: ``True`` 3187 reduction (str, optional): Specifies the reduction to apply to the output: 3188 ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3189 ``'mean'``: the sum of the output will be divided by the number of 3190 elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 3191 and :attr:`reduce` are in the process of being deprecated, and in the meantime, 3192 specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 3193 3194 """ 3195 if has_torch_function_variadic(input, target): 3196 return handle_torch_function( 3197 poisson_nll_loss, 3198 (input, target), 3199 input, 3200 target, 3201 log_input=log_input, 3202 full=full, 3203 size_average=size_average, 3204 eps=eps, 3205 reduce=reduce, 3206 reduction=reduction, 3207 ) 3208 if size_average is not None or reduce is not None: 3209 reduction = _Reduction.legacy_get_string(size_average, reduce) 3210 if reduction != "none" and reduction != "mean" and reduction != "sum": 3211 ret = input 3212 raise ValueError(reduction + " is not a valid value for reduction") 3213 3214 ret = torch.poisson_nll_loss( 3215 input, target, log_input, full, eps, _Reduction.get_enum(reduction) 3216 ) 3217 return ret 3218 3219 3220def gaussian_nll_loss( 3221 input: Tensor, 3222 target: Tensor, 3223 var: Tensor, 3224 full: bool = False, 3225 eps: float = 1e-6, 3226 reduction: str = "mean", 3227) -> Tensor: 3228 r"""Gaussian negative log likelihood loss. 3229 3230 See :class:`~torch.nn.GaussianNLLLoss` for details. 3231 3232 Args: 3233 input: expectation of the Gaussian distribution. 3234 target: sample from the Gaussian distribution. 3235 var: tensor of positive variance(s), one for each of the expectations 3236 in the input (heteroscedastic), or a single one (homoscedastic). 3237 full (bool, optional): include the constant term in the loss calculation. Default: ``False``. 3238 eps (float, optional): value added to var, for stability. Default: 1e-6. 3239 reduction (str, optional): specifies the reduction to apply to the output: 3240 ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3241 ``'mean'``: the output is the average of all batch member losses, 3242 ``'sum'``: the output is the sum of all batch member losses. 3243 Default: ``'mean'``. 3244 """ 3245 if has_torch_function_variadic(input, target, var): 3246 return handle_torch_function( 3247 gaussian_nll_loss, 3248 (input, target, var), 3249 input, 3250 target, 3251 var, 3252 full=full, 3253 eps=eps, 3254 reduction=reduction, 3255 ) 3256 3257 # Check var size 3258 # If var.size == input.size, the case is heteroscedastic and no further checks are needed. 3259 # Otherwise: 3260 if var.size() != input.size(): 3261 # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case. 3262 # e.g. input.size = (10, 2, 3), var.size = (10, 2) 3263 # -> unsqueeze var so that var.shape = (10, 2, 1) 3264 # this is done so that broadcasting can happen in the loss calculation 3265 if input.size()[:-1] == var.size(): 3266 var = torch.unsqueeze(var, -1) 3267 3268 # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1. 3269 # This is also a homoscedastic case. 3270 # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1) 3271 elif ( 3272 input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1 3273 ): # Heteroscedastic case 3274 pass 3275 3276 # If none of the above pass, then the size of var is incorrect. 3277 else: 3278 raise ValueError("var is of incorrect size") 3279 3280 # Check validity of reduction mode 3281 if reduction != "none" and reduction != "mean" and reduction != "sum": 3282 raise ValueError(reduction + " is not valid") 3283 3284 # Entries of var must be non-negative 3285 if torch.any(var < 0): 3286 raise ValueError("var has negative entry/entries") 3287 3288 # Clamp for stability 3289 var = var.clone() 3290 with torch.no_grad(): 3291 var.clamp_(min=eps) 3292 3293 # Calculate the loss 3294 loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var) 3295 if full: 3296 loss += 0.5 * math.log(2 * math.pi) 3297 3298 if reduction == "mean": 3299 return loss.mean() 3300 elif reduction == "sum": 3301 return loss.sum() 3302 else: 3303 return loss 3304 3305 3306def kl_div( 3307 input: Tensor, 3308 target: Tensor, 3309 size_average: Optional[bool] = None, 3310 reduce: Optional[bool] = None, 3311 reduction: str = "mean", 3312 log_target: bool = False, 3313) -> Tensor: 3314 r"""Compute the KL Divergence loss. 3315 3316 Refer - The `Kullback-Leibler divergence Loss 3317 <https://en.wikipedia.org/wiki/Kullback-Leibler_divergence>`__ 3318 3319 See :class:`~torch.nn.KLDivLoss` for details. 3320 3321 Args: 3322 input: Tensor of arbitrary shape in log-probabilities. 3323 target: Tensor of the same shape as input. See :attr:`log_target` for 3324 the target's interpretation. 3325 size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3326 the losses are averaged over each loss element in the batch. Note that for 3327 some losses, there multiple elements per sample. If the field :attr:`size_average` 3328 is set to ``False``, the losses are instead summed for each minibatch. Ignored 3329 when reduce is ``False``. Default: ``True`` 3330 reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3331 losses are averaged or summed over observations for each minibatch depending 3332 on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3333 batch element instead and ignores :attr:`size_average`. Default: ``True`` 3334 reduction (str, optional): Specifies the reduction to apply to the output: 3335 ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. 3336 ``'none'``: no reduction will be applied 3337 ``'batchmean'``: the sum of the output will be divided by the batchsize 3338 ``'sum'``: the output will be summed 3339 ``'mean'``: the output will be divided by the number of elements in the output 3340 Default: ``'mean'`` 3341 log_target (bool): A flag indicating whether ``target`` is passed in the log space. 3342 It is recommended to pass certain distributions (like ``softmax``) 3343 in the log space to avoid numerical issues caused by explicit ``log``. 3344 Default: ``False`` 3345 3346 .. note:: 3347 :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, 3348 and in the meantime, specifying either of those two args will override :attr:`reduction`. 3349 3350 .. warning:: 3351 :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use 3352 :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition. 3353 """ 3354 if has_torch_function_variadic(input, target): 3355 return handle_torch_function( 3356 kl_div, 3357 (input, target), 3358 input, 3359 target, 3360 size_average=size_average, 3361 reduce=reduce, 3362 reduction=reduction, 3363 log_target=log_target, 3364 ) 3365 if size_average is not None or reduce is not None: 3366 reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3367 else: 3368 if reduction == "mean": 3369 warnings.warn( 3370 "reduction: 'mean' divides the total loss by both the batch size and the support size." 3371 "'batchmean' divides only by the batch size, and aligns with the KL div math definition." 3372 "'mean' will be changed to behave the same as 'batchmean' in the next major release." 3373 ) 3374 3375 # special case for batchmean 3376 if reduction == "batchmean": 3377 reduction_enum = _Reduction.get_enum("sum") 3378 else: 3379 reduction_enum = _Reduction.get_enum(reduction) 3380 3381 reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target) 3382 3383 if reduction == "batchmean" and input.dim() != 0: 3384 reduced = reduced / input.size()[0] 3385 3386 return reduced 3387 3388 3389def cross_entropy( 3390 input: Tensor, 3391 target: Tensor, 3392 weight: Optional[Tensor] = None, 3393 size_average: Optional[bool] = None, 3394 ignore_index: int = -100, 3395 reduce: Optional[bool] = None, 3396 reduction: str = "mean", 3397 label_smoothing: float = 0.0, 3398) -> Tensor: 3399 r"""Compute the cross entropy loss between input logits and target. 3400 3401 See :class:`~torch.nn.CrossEntropyLoss` for details. 3402 3403 Args: 3404 input (Tensor) : Predicted unnormalized logits; 3405 see Shape section below for supported shapes. 3406 target (Tensor) : Ground truth class indices or class probabilities; 3407 see Shape section below for supported shapes. 3408 weight (Tensor, optional): a manual rescaling weight given to each 3409 class. If given, has to be a Tensor of size `C` 3410 size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3411 the losses are averaged over each loss element in the batch. Note that for 3412 some losses, there multiple elements per sample. If the field :attr:`size_average` 3413 is set to ``False``, the losses are instead summed for each minibatch. Ignored 3414 when reduce is ``False``. Default: ``True`` 3415 ignore_index (int, optional): Specifies a target value that is ignored 3416 and does not contribute to the input gradient. When :attr:`size_average` is 3417 ``True``, the loss is averaged over non-ignored targets. Note that 3418 :attr:`ignore_index` is only applicable when the target contains class indices. 3419 Default: -100 3420 reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3421 losses are averaged or summed over observations for each minibatch depending 3422 on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3423 batch element instead and ignores :attr:`size_average`. Default: ``True`` 3424 reduction (str, optional): Specifies the reduction to apply to the output: 3425 ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3426 ``'mean'``: the sum of the output will be divided by the number of 3427 elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 3428 and :attr:`reduce` are in the process of being deprecated, and in the meantime, 3429 specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 3430 label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount 3431 of smoothing when computing the loss, where 0.0 means no smoothing. The targets 3432 become a mixture of the original ground truth and a uniform distribution as described in 3433 `Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`. 3434 3435 Shape: 3436 - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` 3437 in the case of `K`-dimensional loss. 3438 - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with 3439 :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. 3440 If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`. 3441 3442 where: 3443 3444 .. math:: 3445 \begin{aligned} 3446 C ={} & \text{number of classes} \\ 3447 N ={} & \text{batch size} \\ 3448 \end{aligned} 3449 3450 Examples:: 3451 3452 >>> # Example of target with class indices 3453 >>> input = torch.randn(3, 5, requires_grad=True) 3454 >>> target = torch.randint(5, (3,), dtype=torch.int64) 3455 >>> loss = F.cross_entropy(input, target) 3456 >>> loss.backward() 3457 >>> 3458 >>> # Example of target with class probabilities 3459 >>> input = torch.randn(3, 5, requires_grad=True) 3460 >>> target = torch.randn(3, 5).softmax(dim=1) 3461 >>> loss = F.cross_entropy(input, target) 3462 >>> loss.backward() 3463 """ 3464 if has_torch_function_variadic(input, target, weight): 3465 return handle_torch_function( 3466 cross_entropy, 3467 (input, target, weight), 3468 input, 3469 target, 3470 weight=weight, 3471 size_average=size_average, 3472 ignore_index=ignore_index, 3473 reduce=reduce, 3474 reduction=reduction, 3475 label_smoothing=label_smoothing, 3476 ) 3477 if size_average is not None or reduce is not None: 3478 reduction = _Reduction.legacy_get_string(size_average, reduce) 3479 return torch._C._nn.cross_entropy_loss( 3480 input, 3481 target, 3482 weight, 3483 _Reduction.get_enum(reduction), 3484 ignore_index, 3485 label_smoothing, 3486 ) 3487 3488 3489def binary_cross_entropy( 3490 input: Tensor, 3491 target: Tensor, 3492 weight: Optional[Tensor] = None, 3493 size_average: Optional[bool] = None, 3494 reduce: Optional[bool] = None, 3495 reduction: str = "mean", 3496) -> Tensor: 3497 r"""Measure Binary Cross Entropy between the target and input probabilities. 3498 3499 See :class:`~torch.nn.BCELoss` for details. 3500 3501 Args: 3502 input: Tensor of arbitrary shape as probabilities. 3503 target: Tensor of the same shape as input with values between 0 and 1. 3504 weight (Tensor, optional): a manual rescaling weight 3505 if provided it's repeated to match input tensor shape 3506 size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3507 the losses are averaged over each loss element in the batch. Note that for 3508 some losses, there multiple elements per sample. If the field :attr:`size_average` 3509 is set to ``False``, the losses are instead summed for each minibatch. Ignored 3510 when reduce is ``False``. Default: ``True`` 3511 reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3512 losses are averaged or summed over observations for each minibatch depending 3513 on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3514 batch element instead and ignores :attr:`size_average`. Default: ``True`` 3515 reduction (str, optional): Specifies the reduction to apply to the output: 3516 ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3517 ``'mean'``: the sum of the output will be divided by the number of 3518 elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 3519 and :attr:`reduce` are in the process of being deprecated, and in the meantime, 3520 specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 3521 3522 Examples:: 3523 3524 >>> input = torch.randn(3, 2, requires_grad=True) 3525 >>> target = torch.rand(3, 2, requires_grad=False) 3526 >>> loss = F.binary_cross_entropy(torch.sigmoid(input), target) 3527 >>> loss.backward() 3528 """ 3529 if has_torch_function_variadic(input, target, weight): 3530 return handle_torch_function( 3531 binary_cross_entropy, 3532 (input, target, weight), 3533 input, 3534 target, 3535 weight=weight, 3536 size_average=size_average, 3537 reduce=reduce, 3538 reduction=reduction, 3539 ) 3540 if size_average is not None or reduce is not None: 3541 reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3542 else: 3543 reduction_enum = _Reduction.get_enum(reduction) 3544 if target.size() != input.size(): 3545 raise ValueError( 3546 f"Using a target size ({target.size()}) that is different to the input size ({input.size()}) is deprecated. " 3547 "Please ensure they have the same size." 3548 ) 3549 3550 if weight is not None: 3551 new_size = _infer_size(target.size(), weight.size()) 3552 weight = weight.expand(new_size) 3553 3554 return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) 3555 3556 3557def binary_cross_entropy_with_logits( 3558 input: Tensor, 3559 target: Tensor, 3560 weight: Optional[Tensor] = None, 3561 size_average: Optional[bool] = None, 3562 reduce: Optional[bool] = None, 3563 reduction: str = "mean", 3564 pos_weight: Optional[Tensor] = None, 3565) -> Tensor: 3566 r"""Calculate Binary Cross Entropy between target and input logits. 3567 3568 See :class:`~torch.nn.BCEWithLogitsLoss` for details. 3569 3570 Args: 3571 input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits). 3572 target: Tensor of the same shape as input with values between 0 and 1 3573 weight (Tensor, optional): a manual rescaling weight 3574 if provided it's repeated to match input tensor shape 3575 size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, 3576 the losses are averaged over each loss element in the batch. Note that for 3577 some losses, there multiple elements per sample. If the field :attr:`size_average` 3578 is set to ``False``, the losses are instead summed for each minibatch. Ignored 3579 when reduce is ``False``. Default: ``True`` 3580 reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the 3581 losses are averaged or summed over observations for each minibatch depending 3582 on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per 3583 batch element instead and ignores :attr:`size_average`. Default: ``True`` 3584 reduction (str, optional): Specifies the reduction to apply to the output: 3585 ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 3586 ``'mean'``: the sum of the output will be divided by the number of 3587 elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` 3588 and :attr:`reduce` are in the process of being deprecated, and in the meantime, 3589 specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` 3590 pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target. 3591 Must be a tensor with equal size along the class dimension to the number of classes. 3592 Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired 3593 operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of 3594 size [B, C, H, W] will apply different pos_weights to each element of the batch or 3595 [C, H, W] the same pos_weights across the batch. To apply the same positive weight 3596 along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. 3597 Default: ``None`` 3598 3599 Examples:: 3600 3601 >>> input = torch.randn(3, requires_grad=True) 3602 >>> target = torch.empty(3).random_(2) 3603 >>> loss = F.binary_cross_entropy_with_logits(input, target) 3604 >>> loss.backward() 3605 """ 3606 if has_torch_function_variadic(input, target, weight, pos_weight): 3607 return handle_torch_function( 3608 binary_cross_entropy_with_logits, 3609 (input, target, weight, pos_weight), 3610 input, 3611 target, 3612 weight=weight, 3613 size_average=size_average, 3614 reduce=reduce, 3615 reduction=reduction, 3616 pos_weight=pos_weight, 3617 ) 3618 if size_average is not None or reduce is not None: 3619 reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3620 else: 3621 reduction_enum = _Reduction.get_enum(reduction) 3622 3623 if not (target.size() == input.size()): 3624 raise ValueError( 3625 f"Target size ({target.size()}) must be the same as input size ({input.size()})" 3626 ) 3627 3628 return torch.binary_cross_entropy_with_logits( 3629 input, target, weight, pos_weight, reduction_enum 3630 ) 3631 3632 3633def smooth_l1_loss( 3634 input: Tensor, 3635 target: Tensor, 3636 size_average: Optional[bool] = None, 3637 reduce: Optional[bool] = None, 3638 reduction: str = "mean", 3639 beta: float = 1.0, 3640) -> Tensor: 3641 r"""Compute the Smooth L1 loss. 3642 3643 Function uses a squared term if the absolute 3644 element-wise error falls below beta and an L1 term otherwise. 3645 3646 See :class:`~torch.nn.SmoothL1Loss` for details. 3647 """ 3648 if has_torch_function_variadic(input, target): 3649 return handle_torch_function( 3650 smooth_l1_loss, 3651 (input, target), 3652 input, 3653 target, 3654 size_average=size_average, 3655 reduce=reduce, 3656 reduction=reduction, 3657 beta=beta, 3658 ) 3659 if not (target.size() == input.size()): 3660 warnings.warn( 3661 f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " 3662 "This will likely lead to incorrect results due to broadcasting. " 3663 "Please ensure they have the same size.", 3664 stacklevel=2, 3665 ) 3666 if size_average is not None or reduce is not None: 3667 reduction = _Reduction.legacy_get_string(size_average, reduce) 3668 3669 expanded_input, expanded_target = torch.broadcast_tensors(input, target) 3670 3671 if beta == 0.0: 3672 return torch._C._nn.l1_loss( 3673 expanded_input, expanded_target, _Reduction.get_enum(reduction) 3674 ) 3675 else: 3676 return torch._C._nn.smooth_l1_loss( 3677 expanded_input, expanded_target, _Reduction.get_enum(reduction), beta 3678 ) 3679 3680 3681def huber_loss( 3682 input: Tensor, 3683 target: Tensor, 3684 reduction: str = "mean", 3685 delta: float = 1.0, 3686) -> Tensor: 3687 r"""Compute the Huber loss. 3688 3689 Function uses a squared term if the absolute 3690 element-wise error falls below delta and a delta-scaled L1 term otherwise. 3691 3692 When delta equals 1, this loss is equivalent to SmoothL1Loss. 3693 In general, Huber loss differs from SmoothL1Loss by a factor of delta (AKA beta in Smooth L1). 3694 3695 See :class:`~torch.nn.HuberLoss` for details. 3696 """ 3697 if has_torch_function_variadic(input, target): 3698 return handle_torch_function( 3699 huber_loss, 3700 (input, target), 3701 input, 3702 target, 3703 reduction=reduction, 3704 delta=delta, 3705 ) 3706 if not (target.size() == input.size()): 3707 warnings.warn( 3708 f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " 3709 "This will likely lead to incorrect results due to broadcasting. " 3710 "Please ensure they have the same size.", 3711 stacklevel=2, 3712 ) 3713 3714 expanded_input, expanded_target = torch.broadcast_tensors(input, target) 3715 return torch._C._nn.huber_loss( 3716 expanded_input, expanded_target, _Reduction.get_enum(reduction), delta 3717 ) 3718 3719 3720def l1_loss( 3721 input: Tensor, 3722 target: Tensor, 3723 size_average: Optional[bool] = None, 3724 reduce: Optional[bool] = None, 3725 reduction: str = "mean", 3726) -> Tensor: # noqa: D400,D402 3727 r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor 3728 3729 Function that takes the mean element-wise absolute value difference. 3730 3731 See :class:`~torch.nn.L1Loss` for details. 3732 """ 3733 if has_torch_function_variadic(input, target): 3734 return handle_torch_function( 3735 l1_loss, 3736 (input, target), 3737 input, 3738 target, 3739 size_average=size_average, 3740 reduce=reduce, 3741 reduction=reduction, 3742 ) 3743 if not (target.size() == input.size()): 3744 warnings.warn( 3745 f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " 3746 "This will likely lead to incorrect results due to broadcasting. " 3747 "Please ensure they have the same size.", 3748 stacklevel=2, 3749 ) 3750 if size_average is not None or reduce is not None: 3751 reduction = _Reduction.legacy_get_string(size_average, reduce) 3752 3753 expanded_input, expanded_target = torch.broadcast_tensors(input, target) 3754 return torch._C._nn.l1_loss( 3755 expanded_input, expanded_target, _Reduction.get_enum(reduction) 3756 ) 3757 3758 3759def mse_loss( 3760 input: Tensor, 3761 target: Tensor, 3762 size_average: Optional[bool] = None, 3763 reduce: Optional[bool] = None, 3764 reduction: str = "mean", 3765) -> Tensor: # noqa: D400,D402 3766 r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor 3767 3768 Measures the element-wise mean squared error. 3769 See :class:`~torch.nn.MSELoss` for details. 3770 """ 3771 if has_torch_function_variadic(input, target): 3772 return handle_torch_function( 3773 mse_loss, 3774 (input, target), 3775 input, 3776 target, 3777 size_average=size_average, 3778 reduce=reduce, 3779 reduction=reduction, 3780 ) 3781 if not (target.size() == input.size()): 3782 warnings.warn( 3783 f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " 3784 "This will likely lead to incorrect results due to broadcasting. " 3785 "Please ensure they have the same size.", 3786 stacklevel=2, 3787 ) 3788 if size_average is not None or reduce is not None: 3789 reduction = _Reduction.legacy_get_string(size_average, reduce) 3790 3791 expanded_input, expanded_target = torch.broadcast_tensors(input, target) 3792 return torch._C._nn.mse_loss( 3793 expanded_input, expanded_target, _Reduction.get_enum(reduction) 3794 ) 3795 3796 3797def margin_ranking_loss( 3798 input1: Tensor, 3799 input2: Tensor, 3800 target: Tensor, 3801 margin: float = 0, 3802 size_average: Optional[bool] = None, 3803 reduce: Optional[bool] = None, 3804 reduction: str = "mean", 3805) -> Tensor: # noqa: D400,D402 3806 r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor 3807 3808 See :class:`~torch.nn.MarginRankingLoss` for details. 3809 """ 3810 if has_torch_function_variadic(input1, input2, target): 3811 return handle_torch_function( 3812 margin_ranking_loss, 3813 (input1, input2, target), 3814 input1, 3815 input2, 3816 target, 3817 margin=margin, 3818 size_average=size_average, 3819 reduce=reduce, 3820 reduction=reduction, 3821 ) 3822 if size_average is not None or reduce is not None: 3823 reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3824 else: 3825 reduction_enum = _Reduction.get_enum(reduction) 3826 if input1.dim() != input2.dim() or input1.dim() != target.dim(): 3827 raise RuntimeError( 3828 f"margin_ranking_loss : All input tensors should have same dimension but got sizes: " 3829 f"input1: {input1.size()}, input2: {input2.size()}, target: {target.size()} " 3830 ) 3831 return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum) 3832 3833 3834def hinge_embedding_loss( 3835 input: Tensor, 3836 target: Tensor, 3837 margin: float = 1.0, 3838 size_average: Optional[bool] = None, 3839 reduce: Optional[bool] = None, 3840 reduction: str = "mean", 3841) -> Tensor: # noqa: D400,D402 3842 r"""hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean') -> Tensor 3843 3844 See :class:`~torch.nn.HingeEmbeddingLoss` for details. 3845 """ 3846 if has_torch_function_variadic(input, target): 3847 return handle_torch_function( 3848 hinge_embedding_loss, 3849 (input, target), 3850 input, 3851 target, 3852 margin=margin, 3853 size_average=size_average, 3854 reduce=reduce, 3855 reduction=reduction, 3856 ) 3857 if size_average is not None or reduce is not None: 3858 reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3859 else: 3860 reduction_enum = _Reduction.get_enum(reduction) 3861 return torch.hinge_embedding_loss(input, target, margin, reduction_enum) 3862 3863 3864def multilabel_margin_loss( 3865 input: Tensor, 3866 target: Tensor, 3867 size_average: Optional[bool] = None, 3868 reduce: Optional[bool] = None, 3869 reduction: str = "mean", 3870) -> Tensor: # noqa: D400,D402 3871 r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor 3872 3873 See :class:`~torch.nn.MultiLabelMarginLoss` for details. 3874 """ 3875 if has_torch_function_variadic(input, target): 3876 return handle_torch_function( 3877 multilabel_margin_loss, 3878 (input, target), 3879 input, 3880 target, 3881 size_average=size_average, 3882 reduce=reduce, 3883 reduction=reduction, 3884 ) 3885 if size_average is not None or reduce is not None: 3886 reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3887 else: 3888 reduction_enum = _Reduction.get_enum(reduction) 3889 return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) 3890 3891 3892def soft_margin_loss( 3893 input: Tensor, 3894 target: Tensor, 3895 size_average: Optional[bool] = None, 3896 reduce: Optional[bool] = None, 3897 reduction: str = "mean", 3898) -> Tensor: # noqa: D400,D402 3899 r""" 3900 soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor 3901 3902 See :class:`~torch.nn.SoftMarginLoss` for details. 3903 """ 3904 if has_torch_function_variadic(input, target): 3905 return handle_torch_function( 3906 soft_margin_loss, 3907 (input, target), 3908 input, 3909 target, 3910 size_average=size_average, 3911 reduce=reduce, 3912 reduction=reduction, 3913 ) 3914 if size_average is not None or reduce is not None: 3915 reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3916 else: 3917 reduction_enum = _Reduction.get_enum(reduction) 3918 return torch._C._nn.soft_margin_loss(input, target, reduction_enum) 3919 3920 3921def multilabel_soft_margin_loss( 3922 input: Tensor, 3923 target: Tensor, 3924 weight: Optional[Tensor] = None, 3925 size_average: Optional[bool] = None, 3926 reduce: Optional[bool] = None, 3927 reduction: str = "mean", 3928) -> Tensor: # noqa: D400,D402 3929 r"""multilabel_soft_margin_loss(input, target, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor 3930 3931 See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. 3932 """ 3933 if has_torch_function_variadic(input, target, weight): 3934 return handle_torch_function( 3935 multilabel_soft_margin_loss, 3936 (input, target, weight), 3937 input, 3938 target, 3939 weight=weight, 3940 size_average=size_average, 3941 reduce=reduce, 3942 reduction=reduction, 3943 ) 3944 if size_average is not None or reduce is not None: 3945 reduction = _Reduction.legacy_get_string(size_average, reduce) 3946 3947 loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input)) 3948 3949 if weight is not None: 3950 loss = loss * weight 3951 3952 class_dim = input.dim() - 1 3953 C = input.size(class_dim) 3954 loss = loss.sum(dim=class_dim) / C # only return N loss values 3955 3956 if reduction == "none": 3957 ret = loss 3958 elif reduction == "mean": 3959 ret = loss.mean() 3960 elif reduction == "sum": 3961 ret = loss.sum() 3962 else: 3963 ret = input 3964 raise ValueError(reduction + " is not valid") 3965 return ret 3966 3967 3968def cosine_embedding_loss( 3969 input1: Tensor, 3970 input2: Tensor, 3971 target: Tensor, 3972 margin: float = 0, 3973 size_average: Optional[bool] = None, 3974 reduce: Optional[bool] = None, 3975 reduction: str = "mean", 3976) -> Tensor: # noqa: D400,D402 3977 r"""cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor 3978 3979 See :class:`~torch.nn.CosineEmbeddingLoss` for details. 3980 """ 3981 if has_torch_function_variadic(input1, input2, target): 3982 return handle_torch_function( 3983 cosine_embedding_loss, 3984 (input1, input2, target), 3985 input1, 3986 input2, 3987 target, 3988 margin=margin, 3989 size_average=size_average, 3990 reduce=reduce, 3991 reduction=reduction, 3992 ) 3993 if size_average is not None or reduce is not None: 3994 reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 3995 else: 3996 reduction_enum = _Reduction.get_enum(reduction) 3997 return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) 3998 3999 4000def multi_margin_loss( 4001 input: Tensor, 4002 target: Tensor, 4003 p: int = 1, 4004 margin: float = 1.0, 4005 weight: Optional[Tensor] = None, 4006 size_average: Optional[bool] = None, 4007 reduce: Optional[bool] = None, 4008 reduction: str = "mean", 4009) -> Tensor: # noqa: D400,D402 4010 r"""multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor 4011 4012 See :class:`~torch.nn.MultiMarginLoss` for details. 4013 """ 4014 if has_torch_function_variadic(input, target, weight): 4015 return handle_torch_function( 4016 multi_margin_loss, 4017 (input, target, weight), 4018 input, 4019 target, 4020 p=p, 4021 margin=margin, 4022 weight=weight, 4023 size_average=size_average, 4024 reduce=reduce, 4025 reduction=reduction, 4026 ) 4027 if size_average is not None or reduce is not None: 4028 reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 4029 else: 4030 reduction_enum = _Reduction.get_enum(reduction) 4031 if p != 1 and p != 2: 4032 raise ValueError("only p == 1 and p == 2 supported") 4033 if weight is not None: 4034 if weight.dim() != 1: 4035 raise ValueError("weight must be one-dimensional") 4036 4037 return torch._C._nn.multi_margin_loss( 4038 input, target, p, margin, weight, reduction_enum 4039 ) 4040 4041 4042pixel_shuffle = _add_docstr( 4043 torch.pixel_shuffle, 4044 r""" 4045pixel_shuffle(input, upscale_factor) -> Tensor 4046 4047Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a 4048tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`. 4049 4050See :class:`~torch.nn.PixelShuffle` for details. 4051 4052Args: 4053 input (Tensor): the input tensor 4054 upscale_factor (int): factor to increase spatial resolution by 4055 4056Examples:: 4057 4058 >>> input = torch.randn(1, 9, 4, 4) 4059 >>> output = torch.nn.functional.pixel_shuffle(input, 3) 4060 >>> print(output.size()) 4061 torch.Size([1, 1, 12, 12]) 4062""", 4063) 4064 4065pixel_unshuffle = _add_docstr( 4066 torch.pixel_unshuffle, 4067 r""" 4068pixel_unshuffle(input, downscale_factor) -> Tensor 4069 4070Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a 4071tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape 4072:math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`. 4073 4074See :class:`~torch.nn.PixelUnshuffle` for details. 4075 4076Args: 4077 input (Tensor): the input tensor 4078 downscale_factor (int): factor to increase spatial resolution by 4079 4080Examples:: 4081 4082 >>> input = torch.randn(1, 1, 12, 12) 4083 >>> output = torch.nn.functional.pixel_unshuffle(input, 3) 4084 >>> print(output.size()) 4085 torch.Size([1, 9, 4, 4]) 4086""", 4087) 4088 4089channel_shuffle = _add_docstr( 4090 torch.channel_shuffle, 4091 r""" 4092channel_shuffle(input, groups) -> Tensor 4093 4094Divide the channels in a tensor of shape :math:`(*, C , H, W)` 4095into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, 4096while keeping the original tensor shape. 4097 4098See :class:`~torch.nn.ChannelShuffle` for details. 4099 4100Args: 4101 input (Tensor): the input tensor 4102 groups (int): number of groups to divide channels in and rearrange. 4103 4104Examples:: 4105 4106 >>> input = torch.randn(1, 4, 2, 2) 4107 >>> print(input) 4108 [[[[1, 2], 4109 [3, 4]], 4110 [[5, 6], 4111 [7, 8]], 4112 [[9, 10], 4113 [11, 12]], 4114 [[13, 14], 4115 [15, 16]], 4116 ]] 4117 >>> output = torch.nn.functional.channel_shuffle(input, 2) 4118 >>> print(output) 4119 [[[[1, 2], 4120 [3, 4]], 4121 [[9, 10], 4122 [11, 12]], 4123 [[5, 6], 4124 [7, 8]], 4125 [[13, 14], 4126 [15, 16]], 4127 ]] 4128""", 4129) 4130 4131native_channel_shuffle = _add_docstr( 4132 torch.native_channel_shuffle, 4133 r""" 4134native_channel_shuffle(input, groups) -> Tensor 4135 4136Native kernel level implementation of the `channel_shuffle`. 4137This function might become private in future releases, use with caution. 4138 4139Divide the channels in a tensor of shape :math:`(*, C , H, W)` 4140into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, 4141while keeping the original tensor shape. 4142 4143See :class:`~torch.nn.ChannelShuffle` for details. 4144 4145Args: 4146 input (Tensor): the input tensor 4147 groups (int): number of groups to divide channels in and rearrange. 4148 4149Examples:: 4150 4151 >>> input = torch.randn(1, 4, 2, 2) 4152 >>> print(input) 4153 [[[[1, 2], 4154 [3, 4]], 4155 [[5, 6], 4156 [7, 8]], 4157 [[9, 10], 4158 [11, 12]], 4159 [[13, 14], 4160 [15, 16]], 4161 ]] 4162 >>> output = torch.nn.functional.native_channel_shuffle(input, 2) 4163 >>> print(output) 4164 [[[[1, 2], 4165 [3, 4]], 4166 [[9, 10], 4167 [11, 12]], 4168 [[5, 6], 4169 [7, 8]], 4170 [[13, 14], 4171 [15, 16]], 4172 ]] 4173""", 4174) 4175 4176 4177@_overload 4178def upsample( # noqa: F811 4179 input: Tensor, 4180 size: Optional[int] = None, 4181 scale_factor: Optional[float] = None, 4182 mode: str = "nearest", 4183 align_corners: Optional[bool] = None, 4184) -> Tensor: # noqa: B950 4185 pass 4186 4187 4188@_overload 4189def upsample( # noqa: F811 4190 input: Tensor, 4191 size: Optional[List[int]] = None, 4192 scale_factor: Optional[float] = None, 4193 mode: str = "nearest", 4194 align_corners: Optional[bool] = None, 4195) -> Tensor: # noqa: B950 4196 pass 4197 4198 4199def upsample( # noqa: F811 4200 input, 4201 size=None, 4202 scale_factor=None, 4203 mode="nearest", 4204 align_corners=None, 4205): 4206 r"""Upsample input. 4207 4208 Provided tensor is upsampled to either the given :attr:`size` or the given 4209 :attr:`scale_factor` 4210 4211 .. warning:: 4212 This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. 4213 This is equivalent with ``nn.functional.interpolate(...)``. 4214 4215 Note: 4216 {backward_reproducibility_note} 4217 4218 The algorithm used for upsampling is determined by :attr:`mode`. 4219 4220 Currently temporal, spatial and volumetric upsampling are supported, i.e. 4221 expected inputs are 3-D, 4-D or 5-D in shape. 4222 4223 The input dimensions are interpreted in the form: 4224 `mini-batch x channels x [optional depth] x [optional height] x width`. 4225 4226 The modes available for upsampling are: `nearest`, `linear` (3D-only), 4227 `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only) 4228 4229 Args: 4230 input (Tensor): the input tensor 4231 size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): 4232 output spatial size. 4233 scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. 4234 mode (str): algorithm used for upsampling: 4235 ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | 4236 ``'trilinear'``. Default: ``'nearest'`` 4237 align_corners (bool, optional): Geometrically, we consider the pixels of the 4238 input and output as squares rather than points. 4239 If set to ``True``, the input and output tensors are aligned by the 4240 center points of their corner pixels, preserving the values at the corner pixels. 4241 If set to ``False``, the input and output tensors are aligned by the corner 4242 points of their corner pixels, and the interpolation uses edge value padding 4243 for out-of-boundary values, making this operation *independent* of input size 4244 when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` 4245 is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. 4246 Default: ``False`` 4247 4248 .. note:: 4249 With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce 4250 negative values or values greater than 255 for images. 4251 Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot 4252 when displaying the image. 4253 4254 .. warning:: 4255 With ``align_corners = True``, the linearly interpolating modes 4256 (`linear`, `bilinear`, and `trilinear`) don't proportionally align the 4257 output and input pixels, and thus the output values can depend on the 4258 input size. This was the default behavior for these modes up to version 4259 0.3.1. Since then, the default behavior is ``align_corners = False``. 4260 See :class:`~torch.nn.Upsample` for concrete examples on how this 4261 affects the outputs. 4262 4263 """ 4264 warnings.warn( 4265 "`nn.functional.upsample` is deprecated. " 4266 "Use `nn.functional.interpolate` instead.", 4267 stacklevel=2, 4268 ) 4269 return interpolate(input, size, scale_factor, mode, align_corners) 4270 4271 4272if upsample.__doc__: 4273 upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes) 4274 4275 4276def _is_integer(x) -> bool: 4277 r"""Type check the input number is an integer. 4278 4279 Will return True for int, SymInt, Numpy integers and Tensors with integer elements. 4280 """ 4281 if isinstance(x, (int, torch.SymInt)): 4282 return True 4283 if np is not None and isinstance(x, np.integer): 4284 return True 4285 return isinstance(x, Tensor) and not x.is_floating_point() 4286 4287 4288@_overload 4289def interpolate( # noqa: F811 4290 input: Tensor, 4291 size: Optional[int] = None, 4292 scale_factor: Optional[List[float]] = None, 4293 mode: str = "nearest", 4294 align_corners: Optional[bool] = None, 4295 recompute_scale_factor: Optional[bool] = None, 4296 antialias: bool = False, 4297) -> Tensor: # noqa: B950 4298 pass 4299 4300 4301@_overload 4302def interpolate( # noqa: F811 4303 input: Tensor, 4304 size: Optional[List[int]] = None, 4305 scale_factor: Optional[List[float]] = None, 4306 mode: str = "nearest", 4307 align_corners: Optional[bool] = None, 4308 recompute_scale_factor: Optional[bool] = None, 4309 antialias: bool = False, 4310) -> Tensor: # noqa: B950 4311 pass 4312 4313 4314@_overload 4315def interpolate( # noqa: F811 4316 input: Tensor, 4317 size: Optional[int] = None, 4318 scale_factor: Optional[float] = None, 4319 mode: str = "nearest", 4320 align_corners: Optional[bool] = None, 4321 recompute_scale_factor: Optional[bool] = None, 4322 antialias: bool = False, 4323) -> Tensor: # noqa: B950 4324 pass 4325 4326 4327@_overload 4328def interpolate( # noqa: F811 4329 input: Tensor, 4330 size: Optional[List[int]] = None, 4331 scale_factor: Optional[float] = None, 4332 mode: str = "nearest", 4333 align_corners: Optional[bool] = None, 4334 recompute_scale_factor: Optional[bool] = None, 4335 antialias: bool = False, 4336) -> Tensor: 4337 pass 4338 4339 4340def interpolate( # noqa: F811 4341 input: Tensor, 4342 size: Optional[int] = None, 4343 scale_factor: Optional[List[float]] = None, 4344 mode: str = "nearest", 4345 align_corners: Optional[bool] = None, 4346 recompute_scale_factor: Optional[bool] = None, 4347 antialias: bool = False, 4348) -> Tensor: # noqa: B950 4349 r"""Down/up samples the input. 4350 4351 Tensor interpolated to either the given :attr:`size` or the given 4352 :attr:`scale_factor` 4353 4354 The algorithm used for interpolation is determined by :attr:`mode`. 4355 4356 Currently temporal, spatial and volumetric sampling are supported, i.e. 4357 expected inputs are 3-D, 4-D or 5-D in shape. 4358 4359 The input dimensions are interpreted in the form: 4360 `mini-batch x channels x [optional depth] x [optional height] x width`. 4361 4362 The modes available for resizing are: `nearest`, `linear` (3D-only), 4363 `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`, `nearest-exact` 4364 4365 Args: 4366 input (Tensor): the input tensor 4367 size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): 4368 output spatial size. 4369 scale_factor (float or Tuple[float]): multiplier for spatial size. If `scale_factor` is a tuple, 4370 its length has to match the number of spatial dimensions; `input.dim() - 2`. 4371 mode (str): algorithm used for upsampling: 4372 ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | 4373 ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'`` 4374 align_corners (bool, optional): Geometrically, we consider the pixels of the 4375 input and output as squares rather than points. 4376 If set to ``True``, the input and output tensors are aligned by the 4377 center points of their corner pixels, preserving the values at the corner pixels. 4378 If set to ``False``, the input and output tensors are aligned by the corner 4379 points of their corner pixels, and the interpolation uses edge value padding 4380 for out-of-boundary values, making this operation *independent* of input size 4381 when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` 4382 is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. 4383 Default: ``False`` 4384 recompute_scale_factor (bool, optional): recompute the scale_factor for use in the 4385 interpolation calculation. If `recompute_scale_factor` is ``True``, then 4386 `scale_factor` must be passed in and `scale_factor` is used to compute the 4387 output `size`. The computed output `size` will be used to infer new scales for 4388 the interpolation. Note that when `scale_factor` is floating-point, it may differ 4389 from the recomputed `scale_factor` due to rounding and precision issues. 4390 If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will 4391 be used directly for interpolation. Default: ``None``. 4392 antialias (bool, optional): flag to apply anti-aliasing. Default: ``False``. Using anti-alias 4393 option together with ``align_corners=False``, interpolation result would match Pillow 4394 result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``. 4395 4396 .. note:: 4397 With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce 4398 negative values or values greater than 255 for images. 4399 Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot 4400 when displaying the image. 4401 4402 .. note:: 4403 Mode ``mode='nearest-exact'`` matches Scikit-Image and PIL nearest neighbours interpolation 4404 algorithms and fixes known issues with ``mode='nearest'``. This mode is introduced to keep 4405 backward compatibility. 4406 Mode ``mode='nearest'`` matches buggy OpenCV's ``INTER_NEAREST`` interpolation algorithm. 4407 4408 .. note:: 4409 The gradients for the dtype ``float16`` on CUDA may be inaccurate in the upsample operation 4410 when using modes ``['linear', 'bilinear', 'bicubic', 'trilinear', 'area']``. 4411 For more details, please refer to the discussion in 4412 `issue#104157 <https://github.com/pytorch/pytorch/issues/104157>`_. 4413 4414 Note: 4415 {backward_reproducibility_note} 4416 """ 4417 if has_torch_function_unary(input): 4418 return handle_torch_function( 4419 interpolate, 4420 (input,), 4421 input, 4422 size=size, 4423 scale_factor=scale_factor, 4424 mode=mode, 4425 align_corners=align_corners, 4426 recompute_scale_factor=recompute_scale_factor, 4427 antialias=antialias, 4428 ) 4429 4430 if mode in ("nearest", "area", "nearest-exact"): 4431 if align_corners is not None: 4432 raise ValueError( 4433 "align_corners option can only be set with the " 4434 "interpolating modes: linear | bilinear | bicubic | trilinear" 4435 ) 4436 else: 4437 if align_corners is None: 4438 align_corners = False 4439 4440 dim = input.dim() - 2 # Number of spatial dimensions. 4441 4442 # Process size and scale_factor. Validate that exactly one is set. 4443 # Validate its length if it is a list, or expand it if it is a scalar. 4444 # After this block, exactly one of output_size and scale_factors will 4445 # be non-None, and it will be a list (or tuple). 4446 if size is not None and scale_factor is not None: 4447 raise ValueError("only one of size or scale_factor should be defined") 4448 elif size is not None: 4449 assert scale_factor is None 4450 scale_factors = None 4451 if isinstance(size, (list, tuple)): 4452 if len(size) != dim: 4453 raise ValueError( 4454 "Input and output must have the same number of spatial dimensions, but got " 4455 f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. " 4456 "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " 4457 "output size in (o1, o2, ...,oK) format." 4458 ) 4459 if not torch.jit.is_scripting(): 4460 if not all(_is_integer(x) for x in size): 4461 raise TypeError( 4462 "expected size to be one of int or Tuple[int] or Tuple[int, int] or " 4463 f"Tuple[int, int, int], but got size with types {[type(x) for x in size]}" 4464 ) 4465 output_size = size 4466 else: 4467 output_size = [size for _ in range(dim)] 4468 elif scale_factor is not None: 4469 assert size is None 4470 output_size = None 4471 if isinstance(scale_factor, (list, tuple)): 4472 if len(scale_factor) != dim: 4473 raise ValueError( 4474 "Input and scale_factor must have the same number of spatial dimensions, but " 4475 f"got input with spatial dimensions of {list(input.shape[2:])} and " 4476 f"scale_factor of shape {scale_factor}. " 4477 "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " 4478 "scale_factor in (s1, s2, ...,sK) format." 4479 ) 4480 scale_factors = scale_factor 4481 else: 4482 scale_factors = [scale_factor for _ in range(dim)] 4483 else: 4484 raise ValueError("either size or scale_factor should be defined") 4485 4486 if ( 4487 recompute_scale_factor is not None 4488 and recompute_scale_factor 4489 and size is not None 4490 ): 4491 raise ValueError( 4492 "recompute_scale_factor is not meaningful with an explicit size." 4493 ) 4494 4495 # "area" mode always requires an explicit size rather than scale factor. 4496 # Re-use the recompute_scale_factor code path. 4497 if mode == "area" and output_size is None: 4498 recompute_scale_factor = True 4499 4500 if recompute_scale_factor is not None and recompute_scale_factor: 4501 # We compute output_size here, then un-set scale_factors. 4502 # The C++ code will recompute it based on the (integer) output size. 4503 assert scale_factors is not None 4504 if not torch.jit.is_scripting() and torch._C._get_tracing_state(): 4505 # make scale_factor a tensor in tracing so constant doesn't get baked in 4506 output_size = [ 4507 ( 4508 torch.floor( 4509 ( 4510 input.size(i + 2).float() 4511 * torch.tensor(scale_factors[i], dtype=torch.float32) 4512 ).float() 4513 ) 4514 ) 4515 for i in range(dim) 4516 ] 4517 elif torch.jit.is_scripting(): 4518 output_size = [ 4519 int(math.floor(float(input.size(i + 2)) * scale_factors[i])) 4520 for i in range(dim) 4521 ] 4522 else: 4523 output_size = [ 4524 _sym_int(input.size(i + 2) * scale_factors[i]) for i in range(dim) 4525 ] 4526 scale_factors = None 4527 4528 if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4): 4529 raise ValueError( 4530 "Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input" 4531 ) 4532 4533 if input.dim() == 3 and mode == "nearest": 4534 return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) 4535 if input.dim() == 4 and mode == "nearest": 4536 return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) 4537 if input.dim() == 5 and mode == "nearest": 4538 return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) 4539 4540 if input.dim() == 3 and mode == "nearest-exact": 4541 return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) 4542 if input.dim() == 4 and mode == "nearest-exact": 4543 return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) 4544 if input.dim() == 5 and mode == "nearest-exact": 4545 return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) 4546 4547 if input.dim() == 3 and mode == "area": 4548 assert output_size is not None 4549 return adaptive_avg_pool1d(input, output_size) 4550 if input.dim() == 4 and mode == "area": 4551 assert output_size is not None 4552 return adaptive_avg_pool2d(input, output_size) 4553 if input.dim() == 5 and mode == "area": 4554 assert output_size is not None 4555 return adaptive_avg_pool3d(input, output_size) 4556 4557 if input.dim() == 3 and mode == "linear": 4558 assert align_corners is not None 4559 return torch._C._nn.upsample_linear1d( 4560 input, output_size, align_corners, scale_factors 4561 ) 4562 if input.dim() == 4 and mode == "bilinear": 4563 assert align_corners is not None 4564 if antialias: 4565 return torch._C._nn._upsample_bilinear2d_aa( 4566 input, output_size, align_corners, scale_factors 4567 ) 4568 # Two levels are necessary to prevent TorchScript from touching 4569 # are_deterministic_algorithms_enabled. 4570 if not torch.jit.is_scripting(): 4571 if torch.are_deterministic_algorithms_enabled() and ( 4572 input.is_cuda or input.is_xpu 4573 ): 4574 # Use slow decomp whose backward will be in terms of index_put 4575 # importlib is required because the import cannot be top level 4576 # (cycle) and cannot be nested (TS doesn't support) 4577 return importlib.import_module( 4578 "torch._decomp.decompositions" 4579 )._upsample_linear_vec(input, output_size, align_corners, scale_factors) 4580 return torch._C._nn.upsample_bilinear2d( 4581 input, output_size, align_corners, scale_factors 4582 ) 4583 if input.dim() == 5 and mode == "trilinear": 4584 assert align_corners is not None 4585 return torch._C._nn.upsample_trilinear3d( 4586 input, output_size, align_corners, scale_factors 4587 ) 4588 if input.dim() == 4 and mode == "bicubic": 4589 assert align_corners is not None 4590 if antialias: 4591 return torch._C._nn._upsample_bicubic2d_aa( 4592 input, output_size, align_corners, scale_factors 4593 ) 4594 return torch._C._nn.upsample_bicubic2d( 4595 input, output_size, align_corners, scale_factors 4596 ) 4597 4598 if input.dim() == 3 and mode == "bilinear": 4599 raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") 4600 if input.dim() == 3 and mode == "trilinear": 4601 raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") 4602 if input.dim() == 4 and mode == "linear": 4603 raise NotImplementedError("Got 4D input, but linear mode needs 3D input") 4604 if input.dim() == 4 and mode == "trilinear": 4605 raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") 4606 if input.dim() == 5 and mode == "linear": 4607 raise NotImplementedError("Got 5D input, but linear mode needs 3D input") 4608 if input.dim() == 5 and mode == "bilinear": 4609 raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") 4610 4611 raise NotImplementedError( 4612 "Input Error: Only 3D, 4D and 5D input Tensors supported" 4613 f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact" 4614 f" (got {mode})" 4615 ) 4616 4617 4618if interpolate.__doc__: 4619 interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes) 4620 4621 4622@_overload 4623def upsample_nearest( # noqa: F811 4624 input: Tensor, 4625 size: Optional[int] = None, 4626 scale_factor: Optional[float] = None, 4627) -> Tensor: 4628 pass 4629 4630 4631@_overload 4632def upsample_nearest( # noqa: F811 4633 input: Tensor, 4634 size: Optional[List[int]] = None, 4635 scale_factor: Optional[float] = None, 4636) -> Tensor: 4637 pass 4638 4639 4640def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 4641 r"""Upsamples the input, using nearest neighbours' pixel values. 4642 4643 .. warning:: 4644 This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. 4645 This is equivalent with ``nn.functional.interpolate(..., mode='nearest')``. 4646 4647 Currently spatial and volumetric upsampling are supported (i.e. expected 4648 inputs are 4 or 5 dimensional). 4649 4650 Args: 4651 input (Tensor): input 4652 size (int or Tuple[int, int] or Tuple[int, int, int]): output spatia 4653 size. 4654 scale_factor (int): multiplier for spatial size. Has to be an integer. 4655 4656 Note: 4657 {backward_reproducibility_note} 4658 """ 4659 # DeprecationWarning is ignored by default 4660 warnings.warn( 4661 "`nn.functional.upsample_nearest` is deprecated. " 4662 "Use `nn.functional.interpolate` instead.", 4663 stacklevel=2, 4664 ) 4665 return interpolate(input, size, scale_factor, mode="nearest") 4666 4667 4668if upsample_nearest.__doc__: 4669 upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes) 4670 4671 4672@_overload 4673def upsample_bilinear( # noqa: F811 4674 input: Tensor, 4675 size: Optional[int] = None, 4676 scale_factor: Optional[float] = None, 4677) -> Tensor: 4678 pass 4679 4680 4681@_overload 4682def upsample_bilinear( # noqa: F811 4683 input: Tensor, 4684 size: Optional[List[int]] = None, 4685 scale_factor: Optional[float] = None, 4686) -> Tensor: 4687 pass 4688 4689 4690@_overload 4691def upsample_bilinear( # noqa: F811 4692 input: Tensor, 4693 size: Optional[int] = None, 4694 scale_factor: Optional[List[float]] = None, 4695) -> Tensor: 4696 pass 4697 4698 4699@_overload 4700def upsample_bilinear( # noqa: F811 4701 input: Tensor, 4702 size: Optional[List[int]] = None, 4703 scale_factor: Optional[List[float]] = None, 4704) -> Tensor: 4705 pass 4706 4707 4708def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 4709 r"""Upsamples the input, using bilinear upsampling. 4710 4711 .. warning:: 4712 This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. 4713 This is equivalent with 4714 ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. 4715 4716 Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` fo 4717 volumetric (5 dimensional) inputs. 4718 4719 Args: 4720 input (Tensor): input 4721 size (int or Tuple[int, int]): output spatial size. 4722 scale_factor (int or Tuple[int, int]): multiplier for spatial size 4723 4724 Note: 4725 {backward_reproducibility_note} 4726 """ 4727 # DeprecationWarning is ignored by default 4728 warnings.warn( 4729 "`nn.functional.upsample_bilinear` is deprecated. " 4730 "Use `nn.functional.interpolate` instead.", 4731 stacklevel=2, 4732 ) 4733 return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) 4734 4735 4736if upsample_bilinear.__doc__: 4737 upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format( 4738 **reproducibility_notes 4739 ) 4740 4741GRID_SAMPLE_INTERPOLATION_MODES = { 4742 "bilinear": 0, 4743 "nearest": 1, 4744 "bicubic": 2, 4745} 4746 4747GRID_SAMPLE_PADDING_MODES = { 4748 "zeros": 0, 4749 "border": 1, 4750 "reflection": 2, 4751} 4752 4753 4754def grid_sample( 4755 input: Tensor, 4756 grid: Tensor, 4757 mode: str = "bilinear", 4758 padding_mode: str = "zeros", 4759 align_corners: Optional[bool] = None, 4760) -> Tensor: 4761 r"""Compute grid sample. 4762 4763 Given an :attr:`input` and a flow-field :attr:`grid`, computes the 4764 ``output`` using :attr:`input` values and pixel locations from :attr:`grid`. 4765 4766 Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are 4767 supported. 4768 4769 In the spatial (4-D) case, for :attr:`input` with shape 4770 :math:`(N, C, H_\text{in}, W_\text{in})` and :attr:`grid` with shape 4771 :math:`(N, H_\text{out}, W_\text{out}, 2)`, the output will have shape 4772 :math:`(N, C, H_\text{out}, W_\text{out})`. 4773 4774 For each output location ``output[n, :, h, w]``, the size-2 vector 4775 ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``, 4776 which are used to interpolate the output value ``output[n, :, h, w]``. 4777 In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the 4778 ``x``, ``y``, ``z`` pixel locations for interpolating 4779 ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or 4780 ``bilinear`` interpolation method to sample the input pixels. 4781 4782 :attr:`grid` specifies the sampling pixel locations normalized by the 4783 :attr:`input` spatial dimensions. Therefore, it should have most values in 4784 the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the 4785 left-top pixel of :attr:`input`, and values ``x = 1, y = 1`` is the 4786 right-bottom pixel of :attr:`input`. 4787 4788 If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding 4789 outputs are handled as defined by :attr:`padding_mode`. Options are 4790 4791 * ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations, 4792 * ``padding_mode="border"``: use border values for out-of-bound grid locations, 4793 * ``padding_mode="reflection"``: use values at locations reflected by 4794 the border for out-of-bound grid locations. For location far away 4795 from the border, it will keep being reflected until becoming in bound, 4796 e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1`` 4797 and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes 4798 ``x'' = -0.5``. 4799 4800 Note: 4801 This function is often used in conjunction with :func:`affine_grid` 4802 to build `Spatial Transformer Networks`_ . 4803 4804 Note: 4805 When using the CUDA backend, this operation may induce nondeterministic 4806 behaviour in its backward pass that is not easily switched off. 4807 Please see the notes on :doc:`/notes/randomness` for background. 4808 4809 Note: 4810 NaN values in :attr:`grid` would be interpreted as ``-1``. 4811 4812 Args: 4813 input (Tensor): input of shape :math:`(N, C, H_\text{in}, W_\text{in})` (4-D case) 4814 or :math:`(N, C, D_\text{in}, H_\text{in}, W_\text{in})` (5-D case) 4815 grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case) 4816 or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case) 4817 mode (str): interpolation mode to calculate output values 4818 ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` 4819 Note: ``mode='bicubic'`` supports only 4-D input. 4820 When ``mode='bilinear'`` and the input is 5-D, the interpolation mode 4821 used internally will actually be trilinear. However, when the input is 4-D, 4822 the interpolation mode will legitimately be bilinear. 4823 padding_mode (str): padding mode for outside grid values 4824 ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'`` 4825 align_corners (bool, optional): Geometrically, we consider the pixels of the 4826 input as squares rather than points. 4827 If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring 4828 to the center points of the input's corner pixels. If set to ``False``, they 4829 are instead considered as referring to the corner points of the input's corner 4830 pixels, making the sampling more resolution agnostic. 4831 This option parallels the ``align_corners`` option in 4832 :func:`interpolate`, and so whichever option is used here 4833 should also be used there to resize the input image before grid sampling. 4834 Default: ``False`` 4835 4836 Returns: 4837 output (Tensor): output Tensor 4838 4839 .. _`Spatial Transformer Networks`: 4840 https://arxiv.org/abs/1506.02025 4841 4842 .. warning:: 4843 When ``align_corners = True``, the grid positions depend on the pixel 4844 size relative to the input image size, and so the locations sampled by 4845 :func:`grid_sample` will differ for the same input given at different 4846 resolutions (that is, after being upsampled or downsampled). 4847 The default behavior up to version 1.2.0 was ``align_corners = True``. 4848 Since then, the default behavior has been changed to ``align_corners = False``, 4849 in order to bring it in line with the default for :func:`interpolate`. 4850 4851 .. note:: 4852 ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. 4853 The constant :math:`\alpha` might be different from packages to packages. 4854 For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. 4855 This algorithm may "overshoot" the range of values it's interpolating. 4856 For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. 4857 Clamp the results with :func:`torch.clamp` to ensure they are within the valid range. 4858 .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation 4859 .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 4860 .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 4861 """ 4862 if has_torch_function_variadic(input, grid): 4863 return handle_torch_function( 4864 grid_sample, 4865 (input, grid), 4866 input, 4867 grid, 4868 mode=mode, 4869 padding_mode=padding_mode, 4870 align_corners=align_corners, 4871 ) 4872 if mode != "bilinear" and mode != "nearest" and mode != "bicubic": 4873 raise ValueError( 4874 f"nn.functional.grid_sample(): expected mode to be 'bilinear', 'nearest' or 'bicubic', but got: '{mode}'" 4875 ) 4876 if ( 4877 padding_mode != "zeros" 4878 and padding_mode != "border" 4879 and padding_mode != "reflection" 4880 ): 4881 raise ValueError( 4882 "nn.functional.grid_sample(): expected padding_mode " 4883 "to be 'zeros', 'border', or 'reflection', " 4884 f"but got: '{padding_mode}'" 4885 ) 4886 4887 if mode == "bilinear": 4888 mode_enum = 0 4889 elif mode == "nearest": 4890 mode_enum = 1 4891 else: # mode == 'bicubic' 4892 mode_enum = 2 4893 4894 if padding_mode == "zeros": 4895 padding_mode_enum = 0 4896 elif padding_mode == "border": 4897 padding_mode_enum = 1 4898 else: # padding_mode == 'reflection' 4899 padding_mode_enum = 2 4900 4901 if align_corners is None: 4902 warnings.warn( 4903 "Default grid_sample and affine_grid behavior has changed " 4904 "to align_corners=False since 1.3.0. Please specify " 4905 "align_corners=True if the old behavior is desired. " 4906 "See the documentation of grid_sample for details." 4907 ) 4908 align_corners = False 4909 4910 return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) 4911 4912 4913def affine_grid( 4914 theta: Tensor, 4915 size: List[int], 4916 align_corners: Optional[bool] = None, 4917) -> Tensor: 4918 r"""Generate 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`. 4919 4920 .. note:: 4921 This function is often used in conjunction with :func:`grid_sample` 4922 to build `Spatial Transformer Networks`_ . 4923 4924 Args: 4925 theta (Tensor): input batch of affine matrices with shape 4926 (:math:`N \times 2 \times 3`) for 2D or 4927 (:math:`N \times 3 \times 4`) for 3D 4928 size (torch.Size): the target output image size. 4929 (:math:`N \times C \times H \times W` for 2D or 4930 :math:`N \times C \times D \times H \times W` for 3D) 4931 Example: torch.Size((32, 3, 24, 24)) 4932 align_corners (bool, optional): if ``True``, consider ``-1`` and ``1`` 4933 to refer to the centers of the corner pixels rather than the image corners. 4934 Refer to :func:`grid_sample` for a more complete description. 4935 A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample` 4936 with the same setting for this option. 4937 Default: ``False`` 4938 4939 Returns: 4940 output (Tensor): output Tensor of size (:math:`N \times H \times W \times 2`) 4941 4942 .. _`Spatial Transformer Networks`: 4943 https://arxiv.org/abs/1506.02025 4944 4945 .. warning:: 4946 When ``align_corners = True``, the grid positions depend on the pixel 4947 size relative to the input image size, and so the locations sampled by 4948 :func:`grid_sample` will differ for the same input given at different 4949 resolutions (that is, after being upsampled or downsampled). 4950 The default behavior up to version 1.2.0 was ``align_corners = True``. 4951 Since then, the default behavior has been changed to ``align_corners = False``, 4952 in order to bring it in line with the default for :func:`interpolate`. 4953 .. warning:: 4954 When ``align_corners = True``, 2D affine transforms on 1D data and 4955 3D affine transforms on 2D data (that is, when one of the spatial 4956 dimensions has unit size) are ill-defined, and not an intended use case. 4957 This is not a problem when ``align_corners = False``. 4958 Up to version 1.2.0, all grid points along a unit dimension were 4959 considered arbitrarily to be at ``-1``. 4960 From version 1.3.0, under ``align_corners = True`` all grid points 4961 along a unit dimension are considered to be at ``0`` 4962 (the center of the input image). 4963 """ 4964 if has_torch_function_unary(theta): 4965 return handle_torch_function( 4966 affine_grid, (theta,), theta, size, align_corners=align_corners 4967 ) 4968 if align_corners is None: 4969 warnings.warn( 4970 "Default grid_sample and affine_grid behavior has changed " 4971 "to align_corners=False since 1.3.0. Please specify " 4972 "align_corners=True if the old behavior is desired. " 4973 "See the documentation of grid_sample for details." 4974 ) 4975 align_corners = False 4976 4977 # enforce floating point dtype on theta 4978 if not theta.is_floating_point(): 4979 raise ValueError( 4980 f"Expected theta to have floating point type, but got {theta.dtype}" 4981 ) 4982 # check that shapes and sizes match 4983 if len(size) == 4: 4984 if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3: 4985 raise ValueError( 4986 f"Expected a batch of 2D affine matrices of shape Nx2x3 for size {size}. Got {theta.shape}." 4987 ) 4988 spatial_size = size[-2:] # spatial dimension sizes 4989 elif len(size) == 5: 4990 if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4: 4991 raise ValueError( 4992 f"Expected a batch of 3D affine matrices of shape Nx3x4 for size {size}. Got {theta.shape}." 4993 ) 4994 spatial_size = size[-3:] # spatial dimension sizes 4995 else: 4996 raise NotImplementedError( 4997 "affine_grid only supports 4D and 5D sizes, " 4998 "for 2D and 3D affine transforms, respectively. " 4999 f"Got size {size}." 5000 ) 5001 # check for empty span 5002 if align_corners and min(spatial_size) == 1: 5003 warnings.warn( 5004 "Since version 1.3.0, affine_grid behavior has changed " 5005 "for unit-size grids when align_corners=True. " 5006 "This is not an intended use case of affine_grid. " 5007 "See the documentation of affine_grid for details." 5008 ) 5009 elif min(size) <= 0: 5010 raise ValueError(f"Expected non-zero, positive output size. Got {size}") 5011 5012 return torch.affine_grid_generator(theta, size, align_corners) 5013 5014 5015def pad( 5016 input: Tensor, 5017 pad: List[int], 5018 mode: str = "constant", 5019 value: Optional[float] = None, 5020) -> Tensor: 5021 r""" 5022 pad(input, pad, mode="constant", value=None) -> Tensor 5023 5024 Pads tensor. 5025 5026 Padding size: 5027 The padding size by which to pad some dimensions of :attr:`input` 5028 are described starting from the last dimension and moving forward. 5029 :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions 5030 of ``input`` will be padded. 5031 For example, to pad only the last dimension of the input tensor, then 5032 :attr:`pad` has the form 5033 :math:`(\text{padding\_left}, \text{padding\_right})`; 5034 to pad the last 2 dimensions of the input tensor, then use 5035 :math:`(\text{padding\_left}, \text{padding\_right},` 5036 :math:`\text{padding\_top}, \text{padding\_bottom})`; 5037 to pad the last 3 dimensions, use 5038 :math:`(\text{padding\_left}, \text{padding\_right},` 5039 :math:`\text{padding\_top}, \text{padding\_bottom}` 5040 :math:`\text{padding\_front}, \text{padding\_back})`. 5041 5042 Padding mode: 5043 See :class:`torch.nn.CircularPad2d`, :class:`torch.nn.ConstantPad2d`, 5044 :class:`torch.nn.ReflectionPad2d`, and :class:`torch.nn.ReplicationPad2d` 5045 for concrete examples on how each of the padding modes works. Constant 5046 padding is implemented for arbitrary dimensions. Circular, replicate and 5047 reflection padding are implemented for padding the last 3 dimensions of a 5048 4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor, 5049 or the last dimension of a 2D or 3D input tensor. 5050 5051 Note: 5052 When using the CUDA backend, this operation may induce nondeterministic 5053 behaviour in its backward pass that is not easily switched off. 5054 Please see the notes on :doc:`/notes/randomness` for background. 5055 5056 Args: 5057 input (Tensor): N-dimensional tensor 5058 pad (tuple): m-elements tuple, where 5059 :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even. 5060 mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. 5061 Default: ``'constant'`` 5062 value: fill value for ``'constant'`` padding. Default: ``0`` 5063 5064 Examples:: 5065 5066 >>> t4d = torch.empty(3, 3, 4, 2) 5067 >>> p1d = (1, 1) # pad last dim by 1 on each side 5068 >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding 5069 >>> print(out.size()) 5070 torch.Size([3, 3, 4, 4]) 5071 >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2) 5072 >>> out = F.pad(t4d, p2d, "constant", 0) 5073 >>> print(out.size()) 5074 torch.Size([3, 3, 8, 4]) 5075 >>> t4d = torch.empty(3, 3, 4, 2) 5076 >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3) 5077 >>> out = F.pad(t4d, p3d, "constant", 0) 5078 >>> print(out.size()) 5079 torch.Size([3, 9, 7, 3]) 5080 """ 5081 if has_torch_function_unary(input): 5082 return handle_torch_function( 5083 torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value 5084 ) 5085 if not torch.jit.is_scripting(): 5086 if torch.are_deterministic_algorithms_enabled() and ( 5087 input.is_cuda or input.is_xpu 5088 ): 5089 if mode == "replicate": 5090 # Use slow decomp whose backward will be in terms of index_put. 5091 # importlib is required because the import cannot be top level 5092 # (cycle) and cannot be nested (TS doesn't support) 5093 return importlib.import_module( 5094 "torch._decomp.decompositions" 5095 )._replication_pad(input, pad) 5096 return torch._C._nn.pad(input, pad, mode, value) 5097 5098 5099# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 5100pad.__module__ = "torch.nn.functional" 5101 5102# distance 5103 5104 5105pairwise_distance = _add_docstr( 5106 torch.pairwise_distance, 5107 r""" 5108pairwise_distance(x1, x2, p=2.0, eps=1e-6, keepdim=False) -> Tensor 5109 5110See :class:`torch.nn.PairwiseDistance` for details 5111""", 5112) 5113 5114 5115pdist = _add_docstr( 5116 torch.pdist, 5117 r""" 5118pdist(input, p=2) -> Tensor 5119 5120Computes the p-norm distance between every pair of row vectors in the input. 5121This is identical to the upper triangular portion, excluding the diagonal, of 5122`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster 5123if the rows are contiguous. 5124 5125If input has shape :math:`N \times M` then the output will have shape 5126:math:`\frac{1}{2} N (N - 1)`. 5127 5128This function is equivalent to ``scipy.spatial.distance.pdist(input, 5129'minkowski', p=p)`` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is 5130equivalent to ``scipy.spatial.distance.pdist(input, 'hamming') * M``. 5131When :math:`p = \infty`, the closest scipy function is 5132``scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())``. 5133 5134Args: 5135 input: input tensor of shape :math:`N \times M`. 5136 p: p value for the p-norm distance to calculate between each vector pair 5137 :math:`\in [0, \infty]`. 5138""", 5139) 5140 5141 5142cosine_similarity = _add_docstr( 5143 torch.cosine_similarity, 5144 r""" 5145cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor 5146 5147Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable 5148to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is 5149squeezed (see :func:`torch.squeeze`), resulting in the 5150output tensor having 1 fewer dimension. 5151 5152.. math :: 5153 \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2, \epsilon) \cdot \max(\Vert x_2 \Vert _2, \epsilon)} 5154 5155Supports :ref:`type promotion <type-promotion-doc>`. 5156 5157Args: 5158 x1 (Tensor): First input. 5159 x2 (Tensor): Second input. 5160 dim (int, optional): Dimension along which cosine similarity is computed. Default: 1 5161 eps (float, optional): Small value to avoid division by zero. 5162 Default: 1e-8 5163 5164Example:: 5165 5166 >>> input1 = torch.randn(100, 128) 5167 >>> input2 = torch.randn(100, 128) 5168 >>> output = F.cosine_similarity(input1, input2) 5169 >>> print(output) 5170""", 5171) 5172 5173 5174one_hot = _add_docstr( 5175 torch._C._nn.one_hot, 5176 r""" 5177one_hot(tensor, num_classes=-1) -> LongTensor 5178 5179Takes LongTensor with index values of shape ``(*)`` and returns a tensor 5180of shape ``(*, num_classes)`` that have zeros everywhere except where the 5181index of last dimension matches the corresponding value of the input tensor, 5182in which case it will be 1. 5183 5184See also `One-hot on Wikipedia`_ . 5185 5186.. _One-hot on Wikipedia: 5187 https://en.wikipedia.org/wiki/One-hot 5188 5189Arguments: 5190 tensor (LongTensor): class values of any shape. 5191 num_classes (int): Total number of classes. If set to -1, the number 5192 of classes will be inferred as one greater than the largest class 5193 value in the input tensor. 5194 5195Returns: 5196 LongTensor that has one more dimension with 1 values at the 5197 index of last dimension indicated by the input, and 0 everywhere 5198 else. 5199 5200Examples: 5201 >>> F.one_hot(torch.arange(0, 5) % 3) 5202 tensor([[1, 0, 0], 5203 [0, 1, 0], 5204 [0, 0, 1], 5205 [1, 0, 0], 5206 [0, 1, 0]]) 5207 >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) 5208 tensor([[1, 0, 0, 0, 0], 5209 [0, 1, 0, 0, 0], 5210 [0, 0, 1, 0, 0], 5211 [1, 0, 0, 0, 0], 5212 [0, 1, 0, 0, 0]]) 5213 >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) 5214 tensor([[[1, 0, 0], 5215 [0, 1, 0]], 5216 [[0, 0, 1], 5217 [1, 0, 0]], 5218 [[0, 1, 0], 5219 [0, 0, 1]]]) 5220""", 5221) 5222 5223 5224def triplet_margin_loss( 5225 anchor: Tensor, 5226 positive: Tensor, 5227 negative: Tensor, 5228 margin: float = 1.0, 5229 p: float = 2, 5230 eps: float = 1e-6, 5231 swap: bool = False, 5232 size_average: Optional[bool] = None, 5233 reduce: Optional[bool] = None, 5234 reduction: str = "mean", 5235) -> Tensor: 5236 r"""Compute the triplet loss between given input tensors and a margin greater than 0. 5237 5238 See :class:`~torch.nn.TripletMarginLoss` for details. 5239 """ 5240 if has_torch_function_variadic(anchor, positive, negative): 5241 return handle_torch_function( 5242 triplet_margin_loss, 5243 (anchor, positive, negative), 5244 anchor, 5245 positive, 5246 negative, 5247 margin=margin, 5248 p=p, 5249 eps=eps, 5250 swap=swap, 5251 size_average=size_average, 5252 reduce=reduce, 5253 reduction=reduction, 5254 ) 5255 if size_average is not None or reduce is not None: 5256 reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) 5257 else: 5258 reduction_enum = _Reduction.get_enum(reduction) 5259 if margin <= 0: 5260 raise ValueError(f"margin must be greater than 0, got {margin}") 5261 return torch.triplet_margin_loss( 5262 anchor, positive, negative, margin, p, eps, swap, reduction_enum 5263 ) 5264 5265 5266def triplet_margin_with_distance_loss( 5267 anchor: Tensor, 5268 positive: Tensor, 5269 negative: Tensor, 5270 *, 5271 distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, 5272 margin: float = 1.0, 5273 swap: bool = False, 5274 reduction: str = "mean", 5275) -> Tensor: 5276 r"""Compute the triplet margin loss for input tensors using a custom distance function. 5277 5278 See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details. 5279 """ 5280 if torch.jit.is_scripting(): 5281 raise NotImplementedError( 5282 "F.triplet_margin_with_distance_loss does not support JIT scripting: " 5283 "functions requiring Callables cannot be scripted." 5284 ) 5285 5286 if has_torch_function_variadic(anchor, positive, negative): 5287 return handle_torch_function( 5288 triplet_margin_with_distance_loss, 5289 (anchor, positive, negative), 5290 anchor, 5291 positive, 5292 negative, 5293 distance_function=distance_function, 5294 margin=margin, 5295 swap=swap, 5296 reduction=reduction, 5297 ) 5298 5299 # Check validity of reduction mode 5300 if reduction not in ("mean", "sum", "none"): 5301 raise ValueError(f"{reduction} is not a valid value for reduction") 5302 5303 # Check validity of margin 5304 if margin <= 0: 5305 raise ValueError(f"margin must be greater than 0, got {margin}") 5306 5307 # Check dimensions 5308 a_dim = anchor.ndim 5309 p_dim = positive.ndim 5310 n_dim = negative.ndim 5311 if not (a_dim == p_dim and p_dim == n_dim): 5312 raise RuntimeError( 5313 f"The anchor, positive, and negative tensors are expected to have " 5314 f"the same number of dimensions, but got: anchor {a_dim}D, " 5315 f"positive {p_dim}D, and negative {n_dim}D inputs" 5316 ) 5317 5318 # Calculate loss 5319 if distance_function is None: 5320 distance_function = torch.pairwise_distance 5321 5322 dist_pos = distance_function(anchor, positive) 5323 dist_neg = distance_function(anchor, negative) 5324 # The distance swap is described in the paper "Learning shallow 5325 # convolutional feature descriptors with triplet losses" by V. Balntas, E. 5326 # Riba et al. If True, and if the positive example is closer to the 5327 # negative example than the anchor is, swaps the positive example and the 5328 # anchor in the loss computation. 5329 if swap: 5330 dist_swap = distance_function(positive, negative) 5331 dist_neg = torch.minimum(dist_neg, dist_swap) 5332 loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) 5333 5334 # Apply reduction 5335 if reduction == "sum": 5336 return torch.sum(loss) 5337 elif reduction == "mean": 5338 return torch.mean(loss) 5339 else: # reduction == "none" 5340 return loss 5341 5342 5343def normalize( 5344 input: Tensor, 5345 p: float = 2.0, 5346 dim: int = 1, 5347 eps: float = 1e-12, 5348 out: Optional[Tensor] = None, 5349) -> Tensor: 5350 r"""Perform :math:`L_p` normalization of inputs over specified dimension. 5351 5352 For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each 5353 :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as 5354 5355 .. math:: 5356 v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. 5357 5358 With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization. 5359 5360 Args: 5361 input: input tensor of any shape 5362 p (float): the exponent value in the norm formulation. Default: 2 5363 dim (int or tuple of ints): the dimension to reduce. Default: 1 5364 eps (float): small value to avoid division by zero. Default: 1e-12 5365 out (Tensor, optional): the output tensor. If :attr:`out` is used, this 5366 operation won't be differentiable. 5367 """ 5368 if has_torch_function_variadic(input, out): 5369 return handle_torch_function( 5370 normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out 5371 ) 5372 if out is None: 5373 denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) 5374 return input / denom 5375 else: 5376 denom = input.norm(p, dim, keepdim=True).clamp_min_(eps).expand_as(input) 5377 return torch.div(input, denom, out=out) 5378 5379 5380def assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None: 5381 assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) 5382 5383 5384def unfold( 5385 input: Tensor, 5386 kernel_size: BroadcastingList2[int], 5387 dilation: BroadcastingList2[int] = 1, 5388 padding: BroadcastingList2[int] = 0, 5389 stride: BroadcastingList2[int] = 1, 5390) -> Tensor: 5391 r"""Extract sliding local blocks from a batched input tensor. 5392 5393 .. warning:: 5394 Currently, only 4-D input tensors (batched image-like tensors) are 5395 supported. 5396 5397 .. warning:: 5398 5399 More than one element of the unfolded tensor may refer to a single 5400 memory location. As a result, in-place operations (especially ones that 5401 are vectorized) may result in incorrect behavior. If you need to write 5402 to the tensor, please clone it first. 5403 5404 5405 See :class:`torch.nn.Unfold` for details 5406 """ 5407 if has_torch_function_unary(input): 5408 return handle_torch_function( 5409 unfold, 5410 (input,), 5411 input, 5412 kernel_size, 5413 dilation=dilation, 5414 padding=padding, 5415 stride=stride, 5416 ) 5417 return torch._C._nn.im2col( 5418 input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride) 5419 ) 5420 5421 5422def fold( 5423 input: Tensor, 5424 output_size: BroadcastingList2[int], 5425 kernel_size: BroadcastingList2[int], 5426 dilation: BroadcastingList2[int] = 1, 5427 padding: BroadcastingList2[int] = 0, 5428 stride: BroadcastingList2[int] = 1, 5429) -> Tensor: 5430 r"""Combine an array of sliding local blocks into a large containing tensor. 5431 5432 .. warning:: 5433 Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported. 5434 5435 See :class:`torch.nn.Fold` for details 5436 """ 5437 if has_torch_function_unary(input): 5438 return handle_torch_function( 5439 fold, 5440 (input,), 5441 input, 5442 output_size, 5443 kernel_size, 5444 dilation=dilation, 5445 padding=padding, 5446 stride=stride, 5447 ) 5448 return torch._C._nn.col2im( 5449 input, 5450 _pair(output_size), 5451 _pair(kernel_size), 5452 _pair(dilation), 5453 _pair(padding), 5454 _pair(stride), 5455 ) 5456 5457 5458# 5459# multihead attention 5460# 5461 5462 5463def _in_projection_packed( 5464 q: Tensor, 5465 k: Tensor, 5466 v: Tensor, 5467 w: Tensor, 5468 b: Optional[Tensor] = None, 5469) -> List[Tensor]: 5470 r"""Perform the in-projection step of the attention operation, using packed weights. 5471 5472 Output is a triple containing projection tensors for query, key and value. 5473 5474 Args: 5475 q, k, v: query, key and value tensors to be projected. For self-attention, 5476 these are typically the same tensor; for encoder-decoder attention, 5477 k and v are typically the same tensor. (We take advantage of these 5478 identities for performance if they are present.) Regardless, q, k and v 5479 must share a common embedding dimension; otherwise their shapes may vary. 5480 w: projection weights for q, k and v, packed into a single tensor. Weights 5481 are packed along dimension 0, in q, k, v order. 5482 b: optional projection biases for q, k and v, packed into a single tensor 5483 in q, k, v order. 5484 5485 Shape: 5486 Inputs: 5487 - q: :math:`(..., E)` where E is the embedding dimension 5488 - k: :math:`(..., E)` where E is the embedding dimension 5489 - v: :math:`(..., E)` where E is the embedding dimension 5490 - w: :math:`(E * 3, E)` where E is the embedding dimension 5491 - b: :math:`E * 3` where E is the embedding dimension 5492 5493 Output: 5494 - in output list :math:`[q', k', v']`, each output tensor will have the 5495 same shape as the corresponding input tensor. 5496 """ 5497 E = q.size(-1) 5498 if k is v: 5499 if q is k: 5500 # self-attention 5501 proj = linear(q, w, b) 5502 # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() 5503 proj = ( 5504 proj.unflatten(-1, (3, E)) 5505 .unsqueeze(0) 5506 .transpose(0, -2) 5507 .squeeze(-2) 5508 .contiguous() 5509 ) 5510 return proj[0], proj[1], proj[2] 5511 else: 5512 # encoder-decoder attention 5513 w_q, w_kv = w.split([E, E * 2]) 5514 if b is None: 5515 b_q = b_kv = None 5516 else: 5517 b_q, b_kv = b.split([E, E * 2]) 5518 q_proj = linear(q, w_q, b_q) 5519 kv_proj = linear(k, w_kv, b_kv) 5520 # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() 5521 kv_proj = ( 5522 kv_proj.unflatten(-1, (2, E)) 5523 .unsqueeze(0) 5524 .transpose(0, -2) 5525 .squeeze(-2) 5526 .contiguous() 5527 ) 5528 return (q_proj, kv_proj[0], kv_proj[1]) 5529 else: 5530 w_q, w_k, w_v = w.chunk(3) 5531 if b is None: 5532 b_q = b_k = b_v = None 5533 else: 5534 b_q, b_k, b_v = b.chunk(3) 5535 return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) 5536 5537 5538def _in_projection( 5539 q: Tensor, 5540 k: Tensor, 5541 v: Tensor, 5542 w_q: Tensor, 5543 w_k: Tensor, 5544 w_v: Tensor, 5545 b_q: Optional[Tensor] = None, 5546 b_k: Optional[Tensor] = None, 5547 b_v: Optional[Tensor] = None, 5548) -> Tuple[Tensor, Tensor, Tensor]: 5549 r"""Perform the in-projection step of the attention operation. 5550 5551 This is simply a triple of linear projections, 5552 with shape constraints on the weights which 5553 ensure embedding dimension uniformity in the projected outputs. 5554 Output is a triple containing projection tensors for query, key and value. 5555 5556 Args: 5557 q, k, v: query, key and value tensors to be projected. 5558 w_q, w_k, w_v: weights for q, k and v, respectively. 5559 b_q, b_k, b_v: optional biases for q, k and v, respectively. 5560 5561 Shape: 5562 Inputs: 5563 - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any 5564 number of leading dimensions. 5565 - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any 5566 number of leading dimensions. 5567 - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any 5568 number of leading dimensions. 5569 - w_q: :math:`(Eq, Eq)` 5570 - w_k: :math:`(Eq, Ek)` 5571 - w_v: :math:`(Eq, Ev)` 5572 - b_q: :math:`(Eq)` 5573 - b_k: :math:`(Eq)` 5574 - b_v: :math:`(Eq)` 5575 5576 Output: in output triple :math:`(q', k', v')`, 5577 - q': :math:`[Qdims..., Eq]` 5578 - k': :math:`[Kdims..., Eq]` 5579 - v': :math:`[Vdims..., Eq]` 5580 5581 """ 5582 Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) 5583 assert w_q.shape == ( 5584 Eq, 5585 Eq, 5586 ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" 5587 assert w_k.shape == ( 5588 Eq, 5589 Ek, 5590 ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" 5591 assert w_v.shape == ( 5592 Eq, 5593 Ev, 5594 ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" 5595 assert b_q is None or b_q.shape == ( 5596 Eq, 5597 ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" 5598 assert b_k is None or b_k.shape == ( 5599 Eq, 5600 ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" 5601 assert b_v is None or b_v.shape == ( 5602 Eq, 5603 ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" 5604 return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) 5605 5606 5607scaled_dot_product_attention = _add_docstr( 5608 torch._C._nn.scaled_dot_product_attention, 5609 r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, 5610 is_causal=False, scale=None, enable_gqa=False) -> Tensor: 5611 5612 Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed, 5613 and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be 5614 specified as a keyword argument. 5615 5616 .. code-block:: python 5617 5618 # Efficient implementation equivalent to the following: 5619 def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, 5620 is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: 5621 L, S = query.size(-2), key.size(-2) 5622 scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 5623 attn_bias = torch.zeros(L, S, dtype=query.dtype) 5624 if is_causal: 5625 assert attn_mask is None 5626 temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) 5627 attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 5628 attn_bias.to(query.dtype) 5629 5630 if attn_mask is not None: 5631 if attn_mask.dtype == torch.bool: 5632 attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 5633 else: 5634 attn_bias += attn_mask 5635 5636 if enable_gqa: 5637 key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) 5638 value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) 5639 5640 attn_weight = query @ key.transpose(-2, -1) * scale_factor 5641 attn_weight += attn_bias 5642 attn_weight = torch.softmax(attn_weight, dim=-1) 5643 attn_weight = torch.dropout(attn_weight, dropout_p, train=True) 5644 return attn_weight @ value 5645 5646 .. warning:: 5647 This function is beta and subject to change. 5648 5649 .. warning:: 5650 This function always applies dropout according to the specified ``dropout_p`` argument. 5651 To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module 5652 that makes the function call is not in training mode. 5653 5654 For example: 5655 5656 .. code-block:: python 5657 5658 class MyModel(nn.Module): 5659 def __init__(self, p=0.5): 5660 super().__init__() 5661 self.p = p 5662 5663 def forward(self, ...): 5664 return F.scaled_dot_product_attention(..., 5665 dropout_p=(self.p if self.training else 0.0)) 5666 5667 Note: 5668 5669 There are currently three supported implementations of scaled dot product attention: 5670 5671 - `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_ 5672 - `Memory-Efficient Attention`_ 5673 - A PyTorch implementation defined in C++ matching the above formulation 5674 5675 The function may call optimized kernels for improved performance when using the CUDA backend. 5676 For all other backends, the PyTorch implementation will be used. 5677 5678 All implementations are enabled by default. Scaled dot product attention attempts to automatically select the 5679 most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation 5680 is used, the following functions are provided for enabling and disabling implementations. 5681 The context manager is the preferred mechanism: 5682 5683 - :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations. 5684 - :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention. 5685 - :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables Memory-Efficient Attention. 5686 - :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables the PyTorch C++ implementation. 5687 5688 Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation, 5689 disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`. 5690 In the event that a fused implementation is not available, a warning will be raised with the 5691 reasons why the fused implementation cannot run. 5692 5693 Due to the nature of fusing floating point operations, the output of this function may be different 5694 depending on what backend kernel is chosen. 5695 The c++ implementation supports torch.float64 and can be used when higher precision is required. 5696 For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16. 5697 For more information please see :doc:`/notes/numerical_accuracy` 5698 5699 Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention 5700 and math kernel on CUDA tensor, and does not support Nested tensor. 5701 Constraints for GQA: 5702 5703 - number_of_heads_query % number_of_heads_key_value == 0 and, 5704 - number_of_heads_key == number_of_heads_value 5705 5706 Note: 5707 5708 {cudnn_reproducibility_note} 5709 """.format( 5710 **reproducibility_notes 5711 ) 5712 + r""" 5713 Args: 5714 query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. 5715 key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`. 5716 value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`. 5717 attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, 5718 which is :math:`(N,..., L, S)`. Two types of masks are supported. 5719 A boolean mask where a value of True indicates that the element *should* take part in attention. 5720 A float mask of the same type as query, key, value that is added to the attention score. 5721 dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied 5722 is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a 5723 square matrix. The attention masking has the form of the upper left causal bias due to the alignment 5724 (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. 5725 An error is thrown if both attn_mask and is_causal are set. 5726 scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set 5727 to :math:`\frac{1}{\sqrt{E}}`. 5728 enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False. 5729 5730 Returns: 5731 output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. 5732 5733 Shape legend: 5734 - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` 5735 - :math:`S: \text{Source sequence length}` 5736 - :math:`L: \text{Target sequence length}` 5737 - :math:`E: \text{Embedding dimension of the query and key}` 5738 - :math:`Ev: \text{Embedding dimension of the value}` 5739 - :math:`Hq: \text{Number of heads of query}` 5740 - :math:`H: \text{Number of heads of key and value}` 5741 5742 Examples: 5743 5744 >>> # Optionally use the context manager to ensure one of the fused kernels is run 5745 >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 5746 >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 5747 >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 5748 >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 5749 >>> F.scaled_dot_product_attention(query,key,value) 5750 5751 5752 >>> # Sample for GQA for llama3 5753 >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda") 5754 >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 5755 >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 5756 >>> with sdpa_kernel(backends=[SDPBackend.MATH]): 5757 >>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True) 5758 5759 5760 .. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning: 5761 https://arxiv.org/abs/2307.08691 5762 .. _Memory-Efficient Attention: 5763 https://github.com/facebookresearch/xformers 5764 .. _Grouped-Query Attention: 5765 https://arxiv.org/pdf/2305.13245 5766 """, 5767) 5768 5769 5770def _mha_shape_check( 5771 query: Tensor, 5772 key: Tensor, 5773 value: Tensor, 5774 key_padding_mask: Optional[Tensor], 5775 attn_mask: Optional[Tensor], 5776 num_heads: int, 5777): 5778 # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask` 5779 # and returns if the input is batched or not. 5780 # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. 5781 5782 # Shape check. 5783 if query.dim() == 3: 5784 # Batched Inputs 5785 is_batched = True 5786 assert key.dim() == 3 and value.dim() == 3, ( 5787 "For batched (3-D) `query`, expected `key` and `value` to be 3-D" 5788 f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" 5789 ) 5790 if key_padding_mask is not None: 5791 assert key_padding_mask.dim() == 2, ( 5792 "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" 5793 f" but found {key_padding_mask.dim()}-D tensor instead" 5794 ) 5795 if attn_mask is not None: 5796 assert attn_mask.dim() in (2, 3), ( 5797 "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" 5798 f" but found {attn_mask.dim()}-D tensor instead" 5799 ) 5800 elif query.dim() == 2: 5801 # Unbatched Inputs 5802 is_batched = False 5803 assert key.dim() == 2 and value.dim() == 2, ( 5804 "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" 5805 f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" 5806 ) 5807 5808 if key_padding_mask is not None: 5809 assert key_padding_mask.dim() == 1, ( 5810 "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" 5811 f" but found {key_padding_mask.dim()}-D tensor instead" 5812 ) 5813 5814 if attn_mask is not None: 5815 assert attn_mask.dim() in (2, 3), ( 5816 "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" 5817 f" but found {attn_mask.dim()}-D tensor instead" 5818 ) 5819 if attn_mask.dim() == 3: 5820 expected_shape = (num_heads, query.shape[0], key.shape[0]) 5821 assert ( 5822 attn_mask.shape == expected_shape 5823 ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}" 5824 else: 5825 raise AssertionError( 5826 f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor" 5827 ) 5828 5829 return is_batched 5830 5831 5832def _canonical_mask( 5833 mask: Optional[Tensor], 5834 mask_name: str, 5835 other_type: Optional[DType], 5836 other_name: str, 5837 target_type: DType, 5838 check_other: bool = True, 5839) -> Optional[Tensor]: 5840 if mask is not None: 5841 _mask_dtype = mask.dtype 5842 _mask_is_float = torch.is_floating_point(mask) 5843 if _mask_dtype != torch.bool and not _mask_is_float: 5844 raise AssertionError( 5845 f"only bool and floating types of {mask_name} are supported" 5846 ) 5847 if check_other and other_type is not None: 5848 if _mask_dtype != other_type: 5849 warnings.warn( 5850 f"Support for mismatched {mask_name} and {other_name} " 5851 "is deprecated. Use same type for both instead." 5852 ) 5853 if not _mask_is_float: 5854 mask = torch.zeros_like(mask, dtype=target_type).masked_fill_( 5855 mask, float("-inf") 5856 ) 5857 return mask 5858 5859 5860def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: 5861 if input is None: 5862 return None 5863 elif isinstance(input, torch.Tensor): 5864 return input.dtype 5865 raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor") 5866 5867 5868def multi_head_attention_forward( 5869 query: Tensor, 5870 key: Tensor, 5871 value: Tensor, 5872 embed_dim_to_check: int, 5873 num_heads: int, 5874 in_proj_weight: Optional[Tensor], 5875 in_proj_bias: Optional[Tensor], 5876 bias_k: Optional[Tensor], 5877 bias_v: Optional[Tensor], 5878 add_zero_attn: bool, 5879 dropout_p: float, 5880 out_proj_weight: Tensor, 5881 out_proj_bias: Optional[Tensor], 5882 training: bool = True, 5883 key_padding_mask: Optional[Tensor] = None, 5884 need_weights: bool = True, 5885 attn_mask: Optional[Tensor] = None, 5886 use_separate_proj_weight: bool = False, 5887 q_proj_weight: Optional[Tensor] = None, 5888 k_proj_weight: Optional[Tensor] = None, 5889 v_proj_weight: Optional[Tensor] = None, 5890 static_k: Optional[Tensor] = None, 5891 static_v: Optional[Tensor] = None, 5892 average_attn_weights: bool = True, 5893 is_causal: bool = False, 5894) -> Tuple[Tensor, Optional[Tensor]]: 5895 r"""Forward method for MultiHeadAttention. 5896 5897 See :class:`torch.nn.MultiheadAttention` for details. 5898 5899 Args: 5900 query, key, value: map a query and a set of key-value pairs to an output. 5901 See "Attention Is All You Need" for more details. 5902 embed_dim_to_check: total dimension of the model. 5903 num_heads: parallel attention heads. 5904 in_proj_weight, in_proj_bias: input projection weight and bias. 5905 bias_k, bias_v: bias of the key and value sequences to be added at dim=0. 5906 add_zero_attn: add a new batch of zeros to the key and 5907 value sequences at dim=1. 5908 dropout_p: probability of an element to be zeroed. 5909 out_proj_weight, out_proj_bias: the output projection weight and bias. 5910 training: apply dropout if is ``True``. 5911 key_padding_mask: if provided, specified padding elements in the key will 5912 be ignored by the attention. This is an binary mask. When the value is True, 5913 the corresponding value on the attention layer will be filled with -inf. 5914 need_weights: output attn_output_weights. 5915 Default: `True` 5916 Note: `needs_weight` defaults to `True`, but should be set to `False` 5917 For best performance when attention weights are not needed. 5918 *Setting needs_weights to `True` 5919 leads to a significant performance degradation.* 5920 attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all 5921 the batches while a 3D mask allows to specify a different mask for the entries of each batch. 5922 is_causal: If specified, applies a causal mask as attention mask, and ignores 5923 attn_mask for computing scaled dot product attention. 5924 Default: ``False``. 5925 .. warning:: 5926 is_causal is provides a hint that the attn_mask is the 5927 causal mask.Providing incorrect hints can result in 5928 incorrect execution, including forward and backward 5929 compatibility. 5930 use_separate_proj_weight: the function accept the proj. weights for query, key, 5931 and value in different forms. If false, in_proj_weight will be used, which is 5932 a combination of q_proj_weight, k_proj_weight, v_proj_weight. 5933 q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. 5934 static_k, static_v: static key and value used for attention operators. 5935 average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. 5936 Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect 5937 when ``need_weights=True.``. Default: True 5938 5939 5940 Shape: 5941 Inputs: 5942 - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 5943 the embedding dimension. 5944 - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 5945 the embedding dimension. 5946 - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 5947 the embedding dimension. 5948 - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. 5949 If a FloatTensor is provided, it will be directly added to the value. 5950 If a BoolTensor is provided, the positions with the 5951 value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. 5952 - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 5953 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, 5954 S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked 5955 positions. If a BoolTensor is provided, positions with ``True`` 5956 are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor 5957 is provided, it will be added to the attention weight. 5958 - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 5959 N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 5960 - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 5961 N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 5962 5963 Outputs: 5964 - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 5965 E is the embedding dimension. 5966 - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns 5967 attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or 5968 :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and 5969 :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per 5970 head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. 5971 """ 5972 tens_ops = ( 5973 query, 5974 key, 5975 value, 5976 in_proj_weight, 5977 in_proj_bias, 5978 bias_k, 5979 bias_v, 5980 out_proj_weight, 5981 out_proj_bias, 5982 ) 5983 if has_torch_function(tens_ops): 5984 return handle_torch_function( 5985 multi_head_attention_forward, 5986 tens_ops, 5987 query, 5988 key, 5989 value, 5990 embed_dim_to_check, 5991 num_heads, 5992 in_proj_weight, 5993 in_proj_bias, 5994 bias_k, 5995 bias_v, 5996 add_zero_attn, 5997 dropout_p, 5998 out_proj_weight, 5999 out_proj_bias, 6000 training=training, 6001 key_padding_mask=key_padding_mask, 6002 need_weights=need_weights, 6003 attn_mask=attn_mask, 6004 is_causal=is_causal, 6005 use_separate_proj_weight=use_separate_proj_weight, 6006 q_proj_weight=q_proj_weight, 6007 k_proj_weight=k_proj_weight, 6008 v_proj_weight=v_proj_weight, 6009 static_k=static_k, 6010 static_v=static_v, 6011 average_attn_weights=average_attn_weights, 6012 ) 6013 6014 is_batched = _mha_shape_check( 6015 query, key, value, key_padding_mask, attn_mask, num_heads 6016 ) 6017 6018 # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input 6019 # is batched, run the computation and before returning squeeze the 6020 # batch dimension so that the output doesn't carry this temporary batch dimension. 6021 if not is_batched: 6022 # unsqueeze if the input is unbatched 6023 query = query.unsqueeze(1) 6024 key = key.unsqueeze(1) 6025 value = value.unsqueeze(1) 6026 if key_padding_mask is not None: 6027 key_padding_mask = key_padding_mask.unsqueeze(0) 6028 6029 # set up shape vars 6030 tgt_len, bsz, embed_dim = query.shape 6031 src_len, _, _ = key.shape 6032 6033 key_padding_mask = _canonical_mask( 6034 mask=key_padding_mask, 6035 mask_name="key_padding_mask", 6036 other_type=_none_or_dtype(attn_mask), 6037 other_name="attn_mask", 6038 target_type=query.dtype, 6039 ) 6040 6041 if is_causal and attn_mask is None: 6042 raise RuntimeError( 6043 "Need attn_mask if specifying the is_causal hint. " 6044 "You may use the Transformer module method " 6045 "`generate_square_subsequent_mask` to create this mask." 6046 ) 6047 6048 if is_causal and key_padding_mask is None and not need_weights: 6049 # when we have a kpm or need weights, we need attn_mask 6050 # Otherwise, we use the is_causal hint go as is_causal 6051 # indicator to SDPA. 6052 attn_mask = None 6053 else: 6054 attn_mask = _canonical_mask( 6055 mask=attn_mask, 6056 mask_name="attn_mask", 6057 other_type=None, 6058 other_name="", 6059 target_type=query.dtype, 6060 check_other=False, 6061 ) 6062 6063 if key_padding_mask is not None: 6064 # We have the attn_mask, and use that to merge kpm into it. 6065 # Turn off use of is_causal hint, as the merged mask is no 6066 # longer causal. 6067 is_causal = False 6068 6069 assert ( 6070 embed_dim == embed_dim_to_check 6071 ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" 6072 if isinstance(embed_dim, torch.Tensor): 6073 # embed_dim can be a tensor when JIT tracing 6074 head_dim = embed_dim.div(num_heads, rounding_mode="trunc") 6075 else: 6076 head_dim = embed_dim // num_heads 6077 assert ( 6078 head_dim * num_heads == embed_dim 6079 ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" 6080 if use_separate_proj_weight: 6081 # allow MHA to have different embedding dimensions when separate projection weights are used 6082 assert ( 6083 key.shape[:2] == value.shape[:2] 6084 ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" 6085 else: 6086 assert ( 6087 key.shape == value.shape 6088 ), f"key shape {key.shape} does not match value shape {value.shape}" 6089 6090 # 6091 # compute in-projection 6092 # 6093 if not use_separate_proj_weight: 6094 assert ( 6095 in_proj_weight is not None 6096 ), "use_separate_proj_weight is False but in_proj_weight is None" 6097 q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) 6098 else: 6099 assert ( 6100 q_proj_weight is not None 6101 ), "use_separate_proj_weight is True but q_proj_weight is None" 6102 assert ( 6103 k_proj_weight is not None 6104 ), "use_separate_proj_weight is True but k_proj_weight is None" 6105 assert ( 6106 v_proj_weight is not None 6107 ), "use_separate_proj_weight is True but v_proj_weight is None" 6108 if in_proj_bias is None: 6109 b_q = b_k = b_v = None 6110 else: 6111 b_q, b_k, b_v = in_proj_bias.chunk(3) 6112 q, k, v = _in_projection( 6113 query, 6114 key, 6115 value, 6116 q_proj_weight, 6117 k_proj_weight, 6118 v_proj_weight, 6119 b_q, 6120 b_k, 6121 b_v, 6122 ) 6123 6124 # prep attention mask 6125 6126 if attn_mask is not None: 6127 # ensure attn_mask's dim is 3 6128 if attn_mask.dim() == 2: 6129 correct_2d_size = (tgt_len, src_len) 6130 if attn_mask.shape != correct_2d_size: 6131 raise RuntimeError( 6132 f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." 6133 ) 6134 attn_mask = attn_mask.unsqueeze(0) 6135 elif attn_mask.dim() == 3: 6136 correct_3d_size = (bsz * num_heads, tgt_len, src_len) 6137 if attn_mask.shape != correct_3d_size: 6138 raise RuntimeError( 6139 f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." 6140 ) 6141 else: 6142 raise RuntimeError( 6143 f"attn_mask's dimension {attn_mask.dim()} is not supported" 6144 ) 6145 6146 # add bias along batch dimension (currently second) 6147 if bias_k is not None and bias_v is not None: 6148 assert static_k is None, "bias cannot be added to static key." 6149 assert static_v is None, "bias cannot be added to static value." 6150 k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 6151 v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 6152 if attn_mask is not None: 6153 attn_mask = pad(attn_mask, (0, 1)) 6154 if key_padding_mask is not None: 6155 key_padding_mask = pad(key_padding_mask, (0, 1)) 6156 else: 6157 assert bias_k is None 6158 assert bias_v is None 6159 6160 # 6161 # reshape q, k, v for multihead attention and make them batch first 6162 # 6163 q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 6164 if static_k is None: 6165 k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 6166 else: 6167 # TODO finish disentangling control flow so we don't do in-projections when statics are passed 6168 assert ( 6169 static_k.size(0) == bsz * num_heads 6170 ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" 6171 assert ( 6172 static_k.size(2) == head_dim 6173 ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" 6174 k = static_k 6175 if static_v is None: 6176 v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 6177 else: 6178 # TODO finish disentangling control flow so we don't do in-projections when statics are passed 6179 assert ( 6180 static_v.size(0) == bsz * num_heads 6181 ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" 6182 assert ( 6183 static_v.size(2) == head_dim 6184 ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" 6185 v = static_v 6186 6187 # add zero attention along batch dimension (now first) 6188 if add_zero_attn: 6189 zero_attn_shape = (bsz * num_heads, 1, head_dim) 6190 k = torch.cat( 6191 [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 6192 ) 6193 v = torch.cat( 6194 [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 6195 ) 6196 if attn_mask is not None: 6197 attn_mask = pad(attn_mask, (0, 1)) 6198 if key_padding_mask is not None: 6199 key_padding_mask = pad(key_padding_mask, (0, 1)) 6200 6201 # update source sequence length after adjustments 6202 src_len = k.size(1) 6203 6204 # merge key padding and attention masks 6205 if key_padding_mask is not None: 6206 assert key_padding_mask.shape == ( 6207 bsz, 6208 src_len, 6209 ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" 6210 key_padding_mask = ( 6211 key_padding_mask.view(bsz, 1, 1, src_len) 6212 .expand(-1, num_heads, -1, -1) 6213 .reshape(bsz * num_heads, 1, src_len) 6214 ) 6215 if attn_mask is None: 6216 attn_mask = key_padding_mask 6217 else: 6218 attn_mask = attn_mask + key_padding_mask 6219 6220 # adjust dropout probability 6221 if not training: 6222 dropout_p = 0.0 6223 6224 # 6225 # (deep breath) calculate attention and out projection 6226 # 6227 6228 if need_weights: 6229 B, Nt, E = q.shape 6230 q_scaled = q * math.sqrt(1.0 / float(E)) 6231 6232 assert not ( 6233 is_causal and attn_mask is None 6234 ), "FIXME: is_causal not implemented for need_weights" 6235 6236 if attn_mask is not None: 6237 attn_output_weights = torch.baddbmm( 6238 attn_mask, q_scaled, k.transpose(-2, -1) 6239 ) 6240 else: 6241 attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) 6242 attn_output_weights = softmax(attn_output_weights, dim=-1) 6243 if dropout_p > 0.0: 6244 attn_output_weights = dropout(attn_output_weights, p=dropout_p) 6245 6246 attn_output = torch.bmm(attn_output_weights, v) 6247 6248 attn_output = ( 6249 attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) 6250 ) 6251 attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 6252 attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) 6253 6254 # optionally average attention weights over heads 6255 attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 6256 if average_attn_weights: 6257 attn_output_weights = attn_output_weights.mean(dim=1) 6258 6259 if not is_batched: 6260 # squeeze the output if input was unbatched 6261 attn_output = attn_output.squeeze(1) 6262 attn_output_weights = attn_output_weights.squeeze(0) 6263 return attn_output, attn_output_weights 6264 else: 6265 # attn_mask can be either (L,S) or (N*num_heads, L, S) 6266 # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) 6267 # in order to match the input for SDPA of (N, num_heads, L, S) 6268 if attn_mask is not None: 6269 if attn_mask.size(0) == 1 and attn_mask.dim() == 3: 6270 attn_mask = attn_mask.unsqueeze(0) 6271 else: 6272 attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) 6273 6274 q = q.view(bsz, num_heads, tgt_len, head_dim) 6275 k = k.view(bsz, num_heads, src_len, head_dim) 6276 v = v.view(bsz, num_heads, src_len, head_dim) 6277 6278 attn_output = scaled_dot_product_attention( 6279 q, k, v, attn_mask, dropout_p, is_causal 6280 ) 6281 attn_output = ( 6282 attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) 6283 ) 6284 6285 attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 6286 attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) 6287 if not is_batched: 6288 # squeeze the output if input was unbatched 6289 attn_output = attn_output.squeeze(1) 6290 return attn_output, None 6291