xref: /aosp_15_r20/external/pytorch/test/test_nnapi.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: mobile"]
3
4import ctypes
5import os
6import unittest
7from typing import Tuple
8
9import torch
10from torch.backends._nnapi.prepare import convert_model_to_nnapi
11from torch.testing._internal.common_quantized import supported_qengines
12from torch.testing._internal.common_utils import run_tests, TestCase
13
14
15def qpt(t, scale, zero_point, dtype=torch.quint8):
16    t = torch.tensor(t)
17    return torch.quantize_per_tensor(t, scale, zero_point, dtype)
18
19
20def nhwc(t):
21    t = t.clone().contiguous(memory_format=torch.channels_last)
22    t.nnapi_nhwc = True
23    return t
24
25
26@unittest.skipUnless(
27    "qnnpack" in supported_qengines,
28    "This Pytorch Build has not been built with or does not support QNNPACK",
29)
30class TestNNAPI(TestCase):
31    def setUp(self):
32        # Avoid saturation in fbgemm
33        torch.backends.quantized.engine = "qnnpack"
34
35        libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH")
36        if libneuralnetworks_path:
37            ctypes.cdll.LoadLibrary(libneuralnetworks_path)
38            print("Will attempt to run NNAPI models.")
39            self.can_run_nnapi = True
40        else:
41            self.can_run_nnapi = False
42
43    # Created for easy override by subclasses (eg TestNnapiBackend)
44    def call_lowering_to_nnapi(self, traced_module, args):
45        return convert_model_to_nnapi(traced_module, args)
46
47    # Created for subclasses to set can_run_nnapi (eg TestNnapiBackend)
48    def set_can_run_nnapi(self, can_run):
49        self.can_run_nnapi = can_run
50
51    def check(
52        self,
53        module,
54        arg_or_args,
55        *,
56        trace_args=None,
57        convert_args=None,
58        atol_rtol=None,
59        limit=None,
60        expected_memory_format=None,
61    ):
62        with torch.no_grad():
63            if isinstance(arg_or_args, torch.Tensor):
64                args = [arg_or_args]
65            else:
66                args = arg_or_args
67            module.eval()
68            traced = torch.jit.trace(module, trace_args or args)
69            nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args)
70            if not self.can_run_nnapi:
71                # Only test that the model was converted successfully.
72                return
73            eager_output = module(*args)
74            nnapi_output = nnapi_module(*args)
75            kwargs = {}
76            if atol_rtol is not None:
77                kwargs["atol"] = atol_rtol[0]
78                kwargs["rtol"] = atol_rtol[1]
79            self.assertEqual(eager_output, nnapi_output, **kwargs)
80            if limit is not None:
81                mismatches = eager_output.int_repr().to(
82                    torch.int32
83                ) - nnapi_output.int_repr().to(torch.int32)
84                if mismatches.count_nonzero() > limit:
85                    # Too many mismatches.  Re-run the check with no tolerance
86                    # to get a nice message.
87                    self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0)
88            if expected_memory_format:
89                self.assertTrue(
90                    nnapi_output.is_contiguous(memory_format=expected_memory_format)
91                )
92
93    def float_and_quant_and_nhwc(self, inp_float, scale, zero_point):
94        torch.manual_seed(29)
95        inp_quant = qpt(inp_float, 0.03, 128)
96        return [
97            ("float", inp_float),
98            ("float-nhwc", nhwc(inp_float)),
99            ("quant", inp_quant),
100            ("quant-nhwc", nhwc(inp_quant)),
101        ]
102
103    def test_prelu(self):
104        arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
105        single_a = torch.nn.PReLU()
106        self.check(single_a, arg)
107        multi_a = torch.nn.PReLU(4)
108        with torch.no_grad():
109            multi_a.weight.copy_(torch.tensor([0.1, 0.2, 0.3, 0.4]))
110        self.check(multi_a, nhwc(arg))
111
112        # Test flexible size
113        self.check(
114            multi_a,
115            arg,
116            trace_args=[torch.zeros(1, 4, 3, 3)],
117            convert_args=[nhwc(torch.zeros(1, 4, 0, 0))],
118        )
119
120    def test_quantize(self):
121        self.check(
122            torch.ao.nn.quantized.Quantize(0.25, 2, torch.quint8),
123            nhwc(torch.tensor([[[[1.0]], [[2.0]]]])),
124        )
125
126    def test_dequantize(self):
127        self.check(
128            torch.ao.nn.quantized.DeQuantize(), nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2))
129        )
130
131    def test_unsqueeze(self):
132        class UnsqueezeModule(torch.nn.Module):
133            def __init__(self, dim):
134                super().__init__()
135                self.dim = dim
136
137            def forward(self, arg):
138                return arg.unsqueeze(self.dim)
139
140        self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2))
141        self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2))
142        self.check(UnsqueezeModule(0), torch.randn(4, 2, 2))
143        self.check(UnsqueezeModule(1), torch.randn(4, 2, 2))
144        self.check(UnsqueezeModule(2), torch.randn(4, 2, 2))
145
146    def test_reshape(self):
147        class ReshapeModule(torch.nn.Module):
148            def __init__(self, shape):
149                super().__init__()
150                self.shape = shape
151
152            def forward(self, arg):
153                return arg.reshape(self.shape)
154
155        self.check(ReshapeModule((2, 4)), torch.randn(4, 2, 1, 1))
156
157        self.check(ReshapeModule((8, -1)), nhwc(torch.randn(4, 2, 1, 1)))
158
159        with self.assertRaisesRegex(Exception, "target size"):
160            self.check(ReshapeModule((2, 4)), nhwc(torch.randn(4, 2, 1, 1)))
161
162    def test_flatten(self):
163        for mod in [
164            torch.nn.Flatten(),
165            torch.nn.Flatten(start_dim=2, end_dim=3),
166            torch.nn.Flatten(start_dim=2, end_dim=4),
167            torch.nn.Flatten(start_dim=0, end_dim=-2),
168            torch.nn.Flatten(start_dim=0, end_dim=4),
169        ]:
170            self.check(mod, torch.randn(4, 2, 1, 3, 7))
171
172        # flex inputs
173        self.check(
174            torch.nn.Flatten(),
175            torch.randn(4, 2, 1, 3, 7),
176            convert_args=[torch.zeros(0, 2, 1, 3, 7)],
177        )
178
179        # channels last
180        self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 1, 4, 7)))
181        self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 3, 1, 1)))
182
183        # Exceptions
184        with self.assertRaisesRegex(Exception, "not supported on NHWC"):
185            self.check(torch.nn.Flatten(), nhwc(torch.randn(1, 3, 4, 4)))
186        with self.assertRaisesRegex(
187            Exception, "Flattening flexible dims is not supported yet"
188        ):
189            self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7))
190        with self.assertRaisesRegex(Exception, "Only 1 dim"):
191            self.check(
192                torch.nn.Flatten(start_dim=1, end_dim=-2), torch.randn(0, 2, 1, 3, 0)
193            )
194
195    def test_slice(self):
196        class SliceModule(torch.nn.Module):
197            def __init__(self, start, stop, step):
198                super().__init__()
199                self.start = start
200                self.stop = stop
201                self.step = step
202
203            def forward(self, t):
204                return t[1:, self.start : self.stop : self.step, :]
205
206        class SliceModule2(torch.nn.Module):
207            def forward(self, t):
208                return t[3:]
209
210        self.check(SliceModule(1, 5, 2), torch.randn(4, 6, 2))
211        self.check(SliceModule2(), torch.randn(5))
212
213        # flex inputs
214        self.check(
215            SliceModule(1, 5, 2),
216            torch.randn(4, 6, 2),
217            convert_args=[torch.zeros(4, 6, 0)],
218        )
219        with self.assertRaisesRegex(Exception, "slice with flexible shape"):
220            self.check(
221                SliceModule(1, 5, 2),
222                torch.randn(4, 6, 2),
223                convert_args=[torch.zeros(0, 0, 0)],
224            )
225
226    def test_cat(self):
227        class CatModule(torch.nn.Module):
228            def __init__(self, dim):
229                super().__init__()
230                self.dim = dim
231
232            def forward(self, t1, t2):
233                return torch.cat([t1, t2], self.dim)
234
235        self.check(
236            CatModule(0),
237            [
238                torch.randn(1, 2, 3, 3),
239                torch.randn(2, 2, 3, 3),
240            ],
241        )
242
243        self.check(
244            CatModule(1),
245            [
246                torch.randn(1, 2, 3, 3),
247                torch.randn(1, 4, 3, 3),
248            ],
249        )
250
251        self.check(
252            CatModule(1),
253            [
254                nhwc(torch.randn(1, 2, 3, 3)),
255                nhwc(torch.randn(1, 4, 3, 3)),
256            ],
257        )
258
259        self.check(
260            CatModule(1),
261            [
262                torch.randn(1, 2, 3, 3),
263                torch.randn(1, 4, 3, 3),
264            ],
265            convert_args=[torch.zeros(0, 0, 0, 0), torch.zeros(0, 0, 0, 0)],
266        )
267
268    def test_pointwise_unary(self):
269        for op in ["relu", "sigmoid"]:
270            with self.subTest(op):
271
272                class UnaryModule(torch.nn.Module):
273                    def forward(self, arg):
274                        if op == "relu":
275                            return torch.nn.functional.relu(arg)
276                        if op == "sigmoid":
277                            return torch.sigmoid(arg)
278                        raise Exception("Bad op")  # noqa: TRY002
279
280                self.check(UnaryModule(), torch.tensor([-1.0, 1.0]))
281                self.check(
282                    UnaryModule(),
283                    qpt(torch.tensor([-1.0, 1.0]), 1.0 / 256, 0),
284                )
285
286    def test_pointwise_binary(self):
287        for op in ["add", "sub", "mul", "div"]:
288            with self.subTest(op):
289
290                class BinaryModule(torch.nn.Module):
291                    def forward(self, lhs, rhs):
292                        if op == "add":
293                            return lhs + rhs
294                        if op == "sub":
295                            return lhs - rhs
296                        if op == "mul":
297                            return lhs * rhs
298                        if op == "div":
299                            return lhs / rhs
300                        raise Exception("Bad op")  # noqa: TRY002
301
302                self.check(
303                    BinaryModule(),
304                    [
305                        torch.tensor([1.0, 2.0]),
306                        torch.tensor([3.0, 4.0]),
307                    ],
308                )
309
310                self.check(
311                    BinaryModule(),
312                    [
313                        torch.tensor([[1.0, 2.0]]),
314                        torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
315                    ],
316                )
317
318                with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"):
319                    self.check(
320                        BinaryModule(),
321                        [
322                            torch.tensor([1.0, 2.0]),
323                            torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
324                        ],
325                    )
326
327    def test_pointwise_binary_const(self):
328        const = torch.randn(1, 4, 6, 6)
329
330        class ArgPlusConst(torch.nn.Module):
331            def forward(self, arg):
332                return arg + const
333
334        class ConstPlusArg(torch.nn.Module):
335            def forward(self, arg):
336                return const + arg
337
338        arg_contig = torch.randn(2, 4, 6, 6)
339        arg_nhwc = nhwc(torch.randn(2, 4, 6, 6))
340
341        for mod_class in [ArgPlusConst, ConstPlusArg]:
342            for use_nhwc in [False, True]:
343                with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc):
344                    arg = arg_nhwc if use_nhwc else arg_contig
345                    memory_format = (
346                        torch.channels_last if use_nhwc else torch.contiguous_format
347                    )
348                    self.check(mod_class(), arg, expected_memory_format=memory_format)
349
350    def test_hardtanh(self):
351        inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0])
352        self.check(torch.nn.Hardtanh(), inp)
353        self.check(torch.nn.Hardtanh(0.0, 6.0), inp)
354        with self.assertRaisesRegex(Exception, "hardtanh with args"):
355            self.check(torch.nn.Hardtanh(0.0, 5.0), inp)
356
357    def test_softmax(self):
358        inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]])
359        self.check(torch.nn.Softmax(), inp)
360        self.check(torch.nn.Softmax(dim=0), inp)
361        # Test flexible size
362        self.check(
363            torch.nn.Softmax(),
364            inp,
365            convert_args=[torch.zeros(0, 0)],
366        )
367
368    def test_to(self):
369        class ToCPU(torch.nn.Module):
370            def __init__(self) -> None:
371                super().__init__()
372                self.prelu = torch.nn.PReLU()
373
374            def forward(self, x):
375                y = x.to("cpu")
376                # add prelu since input operand can't be output
377                return self.prelu(y)
378
379        arg = torch.randn(1, 2, 3, 3)
380        self.check(ToCPU(), arg)
381        # Test flexible size
382        self.check(
383            ToCPU(),
384            arg,
385            convert_args=[torch.zeros(1, 2, 0, 0)],
386        )
387
388    def test_detach(self):
389        class DetachModule(torch.nn.Module):
390            def forward(self, x):
391                y = x.detach()
392                return torch.nn.functional.relu(y)
393
394        self.check(DetachModule(), torch.randn(1, 2, 3, 3))
395        self.check(
396            DetachModule(),
397            torch.randn(1, 2, 3, 3),
398            convert_args=[torch.zeros(1, 2, 0, 0)],
399        )
400
401    def test_log_softmax(self):
402        inp = torch.randn(3, 10)
403        self.check(torch.nn.LogSoftmax(), inp)
404        self.check(torch.nn.LogSoftmax(0), inp)
405
406    def test_mean(self):
407        class MeanModule(torch.nn.Module):
408            def __init__(self, dim, keep=False):
409                super().__init__()
410                self.dim = dim
411                self.keep = keep
412
413            def forward(self, t):
414                return torch.mean(t, dim=self.dim, keepdim=self.keep)
415
416        self.check(MeanModule(0), torch.randn(2, 3))
417        self.check(MeanModule(1), torch.randn(2, 3))
418        self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6))
419        self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6)))
420        self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6)))
421        self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6)))
422
423    def test_max_pool2d(self):
424        for name, inp in self.float_and_quant_and_nhwc(
425            torch.randn(2, 3, 12, 16), 0.3, 128
426        ):
427            with self.subTest(name):
428                self.check(torch.nn.MaxPool2d(2), inp)
429                self.check(torch.nn.MaxPool2d((3, 4)), inp)
430                self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp)
431
432    def test_avg_pool2d(self):
433        for name, inp in self.float_and_quant_and_nhwc(
434            torch.randn(2, 3, 12, 16), 0.3, 128
435        ):
436            with self.subTest(name):
437                atol_rtol = None
438                limit = None
439                convert_dims = (2, 3, 0, 0)
440                convert_arg = torch.zeros(*convert_dims)
441
442                for model in (
443                    torch.nn.AvgPool2d(2),
444                    torch.nn.AvgPool2d((3, 4)),
445                    torch.nn.AvgPool2d((3, 4), (1, 2)),
446                ):
447                    if "quant" in name:
448                        atol_rtol = (1, 0)
449                        limit = model(inp).numel()
450                        convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
451                    if "nhwc" in name:
452                        convert_arg = nhwc(convert_arg)
453
454                    self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
455                    self.check(
456                        model,
457                        inp,
458                        convert_args=[convert_arg],
459                        atol_rtol=atol_rtol,
460                        limit=limit,
461                    )
462
463    def test_adaptive_avg_pool2d(self):
464        for name, inp in self.float_and_quant_and_nhwc(
465            torch.randn(2, 3, 12, 16), 0.3, 128
466        ):
467            with self.subTest(name):
468                self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp)
469                with self.assertRaisesRegex(Exception, "with output size"):
470                    self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp)
471
472    def test_upsample_nearest2d(self):
473        convert_args = dict(
474            self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128)
475        )
476        for name, inp in self.float_and_quant_and_nhwc(
477            torch.randn(2, 3, 12, 16), 0.3, 128
478        ):
479            with self.subTest(name):
480                self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp)
481                self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp)
482                self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp)
483                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp)
484                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp)
485                self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp)
486
487                self.check(
488                    torch.nn.UpsamplingNearest2d(size=(24, 32)),
489                    inp,
490                    convert_args=[convert_args[name]],
491                )
492                self.check(
493                    torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)),
494                    inp,
495                    convert_args=[convert_args[name]],
496                )
497
498    def test_linear(self):
499        torch.manual_seed(29)
500        self.check(torch.nn.Linear(16, 32), torch.randn(2, 16))
501        self.check(
502            torch.nn.Linear(16, 32),
503            torch.randn(2, 16),
504            convert_args=[torch.zeros(0, 16)],
505        )
506
507    def test_conv2d(self):
508        cases = [
509            # in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim,      name
510            (4, 8, (3, 3), 1, 0, 1, 1, (2, 4, 16, 16), "3x3"),  # noqa: E201,E241
511            (4, 8, (3, 3), 1, 0, 1, 0, (2, 4, 16, 16), "3x3nobias"),  # noqa: E201,E241
512            (4, 16, (3, 3), 1, 1, 1, 1, (2, 4, 16, 16), "3x3p1"),  # noqa: E201,E241
513            (8, 8, (3, 3), 2, 0, 1, 1, (2, 8, 16, 16), "3x3s2"),  # noqa: E201,E241
514            (4, 8, (5, 5), 1, 0, 1, 1, (2, 4, 16, 16), "5x5"),  # noqa: E201,E241
515            (4, 4, (3, 3), 1, 0, 4, 1, (2, 4, 16, 16), "3x3dw"),  # noqa: E201,E241
516            (8, 4, (1, 1), 1, 0, 1, 1, (2, 8, 16, 16), "1x1"),  # noqa: E201,E241
517        ]
518
519        for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
520            for case in cases:
521                (
522                    in_ch,
523                    out_ch,
524                    kernel,
525                    stride,
526                    padding,
527                    groups,
528                    bias,
529                    input_dim,
530                    name,
531                ) = case
532                with self.subTest(f"{kind}-{name}"):
533                    inp = torch.randn(input_dim)
534                    model = torch.nn.Conv2d(
535                        in_ch,
536                        out_ch,
537                        kernel,
538                        stride,
539                        padding,
540                        groups=groups,
541                        bias=bool(bias),
542                    )
543                    output_size = model(inp).numel()
544                    atol_rtol = None
545                    limit = None
546                    convert_dims = (0, in_ch, 0, 0)
547                    convert_arg = torch.zeros(*convert_dims)
548
549                    if "quant" in kind:
550                        model = torch.nn.Sequential(model)
551                        model.eval()
552                        model.qconfig = torch.ao.quantization.get_default_qconfig(
553                            "qnnpack"
554                        )
555                        model = torch.ao.quantization.prepare(model)
556                        model(inp)
557                        model = torch.ao.quantization.convert(model)
558                        inp = qpt(inp, 1.0 / 16, 128)
559                        # I've seen numerical differences between QNNPACK and NNAPI,
560                        # but never more than 1 quantum, and never more than ~1% of
561                        # the output in this test.
562                        atol_rtol = (1, 0)
563                        limit = output_size * 0.03
564                        convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128)
565
566                    if "nhwc" in kind:
567                        inp = nhwc(inp)
568                        convert_arg = nhwc(convert_arg)
569
570                    self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
571                    self.check(
572                        model,
573                        inp,
574                        convert_args=[convert_arg],
575                        atol_rtol=atol_rtol,
576                        limit=limit,
577                    )
578
579    def test_conv2d_transpose(self):
580        torch.manual_seed(29)
581        in_ch, out_ch, kernel = (5, 7, (2, 2))
582        input_dim = (4, 5, 3, 3)
583        convert_dims = input_dim[:2] + (0, 0)
584
585        for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]:
586            with self.subTest(kind):
587                inp = torch.randn(input_dim)
588                model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel)
589                output_size = model(inp).numel()
590                atol_rtol = (0.0002, 0)
591                limit = None
592                convert_arg = torch.zeros(*convert_dims)
593
594                if "quant" in kind:
595                    model = torch.ao.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel)
596                    model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
597                    inp = qpt(inp, 1.0 / 16, 128)
598                    # I've seen numerical differences between QNNPACK and NNAPI,
599                    # but never more than 1 quantum, and never more than ~10% of
600                    # the output in this test.
601                    atol_rtol = (1, 0)
602                    limit = output_size * 0.1
603                    convert_arg = qpt(convert_arg, 1.0 / 16, 128)
604
605                if "nhwc" in kind:
606                    inp = nhwc(inp)
607                    convert_arg = nhwc(convert_arg)
608
609                self.check(model, inp, atol_rtol=atol_rtol, limit=limit)
610                self.check(
611                    model,
612                    inp,
613                    convert_args=[convert_arg],
614                    atol_rtol=atol_rtol,
615                    limit=limit,
616                )
617
618    def test_qadd(self):
619        func = torch.ao.nn.quantized.QFunctional()
620        func.scale = 0.5
621        func.zero_point = 120
622
623        class AddMod(torch.nn.Module):
624            def forward(self, lhs, rhs):
625                return func.add(lhs, rhs)
626
627        class AddReluMod(torch.nn.Module):
628            def forward(self, lhs, rhs):
629                return func.add_relu(lhs, rhs)
630
631        class MulMod(torch.nn.Module):
632            def forward(self, lhs, rhs):
633                return func.mul(lhs, rhs)
634
635        for name, mod in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]:
636            with self.subTest(name):
637                self.check(
638                    mod(),
639                    [
640                        qpt([1.0, 2.0], 0.25, 128),
641                        qpt([3.0, 4.0], 0.25, 128),
642                    ],
643                )
644                self.check(
645                    mod(),
646                    [
647                        qpt([[1.0, 2.0]], 0.25, 128),
648                        qpt([[3.0, 4.0]], 0.25, 128),
649                    ],
650                    convert_args=[
651                        qpt([[1.0, 2.0]], 0.25, 128),
652                        qpt(torch.zeros((1, 2)), 0.25, 128),
653                    ],
654                )
655                self.check(
656                    mod(),
657                    [
658                        qpt([[1.0, 2.0]], 0.25, 128),
659                        qpt([[3.0, 4.0]], 0.25, 128),
660                    ],
661                    convert_args=[
662                        qpt(torch.zeros((1, 2)), 0.25, 128),
663                        qpt([[3.0, 4.0]], 0.25, 128),
664                    ],
665                )
666                self.check(
667                    mod(),
668                    [
669                        qpt([[1.0, 2.0]], 0.25, 128),
670                        qpt([[3.0, 4.0]], 0.25, 128),
671                    ],
672                    convert_args=[
673                        qpt(torch.zeros((1, 2)), 0.25, 128),
674                        qpt(torch.zeros((1, 2)), 0.25, 128),
675                    ],
676                )
677                # NOTE: NNAPI qadd supports broadcast, but PT does not.
678
679    def test_qlinear(self):
680        torch.manual_seed(29)
681        weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8)
682        bias = torch.randn(16)
683        mod = torch.ao.nn.quantized.Linear(32, 16)
684        mod.set_weight_bias(weight, bias)
685        inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8)
686        self.check(mod, inp)
687
688    def test_seblock_mul(self):
689        class MulModel(torch.nn.Module):
690            def forward(self, lhs, rhs):
691                return lhs * rhs
692
693        self.check(
694            MulModel(),
695            [
696                nhwc(torch.randn(2, 3, 4, 4)),
697                torch.randn(1, 3, 1, 1),
698            ],
699        )
700
701    def test_multi_output(self):
702        class MultiModel(torch.nn.Module):
703            def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]:
704                the_sum = lhs + rhs
705                the_diff = lhs - rhs
706                return the_sum, the_diff
707
708        self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])])
709
710
711if __name__ == "__main__":
712    run_tests()
713