1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdint.h>
17
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace pack {
30 namespace {
31
32 constexpr int kOutputTensor = 0;
33
Prepare(TfLiteContext * context,TfLiteNode * node)34 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
35 TfLitePackParams* data =
36 reinterpret_cast<TfLitePackParams*>(node->builtin_data);
37
38 TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
39 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
40
41 const TfLiteTensor* input0;
42 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input0));
43 const int dimension_size = NumDimensions(input0) + 1;
44 if (data->axis < 0) {
45 data->axis += dimension_size;
46 }
47 TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
48 TF_LITE_ENSURE(context, data->axis >= 0);
49
50 if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
51 input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt8 &&
52 input0->type != kTfLiteInt16 && input0->type != kTfLiteInt64) {
53 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by pack.",
54 TfLiteTypeGetName(input0->type));
55 return kTfLiteError;
56 }
57 // Make sure all inputs have the same shape and type.
58 for (int i = 1; i < data->values_count; ++i) {
59 const TfLiteTensor* input;
60 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
61 TF_LITE_ENSURE(context, HaveSameShapes(input0, input));
62 TF_LITE_ENSURE_TYPES_EQ(context, input0->type, input->type);
63 }
64
65 // Resize output. rank R will become rank R + 1
66 const TfLiteIntArray* input_shape = input0->dims;
67 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size);
68 int i = 0;
69 for (int index = 0; index < dimension_size; ++index) {
70 if (index == data->axis) {
71 output_shape->data[index] = data->values_count;
72 } else {
73 output_shape->data[index] = input_shape->data[i++];
74 }
75 }
76
77 TfLiteTensor* output;
78 TF_LITE_ENSURE_OK(context,
79 GetOutputSafe(context, node, kOutputTensor, &output));
80 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input0->type);
81
82 // Guarantee input/output quantization params match as we do not support
83 // packing quantized tensors.
84 for (int i = 0; i < data->values_count; i++) {
85 const TfLiteTensor* input;
86 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
87 TF_LITE_ENSURE_EQ(context, input->params.zero_point,
88 output->params.zero_point);
89 TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
90 }
91
92 return context->ResizeTensor(context, output, output_shape);
93 }
94
95 template <typename T>
PackImpl(TfLiteContext * context,TfLiteNode * node,TfLiteTensor * output,int values_count,int axis)96 TfLiteStatus PackImpl(TfLiteContext* context, TfLiteNode* node,
97 TfLiteTensor* output, int values_count, int axis) {
98 TF_LITE_ENSURE(context, axis >= 0);
99
100 VectorOfTensors<T> all_inputs(*context, *node->inputs);
101 tflite::PackParams op_params;
102 op_params.axis = axis;
103 op_params.inputs_count = values_count;
104
105 reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(),
106 GetTensorShape(output), GetTensorData<T>(output));
107 return kTfLiteOk;
108 }
109
Eval(TfLiteContext * context,TfLiteNode * node)110 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
111 const TfLitePackParams* data =
112 reinterpret_cast<TfLitePackParams*>(node->builtin_data);
113
114 TfLiteTensor* output;
115 TF_LITE_ENSURE_OK(context,
116 GetOutputSafe(context, node, kOutputTensor, &output));
117 switch (output->type) {
118 case kTfLiteFloat32: {
119 return PackImpl<float>(context, node, output, data->values_count,
120 data->axis);
121 }
122 case kTfLiteUInt8: {
123 return PackImpl<uint8_t>(context, node, output, data->values_count,
124 data->axis);
125 }
126 case kTfLiteInt8: {
127 return PackImpl<int8_t>(context, node, output, data->values_count,
128 data->axis);
129 }
130 case kTfLiteInt16: {
131 return PackImpl<int16_t>(context, node, output, data->values_count,
132 data->axis);
133 }
134 case kTfLiteInt32: {
135 return PackImpl<int32_t>(context, node, output, data->values_count,
136 data->axis);
137 }
138 case kTfLiteInt64: {
139 return PackImpl<int64_t>(context, node, output, data->values_count,
140 data->axis);
141 }
142 default: {
143 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by pack.",
144 TfLiteTypeGetName(output->type));
145 return kTfLiteError;
146 }
147 }
148 }
149
150 } // namespace
151 } // namespace pack
152
Register_PACK()153 TfLiteRegistration* Register_PACK() {
154 static TfLiteRegistration r = {nullptr, nullptr, pack::Prepare, pack::Eval};
155 return &r;
156 }
157
158 } // namespace builtin
159 } // namespace ops
160 } // namespace tflite
161