1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11from torchvision.models.segmentation import deeplabv3, deeplabv3_resnet50 # @manual 12 13 14class DL3Wrapper(torch.nn.Module): 15 def __init__(self): 16 super().__init__() 17 self.m = deeplabv3_resnet50( 18 weights=deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT 19 ) 20 21 def forward(self, *args): 22 return self.m(*args)["out"] 23 24 25class TestDeepLabV3(unittest.TestCase): 26 dl3 = DL3Wrapper() 27 dl3 = dl3.eval() 28 model_inputs = (torch.randn(1, 3, 224, 224),) 29 30 def test_fp32_dl3(self): 31 32 ( 33 Tester(self.dl3, self.model_inputs) 34 .export() 35 .to_edge_transform_and_lower() 36 .to_executorch() 37 .serialize() 38 .run_method_and_compare_outputs() 39 ) 40