xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/bucketize.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 
18 #include <algorithm>
19 
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/tensor.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace bucketize {
30 namespace {
31 
32 constexpr int kInputTensor = 0;
33 constexpr int kOutputTensor = 0;
34 
35 struct OpData {
36   // boundaries array is owned by the buffer housing TfLiteBucketizeParams.
37   const float* boundaries;
38   int num_boundaries;
39 };
40 
Init(TfLiteContext * context,const char * buffer,size_t length)41 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
42   auto* op_data = new OpData();
43   const auto* params = reinterpret_cast<const TfLiteBucketizeParams*>(buffer);
44 
45   op_data->boundaries = params->boundaries;
46   op_data->num_boundaries = params->num_boundaries;
47   return op_data;
48 }
49 
Free(TfLiteContext * context,void * buffer)50 void Free(TfLiteContext* context, void* buffer) {
51   delete reinterpret_cast<OpData*>(buffer);
52 }
53 
Prepare(TfLiteContext * context,TfLiteNode * node)54 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
55   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
56   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
57   OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
58   if (!std::is_sorted(opdata->boundaries,
59                       opdata->boundaries + opdata->num_boundaries)) {
60     TF_LITE_KERNEL_LOG(context, "Expected sorted boundaries");
61     return kTfLiteError;
62   }
63 
64   const TfLiteTensor* input;
65   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
66 
67   if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
68       input->type != kTfLiteInt64 && input->type != kTfLiteFloat64) {
69     TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by bucketize.",
70                        TfLiteTypeGetName(input->type));
71     return kTfLiteError;
72   }
73 
74   TfLiteTensor* output;
75   TF_LITE_ENSURE_OK(context,
76                     GetOutputSafe(context, node, kOutputTensor, &output));
77   output->type = kTfLiteInt32;
78 
79   TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
80   return context->ResizeTensor(context, output, output_shape);
81 }
82 
83 template <typename T>
Bucketize(const RuntimeShape & input_shape,const T * input_data,const float * boundaries,int num_boundaries,const RuntimeShape & output_shape,int32_t * output_data)84 inline void Bucketize(const RuntimeShape& input_shape, const T* input_data,
85                       const float* boundaries, int num_boundaries,
86                       const RuntimeShape& output_shape, int32_t* output_data) {
87   const int flat_size = MatchingFlatSize(input_shape, output_shape);
88 
89   for (int i = 0; i < flat_size; i++) {
90     auto first_bigger_it = std::upper_bound(
91         boundaries, boundaries + num_boundaries, input_data[i]);
92     output_data[i] = first_bigger_it - boundaries;
93   }
94 }
95 
96 template <typename T>
BucketizeImpl(TfLiteContext * context,TfLiteNode * node)97 TfLiteStatus BucketizeImpl(TfLiteContext* context, TfLiteNode* node) {
98   const TfLiteTensor* input;
99   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
100   OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
101   TfLiteTensor* output;
102   TF_LITE_ENSURE_OK(context,
103                     GetOutputSafe(context, node, kOutputTensor, &output));
104   TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt32);
105 
106   Bucketize<T>(GetTensorShape(input), GetTensorData<T>(input),
107                opdata->boundaries, opdata->num_boundaries,
108                GetTensorShape(output), GetTensorData<int32_t>(output));
109 
110   return kTfLiteOk;
111 }
112 
Eval(TfLiteContext * context,TfLiteNode * node)113 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
114   const TfLiteTensor* input;
115   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
116 
117   switch (input->type) {
118     case kTfLiteFloat32: {
119       return BucketizeImpl<float>(context, node);
120     }
121     case kTfLiteFloat64: {
122       return BucketizeImpl<double>(context, node);
123     }
124     case kTfLiteInt32: {
125       return BucketizeImpl<int32_t>(context, node);
126     }
127     case kTfLiteInt64: {
128       return BucketizeImpl<int64_t>(context, node);
129     }
130     default: {
131       TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by bucketize.",
132                          TfLiteTypeGetName(input->type));
133       return kTfLiteError;
134     }
135   }
136 }
137 
138 }  // namespace
139 }  // namespace bucketize
140 
Register_BUCKETIZE()141 TfLiteRegistration* Register_BUCKETIZE() {
142   static TfLiteRegistration r = {bucketize::Init, bucketize::Free,
143                                  bucketize::Prepare, bucketize::Eval};
144   return &r;
145 }
146 
147 }  // namespace builtin
148 }  // namespace ops
149 }  // namespace tflite
150