1# Owner(s): ["module: onnx"] 2 3import io 4 5import numpy as np 6 7import onnx 8import pytorch_test_common 9from pytorch_test_common import skipIfUnsupportedMinOpsetVersion 10 11import torch 12from torch.onnx import _constants, utils 13from torch.onnx._globals import GLOBALS 14from torch.onnx._internal import jit_utils 15from torch.testing._internal import common_utils 16 17 18def expect_tensor(scalar_type, shape=None): 19 def verify(actual_type): 20 np.testing.assert_equal(actual_type.scalarType(), scalar_type) 21 # if shape is not None: 22 # np.testing.assert_equal(actual_type.sizes(), shape) 23 if shape is not None: 24 np.testing.assert_equal(actual_type.varyingSizes(), shape) 25 26 return verify 27 28 29def as_graphcontext(graph: torch.Graph) -> jit_utils.GraphContext: 30 return jit_utils.GraphContext( 31 graph=graph, 32 block=graph.block(), 33 opset=_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET, 34 original_node=None, # type: ignore[arg-type] 35 params_dict={}, 36 env={}, 37 values_in_env=set(), 38 ) 39 40 41def g_op(graph: torch.Graph, op_name: str, *args, **kwargs): 42 return as_graphcontext(graph).op(op_name, *args, **kwargs) 43 44 45class TestONNXShapeInference(pytorch_test_common.ExportTestCase): 46 def setUp(self): 47 self.opset_version = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET 48 GLOBALS.export_onnx_opset_version = self.opset_version 49 50 def run_test(self, g, n, type_assertion_funcs): 51 if not isinstance(type_assertion_funcs, list): 52 type_assertion_funcs = [type_assertion_funcs] 53 54 torch._C._jit_pass_onnx_graph_shape_type_inference(g, {}, self.opset_version) 55 for out, type_assertion_func in zip(n.outputs(), type_assertion_funcs): 56 type_assertion_func(out.type()) 57 58 def create_empty_graph(self): 59 g = torch._C.Graph() 60 # kick off initialization for ConstantMap. 61 torch._C._jit_pass_onnx_graph_shape_type_inference(g, {}, self.opset_version) 62 return g 63 64 def insert_tensor_constant(self, g, tensor): 65 return g_op(g, "Constant", value_t=tensor) 66 67 def test_cast(self): 68 # Test cast with input of unknown scalar type. 69 g = self.create_empty_graph() 70 input = g.addInput() 71 cast_out = g_op(g, "Cast", input, to_i=1) 72 self.run_test(g, cast_out.node(), expect_tensor("Float")) 73 74 def test_constant_of_shape(self): 75 # Test ConstantOfShape with input of onnx::Shape node. 76 g = self.create_empty_graph() 77 constant = self.insert_tensor_constant(g, torch.ones(1, 2, 3, 4)) 78 shape = g_op(g, "Shape", constant) 79 constant_of_shape = g_op( 80 g, "ConstantOfShape", shape, value_t=torch.tensor([2.0]) 81 ) 82 self.run_test( 83 g, constant_of_shape.node(), expect_tensor("Float", shape=(1, 2, 3, 4)) 84 ) 85 86 def test_constant_of_shape_static(self): 87 # Test ConstantOfShape with input of prim::ListConstruct of static tensor 88 rank = 4 89 g = self.create_empty_graph() 90 constants = [ 91 self.insert_tensor_constant(g, torch.tensor(i + 1)) for i in range(rank) 92 ] 93 shape = g_op(g, "prim::ListConstruct", *constants) 94 shape.setType(torch._C.ListType.ofInts()) 95 constant_of_shape = g_op( 96 g, "ConstantOfShape", shape, value_t=torch.tensor([2.0]) 97 ) 98 self.run_test( 99 g, constant_of_shape.node(), expect_tensor("Float", shape=(1, 2, 3, 4)) 100 ) 101 102 def test_constant_of_shape_dynamic(self): 103 # Test ConstantOfShape with input of prim::ListConstruct of dynamic tensor 104 rank = 4 105 g = self.create_empty_graph() 106 inputs = [g.addInput() for i in range(rank)] 107 shape = g_op(g, "prim::ListConstruct", *inputs) 108 shape.setType(torch._C.ListType.ofInts()) 109 constant_of_shape = g_op( 110 g, "ConstantOfShape", shape, value_t=torch.tensor([2.0]) 111 ) 112 self.run_test( 113 g, 114 constant_of_shape.node(), 115 expect_tensor("Float", shape=(None, None, None, None)), 116 ) 117 118 def test_gather_dynamic_index(self): 119 g = self.create_empty_graph() 120 input = g.addInput() 121 input.setType( 122 input.type().with_dtype(torch.float).with_sizes([None, 3, 16, 16]) 123 ) 124 indices = g.addInput() 125 indices.setType(indices.type().with_dtype(torch.int64).with_sizes([None])) 126 output = g_op(g, "Gather", input, indices, axis_i=1) 127 self.run_test( 128 g, output.node(), expect_tensor("Float", shape=([None, None, 16, 16])) 129 ) 130 131 def test_gather_scalar_index(self): 132 g = self.create_empty_graph() 133 input = g.addInput() 134 input.setType( 135 input.type().with_dtype(torch.float).with_sizes([None, 3, 16, 16]) 136 ) 137 indices = self.insert_tensor_constant(g, torch.tensor(1)) 138 output = g_op(g, "Gather", input, indices, axis_i=1) 139 self.run_test(g, output.node(), expect_tensor("Float", shape=([None, 16, 16]))) 140 141 def test_reshape(self): 142 g = self.create_empty_graph() 143 constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 5)) 144 constant_2 = self.insert_tensor_constant(g, torch.tensor([2, 0, -1])) 145 shape = g_op(g, "Reshape", constant, constant_2) 146 self.run_test(g, shape.node(), expect_tensor("Float", shape=(2, 16, 25))) 147 148 g = self.create_empty_graph() 149 constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 4)) 150 constant_2 = self.insert_tensor_constant(g, torch.tensor([-1, 0, 4])) 151 shape = g_op(g, "Reshape", constant, constant_2) 152 self.run_test(g, shape.node(), expect_tensor("Float", shape=(10, 16, 4))) 153 154 g = self.create_empty_graph() 155 constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 4)) 156 constant_2 = self.insert_tensor_constant(g, torch.tensor([-1, 0, 0])) 157 shape = g_op(g, "Reshape", constant, constant_2) 158 self.run_test(g, shape.node(), expect_tensor("Float", shape=(8, 16, 5))) 159 160 def test_reshape_symbolic(self): 161 g = self.create_empty_graph() 162 input = g.addInput() 163 input.setType(input.type().with_sizes([None, None, 2, 8])) 164 constant = self.insert_tensor_constant(g, torch.tensor([0, 0, -1])) 165 output = g_op(g, "Reshape", input, constant) 166 self.run_test(g, output.node(), expect_tensor(None, shape=(None, None, 16))) 167 168 @skipIfUnsupportedMinOpsetVersion(14) 169 def test_reshape_allowzero(self): 170 g = self.create_empty_graph() 171 input = g.addInput() 172 input.setType(input.type().with_sizes([3, 4, 0])) 173 constant = self.insert_tensor_constant(g, torch.tensor([0, 4, 3])) 174 output = g_op(g, "Reshape", input, constant, allowzero_i=1) 175 self.run_test(g, output.node(), expect_tensor(None, shape=(0, 4, 3))) 176 177 def test_slice(self): 178 g = self.create_empty_graph() 179 input = g.addInput() 180 input.setType(input.type().with_sizes([None, None])) 181 start_input = g.addInput() 182 start_input.setType(start_input.type().with_sizes([None])) 183 end = self.insert_tensor_constant(g, torch.tensor([3])) 184 axis = self.insert_tensor_constant(g, torch.tensor([0])) 185 step = self.insert_tensor_constant(g, torch.tensor([1])) 186 slice = g_op(g, "Slice", input, start_input, end, axis, step) 187 self.run_test(g, slice.node(), expect_tensor(None, shape=(None, None))) 188 189 def test_slice_with_dynamic_start_index(self): 190 g = self.create_empty_graph() 191 input = self.insert_tensor_constant(g, torch.ones(2, 3, 4, 5)) 192 start_input = g.addInput() 193 start_input.setType(start_input.type().with_sizes([2])) 194 end = self.insert_tensor_constant(g, torch.tensor([3, 4])) 195 axis = self.insert_tensor_constant(g, torch.tensor([1, -1])) 196 slice = g_op(g, "Slice", input, start_input, end, axis) 197 self.run_test(g, slice.node(), expect_tensor(None, shape=(2, None, 4, None))) 198 199 def test_broadcast_matmul(self): 200 g = self.create_empty_graph() 201 constant = self.insert_tensor_constant(g, torch.ones(5, 1, 2)) 202 constant_2 = self.insert_tensor_constant(g, torch.ones(3, 1, 2, 1)) 203 shape = g_op(g, "MatMul", constant, constant_2) 204 self.run_test(g, shape.node(), expect_tensor("Float", shape=(3, 5, 1, 1))) 205 206 # test when first input is of rank 1 207 g = self.create_empty_graph() 208 constant = self.insert_tensor_constant(g, torch.ones(2)) 209 constant_2 = self.insert_tensor_constant(g, torch.ones(3, 1, 2, 1)) 210 shape = g_op(g, "MatMul", constant, constant_2) 211 self.run_test(g, shape.node(), expect_tensor("Float", shape=(3, 1, 1))) 212 213 # test when second input is of rank 1 214 g = self.create_empty_graph() 215 constant = self.insert_tensor_constant(g, torch.ones(5, 1, 2)) 216 constant_2 = self.insert_tensor_constant(g, torch.ones(2)) 217 shape = g_op(g, "MatMul", constant, constant_2) 218 self.run_test(g, shape.node(), expect_tensor("Float", shape=(5, 1))) 219 220 # test when both inputs are of rank 1 221 g = self.create_empty_graph() 222 constant = self.insert_tensor_constant(g, torch.ones(2)) 223 constant_2 = self.insert_tensor_constant(g, torch.ones(2)) 224 shape = g_op(g, "MatMul", constant, constant_2) 225 self.run_test(g, shape.node(), expect_tensor("Float", shape=())) 226 227 def test_expand(self): 228 g = self.create_empty_graph() 229 input = g.addInput() 230 constant = self.insert_tensor_constant(g, torch.ones(2, 4)) 231 input.setType(constant.type().with_sizes([None, None])) 232 shape = g_op(g, "Shape", input) 233 expand = g_op(g, "Expand", constant, shape) 234 self.run_test(g, expand.node(), expect_tensor("Float", shape=(None, None))) 235 236 def test_pad(self): 237 g = self.create_empty_graph() 238 input = g.addInput() 239 input.setType(input.type().with_dtype(torch.float).with_sizes([3, 320, 100])) 240 constant = self.insert_tensor_constant(g, torch.ones(6, dtype=torch.long)) 241 none = g_op(g, "prim::Constant").setType(torch.NoneType.get()) 242 pad = g_op(g, "Pad", input, constant, none, mode_s="constant") 243 self.run_test(g, pad.node(), expect_tensor("Float", shape=(5, 322, 102))) 244 245 def test_pad_with_dynamic_input_shape(self): 246 g = self.create_empty_graph() 247 input = g.addInput() 248 input.setType(input.type().with_dtype(torch.float).with_sizes([3, None, None])) 249 constant = self.insert_tensor_constant(g, torch.ones(6, dtype=torch.long)) 250 none = g_op(g, "prim::Constant").setType(torch.NoneType.get()) 251 pad = g_op(g, "Pad", input, constant, none, mode_s="constant") 252 self.run_test(g, pad.node(), expect_tensor("Float", shape=(5, None, None))) 253 254 def test_pad_with_dynamic_pad_size(self): 255 g = self.create_empty_graph() 256 input = g.addInput() 257 input.setType(input.type().with_dtype(torch.float).with_sizes([3, 320, 100])) 258 pad_size = g.addInput() 259 pad_size.setType(pad_size.type().with_dtype(torch.long).with_sizes([6])) 260 none = g_op(g, "prim::Constant").setType(torch.NoneType.get()) 261 pad = g_op(g, "Pad", input, pad_size, none, mode_s="constant") 262 self.run_test(g, pad.node(), expect_tensor("Float", shape=(None, None, None))) 263 264 def test_resize(self): 265 g = self.create_empty_graph() 266 input = g.addInput() 267 input.setType(input.type().with_dtype(torch.float).with_sizes([4, 32, 64, 64])) 268 none = g_op(g, "prim::Constant").setType(torch.NoneType.get()) 269 scales = self.insert_tensor_constant( 270 g, torch.tensor([1, 1, 2, 2], dtype=torch.float) 271 ) 272 resize = g_op( 273 g, 274 "Resize", 275 input, 276 none, 277 scales, 278 coordinate_transformation_mode_s="align_corners", 279 cubic_coeff_a_f=-0.75, 280 mode_s="linear", 281 nearest_mode_s="floor", 282 ) 283 self.run_test(g, resize.node(), expect_tensor("Float", shape=(4, 32, 128, 128))) 284 285 def test_resize_after_concat(self): 286 g = self.create_empty_graph() 287 input = g.addInput() 288 input.setType(input.type().with_dtype(torch.float).with_sizes([4, 32, 64, 64])) 289 none = g_op(g, "prim::Constant").setType(torch.NoneType.get()) 290 scale_1 = self.insert_tensor_constant( 291 g, torch.tensor([1, 1], dtype=torch.float) 292 ) 293 scale_2 = self.insert_tensor_constant( 294 g, torch.tensor([2, 2], dtype=torch.float) 295 ) 296 # `scales` values should be statically known due to constant folding in shape inference. 297 scales = g_op(g, "Concat", scale_1, scale_2, axis_i=0) 298 resize = g_op( 299 g, 300 "Resize", 301 input, 302 none, 303 scales, 304 coordinate_transformation_mode_s="align_corners", 305 cubic_coeff_a_f=-0.75, 306 mode_s="linear", 307 nearest_mode_s="floor", 308 ) 309 self.run_test(g, resize.node(), expect_tensor("Float", shape=(4, 32, 128, 128))) 310 311 def test_reduce_prod_with_axes(self): 312 g = self.create_empty_graph() 313 input = g.addInput() 314 input.setType(input.type().with_dtype(torch.long).with_sizes([2])) 315 reduce_prod = g_op(g, "ReduceProd", input, axes_i=[0]) 316 self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,))) 317 318 def test_reduce_prod_without_axes(self): 319 g = self.create_empty_graph() 320 input = g.addInput() 321 input.setType(input.type().with_dtype(torch.long).with_sizes([2])) 322 reduce_prod = g_op(g, "ReduceProd", input) 323 self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,))) 324 325 def test_proceeding_nodes_use_prim_pack_padded_output_dtype_correctly(self): 326 g = self.create_empty_graph() 327 input = g.addInput() 328 input.setType(input.type().with_dtype(torch.float).with_sizes([4, 16])) 329 length = g.addInput() 330 length.setType(length.type().with_dtype(torch.long).with_sizes([4])) 331 padded, batch_size = g_op(g, "prim::PackPadded", input, length, outputs=2) 332 # `prim::PackPadded` only occurs in tracing mode. Hence its outputs inherits 333 # shape and data type from traced graph. 334 padded.setType(padded.type().with_dtype(torch.float).with_sizes([None, None])) 335 batch_size.setType(batch_size.type().with_dtype(torch.long).with_sizes([None])) 336 # `Gather` should use the data type of `batch_size` as the data type of its output. 337 gather_idx = self.insert_tensor_constant(g, torch.tensor([0], dtype=torch.long)) 338 gather = g_op(g, "Gather", batch_size, gather_idx, axis_i=0) 339 self.run_test(g, gather.node(), expect_tensor("Long", shape=(None,))) 340 341 def test_squeeze_after_dynamic_if(self): 342 from torch.onnx.symbolic_opset11 import squeeze as squeeze11 343 344 g = self.create_empty_graph() 345 346 input = g.addInput() 347 input.setType(input.type().with_dtype(torch.float).with_sizes([1, None, 5])) 348 349 # Type is intentionally not bool to test that 350 # the added "Cast" node doesn't stop shape inference. 351 cond = g.addInput() 352 cond.setType(input.type().with_dtype(torch.int32).with_sizes([1])) 353 if_op, (if_context, else_context), new_node = jit_utils.add_op_with_blocks( 354 as_graphcontext(g), "If", cond, n_blocks=2 355 ) 356 block1_output = if_context.op("Add", input, input) 357 block2_output = else_context.op("Identity", input) 358 utils._add_output_to_block(if_context.block, block1_output) 359 utils._add_output_to_block(else_context.block, block2_output) 360 if_output = torch._C._jit_pass_fixup_onnx_controlflow_node( 361 new_node, _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET 362 )[0] 363 torch._C._jit_pass_onnx_node_shape_type_inference( 364 new_node, {}, _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET 365 ) 366 367 # Exporter will add "If" instead of raw "Squeeze" if it does not know 368 # that if the dimension it is squeezing has size 1. 369 squeezed = squeeze11(as_graphcontext(g), if_output, dim=0) 370 assert squeezed.node().kind() == "onnx::Squeeze" 371 self.run_test(g, squeezed.node(), expect_tensor("Float", shape=(None, 5))) 372 373 374class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase): 375 def setUp(self): 376 super().setUp() 377 self.opset_version = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET 378 379 def test_setType_maintains_output_shape_for_single_custom_op(self): 380 self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) 381 382 class CustomInverse(torch.nn.Module): 383 def forward(self, x): 384 return torch.inverse(x) + x 385 386 def linalg_inv_settype(g, self): 387 return g.op("com.microsoft::Inverse", self).setType(self.type()) 388 389 torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) 390 model = CustomInverse() 391 x = torch.randn(2, 3, 3) 392 f = io.BytesIO() 393 torch.onnx.export( 394 model, 395 (x,), 396 f, 397 opset_version=self.opset_version, 398 custom_opsets={"com.microsoft": 1}, 399 ) 400 401 model_proto = onnx.load(io.BytesIO(f.getvalue())) 402 model_value_info = model_proto.graph.value_info 403 self.assertIsNotNone(model_value_info) 404 assert model_value_info 405 dims = model_value_info[0].type.tensor_type.shape.dim 406 for i in range(len(dims)): 407 # If node output has shape info, it should have dim_value 408 # Otherwise, it has dim_params with dynamic shape 409 self.assertTrue(dims[i].HasField("dim_value")) 410 for dim, rank in zip(dims, x.size()): 411 self.assertEqual(dim.dim_value, rank) 412 413 def test_no_setType_for_single_custom_op(self): 414 self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) 415 416 class CustomInverse(torch.nn.Module): 417 def forward(self, x): 418 return torch.inverse(x) + x 419 420 def linalg_inv_no_settype(g, self): 421 return g.op("com.microsoft::Inverse", self) 422 423 torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_no_settype, 9) 424 model = CustomInverse() 425 x = torch.randn(2, 3, 3) 426 f = io.BytesIO() 427 torch.onnx.export( 428 model, 429 (x,), 430 f, 431 opset_version=self.opset_version, 432 custom_opsets={"com.microsoft": 1}, 433 ) 434 435 model_proto = onnx.load(io.BytesIO(f.getvalue())) 436 model_value_info = model_proto.graph.value_info 437 self.assertIsNotNone(model_value_info) 438 assert model_value_info 439 dims = model_value_info[0].type.tensor_type.shape.dim 440 for i in range(len(dims)): 441 # If node output has shape info, it should have dim_value 442 # Otherwise, it has dim_params with dynamic shape 443 self.assertTrue(dims[i].HasField("dim_param")) 444 445 def test_setType_maintains_output_shape_for_single_custom_op_with_dynamic_axes( 446 self, 447 ): 448 self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) 449 450 class CustomInverse(torch.nn.Module): 451 def forward(self, x): 452 return torch.inverse(x) + x 453 454 def linalg_inv_settype(g, self): 455 return g.op("com.microsoft::Inverse", self).setType( 456 self.type().with_dtype(torch.float).with_sizes([None, 3, 3]) 457 ) 458 459 torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) 460 model = CustomInverse() 461 x = torch.randn(2, 3, 3) 462 f = io.BytesIO() 463 torch.onnx.export( 464 model, 465 (x,), 466 f, 467 opset_version=self.opset_version, 468 custom_opsets={"com.microsoft": 1}, 469 input_names=["x"], 470 dynamic_axes={"x": {0: "batch"}}, 471 ) 472 473 model_proto = onnx.load(io.BytesIO(f.getvalue())) 474 model_value_info = model_proto.graph.value_info 475 self.assertIsNotNone(model_value_info) 476 assert model_value_info 477 dims = model_value_info[0].type.tensor_type.shape.dim 478 # The first axe should be dynamic as we defined when exporting 479 self.assertTrue(dims[0].HasField("dim_param")) 480 for i in range(1, len(dims)): 481 # If node output has shape info, it should have dim_value 482 # Otherwise, it has dim_params with dynamic shape 483 self.assertTrue(dims[i].HasField("dim_value")) 484 self.assertEqual(dims[i].dim_value, x.size()[i]) 485 486 def test_setType_maintains_output_shape_for_single_custom_op_with_onnx_ops(self): 487 self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) 488 489 class CustomInverse(torch.nn.Module): 490 def forward(self, x, y, z): 491 x = torch.inverse(x) 492 return x + y + z 493 494 def linalg_inv_settype(g, self): 495 return g.op("com.microsoft::Inverse", self).setType( 496 self.type().with_dtype(torch.float).with_sizes([2, 3, 10, 10]) 497 ) 498 499 torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) 500 model = CustomInverse() 501 x = torch.randn(2, 3, 10, 10) 502 y = torch.randn(2, 3, 10, 10) 503 z = torch.randn(2, 3, 10, 10) 504 f = io.BytesIO() 505 torch.onnx.export( 506 model, 507 (x, y, z), 508 f, 509 opset_version=self.opset_version, 510 custom_opsets={"com.microsoft": 1}, 511 ) 512 513 model_proto = onnx.load(io.BytesIO(f.getvalue())) 514 # To validate the shape of inverse Op, we need to find inverse output name, 515 # and then use it to identify its value_info for the shape. 516 output_name = "" 517 for node in model_proto.graph.node: 518 if node.op_type == "Inverse": 519 output_name = node.output[0] 520 break 521 assert output_name 522 model_value_info = model_proto.graph.value_info 523 self.assertIsNotNone(model_value_info) 524 assert model_value_info 525 for value_info in model_value_info: 526 assert value_info.name 527 if value_info.name == output_name: 528 dims = value_info.type.tensor_type.shape.dim 529 for i in range(len(dims)): 530 # If node output has shape info, it should have dim_value 531 # Otherwise, it has dim_params with dynamic shape 532 self.assertTrue(dims[i].HasField("dim_value")) 533 for dim, rank in zip(dims, x.size()): 534 self.assertEqual(dim.dim_value, rank) 535 536 537if __name__ == "__main__": 538 common_utils.run_tests() 539