1# mypy: allow-untyped-defs 2import array 3import enum 4import functools 5import logging 6import operator 7import struct 8import sys 9from typing import List, NamedTuple, Optional, Tuple 10 11import torch 12 13 14# TODO: Add type annotations 15# TODO: Check tensor types for ops 16 17 18LOG = logging.getLogger("nnapi_serialize") 19 20 21class NNAPI_OperandCode: 22 FLOAT32 = 0 23 INT32 = 1 24 UINT32 = 2 25 TENSOR_FLOAT32 = 3 26 TENSOR_INT32 = 4 27 TENSOR_QUANT8_ASYMM = 5 28 BOOL = 6 29 TENSOR_QUANT16_SYMM = 7 30 TENSOR_FLOAT16 = 8 31 TENSOR_BOOL8 = 9 32 FLOAT16 = 10 33 TENSOR_QUANT8_SYMM_PER_CHANNEL = 11 34 TENSOR_QUANT16_ASYMM = 12 35 36 37class NNAPI_OperationCode: 38 ADD = 0 39 AVERAGE_POOL_2D = 1 40 CONCATENATION = 2 41 CONV_2D = 3 42 DEPTHWISE_CONV_2D = 4 43 DEPTH_TO_SPACE = 5 44 DEQUANTIZE = 6 45 EMBEDDING_LOOKUP = 7 46 FLOOR = 8 47 FULLY_CONNECTED = 9 48 HASHTABLE_LOOKUP = 10 49 L2_NORMALIZATION = 11 50 L2_POOL_2D = 12 51 LOCAL_RESPONSE_NORMALIZATION = 13 52 LOGISTIC = 14 53 LSH_PROJECTION = 15 54 LSTM = 16 55 MAX_POOL_2D = 17 56 MUL = 18 57 RELU = 19 58 RELU1 = 20 59 RELU6 = 21 60 RESHAPE = 22 61 RESIZE_BILINEAR = 23 62 RNN = 24 63 SOFTMAX = 25 64 SPACE_TO_DEPTH = 26 65 SVDF = 27 66 TANH = 28 67 BATCH_TO_SPACE_ND = 29 68 DIV = 30 69 MEAN = 31 70 PAD = 32 71 SPACE_TO_BATCH_ND = 33 72 SQUEEZE = 34 73 STRIDED_SLICE = 35 74 SUB = 36 75 TRANSPOSE = 37 76 ABS = 38 77 ARGMAX = 39 78 ARGMIN = 40 79 AXIS_ALIGNED_BBOX_TRANSFORM = 41 80 BIDIRECTIONAL_SEQUENCE_LSTM = 42 81 BIDIRECTIONAL_SEQUENCE_RNN = 43 82 BOX_WITH_NMS_LIMIT = 44 83 CAST = 45 84 CHANNEL_SHUFFLE = 46 85 DETECTION_POSTPROCESSING = 47 86 EQUAL = 48 87 EXP = 49 88 EXPAND_DIMS = 50 89 GATHER = 51 90 GENERATE_PROPOSALS = 52 91 GREATER = 53 92 GREATER_EQUAL = 54 93 GROUPED_CONV_2D = 55 94 HEATMAP_MAX_KEYPOINT = 56 95 INSTANCE_NORMALIZATION = 57 96 LESS = 58 97 LESS_EQUAL = 59 98 LOG = 60 99 LOGICAL_AND = 61 100 LOGICAL_NOT = 62 101 LOGICAL_OR = 63 102 LOG_SOFTMAX = 64 103 MAXIMUM = 65 104 MINIMUM = 66 105 NEG = 67 106 NOT_EQUAL = 68 107 PAD_V2 = 69 108 POW = 70 109 PRELU = 71 110 QUANTIZE = 72 111 QUANTIZED_16BIT_LSTM = 73 112 RANDOM_MULTINOMIAL = 74 113 REDUCE_ALL = 75 114 REDUCE_ANY = 76 115 REDUCE_MAX = 77 116 REDUCE_MIN = 78 117 REDUCE_PROD = 79 118 REDUCE_SUM = 80 119 ROI_ALIGN = 81 120 ROI_POOLING = 82 121 RSQRT = 83 122 SELECT = 84 123 SIN = 85 124 SLICE = 86 125 SPLIT = 87 126 SQRT = 88 127 TILE = 89 128 TOPK_V2 = 90 129 TRANSPOSE_CONV_2D = 91 130 UNIDIRECTIONAL_SEQUENCE_LSTM = 92 131 UNIDIRECTIONAL_SEQUENCE_RNN = 93 132 RESIZE_NEAREST_NEIGHBOR = 94 133 134 135class NNAPI_FuseCode: 136 FUSED_NONE = 0 137 FUSED_RELU = 1 138 FUSED_RELU1 = 2 139 FUSED_RELU6 = 3 140 141 142class OperandValueSourceType: 143 IMMEDIATE = 0 144 NUMBERED_BUFFER = 2 145 NUMBERED_MEMORY = 3 146 147 148# Scalar types that appear explicitly in models. 149# These must be kept in sync with 150# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. 151# TODO: Expose these directly to Python to avoid maintaining this list. 152class TorchScalarTypes(enum.Enum): 153 QUINT8 = 13 154 155 156def approx_equal(lhs, rhs, tolerance=1e-6): 157 return abs(lhs - rhs) <= tolerance * min(lhs, rhs) 158 159 160def tensor_size(op_type, dims): 161 ITEM_SIZES = { 162 NNAPI_OperandCode.TENSOR_FLOAT32: 4, 163 NNAPI_OperandCode.TENSOR_INT32: 4, 164 NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1, 165 NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2, 166 NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2, 167 } 168 size = ITEM_SIZES[op_type] 169 for d in dims: 170 size *= d 171 return size 172 173 174def change_element(tup, index, value): 175 ls = list(tup) 176 ls[index] = value 177 return tuple(ls) 178 179 180class ConvPoolArgs2d(NamedTuple): 181 """Configuration arguments for a convolution.""" 182 183 kernel_h: int 184 kernel_w: int 185 stride_h: int 186 stride_w: int 187 pad_t: int 188 pad_b: int 189 pad_l: int 190 pad_r: int 191 dilation_h: int 192 dilation_w: int 193 group: int 194 195 196class DimOrder(enum.Enum): 197 PRESUMED_CONTIGUOUS = 0 198 CHANNELS_LAST = 1 199 SCALAR_OR_VECTOR = 2 200 UNKNOWN_CONSTANT = 999 201 202 203class Operand(NamedTuple): 204 """Represenation of an NNAPI operand.""" 205 206 # NNAPI operand type. One of NNAPI_OperandCode. 207 # TODO: Make this an enum. 208 op_type: int 209 210 # This is always the PyTorch shape, which is NCHW for feature maps. 211 # The actual NNAPI operand might have a transposed shape. 212 # we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes 213 shape: Tuple[int, ...] 214 215 # Specifies how the shape of the operand that we define in NNAPI 216 # relates to the shape we track above. 217 # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match 218 # the shape of the PyTorch tensor. 219 # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and 220 # the NNAPI operand will be represented explicitly as NHWC. 221 dim_order: DimOrder 222 223 # Quantization params 224 scale: float 225 zero_point: int 226 227 def use_nchw(self): 228 if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS: 229 return True 230 if self.dim_order is DimOrder.CHANNELS_LAST: 231 return False 232 raise Exception("Unknown dim order") # noqa: TRY002 233 234 235def broadcast_shapes(shape1, shape2): 236 assert len(shape1) > 0 237 assert len(shape2) > 0 238 s1 = list(shape1) 239 s2 = list(shape2) 240 # TODO: Support non-equal-rank broadcast where semantics match. 241 # This can be tricky for NHWC tensors because dimension orders 242 # don't match between PT and NNAPI, even though semantics match. 243 if len(s1) > len(s2): 244 # s2 = [1] * (len(s1) - len(s2)) + s2 245 raise Exception( # noqa: TRY002 246 "Non-equal-rank broadcast is not supported yet." 247 ) # noqa: TRY002 248 if len(s2) > len(s1): 249 # s3 = [1] * (len(s2) - len(s1)) + s1 250 raise Exception( # noqa: TRY002 251 "Non-equal-rank broadcast is not supported yet." 252 ) # noqa: TRY002 253 ret = [] 254 for d1, d2 in zip(s1, s2): 255 if d1 == 1: 256 ret.append(d2) 257 elif d2 == 1: 258 ret.append(d1) 259 elif d1 == d2: 260 ret.append(d1) 261 else: 262 raise Exception( # noqa: TRY002 263 f"Cannot broadcast shapes: {shape1} and {shape2}" 264 ) # noqa: TRY002 265 return tuple(ret) 266 267 268def get_conv_pool_shape(image_shape, args, out_ch, transpose): 269 batch, in_c, in_h, in_w = image_shape 270 271 # TODO: Handle dilation 272 if args.dilation_h != 1 or args.dilation_w != 1: 273 raise Exception("Dilation not supported yet.") # noqa: TRY002 274 275 if transpose: 276 out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b 277 out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l 278 else: 279 out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1 280 out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1 281 282 # Handle variable-sized tensors. 283 if in_h == 0: 284 out_h = 0 285 if in_w == 0: 286 out_w = 0 287 288 out_shape = (batch, out_ch, out_h, out_w) 289 return out_shape 290 291 292def fix_shape(shape, dim_order): 293 # Return the actual shape that an operand should have in NNAPI, 294 # given a PyTorch shape and dimension order. This is where we 295 # convert from PyTorch's "always NCHW" shape to explicit NHWC. 296 if dim_order is DimOrder.PRESUMED_CONTIGUOUS: 297 return shape 298 if dim_order is DimOrder.CHANNELS_LAST: 299 return tuple([shape[0]] + list(shape[2:]) + [shape[1]]) 300 if dim_order is DimOrder.SCALAR_OR_VECTOR: 301 assert len(shape) == 0 or len(shape) == 1 302 return shape 303 if dim_order is DimOrder.UNKNOWN_CONSTANT: 304 # XXX think this through 305 return shape 306 raise Exception(f"Bad dim_order: {dim_order!r}.") # noqa: TRY002 307 308 309def reverse_map_dim(dim_order, d): 310 # Return the original PyTorch dimension position for a given dimension. 311 # d should be the dimension that NNAPI will see. 312 # reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x 313 # reverse_map_dim(CHANNELS_LAST, 3) == 1 314 if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR): 315 return d 316 assert dim_order is DimOrder.CHANNELS_LAST 317 return [0, 2, 3, 1][d] 318 319 320def flex_name(op_id, dim): 321 # Return the local variable name for the computed flexible size 322 # for a given op and dimension. 323 return f"s_{op_id}_{dim}" 324 325 326class _NnapiSerializer: 327 def __init__(self, config, use_int16_for_qint16=False): 328 self.operands = [] 329 self.values = [] 330 self.operations = [] 331 self.value_data = [] 332 self.operation_args = [] 333 self.inputs = [] 334 self.outputs = [] 335 self.flexible_shape_computation_lines = [] 336 337 self.modules = {} 338 self.constants = {} 339 self.tensor_sequences = {} 340 self.jitval_operand_map = {} 341 self.cached_immediates = {} 342 self.used_weights = [] 343 self.weight_offset = 0 344 self.use_int16_for_qint16 = use_int16_for_qint16 345 346 if config is None: 347 config = {} 348 349 def get_next_operand_id(self): 350 return len(self.operands) 351 352 # Add a tensor operand corresponding to a JIT Value. 353 # Returns the NNAPI operand ID. Can be looked up later with 354 # get_tensor_operand_by_jitval. 355 def add_tensor_operand(self, jitval, oper): 356 assert isinstance(oper, Operand) 357 if jitval in self.jitval_operand_map: 358 raise Exception(f"Duplicate tensor: {jitval!r}") # noqa: TRY002 359 360 operand_id = self.get_next_operand_id() 361 self.operands.append(oper) 362 self.jitval_operand_map[jitval] = operand_id 363 return operand_id 364 365 # Add a tensor operand that does not correspond to a JIT Value. 366 # Useful for cases where multiple NNAPI operands are required 367 # to implement one JIT IR node. Returns the NNAPI operand ID. 368 def add_anonymous_tensor_operand(self, oper): 369 assert isinstance(oper, Operand) 370 operand_id = self.get_next_operand_id() 371 self.operands.append(oper) 372 return operand_id 373 374 def torch_tensor_to_operand(self, tensor, dim_order): 375 dtype = str(tensor.dtype).replace("torch.", "") 376 scale = 0.0 377 zero_point = 0 378 if dtype == "float32": 379 op_type = NNAPI_OperandCode.TENSOR_FLOAT32 380 elif dtype == "int32": 381 op_type = NNAPI_OperandCode.TENSOR_INT32 382 elif dtype == "quint8": 383 op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM 384 scale = tensor.q_scale() 385 zero_point = tensor.q_zero_point() 386 elif dtype == "qint32": 387 op_type = NNAPI_OperandCode.TENSOR_INT32 388 scale = tensor.q_scale() 389 zero_point = tensor.q_zero_point() 390 assert zero_point == 0 391 elif dtype == "int16": 392 if self.use_int16_for_qint16: 393 nnapi_dtype = getattr(tensor, "nnapi_dtype", None) 394 op_codes = ( 395 NNAPI_OperandCode.TENSOR_QUANT16_SYMM, 396 NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, 397 ) 398 if nnapi_dtype in op_codes: 399 op_type = nnapi_dtype 400 scale = tensor.nnapi_scale 401 zero_point = tensor.nnapi_zero_point 402 else: 403 raise Exception( # noqa: TRY002 404 f"`nnapi_type` needs to be one of {op_codes} for `int16`" 405 ) 406 else: 407 raise Exception( # noqa: TRY002 408 "`int16` isn't supported. If you're trying to represent NNAPI" 409 " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" 410 ) 411 else: 412 raise Exception( # noqa: TRY002 413 f"Can't handle input with dtype '{tensor.dtype}'" 414 ) # noqa: TRY002 415 return Operand( 416 shape=tuple(tensor.shape), 417 op_type=op_type, 418 dim_order=dim_order, 419 scale=scale, 420 zero_point=zero_point, 421 ) 422 423 def add_tensor_operand_for_input(self, arg_idx, jitval, tensor): 424 dim_order = ( 425 DimOrder.CHANNELS_LAST 426 if getattr(tensor, "nnapi_nhwc", False) 427 else DimOrder.PRESUMED_CONTIGUOUS 428 ) 429 toper = self.torch_tensor_to_operand(tensor, dim_order) 430 operand_id = self.add_tensor_operand(jitval, toper) 431 self.inputs.append(operand_id) 432 for dim, size in enumerate(tensor.shape): 433 if size == 0: 434 self.compute_operand_shape( 435 operand_id, dim, f"args[{arg_idx}].shape[{dim}]" 436 ) 437 return operand_id 438 439 def add_tensor_operand_for_weight( 440 self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT 441 ): 442 toper = self.torch_tensor_to_operand(tensor, dim_order) 443 operand_id = len(self.operands) 444 self.operands.append(toper) 445 tsize = tensor_size(toper.op_type, toper.shape) 446 psize = ((tsize - 1) | 0x3) + 1 447 self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER)) 448 buf_num = len(self.used_weights) 449 offset = 0 450 self.value_data.append(struct.pack("iii", buf_num, offset, tsize)) 451 # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor 452 if dim_order == DimOrder.CHANNELS_LAST: 453 tensor = tensor.permute(0, 2, 3, 1) 454 self.used_weights.append(tensor) 455 return operand_id 456 457 def add_immediate_operand(self, code, value, dims): 458 assert isinstance(dims, tuple) 459 cache_key = (code, value) 460 if cache_key not in self.cached_immediates: 461 operand_id = len(self.operands) 462 self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0)) 463 self.values.append((operand_id, OperandValueSourceType.IMMEDIATE)) 464 self.value_data.append(value) 465 self.cached_immediates[cache_key] = operand_id 466 return self.cached_immediates[cache_key] 467 468 def add_immediate_int_scalar(self, value): 469 return self.add_immediate_operand( 470 NNAPI_OperandCode.INT32, struct.pack("i", value), () 471 ) 472 473 def add_immediate_float_scalar(self, value): 474 return self.add_immediate_operand( 475 NNAPI_OperandCode.FLOAT32, struct.pack("f", value), () 476 ) 477 478 def add_immediate_bool_scalar(self, value): 479 return self.add_immediate_operand( 480 NNAPI_OperandCode.BOOL, b"\x01" if value else b"\x00", () 481 ) 482 483 def add_immediate_int_vector(self, value): 484 return self.add_immediate_operand( 485 NNAPI_OperandCode.TENSOR_INT32, 486 array.array("i", value).tobytes(), 487 (len(value),), 488 ) 489 490 def has_operand_for_jitval(self, jitval): 491 return jitval in self.jitval_operand_map 492 493 def get_tensor_operand_by_jitval(self, jitval): 494 operand_id = self.jitval_operand_map[jitval] 495 return (operand_id, self.operands[operand_id]) 496 497 def get_tensor_operand_by_jitval_fixed_size(self, jitval): 498 op_id, oper = self.get_tensor_operand_by_jitval(jitval) 499 for s in oper.shape: 500 if s == 0: 501 # TODO: Improve this error message, possibly after converting 502 # many callsites to support flexible size. 503 raise Exception( # noqa: TRY002 504 "Flexible size is not supported for this operand." 505 ) # noqa: TRY002 506 if s < 0: 507 # runtime flex 508 LOG.warning("Operand %s has runtime flex shape", oper) 509 return op_id, oper 510 511 def get_tensor_operand_or_constant( 512 self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS 513 ): 514 operand_id = self.jitval_operand_map.get(jitval) 515 if operand_id is None: 516 _, value = self.get_constant_value(jitval, "TensorType") 517 operand_id = self.add_tensor_operand_for_weight(value, dim_order) 518 return (operand_id, self.operands[operand_id]) 519 520 def get_tensor_operand_for_weight(self, jitval): 521 _, value = self.get_constant_value(jitval, "TensorType") 522 operand_id = self.add_tensor_operand_for_weight(value) 523 return (operand_id, self.operands[operand_id]) 524 525 def add_operation(self, opcode, inputs, outputs): 526 self.operations.append((opcode, len(inputs), len(outputs))) 527 self.operation_args.extend(inputs + outputs) 528 529 def add_tensor_sequence(self, jitval, values): 530 assert jitval not in self.tensor_sequences 531 self.tensor_sequences[jitval] = values 532 533 def add_constant_value(self, jitval, ctype, value): 534 assert jitval not in self.constants 535 self.constants[jitval] = (ctype, value) 536 537 def get_constant_value(self, jitval, typekind=None): 538 record = self.constants.get(jitval) 539 if record is None: 540 raise Exception( # noqa: TRY002 541 f"Could not find constant value for '{jitval!r}'." 542 ) # noqa: TRY002 543 ctype, _ = record 544 if typekind is not None and ctype.kind() != typekind: 545 raise Exception( # noqa: TRY002 546 f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'" 547 ) 548 return record 549 550 def operand_to_template_torchscript(self, op_id, oper, shape=None): 551 """Return a TorchScript expression to build a template for a given operand.""" 552 if shape is None: 553 shape = oper.shape 554 else: 555 assert len(shape) == len(oper.shape) 556 557 shape_parts = ["("] 558 for d, s in enumerate(shape): 559 if s > 0: 560 # Fixed shape dimension: just add the value. 561 shape_parts.append(str(s)) 562 elif s == 0: 563 # Load time flexible shape dimension: it should have been computed in a variable. 564 shape_parts.append(flex_name(op_id, d)) 565 elif s == -1: 566 # Runtime flexible shape 567 shape_parts.append("0") 568 else: 569 raise Exception( # noqa: TRY002 570 "Unknown dim value, dimensions should be >= -1" 571 ) # noqa: TRY002 572 shape_parts.append(",") 573 shape_parts.append(")") 574 shape_code = "".join(shape_parts) 575 if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: 576 return f"torch.zeros({shape_code}, dtype=torch.float32)" 577 elif oper.op_type == NNAPI_OperandCode.TENSOR_INT32: 578 return f"torch.zeros({shape_code}, dtype=torch.int32)" 579 elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 580 return ( 581 f"torch.quantize_per_tensor(" 582 f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)" 583 f".expand({shape_code}).contiguous()" 584 ) 585 elif oper.op_type in ( 586 NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, 587 NNAPI_OperandCode.TENSOR_QUANT16_SYMM, 588 ): 589 if self.use_int16_for_qint16: 590 return f"torch.zeros({shape_code}, dtype=torch.int16)" 591 else: 592 raise Exception( # noqa: TRY002 593 "`int16` isn't supported. If you're trying to represent NNAPI" 594 " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" 595 ) 596 597 raise Exception( # noqa: TRY002 598 f"Unsupported output operand type: {oper.op_type}" 599 ) # noqa: TRY002 600 601 def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim): 602 self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim)) 603 604 def compute_operand_shape(self, op_id, dim, expr): 605 self.flexible_shape_computation_lines.append( 606 f"{flex_name(op_id, dim)} = {expr}" 607 ) 608 609 def transpose_to_nhwc(self, in_id, oper): 610 if oper.shape[2:] != (1, 1): 611 raise Exception( # noqa: TRY002 612 "Automatic transpose only supported for H,W == 1,1" 613 ) # noqa: TRY002 614 615 out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST) 616 617 inputs = [None] * 2 618 inputs[0] = in_id 619 inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1]) 620 621 outputs = [None] * 1 622 outputs[0] = self.add_anonymous_tensor_operand(out_oper) 623 624 self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs) 625 626 return outputs[0], out_oper 627 628 # Transpose inputs as necessary to allow broadcasting. 629 def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper): 630 if in0_oper.dim_order == in1_oper.dim_order: 631 return in0_id, in0_oper, in1_id, in1_oper 632 633 # Assume NHWC is preferred if there is a mismatch. 634 orders = (in0_oper.dim_order, in1_oper.dim_order) 635 if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST): 636 return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper) 637 if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS): 638 return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper) 639 640 raise Exception( # noqa: TRY002 641 f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}" 642 ) 643 644 def get_size_arg(self, jitval): 645 ctype, value = self.get_constant_value(jitval) 646 if ctype.kind() == "ListType": 647 assert ctype.getElementType().kind() == "IntType" 648 return value 649 raise Exception( # noqa: TRY002 650 f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'" 651 ) # noqa: TRY002 652 653 def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config): 654 pc = [i.item() for i in packed_config] 655 assert pc[0] == 2 656 strides = [pc[1], pc[2]] 657 paddings = [pc[3], pc[4]] 658 dilations = [pc[5], pc[6]] 659 output_padding = [pc[7], pc[8]] 660 group_num = pc[9] 661 662 assert len(pc) == 11 663 assert output_padding == [0, 0] 664 665 return self.get_conv_pool_args_2d_common( 666 kernel_size, strides, paddings, dilations, group_num 667 ) 668 669 def get_conv_pool_args_2d_from_jit( 670 self, kernel_size, stride, padding, dilation=None, group=None 671 ): 672 strides = self.get_size_arg(stride) 673 paddings = self.get_size_arg(padding) 674 if dilation is None: 675 dilations = [1, 1] 676 else: 677 dilations = self.get_size_arg(dilation) 678 if group is not None: 679 _, group_num = self.get_constant_value(group, "IntType") 680 else: 681 group_num = None 682 return self.get_conv_pool_args_2d_common( 683 kernel_size, strides, paddings, dilations, group_num 684 ) 685 686 def get_conv_pool_args_2d_common( 687 self, kernel_size, strides, paddings, dilations, group_num 688 ): 689 kernels = list(kernel_size) 690 691 assert len(kernels) == 2 692 assert len(strides) == 2 693 assert len(paddings) == 2 694 assert len(dilations) == 2 695 696 # NNAPI uses 4 values for padding. 697 ph, pw = paddings 698 real_paddings = [ph, ph, pw, pw] 699 700 return ConvPoolArgs2d( 701 *(kernels + strides + real_paddings + dilations + [group_num]) 702 ) 703 704 def serialize_model(self, model, inputs, return_shapes=None): 705 self.add_immediate_bool_scalar(False) 706 self.add_immediate_bool_scalar(True) 707 708 inp_dim_orders = [] 709 out_dim_orders = [] 710 711 self_jitval = next(model.graph.inputs()) 712 self.add_constant_value(self_jitval, self_jitval.type(), model) 713 714 for arg_idx, (input_value, input_tensor) in enumerate( 715 zip(list(model.graph.inputs())[1:], inputs) 716 ): 717 op_id = self.add_tensor_operand_for_input( 718 arg_idx, input_value, input_tensor 719 ) 720 inp_dim_orders.append(self.operands[op_id].dim_order.value) 721 722 for idx, node in enumerate(model.graph.nodes()): 723 LOG.debug("Processing node #%d: %r", idx, node) 724 self.add_node(node) 725 726 retn = model.graph.return_node() 727 assert retn.inputsSize() == 1 728 assert retn.outputsSize() == 0 729 retn_input = retn.inputsAt(0) 730 template_return_lines = ["return ["] 731 if retn_input.type().kind() == "TensorType": 732 return_values = [retn_input] 733 retval_count = -1 734 elif retn_input.type().kind() == "TupleType": 735 return_values = self.tensor_sequences[retn_input] 736 retval_count = len(return_values) 737 else: 738 raise Exception( # noqa: TRY002 739 f"Unsupported return type: {retn_input.type()}" 740 ) # noqa: TRY002 741 742 if return_shapes is not None: 743 assert len(return_shapes) == len(return_values) 744 for i, v in enumerate(return_values): 745 op_id = self.jitval_operand_map[v] 746 self.outputs.append(op_id) 747 out_dim_orders.append(self.operands[op_id].dim_order.value) 748 shape = return_shapes[i] if return_shapes else None 749 template_return_lines.append( 750 self.operand_to_template_torchscript(op_id, self.operands[op_id], shape) 751 + "," 752 ) 753 template_return_lines.append("]") 754 755 model = [] 756 757 version = 1 758 header = struct.pack( 759 "iiiiii", 760 version, 761 len(self.operands), 762 len(self.values), 763 len(self.operations), 764 len(self.inputs), 765 len(self.outputs), 766 ) 767 model.append(header) 768 769 serialized_values, serialized_value_data = self.serialize_values() 770 771 model.extend( 772 struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands 773 ) 774 model.extend(serialized_values) 775 model.extend(struct.pack("iii", *x) for x in self.operations) 776 777 # Compact the model so we can get its length so far. 778 model = [b"".join(model)] 779 model_offset = len(model[0]) 780 # Model offset is the index into the model (in 32-bit words, not bytes) 781 # of the next dimension we're about to serialize. If it's 0, 782 # generate code to mutate it before passing to NNAPI. 783 assert model_offset % 4 == 0 784 model_offset = int(model_offset / 4) 785 786 for op_id, (_, dims, dim_order, _, _) in enumerate(self.operands): 787 shape = fix_shape(dims, dim_order) 788 for d, s in enumerate(shape): 789 if s == 0: 790 pt_d = reverse_map_dim(dim_order, d) 791 self.flexible_shape_computation_lines.append( 792 f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}" 793 ) 794 model_offset += 1 795 796 # convert runtime flex shape from -1 to 0 797 shape = tuple(d if d != -1 else 0 for d in shape) 798 model.append(self.serialize_ints(shape)) 799 800 model.extend(serialized_value_data) 801 model.append(self.serialize_ints(self.operation_args)) 802 model.append(self.serialize_ints(self.inputs)) 803 model.append(self.serialize_ints(self.outputs)) 804 805 self.flexible_shape_computation_lines.extend(template_return_lines) 806 807 return ( 808 array.array("i", b"".join(model)), 809 self.used_weights, 810 inp_dim_orders, 811 out_dim_orders, 812 self.flexible_shape_computation_lines, 813 retval_count, 814 ) 815 816 def serialize_values(self): 817 serialized_values = [] 818 serialized_value_data = [] 819 assert len(self.values) == len(self.value_data) 820 for (op_index, source_type), data in zip(self.values, self.value_data): 821 source_length = len(data) 822 823 # Pad with 0 bytes out to a multiple of 4 for alignment. 824 physical_length = ((source_length - 1) | 0x3) + 1 825 padded_data = data + (b"\0" * (physical_length - source_length)) 826 827 serialized_values.append( 828 struct.pack("iii", op_index, source_type, source_length) 829 ) 830 serialized_value_data.append(padded_data) 831 832 return serialized_values, serialized_value_data 833 834 @staticmethod 835 def serialize_ints(ints): 836 return array.array("i", ints).tobytes() 837 838 ADDER_MAP = { 839 "prim::GetAttr": lambda self, node: self.add_getattr(node), 840 "prim::Constant": lambda self, node: self.add_constant_node(node), 841 "prim::ListConstruct": lambda self, node: self.add_list_construct(node), 842 "prim::TupleConstruct": lambda self, node: self.add_tuple_construct(node), 843 "aten::unsqueeze": lambda self, node: self.add_unsqueeze(node), 844 "aten::to": lambda self, node: self.add_to(node), 845 "aten::detach": lambda self, node: self._identity(node), 846 "aten::reshape": lambda self, node: self.add_reshape(node), 847 "aten::flatten": lambda self, node: self.add_flatten(node), 848 "aten::slice": lambda self, node: self.add_slice(node), 849 "aten::size": lambda self, node: self.add_size(node), 850 "aten::cat": lambda self, node: self.add_cat(node), 851 "aten::mean": lambda self, node: self.add_mean(node), 852 "aten::quantize_per_tensor": lambda self, node: self.add_quantize(node), 853 "aten::dequantize": lambda self, node: self.add_dequantize(node), 854 "aten::add": lambda self, node: self.add_add_sub_op( 855 node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE 856 ), 857 "aten::sub": lambda self, node: self.add_add_sub_op( 858 node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE 859 ), 860 "aten::mul": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( 861 node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE 862 ), 863 "aten::div": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( 864 node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE 865 ), 866 "aten::relu": lambda self, node: self.add_pointwise_simple_unary_op( 867 node, NNAPI_OperationCode.RELU 868 ), 869 "aten::sigmoid": lambda self, node: self.add_pointwise_simple_unary_op( 870 node, NNAPI_OperationCode.LOGISTIC 871 ), 872 "aten::softmax": lambda self, node: self.add_softmax(node), 873 "aten::hardtanh": lambda self, node: self.add_hardtanh(node), 874 "aten::avg_pool2d": lambda self, node: self.add_avg_pool2d(node), 875 "aten::max_pool2d": lambda self, node: self.add_pool2d_node( 876 node, NNAPI_OperationCode.MAX_POOL_2D 877 ), 878 "aten::adaptive_avg_pool2d": lambda self, node: self.add_adaptive_avg_pool2d( 879 node 880 ), 881 "aten::upsample_nearest2d": lambda self, node: self.add_upsample_nearest2d( 882 node 883 ), 884 "aten::prelu": lambda self, node: self.add_prelu_op(node), 885 "aten::addmm": lambda self, node: self.add_addmm(node), 886 "aten::linear": lambda self, node: self.add_linear(node), 887 "aten::_convolution": lambda self, node: self.add_conv_underscore(node), 888 "aten::conv2d": lambda self, node: self.add_conv2d(node), 889 "aten::log_softmax": lambda self, node: self.add_log_softmax(node), 890 "quantized::linear": lambda self, node: self.add_qlinear(node), 891 "quantized::conv2d": lambda self, node: self.add_qconv2d( 892 node, NNAPI_FuseCode.FUSED_NONE 893 ), 894 "quantized::conv2d_relu": lambda self, node: self.add_qconv2d( 895 node, NNAPI_FuseCode.FUSED_RELU 896 ), 897 "quantized::conv_transpose2d": lambda self, node: self.add_qconv2d( 898 node, NNAPI_FuseCode.FUSED_NONE, transpose=True 899 ), 900 "quantized::add": lambda self, node: self.add_qadd( 901 node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE 902 ), 903 "quantized::add_relu": lambda self, node: self.add_qadd( 904 node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU 905 ), 906 "quantized::mul": lambda self, node: self.add_qadd( 907 node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE 908 ), 909 } 910 911 def add_node(self, node): 912 adder = self.ADDER_MAP.get(node.kind()) 913 if not adder: 914 raise Exception( # noqa: TRY002 915 f"Unsupported node kind ({node.kind()!r}) in node {node!r}" 916 ) # noqa: TRY002 917 adder(self, node) 918 919 def _identity(self, node): 920 in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) 921 jitval = node.outputsAt(0) 922 self.jitval_operand_map[jitval] = in_id 923 924 def add_getattr(self, node): 925 assert node.inputsSize() == 1 926 assert node.outputsSize() == 1 927 obj_ctype, obj = self.get_constant_value(node.inputsAt(0)) 928 assert str(obj_ctype).startswith("__torch__.") 929 name = node.s("name") 930 value = getattr(obj, name) 931 output = node.outputsAt(0) 932 ctype = output.type() 933 self.add_constant_value(output, ctype, value) 934 935 def add_constant_node(self, node): 936 assert node.inputsSize() == 0 937 assert node.outputsSize() == 1 938 output = node.outputsAt(0) 939 ctype = output.type() 940 value = output.toIValue() 941 self.add_constant_value(output, ctype, value) 942 943 def add_list_construct(self, node): 944 assert node.outputsSize() == 1 945 output = node.outputsAt(0) 946 ctype = output.type() 947 const_vals: Optional[List] = [] 948 tensors: Optional[List] = [] 949 for inp in node.inputs(): 950 if const_vals is not None and inp in self.constants: 951 _, val = self.get_constant_value(inp) 952 const_vals.append(val) 953 else: 954 const_vals = None 955 if tensors is not None and inp.type().kind() == "TensorType": 956 tensors.append(inp) 957 else: 958 tensors = None 959 960 if const_vals is not None: 961 # NOTE: Now that TorchScript supports list constants, 962 # this code path might not be used anymore. 963 self.add_constant_value(output, ctype, const_vals) 964 if tensors is not None: 965 self.add_tensor_sequence(output, tensors) 966 if const_vals is None and tensors is None: 967 raise Exception( # noqa: TRY002 968 f"Unable to handle ListConstruct node. Neither all constants nor all tensors. {node!r}" 969 ) 970 971 def add_tuple_construct(self, node): 972 assert node.outputsSize() == 1 973 output = node.outputsAt(0) 974 values = list(node.inputs()) 975 self.add_tensor_sequence(output, values) 976 977 def add_unsqueeze(self, node): 978 assert node.inputsSize() == 2 979 assert node.outputsSize() == 1 980 981 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) 982 983 _, dim = self.get_constant_value(node.inputsAt(1), "IntType") 984 assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS 985 986 real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1 987 out_shape_list = list(in_oper.shape) 988 out_shape_list.insert(real_dim, 1) 989 out_shape = tuple(out_shape_list) 990 out_oper = in_oper._replace(shape=out_shape) 991 992 inputs = [None] * 2 993 inputs[0] = in_id 994 inputs[1] = self.add_immediate_int_scalar(dim) 995 996 outputs = [None] * 1 997 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) 998 999 self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs) 1000 1001 def add_to(self, node): 1002 # Handle to("cpu") / to("gpu") case 1003 self._identity(node) 1004 1005 def add_reshape(self, node): 1006 assert node.inputsSize() == 2 1007 assert node.outputsSize() == 1 1008 1009 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) 1010 1011 shape_ctype, shape = self.get_constant_value(node.inputsAt(1)) 1012 assert shape_ctype.kind() == "ListType" 1013 assert shape_ctype.getElementType().kind() == "IntType" 1014 is_trivial_reshape = len(shape) == 2 and shape[1] == -1 1015 1016 if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape: 1017 raise Exception( # noqa: TRY002 1018 "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1]." 1019 ) 1020 1021 # Bit of a hack here. Use a real tensor to infer the output shape. 1022 out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape 1023 out_oper = in_oper._replace( 1024 shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS 1025 ) 1026 1027 inputs = [None] * 2 1028 inputs[0] = in_id 1029 inputs[1] = self.add_immediate_int_vector(shape) 1030 1031 outputs = [None] * 1 1032 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) 1033 1034 self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs) 1035 1036 def add_flatten(self, node): 1037 assert node.inputsSize() == 3 1038 assert node.outputsSize() == 1 1039 1040 in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) 1041 1042 start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType") 1043 end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType") 1044 1045 # channels last with channels == 1 or (height & width both 1) 1046 is_trivial_flatten = len(in_oper.shape) == 4 and ( 1047 in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1) 1048 ) 1049 if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten: 1050 raise Exception( # noqa: TRY002 1051 "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1" 1052 ) 1053 1054 if start_dim < 0: 1055 start_dim += len(in_oper.shape) 1056 if end_dim < 0: 1057 end_dim += len(in_oper.shape) 1058 1059 out_shape = ( 1060 in_oper.shape[:start_dim] 1061 + (functools.reduce(operator.mul, in_oper.shape[start_dim : end_dim + 1]),) 1062 + in_oper.shape[end_dim + 1 :] 1063 ) 1064 1065 if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]): 1066 raise Exception( # noqa: TRY002 1067 "Flattening flexible dims is not supported yet" 1068 ) # noqa: TRY002 1069 non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :] 1070 if non_flattened_dims.count(0) > 1: 1071 raise Exception("Only 1 dim can be flexible") # noqa: TRY002 1072 1073 out_oper = in_oper._replace( 1074 shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS 1075 ) 1076 out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) 1077 1078 for idx, dim in enumerate(out_shape): 1079 if dim == 0: 1080 self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0)) 1081 1082 inputs_1 = tuple(dim if dim != 0 else -1 for dim in out_shape) 1083 inputs = [None] * 2 1084 inputs[0] = in_id 1085 inputs[1] = self.add_immediate_int_vector(inputs_1) 1086 1087 outputs = [None] * 1 1088 outputs[0] = out_id 1089 1090 self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs) 1091 1092 def add_slice(self, node): 1093 assert node.inputsSize() == 5 1094 assert node.outputsSize() == 1 1095 1096 in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) 1097 _, dim_value = self.get_constant_value(node.inputsAt(1)) 1098 _, start_value = self.get_constant_value(node.inputsAt(2)) 1099 _, stop_value = self.get_constant_value(node.inputsAt(3)) 1100 _, step_value = self.get_constant_value(node.inputsAt(4)) 1101 1102 if start_value is None: 1103 start_value = 0 1104 if stop_value is None: 1105 stop_value = sys.maxsize 1106 1107 if start_value < 0: 1108 start_value += in_oper.shape[dim_value] 1109 elif start_value == sys.maxsize: 1110 start_value = 0 1111 1112 if start_value == 0 and stop_value == sys.maxsize: 1113 self._identity(node) 1114 return 1115 1116 if in_oper.shape[dim_value] == 0: 1117 raise Exception("Unable to slice with flexible shape") # noqa: TRY002 1118 1119 if stop_value < 0: 1120 stop_value += in_oper.shape[dim_value] 1121 elif stop_value == sys.maxsize: 1122 stop_value = in_oper.shape[dim_value] 1123 1124 if start_value >= stop_value: 1125 raise Exception( # noqa: TRY002 1126 "Slice start value should be less than stop value" 1127 ) # noqa: TRY002 1128 1129 out_len = (stop_value - start_value) // step_value 1130 out_shape = tuple( 1131 out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape) 1132 ) 1133 out_id = self.add_tensor_operand( 1134 node.outputsAt(0), in_oper._replace(shape=out_shape) 1135 ) 1136 1137 # flex inputs 1138 end_mask = 0 1139 for idx, dim in enumerate(out_shape): 1140 if dim == 0: 1141 self.forward_operand_shape(out_id, idx, in_id, idx) 1142 end_mask |= 1 << idx 1143 1144 inputs = [None] * 7 1145 inputs[0] = in_id 1146 inputs[1] = self.add_immediate_int_vector( 1147 [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))] 1148 ) 1149 inputs[2] = self.add_immediate_int_vector( 1150 [ 1151 stop_value if i == dim_value else dim 1152 for i, dim in enumerate(in_oper.shape) 1153 ] 1154 ) 1155 inputs[3] = self.add_immediate_int_vector( 1156 [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))] 1157 ) 1158 inputs[4] = self.add_immediate_int_scalar(0) # begin mask 1159 inputs[5] = self.add_immediate_int_scalar(end_mask) 1160 inputs[6] = self.add_immediate_int_scalar(0) # shrink axis mas 1161 1162 outputs = [None] * 1 1163 outputs[0] = out_id 1164 1165 self.add_operation(NNAPI_OperationCode.STRIDED_SLICE, inputs, outputs) 1166 1167 def add_size(self, node): 1168 assert node.inputsSize() == 2 1169 assert node.outputsSize() == 1 1170 1171 _, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) 1172 _, value = self.constants[node.inputsAt(1)] 1173 res = in_oper.shape[value] 1174 output = node.outputsAt(0) 1175 self.add_constant_value(output, output.type(), res) 1176 1177 def add_cat(self, node): 1178 assert node.inputsSize() == 2 1179 assert node.outputsSize() == 1 1180 1181 tensors = self.tensor_sequences[node.inputsAt(0)] 1182 _, dim = self.get_constant_value(node.inputsAt(1), "IntType") 1183 1184 assert len(tensors) > 0 1185 in_ids = [] 1186 out_oper = None 1187 out_dim_size = 0 1188 for inp in tensors: 1189 in_id, in_oper = self.get_tensor_operand_by_jitval(inp) 1190 if out_oper is None: 1191 out_shape = change_element(in_oper.shape, dim, -1) 1192 out_oper = in_oper._replace(shape=out_shape) 1193 assert in_oper.op_type == out_oper.op_type 1194 assert in_oper.dim_order == out_oper.dim_order 1195 assert change_element(in_oper.shape, dim, -1) == change_element( 1196 out_oper.shape, dim, -1 1197 ) 1198 # TODO: Possibly check scale and zero point. 1199 in_ids.append(in_id) 1200 # TODO: Possibly support variable-sized inputs. 1201 out_dim_size += in_oper.shape[dim] 1202 1203 assert out_oper is not None 1204 out_oper = out_oper._replace( 1205 shape=change_element(out_oper.shape, dim, out_dim_size) 1206 ) 1207 1208 if in_oper.dim_order == DimOrder.CHANNELS_LAST: # type: ignore[possibly-undefined] 1209 assert len(out_oper.shape) == 4 1210 nnapi_dim = [0, 3, 1, 2][dim] 1211 else: 1212 nnapi_dim = dim 1213 1214 out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) 1215 for idx, d in enumerate(out_oper.shape): 1216 if d == 0: 1217 if idx == dim: 1218 shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids) 1219 self.compute_operand_shape(out_id, idx, shape) 1220 else: 1221 self.forward_operand_shape(out_id, idx, in_ids[0], idx) 1222 1223 inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)] 1224 1225 outputs = [None] * 1 1226 outputs[0] = out_id 1227 1228 self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs) 1229 1230 def add_mean(self, node): 1231 assert node.inputsSize() == 4 1232 assert node.outputsSize() == 1 1233 1234 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) 1235 dim_ctype, dim = self.get_constant_value(node.inputsAt(1)) 1236 assert dim_ctype.kind() == "ListType" 1237 assert dim_ctype.getElementType().kind() == "IntType" 1238 _, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType") 1239 # Expect None for dtype 1240 self.get_constant_value(node.inputsAt(3), "NoneType") 1241 1242 if in_oper.dim_order == DimOrder.CHANNELS_LAST: 1243 assert len(in_oper.shape) == 4 1244 nnapi_dim = [[0, 3, 1, 2][d] for d in dim] 1245 else: 1246 nnapi_dim = dim 1247 1248 collapsed_dims = set() 1249 for d in dim: 1250 if d < 0: 1251 d += len(in_oper.shape) 1252 collapsed_dims.add(d) 1253 1254 if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim: 1255 assert collapsed_dims.issuperset({2, 3}) 1256 out_dim_order = DimOrder.PRESUMED_CONTIGUOUS 1257 else: 1258 out_dim_order = in_oper.dim_order 1259 1260 out_shape = [] 1261 for i, s in enumerate(in_oper.shape): 1262 if i not in collapsed_dims: 1263 out_shape.append(s) 1264 elif keep_dim: 1265 out_shape.append(1) 1266 1267 out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order) 1268 1269 inputs = [None] * 3 1270 inputs[0] = in_id 1271 inputs[1] = self.add_immediate_int_vector(nnapi_dim) 1272 inputs[2] = self.add_immediate_int_scalar(keep_dim) 1273 1274 outputs = [None] * 1 1275 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) 1276 1277 self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs) 1278 1279 def add_quantize(self, node): 1280 assert node.inputsSize() == 4 1281 assert node.outputsSize() == 1 1282 1283 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) 1284 if in_oper.dim_order != DimOrder.CHANNELS_LAST: 1285 raise Exception( # noqa: TRY002 1286 "Most hardware backends prefer NHWC quantized tensors. " 1287 "Try setting `t.nnapi_nhwc = True` on your tensor inputs. " 1288 ) 1289 _, scale = self.get_constant_value(node.inputsAt(1), "FloatType") 1290 _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType") 1291 _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType") 1292 if scalar_type != TorchScalarTypes.QUINT8.value: 1293 raise Exception( # noqa: TRY002 1294 "PyTorch NNAPI export only supports quantized tensors " 1295 "with the quint8 dtype." 1296 ) 1297 op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM 1298 1299 out_oper = in_oper._replace( 1300 op_type=op_type, 1301 scale=scale, 1302 zero_point=zero_point, 1303 ) 1304 1305 inputs = [None] * 1 1306 inputs[0] = in_id 1307 1308 outputs = [None] * 1 1309 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) 1310 1311 self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs) 1312 1313 def add_dequantize(self, node): 1314 assert node.inputsSize() == 1 1315 assert node.outputsSize() == 1 1316 1317 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) 1318 out_oper = in_oper._replace( 1319 op_type=NNAPI_OperandCode.TENSOR_FLOAT32, 1320 scale=0.0, 1321 zero_point=0, 1322 ) 1323 1324 inputs = [None] * 1 1325 inputs[0] = in_id 1326 1327 outputs = [None] * 1 1328 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) 1329 1330 self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs) 1331 1332 def add_pointwise_simple_unary_op(self, node, opcode): 1333 assert node.inputsSize() == 1 1334 assert node.outputsSize() == 1 1335 1336 in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) 1337 1338 out_oper = in_oper 1339 if opcode == NNAPI_OperationCode.LOGISTIC: 1340 # NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale 1341 # must be 1.f / 256 and the zeroPoint must be 0. 1342 # https://fburl.com/h52stoog 1343 if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1344 out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256) 1345 1346 out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) 1347 1348 for idx, dim in enumerate(in_oper.shape): 1349 if dim == 0: 1350 self.forward_operand_shape(out_id, idx, in_id, idx) 1351 1352 inputs = [None] * 1 1353 inputs[0] = in_id 1354 1355 outputs = [None] * 1 1356 outputs[0] = out_id 1357 1358 self.add_operation(opcode, inputs, outputs) 1359 1360 def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): # noqa: D401 1361 """Helper for pointwise binary broadcast ops with superfluous extra args.""" 1362 assert node.outputsSize() == 1 1363 1364 assert node.inputsAt(0).type().kind() == "TensorType" 1365 assert node.inputsAt(1).type().kind() == "TensorType" 1366 1367 if self.has_operand_for_jitval(node.inputsAt(0)): 1368 in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) 1369 in1_id, in1_oper = self.get_tensor_operand_or_constant( 1370 node.inputsAt(1), in0_oper.dim_order 1371 ) 1372 elif self.has_operand_for_jitval(node.inputsAt(1)): 1373 in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1)) 1374 in0_id, in0_oper = self.get_tensor_operand_or_constant( 1375 node.inputsAt(0), in1_oper.dim_order 1376 ) 1377 else: 1378 raise Exception( # noqa: TRY002 1379 f"Can't do a NNAPI binary op: {opcode} on two constants" 1380 ) # noqa: TRY002 1381 1382 assert in0_oper.op_type == in1_oper.op_type 1383 in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast( 1384 in0_id, in0_oper, in1_id, in1_oper 1385 ) 1386 # NOTE: PyTorch and NNAPI have the same broadcast semantics. 1387 out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape) 1388 out_oper = in0_oper._replace(shape=out_shape) 1389 if qparams is not None: 1390 scale, zp = qparams 1391 out_oper = out_oper._replace(scale=scale, zero_point=zp) 1392 1393 out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) 1394 for idx, (d0, d1) in enumerate(zip(in0_oper.shape, in1_oper.shape)): 1395 if d0 == 1 and d1 == 0: 1396 self.forward_operand_shape(out_id, idx, in1_id, idx) 1397 elif d0 == 0 and d1 == 1: 1398 self.forward_operand_shape(out_id, idx, in0_id, idx) 1399 elif d0 == 0 and d1 == 0: 1400 self.flexible_shape_computation_lines.append( 1401 f"assert {flex_name(in0_id, idx)} == {flex_name(in1_id, idx)}" 1402 ) 1403 self.forward_operand_shape(out_id, idx, in0_id, idx) 1404 1405 inputs = [None] * 3 1406 inputs[0] = in0_id 1407 inputs[1] = in1_id 1408 inputs[2] = self.add_immediate_int_scalar(fuse_code) 1409 1410 outputs = [None] * 1 1411 outputs[0] = out_id 1412 1413 self.add_operation(opcode, inputs, outputs) 1414 1415 def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code): 1416 assert node.inputsSize() == 2 1417 self._do_add_binary(node, opcode, fuse_code) 1418 1419 def add_add_sub_op(self, node, opcode, fuse_code): 1420 assert node.inputsSize() == 3 1421 1422 _, alpha = self.get_constant_value(node.inputsAt(2), "IntType") 1423 if alpha != 1: 1424 raise Exception( # noqa: TRY002 1425 "NNAPI does not support add/sub with alpha." 1426 ) # noqa: TRY002 1427 1428 self._do_add_binary(node, opcode, fuse_code) 1429 1430 def add_qadd(self, node, opcode, fuse_code): 1431 assert node.inputsSize() == 4 1432 1433 _, scale = self.get_constant_value(node.inputsAt(2), "FloatType") 1434 _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType") 1435 1436 self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point)) 1437 1438 def add_softmax(self, node): 1439 assert node.inputsSize() == 3 1440 in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) 1441 1442 _, softmax_dim = self.get_constant_value(node.inputsAt(1), "IntType") 1443 1444 out_id = self.add_tensor_operand(node.outputsAt(0), in_oper) 1445 for dim, size in enumerate(in_oper.shape): 1446 if size == 0: 1447 self.forward_operand_shape(out_id, dim, in_id, dim) 1448 1449 inputs = [None] * 3 1450 inputs[0] = in_id 1451 inputs[1] = self.add_immediate_float_scalar( 1452 1.0 1453 ) # positive scaling factor of exponent, beta 1454 inputs[2] = self.add_immediate_int_scalar(softmax_dim) 1455 1456 outputs = [None] * 1 1457 outputs[0] = out_id 1458 1459 self.add_operation(NNAPI_OperationCode.SOFTMAX, inputs, outputs) 1460 1461 def add_hardtanh(self, node): 1462 assert node.inputsSize() == 3 1463 assert node.outputsSize() == 1 1464 1465 in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) 1466 _, min_val = self.get_constant_value(node.inputsAt(1), "FloatType") 1467 _, max_val = self.get_constant_value(node.inputsAt(2), "FloatType") 1468 1469 op_map = { 1470 (-1, 1): NNAPI_OperationCode.RELU1, 1471 (0, 6): NNAPI_OperationCode.RELU6, # noqa: E201 1472 } 1473 1474 opcode = op_map.get((min_val, max_val)) 1475 if opcode is None: 1476 raise Exception( # noqa: TRY002 1477 "NNAPI only supports hardtanh with args (-1, 1) or (0, 6)." 1478 ) # noqa: TRY002 1479 1480 inputs = [None] * 1 1481 inputs[0] = in_id 1482 1483 outputs = [None] * 1 1484 outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper) 1485 1486 self.add_operation(opcode, inputs, outputs) 1487 1488 def add_prelu_op(self, node): 1489 assert node.inputsSize() == 2 1490 assert node.outputsSize() == 1 1491 1492 assert node.inputsAt(0).type().kind() == "TensorType" 1493 assert node.inputsAt(1).type().kind() == "TensorType" 1494 1495 in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) 1496 w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1)) 1497 assert len(w_oper.shape) == 1 1498 assert w_oper.shape[0] > 0 1499 if w_oper.shape[0] > 1: 1500 if in_oper.use_nchw(): 1501 # TODO: Support this by adding trailing 1 dims. 1502 raise Exception( # noqa: TRY002 1503 "Per-channel PReLU only supports channels_last right now." 1504 ) 1505 1506 out_id = self.add_tensor_operand(node.outputsAt(0), in_oper) 1507 for dim, size in enumerate(in_oper.shape): 1508 if size > 0: 1509 pass 1510 elif dim <= 1: 1511 raise Exception( # noqa: TRY002 1512 "PReLU requires fixed size for dim 0 and dim 1." 1513 ) # noqa: TRY002 1514 else: 1515 self.forward_operand_shape(out_id, dim, in_id, dim) 1516 1517 inputs = [None] * 2 1518 inputs[0] = in_id 1519 inputs[1] = w_id 1520 1521 outputs = [None] * 1 1522 outputs[0] = out_id 1523 1524 self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs) 1525 1526 def add_pool2d_node(self, node, opcode): 1527 assert node.inputsSize() == 6 1528 assert node.outputsSize() == 1 1529 image, kernel, stride, padding, dilation, ceil_mode = node.inputs() 1530 1531 stride = stride or kernel 1532 1533 # TODO: Validate ceil_mode semantics. 1534 1535 args = self.get_conv_pool_args_2d_from_jit( 1536 self.get_size_arg(kernel), stride, padding, dilation 1537 ) 1538 if args.dilation_h != 1 or args.dilation_w != 1: 1539 raise Exception("NNAPI does not support dilated pooling.") # noqa: TRY002 1540 1541 image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image) 1542 assert len(image_oper.shape) == 4 1543 1544 out_shape = get_conv_pool_shape( 1545 image_oper.shape, args, image_oper.shape[1], False 1546 ) 1547 use_nchw = image_oper.use_nchw() 1548 1549 inputs = [None] * 11 1550 inputs[0] = image_id 1551 inputs[1] = self.add_immediate_int_scalar(args.pad_l) 1552 inputs[2] = self.add_immediate_int_scalar(args.pad_r) 1553 inputs[3] = self.add_immediate_int_scalar(args.pad_t) 1554 inputs[4] = self.add_immediate_int_scalar(args.pad_b) 1555 inputs[5] = self.add_immediate_int_scalar(args.stride_w) 1556 inputs[6] = self.add_immediate_int_scalar(args.stride_h) 1557 inputs[7] = self.add_immediate_int_scalar(args.kernel_w) 1558 inputs[8] = self.add_immediate_int_scalar(args.kernel_h) 1559 inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) 1560 inputs[10] = self.add_immediate_bool_scalar(use_nchw) 1561 1562 outputs = [None] * 1 1563 outputs[0] = self.add_tensor_operand( 1564 node.outputsAt(0), image_oper._replace(shape=out_shape) 1565 ) 1566 1567 self.add_operation(opcode, inputs, outputs) 1568 1569 def add_avg_pool2d(self, node): 1570 assert node.inputsSize() == 7 1571 assert node.outputsSize() == 1 1572 ( 1573 image, 1574 kernel, 1575 stride, 1576 padding, 1577 ceil_mode, 1578 count_include_pad, 1579 divisor_override, 1580 ) = node.inputs() 1581 1582 _, count_include_pad_value = self.get_constant_value(count_include_pad) 1583 _, divisor_override_value = self.get_constant_value(divisor_override) 1584 if not count_include_pad_value or divisor_override_value: 1585 raise Exception( # noqa: TRY002 1586 "NNAPI doesn't support count_include_pad=False or divisor_override" 1587 ) 1588 1589 args = self.get_conv_pool_args_2d_from_jit( 1590 self.get_size_arg(kernel), stride, padding 1591 ) 1592 1593 image_id, image_oper = self.get_tensor_operand_by_jitval(image) 1594 assert len(image_oper.shape) == 4 1595 1596 out_shape = get_conv_pool_shape( 1597 image_oper.shape, args, image_oper.shape[1], False 1598 ) 1599 use_nchw = image_oper.use_nchw() 1600 1601 inputs = [None] * 11 1602 inputs[0] = image_id 1603 inputs[1] = self.add_immediate_int_scalar(args.pad_l) 1604 inputs[2] = self.add_immediate_int_scalar(args.pad_r) 1605 inputs[3] = self.add_immediate_int_scalar(args.pad_t) 1606 inputs[4] = self.add_immediate_int_scalar(args.pad_b) 1607 inputs[5] = self.add_immediate_int_scalar(args.stride_w) 1608 inputs[6] = self.add_immediate_int_scalar(args.stride_h) 1609 inputs[7] = self.add_immediate_int_scalar(args.kernel_w) 1610 inputs[8] = self.add_immediate_int_scalar(args.kernel_h) 1611 inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) 1612 inputs[10] = self.add_immediate_bool_scalar(use_nchw) 1613 1614 outputs = [None] * 1 1615 out_id = self.add_tensor_operand( 1616 node.outputsAt(0), image_oper._replace(shape=out_shape) 1617 ) 1618 self._handle_conv_pool_flexible_input(out_id, image, args, False) 1619 outputs[0] = out_id 1620 1621 self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) 1622 1623 def add_adaptive_avg_pool2d(self, node): 1624 assert node.inputsSize() == 2 1625 assert node.outputsSize() == 1 1626 1627 image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size( 1628 node.inputsAt(0) 1629 ) 1630 assert len(image_oper.shape) == 4 1631 1632 size_ctype, size_arg = self.get_constant_value(node.inputsAt(1)) 1633 assert size_ctype.kind() == "ListType" 1634 assert size_ctype.getElementType().kind() == "IntType" 1635 if size_arg != [1, 1]: 1636 raise Exception( # noqa: TRY002 1637 "NNAPI only supports adaptive_avg_pool2d with output size (1, 1)." 1638 ) 1639 1640 out_shape = image_oper.shape[0:2] + tuple(size_arg) 1641 use_nchw = image_oper.use_nchw() 1642 1643 inputs = [None] * 11 1644 inputs[0] = image_id 1645 inputs[1] = self.add_immediate_int_scalar(0) 1646 inputs[2] = self.add_immediate_int_scalar(0) 1647 inputs[3] = self.add_immediate_int_scalar(0) 1648 inputs[4] = self.add_immediate_int_scalar(0) 1649 inputs[5] = self.add_immediate_int_scalar(1) 1650 inputs[6] = self.add_immediate_int_scalar(1) 1651 inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3]) 1652 inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2]) 1653 inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) 1654 inputs[10] = self.add_immediate_bool_scalar(use_nchw) 1655 1656 outputs = [None] * 1 1657 outputs[0] = self.add_tensor_operand( 1658 node.outputsAt(0), image_oper._replace(shape=out_shape) 1659 ) 1660 1661 self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) 1662 1663 def add_upsample_nearest2d(self, node): 1664 assert node.inputsSize() == 3 or node.inputsSize() == 4 1665 assert node.outputsSize() == 1 1666 if node.inputsSize() == 3: 1667 image, size_jit, scale_jit = node.inputs() 1668 else: 1669 image, size_jit, scale_h_jit, scale_w_jit = node.inputs() 1670 size_ctype, size_arg = self.get_constant_value(size_jit) 1671 1672 if node.inputsSize() == 3: 1673 scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined] 1674 else: 1675 scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined] 1676 scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined] 1677 1678 # The only way for the 4-argument overload of upsample_nearest2d to 1679 # have been added to the graph without error is if the scale_h and 1680 # scale_w arguments are None 1681 assert scale_h_ctype.kind() == "NoneType" 1682 assert scale_w_ctype.kind() == "NoneType" 1683 1684 scale_ctype = scale_h_ctype 1685 scale_arg = scale_h_arg 1686 1687 image_id, image_oper = self.get_tensor_operand_by_jitval(image) 1688 assert len(image_oper.shape) == 4 1689 1690 if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType": 1691 raise Exception("Size and scale cannot both be non-None.") # noqa: TRY002 1692 elif size_ctype.kind() != "NoneType": 1693 assert size_ctype.kind() == "ListType" 1694 assert size_ctype.getElementType().kind() == "IntType" 1695 assert scale_ctype.kind() == "NoneType" 1696 assert scale_arg is None 1697 assert isinstance(size_arg, list) 1698 assert size_arg 1699 assert all(isinstance(val, int) for val in size_arg) 1700 if len(size_arg) == 1: 1701 size_arg = size_arg * 2 1702 assert len(size_arg) == 2 1703 out_h = size_arg[0] 1704 out_w = size_arg[1] 1705 arg_h = self.add_immediate_int_scalar(out_h) 1706 arg_w = self.add_immediate_int_scalar(out_w) 1707 elif scale_ctype.kind() != "NoneType": 1708 assert scale_ctype.kind() == "ListType" 1709 assert scale_ctype.getElementType().kind() == "FloatType" 1710 assert size_ctype.kind() == "NoneType" 1711 assert size_arg is None 1712 assert isinstance(scale_arg, list) 1713 assert scale_arg 1714 assert all(isinstance(val, float) for val in scale_arg) 1715 if len(scale_arg) == 1: 1716 scale_arg = scale_arg * 2 1717 assert len(scale_arg) == 2 1718 out_h = int(scale_arg[0] * image_oper.shape[2]) 1719 out_w = int(scale_arg[1] * image_oper.shape[3]) 1720 arg_h = self.add_immediate_float_scalar(scale_arg[0]) 1721 arg_w = self.add_immediate_float_scalar(scale_arg[1]) 1722 else: 1723 raise Exception("Size and scale cannot both be None.") # noqa: TRY002 1724 1725 out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w) 1726 use_nchw = image_oper.use_nchw() 1727 out_id = self.add_tensor_operand( 1728 node.outputsAt(0), image_oper._replace(shape=out_shape) 1729 ) 1730 1731 if image_oper.shape[0] == 0 or image_oper.shape[1] == 0: 1732 raise Exception("Flexible batch or channels not supported") # noqa: TRY002 1733 1734 # Handle variable input size 1735 for dim in (2, 3): # h, w indices 1736 if image_oper.shape[dim] == 0: 1737 if size_ctype.kind() != "NoneType": 1738 self.compute_operand_shape(out_id, dim, size_arg[dim - 2]) 1739 elif scale_ctype.kind() != "NoneType": 1740 self.compute_operand_shape( 1741 out_id, 1742 dim, 1743 f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})", 1744 ) 1745 else: 1746 raise Exception( # noqa: TRY002 1747 "Size and scale cannot both be None." 1748 ) # noqa: TRY002 1749 1750 inputs = [None] * 4 1751 inputs[0] = image_id 1752 inputs[1] = arg_w 1753 inputs[2] = arg_h 1754 inputs[3] = self.add_immediate_bool_scalar(use_nchw) 1755 1756 outputs = [None] * 1 1757 outputs[0] = out_id 1758 1759 self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs) 1760 1761 def add_addmm(self, node): 1762 assert node.inputsSize() == 5 1763 assert node.outputsSize() == 1 1764 jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs() 1765 1766 for jitval in (jit_beta, jit_alpha): 1767 scale_ctype, scale_value = self.get_constant_value(jitval) 1768 assert scale_ctype.kind() in ("IntType", "FloatType") 1769 if scale_value != 1: 1770 raise Exception( # noqa: TRY002 1771 "NNAPI Fully-Connected does not support alpha and beta." 1772 ) 1773 1774 self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias) 1775 1776 def add_linear(self, node): 1777 assert node.inputsSize() == 3 1778 assert node.outputsSize() == 1 1779 jit_input, jit_weight, jit_bias = node.inputs() 1780 1781 self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias) 1782 1783 def add_addmm_or_linear( 1784 self, node, transpose_weight, jit_input, jit_weight, jit_bias 1785 ): 1786 input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input) 1787 bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias) 1788 1789 assert len(input_oper.shape) == 2 1790 assert len(bias_oper.shape) == 1 1791 1792 # TODO: Transform at load time to share weights with CPU model. 1793 _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") 1794 assert len(weight_tensor.shape) == 2 1795 if transpose_weight: 1796 nnapi_weight_tensor = weight_tensor.t().contiguous() 1797 else: 1798 nnapi_weight_tensor = weight_tensor.contiguous() 1799 weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) 1800 weight_oper = self.operands[weight_id] 1801 1802 out_shape = (input_oper.shape[0], weight_oper.shape[0]) 1803 out_id = self.add_tensor_operand( 1804 node.outputsAt(0), input_oper._replace(shape=out_shape) 1805 ) 1806 1807 if input_oper.shape[0] == 0: 1808 self.forward_operand_shape(out_id, 0, input_id, 0) 1809 1810 inputs = [None] * 4 1811 inputs[0] = input_id 1812 inputs[1] = weight_id 1813 inputs[2] = bias_id 1814 inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) 1815 1816 outputs = [None] * 1 1817 outputs[0] = out_id 1818 1819 self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) 1820 1821 def add_qlinear(self, node): 1822 assert node.inputsSize() == 4 1823 assert node.outputsSize() == 1 1824 ( 1825 jit_input, 1826 jit_packed_weight, 1827 jit_scale, 1828 jit_zero_point, 1829 ) = node.inputs() 1830 1831 input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) 1832 # TODO: Support automatic reshape 1833 assert len(input_oper.shape) == 2 1834 1835 _, out_scale = self.get_constant_value(jit_scale, "FloatType") 1836 _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") 1837 weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) 1838 assert weight_ctype.name() == "LinearPackedParamsBase" 1839 raw_weight, raw_bias = packed_weight.__getstate__()[0] 1840 assert raw_bias is not None 1841 1842 assert len(raw_weight.shape) == 2 1843 assert len(raw_bias.shape) == 1 1844 assert raw_bias.shape[0] == raw_weight.shape[0] 1845 assert raw_weight.shape[1] == input_oper.shape[1] 1846 1847 assert raw_weight.qscheme() == torch.per_tensor_affine 1848 if raw_weight.dtype == torch.quint8: 1849 unsigned_weight = raw_weight 1850 else: 1851 assert raw_weight.dtype == torch.qint8 1852 unsigned_weight = torch._make_per_tensor_quantized_tensor( 1853 (raw_weight.int_repr().int() + 128).to(torch.uint8), 1854 scale=raw_weight.q_scale(), 1855 zero_point=raw_weight.q_zero_point() + 128, 1856 ) 1857 weight_scale = unsigned_weight.q_scale() 1858 bias_scale = input_oper.scale * weight_scale 1859 int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) 1860 bias_id = self.add_tensor_operand_for_weight(int_bias) 1861 1862 multiplier = input_oper.scale * weight_scale / out_scale 1863 assert multiplier > 0 1864 if multiplier >= 1: 1865 raise Exception( # noqa: TRY002 1866 "Quantized convolution multiplier is greater than 1. " 1867 "This is supported by NNAPI, but not by most hardware backends. " 1868 "Try training a model without quantization-aware training. " 1869 ) 1870 1871 # TODO: Transform at load time to share weights with CPU model. 1872 nnapi_weight_tensor = unsigned_weight.contiguous() 1873 weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) 1874 weight_oper = self.operands[weight_id] 1875 1876 out_shape = (input_oper.shape[0], weight_oper.shape[0]) 1877 out_oper = input_oper._replace( 1878 shape=out_shape, 1879 scale=out_scale, 1880 zero_point=out_zero_point, 1881 ) 1882 1883 inputs = [None] * 4 1884 inputs[0] = input_id 1885 inputs[1] = weight_id 1886 inputs[2] = bias_id 1887 inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) 1888 1889 outputs = [None] * 1 1890 outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) 1891 1892 self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) 1893 1894 def get_optional_bias(self, jit_bias, weight_tensor, transpose=False): 1895 ctype, value = self.get_constant_value(jit_bias) 1896 if ctype.kind() == "NoneType": 1897 bias_idx = 1 if transpose else 0 1898 nnapi_bias_tensor = torch.zeros( 1899 weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype 1900 ) 1901 bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor) 1902 bias_oper = self.operands[bias_id] 1903 return bias_id, bias_oper 1904 else: 1905 return self.get_tensor_operand_for_weight(jit_bias) 1906 1907 def add_conv2d(self, node): 1908 assert node.inputsSize() == 7 1909 assert node.outputsSize() == 1 1910 1911 ( 1912 jit_image, 1913 jit_weight, 1914 jit_bias, 1915 jit_stride, 1916 jit_pad, 1917 jit_dilation, 1918 jit_groups, 1919 ) = node.inputs() 1920 1921 _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") 1922 bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor) 1923 args = self.get_conv_pool_args_2d_from_jit( 1924 weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups 1925 ) 1926 1927 return self.add_conv2d_common( 1928 node.outputsAt(0), 1929 0.0, 1930 0, 1931 jit_image, 1932 weight_tensor, 1933 bias_id, 1934 args, 1935 False, # transpose 1936 NNAPI_FuseCode.FUSED_NONE, 1937 ) 1938 1939 def add_conv_underscore(self, node): 1940 assert node.inputsSize() == 13 1941 assert node.outputsSize() == 1 1942 1943 ( 1944 jit_image, 1945 jit_weight, 1946 jit_bias, 1947 jit_stride, 1948 jit_pad, 1949 jit_dilation, 1950 jit_transpose, 1951 _, 1952 jit_groups, 1953 _, 1954 _, 1955 _, 1956 _, 1957 ) = node.inputs() 1958 1959 _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") 1960 _, transpose = self.get_constant_value(jit_transpose) 1961 bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose) 1962 args = self.get_conv_pool_args_2d_from_jit( 1963 weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups 1964 ) 1965 1966 return self.add_conv2d_common( 1967 node.outputsAt(0), 1968 0.0, 1969 0, 1970 jit_image, 1971 weight_tensor, 1972 bias_id, 1973 args, 1974 transpose, 1975 NNAPI_FuseCode.FUSED_NONE, 1976 ) 1977 1978 def add_log_softmax(self, node): 1979 assert node.inputsSize() == 3 1980 assert node.outputsSize() == 1 1981 1982 (jit_input, jit_dim, jit_half_to_float) = node.inputs() 1983 input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) 1984 _, dim = self.get_constant_value(jit_dim, "IntType") 1985 1986 out_shape = input_oper.shape 1987 1988 inputs = [None] * 3 1989 inputs[0] = input_id 1990 # specifying 1 as the scaling factor for the exponent, beta 1991 inputs[1] = self.add_immediate_float_scalar(1) 1992 inputs[2] = self.add_immediate_int_scalar(dim) 1993 1994 outputs = [None] * 1 1995 outputs[0] = self.add_tensor_operand( 1996 node.outputsAt(0), input_oper._replace(shape=out_shape) 1997 ) 1998 self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs) 1999 2000 def add_qconv2d(self, node, fuse_code, transpose=False): 2001 assert node.inputsSize() == 4 2002 assert node.outputsSize() == 1 2003 2004 ( 2005 jit_image, 2006 jit_packed_weight, 2007 jit_scale, 2008 jit_zero_point, 2009 ) = node.inputs() 2010 2011 _, out_scale = self.get_constant_value(jit_scale, "FloatType") 2012 _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") 2013 weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) 2014 assert weight_ctype.name() == "Conv2dPackedParamsBase" 2015 ( 2016 pack_version, 2017 tensors, 2018 opt_tensors, 2019 ) = packed_weight.__getstate__()[0] 2020 assert pack_version == "2" 2021 packed_config, raw_weight = tensors 2022 (raw_bias,) = opt_tensors 2023 assert raw_bias is not None 2024 args = self.get_conv_pool_args_2d_from_pack( 2025 raw_weight.shape[2:4], packed_config 2026 ) 2027 2028 assert raw_weight.qscheme() == torch.per_tensor_affine 2029 if raw_weight.dtype == torch.quint8: 2030 unsigned_weight = raw_weight 2031 else: 2032 assert raw_weight.dtype == torch.qint8 2033 unsigned_weight = torch._make_per_tensor_quantized_tensor( 2034 (raw_weight.int_repr().int() + 128).to(torch.uint8), 2035 scale=raw_weight.q_scale(), 2036 zero_point=raw_weight.q_zero_point() + 128, 2037 ) 2038 weight_scale = unsigned_weight.q_scale() 2039 _, image_oper = self.get_tensor_operand_by_jitval(jit_image) 2040 bias_scale = image_oper.scale * weight_scale 2041 int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) 2042 bias_id = self.add_tensor_operand_for_weight(int_bias) 2043 2044 multiplier = image_oper.scale * weight_scale / out_scale 2045 assert multiplier > 0 2046 if multiplier >= 1: 2047 raise Exception( # noqa: TRY002 2048 "Quantized convolution multiplier is greater than 1. " 2049 "This is supported by NNAPI, but not by most hardware backends. " 2050 "Try training a model without quantization-aware training. " 2051 ) 2052 2053 return self.add_conv2d_common( 2054 node.outputsAt(0), 2055 out_scale, 2056 out_zero_point, 2057 jit_image, 2058 unsigned_weight, 2059 bias_id, 2060 args, 2061 transpose, 2062 fuse_code, 2063 ) 2064 2065 def add_conv2d_common( 2066 self, 2067 jit_out, 2068 out_scale, 2069 out_zero_point, 2070 jit_image, 2071 weight_tensor, 2072 bias_id, 2073 args, 2074 transpose, 2075 fuse_code, 2076 ): 2077 image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) 2078 in_c = image_oper.shape[1] 2079 2080 if args.group == 1: 2081 # Full convolution 2082 depthwise = False 2083 if transpose: 2084 weight_permutation = (1, 2, 3, 0) 2085 else: 2086 weight_permutation = (0, 2, 3, 1) 2087 elif args.group == in_c: 2088 # Depthwise convolution 2089 depthwise = True 2090 weight_permutation = (1, 2, 3, 0) 2091 else: 2092 raise Exception("Group convolution not supported yet.") # noqa: TRY002 2093 2094 # TODO: Transform at load time to share weights with CPU model. 2095 nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous() 2096 weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) 2097 weight_oper = self.operands[weight_id] 2098 2099 bias_oper = self.operands[bias_id] 2100 2101 if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: 2102 assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 2103 assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 2104 elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 2105 assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM 2106 assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32 2107 assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale) 2108 assert bias_oper.zero_point == 0 2109 else: 2110 raise Exception( # noqa: TRY002 2111 f"Unsupported input type for conv2d: {image_oper.op_type}" 2112 ) # noqa: TRY002 2113 2114 assert len(image_oper.shape) == 4 2115 assert len(weight_oper.shape) == 4 2116 assert len(bias_oper.shape) == 1 2117 2118 if depthwise: 2119 # Depthwise convolution 2120 one, kern_h, kern_w, out_c = weight_oper.shape 2121 assert one == 1 2122 assert out_c % in_c == 0 2123 channel_multiplier = out_c // in_c 2124 assert channel_multiplier == 1 # Don't support multiplier 2125 assert out_c == in_c 2126 else: 2127 # Full convolution 2128 out_c, kern_h, kern_w, kern_d = weight_oper.shape 2129 assert kern_d == in_c 2130 2131 assert out_c == bias_oper.shape[0] 2132 2133 use_nchw = image_oper.use_nchw() 2134 2135 if depthwise: 2136 num_args = 12 2137 opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D 2138 else: 2139 num_args = 11 2140 if transpose: 2141 opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D 2142 else: 2143 opcode = NNAPI_OperationCode.CONV_2D 2144 2145 inputs = [None] * num_args 2146 inputs[0] = image_id 2147 inputs[1] = weight_id 2148 inputs[2] = bias_id 2149 inputs[3] = self.add_immediate_int_scalar(args.pad_l) 2150 inputs[4] = self.add_immediate_int_scalar(args.pad_r) 2151 inputs[5] = self.add_immediate_int_scalar(args.pad_t) 2152 inputs[6] = self.add_immediate_int_scalar(args.pad_b) 2153 inputs[7] = self.add_immediate_int_scalar(args.stride_w) 2154 inputs[8] = self.add_immediate_int_scalar(args.stride_h) 2155 if depthwise: 2156 inputs[9] = self.add_immediate_int_scalar(1) 2157 inputs[10] = self.add_immediate_int_scalar(fuse_code) 2158 inputs[11] = self.add_immediate_bool_scalar(use_nchw) 2159 else: 2160 inputs[9] = self.add_immediate_int_scalar(fuse_code) 2161 inputs[10] = self.add_immediate_bool_scalar(use_nchw) 2162 2163 outputs = [None] * 1 2164 out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose) 2165 out_oper = image_oper._replace( 2166 shape=out_shape, 2167 scale=out_scale, 2168 zero_point=out_zero_point, 2169 ) 2170 out_id = self.add_tensor_operand(jit_out, out_oper) 2171 self._handle_conv_pool_flexible_input(out_id, jit_image, args, transpose) 2172 2173 outputs[0] = out_id 2174 self.add_operation(opcode, inputs, outputs) 2175 2176 def _handle_conv_pool_flexible_input(self, out_id, jit_image, args, transpose): 2177 image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) 2178 batch, in_ch, in_h, in_w = image_oper.shape 2179 2180 if batch == 0: 2181 self.forward_operand_shape(out_id, 0, image_id, 0) 2182 if in_ch == 0: 2183 raise Exception("Input channels can't be flexible") # noqa: TRY002 2184 # H & W 2185 if transpose: 2186 if in_h == 0: 2187 self.compute_operand_shape( 2188 out_id, 2189 2, 2190 f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}", 2191 ) 2192 if in_w == 0: 2193 self.compute_operand_shape( 2194 out_id, 2195 3, 2196 f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}", 2197 ) 2198 else: 2199 if in_h == 0: 2200 self.compute_operand_shape( 2201 out_id, 2202 2, 2203 f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1", 2204 ) 2205 if in_w == 0: 2206 self.compute_operand_shape( 2207 out_id, 2208 3, 2209 f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1", 2210 ) 2211 2212 2213def serialize_model( 2214 module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False 2215): 2216 """Convert to NNAPI and serialize torchscript module. 2217 2218 Parameters: 2219 module: Torchscript module to convert 2220 inputs: Tensors used to specify input details for NNAPI 2221 config (optional): Optional config to attach to module 2222 return_shapes (optional): Specify shape of outputs if 2223 your module uses runtime flexible shapes to set output 2224 buffer size for NNAPI 2225 use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values 2226 """ 2227 return _NnapiSerializer(config, use_int16_for_qint16).serialize_model( 2228 module, inputs, return_shapes 2229 ) 2230