1# mypy: allow-untyped-defs 2# mypy: disable-error-code=arg-type 3from __future__ import annotations 4 5import functools 6import sys 7import warnings 8from typing import Sequence 9 10import torch 11import torch._C._onnx as _C_onnx 12import torch.onnx 13from torch import _C 14 15# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics 16from torch.onnx import ( 17 _constants, 18 _type_utils, 19 errors, 20 symbolic_helper, 21 symbolic_opset9 as opset9, 22) 23from torch.onnx._globals import GLOBALS 24from torch.onnx._internal import jit_utils, registration 25 26 27# EDITING THIS FILE? READ THIS FIRST! 28# see Note [Edit Symbolic Files] in README.md 29 30# This file exports ONNX ops for opset 10 31# Opset 10 is supported by ONNX release 1.5.0 32# release on 04/24/19 33 34 35__all__ = [ 36 "dequantize", 37 "div", 38 "embedding_bag", 39 "fake_quantize_per_tensor_affine", 40 "flip", 41 "fmod", 42 "isfinite", 43 "isinf", 44 "nan_to_num", 45 "quantize_per_tensor", 46 "quantized_add_relu", 47 "quantized_add", 48 "quantized_cat", 49 "quantized_conv1d_relu", 50 "quantized_conv2d_relu", 51 "quantized_conv3d_relu", 52 "quantized_conv1d", 53 "quantized_conv2d", 54 "quantized_conv3d", 55 "quantized_conv_transpose1d", 56 "quantized_conv_transpose2d", 57 "quantized_conv_transpose3d", 58 "quantized_group_norm", 59 "quantized_hardswish", 60 "quantized_instance_norm", 61 "quantized_layer_norm", 62 "quantized_leaky_relu", 63 "quantized_linear", 64 "quantized_linear_relu", 65 "quantized_mul", 66 "quantized_sigmoid", 67 "slice", 68 "sort", 69 "topk", 70] 71 72 73_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) 74 75 76@_onnx_symbolic("aten::div") 77def div(g: jit_utils.GraphContext, self, other, *args): 78 if len(args) == 0: 79 return opset9.true_divide(g, self, other) 80 else: 81 return _div_rounding_mode(g, self, other, *args) 82 83 84@symbolic_helper.parse_args("v", "v", "s") 85def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): 86 if rounding_mode == "floor": 87 return _floor_divide(g, self, other) 88 else: 89 return opset9._div_rounding_mode(g, self, other, rounding_mode) 90 91 92@_onnx_symbolic("aten::_floor_divide") 93def _floor_divide(g: jit_utils.GraphContext, self, other): 94 if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): 95 out = opset9.true_divide(g, self, other) 96 return g.op("Floor", out) 97 else: 98 # Integer division does trunction rounding 99 div = g.op("Div", self, other) 100 # Division is negative if: self < 0 != other < 0 101 zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) 102 negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero)) 103 104 # For negative numbers with self % other != 0, subtract 1 to round down instead of up 105 mod = g.op("Mod", self, other, fmod_i=0) 106 fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) 107 108 one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) 109 fixup = g.op("Sub", div, one) 110 return g.op("Where", fixup_mask, fixup, div) 111 112 113@_onnx_symbolic("aten::sort") 114@symbolic_helper.parse_args("v", "i", "i", "none") 115def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): 116 return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) 117 118 119@_onnx_symbolic("aten::topk") 120@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") 121def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): 122 return symbolic_helper._topk_helper( 123 g, self, k, dim, largest=largest, sorted=sorted, out=out 124 ) 125 126 127def _aten_max_pool_onnx( 128 g: jit_utils.GraphContext, 129 self: _C.Value, 130 kernel_shape: Sequence[int], 131 strides: Sequence[int], 132 pads: Sequence[int], 133 dilations: Sequence[int], 134 ceil_mode: bool, 135 unbatched_rank: int, 136) -> _C.Value: 137 self_rank = g.op("Size", g.op("Shape", self)) 138 if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 139 self = g.op( 140 "Unsqueeze", 141 self, 142 g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), 143 ) 144 145 pool_result, _ = g.op( 146 "MaxPool", 147 self, 148 outputs=2, 149 ceil_mode_i=ceil_mode, 150 dilations_i=dilations, 151 kernel_shape_i=kernel_shape, 152 pads_i=pads, 153 strides_i=strides, 154 ) 155 156 if self_rank == unbatched_rank: 157 pool_result = g.op( 158 "Squeeze", 159 pool_result, 160 g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), 161 ) 162 163 return pool_result 164 165 166# For MaxPool 167def _adjust_attributes_of_max_pool( 168 expand_size: int, 169 kernel_size: Sequence[int] | int, 170 stride: Sequence[int] | int, 171 padding: Sequence[int] | int, 172 dilation: Sequence[int] | int, 173) -> tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: 174 """Adjust attributes of avg_pool to match ONNX specification.""" 175 176 if isinstance(dilation, int): 177 dilation = [dilation] * expand_size 178 179 if isinstance(kernel_size, int): 180 kernel_shape = [kernel_size] * expand_size 181 else: 182 kernel_shape = kernel_size # type: ignore[assignment] 183 184 if isinstance(padding, int): 185 pads = [padding] * expand_size * 2 # type: ignore[operator, assignment] 186 elif len(padding) == 1: 187 pads = padding * expand_size * 2 # type: ignore[operator, assignment] 188 elif len(padding) == 2: 189 # 2D padding 190 pads = padding * 2 # type: ignore[operator, assignment] 191 elif len(padding) == 3: 192 # 3D padding 193 pads = padding * 2 # type: ignore[operator, assignment] 194 else: 195 # When padding is already done for all dimensions, 196 # we don't need to double it 197 # eg: (1, 1, 1, 1, 1, 1) 198 pads = padding # type: ignore[assignment] 199 200 if isinstance(stride, int): 201 strides = [stride] * expand_size 202 elif not stride: 203 strides = kernel_shape 204 else: 205 strides = stride # type: ignore[assignment] 206 207 return (kernel_shape, strides, pads, dilation) 208 209 210def _aten_max_pool_with_indices_onnx( 211 g: jit_utils.GraphContext, 212 self: _C.Value, 213 kernel_shape: Sequence[int], 214 strides: Sequence[int], 215 pads: Sequence[int], 216 dilations: Sequence[int], 217 ceil_mode: bool, 218 unbatched_rank: int, 219 n_dims_one: Sequence[int], 220 n_dims_zero: Sequence[int], 221 n_dims_axes: Sequence[int], 222) -> tuple[_C.Value, Sequence[int]]: 223 self_rank = g.op("Size", g.op("Shape", self)) 224 if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 225 self = g.op( 226 "Unsqueeze", 227 self, 228 g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), 229 ) 230 231 pool_result, indices = g.op( 232 "MaxPool", 233 self, 234 outputs=2, 235 ceil_mode_i=ceil_mode, 236 dilations_i=dilations, 237 kernel_shape_i=kernel_shape, 238 pads_i=pads, 239 strides_i=strides, 240 ) 241 _, flatten_indices = g.op( 242 "MaxPool", 243 self, 244 outputs=2, 245 dilations_i=dilations, 246 kernel_shape_i=n_dims_one, 247 strides_i=n_dims_one, 248 ) 249 250 ends = g.op("Constant", value_t=torch.tensor(n_dims_one)) 251 starts = g.op("Constant", value_t=torch.tensor(n_dims_zero)) 252 axes = g.op("Constant", value_t=torch.tensor(n_dims_axes)) 253 254 delta = g.op("Slice", flatten_indices, starts, ends, axes) 255 indices = g.op("Sub", indices, delta) 256 257 if self_rank == unbatched_rank: 258 pool_result = g.op( 259 "Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64) 260 ) 261 indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64)) 262 263 return (pool_result, indices) 264 265 266@_onnx_symbolic( 267 "aten::max_pool1d", 268 decorate=[symbolic_helper._apply_params("max_pool1d", 1, return_indices=False)], 269) 270@_onnx_symbolic( 271 "aten::max_pool2d", 272 decorate=[symbolic_helper._apply_params("max_pool2d", 2, return_indices=False)], 273) 274@_onnx_symbolic( 275 "aten::max_pool3d", 276 decorate=[symbolic_helper._apply_params("max_pool3d", 3, return_indices=False)], 277) 278@_onnx_symbolic( 279 "aten::max_pool1d_with_indices", 280 decorate=[ 281 symbolic_helper._apply_params( 282 "max_pool1d_with_indices", 283 1, 284 return_indices=True, 285 ) 286 ], 287) 288@_onnx_symbolic( 289 "aten::max_pool2d_with_indices", 290 decorate=[ 291 symbolic_helper._apply_params( 292 "max_pool2d_with_indices", 293 2, 294 return_indices=True, 295 ) 296 ], 297) 298@_onnx_symbolic( 299 "aten::max_pool3d_with_indices", 300 decorate=[ 301 symbolic_helper._apply_params( 302 "max_pool3d_with_indices", 303 3, 304 return_indices=True, 305 ) 306 ], 307) 308def _max_pool(name: str, expand_size: int, return_indices: bool): 309 @symbolic_helper.quantized_args(True, False, False, False, False, False) 310 @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") 311 def symbolic_fn( 312 g: jit_utils.GraphContext, 313 input: _C.Value, 314 kernel_size: Sequence[int], 315 stride: Sequence[int], 316 padding: int | Sequence[int], 317 dilation: Sequence[int], 318 ceil_mode: bool, 319 ): 320 kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( 321 expand_size, kernel_size, stride, padding, dilation 322 ) 323 324 if return_indices: 325 return _aten_max_pool_with_indices_onnx( 326 g, 327 input, 328 kernel_shape, 329 strides, 330 pads, 331 dilations, 332 ceil_mode, 333 expand_size + 1, 334 ([1] * expand_size), 335 ([0] * expand_size), 336 ([2 + i for i in range(expand_size)]), 337 ) 338 else: 339 return _aten_max_pool_onnx( 340 g, 341 input, 342 kernel_shape, 343 strides, 344 pads, 345 dilations, 346 ceil_mode, 347 expand_size + 1, 348 ) 349 350 return symbolic_fn 351 352 353# For AvgPool 354def _adjust_attributes_of_avg_pool( 355 expand_size: int, 356 kernel_size: Sequence[int] | int, 357 stride: Sequence[int] | int, 358 padding: Sequence[int] | int, 359) -> tuple[Sequence[int], Sequence[int], Sequence[int]]: 360 """Adjust attributes of avg_pool to match ONNX specification.""" 361 362 if isinstance(kernel_size, int): 363 kernel_shape = [kernel_size] * expand_size 364 else: 365 kernel_shape = kernel_size # type: ignore[assignment] 366 367 if isinstance(padding, int): 368 pads = [padding] * expand_size * 2 369 elif len(padding) == 1: 370 pads = padding * expand_size * 2 # type: ignore[operator, assignment] 371 elif len(padding) == 2: 372 pads = padding * expand_size # type: ignore[operator, assignment] 373 else: 374 pads = padding * 2 # type: ignore[operator, assignment] 375 376 if isinstance(stride, int): 377 strides = [stride] * expand_size 378 elif not stride: 379 strides = kernel_shape 380 else: 381 strides = stride # type: ignore[assignment] 382 383 return (kernel_shape, strides, pads) 384 385 386@_onnx_symbolic( 387 "aten::avg_pool1d", 388 decorate=[symbolic_helper._apply_params("avg_pool1d", 1)], 389) 390@_onnx_symbolic( 391 "aten::avg_pool2d", 392 decorate=[symbolic_helper._apply_params("avg_pool2d", 2)], 393) 394@_onnx_symbolic( 395 "aten::avg_pool3d", 396 decorate=[symbolic_helper._apply_params("avg_pool3d", 3)], 397) 398def _avg_pool(name, expand_size): 399 @symbolic_helper.quantized_args(True, False, False, False, False, False, False) 400 @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") 401 def symbolic_fn( 402 g, 403 input: _C.Value, 404 kernel_size: Sequence[int], 405 stride: Sequence[int], 406 padding: int | Sequence[int], 407 ceil_mode: int, 408 count_include_pad: int, 409 divisor_override=None, 410 ): 411 kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( 412 expand_size, kernel_size, stride, padding 413 ) 414 415 result = g.op( 416 "AveragePool", 417 input, 418 ceil_mode_i=ceil_mode, 419 count_include_pad_i=count_include_pad, 420 kernel_shape_i=kernel_shape, 421 pads_i=pads, 422 strides_i=strides, 423 ) 424 425 return result 426 427 return symbolic_fn 428 429 430@_onnx_symbolic( 431 "aten::upsample_nearest1d", 432 decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], 433) 434@_onnx_symbolic( 435 "aten::upsample_nearest2d", 436 decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], 437) 438@_onnx_symbolic( 439 "aten::upsample_nearest3d", 440 decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], 441) 442@_onnx_symbolic( 443 "aten::upsample_linear1d", 444 decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], 445) 446@_onnx_symbolic( 447 "aten::upsample_bilinear2d", 448 decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], 449) 450@_onnx_symbolic( 451 "aten::upsample_trilinear3d", 452 decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], 453) 454def _interpolate(name, dim, interpolate_mode): 455 @symbolic_helper.quantized_args(True, False, False) 456 def symbolic_fn(g, input, output_size, *args): 457 scales, align_corners = symbolic_helper._get_interpolate_attributes( 458 g, interpolate_mode, args 459 ) 460 symbolic_helper._interpolate_warning(interpolate_mode) 461 align_corners = symbolic_helper._maybe_get_scalar(align_corners) 462 if align_corners: 463 return symbolic_helper._unimplemented(name, "align_corners == True", input) 464 if scales is None: 465 scales = symbolic_helper._interpolate_size_to_scales( 466 g, input, output_size, dim 467 ) 468 return g.op("Resize", input, scales, mode_s=interpolate_mode) 469 470 return symbolic_fn 471 472 473@_onnx_symbolic("aten::__interpolate") 474def __interpolate( 475 g: jit_utils.GraphContext, 476 input, 477 size, 478 scale_factor, 479 mode, 480 align_corners, 481 recompute_scale_factor, 482 antialias, 483): 484 scales, mode = symbolic_helper._interpolate_get_scales_and_mode( 485 g, input, size, scale_factor, mode, align_corners 486 ) 487 return g.op("Resize", input, scales, mode_s=mode) 488 489 490def _slice( 491 g: jit_utils.GraphContext, 492 input: torch._C.Value, 493 axes: list | torch.Tensor | torch._C.Value, 494 starts: list | torch.Tensor | torch._C.Value, 495 ends: list | torch.Tensor | torch._C.Value, 496 steps: list | torch.Tensor | torch._C.Value | None = None, 497): 498 def is_none_value(value): 499 if value is None: 500 return True 501 return ( 502 isinstance(value, torch._C.Value) 503 and value.node().kind() == "prim::Constant" 504 and isinstance(value.type(), _C.NoneType) 505 ) 506 507 def to_slice_input(list_or_value, default_value=None): 508 # Convert input param into a 1D torch.Value. 509 if is_none_value(list_or_value) and default_value is not None: 510 list_or_value = [default_value] 511 512 if isinstance(list_or_value, (list, torch.Tensor)): 513 return g.op("Constant", value_t=torch.tensor(list_or_value)) 514 515 rank = symbolic_helper._get_tensor_rank(list_or_value) 516 if rank == 0: 517 return symbolic_helper._unsqueeze_helper(g, list_or_value, [0]) 518 if rank == 1: 519 return list_or_value 520 raise errors.SymbolicValueError( 521 f"Rank must be 0 or 1, not {rank}", list_or_value 522 ) 523 524 def get_const_value(list_or_value): 525 if isinstance(list_or_value, (list, torch.Tensor)): 526 if len(list_or_value) == 1: 527 return list_or_value[0] 528 return None 529 return symbolic_helper._maybe_get_const(list_or_value, "i") 530 531 # Check if slice is a no-op 532 if ( 533 get_const_value(starts) == 0 534 and get_const_value(ends) == _constants.INT64_MAX 535 and (steps is None or get_const_value(steps) == 1) 536 ): 537 return input 538 539 axes = to_slice_input(axes) 540 starts = to_slice_input(starts, default_value=0) 541 ends = to_slice_input(ends, default_value=_constants.INT64_MAX) 542 if steps is None: 543 return g.op("Slice", input, starts, ends, axes) 544 steps = to_slice_input(steps, default_value=1) 545 return g.op("Slice", input, starts, ends, axes, steps) 546 547 548@_onnx_symbolic("aten::slice") 549def slice(g: jit_utils.GraphContext, self, *args): 550 if len(args) == 4: 551 # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor 552 dims, start, end, step = args 553 elif len(args) == 3: 554 # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[] 555 start, end, step = args 556 dims = [0] 557 else: 558 raise errors.SymbolicValueError("Unknown aten::slice signature", self) 559 560 return symbolic_helper._slice_helper( 561 g, 562 self, 563 axes=dims, 564 starts=start, 565 ends=end, 566 steps=step, 567 ) 568 569 570@_onnx_symbolic("aten::flip") 571@symbolic_helper.parse_args("v", "is") 572def flip(g: jit_utils.GraphContext, input, dims): 573 return symbolic_helper._slice_helper( 574 g, 575 input, 576 axes=dims, 577 starts=[-1] * len(dims), 578 ends=[-_constants.INT64_MAX] * len(dims), 579 steps=[-1] * len(dims), 580 ) 581 582 583@_onnx_symbolic("aten::fmod") 584def fmod(g: jit_utils.GraphContext, input, other): 585 return g.op("Mod", input, other, fmod_i=1) 586 587 588@_onnx_symbolic("aten::embedding_bag") 589@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") 590def embedding_bag( 591 g: jit_utils.GraphContext, 592 embedding_matrix, 593 indices, 594 offsets, 595 scale_grad_by_freq, 596 mode, 597 sparse, 598 per_sample_weights, 599 include_last_offset, 600 padding_idx, 601): 602 if scale_grad_by_freq and GLOBALS.export_training: 603 return symbolic_helper._onnx_unsupported( 604 "embedding_bag with scale_grad_by_freq for training mode" 605 ) 606 if padding_idx is not None and padding_idx >= 0: 607 raise RuntimeError("embedding_bag with padding_idx") 608 609 warnings.warn( 610 "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " 611 "Please use opset 11 or higher to export model for dynamic input shape.'" 612 ) 613 offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0) 614 if offsets_dim_0 is not None: 615 if include_last_offset: 616 offset_len = offsets_dim_0 - 1 617 offsets_extended = offsets 618 else: 619 offset_len = offsets_dim_0 620 offsets_extended = [ 621 offsets, 622 g.op("Constant", value_t=torch.tensor([sys.maxsize])), 623 ] 624 offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) 625 list_ = [] 626 for i in range(offset_len): 627 start_ = symbolic_helper._unsqueeze_helper( 628 g, 629 opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), 630 [0], 631 ) 632 end_ = symbolic_helper._unsqueeze_helper( 633 g, 634 opset9.select( 635 g, offsets_extended, torch.tensor(0), torch.tensor(i + 1) 636 ), 637 [0], 638 ) 639 axes_ = g.op("Constant", value_t=torch.tensor([0])) 640 indices_row = g.op("Slice", indices, start_, end_, axes_) 641 642 embeddings = g.op("Gather", embedding_matrix, indices_row) 643 if not symbolic_helper._is_none(per_sample_weights): 644 per_sample_weights_row = g.op( 645 "Slice", per_sample_weights, start_, end_, axes_ 646 ) 647 per_sample_weights_row = symbolic_helper._unsqueeze_helper( 648 g, per_sample_weights_row, [1] 649 ) 650 embeddings = g.op("Mul", embeddings, per_sample_weights_row) 651 if mode == 0: 652 embeddings = symbolic_helper._reducesum_helper( 653 g, embeddings, axes_i=[0], keepdims_i=0 654 ) 655 elif mode == 1: 656 embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) 657 else: 658 embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) 659 660 embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0]) 661 list_.append(embeddings) 662 663 output = g.op("Concat", *list_, axis_i=0) 664 # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. 665 # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. 666 return output, None, None, None 667 else: 668 return symbolic_helper._onnx_unsupported( 669 "embedding_bag with unknown shape of offsets for opset 10 is not supported. " 670 "please use opset 11 or higher." 671 ) 672 673 674@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") 675@symbolic_helper.parse_args("v", "v", "v", "i", "i") 676def fake_quantize_per_tensor_affine( 677 g: jit_utils.GraphContext, 678 inputs, 679 scale, 680 zero_point, 681 quant_min=-128, 682 quant_max=127, 683): 684 # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127). 685 # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 686 if (quant_min, quant_max) == (0, 127): 687 symbolic_helper._onnx_opset_unsupported_detailed( 688 "fake_quantize_per_tensor_affine", 689 10, 690 13, 691 "Quantize range (0, 127) not supported, requires opset 13 Clip", 692 inputs, 693 ) 694 if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: 695 raise errors.SymbolicValueError( 696 f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " 697 f"Got ({quant_min}, {quant_max})", 698 inputs, 699 ) 700 scale = symbolic_helper._maybe_get_scalar(scale) 701 if scale is None: 702 symbolic_helper._onnx_opset_unsupported_detailed( 703 "fake_quantize_per_tensor_affine", 704 10, 705 13, 706 "Non-constant scale not supported", 707 inputs, 708 ) 709 scale = scale.float().data # Avoid exporter generating double type 710 if quant_min == 0: 711 zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) 712 else: 713 zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) 714 return g.op( 715 "DequantizeLinear", 716 g.op("QuantizeLinear", inputs, scale, zero_point), 717 scale, 718 zero_point, 719 ) 720 721 722@_onnx_symbolic("aten::isinf") 723def isinf(g: jit_utils.GraphContext, input): 724 return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)) 725 726 727@_onnx_symbolic("aten::isfinite") 728def isfinite(g: jit_utils.GraphContext, input): 729 inf_node = isinf(g, input) 730 nan_node = opset9.isnan(g, input) 731 return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) 732 733 734@_onnx_symbolic("aten::quantize_per_tensor") 735def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): 736 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 737 # TODO(justinchuby): Extract all the cast ops into a helper function. 738 zero_point = g.op( 739 "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type() 740 ) 741 scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) 742 return symbolic_helper.quantize_helper(g, input, scale, zero_point) 743 744 745@_onnx_symbolic("aten::dequantize") 746def dequantize(g: jit_utils.GraphContext, input): 747 return symbolic_helper.dequantize_helper(g, input)[0] 748 749 750@_onnx_symbolic("aten::nan_to_num") 751@symbolic_helper.parse_args("v", "f", "f", "f") 752def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf): 753 # Cannot create a int type tensor with inf/nan values, so we simply 754 # return the original tensor 755 if not symbolic_helper._is_fp(input): 756 return input 757 input_dtype = _type_utils.JitScalarType.from_value(input).dtype() 758 if nan is None: 759 nan = 0.0 760 nan_cond = opset9.isnan(g, input) 761 nan_result = g.op( 762 "Where", 763 nan_cond, 764 g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), 765 input, 766 ) 767 768 # For None values of posinf, neginf we use the greatest/lowest finite 769 # value representable by input's dtype. 770 finfo = torch.finfo(input_dtype) 771 if posinf is None: 772 posinf = finfo.max 773 posinf_cond = opset9.logical_and( 774 g, 775 isinf(g, nan_result), 776 opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), 777 ) 778 nan_posinf_result = g.op( 779 "Where", 780 posinf_cond, 781 g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), 782 nan_result, 783 ) 784 785 if neginf is None: 786 neginf = finfo.min 787 neginf_cond = opset9.logical_and( 788 g, 789 isinf(g, nan_posinf_result), 790 opset9.lt( 791 g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])) 792 ), 793 ) 794 return g.op( 795 "Where", 796 neginf_cond, 797 g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), 798 nan_posinf_result, 799 ) 800 801 802# Quantized symbolics --------------------------------------------------------- 803# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export 804# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were 805# introduced in opset version 10. 806@_onnx_symbolic("quantized::linear") 807def quantized_linear( 808 g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point 809): 810 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 811 weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) 812 q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) 813 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 814 815 output = opset9.linear(g, input, weight, bias) 816 817 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 818 819 820@_onnx_symbolic("quantized::linear_relu") 821def quantized_linear_relu( 822 g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point 823): 824 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 825 weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) 826 q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) 827 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 828 829 output = opset9.linear(g, input, weight, bias) 830 output = opset9.relu(g, output) 831 832 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 833 834 835@_onnx_symbolic("quantized::add") 836def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): 837 x, _, _, _ = symbolic_helper.dequantize_helper(g, x) 838 y, _, _, _ = symbolic_helper.dequantize_helper(g, y) 839 840 output = opset9.add(g, x, y) 841 842 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 843 844 845@_onnx_symbolic("quantized::add_relu") 846def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): 847 x, _, _, _ = symbolic_helper.dequantize_helper(g, x) 848 y, _, _, _ = symbolic_helper.dequantize_helper(g, y) 849 850 output = opset9.add(g, x, y) 851 output = opset9.relu(g, output) 852 853 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 854 855 856@_onnx_symbolic("quantized::mul") 857def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): 858 x, _, _, _ = symbolic_helper.dequantize_helper(g, x) 859 y, _, _, _ = symbolic_helper.dequantize_helper(g, y) 860 861 output = opset9.mul(g, x, y) 862 863 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 864 865 866@_onnx_symbolic("quantized::hardswish") 867def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): 868 x, _, _, _ = symbolic_helper.dequantize_helper(g, x) 869 870 output = opset9.hardswish(g, x) 871 872 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 873 874 875@_onnx_symbolic("quantized::sigmoid") 876def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point): 877 x, _, _, _ = symbolic_helper.dequantize_helper(g, x) 878 879 output = opset9.sigmoid(g, x) 880 881 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 882 883 884@_onnx_symbolic("quantized::leaky_relu") 885def quantized_leaky_relu( 886 g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point 887): 888 x, _, _, _ = symbolic_helper.dequantize_helper(g, x) 889 890 output = opset9.leaky_relu(g, x, negative_slope, inplace) 891 892 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 893 894 895@_onnx_symbolic("quantized::layer_norm") 896def quantized_layer_norm( 897 g: jit_utils.GraphContext, 898 x, 899 normalized_shape, 900 weight, 901 bias, 902 eps, 903 op_scale, 904 op_zero_point, 905): 906 x, _, _, _ = symbolic_helper.dequantize_helper(g, x) 907 908 output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) 909 910 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 911 912 913@_onnx_symbolic("quantized::group_norm") 914def quantized_group_norm( 915 g: jit_utils.GraphContext, 916 x, 917 num_groups, 918 weight, 919 bias, 920 eps, 921 op_scale, 922 op_zero_point, 923): 924 x, _, _, _ = symbolic_helper.dequantize_helper(g, x) 925 926 output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) 927 928 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 929 930 931@_onnx_symbolic("quantized::instance_norm") 932@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") 933def quantized_instance_norm( 934 g: jit_utils.GraphContext, 935 q_input, 936 weight, 937 bias, 938 eps, 939 op_scale, 940 op_zero_point, 941): 942 input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) 943 944 output = opset9.instance_norm( 945 g, input, weight, bias, None, None, False, 0.0, eps, False 946 ) 947 948 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 949 950 951@_onnx_symbolic("quantized::conv1d_relu") 952def quantized_conv1d_relu( 953 g: jit_utils.GraphContext, 954 q_input, 955 q_weight, 956 bias, 957 stride, 958 padding, 959 dilation, 960 groups, 961 op_scale, 962 op_zero_point, 963): 964 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 965 weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) 966 q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) 967 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 968 969 output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) 970 output = opset9.relu(g, output) 971 972 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 973 974 975@_onnx_symbolic("quantized::conv2d_relu") 976def quantized_conv2d_relu( 977 g: jit_utils.GraphContext, 978 q_input, 979 q_weight, 980 bias, 981 stride, 982 padding, 983 dilation, 984 groups, 985 op_scale, 986 op_zero_point, 987): 988 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 989 weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) 990 q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) 991 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 992 993 output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) 994 output = opset9.relu(g, output) 995 996 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 997 998 999@_onnx_symbolic("quantized::conv3d_relu") 1000def quantized_conv3d_relu( 1001 g: jit_utils.GraphContext, 1002 q_input, 1003 q_weight, 1004 bias, 1005 stride, 1006 padding, 1007 dilation, 1008 groups, 1009 op_scale, 1010 op_zero_point, 1011): 1012 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 1013 weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) 1014 q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) 1015 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1016 1017 output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) 1018 output = opset9.relu(g, output) 1019 1020 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1021 1022 1023@_onnx_symbolic("quantized::conv1d") 1024def quantized_conv1d( 1025 g: jit_utils.GraphContext, 1026 q_input, 1027 q_weight, 1028 bias, 1029 stride, 1030 padding, 1031 dilation, 1032 groups, 1033 op_scale, 1034 op_zero_point, 1035): 1036 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 1037 weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) 1038 q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) 1039 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1040 1041 output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) 1042 1043 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1044 1045 1046@_onnx_symbolic("quantized::conv2d") 1047def quantized_conv2d( 1048 g: jit_utils.GraphContext, 1049 q_input, 1050 q_weight, 1051 bias, 1052 stride, 1053 padding, 1054 dilation, 1055 groups, 1056 op_scale, 1057 op_zero_point, 1058): 1059 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 1060 weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) 1061 q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) 1062 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1063 1064 output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) 1065 1066 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1067 1068 1069@_onnx_symbolic("quantized::conv3d") 1070def quantized_conv3d( 1071 g: jit_utils.GraphContext, 1072 q_input, 1073 q_weight, 1074 bias, 1075 stride, 1076 padding, 1077 dilation, 1078 groups, 1079 op_scale, 1080 op_zero_point, 1081): 1082 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 1083 weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) 1084 q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) 1085 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1086 1087 output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) 1088 1089 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1090 1091 1092@_onnx_symbolic("quantized::conv_transpose1d") 1093def quantized_conv_transpose1d( 1094 g: jit_utils.GraphContext, 1095 q_input, 1096 q_weight, 1097 bias, 1098 stride, 1099 padding, 1100 output_padding, 1101 dilation, 1102 groups, 1103 op_scale, 1104 op_zero_point, 1105): 1106 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 1107 weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) 1108 q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) 1109 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1110 1111 output = opset9.conv_transpose2d( 1112 g, input, weight, bias, stride, padding, output_padding, groups, dilation 1113 ) 1114 1115 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1116 1117 1118@_onnx_symbolic("quantized::conv_transpose2d") 1119def quantized_conv_transpose2d( 1120 g: jit_utils.GraphContext, 1121 q_input, 1122 q_weight, 1123 bias, 1124 stride, 1125 padding, 1126 output_padding, 1127 dilation, 1128 groups, 1129 op_scale, 1130 op_zero_point, 1131): 1132 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 1133 weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) 1134 q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) 1135 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1136 1137 output = opset9.conv_transpose2d( 1138 g, input, weight, bias, stride, padding, output_padding, groups, dilation 1139 ) 1140 1141 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1142 1143 1144@_onnx_symbolic("quantized::conv_transpose3d") 1145def quantized_conv_transpose3d( 1146 g: jit_utils.GraphContext, 1147 q_input, 1148 q_weight, 1149 bias, 1150 stride, 1151 padding, 1152 output_padding, 1153 dilation, 1154 groups, 1155 op_scale, 1156 op_zero_point, 1157): 1158 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 1159 weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) 1160 q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) 1161 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1162 1163 output = opset9.conv_transpose3d( 1164 g, input, weight, bias, stride, padding, output_padding, groups, dilation 1165 ) 1166 1167 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1168 1169 1170@_onnx_symbolic("quantized::cat") 1171@symbolic_helper.parse_args("v", "i", "v", "v") 1172def quantized_cat( 1173 g: jit_utils.GraphContext, 1174 q_inputs: _C.Value, 1175 dim: int, 1176 op_scale: _C.Value, 1177 op_zero_point: _C.Value, 1178) -> _C.Value: 1179 unpacked_inputs = symbolic_helper._unpack_list(q_inputs) 1180 dequantized = [ 1181 symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs 1182 ] 1183 concatenated = g.op("Concat", *dequantized, axis_i=dim) 1184 return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) 1185