xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/roll.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <math.h>
17 #include <stdint.h>
18 #include <stdlib.h>
19 
20 #include <cstring>
21 #include <vector>
22 
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/internal/portable_tensor.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 
29 namespace tflite {
30 namespace ops {
31 namespace custom {
32 namespace roll {
33 namespace {
34 
35 // A helper function to extract int32_t or int64_t tensor data.
ExtractIntegerVector(const TfLiteTensor * t)36 std::vector<int32_t> ExtractIntegerVector(const TfLiteTensor* t) {
37   TFLITE_DCHECK(t->type == kTfLiteInt32 || t->type == kTfLiteInt64);
38   const RuntimeShape& shape = GetTensorShape(t);
39   std::vector<int32_t> result(shape.FlatSize());
40   if (t->type == kTfLiteInt32) {
41     memcpy(result.data(), t->data.raw_const, t->bytes);
42   } else {
43     const int64_t* data = GetTensorData<int64_t>(t);
44     for (int i = 0; i < result.size(); ++i) {
45       result[i] = static_cast<int32_t>(data[i]);
46     }
47   }
48   return result;
49 }
50 
51 template <typename T>
Pool(const std::vector<int32_t> & shift_map,const RuntimeShape & shape,const TfLiteTensor * input,TfLiteTensor * cache,TfLiteTensor * output)52 inline void Pool(const std::vector<int32_t>& shift_map,
53                  const RuntimeShape& shape, const TfLiteTensor* input,
54                  TfLiteTensor* cache, TfLiteTensor* output) {
55   int stride = 1, outer_size, next_stride;
56   bool in_place_rolling = false;
57   for (int i = shift_map.size() - 1; i >= 0; --i, stride = next_stride) {
58     next_stride = stride * shape.Dims(i);
59     if (shift_map[i] == 0) continue;
60 
61     TFLITE_DCHECK_EQ(shape.FlatSize() % next_stride, 0);
62     outer_size = shape.FlatSize() / next_stride;
63     const TfLiteTensor* source = input;
64     if (in_place_rolling) {
65       SequentialTensorWriter<T> writer(output, cache);
66       writer.WriteN(0, shape.FlatSize());
67       source = cache;
68     }
69     SequentialTensorWriter<T> writer(source, output);
70     for (int j = 0; j < outer_size; ++j) {
71       // Copies the first stride.
72       const int begin_1 =
73           j * next_stride + (shape.Dims(i) - shift_map[i]) * stride;
74       const int size_1 = shift_map[i] * stride;
75       writer.WriteN(begin_1, size_1);
76       // Copies the second stride.
77       const int begin_2 = j * next_stride;
78       const int size_2 = (shape.Dims(i) - shift_map[i]) * stride;
79       writer.WriteN(begin_2, size_2);
80     }
81     in_place_rolling = true;
82   }
83 
84   // Copies input to output if no rolling is needed.
85   if (!in_place_rolling) {
86     SequentialTensorWriter<T> writer(input, output);
87     writer.WriteN(0, shape.FlatSize());
88     return;
89   }
90 }
91 
92 }  // namespace
93 
94 constexpr int kInputTensor = 0;
95 constexpr int kShiftTensor = 1;
96 constexpr int kAxisTensor = 2;
97 constexpr int kOutputTensor = 0;
98 constexpr int kTensorNotAllocated = -1;
99 
100 struct OpData {
101   // A temporary tensor to store intermediate output data when doing in-place
102   // rolling.
103   int cache_tensor_id = kTensorNotAllocated;
104   int32_t cache_index = kTensorNotAllocated;
105   bool need_cache = false;
106 };
107 
AllocateTemporaryTensorsIfRequired(TfLiteContext * context,TfLiteNode * node,OpData * opdata,const TfLiteTensor * input,const TfLiteTensor * shift)108 TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
109                                                 TfLiteNode* node,
110                                                 OpData* opdata,
111                                                 const TfLiteTensor* input,
112                                                 const TfLiteTensor* shift) {
113   int temporaries_count = 0;
114   opdata->need_cache = (NumElements(shift) > 1);
115   if (opdata->need_cache) {
116     if (opdata->cache_tensor_id == kTensorNotAllocated) {
117       TF_LITE_ENSURE_OK(
118           context, context->AddTensors(context, 1, &opdata->cache_tensor_id));
119     }
120     opdata->cache_index = temporaries_count++;
121   }
122 
123   TfLiteIntArrayFree(node->temporaries);
124   node->temporaries = TfLiteIntArrayCreate(temporaries_count);
125 
126   if (opdata->need_cache) {
127     node->temporaries->data[opdata->cache_index] = opdata->cache_tensor_id;
128     TfLiteTensor* cache;
129     TF_LITE_ENSURE_OK(
130         context, GetTemporarySafe(context, node, opdata->cache_index, &cache));
131     cache->type = input->type;
132     cache->allocation_type = kTfLiteArenaRw;
133     TfLiteIntArray* cache_shape = TfLiteIntArrayCopy(input->dims);
134     TF_LITE_ENSURE_OK(context,
135                       context->ResizeTensor(context, cache, cache_shape));
136   }
137   return kTfLiteOk;
138 }
139 
Init(TfLiteContext * context,const char * buffer,size_t length)140 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
141   auto* opdata = new OpData;
142   return opdata;
143 }
144 
Free(TfLiteContext * context,void * buffer)145 void Free(TfLiteContext* context, void* buffer) {
146   delete static_cast<OpData*>(buffer);
147 }
148 
Prepare(TfLiteContext * context,TfLiteNode * node)149 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
150   OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
151   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
152   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
153 
154   const TfLiteTensor* input;
155   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
156   const TfLiteTensor* shift;
157   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShiftTensor, &shift));
158   const TfLiteTensor* axis;
159   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
160   TfLiteTensor* output;
161   TF_LITE_ENSURE_OK(context,
162                     GetOutputSafe(context, node, kOutputTensor, &output));
163 
164   // Check tensor type.
165   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
166   TF_LITE_ENSURE(
167       context, (shift->type == kTfLiteInt32) || (shift->type == kTfLiteInt64));
168   TF_LITE_ENSURE(context,
169                  (axis->type == kTfLiteInt32) || (axis->type == kTfLiteInt64));
170 
171   // Make sure shift and axis are scalars or 1-D tensors.
172   TF_LITE_ENSURE(context,
173                  (NumDimensions(shift) == 0) || (NumDimensions(shift) == 1));
174   TF_LITE_ENSURE(context,
175                  (NumDimensions(shift) == 0) || (NumDimensions(shift) == 1));
176   TF_LITE_ENSURE_EQ(context, NumElements(shift), NumElements(axis));
177 
178   TF_LITE_ENSURE_OK(context, AllocateTemporaryTensorsIfRequired(
179                                  context, node, opdata, input, shift));
180 
181   // Output shape always equals to input shape.
182   TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
183   return context->ResizeTensor(context, output, output_shape);
184 }
185 
Eval(TfLiteContext * context,TfLiteNode * node)186 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
187   OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
188   const TfLiteTensor* input;
189   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
190   const TfLiteTensor* shift;
191   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShiftTensor, &shift));
192   const TfLiteTensor* axis;
193   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
194 
195   TfLiteTensor* cache = GetTemporary(context, node, opdata->cache_index);
196   TfLiteTensor* output;
197   TF_LITE_ENSURE_OK(context,
198                     GetOutputSafe(context, node, kOutputTensor, &output));
199 
200   // Extract the shift and axis information.
201   std::vector<int32_t> shift_data = ExtractIntegerVector(shift);
202   std::vector<int32_t> axis_data = ExtractIntegerVector(axis);
203 
204   // Maps from index as axis to its corresponding shift value.
205   const int input_rank = NumDimensions(input);
206   std::vector<int32_t> shift_map(input_rank, 0);
207 
208   // Make sure axis is in range [0, rank(input)).
209   for (int i = 0; i < axis_data.size(); ++i) {
210     int32_t axis_i = axis_data[i];
211     if (axis_i < 0) axis_i += input_rank;
212     shift_map[axis_i] += shift_data[i];
213   }
214 
215   // Make sure shift is range [0, rank(input)).
216   for (int i = 0; i < input_rank; ++i) {
217     const int32_t input_dims_i = SizeOfDimension(input, i);
218     int32_t shift_i = shift_map[i] % input_dims_i;
219     if (shift_i < 0) shift_i += input_dims_i;
220     shift_map[i] = shift_i;
221   }
222 
223 #define TF_LITE_ROLL(type) \
224   Pool<type>(shift_map, GetTensorShape(input), input, cache, output);
225 
226   // The type itself doesn't matter, we only care about type size.
227   switch (input->type) {
228     case kTfLiteFloat32:
229       TF_LITE_ROLL(float);
230       break;
231     case kTfLiteInt32:
232       TF_LITE_ROLL(int32_t);
233       break;
234     case kTfLiteInt64:
235       TF_LITE_ROLL(int64_t);
236       break;
237     case kTfLiteInt8:
238       TF_LITE_ROLL(int8_t);
239       break;
240     case kTfLiteInt16:
241       TF_LITE_ROLL(int16_t);
242       break;
243     case kTfLiteUInt8:
244       TF_LITE_ROLL(uint8_t);
245       break;
246     case kTfLiteBool:
247       TF_LITE_ROLL(bool);
248       break;
249     case kTfLiteString:
250       TF_LITE_ROLL(string);
251       break;
252     default:
253       TF_LITE_KERNEL_LOG(
254           context, "Type %d is currently not supported by Slice.", input->type);
255       return kTfLiteError;
256   }
257 #undef TF_LITE_ROLL
258   return kTfLiteOk;
259 }
260 }  // namespace roll
261 
Register_ROLL()262 TfLiteRegistration* Register_ROLL() {
263   static TfLiteRegistration r = {roll::Init, roll::Free, roll::Prepare,
264                                  roll::Eval};
265   return &r;
266 }
267 
268 }  // namespace custom
269 }  // namespace ops
270 }  // namespace tflite
271