xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/models/deeplab_v3.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11from torchvision.models.segmentation import deeplabv3, deeplabv3_resnet50  # @manual
12
13
14class DL3Wrapper(torch.nn.Module):
15    def __init__(self):
16        super().__init__()
17        self.m = deeplabv3_resnet50(
18            weights=deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT
19        )
20
21    def forward(self, *args):
22        return self.m(*args)["out"]
23
24
25class TestDeepLabV3(unittest.TestCase):
26    dl3 = DL3Wrapper()
27    dl3 = dl3.eval()
28    model_inputs = (torch.randn(1, 3, 224, 224),)
29
30    def test_fp32_dl3(self):
31
32        (
33            Tester(self.dl3, self.model_inputs)
34            .export()
35            .to_edge_transform_and_lower()
36            .to_executorch()
37            .serialize()
38            .run_method_and_compare_outputs()
39        )
40