1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from typing import List, Optional 4 5import torch 6from torch.backends._nnapi.serializer import _NnapiSerializer 7 8 9ANEURALNETWORKS_PREFER_LOW_POWER = 0 10ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1 11ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2 12 13 14class NnapiModule(torch.nn.Module): 15 """Torch Module that wraps an NNAPI Compilation. 16 17 This module handles preparing the weights, initializing the 18 NNAPI TorchBind object, and adjusting the memory formats 19 of all inputs and outputs. 20 """ 21 22 # _nnapi.Compilation is defined 23 comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined] 24 weights: List[torch.Tensor] 25 out_templates: List[torch.Tensor] 26 27 def __init__( 28 self, 29 shape_compute_module: torch.nn.Module, 30 ser_model: torch.Tensor, 31 weights: List[torch.Tensor], 32 inp_mem_fmts: List[int], 33 out_mem_fmts: List[int], 34 compilation_preference: int, 35 relax_f32_to_f16: bool, 36 ): 37 super().__init__() 38 self.shape_compute_module = shape_compute_module 39 self.ser_model = ser_model 40 self.weights = weights 41 self.inp_mem_fmts = inp_mem_fmts 42 self.out_mem_fmts = out_mem_fmts 43 self.out_templates = [] 44 self.comp = None 45 self.compilation_preference = compilation_preference 46 self.relax_f32_to_f16 = relax_f32_to_f16 47 48 @torch.jit.export 49 def init(self, args: List[torch.Tensor]): 50 assert self.comp is None 51 self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator] 52 self.weights = [w.contiguous() for w in self.weights] 53 comp = torch.classes._nnapi.Compilation() 54 comp.init2( 55 self.ser_model, 56 self.weights, 57 self.compilation_preference, 58 self.relax_f32_to_f16, 59 ) 60 61 self.comp = comp 62 63 def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]: 64 if self.comp is None: 65 self.init(args) 66 comp = self.comp 67 assert comp is not None 68 outs = [torch.empty_like(out) for out in self.out_templates] 69 70 assert len(args) == len(self.inp_mem_fmts) 71 fixed_args = [] 72 for idx in range(len(args)): 73 fmt = self.inp_mem_fmts[idx] 74 # These constants match the values in DimOrder in serializer.py 75 # TODO: See if it's possible to use those directly. 76 if fmt == 0: 77 fixed_args.append(args[idx].contiguous()) 78 elif fmt == 1: 79 fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous()) 80 else: 81 raise ValueError("Invalid mem_fmt") 82 comp.run(fixed_args, outs) 83 assert len(outs) == len(self.out_mem_fmts) 84 for idx in range(len(self.out_templates)): 85 fmt = self.out_mem_fmts[idx] 86 # These constants match the values in DimOrder in serializer.py 87 # TODO: See if it's possible to use those directly. 88 if fmt in (0, 2): 89 pass 90 elif fmt == 1: 91 outs[idx] = outs[idx].permute(0, 3, 1, 2) 92 else: 93 raise ValueError("Invalid mem_fmt") 94 return outs 95 96 97def convert_model_to_nnapi( 98 model, 99 inputs, 100 serializer=None, 101 return_shapes=None, 102 use_int16_for_qint16=False, 103 compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED, 104 relax_f32_to_f16=False, 105): 106 ( 107 shape_compute_module, 108 ser_model_tensor, 109 used_weights, 110 inp_mem_fmts, 111 out_mem_fmts, 112 retval_count, 113 ) = process_for_nnapi( 114 model, inputs, serializer, return_shapes, use_int16_for_qint16 115 ) 116 117 nnapi_model = NnapiModule( 118 shape_compute_module, 119 ser_model_tensor, 120 used_weights, 121 inp_mem_fmts, 122 out_mem_fmts, 123 compilation_preference, 124 relax_f32_to_f16, 125 ) 126 127 class NnapiInterfaceWrapper(torch.nn.Module): 128 """NNAPI list-ifying and de-list-ifying wrapper. 129 130 NNAPI always expects a list of inputs and provides a list of outputs. 131 This module allows us to accept inputs as separate arguments. 132 It returns results as either a single tensor or tuple, 133 matching the original module. 134 """ 135 136 def __init__(self, mod): 137 super().__init__() 138 self.mod = mod 139 140 wrapper_model_py = NnapiInterfaceWrapper(nnapi_model) 141 wrapper_model = torch.jit.script(wrapper_model_py) 142 # TODO: Maybe make these names match the original. 143 arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs))) 144 if retval_count < 0: 145 ret_expr = "retvals[0]" 146 else: 147 ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count)) 148 wrapper_model.define( 149 f"def forward(self, {arg_list}):\n" 150 f" retvals = self.mod([{arg_list}])\n" 151 f" return {ret_expr}\n" 152 ) 153 return wrapper_model 154 155 156def process_for_nnapi( 157 model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False 158): 159 model = torch.jit.freeze(model) 160 161 if isinstance(inputs, torch.Tensor): 162 inputs = [inputs] 163 164 serializer = serializer or _NnapiSerializer( 165 config=None, use_int16_for_qint16=use_int16_for_qint16 166 ) 167 ( 168 ser_model, 169 used_weights, 170 inp_mem_fmts, 171 out_mem_fmts, 172 shape_compute_lines, 173 retval_count, 174 ) = serializer.serialize_model(model, inputs, return_shapes) 175 ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32) 176 177 # We have to create a new class here every time this function is called 178 # because module.define adds a method to the *class*, not the instance. 179 class ShapeComputeModule(torch.nn.Module): 180 """Code-gen-ed module for tensor shape computation. 181 182 module.prepare will mutate ser_model according to the computed operand 183 shapes, based on the shapes of args. Returns a list of output templates. 184 """ 185 186 shape_compute_module = torch.jit.script(ShapeComputeModule()) 187 real_shape_compute_lines = [ 188 "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n", 189 ] + [f" {line}\n" for line in shape_compute_lines] 190 shape_compute_module.define("".join(real_shape_compute_lines)) 191 192 return ( 193 shape_compute_module, 194 ser_model_tensor, 195 used_weights, 196 inp_mem_fmts, 197 out_mem_fmts, 198 retval_count, 199 ) 200