xref: /aosp_15_r20/external/pytorch/test/test_nnapi.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"]
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport ctypes
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Workerfrom typing import Tuple
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerfrom torch.backends._nnapi.prepare import convert_model_to_nnapi
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantized import supported_qengines
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef qpt(t, scale, zero_point, dtype=torch.quint8):
16*da0073e9SAndroid Build Coastguard Worker    t = torch.tensor(t)
17*da0073e9SAndroid Build Coastguard Worker    return torch.quantize_per_tensor(t, scale, zero_point, dtype)
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerdef nhwc(t):
21*da0073e9SAndroid Build Coastguard Worker    t = t.clone().contiguous(memory_format=torch.channels_last)
22*da0073e9SAndroid Build Coastguard Worker    t.nnapi_nhwc = True
23*da0073e9SAndroid Build Coastguard Worker    return t
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless(
27*da0073e9SAndroid Build Coastguard Worker    "qnnpack" in supported_qengines,
28*da0073e9SAndroid Build Coastguard Worker    "This Pytorch Build has not been built with or does not support QNNPACK",
29*da0073e9SAndroid Build Coastguard Worker)
30*da0073e9SAndroid Build Coastguard Workerclass TestNNAPI(TestCase):
31*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
32*da0073e9SAndroid Build Coastguard Worker        # Avoid saturation in fbgemm
33*da0073e9SAndroid Build Coastguard Worker        torch.backends.quantized.engine = "qnnpack"
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH")
36*da0073e9SAndroid Build Coastguard Worker        if libneuralnetworks_path:
37*da0073e9SAndroid Build Coastguard Worker            ctypes.cdll.LoadLibrary(libneuralnetworks_path)
38*da0073e9SAndroid Build Coastguard Worker            print("Will attempt to run NNAPI models.")
39*da0073e9SAndroid Build Coastguard Worker            self.can_run_nnapi = True
40*da0073e9SAndroid Build Coastguard Worker        else:
41*da0073e9SAndroid Build Coastguard Worker            self.can_run_nnapi = False
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    # Created for easy override by subclasses (eg TestNnapiBackend)
44*da0073e9SAndroid Build Coastguard Worker    def call_lowering_to_nnapi(self, traced_module, args):
45*da0073e9SAndroid Build Coastguard Worker        return convert_model_to_nnapi(traced_module, args)
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    # Created for subclasses to set can_run_nnapi (eg TestNnapiBackend)
48*da0073e9SAndroid Build Coastguard Worker    def set_can_run_nnapi(self, can_run):
49*da0073e9SAndroid Build Coastguard Worker        self.can_run_nnapi = can_run
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker    def check(
52*da0073e9SAndroid Build Coastguard Worker        self,
53*da0073e9SAndroid Build Coastguard Worker        module,
54*da0073e9SAndroid Build Coastguard Worker        arg_or_args,
55*da0073e9SAndroid Build Coastguard Worker        *,
56*da0073e9SAndroid Build Coastguard Worker        trace_args=None,
57*da0073e9SAndroid Build Coastguard Worker        convert_args=None,
58*da0073e9SAndroid Build Coastguard Worker        atol_rtol=None,
59*da0073e9SAndroid Build Coastguard Worker        limit=None,
60*da0073e9SAndroid Build Coastguard Worker        expected_memory_format=None,
61*da0073e9SAndroid Build Coastguard Worker    ):
62*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
63*da0073e9SAndroid Build Coastguard Worker            if isinstance(arg_or_args, torch.Tensor):
64*da0073e9SAndroid Build Coastguard Worker                args = [arg_or_args]
65*da0073e9SAndroid Build Coastguard Worker            else:
66*da0073e9SAndroid Build Coastguard Worker                args = arg_or_args
67*da0073e9SAndroid Build Coastguard Worker            module.eval()
68*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(module, trace_args or args)
69*da0073e9SAndroid Build Coastguard Worker            nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args)
70*da0073e9SAndroid Build Coastguard Worker            if not self.can_run_nnapi:
71*da0073e9SAndroid Build Coastguard Worker                # Only test that the model was converted successfully.
72*da0073e9SAndroid Build Coastguard Worker                return
73*da0073e9SAndroid Build Coastguard Worker            eager_output = module(*args)
74*da0073e9SAndroid Build Coastguard Worker            nnapi_output = nnapi_module(*args)
75*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
76*da0073e9SAndroid Build Coastguard Worker            if atol_rtol is not None:
77*da0073e9SAndroid Build Coastguard Worker                kwargs["atol"] = atol_rtol[0]
78*da0073e9SAndroid Build Coastguard Worker                kwargs["rtol"] = atol_rtol[1]
79*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(eager_output, nnapi_output, **kwargs)
80*da0073e9SAndroid Build Coastguard Worker            if limit is not None:
81*da0073e9SAndroid Build Coastguard Worker                mismatches = eager_output.int_repr().to(
82*da0073e9SAndroid Build Coastguard Worker                    torch.int32
83*da0073e9SAndroid Build Coastguard Worker                ) - nnapi_output.int_repr().to(torch.int32)
84*da0073e9SAndroid Build Coastguard Worker                if mismatches.count_nonzero() > limit:
85*da0073e9SAndroid Build Coastguard Worker                    # Too many mismatches.  Re-run the check with no tolerance
86*da0073e9SAndroid Build Coastguard Worker                    # to get a nice message.
87*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0)
88*da0073e9SAndroid Build Coastguard Worker            if expected_memory_format:
89*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
90*da0073e9SAndroid Build Coastguard Worker                    nnapi_output.is_contiguous(memory_format=expected_memory_format)
91*da0073e9SAndroid Build Coastguard Worker                )
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    def float_and_quant_and_nhwc(self, inp_float, scale, zero_point):
94*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(29)
95*da0073e9SAndroid Build Coastguard Worker        inp_quant = qpt(inp_float, 0.03, 128)
96*da0073e9SAndroid Build Coastguard Worker        return [
97*da0073e9SAndroid Build Coastguard Worker            ("float", inp_float),
98*da0073e9SAndroid Build Coastguard Worker            ("float-nhwc", nhwc(inp_float)),
99*da0073e9SAndroid Build Coastguard Worker            ("quant", inp_quant),
100*da0073e9SAndroid Build Coastguard Worker            ("quant-nhwc", nhwc(inp_quant)),
101*da0073e9SAndroid Build Coastguard Worker        ]
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    def test_prelu(self):
104*da0073e9SAndroid Build Coastguard Worker        arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
105*da0073e9SAndroid Build Coastguard Worker        single_a = torch.nn.PReLU()
106*da0073e9SAndroid Build Coastguard Worker        self.check(single_a, arg)
107*da0073e9SAndroid Build Coastguard Worker        multi_a = torch.nn.PReLU(4)
108*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
109*da0073e9SAndroid Build Coastguard Worker            multi_a.weight.copy_(torch.tensor([0.1, 0.2, 0.3, 0.4]))
110*da0073e9SAndroid Build Coastguard Worker        self.check(multi_a, nhwc(arg))
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        # Test flexible size
113*da0073e9SAndroid Build Coastguard Worker        self.check(
114*da0073e9SAndroid Build Coastguard Worker            multi_a,
115*da0073e9SAndroid Build Coastguard Worker            arg,
116*da0073e9SAndroid Build Coastguard Worker            trace_args=[torch.zeros(1, 4, 3, 3)],
117*da0073e9SAndroid Build Coastguard Worker            convert_args=[nhwc(torch.zeros(1, 4, 0, 0))],
118*da0073e9SAndroid Build Coastguard Worker        )
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    def test_quantize(self):
121*da0073e9SAndroid Build Coastguard Worker        self.check(
122*da0073e9SAndroid Build Coastguard Worker            torch.ao.nn.quantized.Quantize(0.25, 2, torch.quint8),
123*da0073e9SAndroid Build Coastguard Worker            nhwc(torch.tensor([[[[1.0]], [[2.0]]]])),
124*da0073e9SAndroid Build Coastguard Worker        )
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker    def test_dequantize(self):
127*da0073e9SAndroid Build Coastguard Worker        self.check(
128*da0073e9SAndroid Build Coastguard Worker            torch.ao.nn.quantized.DeQuantize(), nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2))
129*da0073e9SAndroid Build Coastguard Worker        )
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    def test_unsqueeze(self):
132*da0073e9SAndroid Build Coastguard Worker        class UnsqueezeModule(torch.nn.Module):
133*da0073e9SAndroid Build Coastguard Worker            def __init__(self, dim):
134*da0073e9SAndroid Build Coastguard Worker                super().__init__()
135*da0073e9SAndroid Build Coastguard Worker                self.dim = dim
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
138*da0073e9SAndroid Build Coastguard Worker                return arg.unsqueeze(self.dim)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2))
141*da0073e9SAndroid Build Coastguard Worker        self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2))
142*da0073e9SAndroid Build Coastguard Worker        self.check(UnsqueezeModule(0), torch.randn(4, 2, 2))
143*da0073e9SAndroid Build Coastguard Worker        self.check(UnsqueezeModule(1), torch.randn(4, 2, 2))
144*da0073e9SAndroid Build Coastguard Worker        self.check(UnsqueezeModule(2), torch.randn(4, 2, 2))
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker    def test_reshape(self):
147*da0073e9SAndroid Build Coastguard Worker        class ReshapeModule(torch.nn.Module):
148*da0073e9SAndroid Build Coastguard Worker            def __init__(self, shape):
149*da0073e9SAndroid Build Coastguard Worker                super().__init__()
150*da0073e9SAndroid Build Coastguard Worker                self.shape = shape
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
153*da0073e9SAndroid Build Coastguard Worker                return arg.reshape(self.shape)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        self.check(ReshapeModule((2, 4)), torch.randn(4, 2, 1, 1))
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker        self.check(ReshapeModule((8, -1)), nhwc(torch.randn(4, 2, 1, 1)))
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "target size"):
160*da0073e9SAndroid Build Coastguard Worker            self.check(ReshapeModule((2, 4)), nhwc(torch.randn(4, 2, 1, 1)))
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker    def test_flatten(self):
163*da0073e9SAndroid Build Coastguard Worker        for mod in [
164*da0073e9SAndroid Build Coastguard Worker            torch.nn.Flatten(),
165*da0073e9SAndroid Build Coastguard Worker            torch.nn.Flatten(start_dim=2, end_dim=3),
166*da0073e9SAndroid Build Coastguard Worker            torch.nn.Flatten(start_dim=2, end_dim=4),
167*da0073e9SAndroid Build Coastguard Worker            torch.nn.Flatten(start_dim=0, end_dim=-2),
168*da0073e9SAndroid Build Coastguard Worker            torch.nn.Flatten(start_dim=0, end_dim=4),
169*da0073e9SAndroid Build Coastguard Worker        ]:
170*da0073e9SAndroid Build Coastguard Worker            self.check(mod, torch.randn(4, 2, 1, 3, 7))
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker        # flex inputs
173*da0073e9SAndroid Build Coastguard Worker        self.check(
174*da0073e9SAndroid Build Coastguard Worker            torch.nn.Flatten(),
175*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 2, 1, 3, 7),
176*da0073e9SAndroid Build Coastguard Worker            convert_args=[torch.zeros(0, 2, 1, 3, 7)],
177*da0073e9SAndroid Build Coastguard Worker        )
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        # channels last
180*da0073e9SAndroid Build Coastguard Worker        self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 1, 4, 7)))
181*da0073e9SAndroid Build Coastguard Worker        self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 3, 1, 1)))
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        # Exceptions
184*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "not supported on NHWC"):
185*da0073e9SAndroid Build Coastguard Worker            self.check(torch.nn.Flatten(), nhwc(torch.randn(1, 3, 4, 4)))
186*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
187*da0073e9SAndroid Build Coastguard Worker            Exception, "Flattening flexible dims is not supported yet"
188*da0073e9SAndroid Build Coastguard Worker        ):
189*da0073e9SAndroid Build Coastguard Worker            self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
190*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Only 1 dim"):
191*da0073e9SAndroid Build Coastguard Worker            self.check(
192*da0073e9SAndroid Build Coastguard Worker                torch.nn.Flatten(start_dim=1, end_dim=-2), torch.randn(0, 2, 1, 3, 0)
193*da0073e9SAndroid Build Coastguard Worker            )
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    def test_slice(self):
196*da0073e9SAndroid Build Coastguard Worker        class SliceModule(torch.nn.Module):
197*da0073e9SAndroid Build Coastguard Worker            def __init__(self, start, stop, step):
198*da0073e9SAndroid Build Coastguard Worker                super().__init__()
199*da0073e9SAndroid Build Coastguard Worker                self.start = start
200*da0073e9SAndroid Build Coastguard Worker                self.stop = stop
201*da0073e9SAndroid Build Coastguard Worker                self.step = step
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker            def forward(self, t):
204*da0073e9SAndroid Build Coastguard Worker                return t[1:, self.start : self.stop : self.step, :]
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker        class SliceModule2(torch.nn.Module):
207*da0073e9SAndroid Build Coastguard Worker            def forward(self, t):
208*da0073e9SAndroid Build Coastguard Worker                return t[3:]
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker        self.check(SliceModule(1, 5, 2), torch.randn(4, 6, 2))
211*da0073e9SAndroid Build Coastguard Worker        self.check(SliceModule2(), torch.randn(5))
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker        # flex inputs
214*da0073e9SAndroid Build Coastguard Worker        self.check(
215*da0073e9SAndroid Build Coastguard Worker            SliceModule(1, 5, 2),
216*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 6, 2),
217*da0073e9SAndroid Build Coastguard Worker            convert_args=[torch.zeros(4, 6, 0)],
218*da0073e9SAndroid Build Coastguard Worker        )
219*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "slice with flexible shape"):
220*da0073e9SAndroid Build Coastguard Worker            self.check(
221*da0073e9SAndroid Build Coastguard Worker                SliceModule(1, 5, 2),
222*da0073e9SAndroid Build Coastguard Worker                torch.randn(4, 6, 2),
223*da0073e9SAndroid Build Coastguard Worker                convert_args=[torch.zeros(0, 0, 0)],
224*da0073e9SAndroid Build Coastguard Worker            )
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker    def test_cat(self):
227*da0073e9SAndroid Build Coastguard Worker        class CatModule(torch.nn.Module):
228*da0073e9SAndroid Build Coastguard Worker            def __init__(self, dim):
229*da0073e9SAndroid Build Coastguard Worker                super().__init__()
230*da0073e9SAndroid Build Coastguard Worker                self.dim = dim
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker            def forward(self, t1, t2):
233*da0073e9SAndroid Build Coastguard Worker                return torch.cat([t1, t2], self.dim)
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        self.check(
236*da0073e9SAndroid Build Coastguard Worker            CatModule(0),
237*da0073e9SAndroid Build Coastguard Worker            [
238*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, 2, 3, 3),
239*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 2, 3, 3),
240*da0073e9SAndroid Build Coastguard Worker            ],
241*da0073e9SAndroid Build Coastguard Worker        )
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker        self.check(
244*da0073e9SAndroid Build Coastguard Worker            CatModule(1),
245*da0073e9SAndroid Build Coastguard Worker            [
246*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, 2, 3, 3),
247*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, 4, 3, 3),
248*da0073e9SAndroid Build Coastguard Worker            ],
249*da0073e9SAndroid Build Coastguard Worker        )
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        self.check(
252*da0073e9SAndroid Build Coastguard Worker            CatModule(1),
253*da0073e9SAndroid Build Coastguard Worker            [
254*da0073e9SAndroid Build Coastguard Worker                nhwc(torch.randn(1, 2, 3, 3)),
255*da0073e9SAndroid Build Coastguard Worker                nhwc(torch.randn(1, 4, 3, 3)),
256*da0073e9SAndroid Build Coastguard Worker            ],
257*da0073e9SAndroid Build Coastguard Worker        )
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        self.check(
260*da0073e9SAndroid Build Coastguard Worker            CatModule(1),
261*da0073e9SAndroid Build Coastguard Worker            [
262*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, 2, 3, 3),
263*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, 4, 3, 3),
264*da0073e9SAndroid Build Coastguard Worker            ],
265*da0073e9SAndroid Build Coastguard Worker            convert_args=[torch.zeros(0, 0, 0, 0), torch.zeros(0, 0, 0, 0)],
266*da0073e9SAndroid Build Coastguard Worker        )
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker    def test_pointwise_unary(self):
269*da0073e9SAndroid Build Coastguard Worker        for op in ["relu", "sigmoid"]:
270*da0073e9SAndroid Build Coastguard Worker            with self.subTest(op):
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker                class UnaryModule(torch.nn.Module):
273*da0073e9SAndroid Build Coastguard Worker                    def forward(self, arg):
274*da0073e9SAndroid Build Coastguard Worker                        if op == "relu":
275*da0073e9SAndroid Build Coastguard Worker                            return torch.nn.functional.relu(arg)
276*da0073e9SAndroid Build Coastguard Worker                        if op == "sigmoid":
277*da0073e9SAndroid Build Coastguard Worker                            return torch.sigmoid(arg)
278*da0073e9SAndroid Build Coastguard Worker                        raise Exception("Bad op")  # noqa: TRY002
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker                self.check(UnaryModule(), torch.tensor([-1.0, 1.0]))
281*da0073e9SAndroid Build Coastguard Worker                self.check(
282*da0073e9SAndroid Build Coastguard Worker                    UnaryModule(),
283*da0073e9SAndroid Build Coastguard Worker                    qpt(torch.tensor([-1.0, 1.0]), 1.0 / 256, 0),
284*da0073e9SAndroid Build Coastguard Worker                )
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    def test_pointwise_binary(self):
287*da0073e9SAndroid Build Coastguard Worker        for op in ["add", "sub", "mul", "div"]:
288*da0073e9SAndroid Build Coastguard Worker            with self.subTest(op):
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker                class BinaryModule(torch.nn.Module):
291*da0073e9SAndroid Build Coastguard Worker                    def forward(self, lhs, rhs):
292*da0073e9SAndroid Build Coastguard Worker                        if op == "add":
293*da0073e9SAndroid Build Coastguard Worker                            return lhs + rhs
294*da0073e9SAndroid Build Coastguard Worker                        if op == "sub":
295*da0073e9SAndroid Build Coastguard Worker                            return lhs - rhs
296*da0073e9SAndroid Build Coastguard Worker                        if op == "mul":
297*da0073e9SAndroid Build Coastguard Worker                            return lhs * rhs
298*da0073e9SAndroid Build Coastguard Worker                        if op == "div":
299*da0073e9SAndroid Build Coastguard Worker                            return lhs / rhs
300*da0073e9SAndroid Build Coastguard Worker                        raise Exception("Bad op")  # noqa: TRY002
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker                self.check(
303*da0073e9SAndroid Build Coastguard Worker                    BinaryModule(),
304*da0073e9SAndroid Build Coastguard Worker                    [
305*da0073e9SAndroid Build Coastguard Worker                        torch.tensor([1.0, 2.0]),
306*da0073e9SAndroid Build Coastguard Worker                        torch.tensor([3.0, 4.0]),
307*da0073e9SAndroid Build Coastguard Worker                    ],
308*da0073e9SAndroid Build Coastguard Worker                )
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker                self.check(
311*da0073e9SAndroid Build Coastguard Worker                    BinaryModule(),
312*da0073e9SAndroid Build Coastguard Worker                    [
313*da0073e9SAndroid Build Coastguard Worker                        torch.tensor([[1.0, 2.0]]),
314*da0073e9SAndroid Build Coastguard Worker                        torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
315*da0073e9SAndroid Build Coastguard Worker                    ],
316*da0073e9SAndroid Build Coastguard Worker                )
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"):
319*da0073e9SAndroid Build Coastguard Worker                    self.check(
320*da0073e9SAndroid Build Coastguard Worker                        BinaryModule(),
321*da0073e9SAndroid Build Coastguard Worker                        [
322*da0073e9SAndroid Build Coastguard Worker                            torch.tensor([1.0, 2.0]),
323*da0073e9SAndroid Build Coastguard Worker                            torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
324*da0073e9SAndroid Build Coastguard Worker                        ],
325*da0073e9SAndroid Build Coastguard Worker                    )
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker    def test_pointwise_binary_const(self):
328*da0073e9SAndroid Build Coastguard Worker        const = torch.randn(1, 4, 6, 6)
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        class ArgPlusConst(torch.nn.Module):
331*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
332*da0073e9SAndroid Build Coastguard Worker                return arg + const
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        class ConstPlusArg(torch.nn.Module):
335*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
336*da0073e9SAndroid Build Coastguard Worker                return const + arg
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker        arg_contig = torch.randn(2, 4, 6, 6)
339*da0073e9SAndroid Build Coastguard Worker        arg_nhwc = nhwc(torch.randn(2, 4, 6, 6))
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker        for mod_class in [ArgPlusConst, ConstPlusArg]:
342*da0073e9SAndroid Build Coastguard Worker            for use_nhwc in [False, True]:
343*da0073e9SAndroid Build Coastguard Worker                with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc):
344*da0073e9SAndroid Build Coastguard Worker                    arg = arg_nhwc if use_nhwc else arg_contig
345*da0073e9SAndroid Build Coastguard Worker                    memory_format = (
346*da0073e9SAndroid Build Coastguard Worker                        torch.channels_last if use_nhwc else torch.contiguous_format
347*da0073e9SAndroid Build Coastguard Worker                    )
348*da0073e9SAndroid Build Coastguard Worker                    self.check(mod_class(), arg, expected_memory_format=memory_format)
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker    def test_hardtanh(self):
351*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0])
352*da0073e9SAndroid Build Coastguard Worker        self.check(torch.nn.Hardtanh(), inp)
353*da0073e9SAndroid Build Coastguard Worker        self.check(torch.nn.Hardtanh(0.0, 6.0), inp)
354*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "hardtanh with args"):
355*da0073e9SAndroid Build Coastguard Worker            self.check(torch.nn.Hardtanh(0.0, 5.0), inp)
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker    def test_softmax(self):
358*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]])
359*da0073e9SAndroid Build Coastguard Worker        self.check(torch.nn.Softmax(), inp)
360*da0073e9SAndroid Build Coastguard Worker        self.check(torch.nn.Softmax(dim=0), inp)
361*da0073e9SAndroid Build Coastguard Worker        # Test flexible size
362*da0073e9SAndroid Build Coastguard Worker        self.check(
363*da0073e9SAndroid Build Coastguard Worker            torch.nn.Softmax(),
364*da0073e9SAndroid Build Coastguard Worker            inp,
365*da0073e9SAndroid Build Coastguard Worker            convert_args=[torch.zeros(0, 0)],
366*da0073e9SAndroid Build Coastguard Worker        )
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker    def test_to(self):
369*da0073e9SAndroid Build Coastguard Worker        class ToCPU(torch.nn.Module):
370*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
371*da0073e9SAndroid Build Coastguard Worker                super().__init__()
372*da0073e9SAndroid Build Coastguard Worker                self.prelu = torch.nn.PReLU()
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
375*da0073e9SAndroid Build Coastguard Worker                y = x.to("cpu")
376*da0073e9SAndroid Build Coastguard Worker                # add prelu since input operand can't be output
377*da0073e9SAndroid Build Coastguard Worker                return self.prelu(y)
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        arg = torch.randn(1, 2, 3, 3)
380*da0073e9SAndroid Build Coastguard Worker        self.check(ToCPU(), arg)
381*da0073e9SAndroid Build Coastguard Worker        # Test flexible size
382*da0073e9SAndroid Build Coastguard Worker        self.check(
383*da0073e9SAndroid Build Coastguard Worker            ToCPU(),
384*da0073e9SAndroid Build Coastguard Worker            arg,
385*da0073e9SAndroid Build Coastguard Worker            convert_args=[torch.zeros(1, 2, 0, 0)],
386*da0073e9SAndroid Build Coastguard Worker        )
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker    def test_detach(self):
389*da0073e9SAndroid Build Coastguard Worker        class DetachModule(torch.nn.Module):
390*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
391*da0073e9SAndroid Build Coastguard Worker                y = x.detach()
392*da0073e9SAndroid Build Coastguard Worker                return torch.nn.functional.relu(y)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker        self.check(DetachModule(), torch.randn(1, 2, 3, 3))
395*da0073e9SAndroid Build Coastguard Worker        self.check(
396*da0073e9SAndroid Build Coastguard Worker            DetachModule(),
397*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 2, 3, 3),
398*da0073e9SAndroid Build Coastguard Worker            convert_args=[torch.zeros(1, 2, 0, 0)],
399*da0073e9SAndroid Build Coastguard Worker        )
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker    def test_log_softmax(self):
402*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 10)
403*da0073e9SAndroid Build Coastguard Worker        self.check(torch.nn.LogSoftmax(), inp)
404*da0073e9SAndroid Build Coastguard Worker        self.check(torch.nn.LogSoftmax(0), inp)
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker    def test_mean(self):
407*da0073e9SAndroid Build Coastguard Worker        class MeanModule(torch.nn.Module):
408*da0073e9SAndroid Build Coastguard Worker            def __init__(self, dim, keep=False):
409*da0073e9SAndroid Build Coastguard Worker                super().__init__()
410*da0073e9SAndroid Build Coastguard Worker                self.dim = dim
411*da0073e9SAndroid Build Coastguard Worker                self.keep = keep
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker            def forward(self, t):
414*da0073e9SAndroid Build Coastguard Worker                return torch.mean(t, dim=self.dim, keepdim=self.keep)
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker        self.check(MeanModule(0), torch.randn(2, 3))
417*da0073e9SAndroid Build Coastguard Worker        self.check(MeanModule(1), torch.randn(2, 3))
418*da0073e9SAndroid Build Coastguard Worker        self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6))
419*da0073e9SAndroid Build Coastguard Worker        self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6)))
420*da0073e9SAndroid Build Coastguard Worker        self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6)))
421*da0073e9SAndroid Build Coastguard Worker        self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6)))
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker    def test_max_pool2d(self):
424*da0073e9SAndroid Build Coastguard Worker        for name, inp in self.float_and_quant_and_nhwc(
425*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, 3, 12, 16), 0.3, 128
426*da0073e9SAndroid Build Coastguard Worker        ):
427*da0073e9SAndroid Build Coastguard Worker            with self.subTest(name):
428*da0073e9SAndroid Build Coastguard Worker                self.check(torch.nn.MaxPool2d(2), inp)
429*da0073e9SAndroid Build Coastguard Worker                self.check(torch.nn.MaxPool2d((3, 4)), inp)
430*da0073e9SAndroid Build Coastguard Worker                self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp)
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool2d(self):
433*da0073e9SAndroid Build Coastguard Worker        for name, inp in self.float_and_quant_and_nhwc(
434*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, 3, 12, 16), 0.3, 128
435*da0073e9SAndroid Build Coastguard Worker        ):
436*da0073e9SAndroid Build Coastguard Worker            with self.subTest(name):
437*da0073e9SAndroid Build Coastguard Worker                atol_rtol = None
438*da0073e9SAndroid Build Coastguard Worker                limit = None
439*da0073e9SAndroid Build Coastguard Worker                convert_dims = (2, 3, 0, 0)
440*da0073e9SAndroid Build Coastguard Worker                convert_arg = torch.zeros(*convert_dims)
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker                for model in (
443*da0073e9SAndroid Build Coastguard Worker                    torch.nn.AvgPool2d(2),
444*da0073e9SAndroid Build Coastguard Worker                    torch.nn.AvgPool2d((3, 4)),
445*da0073e9SAndroid Build Coastguard Worker                    torch.nn.AvgPool2d((3, 4), (1, 2)),
446*da0073e9SAndroid Build Coastguard Worker                ):
447*da0073e9SAndroid Build Coastguard Worker                    if "quant" in name:
448*da0073e9SAndroid Build Coastguard Worker                        atol_rtol = (1, 0)
449*da0073e9SAndroid Build Coastguard Worker                        limit = model(inp).numel()
450*da0073e9SAndroid Build Coastguard Worker                        convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
451*da0073e9SAndroid Build Coastguard Worker                    if "nhwc" in name:
452*da0073e9SAndroid Build Coastguard Worker                        convert_arg = nhwc(convert_arg)
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker                    self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
455*da0073e9SAndroid Build Coastguard Worker                    self.check(
456*da0073e9SAndroid Build Coastguard Worker                        model,
457*da0073e9SAndroid Build Coastguard Worker                        inp,
458*da0073e9SAndroid Build Coastguard Worker                        convert_args=[convert_arg],
459*da0073e9SAndroid Build Coastguard Worker                        atol_rtol=atol_rtol,
460*da0073e9SAndroid Build Coastguard Worker                        limit=limit,
461*da0073e9SAndroid Build Coastguard Worker                    )
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_avg_pool2d(self):
464*da0073e9SAndroid Build Coastguard Worker        for name, inp in self.float_and_quant_and_nhwc(
465*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, 3, 12, 16), 0.3, 128
466*da0073e9SAndroid Build Coastguard Worker        ):
467*da0073e9SAndroid Build Coastguard Worker            with self.subTest(name):
468*da0073e9SAndroid Build Coastguard Worker                self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp)
469*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(Exception, "with output size"):
470*da0073e9SAndroid Build Coastguard Worker                    self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker    def test_upsample_nearest2d(self):
473*da0073e9SAndroid Build Coastguard Worker        convert_args = dict(
474*da0073e9SAndroid Build Coastguard Worker            self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128)
475*da0073e9SAndroid Build Coastguard Worker        )
476*da0073e9SAndroid Build Coastguard Worker        for name, inp in self.float_and_quant_and_nhwc(
477*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, 3, 12, 16), 0.3, 128
478*da0073e9SAndroid Build Coastguard Worker        ):
479*da0073e9SAndroid Build Coastguard Worker            with self.subTest(name):
480*da0073e9SAndroid Build Coastguard Worker                self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp)
481*da0073e9SAndroid Build Coastguard Worker                self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp)
482*da0073e9SAndroid Build Coastguard Worker                self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp)
483*da0073e9SAndroid Build Coastguard Worker                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp)
484*da0073e9SAndroid Build Coastguard Worker                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp)
485*da0073e9SAndroid Build Coastguard Worker                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp)
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker                self.check(
488*da0073e9SAndroid Build Coastguard Worker                    torch.nn.UpsamplingNearest2d(size=(24, 32)),
489*da0073e9SAndroid Build Coastguard Worker                    inp,
490*da0073e9SAndroid Build Coastguard Worker                    convert_args=[convert_args[name]],
491*da0073e9SAndroid Build Coastguard Worker                )
492*da0073e9SAndroid Build Coastguard Worker                self.check(
493*da0073e9SAndroid Build Coastguard Worker                    torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)),
494*da0073e9SAndroid Build Coastguard Worker                    inp,
495*da0073e9SAndroid Build Coastguard Worker                    convert_args=[convert_args[name]],
496*da0073e9SAndroid Build Coastguard Worker                )
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker    def test_linear(self):
499*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(29)
500*da0073e9SAndroid Build Coastguard Worker        self.check(torch.nn.Linear(16, 32), torch.randn(2, 16))
501*da0073e9SAndroid Build Coastguard Worker        self.check(
502*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(16, 32),
503*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, 16),
504*da0073e9SAndroid Build Coastguard Worker            convert_args=[torch.zeros(0, 16)],
505*da0073e9SAndroid Build Coastguard Worker        )
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker    def test_conv2d(self):
508*da0073e9SAndroid Build Coastguard Worker        cases = [
509*da0073e9SAndroid Build Coastguard Worker            # in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim,      name
510*da0073e9SAndroid Build Coastguard Worker            (4, 8, (3, 3), 1, 0, 1, 1, (2, 4, 16, 16), "3x3"),  # noqa: E201,E241
511*da0073e9SAndroid Build Coastguard Worker            (4, 8, (3, 3), 1, 0, 1, 0, (2, 4, 16, 16), "3x3nobias"),  # noqa: E201,E241
512*da0073e9SAndroid Build Coastguard Worker            (4, 16, (3, 3), 1, 1, 1, 1, (2, 4, 16, 16), "3x3p1"),  # noqa: E201,E241
513*da0073e9SAndroid Build Coastguard Worker            (8, 8, (3, 3), 2, 0, 1, 1, (2, 8, 16, 16), "3x3s2"),  # noqa: E201,E241
514*da0073e9SAndroid Build Coastguard Worker            (4, 8, (5, 5), 1, 0, 1, 1, (2, 4, 16, 16), "5x5"),  # noqa: E201,E241
515*da0073e9SAndroid Build Coastguard Worker            (4, 4, (3, 3), 1, 0, 4, 1, (2, 4, 16, 16), "3x3dw"),  # noqa: E201,E241
516*da0073e9SAndroid Build Coastguard Worker            (8, 4, (1, 1), 1, 0, 1, 1, (2, 8, 16, 16), "1x1"),  # noqa: E201,E241
517*da0073e9SAndroid Build Coastguard Worker        ]
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker        for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
520*da0073e9SAndroid Build Coastguard Worker            for case in cases:
521*da0073e9SAndroid Build Coastguard Worker                (
522*da0073e9SAndroid Build Coastguard Worker                    in_ch,
523*da0073e9SAndroid Build Coastguard Worker                    out_ch,
524*da0073e9SAndroid Build Coastguard Worker                    kernel,
525*da0073e9SAndroid Build Coastguard Worker                    stride,
526*da0073e9SAndroid Build Coastguard Worker                    padding,
527*da0073e9SAndroid Build Coastguard Worker                    groups,
528*da0073e9SAndroid Build Coastguard Worker                    bias,
529*da0073e9SAndroid Build Coastguard Worker                    input_dim,
530*da0073e9SAndroid Build Coastguard Worker                    name,
531*da0073e9SAndroid Build Coastguard Worker                ) = case
532*da0073e9SAndroid Build Coastguard Worker                with self.subTest(f"{kind}-{name}"):
533*da0073e9SAndroid Build Coastguard Worker                    inp = torch.randn(input_dim)
534*da0073e9SAndroid Build Coastguard Worker                    model = torch.nn.Conv2d(
535*da0073e9SAndroid Build Coastguard Worker                        in_ch,
536*da0073e9SAndroid Build Coastguard Worker                        out_ch,
537*da0073e9SAndroid Build Coastguard Worker                        kernel,
538*da0073e9SAndroid Build Coastguard Worker                        stride,
539*da0073e9SAndroid Build Coastguard Worker                        padding,
540*da0073e9SAndroid Build Coastguard Worker                        groups=groups,
541*da0073e9SAndroid Build Coastguard Worker                        bias=bool(bias),
542*da0073e9SAndroid Build Coastguard Worker                    )
543*da0073e9SAndroid Build Coastguard Worker                    output_size = model(inp).numel()
544*da0073e9SAndroid Build Coastguard Worker                    atol_rtol = None
545*da0073e9SAndroid Build Coastguard Worker                    limit = None
546*da0073e9SAndroid Build Coastguard Worker                    convert_dims = (0, in_ch, 0, 0)
547*da0073e9SAndroid Build Coastguard Worker                    convert_arg = torch.zeros(*convert_dims)
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker                    if "quant" in kind:
550*da0073e9SAndroid Build Coastguard Worker                        model = torch.nn.Sequential(model)
551*da0073e9SAndroid Build Coastguard Worker                        model.eval()
552*da0073e9SAndroid Build Coastguard Worker                        model.qconfig = torch.ao.quantization.get_default_qconfig(
553*da0073e9SAndroid Build Coastguard Worker                            "qnnpack"
554*da0073e9SAndroid Build Coastguard Worker                        )
555*da0073e9SAndroid Build Coastguard Worker                        model = torch.ao.quantization.prepare(model)
556*da0073e9SAndroid Build Coastguard Worker                        model(inp)
557*da0073e9SAndroid Build Coastguard Worker                        model = torch.ao.quantization.convert(model)
558*da0073e9SAndroid Build Coastguard Worker                        inp = qpt(inp, 1.0 / 16, 128)
559*da0073e9SAndroid Build Coastguard Worker                        # I've seen numerical differences between QNNPACK and NNAPI,
560*da0073e9SAndroid Build Coastguard Worker                        # but never more than 1 quantum, and never more than ~1% of
561*da0073e9SAndroid Build Coastguard Worker                        # the output in this test.
562*da0073e9SAndroid Build Coastguard Worker                        atol_rtol = (1, 0)
563*da0073e9SAndroid Build Coastguard Worker                        limit = output_size * 0.03
564*da0073e9SAndroid Build Coastguard Worker                        convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker                    if "nhwc" in kind:
567*da0073e9SAndroid Build Coastguard Worker                        inp = nhwc(inp)
568*da0073e9SAndroid Build Coastguard Worker                        convert_arg = nhwc(convert_arg)
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker                    self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
571*da0073e9SAndroid Build Coastguard Worker                    self.check(
572*da0073e9SAndroid Build Coastguard Worker                        model,
573*da0073e9SAndroid Build Coastguard Worker                        inp,
574*da0073e9SAndroid Build Coastguard Worker                        convert_args=[convert_arg],
575*da0073e9SAndroid Build Coastguard Worker                        atol_rtol=atol_rtol,
576*da0073e9SAndroid Build Coastguard Worker                        limit=limit,
577*da0073e9SAndroid Build Coastguard Worker                    )
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_transpose(self):
580*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(29)
581*da0073e9SAndroid Build Coastguard Worker        in_ch, out_ch, kernel = (5, 7, (2, 2))
582*da0073e9SAndroid Build Coastguard Worker        input_dim = (4, 5, 3, 3)
583*da0073e9SAndroid Build Coastguard Worker        convert_dims = input_dim[:2] + (0, 0)
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker        for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
586*da0073e9SAndroid Build Coastguard Worker            with self.subTest(kind):
587*da0073e9SAndroid Build Coastguard Worker                inp = torch.randn(input_dim)
588*da0073e9SAndroid Build Coastguard Worker                model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel)
589*da0073e9SAndroid Build Coastguard Worker                output_size = model(inp).numel()
590*da0073e9SAndroid Build Coastguard Worker                atol_rtol = (0.0002, 0)
591*da0073e9SAndroid Build Coastguard Worker                limit = None
592*da0073e9SAndroid Build Coastguard Worker                convert_arg = torch.zeros(*convert_dims)
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker                if "quant" in kind:
595*da0073e9SAndroid Build Coastguard Worker                    model = torch.ao.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel)
596*da0073e9SAndroid Build Coastguard Worker                    model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
597*da0073e9SAndroid Build Coastguard Worker                    inp = qpt(inp, 1.0 / 16, 128)
598*da0073e9SAndroid Build Coastguard Worker                    # I've seen numerical differences between QNNPACK and NNAPI,
599*da0073e9SAndroid Build Coastguard Worker                    # but never more than 1 quantum, and never more than ~10% of
600*da0073e9SAndroid Build Coastguard Worker                    # the output in this test.
601*da0073e9SAndroid Build Coastguard Worker                    atol_rtol = (1, 0)
602*da0073e9SAndroid Build Coastguard Worker                    limit = output_size * 0.1
603*da0073e9SAndroid Build Coastguard Worker                    convert_arg = qpt(convert_arg, 1.0 / 16, 128)
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker                if "nhwc" in kind:
606*da0073e9SAndroid Build Coastguard Worker                    inp = nhwc(inp)
607*da0073e9SAndroid Build Coastguard Worker                    convert_arg = nhwc(convert_arg)
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker                self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
610*da0073e9SAndroid Build Coastguard Worker                self.check(
611*da0073e9SAndroid Build Coastguard Worker                    model,
612*da0073e9SAndroid Build Coastguard Worker                    inp,
613*da0073e9SAndroid Build Coastguard Worker                    convert_args=[convert_arg],
614*da0073e9SAndroid Build Coastguard Worker                    atol_rtol=atol_rtol,
615*da0073e9SAndroid Build Coastguard Worker                    limit=limit,
616*da0073e9SAndroid Build Coastguard Worker                )
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker    def test_qadd(self):
619*da0073e9SAndroid Build Coastguard Worker        func = torch.ao.nn.quantized.QFunctional()
620*da0073e9SAndroid Build Coastguard Worker        func.scale = 0.5
621*da0073e9SAndroid Build Coastguard Worker        func.zero_point = 120
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker        class AddMod(torch.nn.Module):
624*da0073e9SAndroid Build Coastguard Worker            def forward(self, lhs, rhs):
625*da0073e9SAndroid Build Coastguard Worker                return func.add(lhs, rhs)
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker        class AddReluMod(torch.nn.Module):
628*da0073e9SAndroid Build Coastguard Worker            def forward(self, lhs, rhs):
629*da0073e9SAndroid Build Coastguard Worker                return func.add_relu(lhs, rhs)
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker        class MulMod(torch.nn.Module):
632*da0073e9SAndroid Build Coastguard Worker            def forward(self, lhs, rhs):
633*da0073e9SAndroid Build Coastguard Worker                return func.mul(lhs, rhs)
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker        for name, mod in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]:
636*da0073e9SAndroid Build Coastguard Worker            with self.subTest(name):
637*da0073e9SAndroid Build Coastguard Worker                self.check(
638*da0073e9SAndroid Build Coastguard Worker                    mod(),
639*da0073e9SAndroid Build Coastguard Worker                    [
640*da0073e9SAndroid Build Coastguard Worker                        qpt([1.0, 2.0], 0.25, 128),
641*da0073e9SAndroid Build Coastguard Worker                        qpt([3.0, 4.0], 0.25, 128),
642*da0073e9SAndroid Build Coastguard Worker                    ],
643*da0073e9SAndroid Build Coastguard Worker                )
644*da0073e9SAndroid Build Coastguard Worker                self.check(
645*da0073e9SAndroid Build Coastguard Worker                    mod(),
646*da0073e9SAndroid Build Coastguard Worker                    [
647*da0073e9SAndroid Build Coastguard Worker                        qpt([[1.0, 2.0]], 0.25, 128),
648*da0073e9SAndroid Build Coastguard Worker                        qpt([[3.0, 4.0]], 0.25, 128),
649*da0073e9SAndroid Build Coastguard Worker                    ],
650*da0073e9SAndroid Build Coastguard Worker                    convert_args=[
651*da0073e9SAndroid Build Coastguard Worker                        qpt([[1.0, 2.0]], 0.25, 128),
652*da0073e9SAndroid Build Coastguard Worker                        qpt(torch.zeros((1, 2)), 0.25, 128),
653*da0073e9SAndroid Build Coastguard Worker                    ],
654*da0073e9SAndroid Build Coastguard Worker                )
655*da0073e9SAndroid Build Coastguard Worker                self.check(
656*da0073e9SAndroid Build Coastguard Worker                    mod(),
657*da0073e9SAndroid Build Coastguard Worker                    [
658*da0073e9SAndroid Build Coastguard Worker                        qpt([[1.0, 2.0]], 0.25, 128),
659*da0073e9SAndroid Build Coastguard Worker                        qpt([[3.0, 4.0]], 0.25, 128),
660*da0073e9SAndroid Build Coastguard Worker                    ],
661*da0073e9SAndroid Build Coastguard Worker                    convert_args=[
662*da0073e9SAndroid Build Coastguard Worker                        qpt(torch.zeros((1, 2)), 0.25, 128),
663*da0073e9SAndroid Build Coastguard Worker                        qpt([[3.0, 4.0]], 0.25, 128),
664*da0073e9SAndroid Build Coastguard Worker                    ],
665*da0073e9SAndroid Build Coastguard Worker                )
666*da0073e9SAndroid Build Coastguard Worker                self.check(
667*da0073e9SAndroid Build Coastguard Worker                    mod(),
668*da0073e9SAndroid Build Coastguard Worker                    [
669*da0073e9SAndroid Build Coastguard Worker                        qpt([[1.0, 2.0]], 0.25, 128),
670*da0073e9SAndroid Build Coastguard Worker                        qpt([[3.0, 4.0]], 0.25, 128),
671*da0073e9SAndroid Build Coastguard Worker                    ],
672*da0073e9SAndroid Build Coastguard Worker                    convert_args=[
673*da0073e9SAndroid Build Coastguard Worker                        qpt(torch.zeros((1, 2)), 0.25, 128),
674*da0073e9SAndroid Build Coastguard Worker                        qpt(torch.zeros((1, 2)), 0.25, 128),
675*da0073e9SAndroid Build Coastguard Worker                    ],
676*da0073e9SAndroid Build Coastguard Worker                )
677*da0073e9SAndroid Build Coastguard Worker                # NOTE: NNAPI qadd supports broadcast, but PT does not.
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker    def test_qlinear(self):
680*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(29)
681*da0073e9SAndroid Build Coastguard Worker        weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8)
682*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(16)
683*da0073e9SAndroid Build Coastguard Worker        mod = torch.ao.nn.quantized.Linear(32, 16)
684*da0073e9SAndroid Build Coastguard Worker        mod.set_weight_bias(weight, bias)
685*da0073e9SAndroid Build Coastguard Worker        inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
686*da0073e9SAndroid Build Coastguard Worker        self.check(mod, inp)
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker    def test_seblock_mul(self):
689*da0073e9SAndroid Build Coastguard Worker        class MulModel(torch.nn.Module):
690*da0073e9SAndroid Build Coastguard Worker            def forward(self, lhs, rhs):
691*da0073e9SAndroid Build Coastguard Worker                return lhs * rhs
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        self.check(
694*da0073e9SAndroid Build Coastguard Worker            MulModel(),
695*da0073e9SAndroid Build Coastguard Worker            [
696*da0073e9SAndroid Build Coastguard Worker                nhwc(torch.randn(2, 3, 4, 4)),
697*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, 3, 1, 1),
698*da0073e9SAndroid Build Coastguard Worker            ],
699*da0073e9SAndroid Build Coastguard Worker        )
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker    def test_multi_output(self):
702*da0073e9SAndroid Build Coastguard Worker        class MultiModel(torch.nn.Module):
703*da0073e9SAndroid Build Coastguard Worker            def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
704*da0073e9SAndroid Build Coastguard Worker                the_sum = lhs + rhs
705*da0073e9SAndroid Build Coastguard Worker                the_diff = lhs - rhs
706*da0073e9SAndroid Build Coastguard Worker                return the_sum, the_diff
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker        self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])])
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
712*da0073e9SAndroid Build Coastguard Worker    run_tests()
713