1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdint.h>
17
18 #include <algorithm>
19 #include <functional>
20
21 #include "tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h"
22
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26
27 namespace tflite {
28 namespace ops {
29 namespace custom {
30 namespace unsorted_segment {
31
32 enum SegmentType {
33 kSegmentMax,
34 kSegmentMin,
35 kSegmentProd,
36 kSegmentSum,
37 };
38
39 static const int kInputDataTensor = 0;
40 static const int kInputSegmentIdsTensor = 1;
41 static const int kInputNumSegmentsTensor = 2;
42 static const int kOutputTensor = 0;
43
IsConstantOrPersistentTensor(const TfLiteTensor * tensor)44 inline bool IsConstantOrPersistentTensor(const TfLiteTensor* tensor) {
45 return tflite::IsConstantTensor(tensor) ||
46 (tensor->allocation_type == kTfLitePersistentRo);
47 }
48
49 template <typename T, template <typename T2> typename Op>
UnsortedSegmentRef(const tflite::RuntimeShape & input_shape,const T * input_data,const tflite::RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const tflite::RuntimeShape & output_shape,T * output_data)50 void UnsortedSegmentRef(const tflite::RuntimeShape& input_shape,
51 const T* input_data,
52 const tflite::RuntimeShape& segment_ids_shape,
53 const int32_t* segment_ids_data,
54 const tflite::RuntimeShape& output_shape,
55 T* output_data) {
56 for (int i = 0; i < output_shape.FlatSize(); ++i) {
57 output_data[i] = Op<T>::kInitialValue;
58 }
59 Op<T> op;
60 int segment_flat_size = 1;
61 for (int i = 1; i < output_shape.DimensionsCount(); ++i) {
62 segment_flat_size *= output_shape.Dims(i);
63 }
64 for (int i = 0; i < segment_ids_shape.FlatSize(); i++) {
65 int output_index = segment_ids_data[i];
66 if (output_index < 0) continue;
67 for (int j = 0; j < segment_flat_size; ++j) {
68 output_data[output_index * segment_flat_size + j] =
69 op(output_data[output_index * segment_flat_size + j],
70 input_data[i * segment_flat_size + j]);
71 }
72 }
73 }
74
ResizeOutputTensor(TfLiteContext * context,const TfLiteTensor * data,const TfLiteTensor * segment_ids,const TfLiteTensor * num_segments,TfLiteTensor * output)75 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
76 const TfLiteTensor* data,
77 const TfLiteTensor* segment_ids,
78 const TfLiteTensor* num_segments,
79 TfLiteTensor* output) {
80 // The shape of segment_ids is permitted to be any non-empty prefix of
81 // the input data's shape. The shape of output's first dimension is always
82 // equal to num_segments. The remaining dimensions of output's shape are then
83 // taken to be the suffix of input shape after rank(segment_ids)th position.
84 // Public facing tensorflow erroneously describe unsorted_segment ops as only
85 // supporting segment_ids of rank 1, however tensorflow implementation
86 // supports higher dimensional segment_ids as described.
87 const int segment_ids_rank = tflite::NumDimensions(segment_ids);
88 const int data_rank = tflite::NumDimensions(data);
89 TF_LITE_ENSURE(context, segment_ids_rank <= data_rank);
90 for (int i = 0; i < segment_ids_rank; ++i) {
91 // segment_ids shape must be prefix of data shape.
92 TF_LITE_ENSURE_EQ(context, segment_ids->dims->data[i], data->dims->data[i]);
93 }
94 TF_LITE_ENSURE(context, (num_segments->dims->size == 1 &&
95 num_segments->dims->data[0] == 1) ||
96 num_segments->dims->size == 0);
97 // num_segments can be thought of as number of buckets (segments) in output,
98 // where each segment is the reduction of all elements mapped to that
99 // segment_ids. The shape of said elements is the respective
100 // suffix of the data shape.
101 int32_t num_segments_ = tflite::GetTensorData<int32_t>(num_segments)[0];
102 const int num_segment_ids = tflite::NumElements(segment_ids);
103 int max_index = -1;
104 for (int i = 0; i < num_segment_ids; i++) {
105 max_index = std::max(tflite::GetTensorData<int32_t>(segment_ids)[i], max_index);
106 }
107 // num_segments_ must be at greater than max_index else would map elements
108 // to non existent output segments.
109 TF_LITE_ENSURE(context, max_index < num_segments_);
110 const int output_rank = data_rank - segment_ids_rank + 1;
111 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
112 output_shape->data[0] = num_segments_;
113 // output_shape[1:] should be data_shape[Rank(segment_ids):]
114 for (int i = segment_ids_rank; i < data_rank; ++i) {
115 output_shape->data[i - segment_ids_rank + 1] = data->dims->data[i];
116 }
117 return context->ResizeTensor(context, output, output_shape);
118 }
119
Prepare(TfLiteContext * context,TfLiteNode * node)120 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
121 TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 3);
122 TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
123 const TfLiteTensor* data;
124 TF_LITE_ENSURE_OK(context,
125 tflite::GetInputSafe(context, node, kInputDataTensor, &data));
126 const TfLiteTensor* segment_ids;
127 TF_LITE_ENSURE_OK(context, tflite::GetInputSafe(context, node, kInputSegmentIdsTensor,
128 &segment_ids));
129 const TfLiteTensor* num_segments;
130 TF_LITE_ENSURE_OK(
131 context,
132 tflite::GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments));
133 TfLiteTensor* output;
134 TF_LITE_ENSURE_OK(context,
135 tflite::GetOutputSafe(context, node, kOutputTensor, &output));
136 TF_LITE_ENSURE(context,
137 data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
138 TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
139 TF_LITE_ENSURE_EQ(context, num_segments->type, kTfLiteInt32);
140
141 if (tflite::IsDynamicTensor(data) || !IsConstantOrPersistentTensor(segment_ids) ||
142 !IsConstantOrPersistentTensor(num_segments)) {
143 tflite::SetTensorToDynamic(output);
144 return kTfLiteOk;
145 }
146 return ResizeOutputTensor(context, data, segment_ids, num_segments, output);
147 }
148
149 template <typename T>
150 struct SegmenMax {
operator ()tflite::ops::custom::unsorted_segment::SegmenMax151 inline T operator()(const T& a, const T& b) const { return std::max(a, b); }
152 static constexpr T kInitialValue = std::numeric_limits<T>::lowest();
153 };
154
155 template <typename T>
156 struct SegmenMin {
operator ()tflite::ops::custom::unsorted_segment::SegmenMin157 inline T operator()(const T& a, const T& b) const { return std::min(a, b); }
158 static constexpr T kInitialValue = std::numeric_limits<T>::max();
159 };
160
161 template <typename T>
162 struct SegmenProd {
operator ()tflite::ops::custom::unsorted_segment::SegmenProd163 inline T operator()(const T& a, const T& b) const { return a * b; }
164 static constexpr T kInitialValue = T(1);
165 };
166
167 template <typename T>
168 struct SegmenSum {
operator ()tflite::ops::custom::unsorted_segment::SegmenSum169 inline T operator()(const T& a, const T& b) const { return a + b; }
170 static constexpr T kInitialValue = T(0);
171 };
172
173 template <typename T>
EvalType(TfLiteContext * context,const tflite::RuntimeShape & input_shape,const T * input_data,const tflite::RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const tflite::RuntimeShape & output_shape,T * output_data,SegmentType segment_type)174 TfLiteStatus EvalType(TfLiteContext* context, const tflite::RuntimeShape& input_shape,
175 const T* input_data,
176 const tflite::RuntimeShape& segment_ids_shape,
177 const int32_t* segment_ids_data,
178 const tflite::RuntimeShape& output_shape, T* output_data,
179 SegmentType segment_type) {
180 switch (segment_type) {
181 case kSegmentProd:
182 unsorted_segment::UnsortedSegmentRef<T, SegmenProd>(
183 input_shape, input_data, segment_ids_shape, segment_ids_data,
184 output_shape, output_data);
185 break;
186 case kSegmentMax:
187 unsorted_segment::UnsortedSegmentRef<T, SegmenMax>(
188 input_shape, input_data, segment_ids_shape, segment_ids_data,
189 output_shape, output_data);
190 break;
191 case kSegmentSum:
192 unsorted_segment::UnsortedSegmentRef<T, SegmenSum>(
193 input_shape, input_data, segment_ids_shape, segment_ids_data,
194 output_shape, output_data);
195 break;
196 case kSegmentMin:
197 unsorted_segment::UnsortedSegmentRef<T, SegmenMin>(
198 input_shape, input_data, segment_ids_shape, segment_ids_data,
199 output_shape, output_data);
200 break;
201 default:
202 TF_LITE_KERNEL_LOG(context, "Not recognized segment type: %d",
203 segment_type);
204 return kTfLiteError;
205 }
206 return kTfLiteOk;
207 }
208
EvalGeneric(TfLiteContext * context,TfLiteNode * node,SegmentType segment_type)209 TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node,
210 SegmentType segment_type) {
211 const TfLiteTensor* data;
212 TF_LITE_ENSURE_OK(context,
213 tflite::GetInputSafe(context, node, kInputDataTensor, &data));
214 const TfLiteTensor* segment_ids;
215 TF_LITE_ENSURE_OK(context, tflite::GetInputSafe(context, node, kInputSegmentIdsTensor,
216 &segment_ids));
217 const TfLiteTensor* num_segments;
218 TF_LITE_ENSURE_OK(
219 context,
220 tflite::GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments));
221 TfLiteTensor* output;
222 TF_LITE_ENSURE_OK(context,
223 tflite::GetOutputSafe(context, node, kOutputTensor, &output));
224
225 if (tflite::IsDynamicTensor(output)) {
226 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, data, segment_ids,
227 num_segments, output));
228 }
229 TF_LITE_ENSURE_EQ(context, tflite::GetTensorShape(data).Dims(0),
230 tflite::GetTensorShape(segment_ids).Dims(0));
231
232 #define TF_LITE_UNSORTED_SEGMENT(dtype) \
233 EvalType<dtype>(context, tflite::GetTensorShape(data), tflite::GetTensorData<dtype>(data), \
234 tflite::GetTensorShape(segment_ids), \
235 tflite::GetTensorData<int32_t>(segment_ids), tflite::GetTensorShape(output), \
236 tflite::GetTensorData<dtype>(output), segment_type);
237 switch (data->type) {
238 case kTfLiteInt32:
239 TF_LITE_UNSORTED_SEGMENT(int32_t);
240 break;
241 case kTfLiteFloat32:
242 TF_LITE_UNSORTED_SEGMENT(float);
243 break;
244 default:
245 TF_LITE_KERNEL_LOG(
246 context, "Currently UnsortedSegment doesn't support data type: %s",
247 TfLiteTypeGetName(data->type));
248 return kTfLiteError;
249 }
250 #undef TF_LITE_UNSORTED_SEGMENT
251 return kTfLiteOk;
252 }
253
EvalProd(TfLiteContext * context,TfLiteNode * node)254 TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
255 return EvalGeneric(context, node, kSegmentProd);
256 }
EvalMax(TfLiteContext * context,TfLiteNode * node)257 TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
258 return EvalGeneric(context, node, kSegmentMax);
259 }
EvalSum(TfLiteContext * context,TfLiteNode * node)260 TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
261 return EvalGeneric(context, node, kSegmentSum);
262 }
EvalMin(TfLiteContext * context,TfLiteNode * node)263 TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) {
264 return EvalGeneric(context, node, kSegmentMin);
265 }
266
267 } // namespace unsorted_segment
268
Register_UNSORTED_SEGMENT_PROD()269 TfLiteRegistration* Register_UNSORTED_SEGMENT_PROD() {
270 static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
271 unsorted_segment::EvalProd};
272 return &r;
273 }
274
Register_UNSORTED_SEGMENT_MAX()275 TfLiteRegistration* Register_UNSORTED_SEGMENT_MAX() {
276 static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
277 unsorted_segment::EvalMax};
278 return &r;
279 }
280
Register_UNSORTED_SEGMENT_SUM()281 TfLiteRegistration* Register_UNSORTED_SEGMENT_SUM() {
282 static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
283 unsorted_segment::EvalSum};
284 return &r;
285 }
286
Register_UNSORTED_SEGMENT_MIN()287 TfLiteRegistration* Register_UNSORTED_SEGMENT_MIN() {
288 static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
289 unsorted_segment::EvalMin};
290 return &r;
291 }
292
293 } // namespace custom
294 } // namespace ops
295 } // namespace tflite