xref: /aosp_15_r20/external/pytorch/torch/onnx/symbolic_opset12.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# mypy: disable-error-code=arg-type
3from __future__ import annotations
4
5import functools
6import sys
7
8import torch
9from torch._C import _onnx as _C_onnx
10from torch.onnx import (
11    _type_utils,
12    errors,
13    symbolic_helper,
14    symbolic_opset9 as opset9,
15    utils,
16)
17from torch.onnx._internal import jit_utils, registration
18
19
20# EDITING THIS FILE? READ THIS FIRST!
21# see Note [Edit Symbolic Files] in README.md
22
23# This file exports ONNX ops for opset 12
24
25__all__ = [
26    "argmax",
27    "argmin",
28    "binary_cross_entropy_with_logits",
29    "celu",
30    "cross_entropy_loss",
31    "dropout",
32    "einsum",
33    "ge",
34    "le",
35    "native_dropout",
36    "nll_loss",
37    "nll_loss2d",
38    "nll_loss_nd",
39    "outer",
40    "pow",
41    "tensordot",
42    "unfold",
43]
44
45_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)
46
47
48def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
49    if not tensors:
50        raise RuntimeError("Einsum inputs are empty.")
51    # ONNX does not support bool for Einsum inputs.
52    if symbolic_helper._is_bool(tensors[0]):
53        tensors = [
54            g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64)
55            for tensor in tensors
56        ]
57        return g.op(
58            "Cast",
59            g.op("Einsum", *tensors, equation_s=equation),
60            to_i=_C_onnx.TensorProtoDataType.BOOL,
61        )
62    else:
63        return g.op("Einsum", *tensors, equation_s=equation)
64
65
66@_onnx_symbolic("aten::einsum")
67@symbolic_helper.parse_args("s", "v", "is")
68def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
69    tensors = symbolic_helper._unpack_list(tensor_list)
70    return _einsum_helper(g, equation, tensors)
71
72
73@_onnx_symbolic("aten::outer")
74@symbolic_helper.parse_args("v", "v")
75def outer(g: jit_utils.GraphContext, input, other):
76    # make sure to cast other to self's type
77    if _type_utils.JitScalarType.from_value(
78        other, _type_utils.JitScalarType.UNDEFINED
79    ) != _type_utils.JitScalarType.from_value(input):
80        other = g.op(
81            "Cast",
82            other,
83            to_i=_type_utils.JitScalarType.from_value(input).onnx_type(),
84        )
85    return _einsum_helper(g, "i,j->ij", [input, other])
86
87
88def _dropout_returns_masked_input_and_mask(
89    g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
90) -> tuple[torch._C.Value, torch._C.Value | None]:
91    symbolic_helper.check_training_mode(train, "dropout")
92    # In eval mode, dropout is non-op. That is, if the node's
93    # train param is set to False, dropout just returns its inputs.
94    if not train:
95        return input, None
96    p = g.op("Constant", value_t=torch.tensor(p))
97    t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
98    r, mask = g.op("Dropout", input, p, t, outputs=2)
99    return r, mask
100
101
102@_onnx_symbolic("aten::dropout")
103@symbolic_helper.parse_args("v", "f", "b")
104def dropout(g: jit_utils.GraphContext, input, p, train):
105    masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
106    return masked
107
108
109@_onnx_symbolic("aten::native_dropout")
110@symbolic_helper.parse_args("v", "f", "b")
111def native_dropout(g: jit_utils.GraphContext, input, p, train):
112    return _dropout_returns_masked_input_and_mask(g, input, p, train)
113
114
115@_onnx_symbolic("aten::nll_loss")
116def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
117    # none reduction : onnx::Constant[value={0}]
118    # mean reduction : onnx::Constant[value={1}]
119    # sum reduction : onnx::Constant[value={2}]
120    reduction = symbolic_helper._maybe_get_const(reduction, "i")
121    reduction_vals = ["none", "mean", "sum"]
122    reduction = reduction_vals[reduction]
123
124    # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
125    # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
126    ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
127    if weight.node().mustBeNone():
128        nllloss = g.op(
129            "NegativeLogLikelihoodLoss",
130            self,
131            target,
132            reduction_s=reduction,
133            ignore_index_i=ignore_index,
134        )
135    else:
136        nllloss = g.op(
137            "NegativeLogLikelihoodLoss",
138            self,
139            target,
140            weight,
141            reduction_s=reduction,
142            ignore_index_i=ignore_index,
143        )
144
145    return nllloss
146
147
148@_onnx_symbolic("aten::nll_loss2d")
149def nll_loss2d(
150    g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
151):
152    return nll_loss(g, self, target, weight, reduction, ignore_index)
153
154
155@_onnx_symbolic("aten::nll_loss_nd")
156def nll_loss_nd(
157    g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
158):
159    return nll_loss(g, self, target, weight, reduction, ignore_index)
160
161
162@_onnx_symbolic("aten::cross_entropy_loss")
163def cross_entropy_loss(
164    g: jit_utils.GraphContext,
165    self,
166    target,
167    weight,
168    reduction,
169    ignore_index,
170    label_smoothing,
171):
172    # none reduction : onnx::Constant[value={0}]
173    # mean reduction : onnx::Constant[value={1}]
174    # sum reduction : onnx::Constant[value={2}]
175    reduction = symbolic_helper._maybe_get_const(reduction, "i")
176    reduction_vals = ["none", "mean", "sum"]
177    reduction = reduction_vals[reduction]
178
179    label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f")
180    if label_smoothing is not None and label_smoothing > 0.0:
181        raise errors.SymbolicValueError(
182            "Unsupported: ONNX does not support label_smoothing", self
183        )
184
185    # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value.
186    # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
187    ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
188    if weight.node().mustBeNone():
189        celoss = g.op(
190            "SoftmaxCrossEntropyLoss",
191            self,
192            target,
193            reduction_s=reduction,
194            ignore_index_i=ignore_index,
195        )
196    else:
197        celoss = g.op(
198            "SoftmaxCrossEntropyLoss",
199            self,
200            target,
201            weight,
202            reduction_s=reduction,
203            ignore_index_i=ignore_index,
204        )
205
206    return celoss
207
208
209@_onnx_symbolic("aten::binary_cross_entropy_with_logits")
210@symbolic_helper.parse_args("v", "v", "v", "v", "i")
211def binary_cross_entropy_with_logits(
212    g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
213):
214    p = g.op("Constant", value_t=torch.tensor([1]))
215    sig_x = opset9.sigmoid(g, input)
216    log_sig_x = opset9.log(g, sig_x)
217    sub_1_x = opset9.sub(g, p, sig_x)
218    sub_1_y = opset9.sub(g, p, target)
219    log_1_x = opset9.log(g, sub_1_x)
220    if pos_weight is None or symbolic_helper._is_none(pos_weight):
221        output = opset9.neg(
222            g,
223            opset9.add(
224                g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)
225            ),
226        )
227    else:
228        output = opset9.neg(
229            g,
230            opset9.add(
231                g,
232                opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
233                opset9.mul(g, sub_1_y, log_1_x),
234            ),
235        )
236
237    if weight is not None and not symbolic_helper._is_none(weight):
238        output = opset9.mul(g, weight, output)
239
240    reduction = symbolic_helper._maybe_get_const(reduction, "i")
241    if reduction == 0:
242        return output
243    elif reduction == 1:
244        return g.op("ReduceMean", output, keepdims_i=0)
245    elif reduction == 2:
246        return g.op("ReduceSum", output, keepdims_i=0)
247    else:
248        return symbolic_helper._onnx_unsupported(
249            "binary_cross_entropy_with_logits with reduction other than none, mean, or sum",
250            input,
251        )
252
253
254@_onnx_symbolic("aten::celu")
255def celu(g: jit_utils.GraphContext, self, alpha):
256    alpha = symbolic_helper._maybe_get_const(alpha, "f")
257    # if the input is of type double cast it to float
258    if (
259        _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
260        == _type_utils.JitScalarType.DOUBLE
261    ):
262        self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
263        out = g.op("Celu", self, alpha_f=alpha)
264        return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
265
266    return g.op("Celu", self, alpha_f=alpha)
267
268
269@_onnx_symbolic("aten::argmax")
270@symbolic_helper.parse_args("v", "v", "b")
271def argmax(
272    g: jit_utils.GraphContext,
273    input: torch._C.Value,
274    dim: torch._C.Value,
275    keepdim: bool,
276):
277    return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
278
279
280@_onnx_symbolic("aten::argmin")
281@symbolic_helper.parse_args("v", "v", "b")
282def argmin(
283    g: jit_utils.GraphContext,
284    input: torch._C.Value,
285    dim: torch._C.Value,
286    keepdim: bool,
287):
288    return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
289
290
291@_onnx_symbolic("aten::pow")
292def pow(g: jit_utils.GraphContext, self, exponent):
293    return g.op("Pow", self, exponent)
294
295
296@_onnx_symbolic("aten::ge")
297def ge(g: jit_utils.GraphContext, input, other):
298    return g.op("GreaterOrEqual", input, other)
299
300
301@_onnx_symbolic("aten::le")
302def le(g: jit_utils.GraphContext, input, other):
303    return g.op("LessOrEqual", input, other)
304
305
306@_onnx_symbolic("aten::unfold")
307@symbolic_helper.parse_args("v", "i", "v", "v")
308def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
309    const_size = symbolic_helper._maybe_get_const(size, "i")
310    const_step = symbolic_helper._maybe_get_const(step, "i")
311    if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value(
312        const_step
313    ):
314        return opset9.unfold(g, input, dimension, const_size, const_step)
315
316    sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
317    if sizedim is not None:
318        low_start = g.op("Constant", value_t=torch.tensor(0))
319        low_end = g.op("Constant", value_t=torch.tensor(sizedim))
320        hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
321        low_indices = g.op("Range", low_start, low_end, step)
322        hi_indices = g.op("Range", size, hi_end, step)
323
324        low_size = symbolic_helper._size_helper(
325            g, low_indices, g.op("Constant", value_t=torch.tensor(0))
326        )
327        hi_size = symbolic_helper._size_helper(
328            g, hi_indices, g.op("Constant", value_t=torch.tensor(0))
329        )
330
331        ndim = symbolic_helper._get_tensor_rank(input)
332        assert ndim is not None
333        perm = list(range(0, ndim))
334        perm.append(perm.pop(dimension))
335
336        unsqueeze_list = []
337        loop_condition = g.op("Constant", value_t=torch.tensor(1))
338        loop_condition = g.op(
339            "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
340        )
341        loop_len = g.op("Min", low_size, hi_size)
342
343        loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
344            g, "Loop", loop_len, loop_condition, n_blocks=1
345        )
346
347        loop_block = loop_context.block
348        block_input_iter = utils._add_input_to_block(loop_block)
349        # FIXME(justinchuby): cond is unused?
350        cond = utils._add_input_to_block(loop_block)
351
352        starts = loop_context.op("Gather", low_indices, block_input_iter)
353        ends = loop_context.op("Gather", hi_indices, block_input_iter)
354        axes = loop_context.op("Constant", value_t=torch.tensor([2]))
355        starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0])
356        ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0])
357        stack = loop_context.op("Slice", input, starts, ends, axes)
358
359        unsqueeze = symbolic_helper._unsqueeze_helper(
360            loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension]
361        )
362        unsqueeze_list.append(unsqueeze)
363        concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0)
364
365        cond_out = loop_context.op(
366            "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL
367        )
368        utils._add_output_to_block(loop_block, cond_out)
369        utils._add_output_to_block(loop_block, concat)
370
371        loop_output = loop.node().output()
372        perm = [0, 1, 2, 3, 4]
373        perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
374        transpose = g.op("Transpose", loop_output, perm_i=perm)
375        squeeze = symbolic_helper._squeeze_helper(g, transpose, [0])
376
377        return squeeze
378
379    return symbolic_helper._unimplemented("Unfold", "input size not accessible")
380
381
382@_onnx_symbolic("aten::tensordot")
383@symbolic_helper.parse_args("v", "v", "is", "is", "v")
384def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None):
385    if out is not None:
386        symbolic_helper._unimplemented(
387            "Tensordot", "Out parameter is not supported for tensordot."
388        )
389
390    dim_count_a = symbolic_helper._get_tensor_rank(input_a)
391    if dim_count_a is None:
392        raise errors.SymbolicValueError(
393            "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.",
394            input_a,
395        )
396
397    dim_count_b = symbolic_helper._get_tensor_rank(input_b)
398    if dim_count_b is None:
399        raise errors.SymbolicValueError(
400            "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.",
401            input_b,
402        )
403
404    dims_a = [
405        (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i]
406        for i in range(len(dims_a))
407    ]
408    dims_b = [
409        (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i]
410        for i in range(len(dims_b))
411    ]
412
413    left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
414    left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]
415
416    new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a)
417    new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b)
418
419    input_shape = g.op("Shape", new_input_a)
420    left_sizes_a = symbolic_helper._slice_helper(
421        g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]
422    )
423    shape_sizes = [
424        left_sizes_a,
425        g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
426    ]
427    output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
428
429    input_shape = g.op("Shape", output_a)
430    slices = symbolic_helper._slice_helper(
431        g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
432    )
433    shape_sizes = [
434        g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
435        slices,
436    ]
437    output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
438
439    input_shape = g.op("Shape", new_input_b)
440    left_sizes_b = symbolic_helper._slice_helper(
441        g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize]
442    )
443    slices = symbolic_helper._slice_helper(
444        g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]
445    )
446    shape_sizes = [
447        slices,
448        g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
449    ]
450    output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
451
452    input_shape = g.op("Shape", output_b)
453    slices = symbolic_helper._slice_helper(
454        g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
455    )
456    shape_sizes = [
457        g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
458        slices,
459    ]
460    output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
461
462    output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))
463
464    shape_sizes = [left_sizes_a, left_sizes_b]
465    return opset9._reshape_from_tensor(g, output, shape_sizes)
466