1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"] 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport ctypes 5*da0073e9SAndroid Build Coastguard Workerimport os 6*da0073e9SAndroid Build Coastguard Workerimport unittest 7*da0073e9SAndroid Build Coastguard Workerfrom typing import Tuple 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerfrom torch.backends._nnapi.prepare import convert_model_to_nnapi 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantized import supported_qengines 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerdef qpt(t, scale, zero_point, dtype=torch.quint8): 16*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(t) 17*da0073e9SAndroid Build Coastguard Worker return torch.quantize_per_tensor(t, scale, zero_point, dtype) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerdef nhwc(t): 21*da0073e9SAndroid Build Coastguard Worker t = t.clone().contiguous(memory_format=torch.channels_last) 22*da0073e9SAndroid Build Coastguard Worker t.nnapi_nhwc = True 23*da0073e9SAndroid Build Coastguard Worker return t 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless( 27*da0073e9SAndroid Build Coastguard Worker "qnnpack" in supported_qengines, 28*da0073e9SAndroid Build Coastguard Worker "This Pytorch Build has not been built with or does not support QNNPACK", 29*da0073e9SAndroid Build Coastguard Worker) 30*da0073e9SAndroid Build Coastguard Workerclass TestNNAPI(TestCase): 31*da0073e9SAndroid Build Coastguard Worker def setUp(self): 32*da0073e9SAndroid Build Coastguard Worker # Avoid saturation in fbgemm 33*da0073e9SAndroid Build Coastguard Worker torch.backends.quantized.engine = "qnnpack" 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH") 36*da0073e9SAndroid Build Coastguard Worker if libneuralnetworks_path: 37*da0073e9SAndroid Build Coastguard Worker ctypes.cdll.LoadLibrary(libneuralnetworks_path) 38*da0073e9SAndroid Build Coastguard Worker print("Will attempt to run NNAPI models.") 39*da0073e9SAndroid Build Coastguard Worker self.can_run_nnapi = True 40*da0073e9SAndroid Build Coastguard Worker else: 41*da0073e9SAndroid Build Coastguard Worker self.can_run_nnapi = False 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker # Created for easy override by subclasses (eg TestNnapiBackend) 44*da0073e9SAndroid Build Coastguard Worker def call_lowering_to_nnapi(self, traced_module, args): 45*da0073e9SAndroid Build Coastguard Worker return convert_model_to_nnapi(traced_module, args) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker # Created for subclasses to set can_run_nnapi (eg TestNnapiBackend) 48*da0073e9SAndroid Build Coastguard Worker def set_can_run_nnapi(self, can_run): 49*da0073e9SAndroid Build Coastguard Worker self.can_run_nnapi = can_run 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker def check( 52*da0073e9SAndroid Build Coastguard Worker self, 53*da0073e9SAndroid Build Coastguard Worker module, 54*da0073e9SAndroid Build Coastguard Worker arg_or_args, 55*da0073e9SAndroid Build Coastguard Worker *, 56*da0073e9SAndroid Build Coastguard Worker trace_args=None, 57*da0073e9SAndroid Build Coastguard Worker convert_args=None, 58*da0073e9SAndroid Build Coastguard Worker atol_rtol=None, 59*da0073e9SAndroid Build Coastguard Worker limit=None, 60*da0073e9SAndroid Build Coastguard Worker expected_memory_format=None, 61*da0073e9SAndroid Build Coastguard Worker ): 62*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 63*da0073e9SAndroid Build Coastguard Worker if isinstance(arg_or_args, torch.Tensor): 64*da0073e9SAndroid Build Coastguard Worker args = [arg_or_args] 65*da0073e9SAndroid Build Coastguard Worker else: 66*da0073e9SAndroid Build Coastguard Worker args = arg_or_args 67*da0073e9SAndroid Build Coastguard Worker module.eval() 68*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(module, trace_args or args) 69*da0073e9SAndroid Build Coastguard Worker nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args) 70*da0073e9SAndroid Build Coastguard Worker if not self.can_run_nnapi: 71*da0073e9SAndroid Build Coastguard Worker # Only test that the model was converted successfully. 72*da0073e9SAndroid Build Coastguard Worker return 73*da0073e9SAndroid Build Coastguard Worker eager_output = module(*args) 74*da0073e9SAndroid Build Coastguard Worker nnapi_output = nnapi_module(*args) 75*da0073e9SAndroid Build Coastguard Worker kwargs = {} 76*da0073e9SAndroid Build Coastguard Worker if atol_rtol is not None: 77*da0073e9SAndroid Build Coastguard Worker kwargs["atol"] = atol_rtol[0] 78*da0073e9SAndroid Build Coastguard Worker kwargs["rtol"] = atol_rtol[1] 79*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_output, nnapi_output, **kwargs) 80*da0073e9SAndroid Build Coastguard Worker if limit is not None: 81*da0073e9SAndroid Build Coastguard Worker mismatches = eager_output.int_repr().to( 82*da0073e9SAndroid Build Coastguard Worker torch.int32 83*da0073e9SAndroid Build Coastguard Worker ) - nnapi_output.int_repr().to(torch.int32) 84*da0073e9SAndroid Build Coastguard Worker if mismatches.count_nonzero() > limit: 85*da0073e9SAndroid Build Coastguard Worker # Too many mismatches. Re-run the check with no tolerance 86*da0073e9SAndroid Build Coastguard Worker # to get a nice message. 87*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0) 88*da0073e9SAndroid Build Coastguard Worker if expected_memory_format: 89*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 90*da0073e9SAndroid Build Coastguard Worker nnapi_output.is_contiguous(memory_format=expected_memory_format) 91*da0073e9SAndroid Build Coastguard Worker ) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker def float_and_quant_and_nhwc(self, inp_float, scale, zero_point): 94*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(29) 95*da0073e9SAndroid Build Coastguard Worker inp_quant = qpt(inp_float, 0.03, 128) 96*da0073e9SAndroid Build Coastguard Worker return [ 97*da0073e9SAndroid Build Coastguard Worker ("float", inp_float), 98*da0073e9SAndroid Build Coastguard Worker ("float-nhwc", nhwc(inp_float)), 99*da0073e9SAndroid Build Coastguard Worker ("quant", inp_quant), 100*da0073e9SAndroid Build Coastguard Worker ("quant-nhwc", nhwc(inp_quant)), 101*da0073e9SAndroid Build Coastguard Worker ] 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker def test_prelu(self): 104*da0073e9SAndroid Build Coastguard Worker arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1) 105*da0073e9SAndroid Build Coastguard Worker single_a = torch.nn.PReLU() 106*da0073e9SAndroid Build Coastguard Worker self.check(single_a, arg) 107*da0073e9SAndroid Build Coastguard Worker multi_a = torch.nn.PReLU(4) 108*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 109*da0073e9SAndroid Build Coastguard Worker multi_a.weight.copy_(torch.tensor([0.1, 0.2, 0.3, 0.4])) 110*da0073e9SAndroid Build Coastguard Worker self.check(multi_a, nhwc(arg)) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker # Test flexible size 113*da0073e9SAndroid Build Coastguard Worker self.check( 114*da0073e9SAndroid Build Coastguard Worker multi_a, 115*da0073e9SAndroid Build Coastguard Worker arg, 116*da0073e9SAndroid Build Coastguard Worker trace_args=[torch.zeros(1, 4, 3, 3)], 117*da0073e9SAndroid Build Coastguard Worker convert_args=[nhwc(torch.zeros(1, 4, 0, 0))], 118*da0073e9SAndroid Build Coastguard Worker ) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def test_quantize(self): 121*da0073e9SAndroid Build Coastguard Worker self.check( 122*da0073e9SAndroid Build Coastguard Worker torch.ao.nn.quantized.Quantize(0.25, 2, torch.quint8), 123*da0073e9SAndroid Build Coastguard Worker nhwc(torch.tensor([[[[1.0]], [[2.0]]]])), 124*da0073e9SAndroid Build Coastguard Worker ) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker def test_dequantize(self): 127*da0073e9SAndroid Build Coastguard Worker self.check( 128*da0073e9SAndroid Build Coastguard Worker torch.ao.nn.quantized.DeQuantize(), nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2)) 129*da0073e9SAndroid Build Coastguard Worker ) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker def test_unsqueeze(self): 132*da0073e9SAndroid Build Coastguard Worker class UnsqueezeModule(torch.nn.Module): 133*da0073e9SAndroid Build Coastguard Worker def __init__(self, dim): 134*da0073e9SAndroid Build Coastguard Worker super().__init__() 135*da0073e9SAndroid Build Coastguard Worker self.dim = dim 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 138*da0073e9SAndroid Build Coastguard Worker return arg.unsqueeze(self.dim) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2)) 141*da0073e9SAndroid Build Coastguard Worker self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2)) 142*da0073e9SAndroid Build Coastguard Worker self.check(UnsqueezeModule(0), torch.randn(4, 2, 2)) 143*da0073e9SAndroid Build Coastguard Worker self.check(UnsqueezeModule(1), torch.randn(4, 2, 2)) 144*da0073e9SAndroid Build Coastguard Worker self.check(UnsqueezeModule(2), torch.randn(4, 2, 2)) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker def test_reshape(self): 147*da0073e9SAndroid Build Coastguard Worker class ReshapeModule(torch.nn.Module): 148*da0073e9SAndroid Build Coastguard Worker def __init__(self, shape): 149*da0073e9SAndroid Build Coastguard Worker super().__init__() 150*da0073e9SAndroid Build Coastguard Worker self.shape = shape 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 153*da0073e9SAndroid Build Coastguard Worker return arg.reshape(self.shape) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker self.check(ReshapeModule((2, 4)), torch.randn(4, 2, 1, 1)) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker self.check(ReshapeModule((8, -1)), nhwc(torch.randn(4, 2, 1, 1))) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "target size"): 160*da0073e9SAndroid Build Coastguard Worker self.check(ReshapeModule((2, 4)), nhwc(torch.randn(4, 2, 1, 1))) 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker def test_flatten(self): 163*da0073e9SAndroid Build Coastguard Worker for mod in [ 164*da0073e9SAndroid Build Coastguard Worker torch.nn.Flatten(), 165*da0073e9SAndroid Build Coastguard Worker torch.nn.Flatten(start_dim=2, end_dim=3), 166*da0073e9SAndroid Build Coastguard Worker torch.nn.Flatten(start_dim=2, end_dim=4), 167*da0073e9SAndroid Build Coastguard Worker torch.nn.Flatten(start_dim=0, end_dim=-2), 168*da0073e9SAndroid Build Coastguard Worker torch.nn.Flatten(start_dim=0, end_dim=4), 169*da0073e9SAndroid Build Coastguard Worker ]: 170*da0073e9SAndroid Build Coastguard Worker self.check(mod, torch.randn(4, 2, 1, 3, 7)) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker # flex inputs 173*da0073e9SAndroid Build Coastguard Worker self.check( 174*da0073e9SAndroid Build Coastguard Worker torch.nn.Flatten(), 175*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 2, 1, 3, 7), 176*da0073e9SAndroid Build Coastguard Worker convert_args=[torch.zeros(0, 2, 1, 3, 7)], 177*da0073e9SAndroid Build Coastguard Worker ) 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker # channels last 180*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 1, 4, 7))) 181*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.Flatten(), nhwc(torch.randn(2, 3, 1, 1))) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker # Exceptions 184*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "not supported on NHWC"): 185*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.Flatten(), nhwc(torch.randn(1, 3, 4, 4))) 186*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 187*da0073e9SAndroid Build Coastguard Worker Exception, "Flattening flexible dims is not supported yet" 188*da0073e9SAndroid Build Coastguard Worker ): 189*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7)) 190*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Only 1 dim"): 191*da0073e9SAndroid Build Coastguard Worker self.check( 192*da0073e9SAndroid Build Coastguard Worker torch.nn.Flatten(start_dim=1, end_dim=-2), torch.randn(0, 2, 1, 3, 0) 193*da0073e9SAndroid Build Coastguard Worker ) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker def test_slice(self): 196*da0073e9SAndroid Build Coastguard Worker class SliceModule(torch.nn.Module): 197*da0073e9SAndroid Build Coastguard Worker def __init__(self, start, stop, step): 198*da0073e9SAndroid Build Coastguard Worker super().__init__() 199*da0073e9SAndroid Build Coastguard Worker self.start = start 200*da0073e9SAndroid Build Coastguard Worker self.stop = stop 201*da0073e9SAndroid Build Coastguard Worker self.step = step 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker def forward(self, t): 204*da0073e9SAndroid Build Coastguard Worker return t[1:, self.start : self.stop : self.step, :] 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker class SliceModule2(torch.nn.Module): 207*da0073e9SAndroid Build Coastguard Worker def forward(self, t): 208*da0073e9SAndroid Build Coastguard Worker return t[3:] 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker self.check(SliceModule(1, 5, 2), torch.randn(4, 6, 2)) 211*da0073e9SAndroid Build Coastguard Worker self.check(SliceModule2(), torch.randn(5)) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker # flex inputs 214*da0073e9SAndroid Build Coastguard Worker self.check( 215*da0073e9SAndroid Build Coastguard Worker SliceModule(1, 5, 2), 216*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 6, 2), 217*da0073e9SAndroid Build Coastguard Worker convert_args=[torch.zeros(4, 6, 0)], 218*da0073e9SAndroid Build Coastguard Worker ) 219*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "slice with flexible shape"): 220*da0073e9SAndroid Build Coastguard Worker self.check( 221*da0073e9SAndroid Build Coastguard Worker SliceModule(1, 5, 2), 222*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 6, 2), 223*da0073e9SAndroid Build Coastguard Worker convert_args=[torch.zeros(0, 0, 0)], 224*da0073e9SAndroid Build Coastguard Worker ) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker def test_cat(self): 227*da0073e9SAndroid Build Coastguard Worker class CatModule(torch.nn.Module): 228*da0073e9SAndroid Build Coastguard Worker def __init__(self, dim): 229*da0073e9SAndroid Build Coastguard Worker super().__init__() 230*da0073e9SAndroid Build Coastguard Worker self.dim = dim 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def forward(self, t1, t2): 233*da0073e9SAndroid Build Coastguard Worker return torch.cat([t1, t2], self.dim) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker self.check( 236*da0073e9SAndroid Build Coastguard Worker CatModule(0), 237*da0073e9SAndroid Build Coastguard Worker [ 238*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, 3, 3), 239*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 2, 3, 3), 240*da0073e9SAndroid Build Coastguard Worker ], 241*da0073e9SAndroid Build Coastguard Worker ) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker self.check( 244*da0073e9SAndroid Build Coastguard Worker CatModule(1), 245*da0073e9SAndroid Build Coastguard Worker [ 246*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, 3, 3), 247*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 4, 3, 3), 248*da0073e9SAndroid Build Coastguard Worker ], 249*da0073e9SAndroid Build Coastguard Worker ) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker self.check( 252*da0073e9SAndroid Build Coastguard Worker CatModule(1), 253*da0073e9SAndroid Build Coastguard Worker [ 254*da0073e9SAndroid Build Coastguard Worker nhwc(torch.randn(1, 2, 3, 3)), 255*da0073e9SAndroid Build Coastguard Worker nhwc(torch.randn(1, 4, 3, 3)), 256*da0073e9SAndroid Build Coastguard Worker ], 257*da0073e9SAndroid Build Coastguard Worker ) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker self.check( 260*da0073e9SAndroid Build Coastguard Worker CatModule(1), 261*da0073e9SAndroid Build Coastguard Worker [ 262*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, 3, 3), 263*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 4, 3, 3), 264*da0073e9SAndroid Build Coastguard Worker ], 265*da0073e9SAndroid Build Coastguard Worker convert_args=[torch.zeros(0, 0, 0, 0), torch.zeros(0, 0, 0, 0)], 266*da0073e9SAndroid Build Coastguard Worker ) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker def test_pointwise_unary(self): 269*da0073e9SAndroid Build Coastguard Worker for op in ["relu", "sigmoid"]: 270*da0073e9SAndroid Build Coastguard Worker with self.subTest(op): 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker class UnaryModule(torch.nn.Module): 273*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 274*da0073e9SAndroid Build Coastguard Worker if op == "relu": 275*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.relu(arg) 276*da0073e9SAndroid Build Coastguard Worker if op == "sigmoid": 277*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(arg) 278*da0073e9SAndroid Build Coastguard Worker raise Exception("Bad op") # noqa: TRY002 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker self.check(UnaryModule(), torch.tensor([-1.0, 1.0])) 281*da0073e9SAndroid Build Coastguard Worker self.check( 282*da0073e9SAndroid Build Coastguard Worker UnaryModule(), 283*da0073e9SAndroid Build Coastguard Worker qpt(torch.tensor([-1.0, 1.0]), 1.0 / 256, 0), 284*da0073e9SAndroid Build Coastguard Worker ) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker def test_pointwise_binary(self): 287*da0073e9SAndroid Build Coastguard Worker for op in ["add", "sub", "mul", "div"]: 288*da0073e9SAndroid Build Coastguard Worker with self.subTest(op): 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker class BinaryModule(torch.nn.Module): 291*da0073e9SAndroid Build Coastguard Worker def forward(self, lhs, rhs): 292*da0073e9SAndroid Build Coastguard Worker if op == "add": 293*da0073e9SAndroid Build Coastguard Worker return lhs + rhs 294*da0073e9SAndroid Build Coastguard Worker if op == "sub": 295*da0073e9SAndroid Build Coastguard Worker return lhs - rhs 296*da0073e9SAndroid Build Coastguard Worker if op == "mul": 297*da0073e9SAndroid Build Coastguard Worker return lhs * rhs 298*da0073e9SAndroid Build Coastguard Worker if op == "div": 299*da0073e9SAndroid Build Coastguard Worker return lhs / rhs 300*da0073e9SAndroid Build Coastguard Worker raise Exception("Bad op") # noqa: TRY002 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker self.check( 303*da0073e9SAndroid Build Coastguard Worker BinaryModule(), 304*da0073e9SAndroid Build Coastguard Worker [ 305*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.0, 2.0]), 306*da0073e9SAndroid Build Coastguard Worker torch.tensor([3.0, 4.0]), 307*da0073e9SAndroid Build Coastguard Worker ], 308*da0073e9SAndroid Build Coastguard Worker ) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker self.check( 311*da0073e9SAndroid Build Coastguard Worker BinaryModule(), 312*da0073e9SAndroid Build Coastguard Worker [ 313*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1.0, 2.0]]), 314*da0073e9SAndroid Build Coastguard Worker torch.tensor([[3.0, 4.0], [5.0, 6.0]]), 315*da0073e9SAndroid Build Coastguard Worker ], 316*da0073e9SAndroid Build Coastguard Worker ) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"): 319*da0073e9SAndroid Build Coastguard Worker self.check( 320*da0073e9SAndroid Build Coastguard Worker BinaryModule(), 321*da0073e9SAndroid Build Coastguard Worker [ 322*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.0, 2.0]), 323*da0073e9SAndroid Build Coastguard Worker torch.tensor([[3.0, 4.0], [5.0, 6.0]]), 324*da0073e9SAndroid Build Coastguard Worker ], 325*da0073e9SAndroid Build Coastguard Worker ) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker def test_pointwise_binary_const(self): 328*da0073e9SAndroid Build Coastguard Worker const = torch.randn(1, 4, 6, 6) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker class ArgPlusConst(torch.nn.Module): 331*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 332*da0073e9SAndroid Build Coastguard Worker return arg + const 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker class ConstPlusArg(torch.nn.Module): 335*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 336*da0073e9SAndroid Build Coastguard Worker return const + arg 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker arg_contig = torch.randn(2, 4, 6, 6) 339*da0073e9SAndroid Build Coastguard Worker arg_nhwc = nhwc(torch.randn(2, 4, 6, 6)) 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker for mod_class in [ArgPlusConst, ConstPlusArg]: 342*da0073e9SAndroid Build Coastguard Worker for use_nhwc in [False, True]: 343*da0073e9SAndroid Build Coastguard Worker with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc): 344*da0073e9SAndroid Build Coastguard Worker arg = arg_nhwc if use_nhwc else arg_contig 345*da0073e9SAndroid Build Coastguard Worker memory_format = ( 346*da0073e9SAndroid Build Coastguard Worker torch.channels_last if use_nhwc else torch.contiguous_format 347*da0073e9SAndroid Build Coastguard Worker ) 348*da0073e9SAndroid Build Coastguard Worker self.check(mod_class(), arg, expected_memory_format=memory_format) 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker def test_hardtanh(self): 351*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0]) 352*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.Hardtanh(), inp) 353*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.Hardtanh(0.0, 6.0), inp) 354*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "hardtanh with args"): 355*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.Hardtanh(0.0, 5.0), inp) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker def test_softmax(self): 358*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]]) 359*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.Softmax(), inp) 360*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.Softmax(dim=0), inp) 361*da0073e9SAndroid Build Coastguard Worker # Test flexible size 362*da0073e9SAndroid Build Coastguard Worker self.check( 363*da0073e9SAndroid Build Coastguard Worker torch.nn.Softmax(), 364*da0073e9SAndroid Build Coastguard Worker inp, 365*da0073e9SAndroid Build Coastguard Worker convert_args=[torch.zeros(0, 0)], 366*da0073e9SAndroid Build Coastguard Worker ) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker def test_to(self): 369*da0073e9SAndroid Build Coastguard Worker class ToCPU(torch.nn.Module): 370*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 371*da0073e9SAndroid Build Coastguard Worker super().__init__() 372*da0073e9SAndroid Build Coastguard Worker self.prelu = torch.nn.PReLU() 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 375*da0073e9SAndroid Build Coastguard Worker y = x.to("cpu") 376*da0073e9SAndroid Build Coastguard Worker # add prelu since input operand can't be output 377*da0073e9SAndroid Build Coastguard Worker return self.prelu(y) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker arg = torch.randn(1, 2, 3, 3) 380*da0073e9SAndroid Build Coastguard Worker self.check(ToCPU(), arg) 381*da0073e9SAndroid Build Coastguard Worker # Test flexible size 382*da0073e9SAndroid Build Coastguard Worker self.check( 383*da0073e9SAndroid Build Coastguard Worker ToCPU(), 384*da0073e9SAndroid Build Coastguard Worker arg, 385*da0073e9SAndroid Build Coastguard Worker convert_args=[torch.zeros(1, 2, 0, 0)], 386*da0073e9SAndroid Build Coastguard Worker ) 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker def test_detach(self): 389*da0073e9SAndroid Build Coastguard Worker class DetachModule(torch.nn.Module): 390*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 391*da0073e9SAndroid Build Coastguard Worker y = x.detach() 392*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.relu(y) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker self.check(DetachModule(), torch.randn(1, 2, 3, 3)) 395*da0073e9SAndroid Build Coastguard Worker self.check( 396*da0073e9SAndroid Build Coastguard Worker DetachModule(), 397*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 2, 3, 3), 398*da0073e9SAndroid Build Coastguard Worker convert_args=[torch.zeros(1, 2, 0, 0)], 399*da0073e9SAndroid Build Coastguard Worker ) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker def test_log_softmax(self): 402*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 10) 403*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.LogSoftmax(), inp) 404*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.LogSoftmax(0), inp) 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker def test_mean(self): 407*da0073e9SAndroid Build Coastguard Worker class MeanModule(torch.nn.Module): 408*da0073e9SAndroid Build Coastguard Worker def __init__(self, dim, keep=False): 409*da0073e9SAndroid Build Coastguard Worker super().__init__() 410*da0073e9SAndroid Build Coastguard Worker self.dim = dim 411*da0073e9SAndroid Build Coastguard Worker self.keep = keep 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker def forward(self, t): 414*da0073e9SAndroid Build Coastguard Worker return torch.mean(t, dim=self.dim, keepdim=self.keep) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker self.check(MeanModule(0), torch.randn(2, 3)) 417*da0073e9SAndroid Build Coastguard Worker self.check(MeanModule(1), torch.randn(2, 3)) 418*da0073e9SAndroid Build Coastguard Worker self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6)) 419*da0073e9SAndroid Build Coastguard Worker self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6))) 420*da0073e9SAndroid Build Coastguard Worker self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6))) 421*da0073e9SAndroid Build Coastguard Worker self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6))) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker def test_max_pool2d(self): 424*da0073e9SAndroid Build Coastguard Worker for name, inp in self.float_and_quant_and_nhwc( 425*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 12, 16), 0.3, 128 426*da0073e9SAndroid Build Coastguard Worker ): 427*da0073e9SAndroid Build Coastguard Worker with self.subTest(name): 428*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.MaxPool2d(2), inp) 429*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.MaxPool2d((3, 4)), inp) 430*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker def test_avg_pool2d(self): 433*da0073e9SAndroid Build Coastguard Worker for name, inp in self.float_and_quant_and_nhwc( 434*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 12, 16), 0.3, 128 435*da0073e9SAndroid Build Coastguard Worker ): 436*da0073e9SAndroid Build Coastguard Worker with self.subTest(name): 437*da0073e9SAndroid Build Coastguard Worker atol_rtol = None 438*da0073e9SAndroid Build Coastguard Worker limit = None 439*da0073e9SAndroid Build Coastguard Worker convert_dims = (2, 3, 0, 0) 440*da0073e9SAndroid Build Coastguard Worker convert_arg = torch.zeros(*convert_dims) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker for model in ( 443*da0073e9SAndroid Build Coastguard Worker torch.nn.AvgPool2d(2), 444*da0073e9SAndroid Build Coastguard Worker torch.nn.AvgPool2d((3, 4)), 445*da0073e9SAndroid Build Coastguard Worker torch.nn.AvgPool2d((3, 4), (1, 2)), 446*da0073e9SAndroid Build Coastguard Worker ): 447*da0073e9SAndroid Build Coastguard Worker if "quant" in name: 448*da0073e9SAndroid Build Coastguard Worker atol_rtol = (1, 0) 449*da0073e9SAndroid Build Coastguard Worker limit = model(inp).numel() 450*da0073e9SAndroid Build Coastguard Worker convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128) 451*da0073e9SAndroid Build Coastguard Worker if "nhwc" in name: 452*da0073e9SAndroid Build Coastguard Worker convert_arg = nhwc(convert_arg) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker self.check(model, inp, atol_rtol=atol_rtol, limit=limit) 455*da0073e9SAndroid Build Coastguard Worker self.check( 456*da0073e9SAndroid Build Coastguard Worker model, 457*da0073e9SAndroid Build Coastguard Worker inp, 458*da0073e9SAndroid Build Coastguard Worker convert_args=[convert_arg], 459*da0073e9SAndroid Build Coastguard Worker atol_rtol=atol_rtol, 460*da0073e9SAndroid Build Coastguard Worker limit=limit, 461*da0073e9SAndroid Build Coastguard Worker ) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker def test_adaptive_avg_pool2d(self): 464*da0073e9SAndroid Build Coastguard Worker for name, inp in self.float_and_quant_and_nhwc( 465*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 12, 16), 0.3, 128 466*da0073e9SAndroid Build Coastguard Worker ): 467*da0073e9SAndroid Build Coastguard Worker with self.subTest(name): 468*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp) 469*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "with output size"): 470*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker def test_upsample_nearest2d(self): 473*da0073e9SAndroid Build Coastguard Worker convert_args = dict( 474*da0073e9SAndroid Build Coastguard Worker self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128) 475*da0073e9SAndroid Build Coastguard Worker ) 476*da0073e9SAndroid Build Coastguard Worker for name, inp in self.float_and_quant_and_nhwc( 477*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 12, 16), 0.3, 128 478*da0073e9SAndroid Build Coastguard Worker ): 479*da0073e9SAndroid Build Coastguard Worker with self.subTest(name): 480*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp) 481*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp) 482*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp) 483*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp) 484*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp) 485*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker self.check( 488*da0073e9SAndroid Build Coastguard Worker torch.nn.UpsamplingNearest2d(size=(24, 32)), 489*da0073e9SAndroid Build Coastguard Worker inp, 490*da0073e9SAndroid Build Coastguard Worker convert_args=[convert_args[name]], 491*da0073e9SAndroid Build Coastguard Worker ) 492*da0073e9SAndroid Build Coastguard Worker self.check( 493*da0073e9SAndroid Build Coastguard Worker torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), 494*da0073e9SAndroid Build Coastguard Worker inp, 495*da0073e9SAndroid Build Coastguard Worker convert_args=[convert_args[name]], 496*da0073e9SAndroid Build Coastguard Worker ) 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker def test_linear(self): 499*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(29) 500*da0073e9SAndroid Build Coastguard Worker self.check(torch.nn.Linear(16, 32), torch.randn(2, 16)) 501*da0073e9SAndroid Build Coastguard Worker self.check( 502*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(16, 32), 503*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 16), 504*da0073e9SAndroid Build Coastguard Worker convert_args=[torch.zeros(0, 16)], 505*da0073e9SAndroid Build Coastguard Worker ) 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker def test_conv2d(self): 508*da0073e9SAndroid Build Coastguard Worker cases = [ 509*da0073e9SAndroid Build Coastguard Worker # in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name 510*da0073e9SAndroid Build Coastguard Worker (4, 8, (3, 3), 1, 0, 1, 1, (2, 4, 16, 16), "3x3"), # noqa: E201,E241 511*da0073e9SAndroid Build Coastguard Worker (4, 8, (3, 3), 1, 0, 1, 0, (2, 4, 16, 16), "3x3nobias"), # noqa: E201,E241 512*da0073e9SAndroid Build Coastguard Worker (4, 16, (3, 3), 1, 1, 1, 1, (2, 4, 16, 16), "3x3p1"), # noqa: E201,E241 513*da0073e9SAndroid Build Coastguard Worker (8, 8, (3, 3), 2, 0, 1, 1, (2, 8, 16, 16), "3x3s2"), # noqa: E201,E241 514*da0073e9SAndroid Build Coastguard Worker (4, 8, (5, 5), 1, 0, 1, 1, (2, 4, 16, 16), "5x5"), # noqa: E201,E241 515*da0073e9SAndroid Build Coastguard Worker (4, 4, (3, 3), 1, 0, 4, 1, (2, 4, 16, 16), "3x3dw"), # noqa: E201,E241 516*da0073e9SAndroid Build Coastguard Worker (8, 4, (1, 1), 1, 0, 1, 1, (2, 8, 16, 16), "1x1"), # noqa: E201,E241 517*da0073e9SAndroid Build Coastguard Worker ] 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]: 520*da0073e9SAndroid Build Coastguard Worker for case in cases: 521*da0073e9SAndroid Build Coastguard Worker ( 522*da0073e9SAndroid Build Coastguard Worker in_ch, 523*da0073e9SAndroid Build Coastguard Worker out_ch, 524*da0073e9SAndroid Build Coastguard Worker kernel, 525*da0073e9SAndroid Build Coastguard Worker stride, 526*da0073e9SAndroid Build Coastguard Worker padding, 527*da0073e9SAndroid Build Coastguard Worker groups, 528*da0073e9SAndroid Build Coastguard Worker bias, 529*da0073e9SAndroid Build Coastguard Worker input_dim, 530*da0073e9SAndroid Build Coastguard Worker name, 531*da0073e9SAndroid Build Coastguard Worker ) = case 532*da0073e9SAndroid Build Coastguard Worker with self.subTest(f"{kind}-{name}"): 533*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(input_dim) 534*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Conv2d( 535*da0073e9SAndroid Build Coastguard Worker in_ch, 536*da0073e9SAndroid Build Coastguard Worker out_ch, 537*da0073e9SAndroid Build Coastguard Worker kernel, 538*da0073e9SAndroid Build Coastguard Worker stride, 539*da0073e9SAndroid Build Coastguard Worker padding, 540*da0073e9SAndroid Build Coastguard Worker groups=groups, 541*da0073e9SAndroid Build Coastguard Worker bias=bool(bias), 542*da0073e9SAndroid Build Coastguard Worker ) 543*da0073e9SAndroid Build Coastguard Worker output_size = model(inp).numel() 544*da0073e9SAndroid Build Coastguard Worker atol_rtol = None 545*da0073e9SAndroid Build Coastguard Worker limit = None 546*da0073e9SAndroid Build Coastguard Worker convert_dims = (0, in_ch, 0, 0) 547*da0073e9SAndroid Build Coastguard Worker convert_arg = torch.zeros(*convert_dims) 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker if "quant" in kind: 550*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential(model) 551*da0073e9SAndroid Build Coastguard Worker model.eval() 552*da0073e9SAndroid Build Coastguard Worker model.qconfig = torch.ao.quantization.get_default_qconfig( 553*da0073e9SAndroid Build Coastguard Worker "qnnpack" 554*da0073e9SAndroid Build Coastguard Worker ) 555*da0073e9SAndroid Build Coastguard Worker model = torch.ao.quantization.prepare(model) 556*da0073e9SAndroid Build Coastguard Worker model(inp) 557*da0073e9SAndroid Build Coastguard Worker model = torch.ao.quantization.convert(model) 558*da0073e9SAndroid Build Coastguard Worker inp = qpt(inp, 1.0 / 16, 128) 559*da0073e9SAndroid Build Coastguard Worker # I've seen numerical differences between QNNPACK and NNAPI, 560*da0073e9SAndroid Build Coastguard Worker # but never more than 1 quantum, and never more than ~1% of 561*da0073e9SAndroid Build Coastguard Worker # the output in this test. 562*da0073e9SAndroid Build Coastguard Worker atol_rtol = (1, 0) 563*da0073e9SAndroid Build Coastguard Worker limit = output_size * 0.03 564*da0073e9SAndroid Build Coastguard Worker convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128) 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker if "nhwc" in kind: 567*da0073e9SAndroid Build Coastguard Worker inp = nhwc(inp) 568*da0073e9SAndroid Build Coastguard Worker convert_arg = nhwc(convert_arg) 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker self.check(model, inp, atol_rtol=atol_rtol, limit=limit) 571*da0073e9SAndroid Build Coastguard Worker self.check( 572*da0073e9SAndroid Build Coastguard Worker model, 573*da0073e9SAndroid Build Coastguard Worker inp, 574*da0073e9SAndroid Build Coastguard Worker convert_args=[convert_arg], 575*da0073e9SAndroid Build Coastguard Worker atol_rtol=atol_rtol, 576*da0073e9SAndroid Build Coastguard Worker limit=limit, 577*da0073e9SAndroid Build Coastguard Worker ) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker def test_conv2d_transpose(self): 580*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(29) 581*da0073e9SAndroid Build Coastguard Worker in_ch, out_ch, kernel = (5, 7, (2, 2)) 582*da0073e9SAndroid Build Coastguard Worker input_dim = (4, 5, 3, 3) 583*da0073e9SAndroid Build Coastguard Worker convert_dims = input_dim[:2] + (0, 0) 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]: 586*da0073e9SAndroid Build Coastguard Worker with self.subTest(kind): 587*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(input_dim) 588*da0073e9SAndroid Build Coastguard Worker model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel) 589*da0073e9SAndroid Build Coastguard Worker output_size = model(inp).numel() 590*da0073e9SAndroid Build Coastguard Worker atol_rtol = (0.0002, 0) 591*da0073e9SAndroid Build Coastguard Worker limit = None 592*da0073e9SAndroid Build Coastguard Worker convert_arg = torch.zeros(*convert_dims) 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker if "quant" in kind: 595*da0073e9SAndroid Build Coastguard Worker model = torch.ao.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel) 596*da0073e9SAndroid Build Coastguard Worker model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") 597*da0073e9SAndroid Build Coastguard Worker inp = qpt(inp, 1.0 / 16, 128) 598*da0073e9SAndroid Build Coastguard Worker # I've seen numerical differences between QNNPACK and NNAPI, 599*da0073e9SAndroid Build Coastguard Worker # but never more than 1 quantum, and never more than ~10% of 600*da0073e9SAndroid Build Coastguard Worker # the output in this test. 601*da0073e9SAndroid Build Coastguard Worker atol_rtol = (1, 0) 602*da0073e9SAndroid Build Coastguard Worker limit = output_size * 0.1 603*da0073e9SAndroid Build Coastguard Worker convert_arg = qpt(convert_arg, 1.0 / 16, 128) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker if "nhwc" in kind: 606*da0073e9SAndroid Build Coastguard Worker inp = nhwc(inp) 607*da0073e9SAndroid Build Coastguard Worker convert_arg = nhwc(convert_arg) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker self.check(model, inp, atol_rtol=atol_rtol, limit=limit) 610*da0073e9SAndroid Build Coastguard Worker self.check( 611*da0073e9SAndroid Build Coastguard Worker model, 612*da0073e9SAndroid Build Coastguard Worker inp, 613*da0073e9SAndroid Build Coastguard Worker convert_args=[convert_arg], 614*da0073e9SAndroid Build Coastguard Worker atol_rtol=atol_rtol, 615*da0073e9SAndroid Build Coastguard Worker limit=limit, 616*da0073e9SAndroid Build Coastguard Worker ) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker def test_qadd(self): 619*da0073e9SAndroid Build Coastguard Worker func = torch.ao.nn.quantized.QFunctional() 620*da0073e9SAndroid Build Coastguard Worker func.scale = 0.5 621*da0073e9SAndroid Build Coastguard Worker func.zero_point = 120 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker class AddMod(torch.nn.Module): 624*da0073e9SAndroid Build Coastguard Worker def forward(self, lhs, rhs): 625*da0073e9SAndroid Build Coastguard Worker return func.add(lhs, rhs) 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker class AddReluMod(torch.nn.Module): 628*da0073e9SAndroid Build Coastguard Worker def forward(self, lhs, rhs): 629*da0073e9SAndroid Build Coastguard Worker return func.add_relu(lhs, rhs) 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker class MulMod(torch.nn.Module): 632*da0073e9SAndroid Build Coastguard Worker def forward(self, lhs, rhs): 633*da0073e9SAndroid Build Coastguard Worker return func.mul(lhs, rhs) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker for name, mod in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]: 636*da0073e9SAndroid Build Coastguard Worker with self.subTest(name): 637*da0073e9SAndroid Build Coastguard Worker self.check( 638*da0073e9SAndroid Build Coastguard Worker mod(), 639*da0073e9SAndroid Build Coastguard Worker [ 640*da0073e9SAndroid Build Coastguard Worker qpt([1.0, 2.0], 0.25, 128), 641*da0073e9SAndroid Build Coastguard Worker qpt([3.0, 4.0], 0.25, 128), 642*da0073e9SAndroid Build Coastguard Worker ], 643*da0073e9SAndroid Build Coastguard Worker ) 644*da0073e9SAndroid Build Coastguard Worker self.check( 645*da0073e9SAndroid Build Coastguard Worker mod(), 646*da0073e9SAndroid Build Coastguard Worker [ 647*da0073e9SAndroid Build Coastguard Worker qpt([[1.0, 2.0]], 0.25, 128), 648*da0073e9SAndroid Build Coastguard Worker qpt([[3.0, 4.0]], 0.25, 128), 649*da0073e9SAndroid Build Coastguard Worker ], 650*da0073e9SAndroid Build Coastguard Worker convert_args=[ 651*da0073e9SAndroid Build Coastguard Worker qpt([[1.0, 2.0]], 0.25, 128), 652*da0073e9SAndroid Build Coastguard Worker qpt(torch.zeros((1, 2)), 0.25, 128), 653*da0073e9SAndroid Build Coastguard Worker ], 654*da0073e9SAndroid Build Coastguard Worker ) 655*da0073e9SAndroid Build Coastguard Worker self.check( 656*da0073e9SAndroid Build Coastguard Worker mod(), 657*da0073e9SAndroid Build Coastguard Worker [ 658*da0073e9SAndroid Build Coastguard Worker qpt([[1.0, 2.0]], 0.25, 128), 659*da0073e9SAndroid Build Coastguard Worker qpt([[3.0, 4.0]], 0.25, 128), 660*da0073e9SAndroid Build Coastguard Worker ], 661*da0073e9SAndroid Build Coastguard Worker convert_args=[ 662*da0073e9SAndroid Build Coastguard Worker qpt(torch.zeros((1, 2)), 0.25, 128), 663*da0073e9SAndroid Build Coastguard Worker qpt([[3.0, 4.0]], 0.25, 128), 664*da0073e9SAndroid Build Coastguard Worker ], 665*da0073e9SAndroid Build Coastguard Worker ) 666*da0073e9SAndroid Build Coastguard Worker self.check( 667*da0073e9SAndroid Build Coastguard Worker mod(), 668*da0073e9SAndroid Build Coastguard Worker [ 669*da0073e9SAndroid Build Coastguard Worker qpt([[1.0, 2.0]], 0.25, 128), 670*da0073e9SAndroid Build Coastguard Worker qpt([[3.0, 4.0]], 0.25, 128), 671*da0073e9SAndroid Build Coastguard Worker ], 672*da0073e9SAndroid Build Coastguard Worker convert_args=[ 673*da0073e9SAndroid Build Coastguard Worker qpt(torch.zeros((1, 2)), 0.25, 128), 674*da0073e9SAndroid Build Coastguard Worker qpt(torch.zeros((1, 2)), 0.25, 128), 675*da0073e9SAndroid Build Coastguard Worker ], 676*da0073e9SAndroid Build Coastguard Worker ) 677*da0073e9SAndroid Build Coastguard Worker # NOTE: NNAPI qadd supports broadcast, but PT does not. 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker def test_qlinear(self): 680*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(29) 681*da0073e9SAndroid Build Coastguard Worker weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8) 682*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(16) 683*da0073e9SAndroid Build Coastguard Worker mod = torch.ao.nn.quantized.Linear(32, 16) 684*da0073e9SAndroid Build Coastguard Worker mod.set_weight_bias(weight, bias) 685*da0073e9SAndroid Build Coastguard Worker inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8) 686*da0073e9SAndroid Build Coastguard Worker self.check(mod, inp) 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker def test_seblock_mul(self): 689*da0073e9SAndroid Build Coastguard Worker class MulModel(torch.nn.Module): 690*da0073e9SAndroid Build Coastguard Worker def forward(self, lhs, rhs): 691*da0073e9SAndroid Build Coastguard Worker return lhs * rhs 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker self.check( 694*da0073e9SAndroid Build Coastguard Worker MulModel(), 695*da0073e9SAndroid Build Coastguard Worker [ 696*da0073e9SAndroid Build Coastguard Worker nhwc(torch.randn(2, 3, 4, 4)), 697*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 3, 1, 1), 698*da0073e9SAndroid Build Coastguard Worker ], 699*da0073e9SAndroid Build Coastguard Worker ) 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker def test_multi_output(self): 702*da0073e9SAndroid Build Coastguard Worker class MultiModel(torch.nn.Module): 703*da0073e9SAndroid Build Coastguard Worker def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]: 704*da0073e9SAndroid Build Coastguard Worker the_sum = lhs + rhs 705*da0073e9SAndroid Build Coastguard Worker the_diff = lhs - rhs 706*da0073e9SAndroid Build Coastguard Worker return the_sum, the_diff 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])]) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 712*da0073e9SAndroid Build Coastguard Worker run_tests() 713