1#!/usr/bin/env python3 2# Owner(s): ["oncall: mobile"] 3 4import ctypes 5import os 6import unittest 7from typing import Tuple 8 9import torch 10from torch.backends._nnapi.prepare import convert_model_to_nnapi 11from torch.testing._internal.common_quantized import supported_qengines 12from torch.testing._internal.common_utils import run_tests, TestCase 13 14 15def qpt(t, scale, zero_point, dtype=torch.quint8): 16 t = torch.tensor(t) 17 return torch.quantize_per_tensor(t, scale, zero_point, dtype) 18 19 20def nhwc(t): 21 t = t.clone().contiguous(memory_format=torch.channels_last) 22 t.nnapi_nhwc = True 23 return t 24 25 26@unittest.skipUnless( 27 "qnnpack" in supported_qengines, 28 "This Pytorch Build has not been built with or does not support QNNPACK", 29) 30class TestNNAPI(TestCase): 31 def setUp(self): 32 # Avoid saturation in fbgemm 33 torch.backends.quantized.engine = "qnnpack" 34 35 libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH") 36 if libneuralnetworks_path: 37 ctypes.cdll.LoadLibrary(libneuralnetworks_path) 38 print("Will attempt to run NNAPI models.") 39 self.can_run_nnapi = True 40 else: 41 self.can_run_nnapi = False 42 43 # Created for easy override by subclasses (eg TestNnapiBackend) 44 def call_lowering_to_nnapi(self, traced_module, args): 45 return convert_model_to_nnapi(traced_module, args) 46 47 # Created for subclasses to set can_run_nnapi (eg TestNnapiBackend) 48 def set_can_run_nnapi(self, can_run): 49 self.can_run_nnapi = can_run 50 51 def check( 52 self, 53 module, 54 arg_or_args, 55 *, 56 trace_args=None, 57 convert_args=None, 58 atol_rtol=None, 59 limit=None, 60 expected_memory_format=None, 61 ): 62 with torch.no_grad(): 63 if isinstance(arg_or_args, torch.Tensor): 64 args = [arg_or_args] 65 else: 66 args = arg_or_args 67 module.eval() 68 traced = torch.jit.trace(module, trace_args or args) 69 nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args) 70 if not self.can_run_nnapi: 71 # Only test that the model was converted successfully. 72 return 73 eager_output = module(*args) 74 nnapi_output = nnapi_module(*args) 75 kwargs = {} 76 if atol_rtol is not None: 77 kwargs["atol"] = atol_rtol[0] 78 kwargs["rtol"] = atol_rtol[1] 79 self.assertEqual(eager_output, nnapi_output, **kwargs) 80 if limit is not None: 81 mismatches = eager_output.int_repr().to( 82 torch.int32 83 ) - nnapi_output.int_repr().to(torch.int32) 84 if mismatches.count_nonzero() > limit: 85 # Too many mismatches. Re-run the check with no tolerance 86 # to get a nice message. 87 self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0) 88 if expected_memory_format: 89 self.assertTrue( 90 nnapi_output.is_contiguous(memory_format=expected_memory_format) 91 ) 92 93 def float_and_quant_and_nhwc(self, inp_float, scale, zero_point): 94 torch.manual_seed(29) 95 inp_quant = qpt(inp_float, 0.03, 128) 96 return [ 97 ("float", inp_float), 98 ("float-nhwc", nhwc(inp_float)), 99 ("quant", inp_quant), 100 ("quant-nhwc", nhwc(inp_quant)), 101 ] 102 103 def test_prelu(self): 104 arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1) 105 single_a = torch.nn.PReLU() 106 self.check(single_a, arg) 107 multi_a = torch.nn.PReLU(4) 108 with torch.no_grad(): 109 multi_a.weight.copy_(torch.tensor([0.1, 0.2, 0.3, 0.4])) 110 self.check(multi_a, nhwc(arg)) 111 112 # Test flexible size 113 self.check( 114 multi_a, 115 arg, 116 trace_args=[torch.zeros(1, 4, 3, 3)], 117 convert_args=[nhwc(torch.zeros(1, 4, 0, 0))], 118 ) 119 120 def test_quantize(self): 121 self.check( 122 torch.ao.nn.quantized.Quantize(0.25, 2, torch.quint8), 123 nhwc(torch.tensor([[[[1.0]], [[2.0]]]])), 124 ) 125 126 def test_dequantize(self): 127 self.check( 128 torch.ao.nn.quantized.DeQuantize(), nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2)) 129 ) 130 131 def test_unsqueeze(self): 132 class UnsqueezeModule(torch.nn.Module): 133 def __init__(self, dim): 134 super().__init__() 135 self.dim = dim 136 137 def forward(self, arg): 138 return arg.unsqueeze(self.dim) 139 140 self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2)) 141 self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2)) 142 self.check(UnsqueezeModule(0), torch.randn(4, 2, 2)) 143 self.check(UnsqueezeModule(1), torch.randn(4, 2, 2)) 144 self.check(UnsqueezeModule(2), torch.randn(4, 2, 2)) 145 146 def test_reshape(self): 147 class ReshapeModule(torch.nn.Module): 148 def __init__(self, shape): 149 super().__init__() 150 self.shape = shape 151 152 def forward(self, arg): 153 return arg.reshape(self.shape) 154 155 self.check(ReshapeModule((2, 4)), torch.randn(4, 2, 1, 1)) 156 157 self.check(ReshapeModule((8, -1)), nhwc(torch.randn(4, 2, 1, 1))) 158 159 with self.assertRaisesRegex(Exception, "target size"): 160 self.check(ReshapeModule((2, 4)), nhwc(torch.randn(4, 2, 1, 1))) 161 162 def test_flatten(self): 163 for mod in [ 164 torch.nn.Flatten(), 165 torch.nn.Flatten(start_dim=2, end_dim=3), 166 torch.nn.Flatten(start_dim=2, end_dim=4), 167 torch.nn.Flatten(start_dim=0, end_dim=-2), 168 torch.nn.Flatten(start_dim=0, end_dim=4), 169 ]: 170 self.check(mod, torch.randn(4, 2, 1, 3, 7)) 171 172 # flex inputs 173 self.check( 174 torch.nn.Flatten(), 175 torch.randn(4, 2, 1, 3, 7), 176 convert_args=[torch.zeros(0, 2, 1, 3, 7)], 177 ) 178 179 # channels last 180 self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 1, 4, 7))) 181 self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 3, 1, 1))) 182 183 # Exceptions 184 with self.assertRaisesRegex(Exception, "not supported on NHWC"): 185 self.check(torch.nn.Flatten(), nhwc(torch.randn(1, 3, 4, 4))) 186 with self.assertRaisesRegex( 187 Exception, "Flattening flexible dims is not supported yet" 188 ): 189 self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7)) 190 with self.assertRaisesRegex(Exception, "Only 1 dim"): 191 self.check( 192 torch.nn.Flatten(start_dim=1, end_dim=-2), torch.randn(0, 2, 1, 3, 0) 193 ) 194 195 def test_slice(self): 196 class SliceModule(torch.nn.Module): 197 def __init__(self, start, stop, step): 198 super().__init__() 199 self.start = start 200 self.stop = stop 201 self.step = step 202 203 def forward(self, t): 204 return t[1:, self.start : self.stop : self.step, :] 205 206 class SliceModule2(torch.nn.Module): 207 def forward(self, t): 208 return t[3:] 209 210 self.check(SliceModule(1, 5, 2), torch.randn(4, 6, 2)) 211 self.check(SliceModule2(), torch.randn(5)) 212 213 # flex inputs 214 self.check( 215 SliceModule(1, 5, 2), 216 torch.randn(4, 6, 2), 217 convert_args=[torch.zeros(4, 6, 0)], 218 ) 219 with self.assertRaisesRegex(Exception, "slice with flexible shape"): 220 self.check( 221 SliceModule(1, 5, 2), 222 torch.randn(4, 6, 2), 223 convert_args=[torch.zeros(0, 0, 0)], 224 ) 225 226 def test_cat(self): 227 class CatModule(torch.nn.Module): 228 def __init__(self, dim): 229 super().__init__() 230 self.dim = dim 231 232 def forward(self, t1, t2): 233 return torch.cat([t1, t2], self.dim) 234 235 self.check( 236 CatModule(0), 237 [ 238 torch.randn(1, 2, 3, 3), 239 torch.randn(2, 2, 3, 3), 240 ], 241 ) 242 243 self.check( 244 CatModule(1), 245 [ 246 torch.randn(1, 2, 3, 3), 247 torch.randn(1, 4, 3, 3), 248 ], 249 ) 250 251 self.check( 252 CatModule(1), 253 [ 254 nhwc(torch.randn(1, 2, 3, 3)), 255 nhwc(torch.randn(1, 4, 3, 3)), 256 ], 257 ) 258 259 self.check( 260 CatModule(1), 261 [ 262 torch.randn(1, 2, 3, 3), 263 torch.randn(1, 4, 3, 3), 264 ], 265 convert_args=[torch.zeros(0, 0, 0, 0), torch.zeros(0, 0, 0, 0)], 266 ) 267 268 def test_pointwise_unary(self): 269 for op in ["relu", "sigmoid"]: 270 with self.subTest(op): 271 272 class UnaryModule(torch.nn.Module): 273 def forward(self, arg): 274 if op == "relu": 275 return torch.nn.functional.relu(arg) 276 if op == "sigmoid": 277 return torch.sigmoid(arg) 278 raise Exception("Bad op") # noqa: TRY002 279 280 self.check(UnaryModule(), torch.tensor([-1.0, 1.0])) 281 self.check( 282 UnaryModule(), 283 qpt(torch.tensor([-1.0, 1.0]), 1.0 / 256, 0), 284 ) 285 286 def test_pointwise_binary(self): 287 for op in ["add", "sub", "mul", "div"]: 288 with self.subTest(op): 289 290 class BinaryModule(torch.nn.Module): 291 def forward(self, lhs, rhs): 292 if op == "add": 293 return lhs + rhs 294 if op == "sub": 295 return lhs - rhs 296 if op == "mul": 297 return lhs * rhs 298 if op == "div": 299 return lhs / rhs 300 raise Exception("Bad op") # noqa: TRY002 301 302 self.check( 303 BinaryModule(), 304 [ 305 torch.tensor([1.0, 2.0]), 306 torch.tensor([3.0, 4.0]), 307 ], 308 ) 309 310 self.check( 311 BinaryModule(), 312 [ 313 torch.tensor([[1.0, 2.0]]), 314 torch.tensor([[3.0, 4.0], [5.0, 6.0]]), 315 ], 316 ) 317 318 with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"): 319 self.check( 320 BinaryModule(), 321 [ 322 torch.tensor([1.0, 2.0]), 323 torch.tensor([[3.0, 4.0], [5.0, 6.0]]), 324 ], 325 ) 326 327 def test_pointwise_binary_const(self): 328 const = torch.randn(1, 4, 6, 6) 329 330 class ArgPlusConst(torch.nn.Module): 331 def forward(self, arg): 332 return arg + const 333 334 class ConstPlusArg(torch.nn.Module): 335 def forward(self, arg): 336 return const + arg 337 338 arg_contig = torch.randn(2, 4, 6, 6) 339 arg_nhwc = nhwc(torch.randn(2, 4, 6, 6)) 340 341 for mod_class in [ArgPlusConst, ConstPlusArg]: 342 for use_nhwc in [False, True]: 343 with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc): 344 arg = arg_nhwc if use_nhwc else arg_contig 345 memory_format = ( 346 torch.channels_last if use_nhwc else torch.contiguous_format 347 ) 348 self.check(mod_class(), arg, expected_memory_format=memory_format) 349 350 def test_hardtanh(self): 351 inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0]) 352 self.check(torch.nn.Hardtanh(), inp) 353 self.check(torch.nn.Hardtanh(0.0, 6.0), inp) 354 with self.assertRaisesRegex(Exception, "hardtanh with args"): 355 self.check(torch.nn.Hardtanh(0.0, 5.0), inp) 356 357 def test_softmax(self): 358 inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]]) 359 self.check(torch.nn.Softmax(), inp) 360 self.check(torch.nn.Softmax(dim=0), inp) 361 # Test flexible size 362 self.check( 363 torch.nn.Softmax(), 364 inp, 365 convert_args=[torch.zeros(0, 0)], 366 ) 367 368 def test_to(self): 369 class ToCPU(torch.nn.Module): 370 def __init__(self) -> None: 371 super().__init__() 372 self.prelu = torch.nn.PReLU() 373 374 def forward(self, x): 375 y = x.to("cpu") 376 # add prelu since input operand can't be output 377 return self.prelu(y) 378 379 arg = torch.randn(1, 2, 3, 3) 380 self.check(ToCPU(), arg) 381 # Test flexible size 382 self.check( 383 ToCPU(), 384 arg, 385 convert_args=[torch.zeros(1, 2, 0, 0)], 386 ) 387 388 def test_detach(self): 389 class DetachModule(torch.nn.Module): 390 def forward(self, x): 391 y = x.detach() 392 return torch.nn.functional.relu(y) 393 394 self.check(DetachModule(), torch.randn(1, 2, 3, 3)) 395 self.check( 396 DetachModule(), 397 torch.randn(1, 2, 3, 3), 398 convert_args=[torch.zeros(1, 2, 0, 0)], 399 ) 400 401 def test_log_softmax(self): 402 inp = torch.randn(3, 10) 403 self.check(torch.nn.LogSoftmax(), inp) 404 self.check(torch.nn.LogSoftmax(0), inp) 405 406 def test_mean(self): 407 class MeanModule(torch.nn.Module): 408 def __init__(self, dim, keep=False): 409 super().__init__() 410 self.dim = dim 411 self.keep = keep 412 413 def forward(self, t): 414 return torch.mean(t, dim=self.dim, keepdim=self.keep) 415 416 self.check(MeanModule(0), torch.randn(2, 3)) 417 self.check(MeanModule(1), torch.randn(2, 3)) 418 self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6)) 419 self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6))) 420 self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6))) 421 self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6))) 422 423 def test_max_pool2d(self): 424 for name, inp in self.float_and_quant_and_nhwc( 425 torch.randn(2, 3, 12, 16), 0.3, 128 426 ): 427 with self.subTest(name): 428 self.check(torch.nn.MaxPool2d(2), inp) 429 self.check(torch.nn.MaxPool2d((3, 4)), inp) 430 self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp) 431 432 def test_avg_pool2d(self): 433 for name, inp in self.float_and_quant_and_nhwc( 434 torch.randn(2, 3, 12, 16), 0.3, 128 435 ): 436 with self.subTest(name): 437 atol_rtol = None 438 limit = None 439 convert_dims = (2, 3, 0, 0) 440 convert_arg = torch.zeros(*convert_dims) 441 442 for model in ( 443 torch.nn.AvgPool2d(2), 444 torch.nn.AvgPool2d((3, 4)), 445 torch.nn.AvgPool2d((3, 4), (1, 2)), 446 ): 447 if "quant" in name: 448 atol_rtol = (1, 0) 449 limit = model(inp).numel() 450 convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128) 451 if "nhwc" in name: 452 convert_arg = nhwc(convert_arg) 453 454 self.check(model, inp, atol_rtol=atol_rtol, limit=limit) 455 self.check( 456 model, 457 inp, 458 convert_args=[convert_arg], 459 atol_rtol=atol_rtol, 460 limit=limit, 461 ) 462 463 def test_adaptive_avg_pool2d(self): 464 for name, inp in self.float_and_quant_and_nhwc( 465 torch.randn(2, 3, 12, 16), 0.3, 128 466 ): 467 with self.subTest(name): 468 self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp) 469 with self.assertRaisesRegex(Exception, "with output size"): 470 self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp) 471 472 def test_upsample_nearest2d(self): 473 convert_args = dict( 474 self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128) 475 ) 476 for name, inp in self.float_and_quant_and_nhwc( 477 torch.randn(2, 3, 12, 16), 0.3, 128 478 ): 479 with self.subTest(name): 480 self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp) 481 self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp) 482 self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp) 483 self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp) 484 self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp) 485 self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp) 486 487 self.check( 488 torch.nn.UpsamplingNearest2d(size=(24, 32)), 489 inp, 490 convert_args=[convert_args[name]], 491 ) 492 self.check( 493 torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), 494 inp, 495 convert_args=[convert_args[name]], 496 ) 497 498 def test_linear(self): 499 torch.manual_seed(29) 500 self.check(torch.nn.Linear(16, 32), torch.randn(2, 16)) 501 self.check( 502 torch.nn.Linear(16, 32), 503 torch.randn(2, 16), 504 convert_args=[torch.zeros(0, 16)], 505 ) 506 507 def test_conv2d(self): 508 cases = [ 509 # in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name 510 (4, 8, (3, 3), 1, 0, 1, 1, (2, 4, 16, 16), "3x3"), # noqa: E201,E241 511 (4, 8, (3, 3), 1, 0, 1, 0, (2, 4, 16, 16), "3x3nobias"), # noqa: E201,E241 512 (4, 16, (3, 3), 1, 1, 1, 1, (2, 4, 16, 16), "3x3p1"), # noqa: E201,E241 513 (8, 8, (3, 3), 2, 0, 1, 1, (2, 8, 16, 16), "3x3s2"), # noqa: E201,E241 514 (4, 8, (5, 5), 1, 0, 1, 1, (2, 4, 16, 16), "5x5"), # noqa: E201,E241 515 (4, 4, (3, 3), 1, 0, 4, 1, (2, 4, 16, 16), "3x3dw"), # noqa: E201,E241 516 (8, 4, (1, 1), 1, 0, 1, 1, (2, 8, 16, 16), "1x1"), # noqa: E201,E241 517 ] 518 519 for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]: 520 for case in cases: 521 ( 522 in_ch, 523 out_ch, 524 kernel, 525 stride, 526 padding, 527 groups, 528 bias, 529 input_dim, 530 name, 531 ) = case 532 with self.subTest(f"{kind}-{name}"): 533 inp = torch.randn(input_dim) 534 model = torch.nn.Conv2d( 535 in_ch, 536 out_ch, 537 kernel, 538 stride, 539 padding, 540 groups=groups, 541 bias=bool(bias), 542 ) 543 output_size = model(inp).numel() 544 atol_rtol = None 545 limit = None 546 convert_dims = (0, in_ch, 0, 0) 547 convert_arg = torch.zeros(*convert_dims) 548 549 if "quant" in kind: 550 model = torch.nn.Sequential(model) 551 model.eval() 552 model.qconfig = torch.ao.quantization.get_default_qconfig( 553 "qnnpack" 554 ) 555 model = torch.ao.quantization.prepare(model) 556 model(inp) 557 model = torch.ao.quantization.convert(model) 558 inp = qpt(inp, 1.0 / 16, 128) 559 # I've seen numerical differences between QNNPACK and NNAPI, 560 # but never more than 1 quantum, and never more than ~1% of 561 # the output in this test. 562 atol_rtol = (1, 0) 563 limit = output_size * 0.03 564 convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128) 565 566 if "nhwc" in kind: 567 inp = nhwc(inp) 568 convert_arg = nhwc(convert_arg) 569 570 self.check(model, inp, atol_rtol=atol_rtol, limit=limit) 571 self.check( 572 model, 573 inp, 574 convert_args=[convert_arg], 575 atol_rtol=atol_rtol, 576 limit=limit, 577 ) 578 579 def test_conv2d_transpose(self): 580 torch.manual_seed(29) 581 in_ch, out_ch, kernel = (5, 7, (2, 2)) 582 input_dim = (4, 5, 3, 3) 583 convert_dims = input_dim[:2] + (0, 0) 584 585 for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]: 586 with self.subTest(kind): 587 inp = torch.randn(input_dim) 588 model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel) 589 output_size = model(inp).numel() 590 atol_rtol = (0.0002, 0) 591 limit = None 592 convert_arg = torch.zeros(*convert_dims) 593 594 if "quant" in kind: 595 model = torch.ao.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel) 596 model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") 597 inp = qpt(inp, 1.0 / 16, 128) 598 # I've seen numerical differences between QNNPACK and NNAPI, 599 # but never more than 1 quantum, and never more than ~10% of 600 # the output in this test. 601 atol_rtol = (1, 0) 602 limit = output_size * 0.1 603 convert_arg = qpt(convert_arg, 1.0 / 16, 128) 604 605 if "nhwc" in kind: 606 inp = nhwc(inp) 607 convert_arg = nhwc(convert_arg) 608 609 self.check(model, inp, atol_rtol=atol_rtol, limit=limit) 610 self.check( 611 model, 612 inp, 613 convert_args=[convert_arg], 614 atol_rtol=atol_rtol, 615 limit=limit, 616 ) 617 618 def test_qadd(self): 619 func = torch.ao.nn.quantized.QFunctional() 620 func.scale = 0.5 621 func.zero_point = 120 622 623 class AddMod(torch.nn.Module): 624 def forward(self, lhs, rhs): 625 return func.add(lhs, rhs) 626 627 class AddReluMod(torch.nn.Module): 628 def forward(self, lhs, rhs): 629 return func.add_relu(lhs, rhs) 630 631 class MulMod(torch.nn.Module): 632 def forward(self, lhs, rhs): 633 return func.mul(lhs, rhs) 634 635 for name, mod in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]: 636 with self.subTest(name): 637 self.check( 638 mod(), 639 [ 640 qpt([1.0, 2.0], 0.25, 128), 641 qpt([3.0, 4.0], 0.25, 128), 642 ], 643 ) 644 self.check( 645 mod(), 646 [ 647 qpt([[1.0, 2.0]], 0.25, 128), 648 qpt([[3.0, 4.0]], 0.25, 128), 649 ], 650 convert_args=[ 651 qpt([[1.0, 2.0]], 0.25, 128), 652 qpt(torch.zeros((1, 2)), 0.25, 128), 653 ], 654 ) 655 self.check( 656 mod(), 657 [ 658 qpt([[1.0, 2.0]], 0.25, 128), 659 qpt([[3.0, 4.0]], 0.25, 128), 660 ], 661 convert_args=[ 662 qpt(torch.zeros((1, 2)), 0.25, 128), 663 qpt([[3.0, 4.0]], 0.25, 128), 664 ], 665 ) 666 self.check( 667 mod(), 668 [ 669 qpt([[1.0, 2.0]], 0.25, 128), 670 qpt([[3.0, 4.0]], 0.25, 128), 671 ], 672 convert_args=[ 673 qpt(torch.zeros((1, 2)), 0.25, 128), 674 qpt(torch.zeros((1, 2)), 0.25, 128), 675 ], 676 ) 677 # NOTE: NNAPI qadd supports broadcast, but PT does not. 678 679 def test_qlinear(self): 680 torch.manual_seed(29) 681 weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8) 682 bias = torch.randn(16) 683 mod = torch.ao.nn.quantized.Linear(32, 16) 684 mod.set_weight_bias(weight, bias) 685 inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8) 686 self.check(mod, inp) 687 688 def test_seblock_mul(self): 689 class MulModel(torch.nn.Module): 690 def forward(self, lhs, rhs): 691 return lhs * rhs 692 693 self.check( 694 MulModel(), 695 [ 696 nhwc(torch.randn(2, 3, 4, 4)), 697 torch.randn(1, 3, 1, 1), 698 ], 699 ) 700 701 def test_multi_output(self): 702 class MultiModel(torch.nn.Module): 703 def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]: 704 the_sum = lhs + rhs 705 the_diff = lhs - rhs 706 return the_sum, the_diff 707 708 self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])]) 709 710 711if __name__ == "__main__": 712 run_tests() 713