1# Owner(s): ["module: onnx"] 2 3import io 4import itertools 5 6import onnx 7 8import pytorch_test_common 9 10import torch 11import torch.onnx 12from torch.nn import Module 13from torch.onnx import producer_name, producer_version 14from torch.onnx._globals import GLOBALS 15from torch.testing._internal import common_utils 16 17 18def check_onnx_opset_operator( 19 model, ops, opset_version=GLOBALS.export_onnx_opset_version 20): 21 # check_onnx_components 22 assert ( 23 model.producer_name == producer_name 24 and model.producer_version == producer_version 25 and model.opset_import[0].version == opset_version 26 ) 27 28 # check the schema with the onnx checker 29 onnx.checker.check_model(model) 30 31 # check target type and attributes 32 graph = model.graph 33 # ops should contain an object for each node 34 # in graph.node, in the right order. 35 # At least the op_name should be specified, 36 # but the op's attributes can optionally be 37 # specified as well 38 assert len(ops) == len(graph.node) 39 for i in range(0, len(ops)): 40 assert graph.node[i].op_type == ops[i]["op_name"] 41 if "attributes" in ops[i]: 42 attributes = ops[i]["attributes"] 43 assert len(attributes) == len(graph.node[i].attribute) 44 for j in range(0, len(attributes)): 45 for attribute_field in attributes[j].keys(): 46 assert attributes[j][attribute_field] == getattr( 47 graph.node[i].attribute[j], attribute_field 48 ) 49 50 51def check_onnx_opsets_operator( 52 module, 53 x, 54 ops, 55 opset_versions, 56 training=torch.onnx.TrainingMode.EVAL, 57 input_names=None, 58 dynamic_axes=None, 59): 60 for opset_version in opset_versions: 61 f = io.BytesIO() 62 torch.onnx.export( 63 module, 64 x, 65 f, 66 opset_version=opset_version, 67 training=training, 68 input_names=input_names, 69 dynamic_axes=dynamic_axes, 70 ) 71 model = onnx.load(io.BytesIO(f.getvalue())) 72 check_onnx_opset_operator(model, ops[opset_version], opset_version) 73 74 75class TestONNXOpset(pytorch_test_common.ExportTestCase): 76 def test_opset_fallback(self): 77 class MyModule(Module): 78 def forward(self, x): 79 return torch.isnan(x) 80 81 ops = [{"op_name": "IsNaN"}] 82 ops = {9: ops, 10: ops} 83 x = torch.tensor([1.0, float("nan"), 2.0]) 84 check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10]) 85 86 def test_topk(self): 87 class MyModule(Module): 88 def forward(self, x): 89 return torch.topk(x, 3) 90 91 ops_9 = [ 92 { 93 "op_name": "TopK", 94 "attributes": [ 95 {"name": "axis", "i": -1, "type": 2}, 96 {"name": "k", "i": 3, "type": 2}, 97 ], 98 } 99 ] 100 ops_10 = [ 101 {"op_name": "Constant"}, 102 {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]}, 103 ] 104 ops = {9: ops_9, 10: ops_10} 105 x = torch.arange(1.0, 6.0, requires_grad=True) 106 check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10]) 107 108 # test with dynamic k 109 class MyModuleDynamic(torch.jit.ScriptModule): 110 @torch.jit.script_method 111 def forward(self, input, k): 112 return torch.topk(input, k) 113 114 ops_10 = [ 115 {"op_name": "Constant", "attributes": [{"name": "value", "type": 4}]}, 116 {"op_name": "Reshape"}, 117 {"op_name": "TopK", "attributes": [{"name": "axis", "i": -1, "type": 2}]}, 118 ] 119 ops = {10: ops_10} 120 x = torch.arange(1.0, 6.0, requires_grad=True) 121 k = torch.tensor(3) 122 module = MyModuleDynamic() 123 check_onnx_opsets_operator(module, (x, k), ops, opset_versions=[10]) 124 125 def test_maxpool(self): 126 module = torch.nn.MaxPool1d(2, stride=1) 127 128 ops_9 = [ 129 { 130 "op_name": "MaxPool", 131 "attributes": [ 132 {"name": "kernel_shape", "ints": [2], "type": 7}, 133 {"name": "pads", "ints": [0, 0], "type": 7}, 134 {"name": "strides", "ints": [1], "type": 7}, 135 ], 136 } 137 ] 138 ops_10 = [ 139 { 140 "op_name": "MaxPool", 141 "attributes": [ 142 {"name": "ceil_mode", "i": 0, "type": 2}, 143 {"name": "dilations", "ints": [1], "type": 7}, 144 {"name": "kernel_shape", "ints": [2], "type": 7}, 145 {"name": "pads", "ints": [0, 0], "type": 7}, 146 {"name": "strides", "ints": [1], "type": 7}, 147 ], 148 } 149 ] 150 ops = {9: ops_9, 10: ops_10} 151 x = torch.randn(20, 16, 50) 152 check_onnx_opsets_operator(module, x, ops, opset_versions=[9, 10]) 153 154 # add test with dilations 155 module = torch.nn.MaxPool1d(2, stride=1, dilation=2) 156 157 ops_10 = [ 158 { 159 "op_name": "MaxPool", 160 "attributes": [ 161 {"name": "ceil_mode", "i": 0, "type": 2}, 162 {"name": "dilations", "ints": [2], "type": 7}, 163 {"name": "kernel_shape", "ints": [2], "type": 7}, 164 {"name": "pads", "ints": [0, 0], "type": 7}, 165 {"name": "strides", "ints": [1], "type": 7}, 166 ], 167 } 168 ] 169 ops = {10: ops_10} 170 x = torch.randn(20, 16, 50) 171 check_onnx_opsets_operator(module, x, ops, opset_versions=[10]) 172 173 def test_upsample(self): 174 class MyModule(Module): 175 def forward(self, x): 176 size = [v * 2 for v in x.size()[2:]] 177 size = [int(i) for i in size] 178 return torch.nn.functional.interpolate(x, size=size, mode="nearest") 179 180 module = MyModule() 181 ops8 = [ 182 { 183 "op_name": "Upsample", 184 "attributes": [ 185 {"name": "mode", "s": (b"nearest"), "type": 3}, 186 {"name": "scales", "floats": [1.0, 1.0, 2.0, 2.0], "type": 6}, 187 ], 188 } 189 ] 190 ops9 = [ 191 {"op_name": "Constant"}, 192 { 193 "op_name": "Upsample", 194 "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}], 195 }, 196 ] 197 ops = {8: ops8, 9: ops9} 198 x = torch.randn(2, 2, 2, 2) 199 check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9]) 200 201 def test_cast_constant(self): 202 class MyModule(Module): 203 def forward(self, x): 204 return x - 1 205 206 module = MyModule() 207 ops_8 = [ 208 {"op_name": "Constant"}, 209 {"op_name": "Cast", "attributes": [{"name": "to", "i": 7, "type": 2}]}, 210 {"op_name": "Sub"}, 211 ] 212 ops_9 = [{"op_name": "Constant"}, {"op_name": "Sub"}] 213 ops = {8: ops_8, 9: ops_9} 214 x = torch.ones(5, 6, dtype=torch.long) 215 check_onnx_opsets_operator(module, x, ops, opset_versions=[8, 9]) 216 217 def test_slice(self): 218 class MyModule(Module): 219 def forward(self, x): 220 return x[0:1] 221 222 ops_9 = [ 223 { 224 "op_name": "Slice", 225 "attributes": [ 226 {"name": "axes", "ints": [0], "type": 7}, 227 {"name": "ends", "ints": [1], "type": 7}, 228 {"name": "starts", "ints": [0], "type": 7}, 229 ], 230 } 231 ] 232 ops_10 = [ 233 {"op_name": "Constant"}, 234 {"op_name": "Constant"}, 235 {"op_name": "Constant"}, 236 {"op_name": "Constant"}, 237 {"op_name": "Slice", "attributes": []}, 238 ] 239 ops = {9: ops_9, 10: ops_10} 240 x = torch.randn(3) 241 check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10]) 242 243 class DynamicSliceModel(torch.jit.ScriptModule): 244 @torch.jit.script_method 245 def forward(self, x): 246 return x[1 : x.size(0)] 247 248 module = DynamicSliceModel() 249 x = torch.rand(1, 2) 250 ops_10 = [ 251 {"op_name": "Shape"}, 252 {"op_name": "Constant"}, 253 {"op_name": "Gather", "attributes": [{"name": "axis", "i": 0, "type": 2}]}, 254 {"op_name": "Constant"}, 255 {"op_name": "Constant"}, 256 { 257 "op_name": "Unsqueeze", 258 "attributes": [{"name": "axes", "i": 0, "type": 7}], 259 }, 260 {"op_name": "Constant"}, 261 {"op_name": "Slice", "attributes": []}, 262 ] 263 ops = {10: ops_10} 264 check_onnx_opsets_operator( 265 module, 266 x, 267 ops, 268 opset_versions=[10], 269 input_names=["x"], 270 dynamic_axes={"x": [0, 1]}, 271 ) 272 273 ops_10 = [ 274 {"op_name": "Constant"}, 275 {"op_name": "Constant"}, 276 {"op_name": "Constant"}, 277 {"op_name": "Constant"}, 278 {"op_name": "Slice", "attributes": []}, 279 ] 280 ops = {10: ops_10} 281 check_onnx_opsets_operator(module, x, ops, opset_versions=[10]) 282 283 def test_flip(self): 284 class MyModule(Module): 285 def forward(self, x): 286 return torch.flip(x, dims=[0]) 287 288 ops_10 = [ 289 {"op_name": "Constant"}, 290 {"op_name": "Constant"}, 291 {"op_name": "Constant"}, 292 {"op_name": "Constant"}, 293 {"op_name": "Slice", "attributes": []}, 294 ] 295 ops = {10: ops_10} 296 import numpy 297 298 x = torch.tensor(numpy.arange(6.0).reshape(2, 3)) 299 check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[10]) 300 301 def test_dropout(self): 302 class MyModule(Module): 303 def __init__(self) -> None: 304 super().__init__() 305 self.dropout = torch.nn.Dropout(0.5) 306 307 def forward(self, x): 308 return self.dropout(x) 309 310 x = torch.randn(1, 2, 3) 311 312 # we should only export the onnx Dropout op in training mode; test both modes 313 314 # test training mode 315 ops = [ 316 { 317 "op_name": "Dropout", 318 "attributes": [{"name": "ratio", "f": 0.5, "type": 1}], 319 } 320 ] 321 ops = {9: ops, 10: ops} 322 check_onnx_opsets_operator( 323 MyModule(), 324 x, 325 ops, 326 opset_versions=[9, 10], 327 training=torch.onnx.TrainingMode.TRAINING, 328 ) 329 330 # test eval mode 331 ops = [{"op_name": "Identity"}] 332 ops = {9: ops, 10: ops} 333 check_onnx_opsets_operator( 334 MyModule(), 335 x, 336 ops, 337 opset_versions=[9, 10], 338 training=torch.onnx.TrainingMode.EVAL, 339 ) 340 341 def test_full(self): 342 class MyModule(Module): 343 def forward(self, x): 344 return torch.full((3, 4), x) 345 346 ops = [ 347 {"op_name": "Constant"}, 348 {"op_name": "ConstantOfShape"}, 349 {"op_name": "Add"}, 350 ] 351 ops = {9: ops, 10: ops} 352 x = torch.tensor(12.0) 353 check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10]) 354 355 def test_interpolate(self): 356 class MyModel(torch.nn.Module): 357 def forward(self, x): 358 size = [v * 2 for v in x.size()[2:]] 359 return torch.nn.functional.interpolate(x, size=size, mode="nearest") 360 361 ops_9 = [ 362 {"op_name": "Shape"}, 363 {"op_name": "Constant"}, 364 {"op_name": "Gather"}, 365 {"op_name": "Shape"}, 366 {"op_name": "Constant"}, 367 {"op_name": "Gather"}, 368 {"op_name": "Constant"}, 369 {"op_name": "Mul"}, 370 {"op_name": "Constant"}, 371 {"op_name": "Mul"}, 372 {"op_name": "Unsqueeze"}, 373 {"op_name": "Unsqueeze"}, 374 {"op_name": "Concat"}, 375 {"op_name": "Cast"}, 376 {"op_name": "Shape"}, 377 {"op_name": "Slice"}, 378 {"op_name": "Cast"}, 379 {"op_name": "Div"}, 380 {"op_name": "Constant"}, 381 {"op_name": "Concat"}, 382 { 383 "op_name": "Upsample", 384 "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}], 385 }, 386 ] 387 ops_10 = [ 388 {"op_name": "Shape"}, 389 {"op_name": "Constant"}, 390 {"op_name": "Gather"}, 391 {"op_name": "Shape"}, 392 {"op_name": "Constant"}, 393 {"op_name": "Gather"}, 394 {"op_name": "Constant"}, 395 {"op_name": "Mul"}, 396 {"op_name": "Constant"}, 397 {"op_name": "Mul"}, 398 {"op_name": "Unsqueeze"}, 399 {"op_name": "Unsqueeze"}, 400 {"op_name": "Concat"}, 401 {"op_name": "Cast"}, 402 {"op_name": "Shape"}, 403 {"op_name": "Constant"}, 404 {"op_name": "Constant"}, 405 {"op_name": "Constant"}, 406 {"op_name": "Slice"}, 407 {"op_name": "Cast"}, 408 {"op_name": "Div"}, 409 {"op_name": "Constant"}, 410 {"op_name": "Concat"}, 411 { 412 "op_name": "Resize", 413 "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}], 414 }, 415 ] 416 417 ops = {9: ops_9, 10: ops_10} 418 x = torch.randn(1, 2, 3, 4, requires_grad=True) 419 check_onnx_opsets_operator( 420 MyModel(), 421 x, 422 ops, 423 opset_versions=[9, 10], 424 input_names=["x"], 425 dynamic_axes={"x": [0, 1, 2, 3]}, 426 ) 427 428 ops_9 = [ 429 {"op_name": "Constant"}, 430 {"op_name": "Shape"}, 431 {"op_name": "Slice"}, 432 {"op_name": "Cast"}, 433 {"op_name": "Div"}, 434 {"op_name": "Constant"}, 435 {"op_name": "Concat"}, 436 { 437 "op_name": "Upsample", 438 "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}], 439 }, 440 ] 441 ops_10 = [ 442 {"op_name": "Constant"}, 443 {"op_name": "Shape"}, 444 {"op_name": "Constant"}, 445 {"op_name": "Constant"}, 446 {"op_name": "Constant"}, 447 {"op_name": "Slice"}, 448 {"op_name": "Cast"}, 449 {"op_name": "Div"}, 450 {"op_name": "Constant"}, 451 {"op_name": "Concat"}, 452 {"op_name": "Resize"}, 453 ] 454 455 ops = {9: ops_9, 10: ops_10} 456 x = torch.randn(1, 2, 3, 4, requires_grad=True) 457 check_onnx_opsets_operator(MyModel(), x, ops, opset_versions=[9, 10]) 458 459 class MyDynamicModel(torch.nn.Module): 460 def forward(self, x): 461 size = [v * 2 for v in x.size()[2:]] 462 # work around for now: turn the dynamic sizes into constant 463 size = [int(i) for i in size] 464 return torch.nn.functional.interpolate(x, size=size, mode="nearest") 465 466 ops_9 = [ 467 {"op_name": "Constant"}, 468 { 469 "op_name": "Upsample", 470 "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}], 471 }, 472 ] 473 ops_10 = [ 474 {"op_name": "Constant"}, 475 { 476 "op_name": "Resize", 477 "attributes": [{"name": "mode", "s": (b"nearest"), "type": 3}], 478 }, 479 ] 480 ops = {9: ops_9, 10: ops_10} 481 x = torch.randn(20, 16, 50) 482 check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10]) 483 484 def test_affine_grid(self): 485 class MyModule(Module): 486 def __init__(self, align_corners): 487 super().__init__() 488 self.align_corners = align_corners 489 490 def forward(self, theta, size): 491 return torch.nn.functional.affine_grid( 492 theta, size, align_corners=self.align_corners 493 ) 494 495 opset_version = 20 496 ops_2d = { 497 opset_version: [ 498 {"op_name": "Constant"}, 499 {"op_name": "Unsqueeze"}, 500 {"op_name": "Constant"}, 501 {"op_name": "Unsqueeze"}, 502 {"op_name": "Constant"}, 503 {"op_name": "Unsqueeze"}, 504 {"op_name": "Constant"}, 505 {"op_name": "Unsqueeze"}, 506 {"op_name": "Concat"}, 507 {"op_name": "AffineGrid"}, 508 ] 509 } 510 511 ops_3d = { 512 opset_version: [ 513 {"op_name": "Constant"}, 514 {"op_name": "Unsqueeze"}, 515 {"op_name": "Constant"}, 516 {"op_name": "Unsqueeze"}, 517 {"op_name": "Constant"}, 518 {"op_name": "Unsqueeze"}, 519 {"op_name": "Constant"}, 520 {"op_name": "Unsqueeze"}, 521 {"op_name": "Constant"}, 522 {"op_name": "Unsqueeze"}, 523 {"op_name": "Concat"}, 524 {"op_name": "AffineGrid"}, 525 ] 526 } 527 # 2D affine 528 theta_2d = torch.empty(1, 2, 3, dtype=torch.double) 529 size_2d = torch.Size([1, 1, 2, 2]) 530 # 3D affine 531 theta_3d = torch.empty(1, 3, 4, dtype=torch.double) 532 size_3d = torch.Size([1, 1, 2, 2, 2]) 533 534 for inputs, align_corners in itertools.product( 535 ((theta_2d, size_2d, ops_2d), (theta_3d, size_3d, ops_3d)), 536 (True, False), 537 ): 538 theta, size, ops = inputs 539 args = ( 540 theta, 541 size, 542 ) 543 check_onnx_opsets_operator( 544 MyModule(align_corners=align_corners), 545 args, 546 ops, 547 opset_versions=[opset_version], 548 training=torch.onnx.TrainingMode.TRAINING, 549 ) 550 check_onnx_opsets_operator( 551 MyModule(align_corners=align_corners), 552 args, 553 ops, 554 opset_versions=[opset_version], 555 training=torch.onnx.TrainingMode.EVAL, 556 ) 557 558 def test_grid_sample(self): 559 class MyModule(torch.nn.Module): 560 def __init__(self, mode, padding_mode, align_corners): 561 super().__init__() 562 self.mode = mode 563 self.padding_mode = padding_mode 564 self.align_corners = align_corners 565 566 def forward(self, x, grid): 567 return torch.nn.functional.grid_sample( 568 x, 569 grid, 570 mode=self.mode, 571 padding_mode=self.padding_mode, 572 align_corners=self.align_corners, 573 ) 574 575 for mode, padding_mode, align_corners, opset_version in itertools.product( 576 ("bilinear", "nearest", "bicubic"), 577 ("zeros", "border", "reflection"), 578 (True, False), 579 (16, 20), 580 ): 581 582 def test_eval_and_training( 583 ops, opset_version, mode, padding_mode, align_corners, x_shape, grid 584 ): 585 args = ( 586 torch.randn(*x_shape), # x 587 torch.randn(grid), # grid, 588 ) 589 check_onnx_opsets_operator( 590 MyModule( 591 mode=mode, 592 padding_mode=padding_mode, 593 align_corners=align_corners, 594 ), 595 args, 596 ops, 597 opset_versions=[opset_version], 598 training=torch.onnx.TrainingMode.TRAINING, 599 ) 600 check_onnx_opsets_operator( 601 MyModule( 602 mode=mode, 603 padding_mode=padding_mode, 604 align_corners=align_corners, 605 ), 606 args, 607 ops, 608 opset_versions=[opset_version], 609 training=torch.onnx.TrainingMode.EVAL, 610 ) 611 612 ops = {opset_version: [{"op_name": "GridSample"}]} 613 # mode = convert_grid_sample_mode(mode) if opset_version == 20 else mode 614 n, c, d_in, h_in, w_in, d_out, h_out, w_out = 1, 1, 2, 3, 2, 3, 2, 4 615 test_eval_and_training( 616 ops, 617 opset_version, 618 mode, 619 padding_mode, 620 align_corners, 621 (n, c, h_in, w_in), 622 (n, h_out, w_out, 2), 623 ) 624 if opset_version == 20 and mode != "bicubic": 625 test_eval_and_training( 626 ops, 627 opset_version, 628 mode, 629 padding_mode, 630 align_corners, 631 (n, c, d_in, h_in, w_in), 632 (n, d_out, h_out, w_out, 3), 633 ) 634 635 def test_flatten(self): 636 class MyModule(Module): 637 def forward(self, x): 638 return torch.flatten(x) 639 640 module = MyModule() 641 642 ops_0d = [{"op_name": "Constant"}, {"op_name": "Reshape"}] 643 ops_1d = [{"op_name": "Identity"}] 644 for shape in ([], [3]): 645 x = torch.randn(shape) 646 for opset_version in [9, 10]: 647 ops = {opset_version: (ops_0d if len(shape) == 0 else ops_1d)} 648 check_onnx_opsets_operator( 649 module, x, ops, opset_versions=[opset_version] 650 ) 651 652 653if __name__ == "__main__": 654 common_utils.run_tests() 655