1# mypy: allow-untyped-defs 2# EDITING THIS FILE? READ THIS FIRST! 3# see Note [Edit Symbolic Files] in README.md 4 5# This file exports ONNX ops for opset 13 6import functools 7 8import torch 9import torch._C._onnx as _C_onnx 10from torch.onnx import ( 11 _constants, 12 _type_utils, 13 errors, 14 symbolic_helper, 15 symbolic_opset11 as opset11, 16 symbolic_opset9 as opset9, 17 utils, 18) 19from torch.onnx._internal import jit_utils, registration 20 21 22_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13) 23 24 25@_onnx_symbolic("aten::softmax") 26@symbolic_helper.parse_args("v", "i", "none") 27def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): 28 softmax = g.op("Softmax", input, axis_i=dim) 29 if dtype and dtype.node().kind() != "prim::Constant": 30 parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") 31 softmax = g.op( 32 "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() 33 ) 34 35 return softmax 36 37 38@_onnx_symbolic("aten::log_softmax") 39@symbolic_helper.parse_args("v", "i", "none") 40def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): 41 return_op = g.op("LogSoftmax", input, axis_i=dim) 42 if dtype and dtype.node().kind() != "prim::Constant": 43 parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") 44 return_op = g.op( 45 "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() 46 ) 47 return return_op 48 49 50@_onnx_symbolic("aten::frobenius_norm") 51@symbolic_helper.parse_args("v", "v", "i") 52def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): 53 dim_val = symbolic_helper._maybe_get_const(dim, "is") 54 if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0: 55 return g.op("ReduceL2", self, keepdims_i=0) 56 sqr = g.op("Mul", self, self) 57 sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim) 58 return g.op("Sqrt", sumsqr) 59 60 61@_onnx_symbolic("aten::split") 62@symbolic_helper.parse_args("v", "v", "i", "i") 63def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): 64 if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): 65 split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) 66 if _outputs is None: 67 return split_out 68 # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. 69 if ( 70 symbolic_helper._is_packed_list(split_size_or_sizes) 71 and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs 72 ): 73 split_sizes = [ 74 symbolic_helper._unsqueeze_helper(g, v, [0]) 75 for v in symbolic_helper._unpack_list(split_size_or_sizes) 76 ] 77 78 start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) 79 axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) 80 res = [] 81 for i in range(_outputs): 82 end = g.op( 83 "Add", start, split_sizes[i] 84 ) # split_sizes is a list of same length as _outputs 85 res.append(g.op("Slice", self, start, end, axis)) 86 start = end 87 return res 88 return [ 89 g.op( 90 "SequenceAt", 91 split_out, 92 g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), 93 ) 94 for i in range(_outputs) 95 ] 96 97 split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") 98 if split_val.dim() > 0: 99 return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) 100 split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") 101 102 size = symbolic_helper._get_tensor_dim_size(self, dim) 103 if size is None: 104 if _outputs is not None: 105 size = split_size * _outputs 106 else: 107 raise errors.SymbolicValueError( 108 "Unknown dimension size not supported", self 109 ) 110 splits = [split_size] * (size // split_size) 111 leftover = size % split_size 112 if leftover: 113 splits.append(leftover) 114 splits = g.op("Constant", value_t=torch.tensor(splits)) 115 return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) 116 117 118@_onnx_symbolic("aten::split_with_sizes") 119def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): 120 return split(g, self, split_sizes, dim, _outputs) 121 122 123@_onnx_symbolic("aten::unsafe_split") 124def unsafe_split( 125 g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None 126): 127 return split(g, self, split_size_or_sizes, dim, _outputs) 128 129 130@_onnx_symbolic("aten::unsafe_split_with_sizes") 131def unsafe_split_with_sizes( 132 g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None 133): 134 return split_with_sizes(g, self, split_sizes, dim, _outputs) 135 136 137@_onnx_symbolic("aten::tensor_split") 138@symbolic_helper.parse_args("v", "v", "i", "i") 139def tensor_split( 140 g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None 141): 142 axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) 143 axis = opset11.unsqueeze(g, axis, 0) 144 const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) 145 146 if symbolic_helper._is_split_static(indices_or_sections, _outputs): 147 split_val = symbolic_helper._node_get(indices_or_sections.node(), "value") 148 149 if split_val.dim() > 0: 150 start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) 151 res = [] 152 assert _outputs is not None 153 for i in range(_outputs - 1): 154 end = g.op( 155 "Gather", 156 indices_or_sections, 157 g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), 158 axis_i=0, 159 ) 160 res.append(g.op("Slice", self, start, end, axis)) 161 start = end 162 163 end = symbolic_helper._size_helper(g, self, axis) 164 res.append(g.op("Slice", self, start, end, axis)) 165 return res 166 167 split_size = symbolic_helper._get_const( 168 indices_or_sections, "i", "indices_or_sections" 169 ) 170 171 size = symbolic_helper._get_tensor_dim_size(self, dim) 172 if size is None: 173 if _outputs is not None: 174 size = split_size * _outputs 175 else: 176 raise errors.SymbolicValueError( 177 "Unknown dimension size not supported", self 178 ) 179 180 min_split_size = size // split_size 181 num_splits_one_extra = size % split_size 182 183 splits = num_splits_one_extra * [min_split_size + 1] 184 leftover = (split_size - num_splits_one_extra) * [min_split_size] 185 186 splits = g.op( 187 "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long) 188 ) 189 return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) 190 191 if ( 192 symbolic_helper._is_tensor(indices_or_sections) 193 and symbolic_helper._get_tensor_rank(indices_or_sections) == 1 194 ): 195 loop_len = symbolic_helper._size_helper( 196 g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0)) 197 ) 198 loop_len = opset11.unsqueeze(g, loop_len, 0) 199 loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL) 200 201 # To make the first slice in the below loop work, 202 # we pad a zero to the first position so that it will be the initial start of slice. 203 padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) 204 indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0) 205 206 final_splits = g.op("SequenceEmpty") 207 # Loop inputs 208 loop, (loop_context,), _ = jit_utils.add_op_with_blocks( 209 g, "Loop", loop_len, loop_condition, final_splits, outputs=1, n_blocks=1 210 ) 211 212 loop_block = loop_context.block 213 block_input_iter = utils._add_input_to_block(loop_block) 214 cond = utils._add_input_to_block(loop_block) 215 final_splits = utils._add_input_to_block(loop_block) 216 217 start = loop_context.op( 218 "Gather", indices_or_sections, block_input_iter, axis_i=0 219 ) 220 end = loop_context.op( 221 "Gather", 222 indices_or_sections, 223 loop_context.op("Add", block_input_iter, const_1), 224 axis_i=0, 225 ) 226 227 slice = loop_context.op("Slice", self, start, end, axis) 228 final_splits = loop_context.op("SequenceInsert", final_splits, slice) 229 230 # Loop outputs 231 cond_out = loop_context.op("Identity", loop_condition) 232 utils._add_output_to_block(loop_block, cond_out) 233 utils._add_output_to_block(loop_block, final_splits) 234 235 loop_out = loop.node().output() 236 start = g.op( 237 "Gather", 238 indices_or_sections, 239 g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)), 240 axis_i=0, 241 ) 242 start = opset11.unsqueeze(g, start, 0) 243 end = symbolic_helper._size_helper(g, self, axis) 244 245 last_slice = g.op("Slice", self, start, end, axis) 246 247 return g.op("SequenceInsert", loop_out, last_slice) 248 249 else: # scalar tensor 250 dim_size = symbolic_helper._size_helper(g, self, axis) 251 min_split_size = g.op("Div", dim_size, indices_or_sections) 252 min_split_size_plus_1 = g.op( 253 "Add", 254 min_split_size, 255 const_1, 256 ) 257 num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections) 258 splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra) 259 leftover = g.op( 260 "Tile", 261 min_split_size, 262 g.op( 263 "Sub", 264 opset11.unsqueeze(g, indices_or_sections, 0), 265 num_splits_one_extra, 266 ), 267 ) 268 269 splits = g.op("Concat", splits, leftover, axis_i=0) 270 if _outputs is None: 271 return g.op("SplitToSequence", self, splits, axis_i=dim) 272 return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) 273 274 275@_onnx_symbolic("aten::unbind") 276@symbolic_helper.parse_args("v", "i", "i") 277def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): 278 if _outputs is None: 279 return g.op( 280 "SplitToSequence", 281 self, 282 g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), 283 axis_i=dim, 284 keepdims_i=0, 285 ) 286 287 splits = g.op("Constant", value_t=torch.tensor([1] * _outputs)) 288 outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs) 289 outputs = [outputs] if _outputs == 1 else outputs 290 squeezed_outputs = [ 291 g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim]))) 292 for out in outputs 293 ] 294 return squeezed_outputs 295 296 297@_onnx_symbolic("aten::nonzero_numpy") 298# Emitted from `torch.nonzero(x, as_tuple=True)` 299def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): 300 return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs) 301 302 303@_onnx_symbolic("aten::where") 304@symbolic_helper.parse_args("v", "v", "v", "i") 305def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): 306 # Assumes that torch.where's first argument takes only Bool and Byte tensors. 307 if not symbolic_helper._is_bool(condition): 308 condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) 309 if self is None: 310 condition = opset9.nonzero(g, condition) 311 return symbolic_helper._unbind_helper( 312 g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs 313 ) 314 return g.op("Where", condition, self, other) 315 316 317@_onnx_symbolic("aten::fake_quantize_per_channel_affine") 318@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i") 319def fake_quantize_per_channel_affine( 320 g: jit_utils.GraphContext, 321 inputs, 322 scale, 323 zero_point, 324 axis, 325 quant_min=-128, 326 quant_max=127, 327): 328 # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). 329 # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 330 if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: 331 raise errors.SymbolicValueError( 332 "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " 333 f"Got ({quant_min}, {quant_max})", 334 inputs, 335 ) 336 # ONNX defines zero_point to be int8 or uint8 337 if quant_min == 0: 338 zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) 339 else: 340 zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) 341 quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis) 342 if (quant_min, quant_max) == (0, 127): 343 quantized = g.op( 344 "Clip", 345 quantized, 346 opset9.unused(g), 347 g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), 348 ) 349 return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis) 350 351 352@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") 353@symbolic_helper.parse_args("v", "v", "v", "i", "i") 354def fake_quantize_per_tensor_affine( 355 g: jit_utils.GraphContext, 356 inputs, 357 scale, 358 zero_point, 359 quant_min=-128, 360 quant_max=127, 361): 362 # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). 363 # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 364 if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: 365 raise errors.SymbolicValueError( 366 "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " 367 f"Got ({quant_min}, {quant_max})", 368 inputs, 369 ) 370 if quant_min == 0: 371 zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) 372 else: 373 zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) 374 if ( 375 _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) 376 != _type_utils.JitScalarType.FLOAT 377 ): 378 scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) 379 quantized = g.op("QuantizeLinear", inputs, scale, zero_point) 380 if (quant_min, quant_max) == (0, 127): 381 quantized = g.op( 382 "Clip", 383 quantized, 384 opset9.unused(g), 385 g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), 386 ) 387 return g.op("DequantizeLinear", quantized, scale, zero_point) 388 389 390def _reduce_op_symbolic(onnx_op_name): 391 def symbolic(g, self, dim=None, keepdim=None): 392 self = symbolic_helper._maybe_cast_reduce_op_input(g, self) 393 if dim is None: 394 # all-reduce path 395 return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name) 396 else: 397 keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") 398 return g.op(onnx_op_name, self, dim, keepdims_i=keepdim) 399 400 return symbolic 401 402 403@_onnx_symbolic( 404 "aten::sum", 405 decorate=[symbolic_helper._apply_params("ReduceSum", "sum")], 406) 407def _reduce_with_dtype(onnx_op, name): 408 symbolic = _reduce_op_symbolic(onnx_op) 409 410 @symbolic_helper._overload_by_arg_count 411 def reduce(g, *args, **kwargs): 412 @symbolic_helper.parse_args("v", "none") 413 def reduce_nodim(g, self, dtype): 414 dtype_onnx = None 415 if dtype.node().kind() == "onnx::Constant": 416 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 417 dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() 418 self = g.op("Cast", self, to_i=dtype_onnx) 419 elif dtype.node().kind() != "prim::Constant": 420 return symbolic_helper._unimplemented(name, "dtype", dtype) 421 result = symbolic(g, self) 422 if dtype_onnx is not None: 423 result_dtype_onnx = _type_utils.JitScalarType.from_value( 424 result 425 ).onnx_type() 426 if result_dtype_onnx != dtype_onnx: 427 result = g.op("Cast", result, to_i=dtype_onnx) 428 return result 429 430 @symbolic_helper.parse_args("v", "v", "i", "none") 431 def reduce_dim(g, self, dim, keepdim, dtype): 432 dtype_onnx = None 433 if dtype.node().kind() == "onnx::Constant": 434 dtype = symbolic_helper._get_const(dtype, "i", "dtype") 435 dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() 436 self = g.op("Cast", self, to_i=dtype_onnx) 437 elif dtype.node().kind() != "prim::Constant": 438 return symbolic_helper._unimplemented(name, "dtype", dtype) 439 result = symbolic(g, self, dim, keepdim) 440 if dtype_onnx is not None: 441 result_dtype_onnx = _type_utils.JitScalarType.from_value( 442 result 443 ).onnx_type() 444 if result_dtype_onnx != dtype_onnx: 445 result = g.op("Cast", result, to_i=dtype_onnx) 446 return result 447 448 return reduce_nodim, reduce_dim 449 450 return reduce 451 452 453# Ported from 454# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/core.py#L6097 455# NOTE: Supporting aten::unflatten before opset13 needs helper function to adjust ONNX op changes in Concat, Slice, ... 456@_onnx_symbolic("aten::unflatten") 457def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size): 458 input_dim = symbolic_helper._get_tensor_rank(input) 459 if input_dim is None: 460 return symbolic_helper._unimplemented( 461 "dim", 462 "ONNX and PyTorch use different strategies to split the input. " 463 "Input rank must be known at export time.", 464 ) 465 466 # dim could be negative 467 input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64)) 468 dim = g.op("Add", input_dim, dim) 469 dim = g.op("Mod", dim, input_dim) 470 471 input_size = g.op("Shape", input) 472 473 head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) 474 head_end_idx = g.op( 475 "Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) 476 ) 477 head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx) 478 479 dim_plus_one = g.op( 480 "Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) 481 ) 482 tail_start_idx = g.op( 483 "Reshape", 484 dim_plus_one, 485 g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)), 486 ) 487 tail_end_idx = g.op( 488 "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) 489 ) 490 tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx) 491 492 final_shape = g.op( 493 "Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0 494 ) 495 496 return symbolic_helper._reshape_helper(g, input, final_shape) 497 498 499@_onnx_symbolic("aten::unsafe_chunk") 500@symbolic_helper.parse_args("v", "i", "i", "i") 501def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): 502 if _outputs is None: 503 return g.op( 504 "SplitToSequence", 505 self, 506 g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), 507 axis_i=dim, 508 keepdims_i=0, 509 ) 510 511 size = symbolic_helper._get_tensor_dim_size(self, dim) 512 if size is None: 513 return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size") 514 split_size = (size + chunks - 1) // chunks 515 splits = [split_size] * (size // split_size) 516 leftover = size % split_size 517 if leftover: 518 splits.append(leftover) 519 520 # TODO: So far we don"t have a module using this method. We"ll keep 521 # this as a constant unless we see a request of dynamics in any 522 # user's modules. 523 splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long)) 524 return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) 525 526 527@_onnx_symbolic("aten::tile") 528def tile(g: jit_utils.GraphContext, self, dims): 529 self_shape = g.op("Shape", self) 530 self_rank = g.op("Size", self_shape) 531 dims_rank = g.op("Size", dims) 532 diff = g.op("Sub", self_rank, dims_rank) 533 const_zero = g.op("Constant", value_t=torch.tensor([0])) 534 535 # 1. If dims is shorter than self.shape pad dims with 1 536 dims_shorter_than_self_shape = g.op("Greater", diff, const_zero) 537 ( 538 if_op_greater, 539 (if_context_greater, else_context_greater), 540 _, 541 ) = jit_utils.add_op_with_blocks( 542 g, "If", dims_shorter_than_self_shape, n_blocks=2, outputs=1 543 ) 544 const_one = if_context_greater.op("Constant", value_t=torch.LongTensor([1])) 545 diff_1d_greater = if_context_greater.op("Reshape", diff, const_one) 546 exapnd_ones_greater = if_context_greater.op("Expand", const_one, diff_1d_greater) 547 dims_ = if_context_greater.op("Concat", exapnd_ones_greater, dims, axis_i=0) 548 utils._add_output_to_block(if_context_greater.block, dims_) 549 identity_dim = else_context_greater.op("Identity", dims) 550 utils._add_output_to_block(else_context_greater.block, identity_dim) 551 dims_final = if_op_greater.node().output() 552 553 # 2. If dims is longer than self.shape pad self.shape with 1 554 dims_longer_than_self_shape = g.op("Less", diff, const_zero) 555 ( 556 if_op_less, 557 (if_context_less, else_context_less), 558 _, 559 ) = jit_utils.add_op_with_blocks( 560 g, "If", dims_longer_than_self_shape, n_blocks=2, outputs=1 561 ) 562 const_one = if_context_less.op("Constant", value_t=torch.LongTensor([1])) 563 diff_1d_less = if_context_less.op( 564 "Reshape", 565 if_context_less.op("Abs", diff), 566 const_one, 567 ) 568 exapnd_ones_less = if_context_less.op("Expand", const_one, diff_1d_less) 569 self_final_shape = if_context_less.op( 570 "Concat", exapnd_ones_less, self_shape, axis_i=0 571 ) 572 self_ = if_context_less.op("Reshape", self, self_final_shape) 573 utils._add_output_to_block(if_context_less.block, self_) 574 identity_self = else_context_less.op("Identity", self) 575 utils._add_output_to_block(else_context_less.block, identity_self) 576 self_final = if_op_less.node().output() 577 578 dims_final = g.op("Cast", dims_final, to_i=_C_onnx.TensorProtoDataType.INT64) 579 return g.op("Tile", self_final, dims_final) 580 581 582@_onnx_symbolic("aten::repeat_interleave") 583def repeat_interleave( 584 g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None 585): 586 repeats_dim = symbolic_helper._get_tensor_rank(repeats) 587 repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) 588 input_sizes = symbolic_helper._get_tensor_sizes(self) 589 if repeats_dim is None: 590 raise errors.SymbolicValueError( 591 "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", 592 self, 593 ) 594 if repeats_sizes is None: 595 raise errors.SymbolicValueError( 596 "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", 597 self, 598 ) 599 if input_sizes is None: 600 raise errors.SymbolicValueError( 601 "Unsupported: ONNX export of repeat_interleave for unknown input size.", 602 self, 603 ) 604 605 final_dim = dim 606 # if dim is None flatten 607 # By default, use the flattened input array, and return a flat output array 608 if symbolic_helper._is_none(dim): 609 self = symbolic_helper._reshape_helper( 610 g, self, g.op("Constant", value_t=torch.tensor([-1])) 611 ) 612 dim = torch.tensor(0, dtype=torch.int64) 613 else: 614 dim = symbolic_helper._maybe_get_scalar(dim) 615 616 # Handle cases where dim is negative 617 if dim < 0: 618 dim += len(input_sizes) 619 620 output_sizes = input_sizes.copy() 621 for idx, input_size in enumerate(input_sizes): 622 if input_size is None: 623 output_sizes[idx], input_sizes[idx] = 0, -1 624 625 # Check if all indices should be repeated the same number of times. 626 if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): 627 return symbolic_helper._repeat_interleave_single_value_repeat_helper( 628 g, self, repeats, dim 629 ) 630 631 cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None 632 # If input size is dynamic or repeats vector is dynamic 633 if output_sizes[dim] == 0 or cond_dynamic_repeats: 634 reps = symbolic_helper._size_helper(g, self, dim) 635 reps = opset11.unsqueeze(g, reps, 0) 636 637 # Check if repeats is dynamic 638 # As repeats is dynamic, we use a where node as a substitute for the if statement 639 # If repests_dim = 1, expand repeats otherwise use original tensor 640 if cond_dynamic_repeats: 641 repeat_dim = symbolic_helper._size_helper( 642 g, repeats, g.op("Constant", value_t=torch.LongTensor([0])) 643 ) 644 repeat_cond = g.op( 645 "Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])) 646 ) 647 repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats) 648 # There are cases when the repeats are 1-d tensor with multiple repeats, but dim 649 # provided along one of the dynamic axes provided. A simple example would be 650 # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 651 # Now, repeat interleaving can be performed in pytorch when the value of * matches 652 # with the number of elements in repeat, for example if * -> 2, number of repeats 653 # should be 2 as well. 654 else: 655 return opset9.repeat_interleave(g, self, repeats, final_dim) 656 657 reps_like = g.op( 658 "ConstantOfShape", 659 g.op("Shape", repeats), 660 value_t=torch.tensor([1], dtype=torch.long), 661 ) 662 r_splits = split(g, repeats, reps_like, 0) 663 i_splits = split(g, self, reps_like, dim) 664 665 output_sizes[dim], input_sizes[dim] = -1, 1 666 667 # Create a loop to iterate over each value along the dimension 668 # and perform individual interleaving using the repeats tensor 669 # Loop is of the following pattern 670 # input (trip_count, cond) 671 # int trip_count = ...; 672 # bool cond = ...; 673 # for (int i=0; i < trip_count && cond; ++i) { 674 # cond = ...; 675 # } 676 677 # Loop conditions 678 loop_condition = g.op("Constant", value_t=torch.tensor(1)) 679 loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) 680 loop_len = reps 681 682 # Create an empty sequence to store final expansions 683 final_splits = g.op("SequenceEmpty") 684 685 # Loop inputs 686 loop, (loop_context,), _ = jit_utils.add_op_with_blocks( 687 g, "Loop", loop_len, loop_condition, final_splits, n_blocks=1 688 ) 689 690 loop_block = loop_context.block 691 block_input_iter = utils._add_input_to_block(loop_block) 692 cond = utils._add_input_to_block(loop_block) 693 final_splits = utils._add_input_to_block(loop_block) 694 695 r_split = loop_context.op("SequenceAt", r_splits, block_input_iter) 696 i_split = loop_context.op("SequenceAt", i_splits, block_input_iter) 697 698 i_split = opset11.unsqueeze(loop_context, i_split, dim + 1) 699 r_concat = [ 700 loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])), 701 r_split, 702 loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])), 703 ] 704 r_concat = loop_context.op("Concat", *r_concat, axis_i=0) 705 i_split = opset9.expand(loop_context, i_split, r_concat, None) 706 i_split = symbolic_helper._reshape_helper( 707 loop_context, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes)) 708 ) 709 final_splits = loop_context.op("SequenceInsert", final_splits, i_split) 710 711 # Loop outputs 712 cond_out = loop_context.op( 713 "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL 714 ) 715 utils._add_output_to_block(loop_block, cond_out) 716 utils._add_output_to_block(loop_block, final_splits) 717 718 loop_out = loop.node().output() 719 loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim) 720 return loop_out 721 722 723@_onnx_symbolic("aten::diagonal") 724@symbolic_helper.parse_args("v", "i", "i", "i") 725def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2): 726 rank = symbolic_helper._get_tensor_rank(self) 727 # Replace negative indexing when rank is known 728 if rank is not None: 729 dim1 = dim1 if dim1 >= 0 else dim1 + rank 730 dim2 = dim2 if dim2 >= 0 else dim2 + rank 731 732 dim1_size = opset9.size( 733 g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1])) 734 ) 735 dim2_size = opset9.size( 736 g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2])) 737 ) 738 # Create appropriate mask 739 mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0) 740 mask = opset9.zeros(g, mask_shape, None, None, None) 741 mask = g.op("EyeLike", mask, k_i=offset) 742 # dim1 and dim2 appended as a dimension at the end of the shape 743 744 if rank is not None: 745 axes = list(range(rank)) 746 axes.remove(dim1) 747 axes.remove(dim2) 748 self = g.op("Transpose", self, perm_i=axes + [dim1, dim2]) 749 else: 750 return symbolic_helper._unimplemented("diagonal", "unknown input rank") 751 752 # Multiply input and mask to calculate values along diagonal 753 # The mask consists of one values where diagonal values are to be calculated 754 # For example: 755 # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0], 756 # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0], 757 # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]] 758 result = g.op("Mul", self, mask) 759 result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0) 760 761 # Calculate gather indices based on offset and dims 762 # If offset is greater than zero, set offset to zero as this aids in 763 # calculation of selection window 764 offset_op = g.op("Constant", value_t=torch.LongTensor([offset])) 765 if offset >= 0: 766 diag_size = g.op( 767 "Max", 768 g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)), 769 g.op("Constant", value_t=torch.LongTensor([0])), 770 ) 771 offset = 0 772 else: 773 diag_size = g.op( 774 "Max", 775 g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size), 776 g.op("Constant", value_t=torch.LongTensor([0])), 777 ) 778 diag_size = g.op("Concat", diag_size, axis_i=0) 779 780 # Calculate which diagonal values to select 781 # For example, in cases with offsets: 782 # [[0, 1.1, 0] 783 # [0, 0, 2.2]] 784 # we need to select the last two columns, so we create a tensor 785 # with all columns that are to be selected 786 # So in this example, it is [1, 2] 787 select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None) 788 select_window = g.op( 789 "CumSum", 790 select_window_ones_fill, 791 g.op("Constant", value_t=torch.LongTensor([0])), 792 ) 793 select_window = g.op( 794 "Add", 795 select_window, 796 g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])), 797 ) 798 799 gather_shape = [ 800 opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis]))) 801 for axis in list(range(rank))[:-2] 802 ] 803 gather_shape.append(diag_size) 804 gather_shape = g.op("Concat", *gather_shape, axis_i=0) 805 gather_indices = opset9.zeros(g, gather_shape, 4, None, None) 806 807 # There might be cases where offset value is greater than number of rows/columns 808 # and might cause the diagonal to overrun and as a result of this, diag_size would be zero. 809 # For example, if 810 # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows) 811 # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above 812 # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0 813 # In cases without diagonal overrun, we select the appropriate rows/columns along which we 814 # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has 815 # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially 816 # returning an empty tensor 817 overrun_cond = g.op( 818 "Not", 819 g.op( 820 "Equal", 821 diag_size, 822 g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)), 823 ), 824 ) 825 826 if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( 827 g, "If", overrun_cond, n_blocks=2 828 ) 829 830 gather_indices_if_block = if_context.op("Add", gather_indices, select_window) 831 gather_indices_if_block = symbolic_helper._unsqueeze_helper( 832 if_context, gather_indices_if_block, [rank - 1] 833 ) 834 final_non_overrun = if_context.op( 835 "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2 836 ) 837 final_overrun = opset9.zeros(else_context, gather_shape, 6, None, None) 838 utils._add_output_to_block(if_context.block, final_non_overrun) 839 utils._add_output_to_block(else_context.block, final_overrun) 840 return if_op 841 842 843# Quantized ops 844 845 846@_onnx_symbolic("quantized::linear") 847def quantized_linear( 848 g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point 849): 850 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 851 weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) 852 q_bias = symbolic_helper.requantize_bias_helper( 853 g, bias, input_scale, weight_scale, axis 854 ) 855 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 856 857 output = opset9.linear(g, input, weight, bias) 858 859 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 860 861 862@_onnx_symbolic("quantized::linear_relu") 863def quantized_linear_relu( 864 g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point 865): 866 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 867 weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) 868 q_bias = symbolic_helper.requantize_bias_helper( 869 g, bias, input_scale, weight_scale, axis 870 ) 871 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 872 873 output = opset9.linear(g, input, weight, bias) 874 output = opset9.relu(g, output) 875 876 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 877 878 879@_onnx_symbolic("quantized::conv1d_relu") 880def quantized_conv1d_relu( 881 g: jit_utils.GraphContext, 882 q_input, 883 q_weight, 884 bias, 885 stride, 886 padding, 887 dilation, 888 groups, 889 op_scale, 890 op_zero_point, 891): 892 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 893 weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) 894 q_bias = symbolic_helper.requantize_bias_helper( 895 g, bias, input_scale, weight_scale, axis 896 ) 897 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 898 899 output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) 900 output = opset9.relu(g, output) 901 902 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 903 904 905@_onnx_symbolic("quantized::conv2d_relu") 906def quantized_conv2d_relu( 907 g: jit_utils.GraphContext, 908 q_input, 909 q_weight, 910 bias, 911 stride, 912 padding, 913 dilation, 914 groups, 915 op_scale, 916 op_zero_point, 917): 918 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 919 weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) 920 q_bias = symbolic_helper.requantize_bias_helper( 921 g, bias, input_scale, weight_scale, axis 922 ) 923 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 924 925 output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) 926 output = opset9.relu(g, output) 927 928 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 929 930 931@_onnx_symbolic("quantized::conv3d_relu") 932def quantized_conv3d_relu( 933 g: jit_utils.GraphContext, 934 q_input, 935 q_weight, 936 bias, 937 stride, 938 padding, 939 dilation, 940 groups, 941 op_scale, 942 op_zero_point, 943): 944 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 945 weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) 946 q_bias = symbolic_helper.requantize_bias_helper( 947 g, bias, input_scale, weight_scale, axis 948 ) 949 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 950 951 output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) 952 output = opset9.relu(g, output) 953 954 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 955 956 957@_onnx_symbolic("quantized::conv1d") 958def quantized_conv1d( 959 g: jit_utils.GraphContext, 960 q_input, 961 q_weight, 962 bias, 963 stride, 964 padding, 965 dilation, 966 groups, 967 op_scale, 968 op_zero_point, 969): 970 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 971 weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) 972 q_bias = symbolic_helper.requantize_bias_helper( 973 g, bias, input_scale, weight_scale, axis 974 ) 975 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 976 977 output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) 978 979 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 980 981 982@_onnx_symbolic("quantized::conv2d") 983def quantized_conv2d( 984 g: jit_utils.GraphContext, 985 q_input, 986 q_weight, 987 bias, 988 stride, 989 padding, 990 dilation, 991 groups, 992 op_scale, 993 op_zero_point, 994): 995 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 996 weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) 997 q_bias = symbolic_helper.requantize_bias_helper( 998 g, bias, input_scale, weight_scale, axis 999 ) 1000 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1001 1002 output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) 1003 1004 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1005 1006 1007@_onnx_symbolic("quantized::conv3d") 1008def quantized_conv3d( 1009 g: jit_utils.GraphContext, 1010 q_input, 1011 q_weight, 1012 bias, 1013 stride, 1014 padding, 1015 dilation, 1016 groups, 1017 op_scale, 1018 op_zero_point, 1019): 1020 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 1021 weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) 1022 q_bias = symbolic_helper.requantize_bias_helper( 1023 g, bias, input_scale, weight_scale, axis 1024 ) 1025 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1026 1027 output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) 1028 1029 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1030 1031 1032@_onnx_symbolic("quantized::conv_transpose1d") 1033def quantized_conv_transpose1d( 1034 g: jit_utils.GraphContext, 1035 q_input, 1036 q_weight, 1037 bias, 1038 stride, 1039 padding, 1040 output_padding, 1041 dilation, 1042 groups, 1043 op_scale, 1044 op_zero_point, 1045): 1046 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 1047 weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) 1048 q_bias = symbolic_helper.requantize_bias_helper( 1049 g, bias, input_scale, weight_scale, axis 1050 ) 1051 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1052 1053 output = opset9.conv_transpose2d( 1054 g, input, weight, bias, stride, padding, output_padding, groups, dilation 1055 ) 1056 1057 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1058 1059 1060@_onnx_symbolic("quantized::conv_transpose2d") 1061def quantized_conv_transpose2d( 1062 g: jit_utils.GraphContext, 1063 q_input, 1064 q_weight, 1065 bias, 1066 stride, 1067 padding, 1068 output_padding, 1069 dilation, 1070 groups, 1071 op_scale, 1072 op_zero_point, 1073): 1074 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 1075 weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) 1076 q_bias = symbolic_helper.requantize_bias_helper( 1077 g, bias, input_scale, weight_scale, axis 1078 ) 1079 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1080 1081 output = opset9.conv_transpose2d( 1082 g, input, weight, bias, stride, padding, output_padding, groups, dilation 1083 ) 1084 1085 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1086 1087 1088@_onnx_symbolic("quantized::conv_transpose3d") 1089def quantized_conv_transpose3d( 1090 g: jit_utils.GraphContext, 1091 q_input, 1092 q_weight, 1093 bias, 1094 stride, 1095 padding, 1096 output_padding, 1097 dilation, 1098 groups, 1099 op_scale, 1100 op_zero_point, 1101): 1102 input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) 1103 weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) 1104 q_bias = symbolic_helper.requantize_bias_helper( 1105 g, bias, input_scale, weight_scale, axis 1106 ) 1107 bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) 1108 1109 output = opset9.conv_transpose3d( 1110 g, input, weight, bias, stride, padding, output_padding, groups, dilation 1111 ) 1112 1113 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 1114