xref: /aosp_15_r20/external/pytorch/test/onnx/test_models_onnxruntime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3import os
4import unittest
5from collections import OrderedDict
6from typing import List, Mapping, Tuple
7
8import onnx_test_common
9import parameterized
10import PIL
11import pytorch_test_common
12import test_models
13import torchvision
14from pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest
15from torchvision import ops
16from torchvision.models.detection import (
17    faster_rcnn,
18    image_list,
19    keypoint_rcnn,
20    mask_rcnn,
21    roi_heads,
22    rpn,
23    transform,
24)
25
26import torch
27from torch import nn
28from torch.testing._internal import common_utils
29
30
31def exportTest(
32    self,
33    model,
34    inputs,
35    rtol=1e-2,
36    atol=1e-7,
37    opset_versions=None,
38    acceptable_error_percentage=None,
39):
40    opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12, 13, 14]
41
42    for opset_version in opset_versions:
43        self.opset_version = opset_version
44        self.onnx_shape_inference = True
45        onnx_test_common.run_model_test(
46            self,
47            model,
48            input_args=inputs,
49            rtol=rtol,
50            atol=atol,
51            acceptable_error_percentage=acceptable_error_percentage,
52        )
53
54        if self.is_script_test_enabled and opset_version > 11:
55            script_model = torch.jit.script(model)
56            onnx_test_common.run_model_test(
57                self,
58                script_model,
59                input_args=inputs,
60                rtol=rtol,
61                atol=atol,
62                acceptable_error_percentage=acceptable_error_percentage,
63            )
64
65
66TestModels = type(
67    "TestModels",
68    (pytorch_test_common.ExportTestCase,),
69    dict(
70        test_models.TestModels.__dict__,
71        is_script_test_enabled=False,
72        is_script=False,
73        exportTest=exportTest,
74    ),
75)
76
77
78# model tests for scripting with new JIT APIs and shape inference
79TestModels_new_jit_API = type(
80    "TestModels_new_jit_API",
81    (pytorch_test_common.ExportTestCase,),
82    dict(
83        TestModels.__dict__,
84        exportTest=exportTest,
85        is_script_test_enabled=True,
86        is_script=True,
87        onnx_shape_inference=True,
88    ),
89)
90
91
92def _get_image(rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
93    data_dir = os.path.join(os.path.dirname(__file__), "assets")
94    path = os.path.join(data_dir, *rel_path.split("/"))
95    image = PIL.Image.open(path).convert("RGB").resize(size, PIL.Image.BILINEAR)
96
97    return torchvision.transforms.ToTensor()(image)
98
99
100def _get_test_images() -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
101    return (
102        [_get_image("grace_hopper_517x606.jpg", (100, 320))],
103        [_get_image("rgb_pytorch.png", (250, 380))],
104    )
105
106
107def _get_features(images):
108    s0, s1 = images.shape[-2:]
109    features = [
110        ("0", torch.rand(2, 256, s0 // 4, s1 // 4)),
111        ("1", torch.rand(2, 256, s0 // 8, s1 // 8)),
112        ("2", torch.rand(2, 256, s0 // 16, s1 // 16)),
113        ("3", torch.rand(2, 256, s0 // 32, s1 // 32)),
114        ("4", torch.rand(2, 256, s0 // 64, s1 // 64)),
115    ]
116    features = OrderedDict(features)
117    return features
118
119
120def _init_test_generalized_rcnn_transform():
121    min_size = 100
122    max_size = 200
123    image_mean = [0.485, 0.456, 0.406]
124    image_std = [0.229, 0.224, 0.225]
125    return transform.GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
126
127
128def _init_test_rpn():
129    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
130    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
131    rpn_anchor_generator = rpn.AnchorGenerator(anchor_sizes, aspect_ratios)
132    out_channels = 256
133    rpn_head = rpn.RPNHead(
134        out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
135    )
136    rpn_fg_iou_thresh = 0.7
137    rpn_bg_iou_thresh = 0.3
138    rpn_batch_size_per_image = 256
139    rpn_positive_fraction = 0.5
140    rpn_pre_nms_top_n = dict(training=2000, testing=1000)
141    rpn_post_nms_top_n = dict(training=2000, testing=1000)
142    rpn_nms_thresh = 0.7
143    rpn_score_thresh = 0.0
144
145    return rpn.RegionProposalNetwork(
146        rpn_anchor_generator,
147        rpn_head,
148        rpn_fg_iou_thresh,
149        rpn_bg_iou_thresh,
150        rpn_batch_size_per_image,
151        rpn_positive_fraction,
152        rpn_pre_nms_top_n,
153        rpn_post_nms_top_n,
154        rpn_nms_thresh,
155        score_thresh=rpn_score_thresh,
156    )
157
158
159def _init_test_roi_heads_faster_rcnn():
160    out_channels = 256
161    num_classes = 91
162
163    box_fg_iou_thresh = 0.5
164    box_bg_iou_thresh = 0.5
165    box_batch_size_per_image = 512
166    box_positive_fraction = 0.25
167    bbox_reg_weights = None
168    box_score_thresh = 0.05
169    box_nms_thresh = 0.5
170    box_detections_per_img = 100
171
172    box_roi_pool = ops.MultiScaleRoIAlign(
173        featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2
174    )
175
176    resolution = box_roi_pool.output_size[0]
177    representation_size = 1024
178    box_head = faster_rcnn.TwoMLPHead(out_channels * resolution**2, representation_size)
179
180    representation_size = 1024
181    box_predictor = faster_rcnn.FastRCNNPredictor(representation_size, num_classes)
182
183    return roi_heads.RoIHeads(
184        box_roi_pool,
185        box_head,
186        box_predictor,
187        box_fg_iou_thresh,
188        box_bg_iou_thresh,
189        box_batch_size_per_image,
190        box_positive_fraction,
191        bbox_reg_weights,
192        box_score_thresh,
193        box_nms_thresh,
194        box_detections_per_img,
195    )
196
197
198@parameterized.parameterized_class(
199    ("is_script",),
200    [(True,), (False,)],
201    class_name_func=onnx_test_common.parameterize_class_name,
202)
203class TestModelsONNXRuntime(onnx_test_common._TestONNXRuntime):
204    @skipIfUnsupportedMinOpsetVersion(11)
205    @skipScriptTest()  # Faster RCNN model is not scriptable
206    def test_faster_rcnn(self):
207        model = faster_rcnn.fasterrcnn_resnet50_fpn(
208            pretrained=False, pretrained_backbone=True, min_size=200, max_size=300
209        )
210        model.eval()
211        x1 = torch.randn(3, 200, 300, requires_grad=True)
212        x2 = torch.randn(3, 200, 300, requires_grad=True)
213        self.run_test(model, ([x1, x2],), rtol=1e-3, atol=1e-5)
214        self.run_test(
215            model,
216            ([x1, x2],),
217            input_names=["images_tensors"],
218            output_names=["outputs"],
219            dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
220            rtol=1e-3,
221            atol=1e-5,
222        )
223        dummy_image = [torch.ones(3, 100, 100) * 0.3]
224        images, test_images = _get_test_images()
225        self.run_test(
226            model,
227            (images,),
228            additional_test_inputs=[(images,), (test_images,), (dummy_image,)],
229            input_names=["images_tensors"],
230            output_names=["outputs"],
231            dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
232            rtol=1e-3,
233            atol=1e-5,
234        )
235        self.run_test(
236            model,
237            (dummy_image,),
238            additional_test_inputs=[(dummy_image,), (images,)],
239            input_names=["images_tensors"],
240            output_names=["outputs"],
241            dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
242            rtol=1e-3,
243            atol=1e-5,
244        )
245
246    @unittest.skip("Failing after ONNX 1.13.0")
247    @skipIfUnsupportedMinOpsetVersion(11)
248    @skipScriptTest()
249    def test_mask_rcnn(self):
250        model = mask_rcnn.maskrcnn_resnet50_fpn(
251            pretrained=False, pretrained_backbone=True, min_size=200, max_size=300
252        )
253        images, test_images = _get_test_images()
254        self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
255        self.run_test(
256            model,
257            (images,),
258            input_names=["images_tensors"],
259            output_names=["boxes", "labels", "scores", "masks"],
260            dynamic_axes={
261                "images_tensors": [0, 1, 2],
262                "boxes": [0, 1],
263                "labels": [0],
264                "scores": [0],
265                "masks": [0, 1, 2],
266            },
267            rtol=1e-3,
268            atol=1e-5,
269        )
270        dummy_image = [torch.ones(3, 100, 100) * 0.3]
271        self.run_test(
272            model,
273            (images,),
274            additional_test_inputs=[(images,), (test_images,), (dummy_image,)],
275            input_names=["images_tensors"],
276            output_names=["boxes", "labels", "scores", "masks"],
277            dynamic_axes={
278                "images_tensors": [0, 1, 2],
279                "boxes": [0, 1],
280                "labels": [0],
281                "scores": [0],
282                "masks": [0, 1, 2],
283            },
284            rtol=1e-3,
285            atol=1e-5,
286        )
287        self.run_test(
288            model,
289            (dummy_image,),
290            additional_test_inputs=[(dummy_image,), (images,)],
291            input_names=["images_tensors"],
292            output_names=["boxes", "labels", "scores", "masks"],
293            dynamic_axes={
294                "images_tensors": [0, 1, 2],
295                "boxes": [0, 1],
296                "labels": [0],
297                "scores": [0],
298                "masks": [0, 1, 2],
299            },
300            rtol=1e-3,
301            atol=1e-5,
302        )
303
304    @unittest.skip("Failing, see https://github.com/pytorch/pytorch/issues/66528")
305    @skipIfUnsupportedMinOpsetVersion(11)
306    @skipScriptTest()
307    def test_keypoint_rcnn(self):
308        model = keypoint_rcnn.keypointrcnn_resnet50_fpn(
309            pretrained=False, pretrained_backbone=False, min_size=200, max_size=300
310        )
311        images, test_images = _get_test_images()
312        self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
313        self.run_test(
314            model,
315            (images,),
316            input_names=["images_tensors"],
317            output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
318            dynamic_axes={"images_tensors": [0, 1, 2]},
319            rtol=1e-3,
320            atol=1e-5,
321        )
322        dummy_images = [torch.ones(3, 100, 100) * 0.3]
323        self.run_test(
324            model,
325            (images,),
326            additional_test_inputs=[(images,), (test_images,), (dummy_images,)],
327            input_names=["images_tensors"],
328            output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
329            dynamic_axes={"images_tensors": [0, 1, 2]},
330            rtol=5e-3,
331            atol=1e-5,
332        )
333        self.run_test(
334            model,
335            (dummy_images,),
336            additional_test_inputs=[(dummy_images,), (test_images,)],
337            input_names=["images_tensors"],
338            output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
339            dynamic_axes={"images_tensors": [0, 1, 2]},
340            rtol=5e-3,
341            atol=1e-5,
342        )
343
344    @skipIfUnsupportedMinOpsetVersion(11)
345    @skipScriptTest()
346    def test_roi_heads(self):
347        class RoIHeadsModule(torch.nn.Module):
348            def __init__(self) -> None:
349                super().__init__()
350                self.transform = _init_test_generalized_rcnn_transform()
351                self.rpn = _init_test_rpn()
352                self.roi_heads = _init_test_roi_heads_faster_rcnn()
353
354            def forward(self, images, features: Mapping[str, torch.Tensor]):
355                original_image_sizes = [
356                    (img.shape[-1], img.shape[-2]) for img in images
357                ]
358
359                images_m = image_list.ImageList(
360                    images, [(i.shape[-1], i.shape[-2]) for i in images]
361                )
362                proposals, _ = self.rpn(images_m, features)
363                detections, _ = self.roi_heads(
364                    features, proposals, images_m.image_sizes
365                )
366                detections = self.transform.postprocess(
367                    detections, images_m.image_sizes, original_image_sizes
368                )
369                return detections
370
371        images = torch.rand(2, 3, 100, 100)
372        features = _get_features(images)
373        images2 = torch.rand(2, 3, 150, 150)
374        test_features = _get_features(images2)
375
376        model = RoIHeadsModule()
377        model.eval()
378        model(images, features)
379
380        self.run_test(
381            model,
382            (images, features),
383            input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
384            dynamic_axes={
385                "input1": [0, 1, 2, 3],
386                "input2": [0, 1, 2, 3],
387                "input3": [0, 1, 2, 3],
388                "input4": [0, 1, 2, 3],
389                "input5": [0, 1, 2, 3],
390                "input6": [0, 1, 2, 3],
391            },
392            additional_test_inputs=[(images, features), (images2, test_features)],
393        )
394
395    @skipScriptTest()  # TODO: #75625
396    @skipIfUnsupportedMinOpsetVersion(20)
397    def test_transformer_encoder(self):
398        class MyModule(torch.nn.Module):
399            def __init__(self, ninp, nhead, nhid, dropout, nlayers):
400                super().__init__()
401                encoder_layers = nn.TransformerEncoderLayer(ninp, nhead, nhid, dropout)
402                self.transformer_encoder = nn.TransformerEncoder(
403                    encoder_layers, nlayers
404                )
405
406            def forward(self, input):
407                return self.transformer_encoder(input)
408
409        x = torch.rand(10, 32, 512)
410        self.run_test(MyModule(512, 8, 2048, 0.0, 3), (x,), atol=1e-5)
411
412    @skipScriptTest()
413    def test_mobilenet_v3(self):
414        model = torchvision.models.quantization.mobilenet_v3_large(pretrained=False)
415        dummy_input = torch.randn(1, 3, 224, 224)
416        self.run_test(model, (dummy_input,))
417
418    @skipIfUnsupportedMinOpsetVersion(11)
419    @skipScriptTest()
420    def test_shufflenet_v2_dynamic_axes(self):
421        model = torchvision.models.shufflenet_v2_x0_5(weights=None)
422        dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
423        test_inputs = torch.randn(3, 3, 224, 224, requires_grad=True)
424        self.run_test(
425            model,
426            (dummy_input,),
427            additional_test_inputs=[(dummy_input,), (test_inputs,)],
428            input_names=["input_images"],
429            output_names=["outputs"],
430            dynamic_axes={
431                "input_images": {0: "batch_size"},
432                "output": {0: "batch_size"},
433            },
434            rtol=1e-3,
435            atol=1e-5,
436        )
437
438
439if __name__ == "__main__":
440    common_utils.run_tests()
441