xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/split.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace split {
30 
31 struct OpContext {
OpContexttflite::ops::builtin::split::OpContext32   OpContext(TfLiteContext* context, TfLiteNode* node) {
33     params = reinterpret_cast<TfLiteSplitParams*>(node->builtin_data);
34     axis = GetInput(context, node, 0);
35     input = GetInput(context, node, 1);
36   }
37   TfLiteSplitParams* params;
38   const TfLiteTensor* axis;
39   const TfLiteTensor* input;
40 };
41 
UseDynamicOutputTensors(TfLiteContext * context,TfLiteNode * node)42 TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
43   for (int i = 0; i < NumOutputs(node); ++i) {
44     TfLiteTensor* tensor;
45     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
46     SetTensorToDynamic(tensor);
47   }
48   return kTfLiteOk;
49 }
50 
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * axis,const TfLiteTensor * input,int num_splits)51 TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
52                                  const TfLiteTensor* axis,
53                                  const TfLiteTensor* input, int num_splits) {
54   int axis_value = GetTensorData<int>(axis)[0];
55   if (axis_value < 0) {
56     axis_value += NumDimensions(input);
57   }
58 
59   TF_LITE_ENSURE(context, axis_value >= 0);
60   TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
61 
62   const int input_size = SizeOfDimension(input, axis_value);
63   TF_LITE_ENSURE(context, num_splits != 0);
64   TF_LITE_ENSURE_MSG(context, input_size % num_splits == 0,
65                      "Not an even split");
66   const int slice_size = input_size / num_splits;
67 
68   for (int i = 0; i < NumOutputs(node); ++i) {
69     TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
70     output_dims->data[axis_value] = slice_size;
71     TfLiteTensor* output;
72     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
73     TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
74   }
75 
76   return kTfLiteOk;
77 }
78 
Prepare(TfLiteContext * context,TfLiteNode * node)79 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
80   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
81 
82   OpContext op_context(context, node);
83 
84   TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
85 
86   auto input_type = op_context.input->type;
87   TF_LITE_ENSURE(context,
88                  input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
89                      input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
90                      input_type == kTfLiteInt32);
91   for (int i = 0; i < NumOutputs(node); ++i) {
92     TfLiteTensor* tensor;
93     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
94     tensor->type = input_type;
95   }
96 
97   // If we know the contents of the 'axis' tensor, resize all outputs.
98   // Otherwise, wait until Eval().
99   if (IsConstantTensor(op_context.axis)) {
100     return ResizeOutputTensors(context, node, op_context.axis, op_context.input,
101                                op_context.params->num_splits);
102   } else {
103     return UseDynamicOutputTensors(context, node);
104   }
105 }
106 
Eval(TfLiteContext * context,TfLiteNode * node)107 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
108   OpContext op_context(context, node);
109 
110   // When the 'axis' tensor is non-const we can't resize output tensors in
111   // Prepare(), and we have to do it now.
112   if (!IsConstantTensor(op_context.axis)) {
113     TF_LITE_ENSURE_OK(
114         context,
115         ResizeOutputTensors(context, node, op_context.axis, op_context.input,
116                             op_context.params->num_splits));
117   }
118 
119   int axis_value = GetTensorData<int>(op_context.axis)[0];
120   if (axis_value < 0) {
121     axis_value += NumDimensions(op_context.input);
122   }
123 
124   TF_LITE_ENSURE(context, axis_value >= 0);
125   TF_LITE_ENSURE(context, axis_value < NumDimensions(op_context.input));
126 
127   // TODO(b/173221795): Our usage of VectorOfTensors could be optimized by
128   // calculating it in Prepare, unless we defer shape calculation.
129   // We can improve the optimized_ops version to handle other
130   // cases too.
131 #define TF_LITE_SPLIT(scalar)                                       \
132   VectorOfTensors<scalar> all_outputs(*context, *node->outputs);    \
133   tflite::SplitParams op_params;                                    \
134   op_params.num_split = NumOutputs(node);                           \
135   op_params.axis = axis_value;                                      \
136   reference_ops::Split(op_params, GetTensorShape(op_context.input), \
137                        GetTensorData<scalar>(op_context.input),     \
138                        all_outputs.shapes(), all_outputs.data());
139 
140   switch (op_context.input->type) {
141     case kTfLiteFloat32: {
142       TF_LITE_SPLIT(float);
143       break;
144     }
145     case kTfLiteUInt8: {
146       TF_LITE_SPLIT(uint8_t);
147       break;
148     }
149     case kTfLiteInt8: {
150       TF_LITE_SPLIT(int8_t);
151       break;
152     }
153     case kTfLiteInt16: {
154       TF_LITE_SPLIT(int16_t);
155       break;
156     }
157     case kTfLiteInt32: {
158       TF_LITE_SPLIT(int32_t);
159       break;
160     }
161     default:
162       TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
163                          TfLiteTypeGetName(op_context.input->type));
164       return kTfLiteError;
165   }
166 #undef TF_LITE_SPLIT
167 
168   return kTfLiteOk;
169 }
170 
171 }  // namespace split
172 
Register_SPLIT()173 TfLiteRegistration* Register_SPLIT() {
174   static TfLiteRegistration r = {nullptr, nullptr, split::Prepare, split::Eval};
175   return &r;
176 }
177 
178 }  // namespace builtin
179 }  // namespace ops
180 }  // namespace tflite
181