1# mypy: allow-untyped-defs 2# mypy: disable-error-code=arg-type 3"""This file exports ONNX ops for opset 17. 4 5Note [ONNX Operators that are added/updated in opset 17] 6 7~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 8https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set 9New operators: 10 BlackmanWindow 11 DFT 12 HammingWindow 13 HannWindow 14 LayerNormalization 15 MelWeightMatrix 16 STFT 17 SequenceMap 18""" 19 20import functools 21from typing import Optional, Sequence 22 23import torch 24from torch import _C 25from torch.onnx import _type_utils, errors, symbolic_helper 26from torch.onnx._internal import jit_utils, registration 27 28 29# EDITING THIS FILE? READ THIS FIRST! 30# see Note [Edit Symbolic Files] in README.md 31 32__all__ = ["layer_norm", "stft", "quantized_layer_norm"] 33 34_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) 35 36 37@_onnx_symbolic("aten::layer_norm") 38@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") 39def layer_norm( 40 g: jit_utils.GraphContext, 41 input: _C.Value, 42 normalized_shape: Sequence[int], 43 weight: _C.Value, 44 bias: _C.Value, 45 eps: float, 46 cudnn_enable: bool, 47): 48 # normalized_shape: input shape from an expected input of size 49 # axis: The first normalization dimension. 50 # layer_norm normalizes on the last D dimensions, 51 # where D is the size of normalized_shape 52 axis = -len(normalized_shape) 53 scalar_type = _type_utils.JitScalarType.from_value( 54 input, _type_utils.JitScalarType.FLOAT 55 ) 56 dtype = scalar_type.dtype() 57 if symbolic_helper._is_none(weight): 58 weight_value = torch.ones(normalized_shape, dtype=dtype) 59 weight = g.op("Constant", value_t=weight_value) 60 if symbolic_helper._is_none(bias): 61 bias_value = torch.zeros(normalized_shape, dtype=dtype) 62 bias = g.op("Constant", value_t=bias_value) 63 return g.op( 64 "LayerNormalization", 65 input, 66 weight, 67 bias, 68 epsilon_f=eps, 69 axis_i=axis, 70 ) 71 72 73@_onnx_symbolic("quantized::layer_norm") 74def quantized_layer_norm( 75 g: jit_utils.GraphContext, 76 x, 77 normalized_shape, 78 weight, 79 bias, 80 eps, 81 op_scale, 82 op_zero_point, 83): 84 x, _, _, _ = symbolic_helper.dequantize_helper(g, x) 85 86 output = layer_norm(g, x, normalized_shape, weight, bias, eps, False) 87 88 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 89 90 91def _compute_edge_sizes(n_fft, window_size): 92 """Helper function to compute the sizes of the edges (left and right) 93 of a given window centered within an FFT size.""" 94 left = (n_fft - window_size) // 2 95 right = n_fft - left - window_size 96 return left, right 97 98 99@_onnx_symbolic("aten::stft") 100@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b") 101def stft( 102 g: jit_utils.GraphContext, 103 input: _C.Value, 104 n_fft: int, 105 hop_length: Optional[int] = None, 106 win_length: Optional[int] = None, 107 window: Optional[_C.Value] = None, 108 normalized: bool = False, 109 onesided: Optional[bool] = True, 110 return_complex: Optional[bool] = False, 111) -> _C.Value: 112 """Associates `torch.stft` with the `STFT` ONNX operator. 113 Note that torch.stft calls _VF.stft, without centering or padding options. 114 Hence, this function does not contain these two arguments. 115 See torch.stft source code for more info. 116 117 Args: 118 g: Graph to write the ONNX representation into 119 input: Input tensor for the transformation 120 n_fft: FFT size 121 hop_length: Size of the hop. Defaults to `floot(n_fft // 4)` 122 win_length: Size of the analysis window. Defaults to `n_fft` 123 window: Analysis window. Defaults to a window of all ones 124 normalized: Whether to return a normalized STFT 125 onesided: Whether to return only half (+1) of the results, given the 126 symmetry of the STFT 127 return_complex: Whether to return the complex value (Note: Must be 128 `False` or `None`) 129 130 Returns: 131 op: Operator for torch.stft associated with STFT (ONNX) 132 """ 133 # Checks 134 if return_complex: 135 raise errors.SymbolicValueError( 136 msg="STFT does not currently support complex types", value=input 137 ) 138 139 # Get STFT sizes 140 frame_step_value = hop_length if hop_length is not None else n_fft // 4 141 frame_step_const = g.op( 142 "Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64) 143 ) 144 frame_length_const = g.op( 145 "Constant", value_t=torch.tensor(n_fft, dtype=torch.int64) 146 ) 147 148 # Pre-process input if needed 149 signal = input 150 signal_rank = symbolic_helper._get_tensor_rank(signal) 151 if signal_rank == 1: 152 # Add batch dimension 153 signal = g.op( 154 "Unsqueeze", 155 signal, 156 g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), 157 ) 158 elif signal_rank is None or signal_rank > 2: 159 raise errors.SymbolicValueError( 160 msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. " 161 f"Current rank of signal is {signal_rank}, please reduce it.", 162 value=input, 163 ) 164 165 # Get window and make sure it's the same size as `win_length` or `n_fft` 166 n_win = symbolic_helper._get_tensor_dim_size(window, dim=0) 167 if n_win is not None: 168 win_length_default = win_length if win_length else n_fft 169 assert n_win == win_length_default, ( 170 "Analysis window size must equal `win_length` or `n_fft`. " 171 f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})", 172 ) 173 174 # Center window around zeros if needed (required by ONNX's STFT) 175 if n_win < n_fft: 176 left, right = _compute_edge_sizes(n_fft, n_win) 177 left_win = g.op("Constant", value_t=torch.zeros(left)) 178 right_win = g.op("Constant", value_t=torch.zeros(right)) 179 window = g.op("Concat", left_win, window, right_win, axis_i=0) 180 181 # Create window, if needed 182 if symbolic_helper._is_none(window): 183 if win_length: 184 if win_length > n_fft: 185 raise errors.SymbolicValueError( 186 msg="The analysis window can't be longer than the size of the FFT. " 187 f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.", 188 value=input, 189 ) 190 191 # Center window, if needed 192 left, right = _compute_edge_sizes(n_fft, win_length) 193 torch_window = torch.hstack( 194 (torch.zeros(left), torch.ones(win_length), torch.zeros(right)) 195 ) 196 else: 197 # Rectangle window 198 torch_window = torch.ones(n_fft) 199 assert torch_window.shape[0] == n_fft 200 window = g.op("Constant", value_t=torch_window) 201 window = g.op( 202 "Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type() 203 ) 204 205 # Run STFT 206 result = g.op( 207 "STFT", 208 signal, 209 frame_step_const, 210 window, 211 frame_length_const, 212 onesided_i=1 if onesided is None or onesided else 0, 213 ) 214 215 # Transpose to mimic torch.stft's behavior 216 result = g.op("Transpose", result, perm_i=[0, 2, 1, 3]) 217 218 # Remove batch dimension, if needed 219 if signal_rank == 1: 220 result = g.op( 221 "Squeeze", 222 result, 223 g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), 224 ) 225 226 # Normalize, if needed 227 if normalized: 228 sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype())) 229 result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft)) 230 231 return result 232