1# mypy: allow-untyped-defs 2"""This file exports ONNX ops for opset 20. 3 4Note [ONNX Operators that are added/updated in opset 20] 5 6~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 7https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set 8New operators: 9 AffineGrid 10 ConstantOfShape 11 DFT 12 Gelu 13 GridSample 14 ImageDecoder 15 IsInf 16 IsNaN 17 ReduceMax 18 ReduceMin 19 RegexFullMatch 20 StringConcat 21 StringSplit 22""" 23 24import functools 25 26import torch.nn.functional as F 27from torch import _C 28from torch.onnx import symbolic_helper 29from torch.onnx._internal import jit_utils, registration 30 31 32# EDITING THIS FILE? READ THIS FIRST! 33# see Note [Edit Symbolic Files] in symbolic_helper.py 34 35__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"] 36 37 38def convert_grid_sample_mode(mode_s): 39 return ( 40 "linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s 41 ) 42 43 44_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20) 45 46 47@_onnx_symbolic("aten::grid_sampler") 48@symbolic_helper.parse_args("v", "v", "i", "i", "b") 49def _grid_sampler( 50 g: jit_utils.GraphContext, 51 input: _C.Value, 52 grid: _C.Value, 53 mode_enum: int, 54 padding_mode_enum: int, 55 align_corners: bool, 56): 57 mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index] 58 # mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html 59 mode_s = convert_grid_sample_mode(mode_s) 60 padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index] 61 padding_mode_enum # type: ignore[index] 62 ] 63 return g.op( 64 "GridSample", 65 input, 66 grid, 67 align_corners_i=int(align_corners), 68 mode_s=mode_s, 69 padding_mode_s=padding_mode_s, 70 ) 71 72 73@_onnx_symbolic("aten::affine_grid_generator") 74@symbolic_helper.parse_args("v", "v", "b") 75def _affine_grid_generator( 76 g: jit_utils.GraphContext, 77 theta: _C.Value, 78 size: _C.Value, 79 align_corners: bool, 80): 81 return g.op( 82 "AffineGrid", 83 theta, 84 size, 85 align_corners_i=int(align_corners), 86 ) 87 88 89@_onnx_symbolic("aten::gelu") 90@symbolic_helper.parse_args("v", "s") 91def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"): 92 return g.op("Gelu", self, approximate_s=approximate) 93