xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/models/edsr.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10
11from executorch.backends.xnnpack.test.tester import Tester
12from executorch.backends.xnnpack.test.tester.tester import Quantize
13from torchsr.models import edsr_r16f64
14
15
16class TestEDSR(unittest.TestCase):
17    edsr = edsr_r16f64(2, False).eval()  # noqa
18    model_inputs = (torch.randn(1, 3, 224, 224),)
19
20    def test_fp32_edsr(self):
21        (
22            Tester(self.edsr, self.model_inputs)
23            .export()
24            .to_edge_transform_and_lower()
25            .to_executorch()
26            .serialize()
27            .run_method_and_compare_outputs()
28        )
29
30    @unittest.skip("T187799178: Debugging Numerical Issues with Calibration")
31    def _test_qs8_edsr(self):
32        (
33            Tester(self.edsr, self.model_inputs)
34            .quantize()
35            .export()
36            .to_edge_transform_and_lower()
37            .to_executorch()
38            .serialize()
39            .run_method_and_compare_outputs()
40        )
41
42    # TODO: Delete and only used calibrated test after T187799178
43    def test_qs8_edsr_no_calibrate(self):
44        (
45            Tester(self.edsr, self.model_inputs)
46            .quantize(Quantize(calibrate=False))
47            .export()
48            .to_edge_transform_and_lower()
49            .to_executorch()
50            .serialize()
51            .run_method_and_compare_outputs()
52        )
53