xref: /aosp_15_r20/external/pytorch/torch/backends/_nnapi/serializer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import array
3import enum
4import functools
5import logging
6import operator
7import struct
8import sys
9from typing import List, NamedTuple, Optional, Tuple
10
11import torch
12
13
14# TODO: Add type annotations
15# TODO: Check tensor types for ops
16
17
18LOG = logging.getLogger("nnapi_serialize")
19
20
21class NNAPI_OperandCode:
22    FLOAT32 = 0
23    INT32 = 1
24    UINT32 = 2
25    TENSOR_FLOAT32 = 3
26    TENSOR_INT32 = 4
27    TENSOR_QUANT8_ASYMM = 5
28    BOOL = 6
29    TENSOR_QUANT16_SYMM = 7
30    TENSOR_FLOAT16 = 8
31    TENSOR_BOOL8 = 9
32    FLOAT16 = 10
33    TENSOR_QUANT8_SYMM_PER_CHANNEL = 11
34    TENSOR_QUANT16_ASYMM = 12
35
36
37class NNAPI_OperationCode:
38    ADD = 0
39    AVERAGE_POOL_2D = 1
40    CONCATENATION = 2
41    CONV_2D = 3
42    DEPTHWISE_CONV_2D = 4
43    DEPTH_TO_SPACE = 5
44    DEQUANTIZE = 6
45    EMBEDDING_LOOKUP = 7
46    FLOOR = 8
47    FULLY_CONNECTED = 9
48    HASHTABLE_LOOKUP = 10
49    L2_NORMALIZATION = 11
50    L2_POOL_2D = 12
51    LOCAL_RESPONSE_NORMALIZATION = 13
52    LOGISTIC = 14
53    LSH_PROJECTION = 15
54    LSTM = 16
55    MAX_POOL_2D = 17
56    MUL = 18
57    RELU = 19
58    RELU1 = 20
59    RELU6 = 21
60    RESHAPE = 22
61    RESIZE_BILINEAR = 23
62    RNN = 24
63    SOFTMAX = 25
64    SPACE_TO_DEPTH = 26
65    SVDF = 27
66    TANH = 28
67    BATCH_TO_SPACE_ND = 29
68    DIV = 30
69    MEAN = 31
70    PAD = 32
71    SPACE_TO_BATCH_ND = 33
72    SQUEEZE = 34
73    STRIDED_SLICE = 35
74    SUB = 36
75    TRANSPOSE = 37
76    ABS = 38
77    ARGMAX = 39
78    ARGMIN = 40
79    AXIS_ALIGNED_BBOX_TRANSFORM = 41
80    BIDIRECTIONAL_SEQUENCE_LSTM = 42
81    BIDIRECTIONAL_SEQUENCE_RNN = 43
82    BOX_WITH_NMS_LIMIT = 44
83    CAST = 45
84    CHANNEL_SHUFFLE = 46
85    DETECTION_POSTPROCESSING = 47
86    EQUAL = 48
87    EXP = 49
88    EXPAND_DIMS = 50
89    GATHER = 51
90    GENERATE_PROPOSALS = 52
91    GREATER = 53
92    GREATER_EQUAL = 54
93    GROUPED_CONV_2D = 55
94    HEATMAP_MAX_KEYPOINT = 56
95    INSTANCE_NORMALIZATION = 57
96    LESS = 58
97    LESS_EQUAL = 59
98    LOG = 60
99    LOGICAL_AND = 61
100    LOGICAL_NOT = 62
101    LOGICAL_OR = 63
102    LOG_SOFTMAX = 64
103    MAXIMUM = 65
104    MINIMUM = 66
105    NEG = 67
106    NOT_EQUAL = 68
107    PAD_V2 = 69
108    POW = 70
109    PRELU = 71
110    QUANTIZE = 72
111    QUANTIZED_16BIT_LSTM = 73
112    RANDOM_MULTINOMIAL = 74
113    REDUCE_ALL = 75
114    REDUCE_ANY = 76
115    REDUCE_MAX = 77
116    REDUCE_MIN = 78
117    REDUCE_PROD = 79
118    REDUCE_SUM = 80
119    ROI_ALIGN = 81
120    ROI_POOLING = 82
121    RSQRT = 83
122    SELECT = 84
123    SIN = 85
124    SLICE = 86
125    SPLIT = 87
126    SQRT = 88
127    TILE = 89
128    TOPK_V2 = 90
129    TRANSPOSE_CONV_2D = 91
130    UNIDIRECTIONAL_SEQUENCE_LSTM = 92
131    UNIDIRECTIONAL_SEQUENCE_RNN = 93
132    RESIZE_NEAREST_NEIGHBOR = 94
133
134
135class NNAPI_FuseCode:
136    FUSED_NONE = 0
137    FUSED_RELU = 1
138    FUSED_RELU1 = 2
139    FUSED_RELU6 = 3
140
141
142class OperandValueSourceType:
143    IMMEDIATE = 0
144    NUMBERED_BUFFER = 2
145    NUMBERED_MEMORY = 3
146
147
148# Scalar types that appear explicitly in models.
149# These must be kept in sync with
150# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
151# TODO: Expose these directly to Python to avoid maintaining this list.
152class TorchScalarTypes(enum.Enum):
153    QUINT8 = 13
154
155
156def approx_equal(lhs, rhs, tolerance=1e-6):
157    return abs(lhs - rhs) <= tolerance * min(lhs, rhs)
158
159
160def tensor_size(op_type, dims):
161    ITEM_SIZES = {
162        NNAPI_OperandCode.TENSOR_FLOAT32: 4,
163        NNAPI_OperandCode.TENSOR_INT32: 4,
164        NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1,
165        NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2,
166        NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2,
167    }
168    size = ITEM_SIZES[op_type]
169    for d in dims:
170        size *= d
171    return size
172
173
174def change_element(tup, index, value):
175    ls = list(tup)
176    ls[index] = value
177    return tuple(ls)
178
179
180class ConvPoolArgs2d(NamedTuple):
181    """Configuration arguments for a convolution."""
182
183    kernel_h: int
184    kernel_w: int
185    stride_h: int
186    stride_w: int
187    pad_t: int
188    pad_b: int
189    pad_l: int
190    pad_r: int
191    dilation_h: int
192    dilation_w: int
193    group: int
194
195
196class DimOrder(enum.Enum):
197    PRESUMED_CONTIGUOUS = 0
198    CHANNELS_LAST = 1
199    SCALAR_OR_VECTOR = 2
200    UNKNOWN_CONSTANT = 999
201
202
203class Operand(NamedTuple):
204    """Represenation of an NNAPI operand."""
205
206    # NNAPI operand type.  One of NNAPI_OperandCode.
207    # TODO: Make this an enum.
208    op_type: int
209
210    # This is always the PyTorch shape, which is NCHW for feature maps.
211    # The actual NNAPI operand might have a transposed shape.
212    # we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes
213    shape: Tuple[int, ...]
214
215    # Specifies how the shape of the operand that we define in NNAPI
216    # relates to the shape we track above.
217    # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match
218    #   the shape of the PyTorch tensor.
219    # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and
220    #   the NNAPI operand will be represented explicitly as NHWC.
221    dim_order: DimOrder
222
223    # Quantization params
224    scale: float
225    zero_point: int
226
227    def use_nchw(self):
228        if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS:
229            return True
230        if self.dim_order is DimOrder.CHANNELS_LAST:
231            return False
232        raise Exception("Unknown dim order")  # noqa: TRY002
233
234
235def broadcast_shapes(shape1, shape2):
236    assert len(shape1) > 0
237    assert len(shape2) > 0
238    s1 = list(shape1)
239    s2 = list(shape2)
240    # TODO: Support non-equal-rank broadcast where semantics match.
241    # This can be tricky for NHWC tensors because dimension orders
242    # don't match between PT and NNAPI, even though semantics match.
243    if len(s1) > len(s2):
244        # s2 = [1] * (len(s1) - len(s2)) + s2
245        raise Exception(  # noqa: TRY002
246            "Non-equal-rank broadcast is not supported yet."
247        )  # noqa: TRY002
248    if len(s2) > len(s1):
249        # s3 = [1] * (len(s2) - len(s1)) + s1
250        raise Exception(  # noqa: TRY002
251            "Non-equal-rank broadcast is not supported yet."
252        )  # noqa: TRY002
253    ret = []
254    for d1, d2 in zip(s1, s2):
255        if d1 == 1:
256            ret.append(d2)
257        elif d2 == 1:
258            ret.append(d1)
259        elif d1 == d2:
260            ret.append(d1)
261        else:
262            raise Exception(  # noqa: TRY002
263                f"Cannot broadcast shapes: {shape1} and {shape2}"
264            )  # noqa: TRY002
265    return tuple(ret)
266
267
268def get_conv_pool_shape(image_shape, args, out_ch, transpose):
269    batch, in_c, in_h, in_w = image_shape
270
271    # TODO: Handle dilation
272    if args.dilation_h != 1 or args.dilation_w != 1:
273        raise Exception("Dilation not supported yet.")  # noqa: TRY002
274
275    if transpose:
276        out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b
277        out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l
278    else:
279        out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1
280        out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1
281
282    # Handle variable-sized tensors.
283    if in_h == 0:
284        out_h = 0
285    if in_w == 0:
286        out_w = 0
287
288    out_shape = (batch, out_ch, out_h, out_w)
289    return out_shape
290
291
292def fix_shape(shape, dim_order):
293    # Return the actual shape that an operand should have in NNAPI,
294    # given a PyTorch shape and dimension order.  This is where we
295    # convert from PyTorch's "always NCHW" shape to explicit NHWC.
296    if dim_order is DimOrder.PRESUMED_CONTIGUOUS:
297        return shape
298    if dim_order is DimOrder.CHANNELS_LAST:
299        return tuple([shape[0]] + list(shape[2:]) + [shape[1]])
300    if dim_order is DimOrder.SCALAR_OR_VECTOR:
301        assert len(shape) == 0 or len(shape) == 1
302        return shape
303    if dim_order is DimOrder.UNKNOWN_CONSTANT:
304        # XXX think this through
305        return shape
306    raise Exception(f"Bad dim_order: {dim_order!r}.")  # noqa: TRY002
307
308
309def reverse_map_dim(dim_order, d):
310    # Return the original PyTorch dimension position for a given dimension.
311    # d should be the dimension that NNAPI will see.
312    # reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x
313    # reverse_map_dim(CHANNELS_LAST, 3) == 1
314    if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR):
315        return d
316    assert dim_order is DimOrder.CHANNELS_LAST
317    return [0, 2, 3, 1][d]
318
319
320def flex_name(op_id, dim):
321    # Return the local variable name for the computed flexible size
322    # for a given op and dimension.
323    return f"s_{op_id}_{dim}"
324
325
326class _NnapiSerializer:
327    def __init__(self, config, use_int16_for_qint16=False):
328        self.operands = []
329        self.values = []
330        self.operations = []
331        self.value_data = []
332        self.operation_args = []
333        self.inputs = []
334        self.outputs = []
335        self.flexible_shape_computation_lines = []
336
337        self.modules = {}
338        self.constants = {}
339        self.tensor_sequences = {}
340        self.jitval_operand_map = {}
341        self.cached_immediates = {}
342        self.used_weights = []
343        self.weight_offset = 0
344        self.use_int16_for_qint16 = use_int16_for_qint16
345
346        if config is None:
347            config = {}
348
349    def get_next_operand_id(self):
350        return len(self.operands)
351
352    # Add a tensor operand corresponding to a JIT Value.
353    # Returns the NNAPI operand ID.  Can be looked up later with
354    # get_tensor_operand_by_jitval.
355    def add_tensor_operand(self, jitval, oper):
356        assert isinstance(oper, Operand)
357        if jitval in self.jitval_operand_map:
358            raise Exception(f"Duplicate tensor: {jitval!r}")  # noqa: TRY002
359
360        operand_id = self.get_next_operand_id()
361        self.operands.append(oper)
362        self.jitval_operand_map[jitval] = operand_id
363        return operand_id
364
365    # Add a tensor operand that does not correspond to a JIT Value.
366    # Useful for cases where multiple NNAPI operands are required
367    # to implement one JIT IR node.  Returns the NNAPI operand ID.
368    def add_anonymous_tensor_operand(self, oper):
369        assert isinstance(oper, Operand)
370        operand_id = self.get_next_operand_id()
371        self.operands.append(oper)
372        return operand_id
373
374    def torch_tensor_to_operand(self, tensor, dim_order):
375        dtype = str(tensor.dtype).replace("torch.", "")
376        scale = 0.0
377        zero_point = 0
378        if dtype == "float32":
379            op_type = NNAPI_OperandCode.TENSOR_FLOAT32
380        elif dtype == "int32":
381            op_type = NNAPI_OperandCode.TENSOR_INT32
382        elif dtype == "quint8":
383            op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
384            scale = tensor.q_scale()
385            zero_point = tensor.q_zero_point()
386        elif dtype == "qint32":
387            op_type = NNAPI_OperandCode.TENSOR_INT32
388            scale = tensor.q_scale()
389            zero_point = tensor.q_zero_point()
390            assert zero_point == 0
391        elif dtype == "int16":
392            if self.use_int16_for_qint16:
393                nnapi_dtype = getattr(tensor, "nnapi_dtype", None)
394                op_codes = (
395                    NNAPI_OperandCode.TENSOR_QUANT16_SYMM,
396                    NNAPI_OperandCode.TENSOR_QUANT16_ASYMM,
397                )
398                if nnapi_dtype in op_codes:
399                    op_type = nnapi_dtype
400                    scale = tensor.nnapi_scale
401                    zero_point = tensor.nnapi_zero_point
402                else:
403                    raise Exception(  # noqa: TRY002
404                        f"`nnapi_type` needs to be one of {op_codes} for `int16`"
405                    )
406            else:
407                raise Exception(  # noqa: TRY002
408                    "`int16` isn't supported. If you're trying to represent NNAPI"
409                    " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`"
410                )
411        else:
412            raise Exception(  # noqa: TRY002
413                f"Can't handle input with dtype '{tensor.dtype}'"
414            )  # noqa: TRY002
415        return Operand(
416            shape=tuple(tensor.shape),
417            op_type=op_type,
418            dim_order=dim_order,
419            scale=scale,
420            zero_point=zero_point,
421        )
422
423    def add_tensor_operand_for_input(self, arg_idx, jitval, tensor):
424        dim_order = (
425            DimOrder.CHANNELS_LAST
426            if getattr(tensor, "nnapi_nhwc", False)
427            else DimOrder.PRESUMED_CONTIGUOUS
428        )
429        toper = self.torch_tensor_to_operand(tensor, dim_order)
430        operand_id = self.add_tensor_operand(jitval, toper)
431        self.inputs.append(operand_id)
432        for dim, size in enumerate(tensor.shape):
433            if size == 0:
434                self.compute_operand_shape(
435                    operand_id, dim, f"args[{arg_idx}].shape[{dim}]"
436                )
437        return operand_id
438
439    def add_tensor_operand_for_weight(
440        self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT
441    ):
442        toper = self.torch_tensor_to_operand(tensor, dim_order)
443        operand_id = len(self.operands)
444        self.operands.append(toper)
445        tsize = tensor_size(toper.op_type, toper.shape)
446        psize = ((tsize - 1) | 0x3) + 1
447        self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER))
448        buf_num = len(self.used_weights)
449        offset = 0
450        self.value_data.append(struct.pack("iii", buf_num, offset, tsize))
451        # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor
452        if dim_order == DimOrder.CHANNELS_LAST:
453            tensor = tensor.permute(0, 2, 3, 1)
454        self.used_weights.append(tensor)
455        return operand_id
456
457    def add_immediate_operand(self, code, value, dims):
458        assert isinstance(dims, tuple)
459        cache_key = (code, value)
460        if cache_key not in self.cached_immediates:
461            operand_id = len(self.operands)
462            self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0))
463            self.values.append((operand_id, OperandValueSourceType.IMMEDIATE))
464            self.value_data.append(value)
465            self.cached_immediates[cache_key] = operand_id
466        return self.cached_immediates[cache_key]
467
468    def add_immediate_int_scalar(self, value):
469        return self.add_immediate_operand(
470            NNAPI_OperandCode.INT32, struct.pack("i", value), ()
471        )
472
473    def add_immediate_float_scalar(self, value):
474        return self.add_immediate_operand(
475            NNAPI_OperandCode.FLOAT32, struct.pack("f", value), ()
476        )
477
478    def add_immediate_bool_scalar(self, value):
479        return self.add_immediate_operand(
480            NNAPI_OperandCode.BOOL, b"\x01" if value else b"\x00", ()
481        )
482
483    def add_immediate_int_vector(self, value):
484        return self.add_immediate_operand(
485            NNAPI_OperandCode.TENSOR_INT32,
486            array.array("i", value).tobytes(),
487            (len(value),),
488        )
489
490    def has_operand_for_jitval(self, jitval):
491        return jitval in self.jitval_operand_map
492
493    def get_tensor_operand_by_jitval(self, jitval):
494        operand_id = self.jitval_operand_map[jitval]
495        return (operand_id, self.operands[operand_id])
496
497    def get_tensor_operand_by_jitval_fixed_size(self, jitval):
498        op_id, oper = self.get_tensor_operand_by_jitval(jitval)
499        for s in oper.shape:
500            if s == 0:
501                # TODO: Improve this error message, possibly after converting
502                # many callsites to support flexible size.
503                raise Exception(  # noqa: TRY002
504                    "Flexible size is not supported for this operand."
505                )  # noqa: TRY002
506            if s < 0:
507                # runtime flex
508                LOG.warning("Operand %s has runtime flex shape", oper)
509        return op_id, oper
510
511    def get_tensor_operand_or_constant(
512        self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS
513    ):
514        operand_id = self.jitval_operand_map.get(jitval)
515        if operand_id is None:
516            _, value = self.get_constant_value(jitval, "TensorType")
517            operand_id = self.add_tensor_operand_for_weight(value, dim_order)
518        return (operand_id, self.operands[operand_id])
519
520    def get_tensor_operand_for_weight(self, jitval):
521        _, value = self.get_constant_value(jitval, "TensorType")
522        operand_id = self.add_tensor_operand_for_weight(value)
523        return (operand_id, self.operands[operand_id])
524
525    def add_operation(self, opcode, inputs, outputs):
526        self.operations.append((opcode, len(inputs), len(outputs)))
527        self.operation_args.extend(inputs + outputs)
528
529    def add_tensor_sequence(self, jitval, values):
530        assert jitval not in self.tensor_sequences
531        self.tensor_sequences[jitval] = values
532
533    def add_constant_value(self, jitval, ctype, value):
534        assert jitval not in self.constants
535        self.constants[jitval] = (ctype, value)
536
537    def get_constant_value(self, jitval, typekind=None):
538        record = self.constants.get(jitval)
539        if record is None:
540            raise Exception(  # noqa: TRY002
541                f"Could not find constant value for '{jitval!r}'."
542            )  # noqa: TRY002
543        ctype, _ = record
544        if typekind is not None and ctype.kind() != typekind:
545            raise Exception(  # noqa: TRY002
546                f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'"
547            )
548        return record
549
550    def operand_to_template_torchscript(self, op_id, oper, shape=None):
551        """Return a TorchScript expression to build a template for a given operand."""
552        if shape is None:
553            shape = oper.shape
554        else:
555            assert len(shape) == len(oper.shape)
556
557        shape_parts = ["("]
558        for d, s in enumerate(shape):
559            if s > 0:
560                # Fixed shape dimension: just add the value.
561                shape_parts.append(str(s))
562            elif s == 0:
563                # Load time flexible shape dimension: it should have been computed in a variable.
564                shape_parts.append(flex_name(op_id, d))
565            elif s == -1:
566                # Runtime flexible shape
567                shape_parts.append("0")
568            else:
569                raise Exception(  # noqa: TRY002
570                    "Unknown dim value, dimensions should be >= -1"
571                )  # noqa: TRY002
572            shape_parts.append(",")
573        shape_parts.append(")")
574        shape_code = "".join(shape_parts)
575        if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
576            return f"torch.zeros({shape_code}, dtype=torch.float32)"
577        elif oper.op_type == NNAPI_OperandCode.TENSOR_INT32:
578            return f"torch.zeros({shape_code}, dtype=torch.int32)"
579        elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
580            return (
581                f"torch.quantize_per_tensor("
582                f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)"
583                f".expand({shape_code}).contiguous()"
584            )
585        elif oper.op_type in (
586            NNAPI_OperandCode.TENSOR_QUANT16_ASYMM,
587            NNAPI_OperandCode.TENSOR_QUANT16_SYMM,
588        ):
589            if self.use_int16_for_qint16:
590                return f"torch.zeros({shape_code}, dtype=torch.int16)"
591            else:
592                raise Exception(  # noqa: TRY002
593                    "`int16` isn't supported. If you're trying to represent NNAPI"
594                    " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`"
595                )
596
597        raise Exception(  # noqa: TRY002
598            f"Unsupported output operand type: {oper.op_type}"
599        )  # noqa: TRY002
600
601    def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim):
602        self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim))
603
604    def compute_operand_shape(self, op_id, dim, expr):
605        self.flexible_shape_computation_lines.append(
606            f"{flex_name(op_id, dim)} = {expr}"
607        )
608
609    def transpose_to_nhwc(self, in_id, oper):
610        if oper.shape[2:] != (1, 1):
611            raise Exception(  # noqa: TRY002
612                "Automatic transpose only supported for H,W == 1,1"
613            )  # noqa: TRY002
614
615        out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST)
616
617        inputs = [None] * 2
618        inputs[0] = in_id
619        inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1])
620
621        outputs = [None] * 1
622        outputs[0] = self.add_anonymous_tensor_operand(out_oper)
623
624        self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs)
625
626        return outputs[0], out_oper
627
628    # Transpose inputs as necessary to allow broadcasting.
629    def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper):
630        if in0_oper.dim_order == in1_oper.dim_order:
631            return in0_id, in0_oper, in1_id, in1_oper
632
633        # Assume NHWC is preferred if there is a mismatch.
634        orders = (in0_oper.dim_order, in1_oper.dim_order)
635        if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST):
636            return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper)
637        if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS):
638            return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper)
639
640        raise Exception(  # noqa: TRY002
641            f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}"
642        )
643
644    def get_size_arg(self, jitval):
645        ctype, value = self.get_constant_value(jitval)
646        if ctype.kind() == "ListType":
647            assert ctype.getElementType().kind() == "IntType"
648            return value
649        raise Exception(  # noqa: TRY002
650            f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'"
651        )  # noqa: TRY002
652
653    def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config):
654        pc = [i.item() for i in packed_config]
655        assert pc[0] == 2
656        strides = [pc[1], pc[2]]
657        paddings = [pc[3], pc[4]]
658        dilations = [pc[5], pc[6]]
659        output_padding = [pc[7], pc[8]]
660        group_num = pc[9]
661
662        assert len(pc) == 11
663        assert output_padding == [0, 0]
664
665        return self.get_conv_pool_args_2d_common(
666            kernel_size, strides, paddings, dilations, group_num
667        )
668
669    def get_conv_pool_args_2d_from_jit(
670        self, kernel_size, stride, padding, dilation=None, group=None
671    ):
672        strides = self.get_size_arg(stride)
673        paddings = self.get_size_arg(padding)
674        if dilation is None:
675            dilations = [1, 1]
676        else:
677            dilations = self.get_size_arg(dilation)
678        if group is not None:
679            _, group_num = self.get_constant_value(group, "IntType")
680        else:
681            group_num = None
682        return self.get_conv_pool_args_2d_common(
683            kernel_size, strides, paddings, dilations, group_num
684        )
685
686    def get_conv_pool_args_2d_common(
687        self, kernel_size, strides, paddings, dilations, group_num
688    ):
689        kernels = list(kernel_size)
690
691        assert len(kernels) == 2
692        assert len(strides) == 2
693        assert len(paddings) == 2
694        assert len(dilations) == 2
695
696        # NNAPI uses 4 values for padding.
697        ph, pw = paddings
698        real_paddings = [ph, ph, pw, pw]
699
700        return ConvPoolArgs2d(
701            *(kernels + strides + real_paddings + dilations + [group_num])
702        )
703
704    def serialize_model(self, model, inputs, return_shapes=None):
705        self.add_immediate_bool_scalar(False)
706        self.add_immediate_bool_scalar(True)
707
708        inp_dim_orders = []
709        out_dim_orders = []
710
711        self_jitval = next(model.graph.inputs())
712        self.add_constant_value(self_jitval, self_jitval.type(), model)
713
714        for arg_idx, (input_value, input_tensor) in enumerate(
715            zip(list(model.graph.inputs())[1:], inputs)
716        ):
717            op_id = self.add_tensor_operand_for_input(
718                arg_idx, input_value, input_tensor
719            )
720            inp_dim_orders.append(self.operands[op_id].dim_order.value)
721
722        for idx, node in enumerate(model.graph.nodes()):
723            LOG.debug("Processing node #%d: %r", idx, node)
724            self.add_node(node)
725
726        retn = model.graph.return_node()
727        assert retn.inputsSize() == 1
728        assert retn.outputsSize() == 0
729        retn_input = retn.inputsAt(0)
730        template_return_lines = ["return ["]
731        if retn_input.type().kind() == "TensorType":
732            return_values = [retn_input]
733            retval_count = -1
734        elif retn_input.type().kind() == "TupleType":
735            return_values = self.tensor_sequences[retn_input]
736            retval_count = len(return_values)
737        else:
738            raise Exception(  # noqa: TRY002
739                f"Unsupported return type: {retn_input.type()}"
740            )  # noqa: TRY002
741
742        if return_shapes is not None:
743            assert len(return_shapes) == len(return_values)
744        for i, v in enumerate(return_values):
745            op_id = self.jitval_operand_map[v]
746            self.outputs.append(op_id)
747            out_dim_orders.append(self.operands[op_id].dim_order.value)
748            shape = return_shapes[i] if return_shapes else None
749            template_return_lines.append(
750                self.operand_to_template_torchscript(op_id, self.operands[op_id], shape)
751                + ","
752            )
753        template_return_lines.append("]")
754
755        model = []
756
757        version = 1
758        header = struct.pack(
759            "iiiiii",
760            version,
761            len(self.operands),
762            len(self.values),
763            len(self.operations),
764            len(self.inputs),
765            len(self.outputs),
766        )
767        model.append(header)
768
769        serialized_values, serialized_value_data = self.serialize_values()
770
771        model.extend(
772            struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands
773        )
774        model.extend(serialized_values)
775        model.extend(struct.pack("iii", *x) for x in self.operations)
776
777        # Compact the model so we can get its length so far.
778        model = [b"".join(model)]
779        model_offset = len(model[0])
780        # Model offset is the index into the model (in 32-bit words, not bytes)
781        # of the next dimension we're about to serialize.  If it's 0,
782        # generate code to mutate it before passing to NNAPI.
783        assert model_offset % 4 == 0
784        model_offset = int(model_offset / 4)
785
786        for op_id, (_, dims, dim_order, _, _) in enumerate(self.operands):
787            shape = fix_shape(dims, dim_order)
788            for d, s in enumerate(shape):
789                if s == 0:
790                    pt_d = reverse_map_dim(dim_order, d)
791                    self.flexible_shape_computation_lines.append(
792                        f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}"
793                    )
794                model_offset += 1
795
796            # convert runtime flex shape from -1 to 0
797            shape = tuple(d if d != -1 else 0 for d in shape)
798            model.append(self.serialize_ints(shape))
799
800        model.extend(serialized_value_data)
801        model.append(self.serialize_ints(self.operation_args))
802        model.append(self.serialize_ints(self.inputs))
803        model.append(self.serialize_ints(self.outputs))
804
805        self.flexible_shape_computation_lines.extend(template_return_lines)
806
807        return (
808            array.array("i", b"".join(model)),
809            self.used_weights,
810            inp_dim_orders,
811            out_dim_orders,
812            self.flexible_shape_computation_lines,
813            retval_count,
814        )
815
816    def serialize_values(self):
817        serialized_values = []
818        serialized_value_data = []
819        assert len(self.values) == len(self.value_data)
820        for (op_index, source_type), data in zip(self.values, self.value_data):
821            source_length = len(data)
822
823            # Pad with 0 bytes out to a multiple of 4 for alignment.
824            physical_length = ((source_length - 1) | 0x3) + 1
825            padded_data = data + (b"\0" * (physical_length - source_length))
826
827            serialized_values.append(
828                struct.pack("iii", op_index, source_type, source_length)
829            )
830            serialized_value_data.append(padded_data)
831
832        return serialized_values, serialized_value_data
833
834    @staticmethod
835    def serialize_ints(ints):
836        return array.array("i", ints).tobytes()
837
838    ADDER_MAP = {
839        "prim::GetAttr": lambda self, node: self.add_getattr(node),
840        "prim::Constant": lambda self, node: self.add_constant_node(node),
841        "prim::ListConstruct": lambda self, node: self.add_list_construct(node),
842        "prim::TupleConstruct": lambda self, node: self.add_tuple_construct(node),
843        "aten::unsqueeze": lambda self, node: self.add_unsqueeze(node),
844        "aten::to": lambda self, node: self.add_to(node),
845        "aten::detach": lambda self, node: self._identity(node),
846        "aten::reshape": lambda self, node: self.add_reshape(node),
847        "aten::flatten": lambda self, node: self.add_flatten(node),
848        "aten::slice": lambda self, node: self.add_slice(node),
849        "aten::size": lambda self, node: self.add_size(node),
850        "aten::cat": lambda self, node: self.add_cat(node),
851        "aten::mean": lambda self, node: self.add_mean(node),
852        "aten::quantize_per_tensor": lambda self, node: self.add_quantize(node),
853        "aten::dequantize": lambda self, node: self.add_dequantize(node),
854        "aten::add": lambda self, node: self.add_add_sub_op(
855            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE
856        ),
857        "aten::sub": lambda self, node: self.add_add_sub_op(
858            node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE
859        ),
860        "aten::mul": lambda self, node: self.add_pointwise_simple_binary_broadcast_op(
861            node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE
862        ),
863        "aten::div": lambda self, node: self.add_pointwise_simple_binary_broadcast_op(
864            node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE
865        ),
866        "aten::relu": lambda self, node: self.add_pointwise_simple_unary_op(
867            node, NNAPI_OperationCode.RELU
868        ),
869        "aten::sigmoid": lambda self, node: self.add_pointwise_simple_unary_op(
870            node, NNAPI_OperationCode.LOGISTIC
871        ),
872        "aten::softmax": lambda self, node: self.add_softmax(node),
873        "aten::hardtanh": lambda self, node: self.add_hardtanh(node),
874        "aten::avg_pool2d": lambda self, node: self.add_avg_pool2d(node),
875        "aten::max_pool2d": lambda self, node: self.add_pool2d_node(
876            node, NNAPI_OperationCode.MAX_POOL_2D
877        ),
878        "aten::adaptive_avg_pool2d": lambda self, node: self.add_adaptive_avg_pool2d(
879            node
880        ),
881        "aten::upsample_nearest2d": lambda self, node: self.add_upsample_nearest2d(
882            node
883        ),
884        "aten::prelu": lambda self, node: self.add_prelu_op(node),
885        "aten::addmm": lambda self, node: self.add_addmm(node),
886        "aten::linear": lambda self, node: self.add_linear(node),
887        "aten::_convolution": lambda self, node: self.add_conv_underscore(node),
888        "aten::conv2d": lambda self, node: self.add_conv2d(node),
889        "aten::log_softmax": lambda self, node: self.add_log_softmax(node),
890        "quantized::linear": lambda self, node: self.add_qlinear(node),
891        "quantized::conv2d": lambda self, node: self.add_qconv2d(
892            node, NNAPI_FuseCode.FUSED_NONE
893        ),
894        "quantized::conv2d_relu": lambda self, node: self.add_qconv2d(
895            node, NNAPI_FuseCode.FUSED_RELU
896        ),
897        "quantized::conv_transpose2d": lambda self, node: self.add_qconv2d(
898            node, NNAPI_FuseCode.FUSED_NONE, transpose=True
899        ),
900        "quantized::add": lambda self, node: self.add_qadd(
901            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE
902        ),
903        "quantized::add_relu": lambda self, node: self.add_qadd(
904            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU
905        ),
906        "quantized::mul": lambda self, node: self.add_qadd(
907            node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE
908        ),
909    }
910
911    def add_node(self, node):
912        adder = self.ADDER_MAP.get(node.kind())
913        if not adder:
914            raise Exception(  # noqa: TRY002
915                f"Unsupported node kind ({node.kind()!r}) in node {node!r}"
916            )  # noqa: TRY002
917        adder(self, node)
918
919    def _identity(self, node):
920        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
921        jitval = node.outputsAt(0)
922        self.jitval_operand_map[jitval] = in_id
923
924    def add_getattr(self, node):
925        assert node.inputsSize() == 1
926        assert node.outputsSize() == 1
927        obj_ctype, obj = self.get_constant_value(node.inputsAt(0))
928        assert str(obj_ctype).startswith("__torch__.")
929        name = node.s("name")
930        value = getattr(obj, name)
931        output = node.outputsAt(0)
932        ctype = output.type()
933        self.add_constant_value(output, ctype, value)
934
935    def add_constant_node(self, node):
936        assert node.inputsSize() == 0
937        assert node.outputsSize() == 1
938        output = node.outputsAt(0)
939        ctype = output.type()
940        value = output.toIValue()
941        self.add_constant_value(output, ctype, value)
942
943    def add_list_construct(self, node):
944        assert node.outputsSize() == 1
945        output = node.outputsAt(0)
946        ctype = output.type()
947        const_vals: Optional[List] = []
948        tensors: Optional[List] = []
949        for inp in node.inputs():
950            if const_vals is not None and inp in self.constants:
951                _, val = self.get_constant_value(inp)
952                const_vals.append(val)
953            else:
954                const_vals = None
955            if tensors is not None and inp.type().kind() == "TensorType":
956                tensors.append(inp)
957            else:
958                tensors = None
959
960        if const_vals is not None:
961            # NOTE: Now that TorchScript supports list constants,
962            # this code path might not be used anymore.
963            self.add_constant_value(output, ctype, const_vals)
964        if tensors is not None:
965            self.add_tensor_sequence(output, tensors)
966        if const_vals is None and tensors is None:
967            raise Exception(  # noqa: TRY002
968                f"Unable to handle ListConstruct node.  Neither all constants nor all tensors. {node!r}"
969            )
970
971    def add_tuple_construct(self, node):
972        assert node.outputsSize() == 1
973        output = node.outputsAt(0)
974        values = list(node.inputs())
975        self.add_tensor_sequence(output, values)
976
977    def add_unsqueeze(self, node):
978        assert node.inputsSize() == 2
979        assert node.outputsSize() == 1
980
981        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
982
983        _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
984        assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS
985
986        real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1
987        out_shape_list = list(in_oper.shape)
988        out_shape_list.insert(real_dim, 1)
989        out_shape = tuple(out_shape_list)
990        out_oper = in_oper._replace(shape=out_shape)
991
992        inputs = [None] * 2
993        inputs[0] = in_id
994        inputs[1] = self.add_immediate_int_scalar(dim)
995
996        outputs = [None] * 1
997        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
998
999        self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs)
1000
1001    def add_to(self, node):
1002        # Handle to("cpu") / to("gpu") case
1003        self._identity(node)
1004
1005    def add_reshape(self, node):
1006        assert node.inputsSize() == 2
1007        assert node.outputsSize() == 1
1008
1009        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
1010
1011        shape_ctype, shape = self.get_constant_value(node.inputsAt(1))
1012        assert shape_ctype.kind() == "ListType"
1013        assert shape_ctype.getElementType().kind() == "IntType"
1014        is_trivial_reshape = len(shape) == 2 and shape[1] == -1
1015
1016        if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape:
1017            raise Exception(  # noqa: TRY002
1018                "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1]."
1019            )
1020
1021        # Bit of a hack here.  Use a real tensor to infer the output shape.
1022        out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape
1023        out_oper = in_oper._replace(
1024            shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS
1025        )
1026
1027        inputs = [None] * 2
1028        inputs[0] = in_id
1029        inputs[1] = self.add_immediate_int_vector(shape)
1030
1031        outputs = [None] * 1
1032        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
1033
1034        self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
1035
1036    def add_flatten(self, node):
1037        assert node.inputsSize() == 3
1038        assert node.outputsSize() == 1
1039
1040        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
1041
1042        start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
1043        end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
1044
1045        # channels last with channels == 1 or (height & width both 1)
1046        is_trivial_flatten = len(in_oper.shape) == 4 and (
1047            in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1)
1048        )
1049        if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten:
1050            raise Exception(  # noqa: TRY002
1051                "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1"
1052            )
1053
1054        if start_dim < 0:
1055            start_dim += len(in_oper.shape)
1056        if end_dim < 0:
1057            end_dim += len(in_oper.shape)
1058
1059        out_shape = (
1060            in_oper.shape[:start_dim]
1061            + (functools.reduce(operator.mul, in_oper.shape[start_dim : end_dim + 1]),)
1062            + in_oper.shape[end_dim + 1 :]
1063        )
1064
1065        if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]):
1066            raise Exception(  # noqa: TRY002
1067                "Flattening flexible dims is not supported yet"
1068            )  # noqa: TRY002
1069        non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :]
1070        if non_flattened_dims.count(0) > 1:
1071            raise Exception("Only 1 dim can be flexible")  # noqa: TRY002
1072
1073        out_oper = in_oper._replace(
1074            shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS
1075        )
1076        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
1077
1078        for idx, dim in enumerate(out_shape):
1079            if dim == 0:
1080                self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0))
1081
1082        inputs_1 = tuple(dim if dim != 0 else -1 for dim in out_shape)
1083        inputs = [None] * 2
1084        inputs[0] = in_id
1085        inputs[1] = self.add_immediate_int_vector(inputs_1)
1086
1087        outputs = [None] * 1
1088        outputs[0] = out_id
1089
1090        self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
1091
1092    def add_slice(self, node):
1093        assert node.inputsSize() == 5
1094        assert node.outputsSize() == 1
1095
1096        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
1097        _, dim_value = self.get_constant_value(node.inputsAt(1))
1098        _, start_value = self.get_constant_value(node.inputsAt(2))
1099        _, stop_value = self.get_constant_value(node.inputsAt(3))
1100        _, step_value = self.get_constant_value(node.inputsAt(4))
1101
1102        if start_value is None:
1103            start_value = 0
1104        if stop_value is None:
1105            stop_value = sys.maxsize
1106
1107        if start_value < 0:
1108            start_value += in_oper.shape[dim_value]
1109        elif start_value == sys.maxsize:
1110            start_value = 0
1111
1112        if start_value == 0 and stop_value == sys.maxsize:
1113            self._identity(node)
1114            return
1115
1116        if in_oper.shape[dim_value] == 0:
1117            raise Exception("Unable to slice with flexible shape")  # noqa: TRY002
1118
1119        if stop_value < 0:
1120            stop_value += in_oper.shape[dim_value]
1121        elif stop_value == sys.maxsize:
1122            stop_value = in_oper.shape[dim_value]
1123
1124        if start_value >= stop_value:
1125            raise Exception(  # noqa: TRY002
1126                "Slice start value should be less than stop value"
1127            )  # noqa: TRY002
1128
1129        out_len = (stop_value - start_value) // step_value
1130        out_shape = tuple(
1131            out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape)
1132        )
1133        out_id = self.add_tensor_operand(
1134            node.outputsAt(0), in_oper._replace(shape=out_shape)
1135        )
1136
1137        # flex inputs
1138        end_mask = 0
1139        for idx, dim in enumerate(out_shape):
1140            if dim == 0:
1141                self.forward_operand_shape(out_id, idx, in_id, idx)
1142                end_mask |= 1 << idx
1143
1144        inputs = [None] * 7
1145        inputs[0] = in_id
1146        inputs[1] = self.add_immediate_int_vector(
1147            [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))]
1148        )
1149        inputs[2] = self.add_immediate_int_vector(
1150            [
1151                stop_value if i == dim_value else dim
1152                for i, dim in enumerate(in_oper.shape)
1153            ]
1154        )
1155        inputs[3] = self.add_immediate_int_vector(
1156            [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))]
1157        )
1158        inputs[4] = self.add_immediate_int_scalar(0)  # begin mask
1159        inputs[5] = self.add_immediate_int_scalar(end_mask)
1160        inputs[6] = self.add_immediate_int_scalar(0)  # shrink axis mas
1161
1162        outputs = [None] * 1
1163        outputs[0] = out_id
1164
1165        self.add_operation(NNAPI_OperationCode.STRIDED_SLICE, inputs, outputs)
1166
1167    def add_size(self, node):
1168        assert node.inputsSize() == 2
1169        assert node.outputsSize() == 1
1170
1171        _, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
1172        _, value = self.constants[node.inputsAt(1)]
1173        res = in_oper.shape[value]
1174        output = node.outputsAt(0)
1175        self.add_constant_value(output, output.type(), res)
1176
1177    def add_cat(self, node):
1178        assert node.inputsSize() == 2
1179        assert node.outputsSize() == 1
1180
1181        tensors = self.tensor_sequences[node.inputsAt(0)]
1182        _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
1183
1184        assert len(tensors) > 0
1185        in_ids = []
1186        out_oper = None
1187        out_dim_size = 0
1188        for inp in tensors:
1189            in_id, in_oper = self.get_tensor_operand_by_jitval(inp)
1190            if out_oper is None:
1191                out_shape = change_element(in_oper.shape, dim, -1)
1192                out_oper = in_oper._replace(shape=out_shape)
1193            assert in_oper.op_type == out_oper.op_type
1194            assert in_oper.dim_order == out_oper.dim_order
1195            assert change_element(in_oper.shape, dim, -1) == change_element(
1196                out_oper.shape, dim, -1
1197            )
1198            # TODO: Possibly check scale and zero point.
1199            in_ids.append(in_id)
1200            # TODO: Possibly support variable-sized inputs.
1201            out_dim_size += in_oper.shape[dim]
1202
1203        assert out_oper is not None
1204        out_oper = out_oper._replace(
1205            shape=change_element(out_oper.shape, dim, out_dim_size)
1206        )
1207
1208        if in_oper.dim_order == DimOrder.CHANNELS_LAST:  # type: ignore[possibly-undefined]
1209            assert len(out_oper.shape) == 4
1210            nnapi_dim = [0, 3, 1, 2][dim]
1211        else:
1212            nnapi_dim = dim
1213
1214        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
1215        for idx, d in enumerate(out_oper.shape):
1216            if d == 0:
1217                if idx == dim:
1218                    shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids)
1219                    self.compute_operand_shape(out_id, idx, shape)
1220                else:
1221                    self.forward_operand_shape(out_id, idx, in_ids[0], idx)
1222
1223        inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)]
1224
1225        outputs = [None] * 1
1226        outputs[0] = out_id
1227
1228        self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs)
1229
1230    def add_mean(self, node):
1231        assert node.inputsSize() == 4
1232        assert node.outputsSize() == 1
1233
1234        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
1235        dim_ctype, dim = self.get_constant_value(node.inputsAt(1))
1236        assert dim_ctype.kind() == "ListType"
1237        assert dim_ctype.getElementType().kind() == "IntType"
1238        _, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType")
1239        # Expect None for dtype
1240        self.get_constant_value(node.inputsAt(3), "NoneType")
1241
1242        if in_oper.dim_order == DimOrder.CHANNELS_LAST:
1243            assert len(in_oper.shape) == 4
1244            nnapi_dim = [[0, 3, 1, 2][d] for d in dim]
1245        else:
1246            nnapi_dim = dim
1247
1248        collapsed_dims = set()
1249        for d in dim:
1250            if d < 0:
1251                d += len(in_oper.shape)
1252            collapsed_dims.add(d)
1253
1254        if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim:
1255            assert collapsed_dims.issuperset({2, 3})
1256            out_dim_order = DimOrder.PRESUMED_CONTIGUOUS
1257        else:
1258            out_dim_order = in_oper.dim_order
1259
1260        out_shape = []
1261        for i, s in enumerate(in_oper.shape):
1262            if i not in collapsed_dims:
1263                out_shape.append(s)
1264            elif keep_dim:
1265                out_shape.append(1)
1266
1267        out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order)
1268
1269        inputs = [None] * 3
1270        inputs[0] = in_id
1271        inputs[1] = self.add_immediate_int_vector(nnapi_dim)
1272        inputs[2] = self.add_immediate_int_scalar(keep_dim)
1273
1274        outputs = [None] * 1
1275        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
1276
1277        self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs)
1278
1279    def add_quantize(self, node):
1280        assert node.inputsSize() == 4
1281        assert node.outputsSize() == 1
1282
1283        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
1284        if in_oper.dim_order != DimOrder.CHANNELS_LAST:
1285            raise Exception(  # noqa: TRY002
1286                "Most hardware backends prefer NHWC quantized tensors.  "
1287                "Try setting `t.nnapi_nhwc = True` on your tensor inputs.  "
1288            )
1289        _, scale = self.get_constant_value(node.inputsAt(1), "FloatType")
1290        _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType")
1291        _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType")
1292        if scalar_type != TorchScalarTypes.QUINT8.value:
1293            raise Exception(  # noqa: TRY002
1294                "PyTorch NNAPI export only supports quantized tensors "
1295                "with the quint8 dtype."
1296            )
1297        op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
1298
1299        out_oper = in_oper._replace(
1300            op_type=op_type,
1301            scale=scale,
1302            zero_point=zero_point,
1303        )
1304
1305        inputs = [None] * 1
1306        inputs[0] = in_id
1307
1308        outputs = [None] * 1
1309        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
1310
1311        self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs)
1312
1313    def add_dequantize(self, node):
1314        assert node.inputsSize() == 1
1315        assert node.outputsSize() == 1
1316
1317        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
1318        out_oper = in_oper._replace(
1319            op_type=NNAPI_OperandCode.TENSOR_FLOAT32,
1320            scale=0.0,
1321            zero_point=0,
1322        )
1323
1324        inputs = [None] * 1
1325        inputs[0] = in_id
1326
1327        outputs = [None] * 1
1328        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
1329
1330        self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs)
1331
1332    def add_pointwise_simple_unary_op(self, node, opcode):
1333        assert node.inputsSize() == 1
1334        assert node.outputsSize() == 1
1335
1336        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
1337
1338        out_oper = in_oper
1339        if opcode == NNAPI_OperationCode.LOGISTIC:
1340            # NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale
1341            # must be 1.f / 256 and the zeroPoint must be 0.
1342            # https://fburl.com/h52stoog
1343            if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
1344                out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256)
1345
1346        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
1347
1348        for idx, dim in enumerate(in_oper.shape):
1349            if dim == 0:
1350                self.forward_operand_shape(out_id, idx, in_id, idx)
1351
1352        inputs = [None] * 1
1353        inputs[0] = in_id
1354
1355        outputs = [None] * 1
1356        outputs[0] = out_id
1357
1358        self.add_operation(opcode, inputs, outputs)
1359
1360    def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None):  # noqa: D401
1361        """Helper for pointwise binary broadcast ops with superfluous extra args."""
1362        assert node.outputsSize() == 1
1363
1364        assert node.inputsAt(0).type().kind() == "TensorType"
1365        assert node.inputsAt(1).type().kind() == "TensorType"
1366
1367        if self.has_operand_for_jitval(node.inputsAt(0)):
1368            in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
1369            in1_id, in1_oper = self.get_tensor_operand_or_constant(
1370                node.inputsAt(1), in0_oper.dim_order
1371            )
1372        elif self.has_operand_for_jitval(node.inputsAt(1)):
1373            in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1))
1374            in0_id, in0_oper = self.get_tensor_operand_or_constant(
1375                node.inputsAt(0), in1_oper.dim_order
1376            )
1377        else:
1378            raise Exception(  # noqa: TRY002
1379                f"Can't do a NNAPI binary op: {opcode} on two constants"
1380            )  # noqa: TRY002
1381
1382        assert in0_oper.op_type == in1_oper.op_type
1383        in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast(
1384            in0_id, in0_oper, in1_id, in1_oper
1385        )
1386        # NOTE: PyTorch and NNAPI have the same broadcast semantics.
1387        out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape)
1388        out_oper = in0_oper._replace(shape=out_shape)
1389        if qparams is not None:
1390            scale, zp = qparams
1391            out_oper = out_oper._replace(scale=scale, zero_point=zp)
1392
1393        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
1394        for idx, (d0, d1) in enumerate(zip(in0_oper.shape, in1_oper.shape)):
1395            if d0 == 1 and d1 == 0:
1396                self.forward_operand_shape(out_id, idx, in1_id, idx)
1397            elif d0 == 0 and d1 == 1:
1398                self.forward_operand_shape(out_id, idx, in0_id, idx)
1399            elif d0 == 0 and d1 == 0:
1400                self.flexible_shape_computation_lines.append(
1401                    f"assert {flex_name(in0_id, idx)} == {flex_name(in1_id, idx)}"
1402                )
1403                self.forward_operand_shape(out_id, idx, in0_id, idx)
1404
1405        inputs = [None] * 3
1406        inputs[0] = in0_id
1407        inputs[1] = in1_id
1408        inputs[2] = self.add_immediate_int_scalar(fuse_code)
1409
1410        outputs = [None] * 1
1411        outputs[0] = out_id
1412
1413        self.add_operation(opcode, inputs, outputs)
1414
1415    def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code):
1416        assert node.inputsSize() == 2
1417        self._do_add_binary(node, opcode, fuse_code)
1418
1419    def add_add_sub_op(self, node, opcode, fuse_code):
1420        assert node.inputsSize() == 3
1421
1422        _, alpha = self.get_constant_value(node.inputsAt(2), "IntType")
1423        if alpha != 1:
1424            raise Exception(  # noqa: TRY002
1425                "NNAPI does not support add/sub with alpha."
1426            )  # noqa: TRY002
1427
1428        self._do_add_binary(node, opcode, fuse_code)
1429
1430    def add_qadd(self, node, opcode, fuse_code):
1431        assert node.inputsSize() == 4
1432
1433        _, scale = self.get_constant_value(node.inputsAt(2), "FloatType")
1434        _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType")
1435
1436        self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point))
1437
1438    def add_softmax(self, node):
1439        assert node.inputsSize() == 3
1440        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
1441
1442        _, softmax_dim = self.get_constant_value(node.inputsAt(1), "IntType")
1443
1444        out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
1445        for dim, size in enumerate(in_oper.shape):
1446            if size == 0:
1447                self.forward_operand_shape(out_id, dim, in_id, dim)
1448
1449        inputs = [None] * 3
1450        inputs[0] = in_id
1451        inputs[1] = self.add_immediate_float_scalar(
1452            1.0
1453        )  # positive scaling factor of exponent, beta
1454        inputs[2] = self.add_immediate_int_scalar(softmax_dim)
1455
1456        outputs = [None] * 1
1457        outputs[0] = out_id
1458
1459        self.add_operation(NNAPI_OperationCode.SOFTMAX, inputs, outputs)
1460
1461    def add_hardtanh(self, node):
1462        assert node.inputsSize() == 3
1463        assert node.outputsSize() == 1
1464
1465        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
1466        _, min_val = self.get_constant_value(node.inputsAt(1), "FloatType")
1467        _, max_val = self.get_constant_value(node.inputsAt(2), "FloatType")
1468
1469        op_map = {
1470            (-1, 1): NNAPI_OperationCode.RELU1,
1471            (0, 6): NNAPI_OperationCode.RELU6,  # noqa: E201
1472        }
1473
1474        opcode = op_map.get((min_val, max_val))
1475        if opcode is None:
1476            raise Exception(  # noqa: TRY002
1477                "NNAPI only supports hardtanh with args (-1, 1) or (0, 6)."
1478            )  # noqa: TRY002
1479
1480        inputs = [None] * 1
1481        inputs[0] = in_id
1482
1483        outputs = [None] * 1
1484        outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper)
1485
1486        self.add_operation(opcode, inputs, outputs)
1487
1488    def add_prelu_op(self, node):
1489        assert node.inputsSize() == 2
1490        assert node.outputsSize() == 1
1491
1492        assert node.inputsAt(0).type().kind() == "TensorType"
1493        assert node.inputsAt(1).type().kind() == "TensorType"
1494
1495        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
1496        w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1))
1497        assert len(w_oper.shape) == 1
1498        assert w_oper.shape[0] > 0
1499        if w_oper.shape[0] > 1:
1500            if in_oper.use_nchw():
1501                # TODO: Support this by adding trailing 1 dims.
1502                raise Exception(  # noqa: TRY002
1503                    "Per-channel PReLU only supports channels_last right now."
1504                )
1505
1506        out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
1507        for dim, size in enumerate(in_oper.shape):
1508            if size > 0:
1509                pass
1510            elif dim <= 1:
1511                raise Exception(  # noqa: TRY002
1512                    "PReLU requires fixed size for dim 0 and dim 1."
1513                )  # noqa: TRY002
1514            else:
1515                self.forward_operand_shape(out_id, dim, in_id, dim)
1516
1517        inputs = [None] * 2
1518        inputs[0] = in_id
1519        inputs[1] = w_id
1520
1521        outputs = [None] * 1
1522        outputs[0] = out_id
1523
1524        self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs)
1525
1526    def add_pool2d_node(self, node, opcode):
1527        assert node.inputsSize() == 6
1528        assert node.outputsSize() == 1
1529        image, kernel, stride, padding, dilation, ceil_mode = node.inputs()
1530
1531        stride = stride or kernel
1532
1533        # TODO: Validate ceil_mode semantics.
1534
1535        args = self.get_conv_pool_args_2d_from_jit(
1536            self.get_size_arg(kernel), stride, padding, dilation
1537        )
1538        if args.dilation_h != 1 or args.dilation_w != 1:
1539            raise Exception("NNAPI does not support dilated pooling.")  # noqa: TRY002
1540
1541        image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image)
1542        assert len(image_oper.shape) == 4
1543
1544        out_shape = get_conv_pool_shape(
1545            image_oper.shape, args, image_oper.shape[1], False
1546        )
1547        use_nchw = image_oper.use_nchw()
1548
1549        inputs = [None] * 11
1550        inputs[0] = image_id
1551        inputs[1] = self.add_immediate_int_scalar(args.pad_l)
1552        inputs[2] = self.add_immediate_int_scalar(args.pad_r)
1553        inputs[3] = self.add_immediate_int_scalar(args.pad_t)
1554        inputs[4] = self.add_immediate_int_scalar(args.pad_b)
1555        inputs[5] = self.add_immediate_int_scalar(args.stride_w)
1556        inputs[6] = self.add_immediate_int_scalar(args.stride_h)
1557        inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
1558        inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
1559        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
1560        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
1561
1562        outputs = [None] * 1
1563        outputs[0] = self.add_tensor_operand(
1564            node.outputsAt(0), image_oper._replace(shape=out_shape)
1565        )
1566
1567        self.add_operation(opcode, inputs, outputs)
1568
1569    def add_avg_pool2d(self, node):
1570        assert node.inputsSize() == 7
1571        assert node.outputsSize() == 1
1572        (
1573            image,
1574            kernel,
1575            stride,
1576            padding,
1577            ceil_mode,
1578            count_include_pad,
1579            divisor_override,
1580        ) = node.inputs()
1581
1582        _, count_include_pad_value = self.get_constant_value(count_include_pad)
1583        _, divisor_override_value = self.get_constant_value(divisor_override)
1584        if not count_include_pad_value or divisor_override_value:
1585            raise Exception(  # noqa: TRY002
1586                "NNAPI doesn't support count_include_pad=False or divisor_override"
1587            )
1588
1589        args = self.get_conv_pool_args_2d_from_jit(
1590            self.get_size_arg(kernel), stride, padding
1591        )
1592
1593        image_id, image_oper = self.get_tensor_operand_by_jitval(image)
1594        assert len(image_oper.shape) == 4
1595
1596        out_shape = get_conv_pool_shape(
1597            image_oper.shape, args, image_oper.shape[1], False
1598        )
1599        use_nchw = image_oper.use_nchw()
1600
1601        inputs = [None] * 11
1602        inputs[0] = image_id
1603        inputs[1] = self.add_immediate_int_scalar(args.pad_l)
1604        inputs[2] = self.add_immediate_int_scalar(args.pad_r)
1605        inputs[3] = self.add_immediate_int_scalar(args.pad_t)
1606        inputs[4] = self.add_immediate_int_scalar(args.pad_b)
1607        inputs[5] = self.add_immediate_int_scalar(args.stride_w)
1608        inputs[6] = self.add_immediate_int_scalar(args.stride_h)
1609        inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
1610        inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
1611        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
1612        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
1613
1614        outputs = [None] * 1
1615        out_id = self.add_tensor_operand(
1616            node.outputsAt(0), image_oper._replace(shape=out_shape)
1617        )
1618        self._handle_conv_pool_flexible_input(out_id, image, args, False)
1619        outputs[0] = out_id
1620
1621        self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
1622
1623    def add_adaptive_avg_pool2d(self, node):
1624        assert node.inputsSize() == 2
1625        assert node.outputsSize() == 1
1626
1627        image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(
1628            node.inputsAt(0)
1629        )
1630        assert len(image_oper.shape) == 4
1631
1632        size_ctype, size_arg = self.get_constant_value(node.inputsAt(1))
1633        assert size_ctype.kind() == "ListType"
1634        assert size_ctype.getElementType().kind() == "IntType"
1635        if size_arg != [1, 1]:
1636            raise Exception(  # noqa: TRY002
1637                "NNAPI only supports adaptive_avg_pool2d with output size (1, 1)."
1638            )
1639
1640        out_shape = image_oper.shape[0:2] + tuple(size_arg)
1641        use_nchw = image_oper.use_nchw()
1642
1643        inputs = [None] * 11
1644        inputs[0] = image_id
1645        inputs[1] = self.add_immediate_int_scalar(0)
1646        inputs[2] = self.add_immediate_int_scalar(0)
1647        inputs[3] = self.add_immediate_int_scalar(0)
1648        inputs[4] = self.add_immediate_int_scalar(0)
1649        inputs[5] = self.add_immediate_int_scalar(1)
1650        inputs[6] = self.add_immediate_int_scalar(1)
1651        inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3])
1652        inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2])
1653        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
1654        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
1655
1656        outputs = [None] * 1
1657        outputs[0] = self.add_tensor_operand(
1658            node.outputsAt(0), image_oper._replace(shape=out_shape)
1659        )
1660
1661        self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
1662
1663    def add_upsample_nearest2d(self, node):
1664        assert node.inputsSize() == 3 or node.inputsSize() == 4
1665        assert node.outputsSize() == 1
1666        if node.inputsSize() == 3:
1667            image, size_jit, scale_jit = node.inputs()
1668        else:
1669            image, size_jit, scale_h_jit, scale_w_jit = node.inputs()
1670        size_ctype, size_arg = self.get_constant_value(size_jit)
1671
1672        if node.inputsSize() == 3:
1673            scale_ctype, scale_arg = self.get_constant_value(scale_jit)  # type: ignore[possibly-undefined]
1674        else:
1675            scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit)  # type: ignore[possibly-undefined]
1676            scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit)  # type: ignore[possibly-undefined]
1677
1678            # The only way for the 4-argument overload of upsample_nearest2d to
1679            # have been added to the graph without error is if the scale_h and
1680            # scale_w arguments are None
1681            assert scale_h_ctype.kind() == "NoneType"
1682            assert scale_w_ctype.kind() == "NoneType"
1683
1684            scale_ctype = scale_h_ctype
1685            scale_arg = scale_h_arg
1686
1687        image_id, image_oper = self.get_tensor_operand_by_jitval(image)
1688        assert len(image_oper.shape) == 4
1689
1690        if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType":
1691            raise Exception("Size and scale cannot both be non-None.")  # noqa: TRY002
1692        elif size_ctype.kind() != "NoneType":
1693            assert size_ctype.kind() == "ListType"
1694            assert size_ctype.getElementType().kind() == "IntType"
1695            assert scale_ctype.kind() == "NoneType"
1696            assert scale_arg is None
1697            assert isinstance(size_arg, list)
1698            assert size_arg
1699            assert all(isinstance(val, int) for val in size_arg)
1700            if len(size_arg) == 1:
1701                size_arg = size_arg * 2
1702            assert len(size_arg) == 2
1703            out_h = size_arg[0]
1704            out_w = size_arg[1]
1705            arg_h = self.add_immediate_int_scalar(out_h)
1706            arg_w = self.add_immediate_int_scalar(out_w)
1707        elif scale_ctype.kind() != "NoneType":
1708            assert scale_ctype.kind() == "ListType"
1709            assert scale_ctype.getElementType().kind() == "FloatType"
1710            assert size_ctype.kind() == "NoneType"
1711            assert size_arg is None
1712            assert isinstance(scale_arg, list)
1713            assert scale_arg
1714            assert all(isinstance(val, float) for val in scale_arg)
1715            if len(scale_arg) == 1:
1716                scale_arg = scale_arg * 2
1717            assert len(scale_arg) == 2
1718            out_h = int(scale_arg[0] * image_oper.shape[2])
1719            out_w = int(scale_arg[1] * image_oper.shape[3])
1720            arg_h = self.add_immediate_float_scalar(scale_arg[0])
1721            arg_w = self.add_immediate_float_scalar(scale_arg[1])
1722        else:
1723            raise Exception("Size and scale cannot both be None.")  # noqa: TRY002
1724
1725        out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w)
1726        use_nchw = image_oper.use_nchw()
1727        out_id = self.add_tensor_operand(
1728            node.outputsAt(0), image_oper._replace(shape=out_shape)
1729        )
1730
1731        if image_oper.shape[0] == 0 or image_oper.shape[1] == 0:
1732            raise Exception("Flexible batch or channels not supported")  # noqa: TRY002
1733
1734        # Handle variable input size
1735        for dim in (2, 3):  # h, w indices
1736            if image_oper.shape[dim] == 0:
1737                if size_ctype.kind() != "NoneType":
1738                    self.compute_operand_shape(out_id, dim, size_arg[dim - 2])
1739                elif scale_ctype.kind() != "NoneType":
1740                    self.compute_operand_shape(
1741                        out_id,
1742                        dim,
1743                        f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})",
1744                    )
1745                else:
1746                    raise Exception(  # noqa: TRY002
1747                        "Size and scale cannot both be None."
1748                    )  # noqa: TRY002
1749
1750        inputs = [None] * 4
1751        inputs[0] = image_id
1752        inputs[1] = arg_w
1753        inputs[2] = arg_h
1754        inputs[3] = self.add_immediate_bool_scalar(use_nchw)
1755
1756        outputs = [None] * 1
1757        outputs[0] = out_id
1758
1759        self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs)
1760
1761    def add_addmm(self, node):
1762        assert node.inputsSize() == 5
1763        assert node.outputsSize() == 1
1764        jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs()
1765
1766        for jitval in (jit_beta, jit_alpha):
1767            scale_ctype, scale_value = self.get_constant_value(jitval)
1768            assert scale_ctype.kind() in ("IntType", "FloatType")
1769            if scale_value != 1:
1770                raise Exception(  # noqa: TRY002
1771                    "NNAPI Fully-Connected does not support alpha and beta."
1772                )
1773
1774        self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias)
1775
1776    def add_linear(self, node):
1777        assert node.inputsSize() == 3
1778        assert node.outputsSize() == 1
1779        jit_input, jit_weight, jit_bias = node.inputs()
1780
1781        self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias)
1782
1783    def add_addmm_or_linear(
1784        self, node, transpose_weight, jit_input, jit_weight, jit_bias
1785    ):
1786        input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input)
1787        bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias)
1788
1789        assert len(input_oper.shape) == 2
1790        assert len(bias_oper.shape) == 1
1791
1792        # TODO: Transform at load time to share weights with CPU model.
1793        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
1794        assert len(weight_tensor.shape) == 2
1795        if transpose_weight:
1796            nnapi_weight_tensor = weight_tensor.t().contiguous()
1797        else:
1798            nnapi_weight_tensor = weight_tensor.contiguous()
1799        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
1800        weight_oper = self.operands[weight_id]
1801
1802        out_shape = (input_oper.shape[0], weight_oper.shape[0])
1803        out_id = self.add_tensor_operand(
1804            node.outputsAt(0), input_oper._replace(shape=out_shape)
1805        )
1806
1807        if input_oper.shape[0] == 0:
1808            self.forward_operand_shape(out_id, 0, input_id, 0)
1809
1810        inputs = [None] * 4
1811        inputs[0] = input_id
1812        inputs[1] = weight_id
1813        inputs[2] = bias_id
1814        inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
1815
1816        outputs = [None] * 1
1817        outputs[0] = out_id
1818
1819        self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
1820
1821    def add_qlinear(self, node):
1822        assert node.inputsSize() == 4
1823        assert node.outputsSize() == 1
1824        (
1825            jit_input,
1826            jit_packed_weight,
1827            jit_scale,
1828            jit_zero_point,
1829        ) = node.inputs()
1830
1831        input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
1832        # TODO: Support automatic reshape
1833        assert len(input_oper.shape) == 2
1834
1835        _, out_scale = self.get_constant_value(jit_scale, "FloatType")
1836        _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
1837        weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
1838        assert weight_ctype.name() == "LinearPackedParamsBase"
1839        raw_weight, raw_bias = packed_weight.__getstate__()[0]
1840        assert raw_bias is not None
1841
1842        assert len(raw_weight.shape) == 2
1843        assert len(raw_bias.shape) == 1
1844        assert raw_bias.shape[0] == raw_weight.shape[0]
1845        assert raw_weight.shape[1] == input_oper.shape[1]
1846
1847        assert raw_weight.qscheme() == torch.per_tensor_affine
1848        if raw_weight.dtype == torch.quint8:
1849            unsigned_weight = raw_weight
1850        else:
1851            assert raw_weight.dtype == torch.qint8
1852            unsigned_weight = torch._make_per_tensor_quantized_tensor(
1853                (raw_weight.int_repr().int() + 128).to(torch.uint8),
1854                scale=raw_weight.q_scale(),
1855                zero_point=raw_weight.q_zero_point() + 128,
1856            )
1857        weight_scale = unsigned_weight.q_scale()
1858        bias_scale = input_oper.scale * weight_scale
1859        int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
1860        bias_id = self.add_tensor_operand_for_weight(int_bias)
1861
1862        multiplier = input_oper.scale * weight_scale / out_scale
1863        assert multiplier > 0
1864        if multiplier >= 1:
1865            raise Exception(  # noqa: TRY002
1866                "Quantized convolution multiplier is greater than 1.  "
1867                "This is supported by NNAPI, but not by most hardware backends.  "
1868                "Try training a model without quantization-aware training.  "
1869            )
1870
1871        # TODO: Transform at load time to share weights with CPU model.
1872        nnapi_weight_tensor = unsigned_weight.contiguous()
1873        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
1874        weight_oper = self.operands[weight_id]
1875
1876        out_shape = (input_oper.shape[0], weight_oper.shape[0])
1877        out_oper = input_oper._replace(
1878            shape=out_shape,
1879            scale=out_scale,
1880            zero_point=out_zero_point,
1881        )
1882
1883        inputs = [None] * 4
1884        inputs[0] = input_id
1885        inputs[1] = weight_id
1886        inputs[2] = bias_id
1887        inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
1888
1889        outputs = [None] * 1
1890        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
1891
1892        self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
1893
1894    def get_optional_bias(self, jit_bias, weight_tensor, transpose=False):
1895        ctype, value = self.get_constant_value(jit_bias)
1896        if ctype.kind() == "NoneType":
1897            bias_idx = 1 if transpose else 0
1898            nnapi_bias_tensor = torch.zeros(
1899                weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype
1900            )
1901            bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor)
1902            bias_oper = self.operands[bias_id]
1903            return bias_id, bias_oper
1904        else:
1905            return self.get_tensor_operand_for_weight(jit_bias)
1906
1907    def add_conv2d(self, node):
1908        assert node.inputsSize() == 7
1909        assert node.outputsSize() == 1
1910
1911        (
1912            jit_image,
1913            jit_weight,
1914            jit_bias,
1915            jit_stride,
1916            jit_pad,
1917            jit_dilation,
1918            jit_groups,
1919        ) = node.inputs()
1920
1921        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
1922        bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
1923        args = self.get_conv_pool_args_2d_from_jit(
1924            weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
1925        )
1926
1927        return self.add_conv2d_common(
1928            node.outputsAt(0),
1929            0.0,
1930            0,
1931            jit_image,
1932            weight_tensor,
1933            bias_id,
1934            args,
1935            False,  # transpose
1936            NNAPI_FuseCode.FUSED_NONE,
1937        )
1938
1939    def add_conv_underscore(self, node):
1940        assert node.inputsSize() == 13
1941        assert node.outputsSize() == 1
1942
1943        (
1944            jit_image,
1945            jit_weight,
1946            jit_bias,
1947            jit_stride,
1948            jit_pad,
1949            jit_dilation,
1950            jit_transpose,
1951            _,
1952            jit_groups,
1953            _,
1954            _,
1955            _,
1956            _,
1957        ) = node.inputs()
1958
1959        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
1960        _, transpose = self.get_constant_value(jit_transpose)
1961        bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose)
1962        args = self.get_conv_pool_args_2d_from_jit(
1963            weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
1964        )
1965
1966        return self.add_conv2d_common(
1967            node.outputsAt(0),
1968            0.0,
1969            0,
1970            jit_image,
1971            weight_tensor,
1972            bias_id,
1973            args,
1974            transpose,
1975            NNAPI_FuseCode.FUSED_NONE,
1976        )
1977
1978    def add_log_softmax(self, node):
1979        assert node.inputsSize() == 3
1980        assert node.outputsSize() == 1
1981
1982        (jit_input, jit_dim, jit_half_to_float) = node.inputs()
1983        input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
1984        _, dim = self.get_constant_value(jit_dim, "IntType")
1985
1986        out_shape = input_oper.shape
1987
1988        inputs = [None] * 3
1989        inputs[0] = input_id
1990        # specifying 1 as the scaling factor for the exponent, beta
1991        inputs[1] = self.add_immediate_float_scalar(1)
1992        inputs[2] = self.add_immediate_int_scalar(dim)
1993
1994        outputs = [None] * 1
1995        outputs[0] = self.add_tensor_operand(
1996            node.outputsAt(0), input_oper._replace(shape=out_shape)
1997        )
1998        self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs)
1999
2000    def add_qconv2d(self, node, fuse_code, transpose=False):
2001        assert node.inputsSize() == 4
2002        assert node.outputsSize() == 1
2003
2004        (
2005            jit_image,
2006            jit_packed_weight,
2007            jit_scale,
2008            jit_zero_point,
2009        ) = node.inputs()
2010
2011        _, out_scale = self.get_constant_value(jit_scale, "FloatType")
2012        _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
2013        weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
2014        assert weight_ctype.name() == "Conv2dPackedParamsBase"
2015        (
2016            pack_version,
2017            tensors,
2018            opt_tensors,
2019        ) = packed_weight.__getstate__()[0]
2020        assert pack_version == "2"
2021        packed_config, raw_weight = tensors
2022        (raw_bias,) = opt_tensors
2023        assert raw_bias is not None
2024        args = self.get_conv_pool_args_2d_from_pack(
2025            raw_weight.shape[2:4], packed_config
2026        )
2027
2028        assert raw_weight.qscheme() == torch.per_tensor_affine
2029        if raw_weight.dtype == torch.quint8:
2030            unsigned_weight = raw_weight
2031        else:
2032            assert raw_weight.dtype == torch.qint8
2033            unsigned_weight = torch._make_per_tensor_quantized_tensor(
2034                (raw_weight.int_repr().int() + 128).to(torch.uint8),
2035                scale=raw_weight.q_scale(),
2036                zero_point=raw_weight.q_zero_point() + 128,
2037            )
2038        weight_scale = unsigned_weight.q_scale()
2039        _, image_oper = self.get_tensor_operand_by_jitval(jit_image)
2040        bias_scale = image_oper.scale * weight_scale
2041        int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
2042        bias_id = self.add_tensor_operand_for_weight(int_bias)
2043
2044        multiplier = image_oper.scale * weight_scale / out_scale
2045        assert multiplier > 0
2046        if multiplier >= 1:
2047            raise Exception(  # noqa: TRY002
2048                "Quantized convolution multiplier is greater than 1.  "
2049                "This is supported by NNAPI, but not by most hardware backends.  "
2050                "Try training a model without quantization-aware training.  "
2051            )
2052
2053        return self.add_conv2d_common(
2054            node.outputsAt(0),
2055            out_scale,
2056            out_zero_point,
2057            jit_image,
2058            unsigned_weight,
2059            bias_id,
2060            args,
2061            transpose,
2062            fuse_code,
2063        )
2064
2065    def add_conv2d_common(
2066        self,
2067        jit_out,
2068        out_scale,
2069        out_zero_point,
2070        jit_image,
2071        weight_tensor,
2072        bias_id,
2073        args,
2074        transpose,
2075        fuse_code,
2076    ):
2077        image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
2078        in_c = image_oper.shape[1]
2079
2080        if args.group == 1:
2081            # Full convolution
2082            depthwise = False
2083            if transpose:
2084                weight_permutation = (1, 2, 3, 0)
2085            else:
2086                weight_permutation = (0, 2, 3, 1)
2087        elif args.group == in_c:
2088            # Depthwise convolution
2089            depthwise = True
2090            weight_permutation = (1, 2, 3, 0)
2091        else:
2092            raise Exception("Group convolution not supported yet.")  # noqa: TRY002
2093
2094        # TODO: Transform at load time to share weights with CPU model.
2095        nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous()
2096        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
2097        weight_oper = self.operands[weight_id]
2098
2099        bias_oper = self.operands[bias_id]
2100
2101        if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
2102            assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
2103            assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
2104        elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
2105            assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
2106            assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32
2107            assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale)
2108            assert bias_oper.zero_point == 0
2109        else:
2110            raise Exception(  # noqa: TRY002
2111                f"Unsupported input type for conv2d: {image_oper.op_type}"
2112            )  # noqa: TRY002
2113
2114        assert len(image_oper.shape) == 4
2115        assert len(weight_oper.shape) == 4
2116        assert len(bias_oper.shape) == 1
2117
2118        if depthwise:
2119            # Depthwise convolution
2120            one, kern_h, kern_w, out_c = weight_oper.shape
2121            assert one == 1
2122            assert out_c % in_c == 0
2123            channel_multiplier = out_c // in_c
2124            assert channel_multiplier == 1  # Don't support multiplier
2125            assert out_c == in_c
2126        else:
2127            # Full convolution
2128            out_c, kern_h, kern_w, kern_d = weight_oper.shape
2129            assert kern_d == in_c
2130
2131        assert out_c == bias_oper.shape[0]
2132
2133        use_nchw = image_oper.use_nchw()
2134
2135        if depthwise:
2136            num_args = 12
2137            opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D
2138        else:
2139            num_args = 11
2140            if transpose:
2141                opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D
2142            else:
2143                opcode = NNAPI_OperationCode.CONV_2D
2144
2145        inputs = [None] * num_args
2146        inputs[0] = image_id
2147        inputs[1] = weight_id
2148        inputs[2] = bias_id
2149        inputs[3] = self.add_immediate_int_scalar(args.pad_l)
2150        inputs[4] = self.add_immediate_int_scalar(args.pad_r)
2151        inputs[5] = self.add_immediate_int_scalar(args.pad_t)
2152        inputs[6] = self.add_immediate_int_scalar(args.pad_b)
2153        inputs[7] = self.add_immediate_int_scalar(args.stride_w)
2154        inputs[8] = self.add_immediate_int_scalar(args.stride_h)
2155        if depthwise:
2156            inputs[9] = self.add_immediate_int_scalar(1)
2157            inputs[10] = self.add_immediate_int_scalar(fuse_code)
2158            inputs[11] = self.add_immediate_bool_scalar(use_nchw)
2159        else:
2160            inputs[9] = self.add_immediate_int_scalar(fuse_code)
2161            inputs[10] = self.add_immediate_bool_scalar(use_nchw)
2162
2163        outputs = [None] * 1
2164        out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose)
2165        out_oper = image_oper._replace(
2166            shape=out_shape,
2167            scale=out_scale,
2168            zero_point=out_zero_point,
2169        )
2170        out_id = self.add_tensor_operand(jit_out, out_oper)
2171        self._handle_conv_pool_flexible_input(out_id, jit_image, args, transpose)
2172
2173        outputs[0] = out_id
2174        self.add_operation(opcode, inputs, outputs)
2175
2176    def _handle_conv_pool_flexible_input(self, out_id, jit_image, args, transpose):
2177        image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
2178        batch, in_ch, in_h, in_w = image_oper.shape
2179
2180        if batch == 0:
2181            self.forward_operand_shape(out_id, 0, image_id, 0)
2182        if in_ch == 0:
2183            raise Exception("Input channels can't be flexible")  # noqa: TRY002
2184        # H & W
2185        if transpose:
2186            if in_h == 0:
2187                self.compute_operand_shape(
2188                    out_id,
2189                    2,
2190                    f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}",
2191                )
2192            if in_w == 0:
2193                self.compute_operand_shape(
2194                    out_id,
2195                    3,
2196                    f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}",
2197                )
2198        else:
2199            if in_h == 0:
2200                self.compute_operand_shape(
2201                    out_id,
2202                    2,
2203                    f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1",
2204                )
2205            if in_w == 0:
2206                self.compute_operand_shape(
2207                    out_id,
2208                    3,
2209                    f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1",
2210                )
2211
2212
2213def serialize_model(
2214    module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False
2215):
2216    """Convert to NNAPI and serialize torchscript module.
2217
2218    Parameters:
2219        module: Torchscript module to convert
2220        inputs: Tensors used to specify input details for NNAPI
2221        config (optional): Optional config to attach to module
2222        return_shapes (optional): Specify shape of outputs if
2223            your module uses runtime flexible shapes to set output
2224            buffer size for NNAPI
2225        use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values
2226    """
2227    return _NnapiSerializer(config, use_int16_for_qint16).serialize_model(
2228        module, inputs, return_shapes
2229    )
2230