xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/bilinear2d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestUpsampleBilinear2d(unittest.TestCase):
14    class StaticResizeBilinear2dModule(torch.nn.Module):
15        def forward(self, x):
16            a = torch.nn.functional.interpolate(
17                x,
18                size=(x.shape[2] * 2, x.shape[3] * 3),
19                mode="bilinear",
20                align_corners=False,
21                antialias=False,
22            )
23            a = torch.nn.functional.interpolate(
24                a,
25                scale_factor=3.0,
26                mode="bilinear",
27                align_corners=False,
28                antialias=False,
29            )
30            return a
31
32    class StaticResizeBilinear2dModuleWithAlignCorners(torch.nn.Module):
33        def forward(self, x):
34            a = torch.nn.functional.interpolate(
35                x,
36                size=(x.shape[2] * 2, x.shape[3] * 3),
37                mode="bilinear",
38                align_corners=True,
39                antialias=False,
40            )
41            a = torch.nn.functional.interpolate(
42                a,
43                scale_factor=3.0,
44                mode="bilinear",
45                align_corners=True,
46                antialias=False,
47            )
48            return a
49
50    class Bilinear2dAntiAlias(torch.nn.Module):
51        def forward(self, x):
52            a = torch.nn.functional.interpolate(
53                x,
54                size=(x.shape[2] * 2, x.shape[3] * 3),
55                mode="bilinear",
56                align_corners=True,
57                antialias=True,
58            )
59            a = torch.nn.functional.interpolate(
60                a,
61                scale_factor=3.0,
62                mode="bilinear",
63                align_corners=False,
64                antialias=True,
65            )
66            return a
67
68    # Since we may or may not enable dim order, use these ops only for
69    # check_not since we have `to_copy` and `to_dim_order_copy` in the list.
70    ops = {
71        "executorch_exir_dialects_edge__ops_aten_sub_Tensor",
72        "executorch_exir_dialects_edge__ops_aten_mul_Tensor",
73        "executorch_exir_dialects_edge__ops_aten_index_Tensor",
74        "executorch_exir_dialects_edge__ops_aten_arange_start_step",
75        "executorch_exir_dialects_edge__ops_aten__to_copy_default",
76        "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default",
77        "executorch_exir_dialects_edge__ops_aten_add_Tensor",
78        "executorch_exir_dialects_edge__ops_aten_clamp_default",
79    }
80
81    def test_fp32_static_resize_bilinear2d(self):
82        example_inputs = (torch.randn(2, 3, 4, 5),)
83        (
84            Tester(self.StaticResizeBilinear2dModule(), example_inputs)
85            .export()
86            .to_edge_transform_and_lower()
87            .check_not(self.ops)
88            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
89            .to_executorch()
90            .serialize()
91            .run_method_and_compare_outputs()
92        )
93
94    def test_fp32_static_resize_bilinear2d_with_align_corners(self):
95        example_inputs = (torch.randn(2, 3, 4, 5),)
96        (
97            Tester(self.StaticResizeBilinear2dModuleWithAlignCorners(), example_inputs)
98            .export()
99            .to_edge_transform_and_lower()
100            .check_not(self.ops)
101            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
102            .to_executorch()
103            .serialize()
104            .run_method_and_compare_outputs()
105        )
106
107    def test_fp32_static_resize_bilinear2d_antialiased(self):
108        # Check bilinear2d_aa is not partitioned
109        example_inputs = (torch.randn(2, 3, 4, 5),)
110        (
111            Tester(self.Bilinear2dAntiAlias(), example_inputs)
112            .export()
113            .to_edge_transform_and_lower()
114            .check_count(
115                {
116                    "executorch_exir_dialects_edge__ops_aten__upsample_bilinear2d_aa_default": 2
117                }
118            )
119            .check_not(["torch.ops.higher_order.executorch_call_delegate"])
120        )
121