1# mypy: allow-untyped-defs 2"""This file exports ONNX ops for opset 15. 3 4Note [ONNX operators that are added/updated in opset 15] 5~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 6https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set 7New operators: 8 Bernoulli 9 CastLike 10 Optional 11 OptionalGetElement 12 OptionalHasElement 13 14Updated operators: 15 BatchNormalization https://github.com/onnx/onnx/pull/3545 16 Backwards compatible 17 TODO: test coverage for mixed types inputs. 18 Pow https://github.com/onnx/onnx/pull/3412 19 Backwards compatible 20 TODO: bfloat16 support. 21 Shape https://github.com/onnx/onnx/pull/3580 22 Backwards compatible 23 TODO: optional start/end attribute. 24""" 25 26# EDITING THIS FILE? READ THIS FIRST! 27# see Note [Edit Symbolic Files] in README.md 28 29import functools 30 31import torch 32from torch import _C 33from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 34from torch.onnx._internal import jit_utils, registration 35 36 37_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15) 38 39 40@_onnx_symbolic("aten::__is_") 41def aten__is_(g: jit_utils.GraphContext, self, other): 42 if symbolic_helper._is_none(other): 43 if isinstance(self.type(), _C.OptionalType): 44 none = g.op("OptionalHasElement", self) 45 return g.op("Not", none) 46 else: 47 return g.op("Constant", value_t=torch.BoolTensor([0])) 48 return opset9.eq(g, self, other) 49 50 51@_onnx_symbolic("aten::__isnot_") 52@opset9.wrap_logical_op_with_negation # type: ignore[has-type] 53def aten__isnot_(g: jit_utils.GraphContext, self, other): 54 return aten__is_(g, self, other) 55 56 57@_onnx_symbolic("aten::bernoulli") 58def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): 59 if out is not None and not symbolic_helper._is_none(out): 60 symbolic_helper._unimplemented( 61 "Bernoulli", "out parameter is not supported for bernoulli", input 62 ) 63 if generator is not None and not symbolic_helper._is_none(generator): 64 symbolic_helper._unimplemented( 65 "Bernoulli", "generator is not supported for bernoulli", input 66 ) 67 if p is None or symbolic_helper._is_none(p): 68 return g.op("Bernoulli", input) 69 return opset9.bernoulli(g, input, p, generator, out) 70 71 72@_onnx_symbolic("prim::unchecked_cast") 73def prim_unchecked_cast(g: jit_utils.GraphContext, self): 74 # exists to refine the type of the Value 75 # if x is Optional[Tensor], unchecked_cast will cast 76 # x to Tensor, so the rest of the graph knows that x is a Tensor. 77 if isinstance(self.type(), _C.OptionalType): 78 return g.op("OptionalGetElement", self) 79 80 return self 81