1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3# mypy: disable-error-code=arg-type 4"""This file exports ONNX ops for opset 9. 5 6Opset 9 is supported by ONNX release 1.4.1 7release on 01/23/19 8""" 9 10from __future__ import annotations 11 12import builtins 13import functools 14import math 15import sys 16import warnings 17from typing import Callable, Sequence, TYPE_CHECKING 18 19import torch 20import torch._C._onnx as _C_onnx 21import torch.nn.modules.utils 22import torch.onnx 23from torch import _C 24 25# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics 26from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper 27from torch.onnx._globals import GLOBALS 28from torch.onnx._internal import jit_utils, registration 29 30 31if TYPE_CHECKING: 32 from torch.types import Number 33 34# EDITING THIS FILE? READ THIS FIRST! 35# see Note [Edit Symbolic Files] in README.md 36 37__all__ = [ 38 "abs", 39 "acos", 40 "add", 41 "addcmul", 42 "addmm", 43 "alias", 44 "amax", 45 "amin", 46 "aminmax", 47 "arange", 48 "argmax", 49 "argmin", 50 "as_strided", 51 "as_tensor", 52 "asin", 53 "atan", 54 "atan2", 55 "baddbmm", 56 "batch_norm", 57 "bernoulli", 58 "bitwise_not", 59 "bitwise_or", 60 "bmm", 61 "broadcast_tensors", 62 "broadcast_to", 63 "bucketize", 64 "cat", 65 "cdist", 66 "ceil", 67 "clamp_max", 68 "clamp_min", 69 "clamp", 70 "clone", 71 "constant_pad_nd", 72 "contiguous", 73 "conv_tbc", 74 "conv_transpose1d", 75 "conv_transpose2d", 76 "conv_transpose3d", 77 "conv1d", 78 "conv2d", 79 "conv3d", 80 "convert_element_type", 81 "convolution", 82 "cos", 83 "cosine_similarity", 84 "cross", 85 "cumsum", 86 "detach", 87 "dim", 88 "div", 89 "dot", 90 "dropout", 91 "elu", 92 "embedding_bag", 93 "embedding", 94 "empty_like", 95 "empty", 96 "eq", 97 "erf", 98 "exp", 99 "expand_as", 100 "expand", 101 "eye", 102 "fill", 103 "flatten", 104 "floor_divide", 105 "floor", 106 "floordiv", 107 "frobenius_norm", 108 "full_like", 109 "full", 110 "gather", 111 "ge", 112 "gelu", 113 "get_pool_ceil_padding", 114 "glu", 115 "group_norm", 116 "gt", 117 "hann_window", 118 "hardshrink", 119 "hardsigmoid", 120 "hardswish", 121 "hardtanh", 122 "index_add", 123 "index_copy", 124 "index_fill", 125 "index_put", 126 "index_select", 127 "index", 128 "instance_norm", 129 "is_floating_point", 130 "is_pinned", 131 "isnan", 132 "item", 133 "kl_div", 134 "layer_norm", 135 "le", 136 "leaky_relu", 137 "lerp", 138 "lift", 139 "linalg_cross", 140 "linalg_matrix_norm", 141 "linalg_norm", 142 "linalg_vector_norm", 143 "linear", 144 "linspace", 145 "log_sigmoid", 146 "log_softmax", 147 "log", 148 "log10", 149 "log1p", 150 "log2", 151 "logical_and", 152 "logical_not", 153 "logical_or", 154 "logical_xor", 155 "logit", 156 "logsumexp", 157 "lstm_cell", 158 "lstm", 159 "lt", 160 "masked_fill", 161 "masked_fill_", 162 "matmul", 163 "max_pool1d_with_indices", 164 "max_pool2d_with_indices", 165 "max_pool3d_with_indices", 166 "max", 167 "maximum", 168 "meshgrid", 169 "min", 170 "minimum", 171 "mish", 172 "mm", 173 "movedim", 174 "mse_loss", 175 "mul", 176 "multinomial", 177 "mv", 178 "narrow", 179 "native_layer_norm", 180 "ne", 181 "neg", 182 "new_empty", 183 "new_full", 184 "new_ones", 185 "new_zeros", 186 "nonzero_numpy", 187 "nonzero", 188 "norm", 189 "numel", 190 "numpy_T", 191 "one_hot", 192 "ones_like", 193 "ones", 194 "onnx_placeholder", 195 "pad", 196 "pairwise_distance", 197 "permute", 198 "pixel_shuffle", 199 "pixel_unshuffle", 200 "pow", 201 "prelu", 202 "prim_constant_chunk", 203 "prim_constant_split", 204 "prim_constant", 205 "prim_data", 206 "prim_device", 207 "prim_dtype", 208 "prim_if", 209 "prim_layout", 210 "prim_list_construct", 211 "prim_list_unpack", 212 "prim_loop", 213 "prim_max", 214 "prim_min", 215 "prim_shape", 216 "prim_tolist", 217 "prim_tuple_construct", 218 "prim_type", 219 "prim_unchecked_cast", 220 "prim_uninitialized", 221 "rand_like", 222 "rand", 223 "randint_like", 224 "randint", 225 "randn_like", 226 "randn", 227 "reciprocal", 228 "reflection_pad", 229 "relu", 230 "relu6", 231 "remainder", 232 "repeat_interleave", 233 "repeat", 234 "replication_pad", 235 "reshape_as", 236 "reshape", 237 "roll", 238 "rrelu", 239 "rsqrt", 240 "rsub", 241 "scalar_tensor", 242 "scatter_add", 243 "scatter", 244 "select", 245 "selu", 246 "sigmoid", 247 "sign", 248 "silu", 249 "sin", 250 "size", 251 "slice", 252 "softmax", 253 "softplus", 254 "softshrink", 255 "sort", 256 "split_with_sizes", 257 "split", 258 "sqrt", 259 "square", 260 "squeeze", 261 "stack", 262 "std_mean", 263 "std", 264 "sub", 265 "t", 266 "take", 267 "tan", 268 "tanh", 269 "tanhshrink", 270 "tensor", 271 "threshold", 272 "to", 273 "topk", 274 "transpose", 275 "true_divide", 276 "type_as", 277 "unbind", 278 "unfold", 279 "unsafe_chunk", 280 "unsafe_split_with_sizes", 281 "unsafe_split", 282 "unsqueeze", 283 "unsupported_complex_operators", 284 "noop_complex_operators", 285 "unused", 286 "var_mean", 287 "var", 288 "view_as", 289 "view", 290 "where", 291 "wrap_logical_op_with_cast_to", 292 "wrap_logical_op_with_negation", 293 "zeros_like", 294 "zeros", 295 "zero", 296] 297 298 299_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9) 300 301 302def _export(name: str): 303 """Exports the function in the current global namespace.""" 304 305 def wrapper(func): 306 globals()[name] = func 307 __all__.append(name) 308 return func 309 310 return wrapper 311 312 313def unused(g): 314 """Represents "missing" optional inputs.""" 315 n = g.op("prim::Constant") 316 n.setType(_C.OptionalType.ofTensor()) 317 return n 318 319 320@_onnx_symbolic("aten::_shape_as_tensor") 321def _shape_as_tensor(g: jit_utils.GraphContext, input): 322 return g.op("Shape", input) 323 324 325@_onnx_symbolic("aten::_reshape_from_tensor") 326def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape): 327 if isinstance(shape, list): 328 shape = g.op("Concat", *shape, axis_i=0) 329 return reshape(g, input, shape) 330 331 332@_onnx_symbolic("aten::reshape") 333@symbolic_helper.quantized_args(True) 334def reshape(g: jit_utils.GraphContext, self, shape): 335 return symbolic_helper._reshape_helper(g, self, shape) 336 337 338@_onnx_symbolic("aten::reshape_as") 339@symbolic_helper.quantized_args(True) 340def reshape_as(g: jit_utils.GraphContext, self, other): 341 shape = g.op("Shape", other) 342 return reshape(g, self, shape) 343 344 345@_onnx_symbolic("aten::add") 346def add(g: jit_utils.GraphContext, self, other, alpha=None): 347 """ 348 This function takes the add function and returns the corresponding ONNX operator. 349 350 This function is not meant to be called directly by the user. 351 352 Args: 353 g (GraphContext): The graph context. 354 self (Tensor): The first operand. 355 other (Tensor): The second operand. 356 alpha (float, optional): The scaling factor for the second operand. Defaults to None. 357 358 Returns: 359 ONNX operator. 360 """ 361 if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): 362 return symbolic_helper._onnx_opset_unsupported_detailed( 363 "Add", 9, 11, "Add between list of tensors not supported", self 364 ) 365 if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: 366 other = g.op("Mul", other, alpha) 367 return g.op("Add", self, other) 368 369 370@_onnx_symbolic("aten::sub") 371def sub(g: jit_utils.GraphContext, self, other, alpha=None): 372 """ 373 Consumes sub function and returns the corresponding ONNX operator. 374 375 This function is not meant to be called directly by the user. 376 377 Args: 378 g (GraphContext): The graph context. 379 self (Tensor): The first operand. 380 other (Tensor): The second operand. 381 alpha (Optional[Tensor]): A scaling factor to apply to the second operand. 382 If `alpha` is not provided, it defaults to 1. 383 384 Returns: 385 ONNX operator 386 """ 387 if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: 388 other = g.op("Mul", other, alpha) 389 return g.op("Sub", self, other) 390 391 392@_onnx_symbolic("aten::rsub") 393def rsub(g: jit_utils.GraphContext, self, other, alpha=None): 394 return sub(g, other, self, alpha=alpha) 395 396 397@_onnx_symbolic("aten::mul") 398def mul(g: jit_utils.GraphContext, self, other): 399 if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): 400 # ONNX Mul doesn't support Boolean, so use And as an equivalent operator. 401 return g.op("And", self, other) 402 else: 403 return g.op("Mul", self, other) 404 405 406@_onnx_symbolic("aten::div") 407def div(g: jit_utils.GraphContext, self, other, *args): 408 if len(args) == 0: 409 return true_divide(g, self, other) 410 else: 411 return _div_rounding_mode(g, self, other, *args) 412 413 414@_onnx_symbolic("aten::addcmul") 415@symbolic_helper.parse_args("v", "v", "v", "f") 416def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0): 417 value_tens = g.op("Constant", value_t=torch.tensor([value])) 418 return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) 419 420 421@symbolic_helper.parse_args("v", "v", "s") 422def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): 423 if rounding_mode is None: 424 return true_divide(g, self, other) 425 elif rounding_mode == "floor": 426 return _floor_divide(g, self, other) 427 elif rounding_mode == "trunc": 428 return _trunc_divide(g, self, other) 429 else: 430 raise errors.SymbolicValueError( 431 f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"', 432 self, 433 ) 434 435 436def _trunc_divide(g: jit_utils.GraphContext, self, other): 437 out = g.op("Div", self, other) 438 # the correct operation is truncate, which is not supported in ONNX, 439 # we cannot call floor since it will behave differently for negative numbers 440 # (eg. -0.1 should become -0 ) 441 # - if scalar_type information are not available, assume that 442 # we need to call floor (treat as float) 443 out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64) 444 445 # Matching PyTorch's behavior: 446 # - if self is fp the output's type is self's type 447 # - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT 448 # - self is not fp and other is not fp, the output's type is self's output type 449 # - the output type defaults to Float 450 scalar_type = _type_utils.JitScalarType.from_value( 451 self, _type_utils.JitScalarType.UNDEFINED 452 ) 453 if scalar_type != _type_utils.JitScalarType.UNDEFINED: 454 if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other): 455 out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) 456 else: 457 out = g.op( 458 "Cast", 459 out, 460 to_i=scalar_type.onnx_type(), 461 ) 462 else: 463 out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) 464 return out 465 466 467def _floor_divide(g: jit_utils.GraphContext, self, other): 468 if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): 469 out = true_divide(g, self, other) 470 return g.op("Floor", out) 471 else: 472 # Integer division does trunction rounding 473 div = g.op("Div", self, other) 474 # Division is negative if: self < 0 != other < 0 475 zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) 476 negative = g.op( 477 "Xor", 478 symbolic_helper._lt_helper(g, self, zero), 479 symbolic_helper._lt_helper(g, other, zero), 480 ) 481 482 # For negative numbers with self % other != 0, subtract 1 to round down instead of up 483 mod = g.op("Sub", self, g.op("Mul", div, other)) 484 fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) 485 486 one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) 487 fixup = g.op("Mul", fixup_mask, one) 488 return g.op("Sub", div, fixup) 489 490 491@_onnx_symbolic("aten::floor_divide") 492def floor_divide(g: jit_utils.GraphContext, self, other): 493 # Deprecated behavior, floor_divide actually truncates 494 return _trunc_divide(g, self, other) 495 496 497@_onnx_symbolic("aten::floordiv") 498def floordiv(g: jit_utils.GraphContext, self, other): 499 return floor_divide(g, self, other) 500 501 502@_onnx_symbolic("aten::true_divide") 503def true_divide(g: jit_utils.GraphContext, self, other): 504 """Division where both inputs are cast to floating types 505 506 If both inputs are floating, performs div as usual 507 If only one input is a floating type, the other input is cast to its type 508 If neither input is a floating type, both inputs are cast to the default scalar type 509 """ 510 511 # Case 1: either values are floating 512 # Performs div as usual. 513 # Implicit casting will be handled in scalar type analysis pass. 514 if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): 515 return g.op("Div", self, other) 516 517 # Case 2: neither is floating 518 # Casts both inputs to the default scalar type 519 scalar_type = torch.get_default_dtype() 520 onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT 521 assert scalar_type is torch.float or scalar_type is torch.double 522 if torch.get_default_dtype() is torch.double: 523 onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE 524 525 self = g.op("Cast", self, to_i=onnx_scalar_type) 526 other = g.op("Cast", other, to_i=onnx_scalar_type) 527 return g.op("Div", self, other) 528 529 530@_onnx_symbolic("aten::reciprocal") 531def reciprocal(g: jit_utils.GraphContext, self): 532 # torch.reciprocal implicitly casts to float, so we do the same. 533 if not symbolic_helper._is_fp(self): 534 self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) 535 return g.op("Reciprocal", self) 536 537 538@_onnx_symbolic("aten::cat") 539@symbolic_helper.parse_args("v", "i") 540def cat(g: jit_utils.GraphContext, tensor_list, dim): 541 """Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension. 542 543 Parameters: 544 g (jit_utils.GraphContext): Graph context. 545 tensor_list (List[torch.Tensor]): List of tensors to concatenate. 546 dim (int): Dimension along which to concatenate the tensors. 547 548 Returns: 549 ONNX graph node representing the concatenated tensor. 550 """ 551 tensors = symbolic_helper._unpack_list(tensor_list) 552 # torch.cat ignores empty tensors such as `torch.Tensor([])` 553 # These needs to be removed as input from ONNX's concat too, otherwise shape inference 554 # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else) 555 nonempty_tensors = [] 556 for t in tensors: 557 if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size( 558 t, 0 559 ): 560 continue 561 nonempty_tensors.append(t) 562 assert len(nonempty_tensors) > 0 563 assert all( 564 symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None 565 or symbolic_helper._get_tensor_rank(t) is None 566 or symbolic_helper._get_tensor_rank(t) 567 == symbolic_helper._get_tensor_rank(nonempty_tensors[0]) 568 for t in nonempty_tensors 569 ) 570 tensor_list.node().removeAllInputs() 571 for t in nonempty_tensors: 572 tensor_list.node().addInput(t) 573 574 tensors = symbolic_helper._unpack_list(tensor_list) 575 return g.op("Concat", *tensors, axis_i=dim) 576 577 578@_onnx_symbolic("aten::stack") 579@symbolic_helper.parse_args("v", "i") 580def stack(g: jit_utils.GraphContext, tensor_list, dim): 581 unsqueezed = [ 582 symbolic_helper._unsqueeze_helper(g, t, [dim]) 583 for t in symbolic_helper._unpack_list(tensor_list) 584 ] 585 return g.op("Concat", *unsqueezed, axis_i=dim) 586 587 588@_onnx_symbolic("aten::list") 589def _list(g: jit_utils.GraphContext, self): 590 return self 591 592 593@_onnx_symbolic("aten::mm") 594def mm(g: jit_utils.GraphContext, self, other): 595 # Create a dummy C tensor. Only needed for API purposes, the value is 596 # since beta = 0 597 C = g.op("Constant", value_t=torch.tensor([1])) 598 return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) 599 600 601@_onnx_symbolic("aten::bmm") 602def bmm(g: jit_utils.GraphContext, self, other): 603 return g.op("MatMul", self, other) 604 605 606@_onnx_symbolic("aten::matmul") 607def matmul(g: jit_utils.GraphContext, self, other): 608 return g.op("MatMul", self, other) 609 610 611@_onnx_symbolic("aten::addmm") 612@symbolic_helper.parse_args("v", "v", "v", "t", "t") 613def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): 614 scalar_type = None 615 self_scalar_type = symbolic_helper._try_get_scalar_type(self) 616 mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1) 617 mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2) 618 if self_scalar_type is not None: 619 scalar_type = self_scalar_type 620 elif mat1_scalar_type is not None: 621 scalar_type = mat1_scalar_type 622 elif mat2_scalar_type is not None: 623 scalar_type = mat2_scalar_type 624 625 mat1_rank = symbolic_helper._get_tensor_rank(mat1) 626 mat2_rank = symbolic_helper._get_tensor_rank(mat2) 627 628 def is_not_none_nor(v, u): 629 return v is not None and v != u 630 631 if scalar_type is not None and ( 632 is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2) 633 ): 634 res1 = g.op("MatMul", mat1, mat2) 635 res2 = self 636 637 alpha = symbolic_helper._scalar(alpha) 638 beta = symbolic_helper._scalar(beta) 639 640 if alpha != 1: 641 alpha = g.op( 642 "Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype()) 643 ) 644 res1 = g.op("Mul", res1, alpha) 645 if beta != 1: 646 beta = g.op( 647 "Constant", 648 value_t=torch.tensor( 649 symbolic_helper._scalar(beta), dtype=scalar_type.dtype() 650 ), 651 ) 652 res2 = g.op("Mul", res2, beta) 653 654 return g.op("Add", res1, res2) 655 656 return g.op( 657 "Gemm", 658 mat1, 659 mat2, 660 self, 661 beta_f=symbolic_helper._scalar(beta), 662 alpha_f=symbolic_helper._scalar(alpha), 663 ) 664 665 666@_onnx_symbolic("aten::neg") 667def neg(g: jit_utils.GraphContext, self): 668 return g.op("Neg", self) 669 670 671@_onnx_symbolic("aten::sqrt") 672def sqrt(g: jit_utils.GraphContext, self): 673 if _type_utils.JitScalarType.from_value( 674 self, _type_utils.JitScalarType.UNDEFINED 675 ) in { 676 _type_utils.JitScalarType.UINT8, 677 _type_utils.JitScalarType.INT8, 678 _type_utils.JitScalarType.INT16, 679 _type_utils.JitScalarType.INT, 680 _type_utils.JitScalarType.INT64, 681 }: 682 # torch converts all int inputs to sqrt to float 683 self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) 684 685 return g.op("Sqrt", self) 686 687 688@_onnx_symbolic("aten::rsqrt") 689def rsqrt(g: jit_utils.GraphContext, self): 690 return g.op( 691 "Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self) 692 ) 693 694 695@_onnx_symbolic("aten::tanh") 696# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp 697@symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128) 698def tanh(g: jit_utils.GraphContext, self): 699 return g.op("Tanh", self) 700 701 702@_onnx_symbolic("aten::sin") 703def sin(g: jit_utils.GraphContext, self): 704 return g.op("Sin", self) 705 706 707@_onnx_symbolic("aten::cos") 708def cos(g: jit_utils.GraphContext, self): 709 return g.op("Cos", self) 710 711 712@_onnx_symbolic("aten::tan") 713def tan(g: jit_utils.GraphContext, self): 714 return g.op("Tan", self) 715 716 717@_onnx_symbolic("aten::asin") 718def asin(g: jit_utils.GraphContext, self): 719 return g.op("Asin", self) 720 721 722@_onnx_symbolic("aten::acos") 723def acos(g: jit_utils.GraphContext, self): 724 return g.op("Acos", self) 725 726 727@_onnx_symbolic("aten::atan") 728def atan(g: jit_utils.GraphContext, self): 729 return g.op("Atan", self) 730 731 732@_onnx_symbolic("aten::atan2") 733def atan2(g: jit_utils.GraphContext, self, other): 734 # self is y, and other is x on coordinate 735 slope = g.op("Div", self, other) 736 atan = g.op("Atan", slope) 737 const_zero = g.op("Constant", value_t=torch.tensor(0)) 738 const_pi = g.op("Constant", value_t=torch.tensor(math.pi)) 739 740 condition_second_or_third_quadrant = g.op("Greater", self, const_zero) 741 second_third_quadrant = g.op( 742 "Where", 743 condition_second_or_third_quadrant, 744 g.op("Add", atan, const_pi), 745 g.op("Sub", atan, const_pi), 746 ) 747 748 condition_14_or_23_quadrant = g.op("Less", other, const_zero) 749 result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan) 750 751 return result 752 753 754@_onnx_symbolic("aten::sigmoid") 755# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp 756@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) 757def sigmoid(g: jit_utils.GraphContext, self): 758 """Converts the corresponding PyTorch function into ONNX operators. 759 760 It is not meant to be called directly by a user. 761 762 Args: 763 g (jit_utils.GraphContext): Graph context. 764 self (Tensor): the input tensor. 765 Returns: 766 ONNX operator 767 """ 768 return g.op("Sigmoid", self) 769 770 771@_onnx_symbolic("aten::sign") 772def sign(g: jit_utils.GraphContext, self): 773 return g.op("Sign", self) 774 775 776@symbolic_helper.quantized_args(True) 777def _slice(g: jit_utils.GraphContext, input, axes, starts, ends): 778 assert len(starts) == len(ends) 779 if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX: 780 return input 781 return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends) 782 783 784@_onnx_symbolic( 785 "aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")] 786) 787@_onnx_symbolic( 788 "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] 789) 790# torch.prod does not support multidimensional "dim" 791@_onnx_symbolic( 792 "aten::prod", 793 decorate=[ 794 symbolic_helper._apply_params( 795 "ReduceProd", "prod", allow_multi_dim_support=False 796 ) 797 ], 798) 799def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): 800 return symbolic_helper._reduce_with_dtype_helper( 801 onnx_op, name, allow_multi_dim_support 802 ) 803 804 805@_onnx_symbolic("aten::cumsum") 806@symbolic_helper.parse_args("v", "i", "none") 807def cumsum(g: jit_utils.GraphContext, input, dim, dtype): 808 symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) 809 810 811@_onnx_symbolic("aten::_sample_dirichlet") 812def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): 813 return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) 814 815 816@_onnx_symbolic("aten::_standard_gamma") 817def _standard_gamma(g: jit_utils.GraphContext, self, generator): 818 return symbolic_helper._onnx_unsupported("_standard_gamma", self) 819 820 821@_onnx_symbolic("aten::t") 822def t(g: jit_utils.GraphContext, self): 823 rank = symbolic_helper._get_tensor_rank(self) 824 if rank is None or rank < 2: 825 # The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior 826 # clearly and onnxruntime fails on these cases. So we add an Identity node to 827 # mirror the behavior of eager mode. 828 return g.op("Identity", self) 829 return g.op("Transpose", self, perm_i=(1, 0)) 830 831 832@_onnx_symbolic("aten::numpy_T") 833@symbolic_helper.quantized_args(True) 834def numpy_T(g: jit_utils.GraphContext, input): 835 ndim = symbolic_helper._get_tensor_rank(input) 836 assert ndim is not None 837 perm = list(reversed(range(0, ndim))) 838 return g.op("Transpose", input, perm_i=perm) 839 840 841@_onnx_symbolic("aten::expand") 842@symbolic_helper.quantized_args(True) 843def expand(g: jit_utils.GraphContext, self, size, implicit): 844 """Implement the expand function for a pytorch tensor in ONNX according to specified `size`""" 845 size = symbolic_helper._maybe_get_const(size, "is") 846 if not symbolic_helper._is_value(size): 847 size = g.op("Constant", value_t=torch.LongTensor(size)) 848 elif symbolic_helper._is_packed_list(size): 849 # Expand with -1 dim value means dim is unchanged. 850 # Since onnx::expand supports two-way broadcasting, 851 # -1 dim value can be exported to onnx as 1 852 size = symbolic_helper._reshape_helper( 853 g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) 854 ) 855 dtype = _type_utils.JitScalarType.INT64 856 ones = ones_like(g, size, dtype) 857 neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) 858 size = where(g, g.op("Equal", size, neg_ones), ones, size) 859 return g.op("Expand", self, size) 860 861 862@_onnx_symbolic("aten::broadcast_to") 863@symbolic_helper.quantized_args(True) 864def broadcast_to(g: jit_utils.GraphContext, self, size): 865 size = symbolic_helper._maybe_get_const(size, "is") 866 if not symbolic_helper._is_value(size): 867 size = g.op("Constant", value_t=torch.LongTensor(size)) 868 elif symbolic_helper._is_packed_list(size): 869 # Expand with -1 dim value means dim is unchanged. 870 # Since onnx::expand supports two-way broadcasting, 871 # -1 dim value can be exported to onnx as 1 872 size = symbolic_helper._reshape_helper( 873 g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) 874 ) 875 dtype = _type_utils.JitScalarType.INT64 876 ones = ones_like(g, size, dtype) 877 neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) 878 size = where(g, g.op("Equal", size, neg_ones), ones, size) 879 return g.op("Expand", self, size) 880 881 882@_onnx_symbolic("aten::expand_as") 883@symbolic_helper.quantized_args(True, True) 884def expand_as(g: jit_utils.GraphContext, self, other): 885 self_t = symbolic_helper._maybe_get_const(self, "t") 886 if isinstance(self_t, torch.Tensor): 887 orig_type = self_t.dtype 888 self_t = self_t.to(torch.double) 889 dims = [] 890 for d in range(self_t.dim()): 891 if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): 892 dims.append(d) 893 self = g.op( 894 "Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type) 895 ) 896 897 shape = g.op("Shape", other) 898 return g.op("Expand", self, shape) 899 900 901@_onnx_symbolic("aten::embedding") 902@symbolic_helper.quantized_args(True) 903@symbolic_helper.parse_args("v", "v", "i", "b", "v") 904def embedding( 905 g: jit_utils.GraphContext, 906 weight, 907 indices, 908 padding_idx, 909 scale_grad_by_freq, 910 sparse, 911): 912 if scale_grad_by_freq and GLOBALS.export_training: 913 raise errors.SymbolicValueError( 914 "Unsupported: ONNX export of embedding with scale_grad_by_freq=True " 915 "for training mode. ONNX does not support scaling the gradients.", 916 weight, 917 ) 918 if padding_idx >= 0 and GLOBALS.export_training: 919 warnings.warn( 920 "Warning: ONNX export of embedding with padding_idx >= 0 " 921 "for training mode. " 922 "ONNX does not support not updating the embedding vector at padding_idx during training." 923 ) 924 925 return g.op("Gather", weight, indices) 926 927 928@_onnx_symbolic("aten::embedding_bag") 929@symbolic_helper.quantized_args(True) 930@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") 931def embedding_bag( 932 g: jit_utils.GraphContext, 933 embedding_matrix, 934 indices, 935 offsets, 936 scale_grad_by_freq, 937 mode, 938 sparse, 939 per_sample_weights, 940 include_last_offset, 941 padding_idx, 942): 943 if not symbolic_helper._is_none(per_sample_weights): 944 return symbolic_helper._onnx_unsupported( 945 "embedding_bag with per_sample_weights" 946 ) 947 948 return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) 949 950 951@_onnx_symbolic("aten::size") 952@symbolic_helper.quantized_args(True, quantize_output=False) 953def size(g: jit_utils.GraphContext, self, dim=None): 954 if dim is None: 955 return g.op("Shape", self) 956 if symbolic_helper._maybe_get_const(dim, "i") < 0: 957 rank = symbolic_helper._get_tensor_rank(self) 958 if rank is not None: 959 dim = symbolic_helper._maybe_get_const(dim, "i") + rank 960 dim = g.op("Constant", value_t=torch.tensor(dim)) 961 return symbolic_helper._size_helper(g, self, dim) 962 963 964@_onnx_symbolic("aten::transpose") 965@symbolic_helper.quantized_args(True) 966@symbolic_helper.parse_args("v", "i", "i") 967def transpose(g: jit_utils.GraphContext, self, dim0, dim1): 968 if dim0 == dim1: # micro-optimization 969 return self 970 971 # NB: Transpose in ONNX is actually a Permute 972 rank = symbolic_helper._get_tensor_rank(self) 973 if rank is not None: 974 axes = list(range(rank)) 975 axes[dim0], axes[dim1] = axes[dim1], axes[dim0] 976 return g.op("Transpose", self, perm_i=axes) 977 else: 978 raise errors.SymbolicValueError( 979 "Unsupported: ONNX export of transpose for tensor of unknown rank.", 980 self, 981 ) 982 983 984@_onnx_symbolic("aten::permute") 985@symbolic_helper.parse_args("v", "is") 986def permute(g: jit_utils.GraphContext, self, dims): 987 if dims == list(range(0, len(dims))): 988 return self 989 return g.op("Transpose", self, perm_i=dims) 990 991 992@_onnx_symbolic("aten::view") 993@symbolic_helper.quantized_args(True) 994def view(g: jit_utils.GraphContext, self, size): 995 return reshape(g, self, size) 996 997 998@_onnx_symbolic("aten::view_as") 999def view_as(g: jit_utils.GraphContext, self, other): 1000 shape = g.op("Shape", other) 1001 return reshape(g, self, shape) 1002 1003 1004@_onnx_symbolic("aten::unsafe_chunk") 1005@symbolic_helper.parse_args("v", "i", "i", "i") 1006def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): 1007 if _outputs is None: 1008 return symbolic_helper._onnx_opset_unsupported_detailed( 1009 "unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self 1010 ) 1011 size = symbolic_helper._get_tensor_dim_size(self, dim) 1012 if size is None: 1013 return symbolic_helper._unimplemented( 1014 "unsafe_chunk", "unknown dimension size", self 1015 ) 1016 split_size = (size + chunks - 1) // chunks 1017 splits = [split_size] * (size // split_size) 1018 leftover = size % split_size 1019 if leftover: 1020 splits.append(leftover) 1021 return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) 1022 1023 1024@_onnx_symbolic("aten::split") 1025@symbolic_helper.parse_args("v", "v", "i", "i") 1026def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): 1027 if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): 1028 return symbolic_helper._onnx_opset_unsupported_detailed( 1029 "split", 9, 11, "Dynamic number of outputs not supported", self 1030 ) 1031 split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") 1032 if split_val.dim() > 0: 1033 return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs) 1034 split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") 1035 1036 size = symbolic_helper._get_tensor_dim_size(self, dim) 1037 if size is None: 1038 if _outputs is not None: 1039 size = split_size * _outputs 1040 else: 1041 return symbolic_helper._onnx_opset_unsupported_detailed( 1042 "split", 9, 11, "Unknown dimension size not supported", self 1043 ) 1044 splits = [split_size] * (size // split_size) 1045 leftover = size % split_size 1046 if leftover: 1047 splits.append(leftover) 1048 return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) 1049 1050 1051@_onnx_symbolic("aten::unsafe_split") 1052def unsafe_split( 1053 g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None 1054): 1055 return split(g, self, split_size_or_sizes, dim, _outputs) 1056 1057 1058@_onnx_symbolic("aten::split_with_sizes") 1059@symbolic_helper.parse_args("v", "is", "i", "i") 1060def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): 1061 if not symbolic_helper._is_split_static(split_sizes, _outputs): 1062 return symbolic_helper._onnx_opset_unsupported_detailed( 1063 "split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self 1064 ) 1065 return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) 1066 1067 1068@_onnx_symbolic("aten::unsafe_split_with_sizes") 1069def unsafe_split_with_sizes( 1070 g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None 1071): 1072 return split_with_sizes(g, self, split_sizes, dim, _outputs) 1073 1074 1075@_onnx_symbolic("aten::unbind") 1076@symbolic_helper.parse_args("v", "i", "i") 1077def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): 1078 if _outputs is None: 1079 return symbolic_helper._onnx_opset_unsupported_detailed( 1080 "unbind", 9, 11, "Dynamic number of outputs not supported", self 1081 ) 1082 1083 outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) 1084 outputs = [outputs] if _outputs == 1 else outputs 1085 squeezed_outputs = [ 1086 symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs 1087 ] 1088 return squeezed_outputs 1089 1090 1091@_onnx_symbolic("aten::select") 1092@symbolic_helper.quantized_args(True) 1093@symbolic_helper.parse_args("v", "i", "v") 1094def select(g: jit_utils.GraphContext, self, dim, index): 1095 """Implement the select functionality for a pytorch tensor in ONNX. 1096 1097 Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor. 1098 """ 1099 index = symbolic_helper._maybe_get_scalar(index) 1100 if (not symbolic_helper._is_value(index)) and (index < 0): 1101 if index == -1: 1102 end_index = _constants.INT64_MAX 1103 else: 1104 end_index = index + 1 1105 slice_node = symbolic_helper._slice_helper( 1106 g, self, axes=[dim], starts=[index], ends=[end_index] 1107 ) 1108 return symbolic_helper._squeeze_helper(g, slice_node, [dim]) 1109 else: 1110 # FIXME(justinchuby): can index be an int and not a value? 1111 return g.op("Gather", self, index, axis_i=dim) 1112 1113 1114@_onnx_symbolic("aten::square") 1115def square(g: jit_utils.GraphContext, self): 1116 return g.op("Mul", self, self) 1117 1118 1119@_onnx_symbolic("aten::squeeze") 1120def squeeze(g: jit_utils.GraphContext, self, dim=None): 1121 if dim is None: 1122 return g.op("Squeeze", self) 1123 1124 squeeze_dim = symbolic_helper._get_const(dim, "i", "dim") 1125 # Handle negative dims 1126 if squeeze_dim < 0: 1127 rank = symbolic_helper._get_tensor_rank(self) 1128 if rank is not None: 1129 warnings.warn( 1130 "ONNX export squeeze with negative axis " 1131 + str(squeeze_dim) 1132 + " might cause the onnx model to be incorrect. " 1133 + "Negative axis is not supported in ONNX. " 1134 + "Axis is converted to " 1135 + str(squeeze_dim + rank) 1136 + " based on input shape at export time. " 1137 + "Passing an tensor of different rank in execution will be incorrect." 1138 ) 1139 squeeze_dim += rank 1140 else: 1141 return symbolic_helper._unimplemented( 1142 "squeeze", "negative axis with unknown input rank", self 1143 ) 1144 1145 dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim) 1146 if dim_size is None: 1147 warnings.warn( 1148 "This model contains a squeeze operation on dimension " 1149 + str(squeeze_dim) 1150 + " on an input " 1151 + "with unknown shape. Note that if the size of dimension " 1152 + str(squeeze_dim) 1153 + " of the input " 1154 + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " 1155 + "non-singleton dimensions, it is recommended to export this model using opset " 1156 + "version 11 or higher." 1157 ) 1158 return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) 1159 if dim_size > 1: 1160 warnings.warn( 1161 "This model contains a squeeze operation on dimension " 1162 + str(squeeze_dim) 1163 + ". The size of " 1164 + "this dimension in the given input is " 1165 + str(dim_size) 1166 + ". The model will " 1167 + "be exported without the squeeze node. If the model is intended to be used with dynamic " 1168 + "input shapes, please use opset version 11 to " 1169 + "export the model." 1170 ) 1171 return self 1172 1173 warnings.warn( 1174 "This model contains a squeeze operation on dimension " 1175 + str(squeeze_dim) 1176 + ". If the model is " 1177 + "intended to be used with dynamic input shapes, please use opset version 11 to export the model." 1178 ) 1179 return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) 1180 1181 1182@_onnx_symbolic("aten::prelu") 1183def prelu(g: jit_utils.GraphContext, self, weight): 1184 self_rank = symbolic_helper._get_tensor_rank(self) 1185 weight_sizes = symbolic_helper._get_tensor_sizes(weight) 1186 weight_rank = len(weight_sizes) 1187 if self_rank is not None: 1188 if self_rank > 2: 1189 # make weight unidirectional broadcastable 1190 weight = symbolic_helper._unsqueeze_helper( 1191 g, weight, list(range(1, self_rank - 1)) 1192 ) 1193 elif self_rank == 0 and weight_sizes == [1]: 1194 # self and weight are both scalar but weight has rank == 1, squeeze weight. 1195 weight = symbolic_helper._squeeze_helper(g, weight, [0]) 1196 weight_rank = 0 1197 1198 if self_rank is not None and weight_rank is not None: 1199 assert ( 1200 self_rank >= weight_rank 1201 ), f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" 1202 return g.op("PRelu", self, weight) 1203 1204 1205@_onnx_symbolic("aten::silu") 1206def silu(g: jit_utils.GraphContext, input): 1207 return g.op("Mul", input, g.op("Sigmoid", input)) 1208 1209 1210@_onnx_symbolic("aten::mish") 1211def mish(g: jit_utils.GraphContext, input): 1212 return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) 1213 1214 1215@_onnx_symbolic("aten::relu") 1216@symbolic_helper.quantized_args(True) 1217def relu(g: jit_utils.GraphContext, input): 1218 return symbolic_helper._op_with_optional_float_cast( 1219 g, "Relu", input, opset_before=14 1220 ) 1221 1222 1223@_onnx_symbolic("aten::relu6") 1224@symbolic_helper.quantized_args(True) 1225def relu6(g: jit_utils.GraphContext, input): 1226 return clamp(g, input, 0, 6) 1227 1228 1229@_onnx_symbolic("aten::ceil") 1230def ceil(g: jit_utils.GraphContext, input): 1231 return g.op("Ceil", input) 1232 1233 1234@_onnx_symbolic("aten::floor") 1235def floor(g: jit_utils.GraphContext, input): 1236 return g.op("Floor", input) 1237 1238 1239@_onnx_symbolic("aten::len") 1240def _len(g: jit_utils.GraphContext, self): 1241 sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) 1242 return symbolic_helper._squeeze_helper(g, sz_0, [0]) 1243 1244 1245@_onnx_symbolic("aten::threshold") 1246@symbolic_helper.parse_args("v", "t", "t") 1247def threshold(g: jit_utils.GraphContext, self, threshold, value): 1248 # See Note [Export inplace] 1249 if symbolic_helper._scalar(threshold) != 0: 1250 return symbolic_helper._unimplemented("threshold", "non-zero threshold", self) 1251 if symbolic_helper._scalar(value) != 0: 1252 return symbolic_helper._unimplemented("threshold", "non-zero value", self) 1253 return g.op("Relu", self) 1254 1255 1256@_onnx_symbolic("aten::leaky_relu") 1257@symbolic_helper.quantized_args(True) 1258@symbolic_helper.parse_args("v", "f", "b") 1259def leaky_relu( 1260 g: jit_utils.GraphContext, 1261 input: _C.Value, 1262 negative_slope: float, 1263 inplace: bool = False, 1264): 1265 # See Note [Export inplace] 1266 return g.op("LeakyRelu", input, alpha_f=negative_slope) 1267 1268 1269@_onnx_symbolic("aten::glu") 1270@symbolic_helper.parse_args("v", "i") 1271def glu(g: jit_utils.GraphContext, input, dim): 1272 dim_size = symbolic_helper._get_tensor_dim_size(input, dim) 1273 if dim_size is not None: 1274 assert dim_size % 2 == 0 1275 1276 first, second = g.op("Split", input, axis_i=dim, outputs=2) 1277 return g.op("Mul", first, g.op("Sigmoid", second)) 1278 1279 1280@_onnx_symbolic("aten::softmax") 1281@symbolic_helper.parse_args("v", "i", "none") 1282def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): 1283 # Softmax does normalization at vector level. 1284 # PyTorch and ONNX use different strategies to split the input tensor into vectors. 1285 # Thus dim and axis have different meanings. 1286 # PyTorch slices the input tensor into vectors along the `dim`-th dimension. 1287 # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced. 1288 # If input is a 2 x 3 tensor: 1289 # input = [[1.0, 1.0, 1.0], 1290 # [1.0, 1,0, 1,0]] 1291 # with dim = 0, the result is: 1292 # result = [[0.5, 0.5, 0.5], 1293 # [0.5, 0.5, 0.5]] 1294 # with axis = 0, the result is: 1295 # result = [[0.167, 0.167, 0.167], 1296 # [0.167, 0.167, 0.167]] 1297 # So only when dim and axis both equal to ndim - 1 (the last dimension), 1298 # their semantics are equivalent. 1299 # So use softmax when dim and axis both equal to ndim - 1, 1300 # otherwise transpose the input to put the vectors to be normalized to the last dimension. 1301 # When input rank is not known at export time we compute softmax using a subgraph 1302 # with other operators 1303 input_dim = symbolic_helper._get_tensor_rank(input) 1304 if input_dim is not None: 1305 # TODO: remove this as onnx opset 11 spec allows negative axes 1306 if dim < 0: 1307 dim = input_dim + dim 1308 1309 is_transpose_required = input_dim != dim + 1 1310 1311 if is_transpose_required: 1312 axes = list(range(input_dim)) 1313 axes[dim], axes[-1] = axes[-1], axes[dim] 1314 input = g.op("Transpose", input, perm_i=axes) 1315 dim = input_dim - 1 1316 1317 softmax = g.op("Softmax", input, axis_i=dim) 1318 if dtype and dtype.node().kind() != "prim::Constant": 1319 parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") 1320 softmax = g.op( 1321 "Cast", 1322 softmax, 1323 to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(), 1324 ) 1325 1326 if is_transpose_required: 1327 softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined] 1328 return softmax 1329 1330 # Apply max normalization. 1331 input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1)) 1332 1333 exp = g.op("Exp", input) 1334 sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim]) 1335 softmax = g.op("Div", exp, sum) 1336 if dtype and dtype.node().kind() != "prim::Constant": 1337 parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") 1338 softmax = g.op( 1339 "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() 1340 ) 1341 return softmax 1342 1343 1344@_onnx_symbolic("aten::softplus") 1345def softplus(g: jit_utils.GraphContext, self, beta, threshold): 1346 beta_const = symbolic_helper._maybe_get_const(beta, "f") 1347 if beta_const != 1: 1348 return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta) 1349 return g.op("Softplus", self) 1350 1351 1352@_onnx_symbolic("aten::get_pool_ceil_padding") 1353def get_pool_ceil_padding(input, kernel_size, stride, padding): 1354 # TODO(justinchuby): Looks like this op is deprecated in torch 1355 sizes = symbolic_helper._get_tensor_sizes(input) 1356 dim = sizes[-len(padding) :] if sizes is not None else None 1357 if dim is None or any(i is None for i in dim): 1358 return symbolic_helper._unimplemented( 1359 "get_pool_ceil_padding", "input size not accessible", input 1360 ) 1361 ceiled_output_dim = [ 1362 int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) 1363 + 1 1364 for i in range(0, len(padding)) 1365 ] 1366 # ensure last pooling starts inside 1367 ceiled_output_dim = [ 1368 ( 1369 ceiled_output_dim[i] - 1 1370 if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) 1371 else ceiled_output_dim[i] 1372 ) 1373 for i in range(0, len(ceiled_output_dim)) 1374 ] 1375 padding_ceil = [ 1376 ( 1377 0 1378 if (stride[i] == 1) 1379 else ( 1380 kernel_size[i] 1381 - ( 1382 dim[i] 1383 + 2 * padding[i] 1384 - ((ceiled_output_dim[i] - 1) * stride[i] + 1) 1385 ) 1386 ) 1387 ) 1388 for i in range(0, len(padding)) 1389 ] 1390 # ensure padding is not > kernel_size 1391 padding_ceil = [ 1392 ( 1393 ( 1394 int(padding_ceil[i]) 1395 if padding_ceil[i] < kernel_size[i] - 1 1396 else int(kernel_size[i] - 1) 1397 ) 1398 if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) 1399 else int(padding_ceil[i]) 1400 ) 1401 for i in range(0, len(padding_ceil)) 1402 ] 1403 return padding_ceil 1404 1405 1406@_onnx_symbolic( 1407 "aten::max_pool1d", 1408 decorate=[ 1409 symbolic_helper._apply_params( 1410 "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False 1411 ), 1412 _export("max_pool1d"), 1413 ], 1414) 1415@_onnx_symbolic( 1416 "aten::max_pool2d", 1417 decorate=[ 1418 symbolic_helper._apply_params( 1419 "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False 1420 ), 1421 _export("max_pool2d"), 1422 ], 1423) 1424@_onnx_symbolic( 1425 "aten::max_pool3d", 1426 decorate=[ 1427 symbolic_helper._apply_params( 1428 "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False 1429 ), 1430 _export("max_pool3d"), 1431 ], 1432) 1433def _max_pool(name, tuple_fn, ndims, return_indices): 1434 @symbolic_helper.quantized_args(True, False, False, False, False, False) 1435 @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") 1436 def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): 1437 if set(tuple_fn(dilation)) != {1}: 1438 return symbolic_helper._unimplemented(name, "dilation", input) 1439 if not stride: 1440 stride = kernel_size 1441 padding = tuple(tuple_fn(padding)) 1442 if ceil_mode: 1443 padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) 1444 padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) 1445 else: 1446 padding = padding * 2 1447 kwargs = { 1448 "kernel_shape_i": tuple_fn(kernel_size), 1449 "pads_i": padding, 1450 "strides_i": tuple_fn(stride), 1451 } 1452 # easy but hacky way to get flattened indices values 1453 # to be used to convert the indices values to non-flattened. 1454 # In ONNX the indices are computed as a flatten 1-D tensor, 1455 # so the values in indices are in [0, N x C x D1 x ... x Dn). 1456 # To convert the indices to the same format used by Pytorch, 1457 # we first execute a maxpool with a kernel and stride of 1 on the same input. 1458 # This will result in a tensor of indices in which each index will have it's own value. 1459 # Using this tensor as a reference, we extract the first index of each axis and subtract 1460 # it from each index of this axis in the indices to convert. 1461 # This step will result in a tensor were each dimension has values of indices within 1462 # the dimension it is in. 1463 # For more information : 1464 # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 1465 if return_indices: 1466 r, indices = g.op("MaxPool", input, outputs=2, **kwargs) 1467 _, flattened_indices = g.op( 1468 "MaxPool", 1469 input, 1470 outputs=2, 1471 kernel_shape_i=[1 for _ in range(ndims)], 1472 strides_i=[1 for _ in range(ndims)], 1473 ) 1474 # convert indices to have non-flattened indices values 1475 s = symbolic_helper._slice_helper( 1476 g, 1477 flattened_indices, 1478 axes=[2 + i for i in range(ndims)], 1479 starts=list(tuple_fn(0)), 1480 ends=list(tuple_fn(1)), 1481 ) 1482 indices = sub(g, indices, s) 1483 return r, indices 1484 else: 1485 r = g.op("MaxPool", input, outputs=1, **kwargs) 1486 return r 1487 1488 return symbolic_fn 1489 1490 1491max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")( 1492 _max_pool( 1493 "max_pool1d_with_indices", 1494 torch.nn.modules.utils._single, 1495 1, 1496 return_indices=True, 1497 ) 1498) 1499max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")( 1500 _max_pool( 1501 "max_pool2d_with_indices", 1502 torch.nn.modules.utils._pair, 1503 2, 1504 return_indices=True, 1505 ) 1506) 1507max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")( 1508 _max_pool( 1509 "max_pool3d_with_indices", 1510 torch.nn.modules.utils._triple, 1511 3, 1512 return_indices=True, 1513 ) 1514) 1515 1516 1517@_onnx_symbolic( 1518 "aten::avg_pool1d", 1519 decorate=[ 1520 symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single), 1521 _export("avg_pool1d"), 1522 ], 1523) 1524@_onnx_symbolic( 1525 "aten::avg_pool2d", 1526 decorate=[ 1527 symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair), 1528 _export("avg_pool2d"), 1529 ], 1530) 1531@_onnx_symbolic( 1532 "aten::avg_pool3d", 1533 decorate=[ 1534 symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple), 1535 _export("avg_pool3d"), 1536 ], 1537) 1538def _avg_pool(name, tuple_fn): 1539 @symbolic_helper.quantized_args(True) 1540 @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") 1541 def symbolic_fn( 1542 g, 1543 input: _C.Value, 1544 kernel_size: Sequence[int], 1545 stride: Sequence[int], 1546 padding: int | Sequence[int], 1547 ceil_mode: int, 1548 count_include_pad: int, 1549 divisor_override=None, 1550 ): 1551 if not stride: 1552 stride = kernel_size 1553 padding = symbolic_helper._avgpool_helper( 1554 tuple_fn, padding, kernel_size, stride, divisor_override, name 1555 ) 1556 assert isinstance(padding, tuple) 1557 adjusted_padding = padding 1558 # Although onnx::AvgPool provides count_include_pad, 1559 # The corner case of Average Pooling with ceil_mode on 1560 # PyTorch allows sliding window go off bound, which leads to 1561 # this accommodation. 1562 # More detail on https://github.com/pytorch/pytorch/issues/57178 1563 if count_include_pad: 1564 input = symbolic_helper._op_with_optional_float_cast( 1565 g, 1566 "Pad", 1567 input, 1568 pads_i=((0,) * 2 + padding) * 2, 1569 mode_s="constant", 1570 value_f=0.0, 1571 opset_before=11, 1572 ) 1573 adjusted_padding = (0,) * len(padding) 1574 if ceil_mode: 1575 padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) 1576 adjusted_padding = adjusted_padding + tuple( 1577 a + b for (a, b) in zip(padding_ceil, adjusted_padding) 1578 ) 1579 else: 1580 adjusted_padding = adjusted_padding * 2 1581 output = g.op( 1582 "AveragePool", 1583 input, 1584 kernel_shape_i=tuple_fn(kernel_size), 1585 strides_i=tuple_fn(stride), 1586 pads_i=adjusted_padding, 1587 ) 1588 return output 1589 1590 return symbolic_fn 1591 1592 1593@_onnx_symbolic( 1594 "aten::adaptive_avg_pool1d", 1595 decorate=[ 1596 symbolic_helper._apply_params( 1597 "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single 1598 ), 1599 _export("adaptive_avg_pool1d"), 1600 ], 1601) 1602@_onnx_symbolic( 1603 "aten::adaptive_avg_pool2d", 1604 decorate=[ 1605 symbolic_helper._apply_params( 1606 "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair 1607 ), 1608 _export("adaptive_avg_pool2d"), 1609 ], 1610) 1611@_onnx_symbolic( 1612 "aten::adaptive_avg_pool3d", 1613 decorate=[ 1614 symbolic_helper._apply_params( 1615 "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple 1616 ), 1617 _export("adaptive_avg_pool3d"), 1618 ], 1619) 1620@_onnx_symbolic( 1621 "aten::adaptive_max_pool1d", 1622 decorate=[ 1623 symbolic_helper._apply_params( 1624 "adaptive_max_pool1d", 1625 "MaxPool", 1626 torch.nn.modules.utils._single, 1627 max_pool1d_with_indices, 1628 ), 1629 _export("adaptive_max_pool1d"), 1630 ], 1631) 1632@_onnx_symbolic( 1633 "aten::adaptive_max_pool2d", 1634 decorate=[ 1635 symbolic_helper._apply_params( 1636 "adaptive_max_pool2d", 1637 "MaxPool", 1638 torch.nn.modules.utils._pair, 1639 max_pool2d_with_indices, 1640 ), 1641 _export("adaptive_max_pool2d"), 1642 ], 1643) 1644@_onnx_symbolic( 1645 "aten::adaptive_max_pool3d", 1646 decorate=[ 1647 symbolic_helper._apply_params( 1648 "adaptive_max_pool3d", 1649 "MaxPool", 1650 torch.nn.modules.utils._triple, 1651 max_pool3d_with_indices, 1652 ), 1653 _export("adaptive_max_pool3d"), 1654 ], 1655) 1656def _adaptive_pool(name, type, tuple_fn, fn=None): 1657 @symbolic_helper.quantized_args(True, False) 1658 def symbolic_fn(g, input, output_size): 1659 # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, 1660 # by executing a GlobalPool. 1661 # It is also supported for cases where the output size is a factor of the input size. 1662 # For these cases the stride and kernel size are uniform along all the indices of 1663 # the same dimension, which makes it possible to export it to ONNX. 1664 # for MaxPool, GlobalMaxPool does not return indices, 1665 # so we try using max_poolxd_with_indices, and if it is not possible 1666 # (input is not a complete tensor or output size not factor of input size) 1667 # then we call GlobalAveragePool and return None for the indices 1668 output_size_value = output_size 1669 try: 1670 output_size = symbolic_helper._parse_arg(output_size, "is") 1671 except Exception: 1672 # FIXME(justinchuby): Avoid catching Exception. 1673 # Catch a more specific exception instead. 1674 return symbolic_helper._onnx_unsupported( 1675 "adaptive pooling, since output_size is not constant.", input 1676 ) 1677 if output_size == [1] * len(output_size) and type == "AveragePool": 1678 return g.op("GlobalAveragePool", input) 1679 sizes = symbolic_helper._get_tensor_sizes(input) 1680 try: 1681 dim = sizes[2:] 1682 except Exception: 1683 # FIXME(justinchuby): Avoid catching Exception. 1684 # Catch a more specific exception instead. 1685 dim = None 1686 if dim is None or any(i is None for i in dim): 1687 if output_size == [1] * len(output_size): 1688 return g.op("GlobalMaxPool", input), None 1689 return symbolic_helper._unimplemented( 1690 name, "input size not accessible", input 1691 ) 1692 # verify if output size % input size = 0 for all dim 1693 mod = [dim[i] % output_size[i] for i in range(0, len(dim))] 1694 if mod != [0] * len(mod): 1695 if output_size == [1] * len(output_size): 1696 return g.op("GlobalMaxPool", input), None 1697 return symbolic_helper._unimplemented( 1698 name, "output size that are not factor of input size", output_size_value 1699 ) 1700 k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] 1701 # call max_poolxd_with_indices to get indices in the output 1702 if type == "MaxPool": 1703 return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) 1704 output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) 1705 return output 1706 1707 return symbolic_fn 1708 1709 1710def _prepare_onnx_paddings(dim: int, pad): 1711 """Generate paddings in ONNX order based on pad in pytorch. 1712 Args: 1713 dim: the dimension of the tensor. 1714 pad: the paddings in pytorch. 1715 The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... 1716 """ 1717 # The desired order of paddings is 1718 # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. 1719 # n is the dimension of input. 1720 # assume zero-dimensions in the beginning 1721 paddings = list(pad[:]) + [0] * (dim * 2 - len(pad)) 1722 # reverse order and collate first beginnings and then ends 1723 paddings = paddings[-2::-2] + paddings[-1::-2] 1724 return paddings 1725 1726 1727def _convert_padding_node(input): 1728 padding = symbolic_helper._maybe_get_const(input, "is") 1729 if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding): 1730 input_list = symbolic_helper._unpack_list(padding) 1731 try: 1732 padding = [ 1733 symbolic_helper._get_const(v, "i", "padding") for v in input_list 1734 ] 1735 except Exception: 1736 # FIXME(justinchuby): Avoid catching Exception. 1737 # Catch a more specific exception instead. 1738 return symbolic_helper._onnx_opset_unsupported_detailed( 1739 "Pad", 9, 11, "The sizes of the padding must be constant", input 1740 ) 1741 return padding 1742 1743 1744@_onnx_symbolic("aten::constant_pad_nd") 1745def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value): 1746 mode = "constant" 1747 try: 1748 value = symbolic_helper._get_const(value, "f", "value") 1749 except Exception: 1750 # FIXME(justinchuby): Avoid catching Exception. 1751 # Catch a more specific exception instead. 1752 return symbolic_helper._onnx_opset_unsupported_detailed( 1753 "Pad", 9, 11, "The value for the padding must be constant", value 1754 ) 1755 1756 padding = _convert_padding_node(padding) 1757 paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) 1758 return symbolic_helper._op_with_optional_float_cast( 1759 g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 1760 ) 1761 1762 1763def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value): 1764 padding = _convert_padding_node(pad) 1765 assert len(padding) % 2 == 0 1766 ndim = len(padding) // 2 1767 1768 cur = input 1769 for idx in range(ndim): 1770 pad_r = padding[-(2 * idx + 1)] 1771 pad_l = padding[-(2 * idx + 2)] 1772 tensors = [] 1773 if pad_l > 0: 1774 left = symbolic_helper._slice_helper( 1775 g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX] 1776 ) 1777 tensors.append(left) 1778 1779 if pad_l < 0 or pad_r < 0: 1780 start = builtins.max(0, -pad_l) 1781 end = -(builtins.max(0, -pad_r)) 1782 middle = symbolic_helper._slice_helper( 1783 g, 1784 cur, 1785 axes=[2 + idx], 1786 starts=[start], 1787 ends=[end], 1788 ) 1789 tensors.append(middle) 1790 else: 1791 tensors.append(cur) 1792 1793 if pad_r > 0: 1794 right = symbolic_helper._slice_helper( 1795 g, cur, axes=[2 + idx], starts=[0], ends=[pad_r] 1796 ) 1797 tensors.append(right) 1798 1799 cur = g.op("Concat", *tensors, axis_i=(2 + idx)) 1800 1801 return cur 1802 1803 1804@_onnx_symbolic("aten::reflection_pad1d") 1805@_onnx_symbolic("aten::reflection_pad2d") 1806@_onnx_symbolic("aten::reflection_pad3d") 1807def reflection_pad(g: jit_utils.GraphContext, input, padding): 1808 mode = "reflect" 1809 padding = _convert_padding_node(padding) 1810 paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) 1811 return symbolic_helper._op_with_optional_float_cast( 1812 g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 1813 ) 1814 1815 1816@_onnx_symbolic("aten::replication_pad1d") 1817@_onnx_symbolic("aten::replication_pad2d") 1818@_onnx_symbolic("aten::replication_pad3d") 1819def replication_pad(g: jit_utils.GraphContext, input, padding): 1820 mode = "edge" 1821 padding = _convert_padding_node(padding) 1822 paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) 1823 return symbolic_helper._op_with_optional_float_cast( 1824 g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 1825 ) 1826 1827 1828@_onnx_symbolic("aten::pad") 1829def pad( 1830 g: jit_utils.GraphContext, 1831 input: _C.Value, 1832 pad: _C.Value, 1833 mode: _C.Value, 1834 value: _C.Value, 1835): 1836 mode = symbolic_helper._parse_arg(mode, "s") 1837 if mode == "replicate": 1838 return replication_pad(g, input, pad) 1839 elif mode == "reflect": 1840 return reflection_pad(g, input, pad) 1841 elif mode == "constant": 1842 return constant_pad_nd(g, input, pad, value) 1843 elif mode == "circular": 1844 return _pad_circular(g, input, pad) 1845 else: 1846 raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) 1847 1848 1849@_onnx_symbolic( 1850 "aten::upsample_nearest1d", 1851 decorate=[ 1852 symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"), 1853 _export("upsample_nearest1d"), 1854 ], 1855) 1856@_onnx_symbolic( 1857 "aten::upsample_nearest2d", 1858 decorate=[ 1859 symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"), 1860 _export("upsample_nearest2d"), 1861 ], 1862) 1863@_onnx_symbolic( 1864 "aten::upsample_nearest3d", 1865 decorate=[ 1866 symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"), 1867 _export("upsample_nearest3d"), 1868 ], 1869) 1870@_onnx_symbolic( 1871 "aten::upsample_linear1d", 1872 decorate=[ 1873 symbolic_helper._apply_params("upsample_linear1d", 3, "linear"), 1874 _export("upsample_linear1d"), 1875 ], 1876) 1877@_onnx_symbolic( 1878 "aten::upsample_bilinear2d", 1879 decorate=[ 1880 symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"), 1881 _export("upsample_bilinear2d"), 1882 ], 1883) 1884@_onnx_symbolic( 1885 "aten::upsample_trilinear3d", 1886 decorate=[ 1887 symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"), 1888 _export("upsample_trilinear3d"), 1889 ], 1890) 1891def _interpolate(name: str, dim: int, interpolate_mode: str): 1892 def symbolic_fn(g, input, output_size, *args): 1893 scales, align_corners = symbolic_helper._get_interpolate_attributes( 1894 g, interpolate_mode, args 1895 ) 1896 symbolic_helper._interpolate_warning(interpolate_mode) 1897 align_corners = symbolic_helper._maybe_get_scalar(align_corners) 1898 if align_corners: 1899 return symbolic_helper._unimplemented(name, "align_corners == True", input) 1900 if scales is None: 1901 scales = symbolic_helper._interpolate_size_to_scales( 1902 g, input, output_size, dim 1903 ) 1904 return g.op("Upsample", input, scales, mode_s=interpolate_mode) 1905 1906 return symbolic_fn 1907 1908 1909@_onnx_symbolic("aten::__interpolate") 1910def __interpolate( 1911 g: jit_utils.GraphContext, 1912 input, 1913 size, 1914 scale_factor, 1915 mode, 1916 align_corners, 1917 recompute_scale_factor, 1918 antialias, 1919): 1920 scales, mode = symbolic_helper._interpolate_get_scales_and_mode( 1921 g, input, size, scale_factor, mode, align_corners 1922 ) 1923 return g.op("Upsample", input, scales, mode_s=mode) 1924 1925 1926@_onnx_symbolic("aten::bitwise_not") 1927def bitwise_not(g: jit_utils.GraphContext, input): 1928 if not symbolic_helper._is_bool(input): 1929 raise errors.SymbolicValueError( 1930 "ONNX export does NOT support exporting bitwise Not " 1931 "for non-boolean input values", 1932 input, 1933 ) 1934 return g.op("Not", input) 1935 1936 1937@_onnx_symbolic("aten::bitwise_or") 1938def bitwise_or(g, self, other): 1939 if not symbolic_helper._is_bool(self): 1940 raise errors.SymbolicValueError( 1941 "ONNX export does NOT support exporting bitwise OR " 1942 "for non-boolean input values. self: ", 1943 self, 1944 ) 1945 if not symbolic_helper._is_bool(other): 1946 raise errors.SymbolicValueError( 1947 "ONNX export does NOT support exporting bitwise OR " 1948 "for non-boolean input values. other: ", 1949 other, 1950 ) 1951 return g.op("Or", self, other) 1952 1953 1954def wrap_logical_op_with_cast_to(to_type): 1955 def decorator(fn): 1956 @functools.wraps(fn) 1957 def wrap_with_cast(g, input, other): 1958 to_cast_func = globals()[f"_cast_{to_type}"] 1959 return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) 1960 1961 return wrap_with_cast 1962 1963 return decorator 1964 1965 1966def wrap_logical_op_with_negation(func: Callable) -> Callable: 1967 @functools.wraps(func) 1968 def wrap_with_not(g, input, other): 1969 return g.op("Not", func(g, input, other)) 1970 1971 return wrap_with_not 1972 1973 1974@_onnx_symbolic("aten::__not_") 1975def __not_(g: jit_utils.GraphContext, self): 1976 if not symbolic_helper._is_bool(self): 1977 raise errors.SymbolicValueError( 1978 "ONNX export does NOT support exporting bitwise Not " 1979 "for non-boolean input values", 1980 self, 1981 ) 1982 return g.op("Not", self) 1983 1984 1985@_onnx_symbolic("aten::eq") 1986@symbolic_helper.quantized_args(True, True) 1987def eq(g: jit_utils.GraphContext, self, other): 1988 if isinstance(self.type(), _C.DeviceObjType) and isinstance( 1989 other.type(), _C.DeviceObjType 1990 ): 1991 # ONNX doesn't have devices, so consider them all to be equal. 1992 # The no-op check for equality will get constant-folded. 1993 return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool)) 1994 self_node = self.node() 1995 other_node = other.node() 1996 if self_node.kind() == other_node.kind() == "onnx::Constant": 1997 if self_node.kindOf("value") == other_node.kindOf("value") == "s": 1998 # Exporting strings to ONNX is not supported. 1999 # If both strings are constant, we can compare them directly. 2000 # The no-op check for equality will get constant-folded. 2001 return g.op( 2002 "Constant", 2003 value_t=torch.tensor( 2004 self_node.s("value") == other_node.s("value"), 2005 dtype=torch.bool, 2006 ), 2007 ) 2008 2009 return g.op("Equal", self, other) 2010 2011 2012@_onnx_symbolic("aten::ne") 2013@symbolic_helper.quantized_args(True, True) 2014@wrap_logical_op_with_negation 2015def ne(g: jit_utils.GraphContext, self, other): 2016 return eq(g, self, other) 2017 2018 2019@_onnx_symbolic("aten::gt") 2020@symbolic_helper.quantized_args(True, True) 2021def gt(g: jit_utils.GraphContext, input, other): 2022 return _gt_impl(g, input, other) 2023 2024 2025def _gt_impl(g: jit_utils.GraphContext, input, other): 2026 if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): 2027 input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) 2028 other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) 2029 return g.op("Greater", input, other) 2030 2031 2032@_onnx_symbolic("aten::lt") 2033@symbolic_helper.quantized_args(True, True) 2034def lt(g: jit_utils.GraphContext, input, other): 2035 return _lt_impl(g, input, other) 2036 2037 2038def _lt_impl(g: jit_utils.GraphContext, input, other): 2039 if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): 2040 input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) 2041 other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) 2042 return g.op("Less", input, other) 2043 2044 2045@_onnx_symbolic("aten::ge") 2046@symbolic_helper.quantized_args(True, True) 2047@wrap_logical_op_with_negation 2048def ge(g: jit_utils.GraphContext, input, other): 2049 return _lt_impl(g, input, other) 2050 2051 2052@_onnx_symbolic("aten::le") 2053@symbolic_helper.quantized_args(True, True) 2054@wrap_logical_op_with_negation 2055def le(g: jit_utils.GraphContext, input, other): 2056 return _gt_impl(g, input, other) 2057 2058 2059@_onnx_symbolic("aten::__and_") 2060def __and_(g: jit_utils.GraphContext, input, other): 2061 if not symbolic_helper._is_bool(input): 2062 raise errors.SymbolicValueError( 2063 "ONNX export does NOT support exporting bitwise AND " 2064 "for non-boolean input values", 2065 input, 2066 ) 2067 if not symbolic_helper._is_bool(other): 2068 raise errors.SymbolicValueError( 2069 "ONNX export does NOT support exporting bitwise AND " 2070 "for non-boolean input values", 2071 other, 2072 ) 2073 return g.op("And", input, other) 2074 2075 2076@_onnx_symbolic("aten::__or_") 2077def __or_(g: jit_utils.GraphContext, input, other): 2078 if not symbolic_helper._is_bool(input): 2079 raise errors.SymbolicValueError( 2080 "ONNX export does NOT support exporting bitwise OR " 2081 "for non-boolean input values", 2082 input, 2083 ) 2084 if not symbolic_helper._is_bool(other): 2085 raise errors.SymbolicValueError( 2086 "ONNX export does NOT support exporting bitwise OR " 2087 "for non-boolean input values", 2088 other, 2089 ) 2090 return g.op("Or", input, other) 2091 2092 2093@_onnx_symbolic("aten::__xor_") 2094def __xor_(g: jit_utils.GraphContext, input, other): 2095 if not symbolic_helper._is_bool(input): 2096 raise errors.SymbolicValueError( 2097 "ONNX export does NOT support exporting bitwise XOR " 2098 "for non-boolean input values", 2099 input, 2100 ) 2101 if not symbolic_helper._is_bool(other): 2102 raise errors.SymbolicValueError( 2103 "ONNX export does NOT support exporting bitwise XOR " 2104 "for non-boolean input values", 2105 other, 2106 ) 2107 return g.op("Xor", input, other) 2108 2109 2110@_onnx_symbolic("aten::logical_and") 2111@wrap_logical_op_with_cast_to("Bool") 2112def logical_and(g: jit_utils.GraphContext, input, other): 2113 return g.op("And", input, other) 2114 2115 2116@_onnx_symbolic("aten::logical_or") 2117@wrap_logical_op_with_cast_to("Bool") 2118def logical_or(g: jit_utils.GraphContext, input, other): 2119 return g.op("Or", input, other) 2120 2121 2122@_onnx_symbolic("aten::logical_xor") 2123@wrap_logical_op_with_cast_to("Bool") 2124def logical_xor(g: jit_utils.GraphContext, input, other): 2125 return g.op("Xor", input, other) 2126 2127 2128@_onnx_symbolic("aten::logical_not") 2129def logical_not(g: jit_utils.GraphContext, input): 2130 return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)) 2131 2132 2133@_onnx_symbolic("aten::__rshift_") 2134def __rshift_(g: jit_utils.GraphContext, self, other): 2135 # make sure to cast other to self's type 2136 # (when self is long, make sure that other is not float) 2137 self_scalar_type = _type_utils.JitScalarType.from_value(self) 2138 if ( 2139 _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) 2140 != self_scalar_type 2141 ): 2142 other = g.op( 2143 "Cast", 2144 other, 2145 to_i=self_scalar_type.onnx_type(), 2146 ) 2147 2148 two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) 2149 # exponent (same type as self) has to be float or double in onnx::Pow 2150 if not symbolic_helper._is_fp(self): 2151 other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) 2152 two_pow = g.op("Pow", two, other) 2153 two_pow = g.op( 2154 "Cast", 2155 two_pow, 2156 to_i=self_scalar_type.onnx_type(), 2157 ) 2158 rshift = g.op("Div", self, two_pow) 2159 return rshift 2160 2161 2162@_onnx_symbolic("aten::__lshift_") 2163def __lshift_(g: jit_utils.GraphContext, self, other): 2164 # make sure to cast other to self's type 2165 # (when self is long, make sure that other is not float) 2166 self_scalar_type = _type_utils.JitScalarType.from_value(self) 2167 if ( 2168 _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) 2169 != self_scalar_type 2170 ): 2171 other = g.op( 2172 "Cast", 2173 other, 2174 to_i=self_scalar_type.onnx_type(), 2175 ) 2176 2177 two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) 2178 # exponent (same type as self) has to be float or double in onnx::Pow 2179 if not symbolic_helper._is_fp(self): 2180 other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) 2181 two_pow = g.op("Pow", two, other) 2182 two_pow = g.op( 2183 "Cast", 2184 two_pow, 2185 to_i=self_scalar_type.onnx_type(), 2186 ) 2187 lshift = g.op("Mul", self, two_pow) 2188 return lshift 2189 2190 2191@_onnx_symbolic("aten::where") 2192@symbolic_helper.parse_args("v", "v", "v", "i") 2193def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): 2194 # Assumes that torch.where's first argument takes only Bool and Byte tensors. 2195 if not symbolic_helper._is_bool(condition): 2196 condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) 2197 if self is None: 2198 condition = nonzero(g, condition) 2199 return symbolic_helper._unbind_helper( 2200 g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs 2201 ) 2202 return g.op("Where", condition, self, other) 2203 2204 2205@_onnx_symbolic("aten::log_softmax") 2206@symbolic_helper.parse_args("v", "i", "none") 2207def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): 2208 # PyTorch dim and ONNX axis have different meanings. 2209 # See Softmax comment for details. 2210 # TODO: remove this as onnx opset 11 spec allows negative axes 2211 input_dim = symbolic_helper._get_tensor_rank(input) 2212 if input_dim is None: 2213 return symbolic_helper._unimplemented( 2214 "dim", 2215 "ONNX and PyTorch use different strategies to split the input. " 2216 "Input rank must be known at export time.", 2217 ) 2218 if dim < 0: 2219 dim = input_dim + dim 2220 is_transpose_required = input_dim != dim + 1 2221 # ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases. 2222 if is_transpose_required: 2223 axes = list(range(input_dim)) 2224 axes[dim], axes[-1] = axes[-1], axes[dim] 2225 input = g.op("Transpose", input, perm_i=axes) 2226 dim = input_dim - 1 2227 return_op = g.op("LogSoftmax", input, axis_i=dim) 2228 if dtype and dtype.node().kind() != "prim::Constant": 2229 parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") 2230 return_op = g.op( 2231 "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() 2232 ) 2233 if is_transpose_required: 2234 return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined] 2235 return return_op 2236 2237 2238@_onnx_symbolic("aten::_log_softmax") 2239@symbolic_helper.parse_args("v", "i", "i") 2240def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float): 2241 if ( 2242 half_to_float 2243 and _type_utils.JitScalarType.from_value( 2244 input, _type_utils.JitScalarType.UNDEFINED 2245 ) 2246 == _type_utils.JitScalarType.HALF 2247 ): 2248 input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) 2249 return log_softmax(g, input, dim) 2250 2251 2252@_onnx_symbolic("aten::_convolution") 2253@symbolic_helper.parse_args( 2254 "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" 2255) 2256def _convolution( 2257 g: jit_utils.GraphContext, 2258 input, 2259 weight, 2260 bias, 2261 stride, 2262 padding, 2263 dilation, 2264 transposed, 2265 output_padding, 2266 groups, 2267 benchmark, 2268 deterministic, 2269 cudnn_enabled, 2270 allow_tf32=None, 2271): 2272 weight_size = symbolic_helper._get_tensor_sizes(weight) 2273 try: 2274 kernel_shape = weight_size[2:] 2275 except Exception: 2276 # FIXME(justinchuby): Avoid catching Exception. 2277 # Catch a more specific exception instead. 2278 kernel_shape = None 2279 2280 if kernel_shape is None or any(i is None for i in kernel_shape): 2281 raise errors.SymbolicValueError( 2282 "Unsupported: ONNX export of convolution for kernel of unknown shape.", 2283 input, 2284 ) 2285 2286 args = [input, weight] 2287 # ONNX only supports 1D bias 2288 if ( 2289 not symbolic_helper._is_none(bias) 2290 and symbolic_helper._get_tensor_rank(bias) == 1 2291 ): 2292 args.append(bias) 2293 2294 kwargs = { 2295 "kernel_shape_i": weight_size[2:], 2296 "strides_i": stride, 2297 # NB: ONNX supports asymmetric padding, whereas PyTorch supports only 2298 # symmetric padding 2299 "pads_i": padding + padding, 2300 "dilations_i": dilation, 2301 "group_i": groups, 2302 } 2303 2304 if any(o != 0 for o in output_padding): 2305 # ONNX supports both output_shape and output_padding. they are equivalent expressive. 2306 # output_padding is more straightforward, so we use it here. 2307 # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2 2308 assert transposed 2309 assert len(stride) == len(output_padding) 2310 kwargs["output_padding_i"] = output_padding 2311 2312 n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) 2313 2314 if ( 2315 not symbolic_helper._is_none(bias) 2316 and symbolic_helper._get_tensor_rank(bias) != 1 2317 ): 2318 return g.op("Add", n, bias) 2319 else: 2320 return n 2321 2322 2323@_onnx_symbolic("aten::_convolution_mode") 2324@symbolic_helper.parse_args( 2325 "v", 2326 "v", 2327 "v", 2328 "is", 2329 "s", 2330 "is", 2331 "i", 2332) 2333def _convolution_mode( 2334 g: jit_utils.GraphContext, 2335 input, 2336 weight, 2337 bias, 2338 stride, 2339 padding, 2340 dilation, 2341 groups, 2342): 2343 weight_size = symbolic_helper._get_tensor_sizes(weight) 2344 try: 2345 kernel_shape = weight_size[2:] 2346 except Exception: 2347 # FIXME(justinchuby): Avoid catching Exception. 2348 # Catch a more specific exception instead. 2349 kernel_shape = None 2350 2351 if kernel_shape is None or any(i is None for i in kernel_shape): 2352 raise errors.SymbolicValueError( 2353 "Unsupported: ONNX export of convolution for kernel of unknown shape.", 2354 input, 2355 ) 2356 2357 args = [input, weight] 2358 # ONNX only supports 1D bias 2359 if ( 2360 not symbolic_helper._is_none(bias) 2361 and symbolic_helper._get_tensor_rank(bias) == 1 2362 ): 2363 args.append(bias) 2364 2365 if padding == "valid": 2366 padding = "VALID" 2367 elif padding == "same": 2368 padding = "SAME_UPPER" 2369 kwargs = { 2370 "kernel_shape_i": weight_size[2:], 2371 "strides_i": stride, 2372 "auto_pad_s": padding, 2373 "dilations_i": dilation, 2374 "group_i": groups, 2375 } 2376 2377 n = g.op("Conv", *args, **kwargs) 2378 2379 if ( 2380 not symbolic_helper._is_none(bias) 2381 and symbolic_helper._get_tensor_rank(bias) != 1 2382 ): 2383 return g.op("Add", n, bias) 2384 else: 2385 return n 2386 2387 2388@_onnx_symbolic("aten::convolution") 2389@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i") 2390def convolution( 2391 g: jit_utils.GraphContext, 2392 input, 2393 weight, 2394 bias, 2395 stride, 2396 padding, 2397 dilation, 2398 transposed, 2399 output_padding, 2400 groups, 2401): 2402 return _convolution( 2403 g, 2404 input, 2405 weight, 2406 bias, 2407 stride, 2408 padding, 2409 dilation, 2410 transposed, 2411 output_padding, 2412 groups, 2413 None, 2414 None, 2415 None, 2416 None, 2417 ) 2418 2419 2420@_onnx_symbolic("aten::conv1d") 2421@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") 2422def conv1d( 2423 g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups 2424): 2425 str_padding = symbolic_helper._parse_arg(padding, "s") 2426 if str_padding in ["valid", "same"]: 2427 return _convolution_mode( 2428 g, 2429 input, 2430 weight, 2431 bias, 2432 stride, 2433 str_padding, 2434 dilation, 2435 groups, 2436 ) 2437 else: 2438 padding = symbolic_helper._parse_arg(padding, "is") 2439 return _convolution( 2440 g, 2441 input, 2442 weight, 2443 bias, 2444 stride, 2445 padding, 2446 dilation, 2447 False, 2448 (), 2449 groups, 2450 None, 2451 None, 2452 None, 2453 None, 2454 ) 2455 2456 2457@_onnx_symbolic("aten::conv2d") 2458@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") 2459def conv2d( 2460 g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups 2461): 2462 str_padding = symbolic_helper._parse_arg(padding, "s") 2463 if str_padding in ["valid", "same"]: 2464 return _convolution_mode( 2465 g, 2466 input, 2467 weight, 2468 bias, 2469 stride, 2470 str_padding, 2471 dilation, 2472 groups, 2473 ) 2474 else: 2475 padding = symbolic_helper._parse_arg(padding, "is") 2476 return _convolution( 2477 g, 2478 input, 2479 weight, 2480 bias, 2481 stride, 2482 padding, 2483 dilation, 2484 False, 2485 (), 2486 groups, 2487 None, 2488 None, 2489 None, 2490 None, 2491 ) 2492 2493 2494@_onnx_symbolic("aten::conv3d") 2495@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") 2496def conv3d( 2497 g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups 2498): 2499 str_padding = symbolic_helper._parse_arg(padding, "s") 2500 if str_padding in ["valid", "same"]: 2501 return _convolution_mode( 2502 g, 2503 input, 2504 weight, 2505 bias, 2506 stride, 2507 str_padding, 2508 dilation, 2509 groups, 2510 ) 2511 else: 2512 padding = symbolic_helper._parse_arg(padding, "is") 2513 return _convolution( 2514 g, 2515 input, 2516 weight, 2517 bias, 2518 stride, 2519 padding, 2520 dilation, 2521 False, 2522 (), 2523 groups, 2524 None, 2525 None, 2526 None, 2527 None, 2528 ) 2529 2530 2531@_onnx_symbolic("aten::conv_transpose1d") 2532@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") 2533def conv_transpose1d( 2534 g: jit_utils.GraphContext, 2535 input, 2536 weight, 2537 bias, 2538 stride, 2539 padding, 2540 output_padding, 2541 groups, 2542 dilation, 2543): 2544 return _convolution( 2545 g, 2546 input, 2547 weight, 2548 bias, 2549 stride, 2550 padding, 2551 dilation, 2552 True, 2553 output_padding, 2554 groups, 2555 None, 2556 None, 2557 None, 2558 None, 2559 ) 2560 2561 2562@_onnx_symbolic("aten::conv_transpose2d") 2563@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") 2564def conv_transpose2d( 2565 g: jit_utils.GraphContext, 2566 input, 2567 weight, 2568 bias, 2569 stride, 2570 padding, 2571 output_padding, 2572 groups, 2573 dilation, 2574): 2575 return _convolution( 2576 g, 2577 input, 2578 weight, 2579 bias, 2580 stride, 2581 padding, 2582 dilation, 2583 True, 2584 output_padding, 2585 groups, 2586 None, 2587 None, 2588 None, 2589 None, 2590 ) 2591 2592 2593@_onnx_symbolic("aten::conv_transpose3d") 2594@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") 2595def conv_transpose3d( 2596 g: jit_utils.GraphContext, 2597 input, 2598 weight, 2599 bias, 2600 stride, 2601 padding, 2602 output_padding, 2603 groups, 2604 dilation, 2605): 2606 return _convolution( 2607 g, 2608 input, 2609 weight, 2610 bias, 2611 stride, 2612 padding, 2613 dilation, 2614 True, 2615 output_padding, 2616 groups, 2617 None, 2618 None, 2619 None, 2620 None, 2621 ) 2622 2623 2624@_onnx_symbolic("aten::batch_norm") 2625@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") 2626def batch_norm( 2627 g: jit_utils.GraphContext, 2628 input, 2629 weight, 2630 bias, 2631 running_mean, 2632 running_var, 2633 training, 2634 momentum, 2635 eps, 2636 cudnn_enabled, 2637): 2638 symbolic_helper.check_training_mode(training, "batch_norm") 2639 2640 if ( 2641 torch.is_autocast_enabled() 2642 and not symbolic_helper.args_have_same_dtype( 2643 [input, weight, bias, running_mean, running_var] 2644 ) 2645 and GLOBALS.export_onnx_opset_version < 15 2646 ): 2647 return symbolic_helper._onnx_opset_unsupported_detailed( 2648 "BatchNormalization", 2649 9, 2650 15, 2651 "All input tensors must have the same `dtype`." 2652 " Turn off Autocast or export using opset version 15.", 2653 input, 2654 ) 2655 2656 weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( 2657 g, input, weight, bias, running_mean, running_var 2658 ) 2659 out = g.op( 2660 "BatchNormalization", 2661 input, 2662 weight, 2663 bias, 2664 running_mean, 2665 running_var, 2666 epsilon_f=eps, 2667 momentum_f=1 - momentum, 2668 outputs=1 if not training else 5, 2669 ) 2670 if not training: 2671 return out 2672 else: 2673 res, new_running_mean, new_running_var, saved_mean, saved_var = out 2674 new_running_mean.setType(running_mean.type()) 2675 new_running_var.setType(running_var.type()) 2676 saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName()) 2677 saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName()) 2678 return res 2679 2680 2681@_onnx_symbolic("aten::native_layer_norm") 2682@symbolic_helper.quantized_args(True, False, False, False) 2683@symbolic_helper.parse_args("v", "is", "v", "v", "f") 2684def native_layer_norm( 2685 g: jit_utils.GraphContext, 2686 input: _C.Value, 2687 normalized_shape: Sequence[int], 2688 weight: _C.Value, 2689 bias: _C.Value, 2690 eps: float, 2691) -> tuple[_C.Value, _C.Value, _C.Value]: 2692 axes = [-i for i in range(len(normalized_shape), 0, -1)] 2693 2694 two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) 2695 eps_cst = symbolic_helper._generate_wrapped_number(g, eps) 2696 2697 if g.opset < 18: 2698 mean = g.op("ReduceMean", input, axes_i=axes) 2699 else: 2700 mean = g.op( 2701 "ReduceMean", 2702 input, 2703 g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), 2704 ) 2705 2706 numerator = sub(g, input, mean) 2707 2708 # Cast it to eps dtype to avoid precision loss 2709 is_type_half = ( 2710 _type_utils.JitScalarType.from_value(numerator) 2711 == _type_utils.JitScalarType.HALF 2712 ) 2713 if is_type_half: 2714 eps_dtype = _type_utils.JitScalarType.from_value(eps_cst) 2715 numerator = g.op( 2716 "Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type() 2717 ) 2718 2719 # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula 2720 if g.opset < 18: 2721 variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) 2722 else: 2723 variance = g.op( 2724 "ReduceMean", 2725 pow(g, numerator, two_cst), 2726 g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), 2727 ) 2728 2729 denominator = sqrt(g, g.op("Add", variance, eps_cst)) 2730 normalized = g.op("Div", numerator, denominator) 2731 2732 # Cast back to input type as eps related ops are all done 2733 if is_type_half: 2734 input_dtype = _type_utils.JitScalarType.from_value(input) 2735 normalized = g.op( 2736 "Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() 2737 ) 2738 2739 if not (weight is None or symbolic_helper._is_none(weight)): 2740 normalized = mul(g, normalized, weight) 2741 if not (bias is None or symbolic_helper._is_none(bias)): 2742 normalized = add(g, normalized, bias) 2743 2744 # rdenominator := 1 / sqrt(variance + eps) 2745 # According to aten::native_layer_norm, rdenominator should have the same dtype as input, 2746 # mean and normalized, so we need to Cast it back 2747 if is_type_half: 2748 denominator = g.op( 2749 "Cast", 2750 denominator, 2751 to_i=_type_utils.JitScalarType(input_dtype).onnx_type(), # type: ignore[possibly-undefined] 2752 ) 2753 rdenominator = g.op("Reciprocal", denominator) 2754 else: 2755 rdenominator = reciprocal(g, denominator) 2756 2757 return normalized, mean, rdenominator 2758 2759 2760@_onnx_symbolic("aten::layer_norm") 2761@symbolic_helper.quantized_args(True, False, False, False) 2762@symbolic_helper.parse_args("v", "is", "v", "v", "f", "b") 2763def layer_norm( 2764 g: jit_utils.GraphContext, 2765 input: _C.Value, 2766 normalized_shape: Sequence[int], 2767 weight: _C.Value, 2768 bias: _C.Value, 2769 eps: float, 2770 cudnn_enable: bool, 2771) -> _C.Value: 2772 normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) 2773 return normalized 2774 2775 2776@_onnx_symbolic("aten::instance_norm") 2777@symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b") 2778def instance_norm( 2779 g: jit_utils.GraphContext, 2780 input, 2781 weight, 2782 bias, 2783 running_mean, 2784 running_var, 2785 use_input_stats: bool, 2786 momentum: Number, 2787 eps: Number, 2788 cudnn_enabled: bool, 2789): 2790 symbolic_helper.check_training_mode(use_input_stats, "instance_norm") 2791 channel_size = symbolic_helper._get_tensor_dim_size(input, 1) 2792 if weight is None or symbolic_helper._is_none(weight): 2793 if channel_size is None: 2794 raise errors.SymbolicValueError( 2795 "Unsupported: ONNX export of instance_norm for unknown channel size.", 2796 input, 2797 ) 2798 weight_value = torch.tensor( 2799 [1.0] * channel_size, 2800 dtype=_type_utils.JitScalarType.from_value(input).dtype(), 2801 ) 2802 weight = g.op("Constant", value_t=weight_value) 2803 if bias is None or symbolic_helper._is_none(bias): 2804 if channel_size is None: 2805 raise errors.SymbolicValueError( 2806 "Unsupported: ONNX export of instance_norm for unknown channel size.", 2807 input, 2808 ) 2809 bias_value = torch.tensor( 2810 [0.0] * channel_size, 2811 dtype=_type_utils.JitScalarType.from_value(input).dtype(), 2812 ) 2813 bias = g.op("Constant", value_t=bias_value) 2814 if ( 2815 running_mean is None 2816 or symbolic_helper._is_none(running_mean) 2817 or running_var is None 2818 or symbolic_helper._is_none(running_var) 2819 ): 2820 return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) 2821 else: 2822 input_size = symbolic_helper._get_tensor_sizes(input) 2823 # If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm. 2824 # For more information instance_norm(): 2825 # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542 2826 input_size_reshape = input_size.copy() 2827 n = input_size[0] 2828 if n is None: 2829 raise errors.SymbolicValueError( 2830 "Unsupported: ONNX export of instance_norm training for unknown " 2831 "batch size.", 2832 input, 2833 ) 2834 c = input_size[1] 2835 input_size_reshape[0] = 1 2836 input_size_reshape[1] = n * c 2837 weight_ = repeat( 2838 g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) 2839 ) 2840 bias_ = repeat( 2841 g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) 2842 ) 2843 running_mean_ = repeat( 2844 g, 2845 running_mean, 2846 g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), 2847 ) 2848 running_var_ = repeat( 2849 g, 2850 running_var, 2851 g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), 2852 ) 2853 input_reshaped = g.op( 2854 "Reshape", 2855 input, 2856 g.op("Constant", value_t=torch.LongTensor(input_size_reshape)), 2857 ) 2858 out = batch_norm( 2859 g, 2860 input_reshaped, 2861 weight_, 2862 bias_, 2863 running_mean_, 2864 running_var_, 2865 use_input_stats, 2866 momentum, 2867 eps, 2868 cudnn_enabled, 2869 ) 2870 return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) 2871 2872 2873@_onnx_symbolic("aten::unfold") 2874@symbolic_helper.parse_args("v", "i", "i", "i") 2875def unfold(g: jit_utils.GraphContext, input, dimension, size, step): 2876 sizes = symbolic_helper._get_tensor_sizes(input) 2877 # FIXME(justinchuby): Get rid of the try catch here to improve readability 2878 try: 2879 sizedim = sizes[dimension] 2880 except Exception: 2881 # FIXME(justinchuby): Avoid catching Exception. 2882 # Catch a more specific exception instead. 2883 sizedim = None 2884 if sizedim is not None: 2885 low_indices = range(0, sizedim, step) 2886 hi_indices = range(size, sizedim + 1, step) 2887 stack = [ 2888 symbolic_helper._slice_helper( 2889 g, input, axes=[dimension], starts=[low], ends=[hi] 2890 ) 2891 for low, hi in zip(low_indices, hi_indices) 2892 ] 2893 ndim = len(sizes) 2894 perm = list(range(0, ndim)) 2895 perm.append(perm.pop(dimension)) 2896 unsqueeze = [ 2897 symbolic_helper._unsqueeze_helper( 2898 g, g.op("Transpose", t, perm_i=perm), [dimension] 2899 ) 2900 for t in stack 2901 ] 2902 return g.op("Concat", *unsqueeze, axis_i=dimension) 2903 else: 2904 return symbolic_helper._unimplemented( 2905 "Unfold", "input size not accessible", input 2906 ) 2907 2908 2909@_onnx_symbolic("aten::elu") 2910@symbolic_helper.quantized_args(True) 2911@symbolic_helper.parse_args("v", "t", "t", "t") 2912def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale): 2913 if scale and scale != 1.0: 2914 return symbolic_helper._unimplemented( 2915 "scale", "does not support scale in Elu", scale 2916 ) 2917 if input_scale and input_scale != 1.0: 2918 return symbolic_helper._unimplemented( 2919 "input_scale", "does not support input_scale in Elu", input_scale 2920 ) 2921 # See Note [Export inplace] 2922 return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) 2923 2924 2925@_onnx_symbolic("aten::selu") 2926@symbolic_helper.quantized_args(True) 2927def selu(g: jit_utils.GraphContext, input): 2928 return g.op("Selu", input) 2929 2930 2931@_onnx_symbolic("aten::index_select") 2932@symbolic_helper.parse_args("v", "i", "v") 2933def index_select(g: jit_utils.GraphContext, self, dim, index): 2934 # In case of a scalar index, index_select returns a tensor with the same rank as the input. 2935 # To match this behavior in ONNX, we make index a 1D tensor so that the following gather 2936 # also produces a tensor with the same rank as the input. 2937 return symbolic_helper._select_helper(g, self, dim, index) 2938 2939 2940@_onnx_symbolic("aten::index_put") 2941def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate): 2942 if symbolic_helper._is_packed_list(indices_list_value): 2943 indices_list = symbolic_helper._unpack_list(indices_list_value) 2944 else: 2945 indices_list = [indices_list_value] 2946 2947 accumulate = symbolic_helper._parse_arg(accumulate, "b") 2948 2949 if len(indices_list) == 0: 2950 if accumulate: 2951 return add(g, self, values) 2952 return values 2953 symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self) 2954 2955 2956@_onnx_symbolic("aten::index_fill") 2957def index_fill(g: jit_utils.GraphContext, self, dim, index, value): 2958 dim_value = symbolic_helper._parse_arg(dim, "i") 2959 expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( 2960 g, self, dim, index 2961 ) 2962 value = symbolic_helper._maybe_get_scalar(value) 2963 value = symbolic_helper._if_scalar_type_as(value, self) 2964 expanded_value = expand(g, value, expanded_index_shape, None) 2965 2966 return scatter(g, self, dim, expanded_index, expanded_value) 2967 2968 2969@_onnx_symbolic("aten::index_copy") 2970def index_copy(g: jit_utils.GraphContext, self, dim, index, source): 2971 dim_value = symbolic_helper._parse_arg(dim, "i") 2972 expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( 2973 g, self, dim, index 2974 ) 2975 return scatter(g, self, dim, expanded_index, source) 2976 2977 2978@_onnx_symbolic("aten::bucketize") 2979@symbolic_helper.parse_args("v", "v", "b", "b") 2980def bucketize( 2981 g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False 2982): 2983 out_type = _C_onnx.TensorProtoDataType.INT64 2984 if out_int32: 2985 out_type = _C_onnx.TensorProtoDataType.INT32 2986 # A tensor expanded_boundaries is created such that it 2987 # contains a copy of boundaries for each element of self. 2988 new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0) 2989 # Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops 2990 # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md 2991 tensor_rank = symbolic_helper._get_tensor_rank(self) 2992 assert tensor_rank is not None 2993 unsqueeze_axes = list(range(1, tensor_rank + 1)) 2994 expanded_boundaries = expand( 2995 g, 2996 symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes), 2997 new_shape, 2998 None, 2999 ) 3000 # Compare each element of self to boundaries to get a tensor 3001 # with leading 1s and trailing 0s. 3002 # e.g., 4 > [1, 3, 4] = [1, 1, 0] 3003 # The index of the last 1 is the bucket where the element should go. 3004 if right: 3005 cond = ge(g, self, expanded_boundaries) 3006 else: 3007 cond = gt(g, self, expanded_boundaries) 3008 cond_out = g.op("Cast", cond, to_i=out_type) 3009 # Sum to get the number of 1s corresponding to each element, 3010 # which is the same as the bucket index. 3011 # e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2 3012 return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) 3013 3014 3015@_onnx_symbolic("aten::type_as") 3016def type_as(g: jit_utils.GraphContext, self, other): 3017 self_dtype = symbolic_helper._try_get_scalar_type(self) 3018 other_dtype = symbolic_helper._try_get_scalar_type(other) 3019 if self_dtype == other_dtype and self_dtype is not None: 3020 return self 3021 if other_dtype is not None: 3022 return g.op( 3023 "Cast", 3024 self, 3025 to_i=other_dtype.onnx_type(), 3026 ) 3027 3028 raise errors.SymbolicValueError( 3029 "Unsupported: ONNX export of type_as for tensor " 3030 "of unknown dtype. Please check if the dtype of the " 3031 "parameter passed to the type_as function is correct.", 3032 other, 3033 ) 3034 3035 3036@_onnx_symbolic("aten::cosine_similarity") 3037@symbolic_helper.parse_args("v", "v", "i", "f") 3038def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): 3039 cross = symbolic_helper._reducesum_helper( 3040 g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 3041 ) 3042 x1_l2 = symbolic_helper._reducesum_helper( 3043 g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0 3044 ) 3045 x2_l2 = symbolic_helper._reducesum_helper( 3046 g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0 3047 ) 3048 div_tens = max( 3049 g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])) 3050 ) 3051 return div(g, cross, div_tens) 3052 3053 3054@_onnx_symbolic("aten::pairwise_distance") 3055def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim): 3056 if not symbolic_helper._is_value(eps): 3057 eps = g.op("Constant", value_t=torch.tensor([eps])) 3058 inv_p = div( 3059 g, 3060 g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)), 3061 add(g, p, eps), 3062 ) 3063 summation = symbolic_helper._reducesum_helper( 3064 g, 3065 pow(g, sub(g, input1, input2), p), 3066 axes_i=[-1], 3067 keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), 3068 ) 3069 return pow(g, summation, inv_p) 3070 3071 3072@_onnx_symbolic("aten::clone") 3073# ignore clone operators that are inserted by PyTorch autograd 3074def clone(g: jit_utils.GraphContext, input, unused_memory_format): 3075 return input 3076 3077 3078@_onnx_symbolic("aten::abs") 3079def abs(g: jit_utils.GraphContext, self): 3080 return g.op("Abs", self) 3081 3082 3083@_onnx_symbolic("aten::log") 3084def log(g: jit_utils.GraphContext, self): 3085 return g.op("Log", self) 3086 3087 3088@_onnx_symbolic("aten::log1p") 3089def log1p(g: jit_utils.GraphContext, self): 3090 return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self)) 3091 3092 3093@_onnx_symbolic("aten::log10") 3094def log10(g: jit_utils.GraphContext, self): 3095 _ln10 = 2.30258509299404568401 3096 return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) 3097 3098 3099@_onnx_symbolic("aten::pow") 3100def pow(g: jit_utils.GraphContext, self, exponent): 3101 f_dtype = _type_utils.JitScalarType.from_value(self) 3102 if not symbolic_helper._is_fp(self): 3103 f_dtype = _type_utils.JitScalarType.FLOAT 3104 self = g.op("Cast", self, to_i=f_dtype.onnx_type()) 3105 if not symbolic_helper._is_fp(exponent): 3106 exponent = g.op( 3107 "Cast", 3108 exponent, 3109 to_i=f_dtype.onnx_type(), 3110 ) 3111 pow = g.op("Pow", self, exponent) 3112 return pow 3113 3114 3115@_onnx_symbolic("aten::clamp") 3116def clamp(g: jit_utils.GraphContext, self, min, max): 3117 # min or max may be None that we need to dispatch to 3118 # Clip separately, as ONNX does not have None syntax 3119 if symbolic_helper._is_none(min): 3120 return clamp_max(g, self, max) 3121 elif symbolic_helper._is_none(max): 3122 return clamp_min(g, self, min) 3123 else: 3124 if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): 3125 return symbolic_helper._op_with_optional_float_cast( 3126 g, 3127 "Clip", 3128 self, 3129 min_f=symbolic_helper._parse_arg(min, "f"), 3130 max_f=symbolic_helper._parse_arg(max, "f"), 3131 opset_before=12, 3132 ) 3133 else: 3134 return clamp_max(g, clamp_min(g, self, min), max) 3135 3136 3137@_onnx_symbolic("aten::clamp_min") 3138@symbolic_helper.parse_args("v", "v") 3139def clamp_min(g: jit_utils.GraphContext, self, min): 3140 if symbolic_helper._is_constant(min): 3141 return symbolic_helper._op_with_optional_float_cast( 3142 g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 3143 ) 3144 else: 3145 dtype = _type_utils.JitScalarType.from_value(self) 3146 min = g.op("Cast", min, to_i=dtype.onnx_type()) 3147 return symbolic_helper._op_with_optional_float_cast( 3148 g, "Max", self, min, opset_before=12 3149 ) 3150 3151 3152@_onnx_symbolic("aten::clamp_max") 3153@symbolic_helper.parse_args("v", "v") 3154def clamp_max(g: jit_utils.GraphContext, self, max): 3155 if symbolic_helper._is_constant(max): 3156 return symbolic_helper._op_with_optional_float_cast( 3157 g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 3158 ) 3159 else: 3160 dtype = _type_utils.JitScalarType.from_value(self) 3161 max = g.op("Cast", max, to_i=dtype.onnx_type()) 3162 return symbolic_helper._op_with_optional_float_cast( 3163 g, "Min", self, max, opset_before=12 3164 ) 3165 3166 3167@_onnx_symbolic("aten::max") 3168# torch.max (same for torch.min) actually has two interfaces smashed together: 3169# torch.max(x, dim, keepdim) and torch.max(x, y) 3170# TODO(justinchuby): Support multiple quantized args in output 3171def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): 3172 return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) 3173 3174 3175@_onnx_symbolic("aten::maximum") 3176@symbolic_helper.quantized_args(True, True) 3177def maximum(g: jit_utils.GraphContext, input, other): 3178 return max(g, input, dim_or_y=other) 3179 3180 3181@_onnx_symbolic("aten::min") 3182# TODO(justinchuby): Support multiple quantized args in output 3183def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): 3184 return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) 3185 3186 3187@_onnx_symbolic("aten::minimum") 3188@symbolic_helper.quantized_args(True, True) 3189def minimum(g: jit_utils.GraphContext, input, other): 3190 return min(g, input, dim_or_y=other) 3191 3192 3193@_onnx_symbolic("aten::amax") 3194@symbolic_helper.quantized_args(True) 3195@symbolic_helper.parse_args("v", "is", "i") 3196def amax(g: jit_utils.GraphContext, self, dim, keepdim): 3197 return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) 3198 3199 3200@_onnx_symbolic("aten::amin") 3201@symbolic_helper.quantized_args(True) 3202@symbolic_helper.parse_args("v", "is", "i") 3203def amin(g: jit_utils.GraphContext, self, dim, keepdim): 3204 return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) 3205 3206 3207@_onnx_symbolic("aten::aminmax") 3208@symbolic_helper.quantized_args(True) 3209@symbolic_helper.parse_args("v", "v", "i") 3210def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): 3211 reduce_kwargs = {"keepdims_i": keepdim} 3212 if not symbolic_helper._is_none(dim): 3213 dim = symbolic_helper._get_const(dim, "i", "dim") 3214 reduce_kwargs["axes_i"] = [dim] 3215 3216 return g.op("ReduceMin", self, **reduce_kwargs), g.op( 3217 "ReduceMax", self, **reduce_kwargs 3218 ) 3219 3220 3221@_onnx_symbolic("aten::exp") 3222def exp(g: jit_utils.GraphContext, self): 3223 return g.op("Exp", self) 3224 3225 3226@_onnx_symbolic("aten::dropout_") 3227@_onnx_symbolic("aten::dropout") 3228@symbolic_helper.parse_args("v", "f", "i") 3229def dropout(g: jit_utils.GraphContext, input, p, train): 3230 symbolic_helper.check_training_mode(train, "dropout") 3231 # if train is False, dropout is no-op 3232 if not train: 3233 return input 3234 r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) 3235 return r 3236 3237 3238@_onnx_symbolic( 3239 "aten::alpha_dropout_", 3240 decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")], 3241) # See Note [Export inplace] 3242@_onnx_symbolic( 3243 "aten::feature_alpha_dropout_", 3244 decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")], 3245) 3246@_onnx_symbolic( 3247 "aten::feature_dropout_", 3248 decorate=[symbolic_helper._apply_params("aten::feature_dropout_")], 3249) 3250@_onnx_symbolic( 3251 "aten::feature_alpha_dropout", 3252 decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")], 3253) 3254@_onnx_symbolic( 3255 "aten::alpha_dropout", 3256 decorate=[symbolic_helper._apply_params("aten::alpha_dropout")], 3257) 3258@_onnx_symbolic( 3259 "aten::feature_dropout", 3260 decorate=[symbolic_helper._apply_params("aten::feature_dropout")], 3261) 3262def _unsupported_dropout(name: str): 3263 @symbolic_helper.parse_args("v", "none", "b") 3264 def feature_dropout(g, input, p, train): 3265 # NB: In inference mode, FeatureDropout is exported as an identity op. 3266 if train: 3267 return symbolic_helper._unimplemented(name, "training mode", input) 3268 return input 3269 3270 return feature_dropout 3271 3272 3273@_onnx_symbolic("aten::norm") 3274@symbolic_helper.parse_args("v", "t", "is", "i", "v") 3275def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): 3276 if p == 1: 3277 f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1") 3278 elif p == 2: 3279 f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2") 3280 else: 3281 raise errors.SymbolicValueError( 3282 "ONNX export only p-norms with p of 1 or 2", self 3283 ) 3284 result = f(g, self, dim=dim, keepdim=keepdim) 3285 if dtype is not None: 3286 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 3287 result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) 3288 return result 3289 3290 3291@_onnx_symbolic("aten::conv_tbc") 3292@symbolic_helper.parse_args("v", "v", "v", "i") 3293def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): 3294 # input must have 3 dimensions, see: 3295 # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 3296 # input = (time, batch, in_channels) 3297 # weight = (kernel_width, in_channels, out_channels) 3298 # bias = (out_channels,) 3299 input = g.op("Transpose", input, perm_i=[1, 2, 0]) 3300 weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) 3301 conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) 3302 return g.op("Transpose", conv, perm_i=[2, 0, 1]) 3303 3304 3305@_onnx_symbolic("aten::_unique") 3306@symbolic_helper.parse_args("v", "i", "i") 3307def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): 3308 return symbolic_helper._onnx_unsupported("_unique", input) 3309 3310 3311@_onnx_symbolic("aten::_unique2") 3312@symbolic_helper.parse_args("v", "i", "i", "i") 3313def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): 3314 symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) 3315 3316 3317@_onnx_symbolic("aten::_cast_Byte") 3318@_deprecation.deprecated( 3319 "2.0", 3320 "the future", 3321 "Avoid using this function and create a Cast node instead", 3322) 3323def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): 3324 return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) 3325 3326 3327@_onnx_symbolic("aten::_cast_Char") 3328@_deprecation.deprecated( 3329 "2.0", 3330 "the future", 3331 "Avoid using this function and create a Cast node instead", 3332) 3333def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): 3334 return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) 3335 3336 3337@_onnx_symbolic("aten::_cast_Short") 3338@_deprecation.deprecated( 3339 "2.0", 3340 "the future", 3341 "Avoid using this function and create a Cast node instead", 3342) 3343def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): 3344 return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) 3345 3346 3347@_onnx_symbolic("aten::_cast_Int") 3348@_deprecation.deprecated( 3349 "2.0", 3350 "the future", 3351 "Avoid using this function and create a Cast node instead", 3352) 3353def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): 3354 return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) 3355 3356 3357@_onnx_symbolic("aten::_cast_Long") 3358@_deprecation.deprecated( 3359 "2.0", 3360 "the future", 3361 "Avoid using this function and create a Cast node instead", 3362) 3363def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): 3364 return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) 3365 3366 3367@_onnx_symbolic("aten::_cast_Half") 3368@_deprecation.deprecated( 3369 "2.0", 3370 "the future", 3371 "Avoid using this function and create a Cast node instead", 3372) 3373def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): 3374 return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) 3375 3376 3377@_onnx_symbolic("aten::_cast_Float") 3378@_deprecation.deprecated( 3379 "2.0", 3380 "the future", 3381 "Avoid using this function and create a Cast node instead", 3382) 3383def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): 3384 return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) 3385 3386 3387@_onnx_symbolic("aten::_cast_Double") 3388@_deprecation.deprecated( 3389 "2.0", 3390 "the future", 3391 "Avoid using this function and create a Cast node instead", 3392) 3393def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): 3394 return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) 3395 3396 3397@_onnx_symbolic("aten::_cast_Bool") 3398@_deprecation.deprecated( 3399 "2.0", 3400 "the future", 3401 "Avoid using this function and create a Cast node instead", 3402) 3403def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): 3404 return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) 3405 3406 3407@_onnx_symbolic("aten::empty") 3408@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") 3409def empty( 3410 g: jit_utils.GraphContext, 3411 sizes, 3412 dtype, 3413 layout, 3414 device, 3415 pin_memory=False, 3416 memory_format=None, 3417): 3418 return zeros(g, sizes, dtype, layout, device, pin_memory) 3419 3420 3421@_onnx_symbolic("aten::empty_like") 3422@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") 3423def empty_like( 3424 g: jit_utils.GraphContext, 3425 input, 3426 dtype=None, 3427 layout=None, 3428 device=None, 3429 pin_memory=False, 3430 memory_format=None, 3431): 3432 return zeros_like(g, input, dtype, layout, device, pin_memory) 3433 3434 3435@_onnx_symbolic("aten::new_empty") 3436def new_empty( 3437 g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False 3438): 3439 self_dtype = symbolic_helper._try_get_scalar_type(self) 3440 if symbolic_helper._is_none(dtype) and self_dtype is not None: 3441 dtype = self_dtype 3442 return empty(g, sizes, dtype, layout, device, pin_memory) 3443 3444 3445@_onnx_symbolic("aten::scalar_tensor") 3446def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options): 3447 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 3448 if dtype is None: 3449 dtype = _type_utils.JitScalarType.FLOAT 3450 scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type()) 3451 return scalar 3452 3453 3454@_onnx_symbolic("aten::tensor") 3455def tensor( 3456 g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False 3457): 3458 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 3459 if symbolic_helper._is_packed_list(data): 3460 if dtype is None: 3461 dtype = _type_utils.JitScalarType.from_value( 3462 symbolic_helper._unpack_list(data)[0] 3463 ) 3464 input_list = [] 3465 for t in symbolic_helper._unpack_list(data): 3466 shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) 3467 t = symbolic_helper._reshape_helper(g, t, shape_reference) 3468 t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type()) 3469 input_list.append(t) 3470 return g.op("Concat", *input_list, axis_i=0) 3471 else: 3472 if dtype is None: 3473 dtype = _type_utils.JitScalarType.from_value(data) 3474 if symbolic_helper._is_list(data) and ( 3475 symbolic_helper._is_tensor_list(data) 3476 or symbolic_helper._is_scalar_list(data) 3477 ): 3478 data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1) 3479 return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type()) 3480 3481 3482@_onnx_symbolic("aten::as_tensor") 3483def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None): 3484 return tensor(g, data, dtype, device) 3485 3486 3487@_onnx_symbolic("aten::zeros") 3488@symbolic_helper.parse_args("v", "i", "v", "v", "v") 3489def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): 3490 # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it 3491 if dtype is None: 3492 scalar_type = _type_utils.JitScalarType.FLOAT 3493 else: 3494 scalar_type = _type_utils.JitScalarType(dtype) 3495 sizes_ = symbolic_helper._maybe_get_const(sizes, "is") 3496 if isinstance(sizes_, list) and len(sizes_) == 0: 3497 sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) 3498 return g.op( 3499 "ConstantOfShape", 3500 sizes, 3501 value_t=torch.tensor([0], dtype=scalar_type.dtype()), 3502 ) 3503 3504 3505@_onnx_symbolic("aten::zeros_like") 3506@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") 3507def zeros_like( 3508 g: jit_utils.GraphContext, 3509 input, 3510 dtype=None, 3511 layout=None, 3512 device=None, 3513 pin_memory=False, 3514 memory_format=None, 3515): 3516 shape = g.op("Shape", input) 3517 if symbolic_helper._is_none(dtype): 3518 scalar_type = _type_utils.JitScalarType.from_value( 3519 input, _type_utils.JitScalarType.FLOAT 3520 ) 3521 else: 3522 scalar_type = _type_utils.JitScalarType(dtype) 3523 return g.op( 3524 "ConstantOfShape", 3525 shape, 3526 value_t=torch.tensor([0], dtype=scalar_type.dtype()), 3527 ) 3528 3529 3530@_onnx_symbolic("aten::new_zeros") 3531def new_zeros( 3532 g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False 3533): 3534 self_dtype = symbolic_helper._try_get_scalar_type(self) 3535 3536 if symbolic_helper._is_none(dtype) and self_dtype is not None: 3537 dtype = self_dtype 3538 return zeros(g, sizes, dtype, layout, device, pin_memory) 3539 3540 3541@_onnx_symbolic("aten::zero") 3542def zero(g: jit_utils.GraphContext, self): 3543 self_dtype = symbolic_helper._try_get_scalar_type(self) 3544 return zeros_like(g, self, self_dtype) 3545 3546 3547@_onnx_symbolic("aten::ones") 3548@symbolic_helper.parse_args("v", "i", "v", "v", "v") 3549def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): 3550 if dtype is None: 3551 scalar_type = _type_utils.JitScalarType.FLOAT 3552 else: 3553 scalar_type = _type_utils.JitScalarType(dtype) 3554 sizes_ = symbolic_helper._maybe_get_const(sizes, "is") 3555 if isinstance(sizes_, list) and len(sizes_) == 0: 3556 sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) 3557 return g.op( 3558 "ConstantOfShape", 3559 sizes, 3560 value_t=torch.tensor([1], dtype=scalar_type.dtype()), 3561 ) 3562 3563 3564@_onnx_symbolic("aten::ones_like") 3565@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") 3566def ones_like( 3567 g: jit_utils.GraphContext, 3568 input, 3569 dtype=None, 3570 layout=None, 3571 device=None, 3572 pin_memory=False, 3573 memory_format=None, 3574): 3575 shape = g.op("Shape", input) 3576 if symbolic_helper._is_none(dtype): 3577 scalar_type = _type_utils.JitScalarType.from_value( 3578 input, _type_utils.JitScalarType.FLOAT 3579 ) 3580 else: 3581 scalar_type = _type_utils.JitScalarType(dtype) 3582 return g.op( 3583 "ConstantOfShape", 3584 shape, 3585 value_t=torch.tensor([1], dtype=scalar_type.dtype()), 3586 ) 3587 3588 3589@_onnx_symbolic("aten::new_ones") 3590def new_ones( 3591 g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False 3592): 3593 self_dtype = symbolic_helper._try_get_scalar_type(self) 3594 if symbolic_helper._is_none(dtype) and self_dtype is not None: 3595 dtype = self_dtype 3596 return ones(g, sizes, dtype, layout, device, pin_memory) 3597 3598 3599@_onnx_symbolic("aten::full") 3600def full( 3601 g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False 3602): 3603 const_value = symbolic_helper._maybe_get_const(value, "t") 3604 if symbolic_helper._is_value(const_value): 3605 dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype 3606 tmp = zeros(g, sizes, dtype, layout, device) 3607 return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) 3608 else: 3609 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 3610 if dtype is None: 3611 scalar_type = _type_utils.JitScalarType.FLOAT 3612 else: 3613 scalar_type = _type_utils.JitScalarType(dtype) 3614 sizes_ = symbolic_helper._maybe_get_const(sizes, "is") 3615 if isinstance(sizes_, list) and len(sizes_) == 0: 3616 sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) 3617 return g.op( 3618 "ConstantOfShape", 3619 sizes, 3620 value_t=const_value.view(1).to(scalar_type.dtype()), 3621 ) 3622 3623 3624@_onnx_symbolic("aten::full_like") 3625def full_like( 3626 g: jit_utils.GraphContext, 3627 input, 3628 fill_value, 3629 dtype=None, 3630 layout=None, 3631 device=None, 3632 pin_memory=False, 3633 memory_format=None, 3634): 3635 fill_value = symbolic_helper._maybe_get_const(fill_value, "f") 3636 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 3637 if dtype is None: 3638 scalar_type = _type_utils.JitScalarType.from_value( 3639 input, _type_utils.JitScalarType.FLOAT 3640 ) 3641 else: 3642 scalar_type = _type_utils.JitScalarType(dtype) 3643 if symbolic_helper._is_value(fill_value): 3644 tmp = zeros_like(g, input, dtype, layout, device) 3645 fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type()) 3646 return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1))) 3647 else: 3648 shape = g.op("Shape", input) 3649 return g.op( 3650 "ConstantOfShape", 3651 shape, 3652 value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()), 3653 ) 3654 3655 3656@_onnx_symbolic("aten::new_full") 3657def new_full( 3658 g: jit_utils.GraphContext, 3659 self, 3660 size, 3661 fill_value, 3662 dtype, 3663 layout, 3664 device, 3665 pin_memory=False, 3666): 3667 self_dtype = symbolic_helper._try_get_scalar_type(self) 3668 if symbolic_helper._is_none(dtype) and self_dtype is not None: 3669 dtype = self_dtype 3670 return full(g, size, fill_value, dtype, layout, device, pin_memory) 3671 3672 3673@_onnx_symbolic("aten::eye") 3674def eye(g: jit_utils.GraphContext, *args): 3675 if len(args) == 5: 3676 # aten::eye(n, dtype, layout, device, pin_memory) 3677 n, dtype, layout, device, pin_memory = args 3678 dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) 3679 shape = g.op("Concat", dim_size, dim_size, axis_i=0) 3680 tensor = zeros(g, shape, dtype, layout, device) 3681 return g.op("EyeLike", tensor) 3682 if len(args) == 6: 3683 # aten::eye(n, m, dtype, layout, device, pin_memory) 3684 n, m, dtype, layout, device, pin_memory = args 3685 shape = g.op( 3686 "Concat", 3687 symbolic_helper._unsqueeze_helper(g, n, [0]), 3688 symbolic_helper._unsqueeze_helper(g, m, [0]), 3689 axis_i=0, 3690 ) 3691 tensor = zeros(g, shape, dtype, layout, device) 3692 return g.op("EyeLike", tensor) 3693 3694 return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments") 3695 3696 3697@_onnx_symbolic("aten::slice") 3698def slice(g: jit_utils.GraphContext, self, *args): 3699 if len(args) == 4: 3700 # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor 3701 dim, start, end, step = args 3702 step = symbolic_helper._parse_arg(step, "i") 3703 if step != 1: 3704 raise errors.SymbolicValueError("step!=1 is currently not supported", self) 3705 is_start_none = start.node().kind() == "prim::Constant" and isinstance( 3706 start.type(), _C.NoneType 3707 ) 3708 is_end_none = end.node().kind() == "prim::Constant" and isinstance( 3709 end.type(), _C.NoneType 3710 ) 3711 is_start_onnx_const = start.node().kind() == "onnx::Constant" 3712 is_end_onnx_const = end.node().kind() == "onnx::Constant" 3713 if ( 3714 ((not is_start_none) and (not is_start_onnx_const)) 3715 or ((not is_end_none) and (not is_end_onnx_const)) 3716 or dim.node().kind() != "onnx::Constant" 3717 ): 3718 if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: 3719 raise errors.SymbolicValueError( 3720 "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice " 3721 "is a deprecated experimental op. Please use statically allocated " 3722 "variables or export to a higher opset version.", 3723 self, 3724 ) 3725 else: 3726 start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0]) 3727 end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0]) 3728 dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0]) 3729 return g.op( 3730 "DynamicSlice", 3731 self, 3732 start_unsqueezed, 3733 end_unsqueezed, 3734 dim_unsqueezed, 3735 ) 3736 else: 3737 start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") 3738 end = ( 3739 _constants.INT64_MAX 3740 if is_end_none 3741 else symbolic_helper._parse_arg(end, "i") 3742 ) 3743 dim = symbolic_helper._parse_arg(dim, "i") 3744 return symbolic_helper._slice_helper( 3745 g, self, axes=[dim], starts=[start], ends=[end] 3746 ) 3747 elif len(args) == 3: 3748 # aten::slice(t[] l, int start, int end, int step) -> t[] 3749 start, end, step = args 3750 dim = 0 3751 is_start_none = start.node().kind() == "prim::Constant" and isinstance( 3752 start.type(), _C.NoneType 3753 ) 3754 is_end_none = end.node().kind() == "prim::Constant" and isinstance( 3755 end.type(), _C.NoneType 3756 ) 3757 start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") 3758 end = ( 3759 _constants.INT64_MAX 3760 if is_end_none 3761 else symbolic_helper._parse_arg(end, "i") 3762 ) 3763 return symbolic_helper._slice_helper( 3764 g, self, axes=[dim], starts=[start], ends=[end] 3765 ) 3766 3767 return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments") 3768 3769 3770@_onnx_symbolic("aten::hardtanh") 3771@symbolic_helper.quantized_args(True) 3772@symbolic_helper.parse_args("v", "f", "f") 3773def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): 3774 return symbolic_helper._op_with_optional_float_cast( 3775 g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 3776 ) 3777 3778 3779@_onnx_symbolic("aten::hardswish") 3780@symbolic_helper.quantized_args(True) 3781@symbolic_helper.parse_args("v") 3782def hardswish(g: jit_utils.GraphContext, self): 3783 hs = hardsigmoid(g, self) 3784 return g.op("Mul", self, hs) 3785 3786 3787@_onnx_symbolic("aten::hardsigmoid") 3788# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp 3789@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) 3790@symbolic_helper.parse_args("v") 3791def hardsigmoid(g: jit_utils.GraphContext, self): 3792 # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. 3793 # See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html 3794 return g.op("HardSigmoid", self, alpha_f=1 / 6) 3795 3796 3797@_onnx_symbolic("aten::tanhshrink") 3798@symbolic_helper.parse_args("v") 3799def tanhshrink(g: jit_utils.GraphContext, self): 3800 return g.op("Sub", self, tanh(g, self)) 3801 3802 3803@_onnx_symbolic("aten::hardshrink") 3804@symbolic_helper.parse_args("v", "f") 3805def hardshrink(g: jit_utils.GraphContext, self, lambd): 3806 scalar_type = _type_utils.JitScalarType.from_value( 3807 self, _type_utils.JitScalarType.FLOAT 3808 ) 3809 lambd_op = g.op( 3810 "Constant", 3811 value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), 3812 ) 3813 cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) 3814 return g.op( 3815 "Where", 3816 cond, 3817 self, 3818 g.op( 3819 "Constant", 3820 value_t=torch.tensor(0, dtype=scalar_type.dtype()), 3821 ), 3822 ) 3823 3824 3825@_onnx_symbolic("aten::softshrink") 3826@symbolic_helper.parse_args("v", "f") 3827def softshrink(g: jit_utils.GraphContext, self, lambd): 3828 scalar_type = _type_utils.JitScalarType.from_value( 3829 self, _type_utils.JitScalarType.FLOAT 3830 ) 3831 lambd_op = g.op( 3832 "Constant", 3833 value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), 3834 ) 3835 gt_cond = gt(g, self, lambd_op) 3836 gt_out = g.op( 3837 "Where", 3838 gt_cond, 3839 sub(g, self, lambd_op), 3840 g.op( 3841 "Constant", 3842 value_t=torch.tensor(0, dtype=scalar_type.dtype()), 3843 ), 3844 ) 3845 lt_cond = lt(g, self, neg(g, lambd_op)) 3846 lt_out = g.op( 3847 "Where", 3848 lt_cond, 3849 add(g, self, lambd_op), 3850 g.op( 3851 "Constant", 3852 value_t=torch.tensor(0, dtype=scalar_type.dtype()), 3853 ), 3854 ) 3855 return add(g, gt_out, lt_out) 3856 3857 3858@_onnx_symbolic("aten::alias") 3859def alias(g: jit_utils.GraphContext, self): 3860 return self 3861 3862 3863@_onnx_symbolic("aten::unsqueeze") 3864@symbolic_helper.parse_args("v", "i") 3865def unsqueeze(g: jit_utils.GraphContext, self, dim): 3866 """Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`""" 3867 # Handle negative dim 3868 if dim < 0: 3869 rank = symbolic_helper._get_tensor_rank(self) 3870 if rank is not None: 3871 warnings.warn( 3872 "ONNX export unsqueeze with negative axis " 3873 + str(dim) 3874 + " might cause the onnx model to be incorrect. " 3875 + "Negative axis is not supported in ONNX. " 3876 + "Axis is converted to " 3877 + str(dim + rank + 1) 3878 + " based on input shape at export time. " 3879 + "Passing an tensor of different rank in execution will be incorrect." 3880 ) 3881 dim = dim + rank + 1 3882 else: 3883 return symbolic_helper._unimplemented( 3884 "unsqueeze", "negative axis with unknown input rank", self 3885 ) 3886 3887 return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) 3888 3889 3890@_onnx_symbolic("aten::sort") 3891# TODO(justinchuby): Support multiple quantized args in output 3892@symbolic_helper.parse_args("v", "i", "i", "none") 3893def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): 3894 if out is not None: 3895 symbolic_helper._unimplemented( 3896 "Sort", "Out parameter is not supported for sort", self 3897 ) 3898 self_sizes = symbolic_helper._get_tensor_sizes(self) 3899 try: 3900 dim_size = self_sizes[dim] 3901 except Exception: 3902 # FIXME(justinchuby): Avoid catching Exception. 3903 # Catch a more specific exception instead. 3904 dim_size = None 3905 3906 if dim_size is None: 3907 return symbolic_helper._unimplemented("Sort", "input size not accessible", self) 3908 3909 return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) 3910 3911 3912@_onnx_symbolic("aten::numel") 3913def numel(g: jit_utils.GraphContext, self): 3914 return symbolic_helper._numel_helper(g, self) 3915 3916 3917@_onnx_symbolic("aten::topk") 3918# TODO(justinchuby): Support multiple quantized args in output 3919@symbolic_helper.parse_args("v", "i", "i", "i", "i", "none") 3920def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): 3921 if out is not None: 3922 symbolic_helper._unimplemented( 3923 "TopK", "Out parameter is not supported for topk", self 3924 ) 3925 if not largest: 3926 symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self) 3927 3928 return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) 3929 3930 3931@_onnx_symbolic("prim::convert_element_type") 3932def convert_element_type(g: jit_utils.GraphContext, self, *args): 3933 dtype = symbolic_helper._get_const(args[0], "i", "dtype") 3934 return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) 3935 3936 3937@_onnx_symbolic("aten::to") 3938def to(g: jit_utils.GraphContext, self, *args): 3939 def is_aten_to_device_only(args): 3940 if len(args) == 4: 3941 # aten::to(Tensor, Device, bool, bool, memory_format) 3942 return ( 3943 args[0].node().kind() == "prim::device" 3944 or args[0].type().isSubtypeOf(_C.ListType.ofInts()) 3945 or isinstance(args[0].type(), _C.DeviceObjType) 3946 ) 3947 elif len(args) == 5: 3948 # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) 3949 # When dtype is None, this is a aten::to(device) call 3950 dtype = symbolic_helper._get_const(args[1], "i", "dtype") 3951 return dtype is None 3952 elif len(args) in (6, 7): 3953 # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor 3954 # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor 3955 # When dtype is None, this is a aten::to(device) call 3956 dtype = symbolic_helper._get_const(args[0], "i", "dtype") 3957 return dtype is None 3958 return False 3959 3960 # ONNX doesn't have a concept of a device, so we ignore device-only casts 3961 if is_aten_to_device_only(args): 3962 return self 3963 3964 if len(args) == 4: 3965 # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=<Tensor>]() 3966 # In this case, the constant value is a tensor not int, 3967 # so symbolic_helper._maybe_get_const(args[0], 'i') would not work. 3968 dtype = args[0] 3969 if ( 3970 symbolic_helper._is_value(args[0]) 3971 and args[0].node().kind() == "onnx::Constant" 3972 ): 3973 tval = symbolic_helper._node_get(args[0].node(), "value") 3974 if isinstance(tval, torch.Tensor): 3975 if len(tval.shape) == 0: 3976 tval = tval.item() 3977 dtype = int(tval) 3978 else: 3979 dtype = tval 3980 3981 if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor): 3982 # aten::to(Tensor, Tensor, bool, bool, memory_format) 3983 dtype = _type_utils.JitScalarType.from_value(args[0]) 3984 return g.op( 3985 "Cast", 3986 self, 3987 to_i=dtype.onnx_type(), 3988 ) 3989 else: 3990 # aten::to(Tensor, ScalarType, bool, bool, memory_format) 3991 # memory_format is ignored 3992 return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) 3993 elif len(args) == 5: 3994 # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) 3995 dtype = symbolic_helper._get_const(args[1], "i", "dtype") 3996 # memory_format is ignored 3997 return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) 3998 elif len(args) == 6: 3999 # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor 4000 dtype = symbolic_helper._get_const(args[0], "i", "dtype") 4001 # Layout, device and memory_format are ignored 4002 return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) 4003 elif len(args) == 7: 4004 # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor 4005 dtype = symbolic_helper._get_const(args[0], "i", "dtype") 4006 # Layout, device and memory_format are ignored 4007 return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) 4008 4009 return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self) 4010 4011 4012@_onnx_symbolic("aten::repeat") 4013def repeat(g: jit_utils.GraphContext, self, repeats): 4014 dtype = _type_utils.JitScalarType.INT64 4015 shape_ = ones_like(g, repeats, dtype) 4016 self = g.op("Expand", self, shape_) 4017 return g.op("Tile", self, repeats) 4018 4019 4020@_onnx_symbolic("aten::repeat_interleave") 4021def repeat_interleave( 4022 g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None 4023): 4024 repeats_dim = symbolic_helper._get_tensor_rank(repeats) 4025 repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) 4026 input_sizes = symbolic_helper._get_tensor_sizes(self) 4027 if repeats_dim is None: 4028 raise errors.SymbolicValueError( 4029 "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", 4030 self, 4031 ) 4032 if repeats_sizes is None: 4033 raise errors.SymbolicValueError( 4034 "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", 4035 self, 4036 ) 4037 if input_sizes is None: 4038 raise errors.SymbolicValueError( 4039 "Unsupported: ONNX export of repeat_interleave for unknown input size.", 4040 self, 4041 ) 4042 4043 # if dim is None flatten 4044 # By default, use the flattened input array, and return a flat output array 4045 if symbolic_helper._is_none(dim): 4046 self = symbolic_helper._reshape_helper( 4047 g, self, g.op("Constant", value_t=torch.tensor([-1])) 4048 ) 4049 dim = torch.tensor(0, dtype=torch.int64) 4050 else: 4051 dim = symbolic_helper._maybe_get_scalar(dim) 4052 4053 # Handle cases where dim is negative 4054 if dim < 0: 4055 dim += len(input_sizes) 4056 4057 input_sizes_temp = input_sizes.copy() 4058 for idx, input_size in enumerate(input_sizes): 4059 if input_size is None: 4060 input_sizes[idx], input_sizes_temp[idx] = 0, -1 4061 4062 # Cases where repeats is an int or single value tensor 4063 if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): 4064 if input_sizes[dim] == 0: 4065 return symbolic_helper._onnx_opset_unsupported_detailed( 4066 "repeat_interleave", 4067 9, 4068 13, 4069 "Unsupported along dimension with unknown input size", 4070 self, 4071 ) 4072 return symbolic_helper._repeat_interleave_single_value_repeat_helper( 4073 g, self, repeats, dim 4074 ) 4075 4076 # Cases where repeats is a 1 dim Tensor 4077 elif repeats_dim == 1: 4078 if input_sizes[dim] == 0: 4079 return symbolic_helper._onnx_opset_unsupported_detailed( 4080 "repeat_interleave", 4081 9, 4082 13, 4083 "Unsupported along dimension with unknown input size", 4084 self, 4085 ) 4086 if repeats_sizes[0] is None: 4087 return symbolic_helper._onnx_opset_unsupported_detailed( 4088 "repeat_interleave", 4089 9, 4090 13, 4091 "Unsupported for cases with dynamic repeats", 4092 self, 4093 ) 4094 assert ( 4095 repeats_sizes[0] == input_sizes[dim] 4096 ), "repeats must have the same size as input along dim" 4097 reps = repeats_sizes[0] 4098 else: 4099 raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self) 4100 4101 final_splits = [] 4102 r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0) 4103 i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim) 4104 input_sizes[dim], input_sizes_temp[dim] = -1, 1 4105 for idx, r_split in enumerate(r_splits): 4106 i_split = unsqueeze(g, i_splits[idx], dim + 1) 4107 r_concat = [ 4108 g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), 4109 r_split, 4110 g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), 4111 ] 4112 r_concat = g.op("Concat", *r_concat, axis_i=0) 4113 i_split = expand(g, i_split, r_concat, None) 4114 i_split = symbolic_helper._reshape_helper( 4115 g, 4116 i_split, 4117 g.op("Constant", value_t=torch.LongTensor(input_sizes)), 4118 allowzero=0, 4119 ) 4120 final_splits.append(i_split) 4121 return g.op("Concat", *final_splits, axis_i=dim) 4122 4123 4124@_onnx_symbolic("aten::pixel_shuffle") 4125@symbolic_helper.parse_args("v", "i") 4126def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): 4127 dims = symbolic_helper._get_tensor_sizes(self) 4128 if len(dims) != 4: 4129 return symbolic_helper._unimplemented( 4130 "pixel_shuffle", "only support 4d input", self 4131 ) 4132 if any(i is None for i in dims[1:]): 4133 after_view = symbolic_helper._reshape_helper( 4134 g, 4135 symbolic_helper._unsqueeze_helper(g, self, [2, 3]), 4136 g.op( 4137 "Constant", 4138 value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]), 4139 ), 4140 allowzero=0, 4141 ) 4142 after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) 4143 # For dynamic input shapes, two reshapes are performed 4144 reshape_h = symbolic_helper._reshape_helper( 4145 g, 4146 after_transpose, 4147 g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])), 4148 allowzero=0, 4149 ) 4150 reshape_w = symbolic_helper._reshape_helper( 4151 g, 4152 reshape_h, 4153 g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])), 4154 allowzero=0, 4155 ) 4156 return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5]) 4157 else: 4158 output_channel = dims[1] // upscale_factor // upscale_factor 4159 after_view = symbolic_helper._reshape_helper( 4160 g, 4161 self, 4162 g.op( 4163 "Constant", 4164 value_t=torch.tensor( 4165 [ 4166 -1, 4167 output_channel, 4168 upscale_factor, 4169 upscale_factor, 4170 dims[2], 4171 dims[3], 4172 ] 4173 ), 4174 ), 4175 allowzero=0, 4176 ) 4177 after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) 4178 return symbolic_helper._reshape_helper( 4179 g, 4180 after_transpose, 4181 g.op( 4182 "Constant", 4183 value_t=torch.tensor( 4184 [ 4185 -1, 4186 output_channel, 4187 dims[2] * upscale_factor, 4188 dims[3] * upscale_factor, 4189 ] 4190 ), 4191 ), 4192 allowzero=0, 4193 ) 4194 4195 4196@_onnx_symbolic("aten::pixel_unshuffle") 4197@symbolic_helper.parse_args("v", "i") 4198def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor): 4199 dims = symbolic_helper._get_tensor_sizes(self) 4200 if len(dims) != 4: 4201 return symbolic_helper._unimplemented( 4202 "pixel_shuffle", "only support 4d input", self 4203 ) 4204 if any(i is None for i in dims[1:]): 4205 # For dynamic input shapes, two reshapes are performed 4206 reshape_h = symbolic_helper._reshape_helper( 4207 g, 4208 symbolic_helper._unsqueeze_helper(g, self, [3]), 4209 g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])), 4210 allowzero=0, 4211 ) 4212 reshape_w = symbolic_helper._reshape_helper( 4213 g, 4214 reshape_h, 4215 g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])), 4216 allowzero=0, 4217 ) 4218 after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4]) 4219 final_reshape = symbolic_helper._reshape_helper( 4220 g, 4221 after_transpose, 4222 g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])), 4223 allowzero=0, 4224 ) 4225 return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3]) 4226 else: 4227 output_channel = dims[1] * downscale_factor * downscale_factor 4228 after_view = symbolic_helper._reshape_helper( 4229 g, 4230 self, 4231 g.op( 4232 "Constant", 4233 value_t=torch.tensor( 4234 [ 4235 -1, 4236 dims[1], 4237 dims[2] // downscale_factor, 4238 downscale_factor, 4239 dims[3] // downscale_factor, 4240 downscale_factor, 4241 ] 4242 ), 4243 ), 4244 allowzero=0, 4245 ) 4246 after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4]) 4247 return symbolic_helper._reshape_helper( 4248 g, 4249 after_transpose, 4250 g.op( 4251 "Constant", 4252 value_t=torch.tensor( 4253 [ 4254 -1, 4255 output_channel, 4256 dims[2] // downscale_factor, 4257 dims[3] // downscale_factor, 4258 ] 4259 ), 4260 ), 4261 allowzero=0, 4262 ) 4263 4264 4265def _generic_rnn( 4266 g: jit_utils.GraphContext, 4267 variant, 4268 input, 4269 initial_states, 4270 all_weights, 4271 has_biases, 4272 num_layers, 4273 dropout, 4274 train, 4275 bidirectional, 4276 batch_first=None, 4277 batch_sizes=None, 4278): 4279 warnings.warn( 4280 "Exporting a model to ONNX with a batch_size other than 1, " 4281 + "with a variable length with " 4282 + variant 4283 + " can cause an error " 4284 + "when running the ONNX model with a different batch size. " 4285 + "Make sure to save the model with a batch size of 1, " 4286 + "or define the initial states (h0/c0) as inputs of the model. " 4287 ) 4288 4289 onnxActivations = [ 4290 "Relu", 4291 "Tanh", 4292 "Sigmoid", 4293 "Affine", 4294 "LeakyRelu", 4295 "ThresholdedRelu", 4296 "ScaledTanh", 4297 "HardSigmoid", 4298 "Elu", 4299 "Softsign", 4300 "Softplus", 4301 ] 4302 variantToOnnxActivationMap = dict( 4303 zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations) 4304 ) 4305 weights_per_layer = 4 if has_biases else 2 4306 # this means that projections are used inside LSTM, so need to tell user that it's not supported 4307 if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * ( 4308 1 + bidirectional 4309 ): 4310 return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input) 4311 assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) 4312 layer_weights = [ 4313 all_weights[i : i + weights_per_layer] 4314 for i in range(0, len(all_weights), weights_per_layer) 4315 ] 4316 if batch_first: 4317 # batch, seq, feat -> seq, batch, feat 4318 input = g.op("Transpose", input, perm_i=[1, 0, 2]) 4319 if dropout and train: 4320 return symbolic_helper._unimplemented( 4321 "RNN/GRU/LSTM", "dropout in training mode", input 4322 ) 4323 4324 if variant.startswith("RNN"): 4325 nonlinearity = variantToOnnxActivationMap[variant[4:].lower()] 4326 variant = "RNN" 4327 4328 w_hh = all_weights[1] 4329 hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1) 4330 if hidden_size is None: 4331 return symbolic_helper._unimplemented( 4332 "RNN/GRU/LSTM", "unknown hidden size", input 4333 ) 4334 4335 unidirectional = not bidirectional 4336 4337 prev_output = input 4338 4339 h_outs = [] 4340 if variant == "RNN" or variant == "GRU": 4341 h0 = initial_states 4342 elif variant == "LSTM": 4343 h0, c0 = initial_states 4344 c_outs = [] 4345 4346 sequence_lens = unused(g) if batch_sizes is None else batch_sizes 4347 4348 if variant == "GRU": 4349 # pytorch is reset, input, hidden 4350 # onnx is input, reset, hidden 4351 reform_permutation = [(1, 2), (0, 1), (2, 3)] 4352 elif variant == "LSTM": 4353 # pytorch is input, forget, cell, output. 4354 # onnx is input, output, forget, cell. 4355 reform_permutation = [(0, 1), (3, 4), (1, 3)] 4356 4357 def reform_weights(g, w, n, intervals): 4358 slices = [ 4359 symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) 4360 for x, y in intervals 4361 ] 4362 return g.op("Concat", *slices, axis_i=0) 4363 4364 def transform_weights_no_bias(layer_index): 4365 weights = layer_weights[layer_index] 4366 if variant == "RNN": 4367 weight_ih, weight_hh = weights 4368 elif variant == "GRU" or variant == "LSTM": 4369 weight_ih, weight_hh = ( 4370 reform_weights(g, w, hidden_size, reform_permutation) for w in weights 4371 ) 4372 return tuple( 4373 symbolic_helper._unsqueeze_helper(g, x, [0]) 4374 for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined] 4375 ) 4376 4377 def transform_weights(layer_index): 4378 weights = layer_weights[layer_index] 4379 if variant == "RNN": 4380 weight_ih, weight_hh, bias_ih, bias_hh = weights 4381 elif variant == "GRU" or variant == "LSTM": 4382 weight_ih, weight_hh, bias_ih, bias_hh = ( 4383 reform_weights(g, w, hidden_size, reform_permutation) for w in weights 4384 ) 4385 bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined] 4386 return tuple( 4387 symbolic_helper._unsqueeze_helper(g, x, [0]) 4388 for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined] 4389 ) 4390 4391 def retrieve_state(x, start, end): 4392 return ( 4393 x 4394 if num_layers == 1 4395 else symbolic_helper._slice_helper( 4396 g, x, axes=[0], starts=[start], ends=[end] 4397 ) 4398 ) 4399 4400 for i in range(num_layers): 4401 if unidirectional: 4402 if weights_per_layer == 4: 4403 weight_ih, weight_hh, bias_concat = transform_weights(i) 4404 else: 4405 weight_ih, weight_hh = transform_weights_no_bias(i) 4406 bias_concat = unused(g) 4407 4408 state_indices = i, i + 1 4409 else: 4410 if weights_per_layer == 4: 4411 weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) 4412 weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) 4413 bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0) 4414 else: 4415 weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i) 4416 weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1) 4417 bias_concat = unused(g) 4418 4419 weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0) 4420 weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0) 4421 4422 state_indices = 2 * i, 2 * i + 2 4423 4424 inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] 4425 4426 inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined] 4427 if variant == "LSTM": 4428 inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined] 4429 4430 extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} 4431 if variant == "RNN": 4432 if bidirectional: 4433 activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined] 4434 else: 4435 activation = [nonlinearity] # type: ignore[possibly-undefined] 4436 4437 prev_output, h_out = g.op( 4438 "RNN", 4439 *inputs, 4440 outputs=2, 4441 hidden_size_i=hidden_size, 4442 activations_s=activation, 4443 **extra_kwargs, 4444 ) 4445 elif variant == "GRU": 4446 prev_output, h_out = g.op( 4447 "GRU", 4448 *inputs, 4449 outputs=2, 4450 hidden_size_i=hidden_size, 4451 linear_before_reset_i=1, 4452 **extra_kwargs, 4453 ) 4454 elif variant == "LSTM": 4455 prev_output, h_out, c_out = g.op( 4456 "LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs 4457 ) 4458 4459 if bidirectional: 4460 # The ONNX RNN/GRU/LSTM produce an output of dimensions 4461 # seq_len, num_directions, batch, hidden_size 4462 # We have to convert to match pytorch's expected 4463 # seq_len, batch, num_directions * hidden_size 4464 # by first moving num_directions before hidden_size with 4465 # Transpose, and then combining it with hidden_size 4466 # with Reshape. 4467 prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) 4468 prev_output = symbolic_helper._reshape_helper( 4469 g, 4470 prev_output, 4471 g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), 4472 allowzero=0, 4473 ) 4474 else: 4475 prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) 4476 4477 h_outs.append(h_out) # type: ignore[possibly-undefined] 4478 if variant == "LSTM": 4479 c_outs.append(c_out) # type: ignore[possibly-undefined] 4480 if batch_first: 4481 # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size 4482 prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) 4483 h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined] 4484 if variant == "RNN" or variant == "GRU": 4485 return prev_output, h_outs 4486 elif variant == "LSTM": 4487 c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined] 4488 return prev_output, h_outs, c_outs 4489 4490 4491@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") 4492def _lstm_full( 4493 g: jit_utils.GraphContext, 4494 input, 4495 hidden_v, 4496 weight_v, 4497 has_biases, 4498 num_layers, 4499 dropout, 4500 train, 4501 bidirectional, 4502 batch_first, 4503): 4504 hidden, weight = ( 4505 symbolic_helper._unpack_list(hidden_v), 4506 symbolic_helper._unpack_list(weight_v), 4507 ) 4508 return _generic_rnn( 4509 g, 4510 "LSTM", 4511 input, 4512 hidden, 4513 weight, 4514 has_biases, 4515 num_layers, 4516 dropout, 4517 train, 4518 bidirectional, 4519 batch_first, 4520 ) 4521 4522 4523@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") 4524def _lstm_packed( 4525 g: jit_utils.GraphContext, 4526 input, 4527 batch_sizes, 4528 hidden_v, 4529 weight_v, 4530 has_biases, 4531 num_layers, 4532 dropout, 4533 train, 4534 bidirectional, 4535): 4536 hidden, weight = ( 4537 symbolic_helper._unpack_list(hidden_v), 4538 symbolic_helper._unpack_list(weight_v), 4539 ) 4540 return _generic_rnn( 4541 g, 4542 "LSTM", 4543 input, 4544 hidden, 4545 weight, 4546 has_biases, 4547 num_layers, 4548 dropout, 4549 train, 4550 bidirectional, 4551 batch_sizes=batch_sizes, 4552 ) 4553 4554 4555@_onnx_symbolic("aten::lstm") 4556def lstm(g: jit_utils.GraphContext, *args): 4557 if symbolic_helper._is_tensor_list(args[3]): 4558 return _lstm_packed(g, *args) 4559 else: 4560 return _lstm_full(g, *args) 4561 4562 4563@_onnx_symbolic("aten::lstm_cell") 4564def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh): 4565 input = symbolic_helper._unsqueeze_helper(g, self, [0]) 4566 hidden = symbolic_helper._unpack_list(hidden) 4567 hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden] 4568 weight = ( 4569 (w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh) 4570 ) 4571 has_biases = True if symbolic_helper._is_tensor(b_ih) else False 4572 _, h_outs, c_outs = _generic_rnn( 4573 g, 4574 "LSTM", 4575 input, 4576 hidden, 4577 weight, 4578 has_biases, 4579 num_layers=1, 4580 dropout=0, 4581 train=0, 4582 bidirectional=False, 4583 batch_first=False, 4584 ) 4585 return symbolic_helper._squeeze_helper( 4586 g, h_outs, [0] 4587 ), symbolic_helper._squeeze_helper(g, c_outs, [0]) 4588 4589 4590@_onnx_symbolic( 4591 "aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")] 4592) 4593@_onnx_symbolic( 4594 "aten::rnn_tanh", 4595 decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")], 4596) 4597@_onnx_symbolic( 4598 "aten::rnn_relu", 4599 decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")], 4600) 4601def _one_hidden_rnn(kind: str): 4602 @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") 4603 def _rnn_full( 4604 g, 4605 input, 4606 hidden, 4607 weight_v, 4608 has_biases, 4609 num_layers, 4610 dropout, 4611 train, 4612 bidirectional, 4613 batch_first, 4614 ): 4615 weight = symbolic_helper._unpack_list(weight_v) 4616 return _generic_rnn( 4617 g, 4618 kind, 4619 input, 4620 hidden, 4621 weight, 4622 has_biases, 4623 num_layers, 4624 dropout, 4625 train, 4626 bidirectional, 4627 batch_first, 4628 ) 4629 4630 @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") 4631 def _rnn_packed( 4632 g, 4633 input, 4634 batch_sizes, 4635 hidden, 4636 weight_v, 4637 has_biases, 4638 num_layers, 4639 dropout, 4640 train, 4641 bidirectional, 4642 ): 4643 weight = symbolic_helper._unpack_list(weight_v) 4644 return _generic_rnn( 4645 g, 4646 kind, 4647 input, 4648 hidden, 4649 weight, 4650 has_biases, 4651 num_layers, 4652 dropout, 4653 train, 4654 bidirectional, 4655 batch_sizes=batch_sizes, 4656 ) 4657 4658 def symbolic(g, *args): 4659 if symbolic_helper._is_tensor_list(args[3]): 4660 return _rnn_packed(g, *args) 4661 else: 4662 return _rnn_full(g, *args) 4663 4664 return symbolic 4665 4666 4667@_onnx_symbolic("aten::_dim_arange") 4668@symbolic_helper.parse_args("v", "i") 4669def _dim_arange(g: jit_utils.GraphContext, like, dim): 4670 like_shape = g.op("Shape", like) 4671 stop = g.op( 4672 "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 4673 ) 4674 # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) 4675 return arange(g, stop, 4, None, None, None) 4676 4677 4678@_onnx_symbolic("aten::detach") 4679def detach(g: jit_utils.GraphContext, input): 4680 # Erase aten::detach nodes because ONNX is inference only 4681 return input 4682 4683 4684@_onnx_symbolic("aten::contiguous") 4685@symbolic_helper.parse_args("v", "i") 4686def contiguous(g: jit_utils.GraphContext, input, memory_format): 4687 if memory_format > 2: # allower values are any, preserve and contiguous_format 4688 raise errors.SymbolicValueError( 4689 "onnx memory_format support is not implemented", input 4690 ) 4691 return input 4692 4693 4694@_onnx_symbolic("aten::_pack_padded_sequence") 4695@symbolic_helper.parse_args("v", "v", "i") 4696def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first): 4697 # Currently there is no PackPadded operator in ONNX. We rely on an 4698 # optimization pass to remove this later. It is an error if all 4699 # PackPadded operators cannot be optimized out. 4700 if batch_first: 4701 input = g.op("Transpose", input, perm_i=[1, 0, 2]) 4702 if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): 4703 raise errors.SymbolicValueError( 4704 "'lengths' must be a Tensor for ONNX export", input 4705 ) 4706 # We know it's a TensorType so this check is now safe. 4707 # It's really only necessary because those operators expand to something that 4708 # only works with int32 types in Caffe2... 4709 if ( 4710 _type_utils.JitScalarType.from_value( 4711 lengths, _type_utils.JitScalarType.UNDEFINED 4712 ) 4713 != _type_utils.JitScalarType.INT 4714 ): 4715 lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32) 4716 return g.op("prim::PackPadded", input, lengths, outputs=2) 4717 4718 4719@_onnx_symbolic("aten::_pad_packed_sequence") 4720@symbolic_helper.parse_args("v", "v", "i", "t", "v") 4721def _pad_packed_sequence( 4722 g: jit_utils.GraphContext, 4723 data, 4724 batch_sizes, 4725 batch_first, 4726 padding_value, 4727 total_length, 4728): 4729 # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence 4730 # It is only useful/used when training using data_parallel model, so 4731 # It shouldn't be relevant for ONNX anyway 4732 data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2) 4733 if batch_first: 4734 data = g.op("Transpose", data, perm_i=[1, 0, 2]) 4735 return data, lengths 4736 4737 4738@_onnx_symbolic("aten::randint") 4739def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options): 4740 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 4741 low_i = symbolic_helper._get_const(low, "i", "low") 4742 high_i = symbolic_helper._get_const(high, "i", "high") 4743 if dtype is None: 4744 scalar_type = _type_utils.JitScalarType.INT64 4745 else: 4746 scalar_type = _type_utils.JitScalarType(dtype) 4747 if low_i is None: 4748 raise symbolic_helper._onnx_unsupported("randint", low) 4749 if high_i is None: 4750 raise symbolic_helper._onnx_unsupported("randint", high) 4751 4752 shape = symbolic_helper._maybe_get_const(shapes, "is") 4753 if symbolic_helper._is_value(shape): 4754 shape_const = g.op( 4755 "ConstantOfShape", 4756 shapes, 4757 value_t=torch.tensor([0], dtype=torch.float), 4758 ) 4759 randn = g.op( 4760 "RandomUniformLike", 4761 shape_const, 4762 low_f=low_i, 4763 high_f=high_i, 4764 ) 4765 else: 4766 randn = g.op( 4767 "RandomUniform", 4768 shape_i=shape, 4769 low_f=low_i, 4770 high_f=high_i, 4771 ) 4772 4773 # cast to integer type 4774 int_dtype = _type_utils.JitScalarType.INT64 4775 randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) 4776 if int_dtype != scalar_type: 4777 randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) 4778 return randint 4779 4780 4781@_onnx_symbolic("aten::randint_like") 4782def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options): 4783 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 4784 low_i = symbolic_helper._get_const(low, "i", "low") 4785 high_i = symbolic_helper._get_const(high, "i", "high") 4786 if dtype is None: 4787 scalar_type = _type_utils.JitScalarType.INT64 4788 else: 4789 scalar_type = _type_utils.JitScalarType(dtype) 4790 if low_i is None: 4791 raise symbolic_helper._onnx_unsupported("randint", low) 4792 if high_i is None: 4793 raise symbolic_helper._onnx_unsupported("randint", high) 4794 4795 randn = g.op( 4796 "RandomUniformLike", 4797 self, 4798 low_f=low_i, 4799 high_f=high_i, 4800 ) 4801 4802 # cast to integer type 4803 int_dtype = _type_utils.JitScalarType.INT64 4804 randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) 4805 if int_dtype != scalar_type: 4806 randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) 4807 return randint 4808 4809 4810@_onnx_symbolic("aten::randn") 4811def randn(g: jit_utils.GraphContext, shapes, dtype, *options): 4812 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 4813 if dtype is None: 4814 scalar_type = _type_utils.JitScalarType.FLOAT 4815 else: 4816 scalar_type = _type_utils.JitScalarType(dtype) 4817 shape = symbolic_helper._maybe_get_const(shapes, "is") 4818 if symbolic_helper._is_value(shape): 4819 shape_const = g.op( 4820 "ConstantOfShape", 4821 shapes, 4822 value_t=torch.tensor([0], dtype=torch.float), 4823 ) 4824 return g.op( 4825 "RandomNormalLike", 4826 shape_const, 4827 dtype_i=scalar_type.onnx_type(), 4828 ) 4829 return g.op( 4830 "RandomNormal", 4831 shape_i=shape, 4832 dtype_i=scalar_type.onnx_type(), 4833 ) 4834 4835 4836@_onnx_symbolic("aten::rand") 4837def rand(g: jit_utils.GraphContext, shapes, dtype, *options): 4838 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 4839 if dtype is None: 4840 scalar_type = _type_utils.JitScalarType.FLOAT 4841 else: 4842 scalar_type = _type_utils.JitScalarType(dtype) 4843 shape = symbolic_helper._maybe_get_const(shapes, "is") 4844 if symbolic_helper._is_value(shape): 4845 shape_const = g.op( 4846 "ConstantOfShape", 4847 shapes, 4848 value_t=torch.tensor([0], dtype=torch.float), 4849 ) 4850 return g.op( 4851 "RandomUniformLike", 4852 shape_const, 4853 dtype_i=scalar_type.onnx_type(), 4854 ) 4855 return g.op( 4856 "RandomUniform", 4857 shape_i=shape, 4858 dtype_i=scalar_type.onnx_type(), 4859 ) 4860 4861 4862@_onnx_symbolic("aten::randn_like") 4863def randn_like( 4864 g: jit_utils.GraphContext, 4865 self, 4866 dtype, 4867 layout=None, 4868 device=None, 4869 pin_memory=False, 4870 memory_format=None, 4871): 4872 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 4873 if dtype is None: 4874 scalar_type = _type_utils.JitScalarType.from_value( 4875 self, _type_utils.JitScalarType.FLOAT 4876 ) 4877 else: 4878 scalar_type = _type_utils.JitScalarType(dtype) 4879 return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type()) 4880 4881 4882@_onnx_symbolic("aten::rand_like") 4883def rand_like( 4884 g: jit_utils.GraphContext, 4885 self, 4886 dtype, 4887 layout=None, 4888 device=None, 4889 pin_memory=False, 4890 memory_format=None, 4891): 4892 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 4893 if dtype is None: 4894 dtype = _type_utils.JitScalarType.from_value( 4895 self, _type_utils.JitScalarType.FLOAT 4896 ) 4897 return g.op( 4898 "RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type() 4899 ) 4900 4901 4902@_onnx_symbolic("aten::rrelu") 4903@symbolic_helper.parse_args("v", "f", "f", "i", "none") 4904def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator): 4905 if not training: 4906 slope = (upper + lower) / 2.0 4907 return g.op("LeakyRelu", input, alpha_f=slope) 4908 p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower) 4909 return g.op("PRelu", input, p) 4910 4911 4912@_onnx_symbolic("aten::bernoulli") 4913def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): 4914 if out is not None and not symbolic_helper._is_none(out): 4915 symbolic_helper._unimplemented( 4916 "Bernoulli", "out parameter is not supported for bernoulli", input 4917 ) 4918 if generator is not None and not symbolic_helper._is_none(generator): 4919 symbolic_helper._unimplemented( 4920 "Bernoulli", "generator is not supported for bernoulli", input 4921 ) 4922 4923 dtype = _type_utils.JitScalarType.from_value( 4924 input, _type_utils.JitScalarType.UNDEFINED 4925 ) 4926 if dtype == _type_utils.JitScalarType.UNDEFINED: 4927 return symbolic_helper._unimplemented( 4928 "Bernoulli", "input dtype not accessible", input 4929 ) 4930 4931 rands = g.op( 4932 "RandomUniformLike", 4933 input, 4934 high_f=1.0, 4935 low_f=0.0, 4936 dtype_i=dtype.onnx_type(), 4937 ) 4938 prob = p if p is not None and not symbolic_helper._is_none(p) else input 4939 output = g.op("Less", rands, prob) 4940 return g.op("Cast", output, to_i=dtype.onnx_type()) 4941 4942 4943@_onnx_symbolic("aten::log_sigmoid") 4944@symbolic_helper.parse_args("v") 4945def log_sigmoid(g: jit_utils.GraphContext, input): 4946 p = g.op("Sigmoid", input) 4947 return g.op("Log", p) 4948 4949 4950@_onnx_symbolic("aten::erf") 4951@symbolic_helper.parse_args("v") 4952def erf(g: jit_utils.GraphContext, input): 4953 return g.op("Erf", input) 4954 4955 4956@_onnx_symbolic("aten::flatten") 4957@symbolic_helper.quantized_args(True, False, False) 4958@symbolic_helper.parse_args("v", "i", "i") 4959def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): 4960 dim = symbolic_helper._get_tensor_rank(input) 4961 if dim is None: 4962 return symbolic_helper._unimplemented( 4963 "dim", 4964 "ONNX and PyTorch use different strategies to split the input. " 4965 "Input rank must be known at export time.", 4966 input, 4967 ) 4968 4969 if dim == 0: 4970 return symbolic_helper._reshape_helper(g, input, [1]) 4971 if dim == 1: 4972 return g.op("Identity", input) 4973 # TODO: remove this as onnx opset 11 spec allows negative axes 4974 if end_dim < 0: 4975 end_dim = dim + end_dim 4976 # use ONNX's Flatten operator for cases where the output shape is 2D 4977 if start_dim == 1 and end_dim == dim - 1: 4978 return g.op("Flatten", input, axis_i=start_dim) 4979 if start_dim == 0 and end_dim == dim - 2: 4980 return g.op("Flatten", input, axis_i=end_dim + 1) 4981 4982 return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) 4983 4984 4985@_onnx_symbolic("aten::nonzero") 4986@symbolic_helper.parse_args("v") 4987def nonzero(g: jit_utils.GraphContext, input): 4988 """Emitted from `torch.nonzero(x, as_tuple=False)`""" 4989 return t(g, g.op("NonZero", input)) 4990 4991 4992@_onnx_symbolic("aten::nonzero_numpy") 4993# Emitted from `torch.nonzero(x, as_tuple=True)` 4994def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): 4995 return unbind(g, nonzero(g, input), 1, _outputs=_outputs) 4996 4997 4998@_onnx_symbolic("aten::isnan") 4999@symbolic_helper.parse_args("v") 5000def isnan(g: jit_utils.GraphContext, input): 5001 output = g.op("IsNaN", input) 5002 return output 5003 5004 5005@_onnx_symbolic("aten::any") 5006def _any(g: jit_utils.GraphContext, *args): 5007 # aten::any(Tensor self) 5008 if len(args) == 1: 5009 input = args[0] 5010 dim, keepdim = None, 0 5011 # aten::any(Tensor self, int[]? dim, bool keepdim) 5012 else: 5013 input, dim, keepdim = args 5014 # Can be int list or single int 5015 dim = symbolic_helper._parse_arg(dim, "t") 5016 dim = [int(d) for d in dim.view(-1)] 5017 keepdim = symbolic_helper._parse_arg(keepdim, "i") 5018 input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) 5019 input_sum = symbolic_helper._reducesum_helper( 5020 g, input, axes_i=dim, keepdims_i=keepdim 5021 ) 5022 return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) 5023 5024 5025@_onnx_symbolic("aten::all") 5026def _all(g: jit_utils.GraphContext, *args): 5027 input = g.op("Not", args[0]) 5028 # aten::all(Tensor self) 5029 if len(args) == 1: 5030 return g.op("Not", _any(g, input)) 5031 # aten::all(Tensor self, int[]? dim, bool keepdim) 5032 else: 5033 return g.op("Not", _any(g, input, args[1], args[2])) 5034 5035 5036@_onnx_symbolic("aten::narrow") 5037@symbolic_helper.parse_args("v", "i", "i", "i") 5038def narrow(g: jit_utils.GraphContext, input, dim, start, length): 5039 return symbolic_helper._slice_helper( 5040 g, input, axes=[dim], starts=[start], ends=[start + length] 5041 ) 5042 5043 5044@_onnx_symbolic("aten::argmax") 5045@symbolic_helper.parse_args("v", "v", "b") 5046def argmax( 5047 g: jit_utils.GraphContext, 5048 input: torch._C.Value, 5049 dim: torch._C.Value, 5050 keepdim: bool, 5051): 5052 return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") 5053 5054 5055@_onnx_symbolic("aten::argmin") 5056@symbolic_helper.parse_args("v", "v", "b") 5057def argmin( 5058 g: jit_utils.GraphContext, 5059 input: torch._C.Value, 5060 dim: torch._C.Value, 5061 keepdim: bool, 5062): 5063 return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") 5064 5065 5066@_onnx_symbolic("aten::scatter") 5067@symbolic_helper.parse_args("v", "i", "v", "v") 5068def scatter(g: jit_utils.GraphContext, self, dim, index, src): 5069 src_type = _type_utils.JitScalarType.from_value( 5070 src, _type_utils.JitScalarType.UNDEFINED 5071 ) 5072 src = symbolic_helper._maybe_get_scalar(src) 5073 if symbolic_helper._is_value(src): 5074 return g.op("Scatter", self, index, src, axis_i=dim) 5075 else: 5076 # Check if scalar "src" has same type as self (PyTorch allows different 5077 # type for scalar src (but not when src is tensor)). If not, insert Cast node. 5078 self_scalar_type = _type_utils.JitScalarType.from_value(self) 5079 if self_scalar_type != src_type: 5080 src = g.op("Cast", src, to_i=self_scalar_type.onnx_type()) 5081 return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) 5082 5083 5084@_onnx_symbolic("aten::scatter_add") 5085@symbolic_helper.parse_args("v", "i", "v", "v") 5086def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): 5087 scalar_type = symbolic_helper._try_get_scalar_type(self) 5088 if scalar_type is None: 5089 return symbolic_helper._unimplemented( 5090 "scatter_add", "input dtype not accessible", self 5091 ) 5092 sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False) 5093 if sizes: 5094 to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype())) 5095 else: 5096 to_add = zeros_like(g, self, scalar_type) 5097 to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src) 5098 return add(g, self, to_add) 5099 5100 5101@_onnx_symbolic("aten::log2") 5102def log2(g: jit_utils.GraphContext, self): 5103 _ln2 = 0.693147180559945309 5104 return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2))) 5105 5106 5107@_onnx_symbolic("aten::is_floating_point") 5108def is_floating_point(g: jit_utils.GraphContext, self): 5109 if symbolic_helper._is_fp(self): 5110 return g.op("Constant", value_t=torch.BoolTensor([1])) 5111 return g.op("Constant", value_t=torch.BoolTensor([0])) 5112 5113 5114@_onnx_symbolic("aten::__is_") 5115def __is_(g: jit_utils.GraphContext, self, other): 5116 if symbolic_helper._is_none(other): 5117 if symbolic_helper._is_none(self): 5118 return g.op("Constant", value_t=torch.BoolTensor([1])) 5119 return g.op("Constant", value_t=torch.BoolTensor([0])) 5120 return eq(g, self, other) 5121 5122 5123@_onnx_symbolic("aten::__isnot_") 5124@wrap_logical_op_with_negation 5125def __isnot_(g: jit_utils.GraphContext, self, other): 5126 return __is_(g, self, other) 5127 5128 5129@_onnx_symbolic("aten::one_hot") 5130def one_hot(g: jit_utils.GraphContext, self, num_classes): 5131 values = g.op("Constant", value_t=torch.LongTensor([0, 1])) 5132 # onnxruntime supports limited type combinations for OneHot. 5133 if _type_utils.JitScalarType.from_value( 5134 num_classes, _type_utils.JitScalarType.UNDEFINED 5135 ) in { 5136 _type_utils.JitScalarType.UINT8, 5137 _type_utils.JitScalarType.INT8, 5138 _type_utils.JitScalarType.INT, 5139 _type_utils.JitScalarType.INT16, 5140 }: 5141 num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64) 5142 return g.op("OneHot", self, num_classes, values, axis_i=-1) 5143 5144 5145@_onnx_symbolic("aten::gather") 5146@symbolic_helper.parse_args("v", "i", "v", "v") 5147def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): 5148 if symbolic_helper._maybe_get_const(sparse_grad, "i"): 5149 return symbolic_helper._unimplemented("gather", "sparse_grad == True", self) 5150 # NOTE: This workaround is needed since GatherElement is only supported 5151 # since opset 11, and Gather in ONNX is not the same as torch.gather. 5152 scalar_type = _type_utils.JitScalarType.from_value(self) 5153 values = g.op("Constant", value_t=torch.LongTensor([0, 1])) 5154 depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) 5155 index = g.op( 5156 "Cast", 5157 g.op("OneHot", index, depth, values, axis_i=dim), 5158 to_i=scalar_type.onnx_type(), 5159 ) 5160 mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index) 5161 return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) 5162 5163 5164@symbolic_helper.parse_args("v", "is", "i", "i") 5165def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim): 5166 return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim) 5167 5168 5169@_onnx_symbolic("aten::std") 5170def std(g: jit_utils.GraphContext, input, *args): 5171 var, _ = var_mean(g, input, *args) 5172 return g.op("Sqrt", var) 5173 5174 5175@_onnx_symbolic("aten::var") 5176def var(g: jit_utils.GraphContext, input, *args): 5177 var, _ = var_mean(g, input, *args) 5178 return var 5179 5180 5181@_onnx_symbolic("aten::var_mean") 5182def var_mean(g: jit_utils.GraphContext, input, *args): 5183 if len(args) == 1: 5184 return _var_mean(g, input, None, args[0], None) 5185 else: 5186 return _var_mean(g, input, *args) 5187 5188 5189@_onnx_symbolic("aten::std_mean") 5190def std_mean(g: jit_utils.GraphContext, input, *args): 5191 var, mean = var_mean(g, input, *args) 5192 return g.op("Sqrt", var), mean 5193 5194 5195@_onnx_symbolic("aten::logsumexp") 5196@symbolic_helper.parse_args("v", "is", "i") 5197def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): 5198 return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) 5199 5200 5201@_onnx_symbolic("aten::arange") 5202def arange(g: jit_utils.GraphContext, *args): 5203 def _get_arange_dtype(dtype): 5204 dtype = symbolic_helper._maybe_get_const(dtype, "i") 5205 return dtype 5206 5207 def _float_step_convert(range_tensor): 5208 if symbolic_helper._is_fp(range_tensor): 5209 range_tensor = g.op( 5210 "Cast", 5211 g.op("Ceil", range_tensor), 5212 to_i=_type_utils.JitScalarType.INT64.onnx_type(), 5213 ) 5214 return range_tensor 5215 5216 if len(args) == 2 or len(args) == 5: 5217 if len(args) == 2: 5218 # aten::arange(Scalar end, Tensor out) 5219 dtype = None 5220 else: 5221 # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) 5222 dtype = _get_arange_dtype(args[1]) 5223 dtype, end, start, step = symbolic_helper._arange_cast_helper( 5224 g, end=args[0], dtype=dtype 5225 ) 5226 end = symbolic_helper._unsqueeze_helper(g, end, [0]) 5227 range_tensor = _float_step_convert(end) 5228 arange_tensor = symbolic_helper._squeeze_helper( 5229 g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1] 5230 ) 5231 return g.op( 5232 "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() 5233 ) 5234 elif len(args) == 4 or len(args) == 7: 5235 if len(args) == 4: 5236 # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) 5237 dtype = None 5238 else: 5239 # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) 5240 dtype = _get_arange_dtype(args[3]) 5241 dtype, end, start, step = symbolic_helper._arange_cast_helper( 5242 g, start=args[0], end=args[1], step=args[2], dtype=dtype 5243 ) 5244 step = symbolic_helper._unsqueeze_helper(g, step, [0]) 5245 end = symbolic_helper._unsqueeze_helper(g, end, [0]) 5246 start = symbolic_helper._unsqueeze_helper(g, start, [0]) 5247 range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step)) 5248 arange_tensor = symbolic_helper._squeeze_helper( 5249 g, nonzero(g, ones(g, range_tensor, None, None, None)), [1] 5250 ) 5251 arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) 5252 return g.op( 5253 "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() 5254 ) 5255 elif len(args) == 6: 5256 # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) 5257 dtype = _get_arange_dtype(args[2]) 5258 dtype, end, start, step = symbolic_helper._arange_cast_helper( 5259 g, start=args[0], end=args[1], dtype=dtype 5260 ) 5261 end = symbolic_helper._unsqueeze_helper(g, end, [0]) 5262 start = symbolic_helper._unsqueeze_helper(g, start, [0]) 5263 range_tensor = _float_step_convert(g.op("Sub", end, start)) 5264 arange_tensor = g.op( 5265 "Add", 5266 symbolic_helper._squeeze_helper( 5267 g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1] 5268 ), 5269 start, 5270 ) 5271 return g.op( 5272 "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() 5273 ) 5274 5275 return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments") 5276 5277 5278@_onnx_symbolic("aten::linspace") 5279def linspace( 5280 g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory 5281): 5282 range_tensor = symbolic_helper._arange_helper(g, steps, None) 5283 step = div( 5284 g, 5285 sub(g, end, start), 5286 sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))), 5287 ) 5288 return add(g, mul(g, range_tensor, step), start) 5289 5290 5291@_onnx_symbolic("aten::lift") 5292def lift(g: jit_utils.GraphContext, self): 5293 # at::lift() is a no-op from the perspective of tracing for onnx 5294 return self 5295 5296 5297@_onnx_symbolic("aten::masked_fill") 5298def masked_fill(g: jit_utils.GraphContext, self, mask, value): 5299 """Implement the masked_fill functionality available for a pytorch tensor in ONNX. 5300 5301 Fills elements of the input tensor with `value` where `mask` is True. 5302 """ 5303 mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) 5304 value = symbolic_helper._maybe_get_scalar(value) 5305 return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) 5306 5307 5308@_onnx_symbolic("aten::masked_fill_") 5309def masked_fill_(g: jit_utils.GraphContext, self, mask, value): 5310 return masked_fill(g, self, mask, value) 5311 5312 5313@_onnx_symbolic("aten::index") 5314def index(g: jit_utils.GraphContext, self, index): 5315 if symbolic_helper._is_packed_list(index): 5316 indices = symbolic_helper._unpack_list(index) 5317 else: 5318 indices = [index] 5319 5320 def try_mask_to_index(index): 5321 if not symbolic_helper._is_none(index) and ( 5322 _type_utils.JitScalarType.from_value( 5323 index, _type_utils.JitScalarType.UNDEFINED 5324 ) 5325 == _type_utils.JitScalarType.UINT8 5326 or symbolic_helper._is_bool(index) 5327 ): 5328 if g.opset < 9: 5329 raise errors.SymbolicValueError( 5330 "Exporting masked indices are only supported after ONNX opset 9.", 5331 self, 5332 ) 5333 warnings.warn( 5334 "Exporting aten::index operator with indices of type Byte. " 5335 "Only 1-D indices are supported. In any other case, " 5336 "this will produce an incorrect ONNX graph." 5337 ) 5338 index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1]) 5339 return index 5340 5341 indices = [try_mask_to_index(idx) for idx in indices] 5342 if len(indices) == 1: 5343 return symbolic_helper._select_helper( 5344 g, self, 0, indices[0], apply_reshape=False 5345 ) 5346 else: 5347 # Multiple tensors as indices. Each tensor could either be 5348 # 1. prim::Constant() 5349 # representing ":" in python indexing. E.g. tensor[:, :] 5350 # 2. prim::Constant[value=...] or tensor output 5351 # representing advanced indexing. E.g. tensor[[0, 1], [2, 0]]. 5352 # For more info on advanced indexing, 5353 # check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing 5354 5355 # Consider a general case of 5356 # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] 5357 # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":". 5358 # Same results can be achieved through transposing t into 5359 # t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] 5360 # and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t 5361 # and process the tensor indices. 5362 # t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n] 5363 # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)) 5364 # After gather, reshape and transpose back. 5365 adv_idx_indices = [ 5366 i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx) 5367 ] 5368 5369 if len(adv_idx_indices) == 0: 5370 return self 5371 elif len(adv_idx_indices) == 1: 5372 return index_select( 5373 g, self, adv_idx_indices[0], indices[adv_idx_indices[0]] 5374 ) 5375 else: 5376 rank = symbolic_helper._get_tensor_rank(self) 5377 if rank is None: 5378 return symbolic_helper._unimplemented( 5379 "aten::index", 5380 "operator of advanced indexing on tensor of unknown rank. ", 5381 self, 5382 ) 5383 # TODO: If indexing is supported natively in ONNX in future opsets, 5384 # update the warning to recommend exporting with higher opset version. 5385 warnings.warn( 5386 "Exporting aten::index operator of advanced indexing in opset " 5387 f"{GLOBALS.export_onnx_opset_version}" 5388 " is achieved by combination of multiple ONNX operators, " 5389 "including Reshape, Transpose, Concat, and Gather. " 5390 "If indices include negative values, the exported graph will produce incorrect results." 5391 ) 5392 adv_idx_count = len(adv_idx_indices) 5393 shape_tensor = _shape_as_tensor(g, self) 5394 dim_tensor_list = [ 5395 g.op( 5396 "Gather", 5397 shape_tensor, 5398 g.op("Constant", value_t=torch.LongTensor([dim])), 5399 axis_i=0, 5400 ) 5401 for dim in range(rank) 5402 ] 5403 5404 self = g.op( 5405 "Transpose", 5406 self, 5407 perm_i=adv_idx_indices 5408 + [i for i in range(rank) if i not in adv_idx_indices], 5409 ) 5410 self = g.op("Flatten", self, axis_i=adv_idx_count) 5411 5412 # Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well. 5413 cum_adv_index = indices[adv_idx_indices[-1]] 5414 multiplier = dim_tensor_list[adv_idx_indices[-1]] 5415 for i in range(adv_idx_count - 2, -1, -1): 5416 adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier) 5417 cum_adv_index = g.op("Add", cum_adv_index, adv_index) 5418 multiplier = g.op( 5419 "Mul", multiplier, dim_tensor_list[adv_idx_indices[i]] 5420 ) 5421 5422 # perform gather 5423 self = index_select(g, self, 0, cum_adv_index) 5424 5425 cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index) 5426 # check if all advanced indices are consecutive. 5427 # Refer to https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing 5428 # to understand how the subarray position is decided. 5429 if adv_idx_indices == list( 5430 range(adv_idx_indices[0], adv_idx_indices[-1] + 1) 5431 ): 5432 # unfold regular index axes 5433 folded_adv_idx_shape_list = [ 5434 g.op("Constant", value_t=torch.LongTensor([-1])) 5435 ] + [ 5436 dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices 5437 ] 5438 folded_adv_idx_shape = g.op( 5439 "Concat", *folded_adv_idx_shape_list, axis_i=0 5440 ) 5441 self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape) 5442 5443 # Transpose folded advanced indexed axis to its original location. 5444 adv_idx_permute = ( 5445 list(range(1, adv_idx_indices[0] + 1)) 5446 + [0] 5447 + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1)) 5448 ) 5449 self = g.op("Transpose", self, perm_i=adv_idx_permute) 5450 5451 # unfold advanced index axes 5452 final_shape_list = ( 5453 [dim_tensor_list[i] for i in range(adv_idx_indices[0])] 5454 + [cum_adv_index_shape_tensor] 5455 + [ 5456 dim_tensor_list[i] 5457 for i in range(adv_idx_indices[0], rank) 5458 if i not in adv_idx_indices 5459 ] 5460 ) 5461 final_shape = g.op("Concat", *final_shape_list, axis_i=0) 5462 else: 5463 final_shape = g.op( 5464 "Concat", 5465 cum_adv_index_shape_tensor, 5466 *[ 5467 dim_tensor_list[i] 5468 for i in range(rank) 5469 if i not in adv_idx_indices 5470 ], 5471 axis_i=0, 5472 ) 5473 5474 return symbolic_helper._reshape_helper(g, self, final_shape) 5475 5476 5477@_onnx_symbolic("aten::linalg_norm") 5478@symbolic_helper.parse_args("v", "v", "is", "b", "v") 5479def linalg_norm( 5480 g: jit_utils.GraphContext, 5481 self: torch._C.Value, 5482 ord: torch._C.Value, 5483 dim: Sequence[int] | None, 5484 keepdim: bool, 5485 dtype: torch._C.Value, 5486): 5487 # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html 5488 ord_value = None 5489 if dim is None: 5490 if symbolic_helper._is_none(ord): 5491 self = symbolic_helper._reshape_helper(g, self, [-1]) 5492 ord = g.op("Constant", value_t=torch.LongTensor([2])) 5493 self_dim = symbolic_helper._get_tensor_rank(self) 5494 if self_dim is None: 5495 return symbolic_helper._unimplemented( 5496 "dim", "Input rank must be known at export time.", self 5497 ) 5498 if self_dim == 1: 5499 ord_value = symbolic_helper._parse_arg(ord, "f") 5500 else: 5501 dim = [0, 1] 5502 else: 5503 if len(dim) == 1: 5504 if symbolic_helper._is_none(ord): 5505 ord = g.op("Constant", value_t=torch.LongTensor([2])) 5506 ord_value = symbolic_helper._parse_arg(ord, "f") 5507 if ord_value: 5508 return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype) 5509 return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) 5510 5511 5512@_onnx_symbolic("aten::linalg_vector_norm") 5513@symbolic_helper.parse_args("v", "f", "is", "b", "v") 5514def linalg_vector_norm( 5515 g: jit_utils.GraphContext, 5516 self: torch._C.Value, 5517 ord: float, 5518 dim: Sequence[int] | None, 5519 keepdim: bool, 5520 dtype: torch._C.Value, 5521): 5522 return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) 5523 5524 5525@_onnx_symbolic("aten::linalg_matrix_norm") 5526@symbolic_helper.parse_args("v", "v", "is", "b", "v") 5527def linalg_matrix_norm( 5528 g: jit_utils.GraphContext, 5529 self: torch._C.Value, 5530 ord: torch._C.Value, 5531 dim: list[int], 5532 keepdim: bool, 5533 dtype: torch._C.Value, 5534): 5535 # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html 5536 ord_value = symbolic_helper._parse_arg(ord, "s") 5537 if ord_value == "fro": 5538 return frobenius_norm(g, self, dim, keepdim) 5539 elif ord_value == "nuc": 5540 return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self) 5541 else: 5542 ord_value = symbolic_helper._parse_arg(ord, "f") 5543 if ord_value is None: 5544 return frobenius_norm(g, self, dim, keepdim) 5545 if ord_value == 2 or ord_value == -2: 5546 # ord = 2/-2 unimplemented due to lack of operators 5547 # used to calculate singular values 5548 return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self) 5549 # Wrap the dim vector to handle negative dim values 5550 self_dim = symbolic_helper._get_tensor_rank(self) 5551 if self_dim is None: 5552 return symbolic_helper._unimplemented( 5553 "linalg.matrix_norm", "Input rank must be known at export time.", self 5554 ) 5555 # Common implementation for cases with 5556 # ord = 1/-1 and ord = inf/-inf 5557 if dim[0] < 0: 5558 dim[0] += self_dim 5559 if dim[1] < 0: 5560 dim[1] += self_dim 5561 5562 if ord_value == math.inf or ord_value == -math.inf: 5563 dim[0], dim[1] = dim[1], dim[0] 5564 if dim[1] > dim[0] and not keepdim: 5565 dim[1] -= 1 5566 sum = symbolic_helper._reducesum_helper( 5567 g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim 5568 ) 5569 if ord_value > 0: 5570 result, indices = max( 5571 g, 5572 sum, 5573 dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), 5574 keepdim=keepdim, 5575 ) 5576 else: 5577 result, indices = min( 5578 g, 5579 sum, 5580 dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), 5581 keepdim=keepdim, 5582 ) 5583 return result 5584 5585 5586@_onnx_symbolic("aten::linalg_cross") 5587@symbolic_helper.parse_args("v", "v", "i") 5588def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1): 5589 return cross(g, input, other, dim) 5590 5591 5592@_onnx_symbolic("aten::frobenius_norm") 5593@symbolic_helper.parse_args("v", "is", "b") 5594def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): 5595 sqr = g.op("Mul", self, self) 5596 sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) 5597 return g.op("Sqrt", sumsqr) 5598 5599 5600@_onnx_symbolic("aten::multinomial") 5601@symbolic_helper.parse_args("v", "i", "b", "v") 5602def multinomial( 5603 g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None 5604): 5605 if generator is not None and not symbolic_helper._is_none(generator): 5606 symbolic_helper._unimplemented( 5607 "Multinomial", "generator is not supported for multinomial", input 5608 ) 5609 if not replacement and num_samples > 1: 5610 symbolic_helper._unimplemented( 5611 "Multinomial", 5612 "replacement=False when num_samples > 1 is not supported for multinomial", 5613 input, 5614 ) 5615 5616 log_input = log(g, input) 5617 return g.op( 5618 "Multinomial", 5619 log_input, 5620 dtype_i=_C_onnx.TensorProtoDataType.INT64, 5621 sample_size_i=num_samples, 5622 ) 5623 5624 5625@_onnx_symbolic("aten::baddbmm") 5626def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha): 5627 scalar_type = _type_utils.JitScalarType.from_value(self) 5628 batch_mul = matmul(g, batch1, batch2) 5629 mul_a = mul( 5630 g, 5631 batch_mul, 5632 g.op("Cast", alpha, to_i=scalar_type.onnx_type()), 5633 ) 5634 mul_b = mul( 5635 g, 5636 self, 5637 g.op("Cast", beta, to_i=scalar_type.onnx_type()), 5638 ) 5639 return add(g, mul_a, mul_b) 5640 5641 5642@_onnx_symbolic("aten::meshgrid") 5643@symbolic_helper.parse_args("v", "s") 5644def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: str | None = None): 5645 if indexing is None: 5646 indexing = "ij" 5647 elif indexing not in {"ij", "xy"}: 5648 raise errors.SymbolicValueError( 5649 f"Unsupported indexing: {indexing}", tensor_list 5650 ) 5651 unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list) 5652 if indexing == "xy": 5653 unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1] 5654 tensors = [ 5655 symbolic_helper._reshape_helper( 5656 g, t, g.op("Constant", value_t=torch.LongTensor([-1])) 5657 ) 5658 for t in unpacked_tensor_list 5659 ] 5660 tensors_shape = [g.op("Shape", t) for t in tensors] 5661 out_shape = g.op("Concat", *tensors_shape, axis_i=0) 5662 out = [] 5663 for i, t in enumerate(tensors): 5664 shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len( 5665 tensors 5666 ) 5667 shape_i[i] = tensors_shape[i] 5668 t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0)) 5669 out.append(g.op("Expand", t_reshaped, out_shape)) 5670 if indexing == "xy": 5671 out[0], out[1] = out[1], out[0] 5672 return g.op("prim::ListConstruct", *out) 5673 5674 5675@_onnx_symbolic("aten::remainder") 5676def remainder(g: jit_utils.GraphContext, input, other): 5677 div = _floor_divide(g, input, other) 5678 quo = g.op("Mul", div, other) 5679 return g.op("Sub", input, quo) 5680 5681 5682@_onnx_symbolic("aten::gelu") 5683@symbolic_helper.parse_args("v", "s") 5684def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"): 5685 if approximate == "tanh": 5686 kBeta = math.sqrt(2 / math.pi) 5687 kKappa = 0.044715 5688 5689 beta = torch.tensor(kBeta, dtype=torch.double) 5690 kappa = torch.tensor(kKappa, dtype=torch.double) 5691 one = torch.tensor(1.0, dtype=torch.double) 5692 half = torch.tensor(0.5, dtype=torch.double) 5693 5694 self_cube = mul(g, self, mul(g, self, self)) 5695 inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) 5696 return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) 5697 else: 5698 _sqrt2 = 1.4142135623730951 5699 erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) 5700 erf_plusone = add( 5701 g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)) 5702 ) 5703 return mul( 5704 g, 5705 mul(g, self, erf_plusone), 5706 g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)), 5707 ) 5708 5709 5710@_onnx_symbolic("aten::group_norm") 5711@symbolic_helper.quantized_args(True, False, False, False) 5712@symbolic_helper.parse_args("v", "i", "v", "v", "f", "i") 5713def group_norm( 5714 g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled 5715): 5716 channel_size = symbolic_helper._get_tensor_dim_size(input, 1) 5717 if channel_size is not None: 5718 assert channel_size % num_groups == 0 5719 input_rank = symbolic_helper._get_tensor_rank(input) 5720 if input_rank is None: 5721 return symbolic_helper._unimplemented("group_norm", "unknown input rank", input) 5722 # 0 in the shape list keeps dimension value unchanged. 5723 shape = [0, num_groups, -1] 5724 input_reshaped = symbolic_helper._reshape_helper( 5725 g, input, g.op("Constant", value_t=torch.LongTensor(shape)) 5726 ) 5727 5728 # C is always divisible by num_groups 5729 # Due to shape difference. we need to apply weight and bias after 5730 # instance norm computation and reshape 5731 weight_ = g.op( 5732 "Constant", 5733 value_t=torch.tensor( 5734 [1.0] * num_groups, 5735 dtype=_type_utils.JitScalarType.from_value(input).dtype(), 5736 ), 5737 ) 5738 bias_ = g.op( 5739 "Constant", 5740 value_t=torch.tensor( 5741 [0.0] * num_groups, 5742 dtype=_type_utils.JitScalarType.from_value(input).dtype(), 5743 ), 5744 ) 5745 5746 norm_reshaped = g.op( 5747 "InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps 5748 ) 5749 norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input)) 5750 5751 if weight is None or weight.node().mustBeNone(): 5752 weight_value = torch.tensor( 5753 [1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() 5754 ) 5755 weight = g.op("Constant", value_t=weight_value) 5756 if bias is None or bias.node().mustBeNone(): 5757 bias_value = torch.tensor( 5758 [0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() 5759 ) 5760 bias = g.op("Constant", value_t=bias_value) 5761 5762 # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] 5763 axes = list(range(1, input_rank - 1)) 5764 return add( 5765 g, 5766 mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)), 5767 symbolic_helper._unsqueeze_helper(g, bias, axes), 5768 ) 5769 5770 5771@_onnx_symbolic("aten::_weight_norm") 5772@symbolic_helper.parse_args("v", "v", "i") 5773def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): 5774 rank = symbolic_helper._get_tensor_rank(weight_v) 5775 if rank is not None: 5776 # W = g * ((v) / ||v||) 5777 # Compute norm_except_dim for l2 norm. dim = None means over all dims 5778 # torch's weight_norm module sets dim = -1 if it's None. 5779 # This conflicts the logic for negative axes to access dims backwards 5780 # TODO: Might need a fix in torch group_norm module 5781 axes = list(range(rank)) 5782 if dim is not None: 5783 if dim < -1: 5784 dim += rank 5785 if dim != -1: 5786 axes.remove(dim) 5787 norm_v = norm(g, weight_v, 2, axes, 1) 5788 div = g.op("Div", weight_v, norm_v) 5789 return g.op("Mul", div, weight_g) 5790 raise errors.SymbolicValueError( 5791 "Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", 5792 weight_v, 5793 ) 5794 5795 5796@_onnx_symbolic("aten::dim") 5797def dim(g: jit_utils.GraphContext, self): 5798 """Implement the dim functionality available for a pytorch tensor in ONNX""" 5799 # ONNX does not support dim directly in this opset so we can use 2 ops to get the info 5800 shape = g.op("Shape", self) 5801 return g.op("Size", shape) 5802 5803 5804@_onnx_symbolic("aten::__contains_") 5805def __contains_(g: jit_utils.GraphContext, self, element): 5806 unpacked_list = symbolic_helper._unpack_list(self) 5807 if all( 5808 symbolic_helper._is_constant(x) for x in unpacked_list 5809 ) and symbolic_helper._is_constant(element): 5810 return g.op( 5811 "Constant", 5812 value_t=torch.tensor( 5813 symbolic_helper._node_get(element.node(), "value") 5814 in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list) 5815 ), 5816 ) 5817 5818 raise errors.SymbolicValueError( 5819 "Unsupported: ONNX export of __contains__ for non-constant list or element.", 5820 self, 5821 ) 5822 5823 5824@_onnx_symbolic("aten::__getitem_") 5825def __getitem_(g: jit_utils.GraphContext, self, i): 5826 return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) 5827 5828 5829@_onnx_symbolic("aten::item") 5830def item(g: jit_utils.GraphContext, self): 5831 return self 5832 5833 5834@_onnx_symbolic("aten::take") 5835def take(g: jit_utils.GraphContext, self, index): 5836 self_flattened = symbolic_helper._reshape_helper( 5837 g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) 5838 ) 5839 out = index_select(g, self_flattened, 0, index) 5840 out = reshape_as(g, out, index) 5841 return out 5842 5843 5844def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target): 5845 diff_ = sub(g, target, input) 5846 exp_ = exp(g, target) 5847 output = mul(g, exp_, diff_) 5848 return output 5849 5850 5851def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target): 5852 log_ = log(g, target) 5853 diff_ = sub(g, log_, input) 5854 output_pos = mul(g, target, diff_) 5855 zeros_ = zeros_like(g, output_pos) 5856 mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0))) 5857 output = where(g, mask_, output_pos, zeros_) 5858 return output 5859 5860 5861@_onnx_symbolic("aten::kl_div") 5862@symbolic_helper.parse_args("v", "v", "i", "b") 5863def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target): 5864 if log_target: 5865 output = _kl_div_log_target_impl(g, input, target) 5866 else: 5867 output = _kl_div_non_log_target_impl(g, input, target) 5868 5869 if reduction == 0: 5870 return output 5871 elif reduction == 1: 5872 return g.op("ReduceMean", output, keepdims_i=0) 5873 elif reduction == 2: 5874 return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) 5875 else: 5876 return symbolic_helper._onnx_unsupported( 5877 "kl_div with reduction other than none, mean, or sum.", input 5878 ) 5879 5880 5881@_onnx_symbolic("aten::mse_loss") 5882@symbolic_helper.parse_args("v", "v", "i") 5883def mse_loss(g: jit_utils.GraphContext, input, target, reduction): 5884 output = mul(g, sub(g, input, target), sub(g, input, target)) 5885 if reduction == 0: 5886 return output 5887 elif reduction == 1: 5888 return g.op("ReduceMean", output, keepdims_i=0) 5889 elif reduction == 2: 5890 return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) 5891 else: 5892 return symbolic_helper._onnx_unsupported( 5893 "mse_loss with reduction other than none, mean, or sum.", input 5894 ) 5895 5896 5897@_onnx_symbolic("aten::as_strided") 5898@symbolic_helper.quantized_args(True) 5899@symbolic_helper.parse_args("v", "v", "is", "i") 5900def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None): 5901 sizes = symbolic_helper._maybe_get_const(sizes, "is") 5902 rank = len(strides) 5903 self_1d = symbolic_helper._reshape_helper( 5904 g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) 5905 ) 5906 ind: torch.Tensor | None 5907 if not symbolic_helper._is_value(sizes): 5908 ind = torch.tensor([0], dtype=torch.long) 5909 for i, (size, stride) in enumerate(zip(sizes, strides)): 5910 r_size = [1] * rank 5911 r_size[i] = -1 5912 ind = ind + torch.arange(size).view(r_size) * stride 5913 if offset: 5914 ind = ind + offset 5915 return g.op("Gather", self_1d, g.op("Constant", value_t=ind)) 5916 else: 5917 ind = None 5918 for i, stride in enumerate(strides): 5919 r_size = [1] * rank 5920 r_size[i] = -1 5921 size = select( 5922 g, 5923 sizes, 5924 g.op("Constant", value_t=torch.tensor([0])), 5925 g.op("Constant", value_t=torch.tensor(i)), 5926 ) 5927 tmp_ind = symbolic_helper._reshape_helper( 5928 g, 5929 arange(g, size, 4, None, None, None), 5930 g.op("Constant", value_t=torch.tensor(r_size)), 5931 ) 5932 tmp_ind = g.op( 5933 "Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])) 5934 ) 5935 if ind is None: 5936 ind = tmp_ind 5937 else: 5938 ind = g.op("Add", ind, tmp_ind) 5939 if offset: 5940 ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) 5941 return g.op("Gather", self_1d, ind) 5942 5943 5944@_onnx_symbolic("aten::__derive_index") 5945def __derive_index(g: jit_utils.GraphContext, index, start, step): 5946 return g.op("Add", start, g.op("Mul", index, step)) 5947 5948 5949@_onnx_symbolic("aten::__range_length") 5950# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp 5951# if (step > 0 && lo < hi) { 5952# push(stack, 1 + (hi - 1 - lo) / step); 5953# } else if (step < 0 && lo > hi) { 5954# push(stack, 1 + (lo - 1 - hi) / (0 - step)); 5955# } else { 5956# push(stack, 0); 5957# } 5958def __range_length(g: jit_utils.GraphContext, lo, hi, step): 5959 sub = g.op("Sub", hi, lo) 5960 div = g.op("Ceil", true_divide(g, sub, step)) 5961 return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64) 5962 5963 5964@_onnx_symbolic("aten::linear") 5965def linear(g: jit_utils.GraphContext, input, weight, bias): 5966 rank = symbolic_helper._get_tensor_rank(input) 5967 weight = t(g, weight) 5968 if rank == 2 and not bias.node().mustBeNone(): 5969 alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) 5970 beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) 5971 output = addmm(g, bias, input, weight, alpha, beta) 5972 else: 5973 output = matmul(g, input, weight) 5974 if not bias.node().mustBeNone(): 5975 output = add(g, bias, output) 5976 5977 return output 5978 5979 5980@_onnx_symbolic("aten::hann_window") 5981@symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v") 5982def hann_window( 5983 g: jit_utils.GraphContext, 5984 window_length, 5985 periodic=True, 5986 dtype: int | None = None, 5987 layout=None, 5988 device=None, 5989 pin_memory=None, 5990 requires_grad=False, 5991): 5992 if dtype is None: 5993 dtype_ = torch.get_default_dtype() 5994 if not dtype_ or not dtype_.is_floating_point: 5995 dtype_ = torch.float 5996 scalar_type = _type_utils.JitScalarType.from_dtype(dtype_) 5997 else: 5998 scalar_type = _type_utils.JitScalarType(dtype) 5999 6000 n_array = arange(g, window_length, 4, None, None, None) 6001 output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT) 6002 output = mul( 6003 g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output 6004 ) 6005 6006 if periodic is False: 6007 window_length = sub( 6008 g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)) 6009 ) 6010 output = div(g, output, window_length) 6011 output = g.op( 6012 "Cast", 6013 square(g, sin(g, output)), 6014 to_i=scalar_type.onnx_type(), 6015 ) 6016 6017 return output 6018 6019 6020@_onnx_symbolic("aten::mv") 6021def mv(g: jit_utils.GraphContext, self, vec): 6022 return matmul(g, self, vec) 6023 6024 6025@_onnx_symbolic("aten::dot") 6026def dot(g: jit_utils.GraphContext, self, other): 6027 return matmul(g, self, other) 6028 6029 6030@_onnx_symbolic("aten::movedim") 6031@symbolic_helper.parse_args("v", "t", "t") 6032def movedim(g: jit_utils.GraphContext, self, source, destination): 6033 # This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim 6034 source = source.view(-1) 6035 destination = destination.view(-1) 6036 6037 assert source.size() == destination.size() 6038 6039 if (source == destination).all(): 6040 return self 6041 6042 self_rank = symbolic_helper._get_tensor_rank(self) 6043 assert self_rank is not None 6044 6045 perm = list(range(self_rank)) 6046 6047 src_dims = perm.copy() 6048 dst_dims = perm.copy() 6049 6050 for src, dst in zip(source.tolist(), destination.tolist()): 6051 perm[dst] = src 6052 src_dims[src] = -1 6053 dst_dims[dst] = -1 6054 6055 src_dims = [dim for dim in src_dims if dim != -1] 6056 dst_dims = [dim for dim in dst_dims if dim != -1] 6057 6058 for src, dst in zip(src_dims, dst_dims): 6059 perm[dst] = src 6060 6061 return g.op("Transpose", self, perm_i=perm) 6062 6063 6064@_onnx_symbolic("aten::fill") 6065@symbolic_helper.parse_args("v", "v") 6066def fill(g: jit_utils.GraphContext, self, value): 6067 scalar_type = _type_utils.JitScalarType.from_value( 6068 self, _type_utils.JitScalarType.FLOAT 6069 ) 6070 return full_like(g, self, value, scalar_type) 6071 6072 6073@_onnx_symbolic("aten::index_add") 6074def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): 6075 warnings.warn( 6076 "Warning: ONNX export does not support duplicated values in 'index' field, " 6077 + "this will cause the ONNX model to be incorrect." 6078 ) 6079 6080 # ONNX does not support "alpha" argument, unlike aten index_add 6081 # See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context 6082 if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: 6083 return symbolic_helper._unimplemented("index_add", "alpha != 1", self) 6084 6085 dim = symbolic_helper._maybe_get_const(dim, "i") 6086 if dim is None: 6087 raise errors.SymbolicValueError( 6088 "ONNX export does NOT support exporting 'index_add_()' function with " 6089 "unknown 'dim' value.", 6090 self, 6091 ) 6092 6093 self_dim_rank = symbolic_helper._get_tensor_rank(self) 6094 other_dim_rank = symbolic_helper._get_tensor_rank(other) 6095 6096 if self_dim_rank is None or other_dim_rank is None: 6097 raise errors.SymbolicValueError( 6098 "ONNX export does NOT support exporting 'index_add_()' function while " 6099 "the rank of self tensor or tensor to be added is unknown.", 6100 self, 6101 ) 6102 6103 if other_dim_rank != self_dim_rank: 6104 delta = self_dim_rank - other_dim_rank 6105 for i in range(delta): 6106 other = symbolic_helper._unsqueeze_helper( 6107 g, other, [symbolic_helper._get_tensor_rank(other)] 6108 ) 6109 6110 other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim) 6111 self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim) 6112 6113 if (other_dim_size is not None) and (self_dim_size is not None): 6114 if other_dim_size > self_dim_size: 6115 raise errors.SymbolicValueError( 6116 "ONNX export does not support exporting 'index_add_()' function with " 6117 "duplicated values in 'index' parameter yet.", 6118 self, 6119 ) 6120 6121 # Construct a new shape. It's almost as same as self except the size of the 'dim' 6122 # dimension is 1, so that we can expand other dimensions as expected. 6123 new_shape_axes = list(range(self_dim_rank)) 6124 new_shape_starts = [0 for i in range(self_dim_rank)] 6125 new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)] 6126 6127 new_shape = symbolic_helper._slice_helper( 6128 g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends 6129 ) 6130 other = expand_as(g, other, new_shape) 6131 6132 for i in range(dim): 6133 index = symbolic_helper._unsqueeze_helper(g, index, [0]) 6134 6135 for i in range(self_dim_rank - dim - 1): 6136 index = symbolic_helper._unsqueeze_helper( 6137 g, index, [symbolic_helper._get_tensor_rank(index)] 6138 ) 6139 6140 return scatter_add(g, self, dim, expand_as(g, index, other), other) 6141 6142 6143@_onnx_symbolic("aten::roll") 6144@symbolic_helper.parse_args("v", "is", "is") 6145def roll(g: jit_utils.GraphContext, self, shifts, dims): 6146 assert len(shifts) == len(dims) 6147 6148 result = self 6149 for i in range(len(shifts)): 6150 shapes = [] 6151 shape = symbolic_helper._slice_helper( 6152 g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize] 6153 ) 6154 shapes.append(shape) 6155 shape = symbolic_helper._slice_helper( 6156 g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]] 6157 ) 6158 shapes.append(shape) 6159 result = g.op("Concat", *shapes, axis_i=dims[i]) 6160 6161 return result 6162 6163 6164@_onnx_symbolic("aten::cross") 6165@symbolic_helper.parse_args("v", "v", "i") 6166def cross(g: jit_utils.GraphContext, input, other, dim=None): 6167 dim = symbolic_helper._get_dim_for_cross(input, dim) 6168 # If we have two tensors such that 6169 # A = [a, b, c], B = [d, e, f], we permute the tensor such that we have 6170 # After first roll, 6171 # A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e) 6172 roll_x_1 = roll(g, input, [2], [dim]) 6173 roll_y_1 = roll(g, other, [1], [dim]) 6174 # After second roll, 6175 # A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d) 6176 roll_x_2 = roll(g, input, [1], [dim]) 6177 roll_y_2 = roll(g, other, [2], [dim]) 6178 # cross product is calculated as 6179 # result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)] 6180 return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) 6181 6182 6183@_onnx_symbolic("aten::cdist") 6184def cdist( 6185 g: jit_utils.GraphContext, 6186 x1, 6187 x2, 6188 p=2.0, 6189 compute_mode="use_mm_for_euclid_dist_if_necessary", 6190): 6191 # X1.shape = (B * P * D), X2.shape = (B * R * D) 6192 # In order to respect numpy style broadcasting as demonstrated in 6193 # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md 6194 # we unsqueeze both input tensors 6195 # Currently we ignore the 'compute_mode' variable as we use default to 6196 # using matrix multiplication to calculate the euclidean distance 6197 rank = symbolic_helper._get_tensor_rank(x1) 6198 assert rank is not None 6199 broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1]) 6200 broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2]) 6201 return pairwise_distance( 6202 g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False 6203 ) 6204 6205 6206@_onnx_symbolic("aten::lerp") 6207def lerp(g: jit_utils.GraphContext, self, end, weight): 6208 # Conditional for better numeric. This has been discussed in 6209 # https://github.com/pytorch/pytorch/pull/18871 6210 diff = g.op("Sub", end, self) 6211 return where( 6212 g, 6213 g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))), 6214 g.op("Add", self, g.op("Mul", weight, diff)), 6215 g.op( 6216 "Sub", 6217 end, 6218 g.op( 6219 "Mul", 6220 diff, 6221 g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight), 6222 ), 6223 ), 6224 ) 6225 6226 6227@_onnx_symbolic("aten::broadcast_tensors") 6228def broadcast_tensors(g: jit_utils.GraphContext, self): 6229 all_tensors = symbolic_helper._unpack_list(self) 6230 t_with_final_shape = zeros_like(g, all_tensors[0]) 6231 6232 # Add operator supports multidirectional broadcasting. So we leverage this function 6233 # to infer the final shape generated by the broadcast. 6234 for t in all_tensors: 6235 t_with_final_shape = add(g, t_with_final_shape, t) 6236 6237 t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors] 6238 return g.op("prim::ListConstruct", *t_list) 6239 6240 6241@_onnx_symbolic("aten::is_pinned") 6242def is_pinned(g: jit_utils.GraphContext, self, device=None): 6243 # Unused by ONNX. 6244 return None 6245 6246 6247@_onnx_symbolic("prim::ConstantSplit") 6248def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim): 6249 size = symbolic_helper._get_tensor_dim_size(self, dim) 6250 if size is None: 6251 return symbolic_helper._unimplemented( 6252 "prim::ConstantSplit", "unknown dimension size", self 6253 ) 6254 splits = [split_size] * (size // split_size) 6255 leftover = size % split_size 6256 if leftover: 6257 splits.append(leftover) 6258 return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) 6259 6260 6261# TODO: It would be better to export this as a chunk directly, as this is 6262# less sensitive to changes in input size. 6263# TODO: Once we have proper scoping, stop reimplementing chunk, delete this 6264# method, and use the desugared version 6265@_onnx_symbolic("prim::ConstantChunk") 6266def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): 6267 dim_size = symbolic_helper._get_tensor_dim_size(self, dim) 6268 if dim_size is None: 6269 return symbolic_helper._unimplemented( 6270 "prim::ConstantChunk", "unknown dimension size", self 6271 ) 6272 split_size = (dim_size + chunks - 1) // chunks 6273 return prim_constant_split(g, self, split_size, dim) 6274 6275 6276@_onnx_symbolic("prim::shape") 6277def prim_shape(g: jit_utils.GraphContext, self): 6278 return g.op("Shape", self) 6279 6280 6281@_onnx_symbolic("prim::max") 6282def prim_max(g: jit_utils.GraphContext, self, other): 6283 return symbolic_helper._op_with_optional_float_cast( 6284 g, "Max", self, other, opset_before=12 6285 ) 6286 6287 6288@_onnx_symbolic("prim::min") 6289def prim_min(g: jit_utils.GraphContext, self, other=None): 6290 if not other: 6291 if symbolic_helper._is_packed_list(self): 6292 self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) 6293 return min(g, self) 6294 return min(g, self, other) 6295 6296 6297@_onnx_symbolic("prim::data") 6298def prim_data(g: jit_utils.GraphContext, self): 6299 return self 6300 6301 6302@_onnx_symbolic("prim::layout") 6303def prim_layout(g: jit_utils.GraphContext, self): 6304 # Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'. 6305 # Layout class defined in 'c10/core/Layout.h'. 6306 return g.op("Constant", value_t=torch.tensor(0)) 6307 6308 6309@_onnx_symbolic("prim::ListConstruct") 6310def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs): 6311 return None 6312 6313 6314@_onnx_symbolic("prim::ListUnpack") 6315def prim_list_unpack( 6316 g: jit_utils.GraphContext, *inputs, **kwargs 6317) -> list[_C.Value] | None: 6318 if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": 6319 # Cancel the previous node if it is ListConstruct by returning its inputs 6320 # TODO(justinchuby): Use a public method in the helper module 6321 return symbolic_helper._unpack_list(inputs[0]) 6322 6323 return None 6324 6325 6326@_onnx_symbolic("prim::TupleConstruct") 6327def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs): 6328 return None 6329 6330 6331@_onnx_symbolic("prim::Uninitialized") 6332def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs): 6333 return None 6334 6335 6336# exists to refine the type of the Value 6337# if x is an optional Tensor, unchecked_cast will cast 6338# x to Tensor, so the rest of the graph knows that x is a Tensor 6339# this doesn't do anything in runtime and is a noop in ONNX 6340@_onnx_symbolic("prim::unchecked_cast") 6341def prim_unchecked_cast(g: jit_utils.GraphContext, self): 6342 return self 6343 6344 6345@_onnx_symbolic("prim::dtype") 6346def prim_dtype(g: jit_utils.GraphContext, self): 6347 scalar_type = symbolic_helper._try_get_scalar_type(self) 6348 if scalar_type is None: 6349 scalar_type = _type_utils.JitScalarType.FLOAT 6350 # This node records a torch dtype as int 6351 return g.op("Constant", value_t=torch.tensor(scalar_type)) 6352 6353 6354@_onnx_symbolic("prim::tolist") 6355def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val): 6356 """tolist is currently supported only for 1D input tensors. 6357 6358 dim_val and elem_ty_val represent dimension and type annotations 6359 that need to match dimension and type of the input tensor. 6360 """ 6361 dim = symbolic_helper._maybe_get_const(dim_val, "i") 6362 if dim > 1: 6363 return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) 6364 return input 6365 6366 6367# ----------------------------------------------------------------------------- 6368# Symbolic functions that need extra context 6369# ----------------------------------------------------------------------------- 6370@_onnx_symbolic("prim::device") 6371def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: 6372 output_type = g.original_node.output().type() 6373 if isinstance(output_type, _C.DeviceObjType): 6374 return None 6375 6376 return symbolic_helper._unimplemented( 6377 "prim::device", 6378 f"output type should be 'DeviceObjType', not '{output_type.kind()}'", 6379 g.original_node.output(), 6380 ) 6381 6382 6383@_onnx_symbolic("prim::Loop") 6384def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: 6385 node = g.original_node 6386 env = g.env 6387 values_in_env = g.values_in_env 6388 params_dict = g.params_dict 6389 6390 operator_export_type = GLOBALS.operator_export_type 6391 opset_version = GLOBALS.export_onnx_opset_version 6392 6393 old_blocks = tuple(node.blocks()) 6394 new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( 6395 g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks) 6396 ) 6397 6398 for old_block, new_block_context in zip(old_blocks, new_block_contexts): 6399 # Copy input metadata to subblock 6400 # 6401 # prim::Loop(iter, cond, input_1, ..., input_n) 6402 # block0(iter, input_1, ..., input_n) 6403 # 6404 # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. 6405 for i, b_in in enumerate(old_block.inputs()): 6406 if i == 0 and i < len(inputs): 6407 b_in.setType(inputs[i].type()) 6408 # For optional block inputs, they may switch between None not-None inside 6409 # the loop body, so if the loop input is not optional, the block input may 6410 # still need to be optional. 6411 if ( 6412 i > 0 6413 and (i + 1) < len(inputs) 6414 and not isinstance(b_in.type(), _C.OptionalType) 6415 ): 6416 b_in.setType(inputs[i + 1].type()) 6417 torch._C._jit_pass_onnx_block( 6418 old_block, 6419 new_block_context.block, 6420 operator_export_type, 6421 env, 6422 values_in_env, 6423 False, 6424 ) 6425 fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( 6426 new_node, opset_version 6427 ) 6428 # Run shape type inference for Loop after subblock is converted. 6429 if GLOBALS.onnx_shape_inference: 6430 torch._C._jit_pass_onnx_node_shape_type_inference( 6431 new_node, params_dict, opset_version 6432 ) 6433 return fixed_outputs 6434 6435 6436@_onnx_symbolic("prim::If") 6437def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: 6438 n = g.original_node 6439 block = g.block 6440 env = g.env 6441 values_in_env = g.values_in_env 6442 params_dict = g.params_dict 6443 6444 operator_export_type = GLOBALS.operator_export_type 6445 opset_version = GLOBALS.export_onnx_opset_version 6446 6447 static_if = inputs[0].node().kind() == "onnx::Constant" 6448 if static_if: 6449 # Fold static if 6450 # 6451 # The torch IR 6452 # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), 6453 # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... 6454 # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() 6455 # %21 : Long(device=cpu) = aten::eq(%20, %64) 6456 # %22 : Long(device=cpu) = prim::If(%21) 6457 # block0(): 6458 # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) 6459 # -> (%23) 6460 # block1(): 6461 # -> (%65) 6462 # %input.53 : Tensor, %weight : Tensor = prim::If(%22) 6463 # block0(): 6464 # -> (%embedding_matrix.1, %input.1) 6465 # block1(): 6466 # -> (%input.1, %embedding_matrix.1) 6467 # %26 : int[] = aten::size(%input.53) 6468 # 6469 # The converted ONNX graph 6470 # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() 6471 # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) 6472 # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() 6473 # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) 6474 input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() 6475 const_value = ( 6476 all(input_flag) if isinstance(input_flag, list) else bool(input_flag) 6477 ) 6478 block_idx = 0 if const_value else 1 6479 current_b = list(n.blocks())[block_idx] 6480 env = torch._C._jit_pass_onnx_block( 6481 current_b, 6482 block, 6483 operator_export_type, 6484 env, 6485 values_in_env, 6486 True, 6487 ) 6488 if_output_list = list(n.outputs()) 6489 current_b_list = list(current_b.outputs()) 6490 6491 final_b_list = [] 6492 for idx in range(len(if_output_list)): 6493 if current_b_list[idx] not in env: 6494 raise errors.SymbolicValueError( 6495 f"The sub block ATen output {current_b_list[idx]} is not in env.", 6496 current_b_list[idx], 6497 ) # type:ignore[operator] 6498 onnx_b = env[current_b_list[idx]] 6499 final_b_list.append(onnx_b) 6500 return final_b_list 6501 else: 6502 old_blocks = tuple(n.blocks()) 6503 new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( 6504 g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks) 6505 ) 6506 6507 for old_block, new_block_context in zip(old_blocks, new_block_contexts): 6508 torch._C._jit_pass_onnx_block( 6509 old_block, 6510 new_block_context.block, 6511 operator_export_type, 6512 env, 6513 values_in_env, 6514 False, 6515 ) 6516 fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( 6517 new_node, opset_version 6518 ) 6519 # Run shape type inference for If after subblock is converted. 6520 if GLOBALS.onnx_shape_inference: 6521 torch._C._jit_pass_onnx_node_shape_type_inference( 6522 new_node, params_dict, opset_version 6523 ) 6524 return fixed_outputs 6525 6526 6527@_onnx_symbolic("prim::Constant") 6528def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs): 6529 node = g.original_node 6530 6531 if node.mustBeNone(): 6532 return None 6533 # This must go before checking for string values, because some device constants 6534 # have string values, but we want to keep them as unconverted Device types so 6535 # that eq() can work on them. 6536 if isinstance(node.output().type(), _C.DeviceObjType): 6537 return None 6538 if node.kindOf("value") == "t": 6539 return g.op("Constant", value_t=symbolic_helper._node_get(node, "value")) 6540 if node.kindOf("value") == "s": 6541 return g.op("Constant", value_s=symbolic_helper._node_get(node, "value")) 6542 if node.output().type().isSubtypeOf( 6543 _C.ListType.ofInts() 6544 ) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()): 6545 return g.op( 6546 "Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value")) 6547 ) 6548 if node.output().type().isSubtypeOf(_C.ListType.ofStrings()): 6549 str_constants = [ 6550 g.op("Constant", value_s=s) 6551 for s in symbolic_helper._node_get(node, "value") 6552 ] 6553 return g.op("prim::ListConstruct", *str_constants) 6554 6555 raise errors.SymbolicValueError( 6556 f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. " 6557 f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", 6558 node.output(), 6559 ) 6560 6561 6562@_onnx_symbolic("prim::type") 6563def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs): 6564 if device_value.node().kind() == "prim::device": 6565 device = jit_utils.get_device_from_value(device_value.node().input()) 6566 if device is not None: 6567 return g.op("Constant", value_s=str(device)) 6568 6569 return symbolic_helper._unimplemented( 6570 "prim::type", 6571 "Device type cannot be statically determined.", 6572 device_value, 6573 ) 6574 6575 6576@_onnx_symbolic("onnx::Placeholder") 6577def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs): 6578 node = g.original_node 6579 block = g.block 6580 env = g.env 6581 values_in_env = g.values_in_env 6582 6583 return torch._C._jit_onnx_convert_pattern_from_subblock( 6584 block, node, env, values_in_env 6585 ) 6586 6587 6588@_onnx_symbolic("aten::resolve_conj") 6589@_onnx_symbolic("aten::resolve_neg") 6590def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value): 6591 # ONNX does not have operators to *directly* manipulate real/imaginary components 6592 # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, 6593 # which results in failures due to missing operators for complex numbers 6594 6595 # `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op 6596 return input 6597 6598 6599@_onnx_symbolic("aten::_conj") 6600@_onnx_symbolic("aten::conj_physical") 6601def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value): 6602 # ONNX does not have operators to *directly* manipulate real/imaginary components 6603 # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, 6604 # which results in failures due to missing operators for complex numbers 6605 6606 # While `aten::_conj` and `aten::conj_physical` raise exception when input is complex 6607 if symbolic_helper.is_complex_value(input): 6608 # FIXME(justinchuby): report correct name for symbolic being executed 6609 return symbolic_helper._onnx_unsupported( 6610 "aten::_conj, aten::conj_physical", 6611 input, 6612 ) 6613 6614 # they can safely be implemented as no-op for real numbers only 6615 return noop_complex_operators(g, input) 6616 6617 6618@_onnx_symbolic("aten::logit") 6619def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value): 6620 one = g.op("Constant", value_t=torch.tensor(1.0)) 6621 6622 if not symbolic_helper._is_none(eps): 6623 eps = g.op( 6624 "Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type() 6625 ) 6626 one_sub_eps = g.op("Sub", one, eps) 6627 self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self) 6628 temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps) 6629 6630 temporary_self_less_eps = g.op("Less", temporary_self, eps) 6631 z = g.op("Where", temporary_self_less_eps, eps, temporary_self) 6632 else: 6633 z = self 6634 6635 sub = g.op("Sub", one, z) 6636 div = g.op("Div", z, sub) 6637 return g.op("Log", div) 6638