xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/node_visitor.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8from typing import Any, Dict, Tuple
9
10import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
11
12import numpy as np
13import torch
14from executorch.backends.qualcomm.utils.constants import (
15    QCOM_AXIS,
16    QCOM_AXIS_ORDER,
17    QCOM_BITWIDTH,
18    QCOM_DTYPE,
19    QCOM_ENCODING,
20    QCOM_OFFSET,
21    QCOM_QUANT_ATTRS,
22    QCOM_QUANT_MAX,
23    QCOM_QUANT_MIN,
24    QCOM_REQUANTIZE,
25    QCOM_SCALE,
26    QCOM_SCALE_OFFSET,
27    QCOM_SCALES,
28    QCOM_ZERO_POINT,
29    QCOM_ZERO_POINTS,
30)
31
32from executorch.exir.dialects._ops import ops as exir_ops
33
34from .utils import (
35    deduce_dtype,
36    get_parameter,
37    is_graph_input,
38    is_graph_output,
39    is_parameter,
40)
41
42
43QNN_QUANT_TYPE_MAP = {
44    torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8,
45    torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16,
46    torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32,
47    # Note that there is no int64 tensor data type in Qnn.
48    torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED,
49    torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
50    torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
51}
52QNN_TENSOR_TYPE_MAP = {
53    torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
54    torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
55    torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
56    torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,
57    torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
58    torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64,
59    torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8,
60    torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
61    float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
62}
63
64PER_CHANNEL_ENCODING = {
65    exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
66    exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
67}
68
69PER_TENSOR_ENCODING = {
70    exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
71    exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
72    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
73    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
74}
75
76
77class NodeVisitor:
78    """
79    Node visitor pattern for visiting nodes in an edge IR graph
80    """
81
82    def __init__(
83        self,
84        external_ids,
85        edge_program: torch.export.ExportedProgram,
86        enable_tensor_dump,
87    ) -> None:
88        self.external_ids = external_ids or {}
89        self.edge_program = edge_program
90        self.enable_tensor_dump = enable_tensor_dump
91
92    def get_tensor(self, input_node, op_node, idx=None):
93        """
94        Get tensor value/shape with axis_order
95        """
96
97        def _get_tensor(node, index):
98            if index is not None:
99                assert isinstance(index, int)
100                if is_parameter(node, self.edge_program):
101                    return get_parameter(node, self.edge_program)[index]
102                return node.meta["val"][index]
103
104            if is_parameter(node, self.edge_program):
105                return get_parameter(node, self.edge_program)
106            return node.meta["val"]
107
108        tensor = _get_tensor(input_node, idx)
109        if len(tensor.shape) != 0 and QCOM_AXIS_ORDER in op_node.meta:
110            tensor = tensor.permute(dims=op_node.meta[QCOM_AXIS_ORDER]).contiguous()
111        return tensor
112
113    def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
114        quant_config = copy.deepcopy(quant_attrs)
115
116        scales = quant_attrs[QCOM_SCALES]
117        zero_points = quant_attrs[QCOM_ZERO_POINTS]
118        assert len(scales) == len(
119            zero_points
120        ), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}"
121
122        scale_offset = []
123        for i in range(len(scales)):
124            # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
125            scale_offset.append(
126                PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i])
127            )
128
129        user_0 = list(node.users)[0]
130        # Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
131        if (
132            "convolution" in user_0.target.__name__
133            and list(node.users)[0].args[1] == node
134        ):
135            quant_config[QCOM_AXIS] = 3
136
137        else:
138            quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]
139
140        quant_config[QCOM_SCALE_OFFSET] = scale_offset
141        # special case for 4 bits
142        if (
143            quant_config[QCOM_DTYPE] == torch.int8
144            and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
145        ):
146            quant_config[QCOM_BITWIDTH] = 4
147            return (
148                PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
149                quant_config,
150            )
151        return (
152            PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET,
153            quant_config,
154        )
155
156    def make_qnn_per_tensor_config(self, quant_attrs: Dict):
157        quant_config = copy.deepcopy(quant_attrs)
158        # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h
159        quant_config[QCOM_OFFSET] = -quant_attrs[QCOM_ZERO_POINT]
160        # special case for 4 bits
161        if (
162            quant_config[QCOM_DTYPE] == torch.int8
163            and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
164        ):
165            quant_config[QCOM_BITWIDTH] = 4
166            return (
167                PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET,
168                quant_config,
169            )
170        return (
171            PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET,
172            quant_config,
173        )
174
175    def get_quant_encoding_conf(
176        self, node: torch.fx.Node, is_input_tensor: bool = False
177    ) -> Tuple[Any, Dict]:
178        if not node.meta.get(QCOM_QUANT_ATTRS, None):
179            return (
180                PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
181                {},
182            )
183        quant_attrs = (
184            node.meta[QCOM_REQUANTIZE]
185            if QCOM_REQUANTIZE in node.meta and is_input_tensor
186            else node.meta[QCOM_QUANT_ATTRS]
187        )
188        if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING:
189            return self.make_qnn_per_channel_config(node, quant_attrs)
190
191        return self.make_qnn_per_tensor_config(quant_attrs)
192
193    def get_quant_tensor_value(
194        self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
195    ) -> torch.Tensor:
196        if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING:
197            scale = quant_attrs[QCOM_SCALE]
198            zero_point = quant_attrs[QCOM_ZERO_POINT]
199        else:  # per channel case
200            scale = quant_attrs[QCOM_SCALES]
201            zero_point = quant_attrs[QCOM_ZERO_POINTS]
202
203        dtype = quant_configs[QCOM_DTYPE]
204
205        tensor = tensor.div(scale).add(zero_point).round().to(dtype)
206        # Make the backends access data correctly
207        if quant_configs.get(QCOM_BITWIDTH) == 4:
208            mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
209            tensor = torch.bitwise_and(mask, tensor)
210        return tensor
211
212    def get_tensor_type(
213        self,
214        node: torch.fx.Node,
215        tensor_type: PyQnnWrapper.Qnn_TensorType_t,
216    ) -> PyQnnWrapper.Qnn_TensorType_t:
217        is_input = is_graph_input(node, self.edge_program)
218        is_output = is_graph_output(node)
219        # handle logic for input/output tensors
220        if is_input or is_output:
221            assert (
222                node in self.external_ids
223            ), f"Node {node}, is_input: {is_input}, is_output: {is_output}, ext_ids: {self.external_ids.keys()}"
224            if is_input:
225                return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_WRITE
226            if is_output:
227                return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ
228
229        if is_parameter(node, self.edge_program):
230            return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
231        # dump all tensor, set to app read, and we only dump native tensors
232        if (
233            self.enable_tensor_dump
234            and tensor_type == PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
235        ):
236            return PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ
237        return tensor_type
238
239    def get_data_type(
240        self,
241        tensor: torch.Tensor,
242        quant_config: Dict,
243    ) -> PyQnnWrapper.Qnn_TensorType_t:
244        if quant_config:
245            quant_config[QCOM_DTYPE] = deduce_dtype(tensor, quant_config)
246            return QNN_QUANT_TYPE_MAP[quant_config[QCOM_DTYPE]]
247
248        return QNN_TENSOR_TYPE_MAP[tensor.dtype]
249
250    def define_custom_tensor_wrapper(
251        self,
252        node_name: str,
253        tensor_type: PyQnnWrapper.Qnn_TensorType_t,
254        dtype: PyQnnWrapper.Qnn_DataType_t,
255        quant_encoding: PyQnnWrapper.Qnn_QuantizationEncoding_t,
256        quant_configs: dict,
257        dims: torch.Size,
258        tensor: torch.Tensor,
259        is_fake_tensor: bool,
260        nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
261        wrapper_idx: int = 0,
262    ) -> PyQnnWrapper.TensorWrapper:
263        if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
264            return cached
265        if is_fake_tensor:
266            tensor_wrapper = PyQnnWrapper.TensorWrapper(
267                node_name,
268                tensor_type,
269                dtype,
270                quant_encoding,
271                quant_configs,
272                len(dims),
273                dims,
274                np.array([]),
275                False,
276            )
277        else:
278            # Can implement non-fake tensor when there is a need
279            return None
280        nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
281        return tensor_wrapper
282
283    def define_tensor(
284        self,
285        node: torch.fx.Node,
286        tensor: torch.Tensor,
287        tensor_type: PyQnnWrapper.Qnn_TensorType_t,
288        nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
289        is_input_tensor: bool,
290        node_name: str = None,
291        wrapper_idx: int = 0,
292    ) -> PyQnnWrapper.TensorWrapper:
293        """
294        Covert torch.Tensor to TensorWrapper
295
296        Args:
297            node: EdgeIR Node
298            tensor: EdgeIR Tensor
299            tensor_type: QNN tensor type
300            nodes_to_wrappers: Set contains edge_graph values(node targets)
301            is_input_tensor: Whether tensor is a fake input tensor relatively to
302                             the op builder that is calling this function
303        """
304        if node_name is None:
305            node_name = node.name
306
307        if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
308            return cached
309
310        tensor_name = f"{node.name}_{wrapper_idx}"
311        if is_graph_input(node, self.edge_program):
312            tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name
313        if is_graph_output(node):
314            tensor_name = "output_" + tensor_name
315        dims = [1] if len(tensor.size()) == 0 else tensor.size()
316        tensor_type = self.get_tensor_type(node, tensor_type)
317        quant_encoding, quant_configs = self.get_quant_encoding_conf(
318            node, is_input_tensor
319        )
320        dtype = self.get_data_type(tensor, quant_configs)
321        if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
322            tensor_wrapper = PyQnnWrapper.TensorWrapper(
323                tensor_name,
324                tensor_type,
325                dtype,
326                quant_encoding,
327                quant_configs,
328                len(dims),
329                dims,
330                np.array([]),
331                False,
332            )
333        else:
334            if quant_configs:
335                tensor = self.get_quant_tensor_value(
336                    tensor,
337                    node.meta[QCOM_QUANT_ATTRS],
338                    quant_configs,
339                )
340            tensor_wrapper = PyQnnWrapper.TensorWrapper(
341                tensor_name,
342                tensor_type,
343                dtype,
344                quant_encoding,
345                quant_configs,
346                len(dims),
347                dims,
348                tensor.detach().numpy(),
349                True,
350            )
351        nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
352        return tensor_wrapper
353
354    def define_node(
355        self,
356        node: torch.fx.Node,
357        nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
358    ) -> PyQnnWrapper.PyQnnOpWrapper:
359        """Convert torch.fx.Node to OpWrapper"""
360        raise NotImplementedError("NodeVisitor must be extended!")
361
362
363# This will hold mapping of all node names to the visitor class
364_node_visitor_dict = {}
365
366
367def register_node_visitor(visitor):
368    """Register node visitor into _node_visitor_dict"""
369    assert (
370        isinstance(visitor, type)
371        and issubclass(visitor, NodeVisitor)
372        and hasattr(visitor, "target")
373    ), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
374    for target in visitor.target:
375        _node_visitor_dict[target] = visitor
376
377
378def generate_node_to_external_map(
379    edge_program: torch.export.ExportedProgram,
380) -> Dict[torch.fx.Node, int]:
381    node_to_external_map = {}
382    for node in edge_program.graph_module.graph.nodes:
383        # The order in which we visit the placeholder node is same as the *args
384        # order for the forward(*args) signature for this gm. Using the order of
385        # the nodes as external_id to extract the right arg from *args at runtime
386        if is_graph_input(node, edge_program):
387            node_to_external_map[node] = len(node_to_external_map)
388    for node in edge_program.graph_module.graph.nodes:
389        if is_graph_output(node):
390            node_to_external_map[node] = len(node_to_external_map)
391    return node_to_external_map
392
393
394def get_node_visitors(
395    edge_program: torch.export.ExportedProgram,
396    enable_tensor_dump=False,
397) -> Dict[str, NodeVisitor]:
398    """Create a new class instance at runtime, and put them in a dict"""
399    node_to_external_map = generate_node_to_external_map(edge_program)
400    node_visitors = {}
401    for target, visitor in _node_visitor_dict.items():
402        assert callable(
403            visitor
404        ), f"Expeting a callable class, but got {visitor} of type {type(visitor)}"
405        node_visitors[target] = visitor(
406            node_to_external_map, edge_program, enable_tensor_dump
407        )
408    return node_visitors
409