xref: /aosp_15_r20/external/armnn/python/pyarmnn/src/pyarmnn/_tensor/workload_tensors.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2020 Arm Ltd. All rights reserved.
2# SPDX-License-Identifier: MIT
3"""
4This file contains functions relating to WorkloadTensors.
5WorkloadTensors are the inputTensors and outputTensors that are consumed by IRuntime.EnqueueWorkload.
6"""
7from typing import Union, List, Tuple
8import logging
9
10import numpy as np
11
12from .tensor import Tensor
13from .const_tensor import ConstTensor
14
15
16def make_input_tensors(inputs_binding_info: List[Tuple],
17                       input_data: List[np.ndarray]) -> List[Tuple[int, ConstTensor]]:
18    """Returns `inputTensors` to be used with `IRuntime.EnqueueWorkload`.
19
20    This is the primary function to call when you want to produce `inputTensors` for `IRuntime.EnqueueWorkload`.
21    The output is a list of tuples containing ConstTensors with a corresponding input tensor id.
22    The output should be used directly with `IRuntime.EnqueueWorkload`.
23    This function works for single or multiple input data and binding information.
24
25    Examples:
26        Creating inputTensors.
27        >>> import pyarmnn as ann
28        >>> import numpy as np
29        >>>
30        >>> parser = ann.ITfLiteParser()
31        >>> ...
32        >>> example_image = np.array(...)
33        >>> input_binding_info = parser.GetNetworkInputBindingInfo(...)
34        >>>
35        >>> input_tensors = ann.make_input_tensors([input_binding_info], [example_image])
36
37    Args:
38        inputs_binding_info (list of tuples): (int, `TensorInfo`) Binding information for input tensors obtained from
39                                              `GetNetworkInputBindingInfo`.
40        input_data (list ndarrays): Tensor data to be used for inference.
41
42    Returns:
43        list: `inputTensors` - A list of tuples (`int` , `ConstTensor`).
44
45
46    Raises:
47        ValueError: If length of `inputs_binding_info` and `input_data` are not the same.
48    """
49    if len(inputs_binding_info) != len(input_data):
50        raise ValueError("Length of 'inputs_binding_info' does not match length of 'input_data'")
51
52    input_tensors = []
53
54    for in_bind_info, in_data in zip(inputs_binding_info, input_data):
55        in_tensor_id = in_bind_info[0]
56        in_tensor_info = in_bind_info[1]
57        in_tensor_info.SetConstant()
58        input_tensors.append((in_tensor_id, ConstTensor(in_tensor_info, in_data)))
59
60    return input_tensors
61
62
63def make_output_tensors(outputs_binding_info: List[Tuple]) -> List[Tuple[int, Tensor]]:
64    """Returns `outputTensors` to be used with `IRuntime.EnqueueWorkload`.
65
66    This is the primary function to call when you want to produce `outputTensors` for `IRuntime.EnqueueWorkload`.
67    The output is a list of tuples containing Tensors with a corresponding output tensor id.
68    The output should be used directly with `IRuntime.EnqueueWorkload`.
69
70    Examples:
71        Creating outputTensors.
72        >>> import pyarmnn as ann
73        >>>
74        >>> parser = ann.ITfLiteParser()
75        >>> ...
76        >>> output_binding_info = parser.GetNetworkOutputBindingInfo(...)
77        >>>
78        >>> output_tensors = ann.make_output_tensors([output_binding_info])
79
80    Args:
81        outputs_binding_info (list of tuples): (int, `TensorInfo`) Binding information for output tensors obtained from
82                                               `GetNetworkOutputBindingInfo`.
83
84    Returns:
85        list: `outputTensors` - A list of tuples (`int`, `Tensor`).
86    """
87    output_tensors = []
88
89    for out_bind_info in outputs_binding_info:
90        out_tensor_id = out_bind_info[0]
91        out_tensor_info = out_bind_info[1]
92        output_tensors.append((out_tensor_id, Tensor(out_tensor_info)))
93
94    return output_tensors
95
96
97def workload_tensors_to_ndarray(workload_tensors: List[Tuple[int, Union[Tensor, ConstTensor]]]) -> List[np.ndarray]:
98    """Returns a list of the underlying tensor data as ndarrays from `inputTensors` or `outputTensors`.
99
100    We refer to `inputTensors` and `outputTensors` as workload tensors because
101    they are used with `IRuntime.EnqueueWorkload`.
102    Although this function can be used on either `inputTensors` or `outputTensors` the main use of this function
103    is to collect results from `outputTensors` after `IRuntime.EnqueueWorkload` has been called.
104
105    Examples:
106        Getting results after inference.
107        >>> import pyarmnn as ann
108        >>>
109        >>> ...
110        >>> runtime = ann.IRuntime(...)
111        >>> ...
112        >>> runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
113        >>>
114        >>> inference_results = workload_tensors_to_ndarray(output_tensors)
115
116    Args:
117        workload_tensors (inputTensors or outputTensors): `inputTensors` or `outputTensors` to get data from. See
118                                                          `make_input_tensors` and `make_output_tensors`.
119
120    Returns:
121        list: List of `ndarrays` for the underlying tensor data from given `inputTensors` or `outputTensors`.
122    """
123    arrays = []
124    for index, (_, tensor) in enumerate(workload_tensors):
125        arrays.append(tensor.get_memory_area().reshape(list(tensor.GetShape())))
126        logging.info("Workload tensor {} shape: {}".format(index, tensor.GetShape()))
127
128    return arrays
129