xref: /aosp_15_r20/external/pytorch/test/onnx/test_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: onnx"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport pytorch_test_common
6*da0073e9SAndroid Build Coastguard Workerfrom model_defs.dcgan import _netD, _netG, bsz, imgsz, nz, weights_init
7*da0073e9SAndroid Build Coastguard Workerfrom model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2
8*da0073e9SAndroid Build Coastguard Workerfrom model_defs.mnist import MNIST
9*da0073e9SAndroid Build Coastguard Workerfrom model_defs.op_test import ConcatNet, DummyNet, FakeQuantNet, PermuteNet, PReluNet
10*da0073e9SAndroid Build Coastguard Workerfrom model_defs.squeezenet import SqueezeNet
11*da0073e9SAndroid Build Coastguard Workerfrom model_defs.srresnet import SRResNet
12*da0073e9SAndroid Build Coastguard Workerfrom model_defs.super_resolution import SuperResolutionNet
13*da0073e9SAndroid Build Coastguard Workerfrom pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest
14*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models import shufflenet_v2_x1_0
15*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.alexnet import alexnet
16*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.densenet import densenet121
17*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.googlenet import googlenet
18*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.inception import inception_v3
19*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.mnasnet import mnasnet1_0
20*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.mobilenet import mobilenet_v2
21*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.resnet import resnet50
22*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.segmentation import deeplabv3_resnet101, fcn_resnet101
23*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn
24*da0073e9SAndroid Build Coastguard Workerfrom torchvision.models.video import mc3_18, r2plus1d_18, r3d_18
25*da0073e9SAndroid Build Coastguard Workerfrom verify import verify
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerimport torch
28*da0073e9SAndroid Build Coastguard Workerfrom torch.ao import quantization
29*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import Variable
30*da0073e9SAndroid Build Coastguard Workerfrom torch.onnx import OperatorExportTypes
31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import common_utils
32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfNoLapack
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerif torch.cuda.is_available():
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    def toC(x):
38*da0073e9SAndroid Build Coastguard Worker        return x.cuda()
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerelse:
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    def toC(x):
43*da0073e9SAndroid Build Coastguard Worker        return x
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard WorkerBATCH_SIZE = 2
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Workerclass TestModels(pytorch_test_common.ExportTestCase):
50*da0073e9SAndroid Build Coastguard Worker    opset_version = 9  # Caffe2 doesn't support the default.
51*da0073e9SAndroid Build Coastguard Worker    keep_initializers_as_inputs = False
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, **kwargs):
54*da0073e9SAndroid Build Coastguard Worker        import caffe2.python.onnx.backend as backend
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        with torch.onnx.select_model_mode_for_export(
57*da0073e9SAndroid Build Coastguard Worker            model, torch.onnx.TrainingMode.EVAL
58*da0073e9SAndroid Build Coastguard Worker        ):
59*da0073e9SAndroid Build Coastguard Worker            graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
60*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_pass_lint(graph)
61*da0073e9SAndroid Build Coastguard Worker            verify(
62*da0073e9SAndroid Build Coastguard Worker                model,
63*da0073e9SAndroid Build Coastguard Worker                inputs,
64*da0073e9SAndroid Build Coastguard Worker                backend,
65*da0073e9SAndroid Build Coastguard Worker                rtol=rtol,
66*da0073e9SAndroid Build Coastguard Worker                atol=atol,
67*da0073e9SAndroid Build Coastguard Worker                opset_version=self.opset_version,
68*da0073e9SAndroid Build Coastguard Worker            )
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    def test_ops(self):
71*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
72*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(DummyNet()), toC(x))
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def test_prelu(self):
75*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
76*da0073e9SAndroid Build Coastguard Worker        self.exportTest(PReluNet(), x)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    @skipScriptTest()
79*da0073e9SAndroid Build Coastguard Worker    def test_concat(self):
80*da0073e9SAndroid Build Coastguard Worker        input_a = Variable(torch.randn(BATCH_SIZE, 3))
81*da0073e9SAndroid Build Coastguard Worker        input_b = Variable(torch.randn(BATCH_SIZE, 3))
82*da0073e9SAndroid Build Coastguard Worker        inputs = ((toC(input_a), toC(input_b)),)
83*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(ConcatNet()), inputs)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    def test_permute(self):
86*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
87*da0073e9SAndroid Build Coastguard Worker        self.exportTest(PermuteNet(), x)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    @skipScriptTest()
90*da0073e9SAndroid Build Coastguard Worker    def test_embedding_sequential_1(self):
91*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
92*da0073e9SAndroid Build Coastguard Worker        self.exportTest(EmbeddingNetwork1(), x)
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    @skipScriptTest()
95*da0073e9SAndroid Build Coastguard Worker    def test_embedding_sequential_2(self):
96*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
97*da0073e9SAndroid Build Coastguard Worker        self.exportTest(EmbeddingNetwork2(), x)
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("This model takes too much memory")
100*da0073e9SAndroid Build Coastguard Worker    def test_srresnet(self):
101*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(1, 3, 224, 224).fill_(1.0))
102*da0073e9SAndroid Build Coastguard Worker        self.exportTest(
103*da0073e9SAndroid Build Coastguard Worker            toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x)
104*da0073e9SAndroid Build Coastguard Worker        )
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
107*da0073e9SAndroid Build Coastguard Worker    def test_super_resolution(self):
108*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0))
109*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(SuperResolutionNet(upscale_factor=3)), toC(x), atol=1e-6)
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker    def test_alexnet(self):
112*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
113*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(alexnet()), toC(x))
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker    def test_mnist(self):
116*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0))
117*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(MNIST()), toC(x))
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("This model takes too much memory")
120*da0073e9SAndroid Build Coastguard Worker    def test_vgg16(self):
121*da0073e9SAndroid Build Coastguard Worker        # VGG 16-layer model (configuration "D")
122*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
123*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(vgg16()), toC(x))
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("This model takes too much memory")
126*da0073e9SAndroid Build Coastguard Worker    def test_vgg16_bn(self):
127*da0073e9SAndroid Build Coastguard Worker        # VGG 16-layer model (configuration "D") with batch normalization
128*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
129*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(vgg16_bn()), toC(x))
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("This model takes too much memory")
132*da0073e9SAndroid Build Coastguard Worker    def test_vgg19(self):
133*da0073e9SAndroid Build Coastguard Worker        # VGG 19-layer model (configuration "E")
134*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
135*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(vgg19()), toC(x))
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("This model takes too much memory")
138*da0073e9SAndroid Build Coastguard Worker    def test_vgg19_bn(self):
139*da0073e9SAndroid Build Coastguard Worker        # VGG 19-layer model (configuration "E") with batch normalization
140*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
141*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(vgg19_bn()), toC(x))
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker    def test_resnet(self):
144*da0073e9SAndroid Build Coastguard Worker        # ResNet50 model
145*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
146*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(resnet50()), toC(x), atol=1e-6)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    # This test is numerically unstable. Sporadic single element mismatch occurs occasionally.
149*da0073e9SAndroid Build Coastguard Worker    def test_inception(self):
150*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299))
151*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(inception_v3()), toC(x), acceptable_error_percentage=0.01)
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    def test_squeezenet(self):
154*da0073e9SAndroid Build Coastguard Worker        # SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and
155*da0073e9SAndroid Build Coastguard Worker        # <0.5MB model size
156*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
157*da0073e9SAndroid Build Coastguard Worker        sqnet_v1_0 = SqueezeNet(version=1.1)
158*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(sqnet_v1_0), toC(x))
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        # SqueezeNet 1.1 has 2.4x less computation and slightly fewer params
161*da0073e9SAndroid Build Coastguard Worker        # than SqueezeNet 1.0, without sacrificing accuracy.
162*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
163*da0073e9SAndroid Build Coastguard Worker        sqnet_v1_1 = SqueezeNet(version=1.1)
164*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(sqnet_v1_1), toC(x))
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker    def test_densenet(self):
167*da0073e9SAndroid Build Coastguard Worker        # Densenet-121 model
168*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
169*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    @skipScriptTest()
172*da0073e9SAndroid Build Coastguard Worker    def test_dcgan_netD(self):
173*da0073e9SAndroid Build Coastguard Worker        netD = _netD(1)
174*da0073e9SAndroid Build Coastguard Worker        netD.apply(weights_init)
175*da0073e9SAndroid Build Coastguard Worker        input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1))
176*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(netD), toC(input))
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    @skipScriptTest()
179*da0073e9SAndroid Build Coastguard Worker    def test_dcgan_netG(self):
180*da0073e9SAndroid Build Coastguard Worker        netG = _netG(1)
181*da0073e9SAndroid Build Coastguard Worker        netG.apply(weights_init)
182*da0073e9SAndroid Build Coastguard Worker        input = Variable(torch.empty(bsz, nz, 1, 1).normal_(0, 1))
183*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(netG), toC(input))
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker    @skipIfUnsupportedMinOpsetVersion(10)
186*da0073e9SAndroid Build Coastguard Worker    def test_fake_quant(self):
187*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
188*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(FakeQuantNet()), toC(x))
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker    @skipIfUnsupportedMinOpsetVersion(10)
191*da0073e9SAndroid Build Coastguard Worker    def test_qat_resnet_pertensor(self):
192*da0073e9SAndroid Build Coastguard Worker        # Quantize ResNet50 model
193*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
194*da0073e9SAndroid Build Coastguard Worker        qat_resnet50 = resnet50()
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        # Use per tensor for weight. Per channel support will come with opset 13
197*da0073e9SAndroid Build Coastguard Worker        qat_resnet50.qconfig = quantization.QConfig(
198*da0073e9SAndroid Build Coastguard Worker            activation=quantization.default_fake_quant,
199*da0073e9SAndroid Build Coastguard Worker            weight=quantization.default_fake_quant,
200*da0073e9SAndroid Build Coastguard Worker        )
201*da0073e9SAndroid Build Coastguard Worker        quantization.prepare_qat(qat_resnet50, inplace=True)
202*da0073e9SAndroid Build Coastguard Worker        qat_resnet50.apply(torch.ao.quantization.enable_observer)
203*da0073e9SAndroid Build Coastguard Worker        qat_resnet50.apply(torch.ao.quantization.enable_fake_quant)
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker        _ = qat_resnet50(x)
206*da0073e9SAndroid Build Coastguard Worker        for module in qat_resnet50.modules():
207*da0073e9SAndroid Build Coastguard Worker            if isinstance(module, quantization.FakeQuantize):
208*da0073e9SAndroid Build Coastguard Worker                module.calculate_qparams()
209*da0073e9SAndroid Build Coastguard Worker        qat_resnet50.apply(torch.ao.quantization.disable_observer)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(qat_resnet50), toC(x))
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker    @skipIfUnsupportedMinOpsetVersion(13)
214*da0073e9SAndroid Build Coastguard Worker    def test_qat_resnet_per_channel(self):
215*da0073e9SAndroid Build Coastguard Worker        # Quantize ResNet50 model
216*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
217*da0073e9SAndroid Build Coastguard Worker        qat_resnet50 = resnet50()
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        qat_resnet50.qconfig = quantization.QConfig(
220*da0073e9SAndroid Build Coastguard Worker            activation=quantization.default_fake_quant,
221*da0073e9SAndroid Build Coastguard Worker            weight=quantization.default_per_channel_weight_fake_quant,
222*da0073e9SAndroid Build Coastguard Worker        )
223*da0073e9SAndroid Build Coastguard Worker        quantization.prepare_qat(qat_resnet50, inplace=True)
224*da0073e9SAndroid Build Coastguard Worker        qat_resnet50.apply(torch.ao.quantization.enable_observer)
225*da0073e9SAndroid Build Coastguard Worker        qat_resnet50.apply(torch.ao.quantization.enable_fake_quant)
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker        _ = qat_resnet50(x)
228*da0073e9SAndroid Build Coastguard Worker        for module in qat_resnet50.modules():
229*da0073e9SAndroid Build Coastguard Worker            if isinstance(module, quantization.FakeQuantize):
230*da0073e9SAndroid Build Coastguard Worker                module.calculate_qparams()
231*da0073e9SAndroid Build Coastguard Worker        qat_resnet50.apply(torch.ao.quantization.disable_observer)
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(qat_resnet50), toC(x))
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker    @skipScriptTest(skip_before_opset_version=15, reason="None type in outputs")
236*da0073e9SAndroid Build Coastguard Worker    def test_googlenet(self):
237*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
238*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    def test_mnasnet(self):
241*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
242*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(mnasnet1_0()), toC(x), rtol=1e-3, atol=1e-5)
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker    def test_mobilenet(self):
245*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
246*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker    @skipScriptTest()  # prim_data
249*da0073e9SAndroid Build Coastguard Worker    def test_shufflenet(self):
250*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
251*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5)
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker    @skipIfUnsupportedMinOpsetVersion(11)
254*da0073e9SAndroid Build Coastguard Worker    def test_fcn(self):
255*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
256*da0073e9SAndroid Build Coastguard Worker        self.exportTest(
257*da0073e9SAndroid Build Coastguard Worker            toC(fcn_resnet101(weights=None, weights_backbone=None)),
258*da0073e9SAndroid Build Coastguard Worker            toC(x),
259*da0073e9SAndroid Build Coastguard Worker            rtol=1e-3,
260*da0073e9SAndroid Build Coastguard Worker            atol=1e-5,
261*da0073e9SAndroid Build Coastguard Worker        )
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker    @skipIfUnsupportedMinOpsetVersion(11)
264*da0073e9SAndroid Build Coastguard Worker    def test_deeplab(self):
265*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
266*da0073e9SAndroid Build Coastguard Worker        self.exportTest(
267*da0073e9SAndroid Build Coastguard Worker            toC(deeplabv3_resnet101(weights=None, weights_backbone=None)),
268*da0073e9SAndroid Build Coastguard Worker            toC(x),
269*da0073e9SAndroid Build Coastguard Worker            rtol=1e-3,
270*da0073e9SAndroid Build Coastguard Worker            atol=1e-5,
271*da0073e9SAndroid Build Coastguard Worker        )
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker    def test_r3d_18_video(self):
274*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
275*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(r3d_18()), toC(x), rtol=1e-3, atol=1e-5)
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker    def test_mc3_18_video(self):
278*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
279*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(mc3_18()), toC(x), rtol=1e-3, atol=1e-5)
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker    def test_r2plus1d_18_video(self):
282*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
283*da0073e9SAndroid Build Coastguard Worker        self.exportTest(toC(r2plus1d_18()), toC(x), rtol=1e-3, atol=1e-5)
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
287*da0073e9SAndroid Build Coastguard Worker    common_utils.run_tests()
288