1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace split {
30
31 struct OpContext {
OpContexttflite::ops::builtin::split::OpContext32 OpContext(TfLiteContext* context, TfLiteNode* node) {
33 params = reinterpret_cast<TfLiteSplitParams*>(node->builtin_data);
34 axis = GetInput(context, node, 0);
35 input = GetInput(context, node, 1);
36 }
37 TfLiteSplitParams* params;
38 const TfLiteTensor* axis;
39 const TfLiteTensor* input;
40 };
41
UseDynamicOutputTensors(TfLiteContext * context,TfLiteNode * node)42 TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
43 for (int i = 0; i < NumOutputs(node); ++i) {
44 TfLiteTensor* tensor;
45 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
46 SetTensorToDynamic(tensor);
47 }
48 return kTfLiteOk;
49 }
50
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * axis,const TfLiteTensor * input,int num_splits)51 TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
52 const TfLiteTensor* axis,
53 const TfLiteTensor* input, int num_splits) {
54 int axis_value = GetTensorData<int>(axis)[0];
55 if (axis_value < 0) {
56 axis_value += NumDimensions(input);
57 }
58
59 TF_LITE_ENSURE(context, axis_value >= 0);
60 TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
61
62 const int input_size = SizeOfDimension(input, axis_value);
63 TF_LITE_ENSURE(context, num_splits != 0);
64 TF_LITE_ENSURE_MSG(context, input_size % num_splits == 0,
65 "Not an even split");
66 const int slice_size = input_size / num_splits;
67
68 for (int i = 0; i < NumOutputs(node); ++i) {
69 TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
70 output_dims->data[axis_value] = slice_size;
71 TfLiteTensor* output;
72 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
73 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
74 }
75
76 return kTfLiteOk;
77 }
78
Prepare(TfLiteContext * context,TfLiteNode * node)79 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
80 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
81
82 OpContext op_context(context, node);
83
84 TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
85
86 auto input_type = op_context.input->type;
87 TF_LITE_ENSURE(context,
88 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
89 input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
90 input_type == kTfLiteInt32);
91 for (int i = 0; i < NumOutputs(node); ++i) {
92 TfLiteTensor* tensor;
93 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
94 tensor->type = input_type;
95 }
96
97 // If we know the contents of the 'axis' tensor, resize all outputs.
98 // Otherwise, wait until Eval().
99 if (IsConstantTensor(op_context.axis)) {
100 return ResizeOutputTensors(context, node, op_context.axis, op_context.input,
101 op_context.params->num_splits);
102 } else {
103 return UseDynamicOutputTensors(context, node);
104 }
105 }
106
Eval(TfLiteContext * context,TfLiteNode * node)107 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
108 OpContext op_context(context, node);
109
110 // When the 'axis' tensor is non-const we can't resize output tensors in
111 // Prepare(), and we have to do it now.
112 if (!IsConstantTensor(op_context.axis)) {
113 TF_LITE_ENSURE_OK(
114 context,
115 ResizeOutputTensors(context, node, op_context.axis, op_context.input,
116 op_context.params->num_splits));
117 }
118
119 int axis_value = GetTensorData<int>(op_context.axis)[0];
120 if (axis_value < 0) {
121 axis_value += NumDimensions(op_context.input);
122 }
123
124 TF_LITE_ENSURE(context, axis_value >= 0);
125 TF_LITE_ENSURE(context, axis_value < NumDimensions(op_context.input));
126
127 // TODO(b/173221795): Our usage of VectorOfTensors could be optimized by
128 // calculating it in Prepare, unless we defer shape calculation.
129 // We can improve the optimized_ops version to handle other
130 // cases too.
131 #define TF_LITE_SPLIT(scalar) \
132 VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
133 tflite::SplitParams op_params; \
134 op_params.num_split = NumOutputs(node); \
135 op_params.axis = axis_value; \
136 reference_ops::Split(op_params, GetTensorShape(op_context.input), \
137 GetTensorData<scalar>(op_context.input), \
138 all_outputs.shapes(), all_outputs.data());
139
140 switch (op_context.input->type) {
141 case kTfLiteFloat32: {
142 TF_LITE_SPLIT(float);
143 break;
144 }
145 case kTfLiteUInt8: {
146 TF_LITE_SPLIT(uint8_t);
147 break;
148 }
149 case kTfLiteInt8: {
150 TF_LITE_SPLIT(int8_t);
151 break;
152 }
153 case kTfLiteInt16: {
154 TF_LITE_SPLIT(int16_t);
155 break;
156 }
157 case kTfLiteInt32: {
158 TF_LITE_SPLIT(int32_t);
159 break;
160 }
161 default:
162 TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
163 TfLiteTypeGetName(op_context.input->type));
164 return kTfLiteError;
165 }
166 #undef TF_LITE_SPLIT
167
168 return kTfLiteOk;
169 }
170
171 } // namespace split
172
Register_SPLIT()173 TfLiteRegistration* Register_SPLIT() {
174 static TfLiteRegistration r = {nullptr, nullptr, split::Prepare, split::Eval};
175 return &r;
176 }
177
178 } // namespace builtin
179 } // namespace ops
180 } // namespace tflite
181