xref: /aosp_15_r20/external/pytorch/torch/onnx/symbolic_opset17.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# mypy: disable-error-code=arg-type
3"""This file exports ONNX ops for opset 17.
4
5Note [ONNX Operators that are added/updated in opset 17]
6
7~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set
9New operators:
10    BlackmanWindow
11    DFT
12    HammingWindow
13    HannWindow
14    LayerNormalization
15    MelWeightMatrix
16    STFT
17    SequenceMap
18"""
19
20import functools
21from typing import Optional, Sequence
22
23import torch
24from torch import _C
25from torch.onnx import _type_utils, errors, symbolic_helper
26from torch.onnx._internal import jit_utils, registration
27
28
29# EDITING THIS FILE? READ THIS FIRST!
30# see Note [Edit Symbolic Files] in README.md
31
32__all__ = ["layer_norm", "stft", "quantized_layer_norm"]
33
34_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17)
35
36
37@_onnx_symbolic("aten::layer_norm")
38@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
39def layer_norm(
40    g: jit_utils.GraphContext,
41    input: _C.Value,
42    normalized_shape: Sequence[int],
43    weight: _C.Value,
44    bias: _C.Value,
45    eps: float,
46    cudnn_enable: bool,
47):
48    # normalized_shape: input shape from an expected input of size
49    # axis: The first normalization dimension.
50    # layer_norm normalizes on the last D dimensions,
51    # where D is the size of normalized_shape
52    axis = -len(normalized_shape)
53    scalar_type = _type_utils.JitScalarType.from_value(
54        input, _type_utils.JitScalarType.FLOAT
55    )
56    dtype = scalar_type.dtype()
57    if symbolic_helper._is_none(weight):
58        weight_value = torch.ones(normalized_shape, dtype=dtype)
59        weight = g.op("Constant", value_t=weight_value)
60    if symbolic_helper._is_none(bias):
61        bias_value = torch.zeros(normalized_shape, dtype=dtype)
62        bias = g.op("Constant", value_t=bias_value)
63    return g.op(
64        "LayerNormalization",
65        input,
66        weight,
67        bias,
68        epsilon_f=eps,
69        axis_i=axis,
70    )
71
72
73@_onnx_symbolic("quantized::layer_norm")
74def quantized_layer_norm(
75    g: jit_utils.GraphContext,
76    x,
77    normalized_shape,
78    weight,
79    bias,
80    eps,
81    op_scale,
82    op_zero_point,
83):
84    x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
85
86    output = layer_norm(g, x, normalized_shape, weight, bias, eps, False)
87
88    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
89
90
91def _compute_edge_sizes(n_fft, window_size):
92    """Helper function to compute the sizes of the edges (left and right)
93    of a given window centered within an FFT size."""
94    left = (n_fft - window_size) // 2
95    right = n_fft - left - window_size
96    return left, right
97
98
99@_onnx_symbolic("aten::stft")
100@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b")
101def stft(
102    g: jit_utils.GraphContext,
103    input: _C.Value,
104    n_fft: int,
105    hop_length: Optional[int] = None,
106    win_length: Optional[int] = None,
107    window: Optional[_C.Value] = None,
108    normalized: bool = False,
109    onesided: Optional[bool] = True,
110    return_complex: Optional[bool] = False,
111) -> _C.Value:
112    """Associates `torch.stft` with the `STFT` ONNX operator.
113    Note that torch.stft calls _VF.stft, without centering or padding options.
114    Hence, this function does not contain these two arguments.
115    See torch.stft source code for more info.
116
117    Args:
118        g: Graph to write the ONNX representation into
119        input: Input tensor for the transformation
120        n_fft: FFT size
121        hop_length: Size of the hop. Defaults to `floot(n_fft // 4)`
122        win_length: Size of the analysis window. Defaults to `n_fft`
123        window: Analysis window. Defaults to a window of all ones
124        normalized: Whether to return a normalized STFT
125        onesided: Whether to return only half (+1) of the results, given the
126            symmetry of the STFT
127        return_complex: Whether to return the complex value (Note: Must be
128            `False` or `None`)
129
130    Returns:
131        op: Operator for torch.stft associated with STFT (ONNX)
132    """
133    # Checks
134    if return_complex:
135        raise errors.SymbolicValueError(
136            msg="STFT does not currently support complex types", value=input
137        )
138
139    # Get STFT sizes
140    frame_step_value = hop_length if hop_length is not None else n_fft // 4
141    frame_step_const = g.op(
142        "Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64)
143    )
144    frame_length_const = g.op(
145        "Constant", value_t=torch.tensor(n_fft, dtype=torch.int64)
146    )
147
148    # Pre-process input if needed
149    signal = input
150    signal_rank = symbolic_helper._get_tensor_rank(signal)
151    if signal_rank == 1:
152        # Add batch dimension
153        signal = g.op(
154            "Unsqueeze",
155            signal,
156            g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
157        )
158    elif signal_rank is None or signal_rank > 2:
159        raise errors.SymbolicValueError(
160            msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. "
161            f"Current rank of signal is {signal_rank}, please reduce it.",
162            value=input,
163        )
164
165    # Get window and make sure it's the same size as `win_length` or `n_fft`
166    n_win = symbolic_helper._get_tensor_dim_size(window, dim=0)
167    if n_win is not None:
168        win_length_default = win_length if win_length else n_fft
169        assert n_win == win_length_default, (
170            "Analysis window size must equal `win_length` or `n_fft`. "
171            f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})",
172        )
173
174        # Center window around zeros if needed (required by ONNX's STFT)
175        if n_win < n_fft:
176            left, right = _compute_edge_sizes(n_fft, n_win)
177            left_win = g.op("Constant", value_t=torch.zeros(left))
178            right_win = g.op("Constant", value_t=torch.zeros(right))
179            window = g.op("Concat", left_win, window, right_win, axis_i=0)
180
181    # Create window, if needed
182    if symbolic_helper._is_none(window):
183        if win_length:
184            if win_length > n_fft:
185                raise errors.SymbolicValueError(
186                    msg="The analysis window can't be longer than the size of the FFT. "
187                    f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.",
188                    value=input,
189                )
190
191            # Center window, if needed
192            left, right = _compute_edge_sizes(n_fft, win_length)
193            torch_window = torch.hstack(
194                (torch.zeros(left), torch.ones(win_length), torch.zeros(right))
195            )
196        else:
197            # Rectangle window
198            torch_window = torch.ones(n_fft)
199        assert torch_window.shape[0] == n_fft
200        window = g.op("Constant", value_t=torch_window)
201    window = g.op(
202        "Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type()
203    )
204
205    # Run STFT
206    result = g.op(
207        "STFT",
208        signal,
209        frame_step_const,
210        window,
211        frame_length_const,
212        onesided_i=1 if onesided is None or onesided else 0,
213    )
214
215    # Transpose to mimic torch.stft's behavior
216    result = g.op("Transpose", result, perm_i=[0, 2, 1, 3])
217
218    # Remove batch dimension, if needed
219    if signal_rank == 1:
220        result = g.op(
221            "Squeeze",
222            result,
223            g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
224        )
225
226    # Normalize, if needed
227    if normalized:
228        sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype()))
229        result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft))
230
231    return result
232