xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/one_hot.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 
23 namespace tflite {
24 namespace ops {
25 namespace builtin {
26 namespace one_hot {
27 
28 constexpr int kIndicesTensor = 0;
29 constexpr int kDepthTensor = 1;
30 constexpr int kOnValueTensor = 2;
31 constexpr int kOffValueTensor = 3;
32 constexpr int kOutputTensor = 0;
33 
34 // Convenience utility for destructuring a node into the appropriate tensors and
35 // data for the op. Note that this destructuring is quite cheap, so we can avoid
36 // allocating op-specific, persistent data on the heap.
37 struct OneHotContext {
OneHotContexttflite::ops::builtin::one_hot::OneHotContext38   OneHotContext(TfLiteContext* context, TfLiteNode* node) {
39     indices = GetInput(context, node, kIndicesTensor);
40     depth = GetInput(context, node, kDepthTensor);
41     on_value = GetInput(context, node, kOnValueTensor);
42     off_value = GetInput(context, node, kOffValueTensor);
43     output = GetOutput(context, node, kOutputTensor);
44 
45     const auto* params =
46         reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data);
47     const int indices_dims = indices->dims->size;
48     axis = (params->axis == -1) ? indices_dims : params->axis;
49     output_dims = indices_dims + 1;
50     dtype = on_value->type;
51   }
52 
53   const TfLiteTensor* indices;
54   const TfLiteTensor* depth;
55   const TfLiteTensor* on_value;
56   const TfLiteTensor* off_value;
57   TfLiteTensor* output;
58   int axis;
59   int output_dims;
60   TfLiteType dtype;
61 };
62 
63 template <typename T, typename TI>
OneHotComputeImpl(const OneHotContext & op_context)64 void OneHotComputeImpl(const OneHotContext& op_context) {
65   // prefix_dim_size == # of elements before the axis
66   // depth == # of elements per axis
67   // suffix_dim_size == # of elements after the axis
68   int prefix_dim_size = 1;
69   for (int i = 0; i < op_context.axis; ++i) {
70     prefix_dim_size *= op_context.indices->dims->data[i];
71   }
72   if (prefix_dim_size == 0) {
73     // If indices tensor is degenerate, return a degenerate tensor, just like
74     // TensorFlow does.
75     return;
76   }
77   const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
78   const int depth = *op_context.depth->data.i32;
79 
80   const T on_value = *GetTensorData<T>(op_context.on_value);
81   const T off_value = *GetTensorData<T>(op_context.off_value);
82 
83   // View the indices as a matrix of size:
84   //     prefix_dim_size x suffix_dim_size
85   // View the output as a matrix of size:
86   //     prefix_dim_size x depth x suffix_dim_size
87   // Then the output is:
88   //     output(i, j, k) == (indices(i, k) == j) ? on : off
89   T* output = GetTensorData<T>(op_context.output);
90   const TI* indices = GetTensorData<TI>(op_context.indices);
91   for (int i = 0; i < prefix_dim_size; ++i) {
92     for (int j = 0; j < depth; ++j) {
93       for (int k = 0; k < suffix_dim_size; ++k, ++output) {
94         *output = static_cast<int>(indices[i * suffix_dim_size + k]) == j
95                       ? on_value
96                       : off_value;
97       }
98     }
99   }
100 }
101 
102 template <typename T>
OneHotCompute(const OneHotContext & op_context)103 void OneHotCompute(const OneHotContext& op_context) {
104   if (op_context.indices->type == kTfLiteInt64) {
105     OneHotComputeImpl<T, int64_t>(op_context);
106   } else {
107     OneHotComputeImpl<T, int>(op_context);
108   }
109 }
110 
ResizeOutputTensor(TfLiteContext * context,const OneHotContext & op_context)111 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
112                                 const OneHotContext& op_context) {
113   TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0);
114   TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims);
115   for (int i = 0; i < op_context.output_dims; ++i) {
116     if (i < op_context.axis) {
117       output_size->data[i] = op_context.indices->dims->data[i];
118     } else if (i == op_context.axis) {
119       output_size->data[i] = *op_context.depth->data.i32;
120     } else {
121       output_size->data[i] = op_context.indices->dims->data[i - 1];
122     }
123   }
124   return context->ResizeTensor(context, op_context.output, output_size);
125 }
126 
Prepare(TfLiteContext * context,TfLiteNode * node)127 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
128   TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
129   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
130 
131   OneHotContext op_context{context, node};
132   switch (op_context.dtype) {
133     // TODO(b/111744875): Support uint8 and quantization.
134     case kTfLiteFloat32:
135     case kTfLiteInt16:
136     case kTfLiteInt32:
137     case kTfLiteInt64:
138     case kTfLiteInt8:
139     case kTfLiteUInt8:
140     case kTfLiteBool:
141       op_context.output->type = op_context.dtype;
142       break;
143     default:
144       TF_LITE_KERNEL_LOG(context, "Unknown output data type: %s",
145                          TfLiteTypeGetName(op_context.dtype));
146       return kTfLiteError;
147   }
148 
149   TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 ||
150                               op_context.indices->type == kTfLiteInt64);
151   TF_LITE_ENSURE(context, op_context.axis >= 0 &&
152                               op_context.axis < op_context.output_dims);
153   TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
154   TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
155   TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
156   TF_LITE_ENSURE_TYPES_EQ(context, op_context.on_value->type, op_context.dtype);
157   TF_LITE_ENSURE_TYPES_EQ(context, op_context.off_value->type,
158                           op_context.dtype);
159 
160   if (!IsConstantTensor(op_context.depth)) {
161     SetTensorToDynamic(op_context.output);
162     return kTfLiteOk;
163   }
164 
165   return ResizeOutputTensor(context, op_context);
166 }
167 
Eval(TfLiteContext * context,TfLiteNode * node)168 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
169   OneHotContext op_context{context, node};
170 
171   if (IsDynamicTensor(op_context.output)) {
172     ResizeOutputTensor(context, op_context);
173   }
174 
175   switch (op_context.output->type) {
176     case kTfLiteFloat32:
177       OneHotCompute<float>(op_context);
178       break;
179     case kTfLiteInt32:
180       OneHotCompute<int>(op_context);
181       break;
182     case kTfLiteInt64:
183       OneHotCompute<int64_t>(op_context);
184       break;
185     case kTfLiteInt8:
186       OneHotCompute<int8_t>(op_context);
187       break;
188     case kTfLiteUInt8:
189       OneHotCompute<uint8_t>(op_context);
190       break;
191     case kTfLiteBool:
192       OneHotCompute<bool>(op_context);
193       break;
194     default:
195       return kTfLiteError;
196   }
197 
198   return kTfLiteOk;
199 }
200 
201 }  // namespace one_hot
202 
Register_ONE_HOT()203 TfLiteRegistration* Register_ONE_HOT() {
204   static TfLiteRegistration r = {
205       nullptr,
206       nullptr,
207       one_hot::Prepare,
208       one_hot::Eval,
209   };
210   return &r;
211 }
212 
213 }  // namespace builtin
214 }  // namespace ops
215 }  // namespace tflite
216