1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 
18 #include <algorithm>
19 #include <functional>
20 
21 #include "tensorflow_lite_support/custom_ops/kernel/unsorted_segment.h"
22 
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace custom {
30 namespace unsorted_segment {
31 
32 enum SegmentType {
33   kSegmentMax,
34   kSegmentMin,
35   kSegmentProd,
36   kSegmentSum,
37 };
38 
39 static const int kInputDataTensor = 0;
40 static const int kInputSegmentIdsTensor = 1;
41 static const int kInputNumSegmentsTensor = 2;
42 static const int kOutputTensor = 0;
43 
IsConstantOrPersistentTensor(const TfLiteTensor * tensor)44 inline bool IsConstantOrPersistentTensor(const TfLiteTensor* tensor) {
45   return tflite::IsConstantTensor(tensor) ||
46          (tensor->allocation_type == kTfLitePersistentRo);
47 }
48 
49 template <typename T, template <typename T2> typename Op>
UnsortedSegmentRef(const tflite::RuntimeShape & input_shape,const T * input_data,const tflite::RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const tflite::RuntimeShape & output_shape,T * output_data)50 void UnsortedSegmentRef(const tflite::RuntimeShape& input_shape,
51                                const T* input_data,
52                                const tflite::RuntimeShape& segment_ids_shape,
53                                const int32_t* segment_ids_data,
54                                const tflite::RuntimeShape& output_shape,
55                                T* output_data) {
56   for (int i = 0; i < output_shape.FlatSize(); ++i) {
57     output_data[i] = Op<T>::kInitialValue;
58   }
59   Op<T> op;
60   int segment_flat_size = 1;
61   for (int i = 1; i < output_shape.DimensionsCount(); ++i) {
62     segment_flat_size *= output_shape.Dims(i);
63   }
64   for (int i = 0; i < segment_ids_shape.FlatSize(); i++) {
65     int output_index = segment_ids_data[i];
66     if (output_index < 0) continue;
67     for (int j = 0; j < segment_flat_size; ++j) {
68       output_data[output_index * segment_flat_size + j] =
69           op(output_data[output_index * segment_flat_size + j],
70              input_data[i * segment_flat_size + j]);
71     }
72   }
73 }
74 
ResizeOutputTensor(TfLiteContext * context,const TfLiteTensor * data,const TfLiteTensor * segment_ids,const TfLiteTensor * num_segments,TfLiteTensor * output)75 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
76                                 const TfLiteTensor* data,
77                                 const TfLiteTensor* segment_ids,
78                                 const TfLiteTensor* num_segments,
79                                 TfLiteTensor* output) {
80   // The shape of segment_ids is permitted to be any non-empty prefix of
81   // the input data's shape. The shape of output's first dimension is always
82   // equal to num_segments. The remaining dimensions of output's shape are then
83   // taken to be the suffix of input shape after rank(segment_ids)th position.
84   // Public facing tensorflow erroneously describe unsorted_segment ops as only
85   // supporting segment_ids of rank 1, however tensorflow implementation
86   // supports higher dimensional segment_ids as described.
87   const int segment_ids_rank = tflite::NumDimensions(segment_ids);
88   const int data_rank = tflite::NumDimensions(data);
89   TF_LITE_ENSURE(context, segment_ids_rank <= data_rank);
90   for (int i = 0; i < segment_ids_rank; ++i) {
91     // segment_ids shape must be prefix of data shape.
92     TF_LITE_ENSURE_EQ(context, segment_ids->dims->data[i], data->dims->data[i]);
93   }
94   TF_LITE_ENSURE(context, (num_segments->dims->size == 1 &&
95                            num_segments->dims->data[0] == 1) ||
96                               num_segments->dims->size == 0);
97   // num_segments can be thought of as number of buckets (segments) in output,
98   // where each segment is the reduction of all elements mapped to that
99   // segment_ids. The shape of said elements is the respective
100   // suffix of the data shape.
101   int32_t num_segments_ = tflite::GetTensorData<int32_t>(num_segments)[0];
102   const int num_segment_ids = tflite::NumElements(segment_ids);
103   int max_index = -1;
104   for (int i = 0; i < num_segment_ids; i++) {
105     max_index = std::max(tflite::GetTensorData<int32_t>(segment_ids)[i], max_index);
106   }
107   // num_segments_ must be at greater than max_index else would map elements
108   // to non existent output segments.
109   TF_LITE_ENSURE(context, max_index < num_segments_);
110   const int output_rank = data_rank - segment_ids_rank + 1;
111   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
112   output_shape->data[0] = num_segments_;
113   // output_shape[1:] should be data_shape[Rank(segment_ids):]
114   for (int i = segment_ids_rank; i < data_rank; ++i) {
115     output_shape->data[i - segment_ids_rank + 1] = data->dims->data[i];
116   }
117   return context->ResizeTensor(context, output, output_shape);
118 }
119 
Prepare(TfLiteContext * context,TfLiteNode * node)120 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
121   TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 3);
122   TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
123   const TfLiteTensor* data;
124   TF_LITE_ENSURE_OK(context,
125                     tflite::GetInputSafe(context, node, kInputDataTensor, &data));
126   const TfLiteTensor* segment_ids;
127   TF_LITE_ENSURE_OK(context, tflite::GetInputSafe(context, node, kInputSegmentIdsTensor,
128                                           &segment_ids));
129   const TfLiteTensor* num_segments;
130   TF_LITE_ENSURE_OK(
131       context,
132       tflite::GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments));
133   TfLiteTensor* output;
134   TF_LITE_ENSURE_OK(context,
135                     tflite::GetOutputSafe(context, node, kOutputTensor, &output));
136   TF_LITE_ENSURE(context,
137                  data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
138   TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
139   TF_LITE_ENSURE_EQ(context, num_segments->type, kTfLiteInt32);
140 
141   if (tflite::IsDynamicTensor(data) || !IsConstantOrPersistentTensor(segment_ids) ||
142       !IsConstantOrPersistentTensor(num_segments)) {
143     tflite::SetTensorToDynamic(output);
144     return kTfLiteOk;
145   }
146   return ResizeOutputTensor(context, data, segment_ids, num_segments, output);
147 }
148 
149 template <typename T>
150 struct SegmenMax {
operator ()tflite::ops::custom::unsorted_segment::SegmenMax151   inline T operator()(const T& a, const T& b) const { return std::max(a, b); }
152   static constexpr T kInitialValue = std::numeric_limits<T>::lowest();
153 };
154 
155 template <typename T>
156 struct SegmenMin {
operator ()tflite::ops::custom::unsorted_segment::SegmenMin157   inline T operator()(const T& a, const T& b) const { return std::min(a, b); }
158   static constexpr T kInitialValue = std::numeric_limits<T>::max();
159 };
160 
161 template <typename T>
162 struct SegmenProd {
operator ()tflite::ops::custom::unsorted_segment::SegmenProd163   inline T operator()(const T& a, const T& b) const { return a * b; }
164   static constexpr T kInitialValue = T(1);
165 };
166 
167 template <typename T>
168 struct SegmenSum {
operator ()tflite::ops::custom::unsorted_segment::SegmenSum169   inline T operator()(const T& a, const T& b) const { return a + b; }
170   static constexpr T kInitialValue = T(0);
171 };
172 
173 template <typename T>
EvalType(TfLiteContext * context,const tflite::RuntimeShape & input_shape,const T * input_data,const tflite::RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const tflite::RuntimeShape & output_shape,T * output_data,SegmentType segment_type)174 TfLiteStatus EvalType(TfLiteContext* context, const tflite::RuntimeShape& input_shape,
175                       const T* input_data,
176                       const tflite::RuntimeShape& segment_ids_shape,
177                       const int32_t* segment_ids_data,
178                       const tflite::RuntimeShape& output_shape, T* output_data,
179                       SegmentType segment_type) {
180   switch (segment_type) {
181     case kSegmentProd:
182       unsorted_segment::UnsortedSegmentRef<T, SegmenProd>(
183           input_shape, input_data, segment_ids_shape, segment_ids_data,
184           output_shape, output_data);
185       break;
186     case kSegmentMax:
187       unsorted_segment::UnsortedSegmentRef<T, SegmenMax>(
188           input_shape, input_data, segment_ids_shape, segment_ids_data,
189           output_shape, output_data);
190       break;
191     case kSegmentSum:
192       unsorted_segment::UnsortedSegmentRef<T, SegmenSum>(
193           input_shape, input_data, segment_ids_shape, segment_ids_data,
194           output_shape, output_data);
195       break;
196     case kSegmentMin:
197       unsorted_segment::UnsortedSegmentRef<T, SegmenMin>(
198           input_shape, input_data, segment_ids_shape, segment_ids_data,
199           output_shape, output_data);
200       break;
201     default:
202       TF_LITE_KERNEL_LOG(context, "Not recognized segment type: %d",
203                          segment_type);
204       return kTfLiteError;
205   }
206   return kTfLiteOk;
207 }
208 
EvalGeneric(TfLiteContext * context,TfLiteNode * node,SegmentType segment_type)209 TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node,
210                          SegmentType segment_type) {
211   const TfLiteTensor* data;
212   TF_LITE_ENSURE_OK(context,
213                     tflite::GetInputSafe(context, node, kInputDataTensor, &data));
214   const TfLiteTensor* segment_ids;
215   TF_LITE_ENSURE_OK(context, tflite::GetInputSafe(context, node, kInputSegmentIdsTensor,
216                                           &segment_ids));
217   const TfLiteTensor* num_segments;
218   TF_LITE_ENSURE_OK(
219       context,
220       tflite::GetInputSafe(context, node, kInputNumSegmentsTensor, &num_segments));
221   TfLiteTensor* output;
222   TF_LITE_ENSURE_OK(context,
223                     tflite::GetOutputSafe(context, node, kOutputTensor, &output));
224 
225   if (tflite::IsDynamicTensor(output)) {
226     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, data, segment_ids,
227                                                   num_segments, output));
228   }
229   TF_LITE_ENSURE_EQ(context, tflite::GetTensorShape(data).Dims(0),
230                     tflite::GetTensorShape(segment_ids).Dims(0));
231 
232 #define TF_LITE_UNSORTED_SEGMENT(dtype)                                        \
233   EvalType<dtype>(context, tflite::GetTensorShape(data), tflite::GetTensorData<dtype>(data),   \
234                   tflite::GetTensorShape(segment_ids),                                 \
235                   tflite::GetTensorData<int32_t>(segment_ids), tflite::GetTensorShape(output), \
236                   tflite::GetTensorData<dtype>(output), segment_type);
237   switch (data->type) {
238     case kTfLiteInt32:
239       TF_LITE_UNSORTED_SEGMENT(int32_t);
240       break;
241     case kTfLiteFloat32:
242       TF_LITE_UNSORTED_SEGMENT(float);
243       break;
244     default:
245       TF_LITE_KERNEL_LOG(
246           context, "Currently UnsortedSegment doesn't support data type: %s",
247           TfLiteTypeGetName(data->type));
248       return kTfLiteError;
249   }
250 #undef TF_LITE_UNSORTED_SEGMENT
251   return kTfLiteOk;
252 }
253 
EvalProd(TfLiteContext * context,TfLiteNode * node)254 TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
255   return EvalGeneric(context, node, kSegmentProd);
256 }
EvalMax(TfLiteContext * context,TfLiteNode * node)257 TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
258   return EvalGeneric(context, node, kSegmentMax);
259 }
EvalSum(TfLiteContext * context,TfLiteNode * node)260 TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
261   return EvalGeneric(context, node, kSegmentSum);
262 }
EvalMin(TfLiteContext * context,TfLiteNode * node)263 TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) {
264   return EvalGeneric(context, node, kSegmentMin);
265 }
266 
267 }  // namespace unsorted_segment
268 
Register_UNSORTED_SEGMENT_PROD()269 TfLiteRegistration* Register_UNSORTED_SEGMENT_PROD() {
270   static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
271                                  unsorted_segment::EvalProd};
272   return &r;
273 }
274 
Register_UNSORTED_SEGMENT_MAX()275 TfLiteRegistration* Register_UNSORTED_SEGMENT_MAX() {
276   static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
277                                  unsorted_segment::EvalMax};
278   return &r;
279 }
280 
Register_UNSORTED_SEGMENT_SUM()281 TfLiteRegistration* Register_UNSORTED_SEGMENT_SUM() {
282   static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
283                                  unsorted_segment::EvalSum};
284   return &r;
285 }
286 
Register_UNSORTED_SEGMENT_MIN()287 TfLiteRegistration* Register_UNSORTED_SEGMENT_MIN() {
288   static TfLiteRegistration r = {nullptr, nullptr, unsorted_segment::Prepare,
289                                  unsorted_segment::EvalMin};
290   return &r;
291 }
292 
293 }  // namespace custom
294 }  // namespace ops
295 }  // namespace tflite