1# mypy: allow-untyped-defs 2""" 3Note [ONNX operators that are added/updated from opset 8 to opset 9] 4~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 5New operators: 6 Compress 7 ConstantOfShape 8 EyeLike 9 MaxUnpool 10 OneHot 11 Sinh 12 Cosh 13 Asinh 14 Acosh 15 Atanh 16 Shrink 17 IsNaN 18 Sign 19 Erf 20 Scatter 21 Where 22 NonZero 23 TfIdfVectorizer 24 MeanVarianceNormalization 25 26Updated operators: 27 BatchNormalization: removed spatial attribute. 28 Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported. 29 Cast: more data types{string} supported. 30 Upsample: moved scales from attribute to input. 31 Scan 32""" 33 34import functools 35import warnings 36 37import torch 38from torch._C import _onnx as _C_onnx 39from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9 40from torch.onnx._internal import jit_utils, registration 41 42 43_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8) 44 45block_listed_operators = ( 46 "nonzero", 47 "where", 48 "scatter", 49 "scatter_add", 50 "erf", 51 "sign", 52 "isnan", 53 "gather", 54 "arange", 55 "masked_fill", 56 "index_fill", 57 "index_copy", 58 "repeat_interleave", 59 "any", 60 "all", 61) 62 63for block_listed_op in block_listed_operators: 64 _onnx_symbolic(f"aten::{block_listed_op}")( 65 symbolic_helper._block_list_in_opset(block_listed_op) 66 ) 67 68 69@_onnx_symbolic( 70 "aten::upsample_nearest1d", 71 decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], 72) 73@_onnx_symbolic( 74 "aten::upsample_nearest2d", 75 decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], 76) 77@_onnx_symbolic( 78 "aten::upsample_nearest3d", 79 decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], 80) 81@_onnx_symbolic( 82 "aten::upsample_linear1d", 83 decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], 84) 85@_onnx_symbolic( 86 "aten::upsample_bilinear2d", 87 decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], 88) 89@_onnx_symbolic( 90 "aten::upsample_trilinear3d", 91 decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], 92) 93def _interpolate(name, dim, interpolate_mode): 94 def symbolic_fn(g, input, output_size, *args): 95 scales, align_corners = symbolic_helper._get_interpolate_attributes( 96 g, interpolate_mode, args 97 ) 98 symbolic_helper._interpolate_warning(interpolate_mode) 99 align_corners = symbolic_helper._maybe_get_scalar(align_corners) 100 if align_corners: 101 return symbolic_helper._unimplemented(name, "align_corners == True", input) 102 output_size = symbolic_helper._maybe_get_const(output_size, "is") 103 if symbolic_helper._is_value(output_size): 104 return symbolic_helper._unimplemented( 105 name, "torch._C.Value (output_size) indexing" 106 ) 107 if scales is None: 108 scales = [ 109 1.0 110 if i < 2 111 else float(output_size[-(dim - i)]) 112 / float(input.type().sizes()[-(dim - i)]) 113 for i in range(0, dim) 114 ] 115 return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales) 116 117 return symbolic_fn 118 119 120@_onnx_symbolic("aten::__interpolate") 121def __interpolate( 122 g: jit_utils.GraphContext, 123 input, 124 size, 125 scale_factor, 126 mode, 127 align_corners, 128 recompute_scale_factor, 129 antialias, 130): 131 align_corners = symbolic_helper._maybe_get_const(align_corners, "b") 132 if not symbolic_helper._is_none(align_corners) and align_corners: 133 return symbolic_helper._unimplemented("interpolate", "align_corners == True") 134 135 if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value( 136 scale_factor 137 ): 138 return symbolic_helper._unimplemented( 139 "interpolate", "dynamic scales in opset 8" 140 ) 141 142 if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size): 143 return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8") 144 145 scales, mode = symbolic_helper._interpolate_get_scales_and_mode( 146 g, input, size, scale_factor, mode, align_corners 147 ) 148 return g.op("Upsample", input, mode_s=mode, scales_f=scales) 149 150 151# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation 152# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which 153# is lost after casting. 154def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args): 155 floating_scalar_types = { 156 _type_utils.JitScalarType.HALF, 157 _type_utils.JitScalarType.FLOAT, 158 _type_utils.JitScalarType.DOUBLE, 159 } 160 old_type = None 161 # Cast the input tensor to Float if its scalarType is known and is not floating number. 162 # If casting is performed, return the old scalarType, otherwise return None. 163 arg0_type = _type_utils.JitScalarType.from_value( 164 args[0], _type_utils.JitScalarType.UNDEFINED 165 ) 166 if arg0_type != _type_utils.JitScalarType.UNDEFINED: 167 old_type = arg0_type 168 if old_type not in floating_scalar_types: 169 old_type = old_type.scalar_name() # type: ignore[assignment] 170 args = tuple( 171 g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT) 172 for arg in args 173 ) 174 else: 175 return (None,) + args 176 else: 177 warnings.warn( 178 "Only floating datatype is supported for these operators: " 179 "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause " 180 "the onnx model to be incorrect, if inputs have integer datatypes." 181 ) 182 return (old_type,) + args 183 184 185def _cast_to_type(g: jit_utils.GraphContext, input, to_type): 186 if to_type is None: 187 return input 188 return getattr(opset9, f"_cast_{to_type}")(g, input, False) 189 190 191def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name): 192 other = symbolic_helper._maybe_get_scalar(other) 193 other = symbolic_helper._if_scalar_type_as(other, input) 194 _, input, other = _try_cast_integer_to_float(g, input, other) 195 return g.op(op_name, input, other) 196 197 198# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten}, 199# integer input type not supported in opset8. Cast to float if possible. 200@_onnx_symbolic("aten::gt") 201def gt(g: jit_utils.GraphContext, input, other): 202 return _comparison_operator(g, input, other, "Greater") 203 204 205@_onnx_symbolic("aten::lt") 206def lt(g: jit_utils.GraphContext, input, other): 207 return _comparison_operator(g, input, other, "Less") 208 209 210@_onnx_symbolic("aten::bmm") 211def bmm(g: jit_utils.GraphContext, self, other): 212 if symbolic_helper._try_get_scalar_type(self): 213 old_type, self, other = _try_cast_integer_to_float(g, self, other) 214 return _cast_to_type(g, g.op("MatMul", self, other), old_type) 215 else: 216 return g.op("MatMul", self, other) 217 218 219@_onnx_symbolic("aten::matmul") 220def matmul(g: jit_utils.GraphContext, self, other): 221 return bmm(g, self, other) 222 223 224@_onnx_symbolic("aten::prelu") 225def prelu(g: jit_utils.GraphContext, self, weight): 226 self_rank = symbolic_helper._get_tensor_rank(self) 227 weight_sizes = symbolic_helper._get_tensor_sizes(weight) 228 if self_rank is not None and self_rank > 2: 229 weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) 230 elif self_rank == 0 and weight_sizes == [1]: 231 # self and weight are both scalar but weight has rank == 1, squeeze weight. 232 weight = symbolic_helper._squeeze_helper(g, weight, [0]) 233 if symbolic_helper._try_get_scalar_type(self): 234 old_type, self, weight = _try_cast_integer_to_float(g, self, weight) 235 return _cast_to_type(g, g.op("PRelu", self, weight), old_type) 236 else: 237 return g.op("PRelu", self, weight) 238 239 240@_onnx_symbolic("aten::mm") 241def mm(g: jit_utils.GraphContext, self, other): 242 # Create a dummy C tensor. Only needed for API purposes, the value is 243 # since beta = 0 244 scalar_type = symbolic_helper._try_get_scalar_type(self, other) 245 if scalar_type is None: 246 raise errors.SymbolicValueError( 247 "mm can only operate on tensors with known types", self 248 ) 249 zero_constant = g.op( 250 "Constant", 251 value_t=torch.tensor([0], dtype=scalar_type.dtype()), 252 ) 253 254 if symbolic_helper._try_get_scalar_type(self): 255 old_type, self, other, zero_constant = _try_cast_integer_to_float( 256 g, self, other, zero_constant 257 ) 258 return _cast_to_type( 259 g, 260 g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0), 261 old_type, 262 ) 263 return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0) 264 265 266@_onnx_symbolic("aten::addmm") 267@symbolic_helper.parse_args("v", "v", "v", "t", "t") 268def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): 269 if symbolic_helper._try_get_scalar_type(self): 270 old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2) 271 return _cast_to_type( 272 g, 273 g.op( 274 "Gemm", 275 mat1, 276 mat2, 277 self, 278 beta_f=symbolic_helper._scalar(beta), 279 alpha_f=symbolic_helper._scalar(alpha), 280 ), 281 old_type, 282 ) 283 else: 284 return g.op( 285 "Gemm", 286 mat1, 287 mat2, 288 self, 289 beta_f=symbolic_helper._scalar(beta), 290 alpha_f=symbolic_helper._scalar(alpha), 291 ) 292 293 294@_onnx_symbolic("aten::flatten") 295def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): 296 start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim") 297 end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim") 298 299 dim = input.type().dim() 300 if end_dim_i < 0: 301 end_dim_i = dim + end_dim_i 302 # use ONNX's Flatten operator for cases where the output shape is 2D 303 if start_dim_i == 1 and end_dim_i == dim - 1: 304 if symbolic_helper._try_get_scalar_type(input): 305 old_type, input = _try_cast_integer_to_float(g, input) 306 return _cast_to_type( 307 g, g.op("Flatten", input, axis_i=start_dim_i), old_type 308 ) 309 else: 310 return g.op("Flatten", input, axis_i=start_dim_i) 311 if start_dim_i == 0 and end_dim_i == dim - 2: 312 if symbolic_helper._try_get_scalar_type(input): 313 old_type, input = _try_cast_integer_to_float(g, input) 314 return _cast_to_type( 315 g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type 316 ) 317 else: 318 return g.op("Flatten", input, axis_i=end_dim_i + 1) 319 320 return opset9.flatten(g, input, start_dim, end_dim) 321 322 323def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value): 324 if dtype is None: 325 scalar_type = _type_utils.JitScalarType.FLOAT 326 else: 327 scalar_type = _type_utils.JitScalarType(dtype) 328 if not scalar_type.dtype().is_floating_point: 329 result = g.op( 330 "ConstantFill", 331 sizes, 332 dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(), 333 input_as_shape_i=1, 334 value_f=const_value, 335 ) 336 return g.op("Cast", result, to_i=scalar_type.onnx_type()) 337 else: 338 return g.op( 339 "ConstantFill", 340 sizes, 341 dtype_i=scalar_type.onnx_type(), 342 input_as_shape_i=1, 343 value_f=const_value, 344 ) 345 346 347@_onnx_symbolic("aten::empty") 348@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") 349def empty( 350 g: jit_utils.GraphContext, 351 sizes, 352 dtype, 353 layout, 354 device, 355 pin_memory=False, 356 memory_format=None, 357): 358 return zeros(g, sizes, dtype, layout, device, pin_memory) 359 360 361@_onnx_symbolic("aten::empty_like") 362@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") 363def empty_like( 364 g: jit_utils.GraphContext, 365 input, 366 dtype, 367 layout, 368 device, 369 pin_memory=False, 370 memory_format=None, 371): 372 return zeros_like(g, input, dtype, layout, device, pin_memory) 373 374 375@_onnx_symbolic("aten::zeros") 376@symbolic_helper.parse_args("v", "i", "v", "v", "v") 377def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): 378 # NOTE: no way to set device and layout in ONNX, so we ignore it 379 return _constant_fill(g, sizes, dtype, 0) 380 381 382@_onnx_symbolic("aten::zeros_like") 383@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") 384def zeros_like( 385 g: jit_utils.GraphContext, 386 input, 387 dtype, 388 layout, 389 device, 390 pin_memory=False, 391 memory_format=None, 392): 393 shape = g.op("Shape", input) 394 return _constant_fill(g, shape, dtype, 0) 395 396 397@_onnx_symbolic("aten::ones") 398@symbolic_helper.parse_args("v", "i", "v", "v", "v") 399def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): 400 return _constant_fill(g, sizes, dtype, 1) 401 402 403@_onnx_symbolic("aten::ones_like") 404@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") 405def ones_like( 406 g: jit_utils.GraphContext, 407 input, 408 dtype, 409 layout, 410 device, 411 pin_memory=False, 412 memory_format=None, 413): 414 shape = g.op("Shape", input) 415 return _constant_fill(g, shape, dtype, 1) 416 417 418@_onnx_symbolic("aten::full") 419def full( 420 g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False 421): 422 const_value = symbolic_helper._maybe_get_const(value, "t") 423 if symbolic_helper._is_value(const_value): 424 tmp = zeros(g, sizes, dtype, layout, device) 425 return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) 426 else: 427 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 428 return _constant_fill(g, sizes, dtype, const_value) 429 430 431@_onnx_symbolic("aten::full_like") 432@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v") 433def full_like( 434 g: jit_utils.GraphContext, 435 input, 436 fill_value, 437 dtype, 438 layout, 439 device, 440 pin_memory=False, 441 memory_format=None, 442): 443 shape = g.op("Shape", input) 444 return _constant_fill(g, shape, dtype, fill_value) 445 446 447@_onnx_symbolic("aten::repeat") 448def repeat(g: jit_utils.GraphContext, self, repeats): 449 if not symbolic_helper._is_value(repeats): 450 repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) 451 if symbolic_helper._is_packed_list(repeats): 452 repeat_size_len = len(symbolic_helper._unpack_list(repeats)) 453 else: 454 const_repeats = symbolic_helper._maybe_get_const(repeats, "is") 455 repeat_size_len = len(const_repeats) 456 if self.isCompleteTensor(): 457 sizes = self.type().sizes() 458 diff_dims = repeat_size_len - len(sizes) 459 if diff_dims > 0: 460 self = opset9.view( 461 g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes)) 462 ) 463 return g.op("Tile", self, repeats) 464