xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/pack.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace pack {
30 namespace {
31 
32 constexpr int kOutputTensor = 0;
33 
Prepare(TfLiteContext * context,TfLiteNode * node)34 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
35   TfLitePackParams* data =
36       reinterpret_cast<TfLitePackParams*>(node->builtin_data);
37 
38   TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
39   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
40 
41   const TfLiteTensor* input0;
42   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input0));
43   const int dimension_size = NumDimensions(input0) + 1;
44   if (data->axis < 0) {
45     data->axis += dimension_size;
46   }
47   TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
48   TF_LITE_ENSURE(context, data->axis >= 0);
49 
50   if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
51       input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt8 &&
52       input0->type != kTfLiteInt16 && input0->type != kTfLiteInt64) {
53     TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by pack.",
54                        TfLiteTypeGetName(input0->type));
55     return kTfLiteError;
56   }
57   // Make sure all inputs have the same shape and type.
58   for (int i = 1; i < data->values_count; ++i) {
59     const TfLiteTensor* input;
60     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
61     TF_LITE_ENSURE(context, HaveSameShapes(input0, input));
62     TF_LITE_ENSURE_TYPES_EQ(context, input0->type, input->type);
63   }
64 
65   // Resize output. rank R will become rank R + 1
66   const TfLiteIntArray* input_shape = input0->dims;
67   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size);
68   int i = 0;
69   for (int index = 0; index < dimension_size; ++index) {
70     if (index == data->axis) {
71       output_shape->data[index] = data->values_count;
72     } else {
73       output_shape->data[index] = input_shape->data[i++];
74     }
75   }
76 
77   TfLiteTensor* output;
78   TF_LITE_ENSURE_OK(context,
79                     GetOutputSafe(context, node, kOutputTensor, &output));
80   TF_LITE_ENSURE_TYPES_EQ(context, output->type, input0->type);
81 
82   // Guarantee input/output quantization params match as we do not support
83   // packing quantized tensors.
84   for (int i = 0; i < data->values_count; i++) {
85     const TfLiteTensor* input;
86     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
87     TF_LITE_ENSURE_EQ(context, input->params.zero_point,
88                       output->params.zero_point);
89     TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
90   }
91 
92   return context->ResizeTensor(context, output, output_shape);
93 }
94 
95 template <typename T>
PackImpl(TfLiteContext * context,TfLiteNode * node,TfLiteTensor * output,int values_count,int axis)96 TfLiteStatus PackImpl(TfLiteContext* context, TfLiteNode* node,
97                       TfLiteTensor* output, int values_count, int axis) {
98   TF_LITE_ENSURE(context, axis >= 0);
99 
100   VectorOfTensors<T> all_inputs(*context, *node->inputs);
101   tflite::PackParams op_params;
102   op_params.axis = axis;
103   op_params.inputs_count = values_count;
104 
105   reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(),
106                          GetTensorShape(output), GetTensorData<T>(output));
107   return kTfLiteOk;
108 }
109 
Eval(TfLiteContext * context,TfLiteNode * node)110 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
111   const TfLitePackParams* data =
112       reinterpret_cast<TfLitePackParams*>(node->builtin_data);
113 
114   TfLiteTensor* output;
115   TF_LITE_ENSURE_OK(context,
116                     GetOutputSafe(context, node, kOutputTensor, &output));
117   switch (output->type) {
118     case kTfLiteFloat32: {
119       return PackImpl<float>(context, node, output, data->values_count,
120                              data->axis);
121     }
122     case kTfLiteUInt8: {
123       return PackImpl<uint8_t>(context, node, output, data->values_count,
124                                data->axis);
125     }
126     case kTfLiteInt8: {
127       return PackImpl<int8_t>(context, node, output, data->values_count,
128                               data->axis);
129     }
130     case kTfLiteInt16: {
131       return PackImpl<int16_t>(context, node, output, data->values_count,
132                                data->axis);
133     }
134     case kTfLiteInt32: {
135       return PackImpl<int32_t>(context, node, output, data->values_count,
136                                data->axis);
137     }
138     case kTfLiteInt64: {
139       return PackImpl<int64_t>(context, node, output, data->values_count,
140                                data->axis);
141     }
142     default: {
143       TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by pack.",
144                          TfLiteTypeGetName(output->type));
145       return kTfLiteError;
146     }
147   }
148 }
149 
150 }  // namespace
151 }  // namespace pack
152 
Register_PACK()153 TfLiteRegistration* Register_PACK() {
154   static TfLiteRegistration r = {nullptr, nullptr, pack::Prepare, pack::Eval};
155   return &r;
156 }
157 
158 }  // namespace builtin
159 }  // namespace ops
160 }  // namespace tflite
161