xref: /aosp_15_r20/external/pytorch/torch/backends/_nnapi/prepare.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from typing import List, Optional
4
5import torch
6from torch.backends._nnapi.serializer import _NnapiSerializer
7
8
9ANEURALNETWORKS_PREFER_LOW_POWER = 0
10ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
11ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
12
13
14class NnapiModule(torch.nn.Module):
15    """Torch Module that wraps an NNAPI Compilation.
16
17    This module handles preparing the weights, initializing the
18    NNAPI TorchBind object, and adjusting the memory formats
19    of all inputs and outputs.
20    """
21
22    # _nnapi.Compilation is defined
23    comp: Optional[torch.classes._nnapi.Compilation]  # type: ignore[name-defined]
24    weights: List[torch.Tensor]
25    out_templates: List[torch.Tensor]
26
27    def __init__(
28        self,
29        shape_compute_module: torch.nn.Module,
30        ser_model: torch.Tensor,
31        weights: List[torch.Tensor],
32        inp_mem_fmts: List[int],
33        out_mem_fmts: List[int],
34        compilation_preference: int,
35        relax_f32_to_f16: bool,
36    ):
37        super().__init__()
38        self.shape_compute_module = shape_compute_module
39        self.ser_model = ser_model
40        self.weights = weights
41        self.inp_mem_fmts = inp_mem_fmts
42        self.out_mem_fmts = out_mem_fmts
43        self.out_templates = []
44        self.comp = None
45        self.compilation_preference = compilation_preference
46        self.relax_f32_to_f16 = relax_f32_to_f16
47
48    @torch.jit.export
49    def init(self, args: List[torch.Tensor]):
50        assert self.comp is None
51        self.out_templates = self.shape_compute_module.prepare(self.ser_model, args)  # type: ignore[operator]
52        self.weights = [w.contiguous() for w in self.weights]
53        comp = torch.classes._nnapi.Compilation()
54        comp.init2(
55            self.ser_model,
56            self.weights,
57            self.compilation_preference,
58            self.relax_f32_to_f16,
59        )
60
61        self.comp = comp
62
63    def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]:
64        if self.comp is None:
65            self.init(args)
66        comp = self.comp
67        assert comp is not None
68        outs = [torch.empty_like(out) for out in self.out_templates]
69
70        assert len(args) == len(self.inp_mem_fmts)
71        fixed_args = []
72        for idx in range(len(args)):
73            fmt = self.inp_mem_fmts[idx]
74            # These constants match the values in DimOrder in serializer.py
75            # TODO: See if it's possible to use those directly.
76            if fmt == 0:
77                fixed_args.append(args[idx].contiguous())
78            elif fmt == 1:
79                fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
80            else:
81                raise ValueError("Invalid mem_fmt")
82        comp.run(fixed_args, outs)
83        assert len(outs) == len(self.out_mem_fmts)
84        for idx in range(len(self.out_templates)):
85            fmt = self.out_mem_fmts[idx]
86            # These constants match the values in DimOrder in serializer.py
87            # TODO: See if it's possible to use those directly.
88            if fmt in (0, 2):
89                pass
90            elif fmt == 1:
91                outs[idx] = outs[idx].permute(0, 3, 1, 2)
92            else:
93                raise ValueError("Invalid mem_fmt")
94        return outs
95
96
97def convert_model_to_nnapi(
98    model,
99    inputs,
100    serializer=None,
101    return_shapes=None,
102    use_int16_for_qint16=False,
103    compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED,
104    relax_f32_to_f16=False,
105):
106    (
107        shape_compute_module,
108        ser_model_tensor,
109        used_weights,
110        inp_mem_fmts,
111        out_mem_fmts,
112        retval_count,
113    ) = process_for_nnapi(
114        model, inputs, serializer, return_shapes, use_int16_for_qint16
115    )
116
117    nnapi_model = NnapiModule(
118        shape_compute_module,
119        ser_model_tensor,
120        used_weights,
121        inp_mem_fmts,
122        out_mem_fmts,
123        compilation_preference,
124        relax_f32_to_f16,
125    )
126
127    class NnapiInterfaceWrapper(torch.nn.Module):
128        """NNAPI list-ifying and de-list-ifying wrapper.
129
130        NNAPI always expects a list of inputs and provides a list of outputs.
131        This module allows us to accept inputs as separate arguments.
132        It returns results as either a single tensor or tuple,
133        matching the original module.
134        """
135
136        def __init__(self, mod):
137            super().__init__()
138            self.mod = mod
139
140    wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
141    wrapper_model = torch.jit.script(wrapper_model_py)
142    # TODO: Maybe make these names match the original.
143    arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
144    if retval_count < 0:
145        ret_expr = "retvals[0]"
146    else:
147        ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
148    wrapper_model.define(
149        f"def forward(self, {arg_list}):\n"
150        f"    retvals = self.mod([{arg_list}])\n"
151        f"    return {ret_expr}\n"
152    )
153    return wrapper_model
154
155
156def process_for_nnapi(
157    model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False
158):
159    model = torch.jit.freeze(model)
160
161    if isinstance(inputs, torch.Tensor):
162        inputs = [inputs]
163
164    serializer = serializer or _NnapiSerializer(
165        config=None, use_int16_for_qint16=use_int16_for_qint16
166    )
167    (
168        ser_model,
169        used_weights,
170        inp_mem_fmts,
171        out_mem_fmts,
172        shape_compute_lines,
173        retval_count,
174    ) = serializer.serialize_model(model, inputs, return_shapes)
175    ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
176
177    # We have to create a new class here every time this function is called
178    # because module.define adds a method to the *class*, not the instance.
179    class ShapeComputeModule(torch.nn.Module):
180        """Code-gen-ed module for tensor shape computation.
181
182        module.prepare will mutate ser_model according to the computed operand
183        shapes, based on the shapes of args.  Returns a list of output templates.
184        """
185
186    shape_compute_module = torch.jit.script(ShapeComputeModule())
187    real_shape_compute_lines = [
188        "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
189    ] + [f"    {line}\n" for line in shape_compute_lines]
190    shape_compute_module.define("".join(real_shape_compute_lines))
191
192    return (
193        shape_compute_module,
194        ser_model_tensor,
195        used_weights,
196        inp_mem_fmts,
197        out_mem_fmts,
198        retval_count,
199    )
200