1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <math.h>
17 #include <stdint.h>
18 #include <stdlib.h>
19
20 #include <cstring>
21 #include <vector>
22
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/internal/portable_tensor.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28
29 namespace tflite {
30 namespace ops {
31 namespace custom {
32 namespace roll {
33 namespace {
34
35 // A helper function to extract int32_t or int64_t tensor data.
ExtractIntegerVector(const TfLiteTensor * t)36 std::vector<int32_t> ExtractIntegerVector(const TfLiteTensor* t) {
37 TFLITE_DCHECK(t->type == kTfLiteInt32 || t->type == kTfLiteInt64);
38 const RuntimeShape& shape = GetTensorShape(t);
39 std::vector<int32_t> result(shape.FlatSize());
40 if (t->type == kTfLiteInt32) {
41 memcpy(result.data(), t->data.raw_const, t->bytes);
42 } else {
43 const int64_t* data = GetTensorData<int64_t>(t);
44 for (int i = 0; i < result.size(); ++i) {
45 result[i] = static_cast<int32_t>(data[i]);
46 }
47 }
48 return result;
49 }
50
51 template <typename T>
Pool(const std::vector<int32_t> & shift_map,const RuntimeShape & shape,const TfLiteTensor * input,TfLiteTensor * cache,TfLiteTensor * output)52 inline void Pool(const std::vector<int32_t>& shift_map,
53 const RuntimeShape& shape, const TfLiteTensor* input,
54 TfLiteTensor* cache, TfLiteTensor* output) {
55 int stride = 1, outer_size, next_stride;
56 bool in_place_rolling = false;
57 for (int i = shift_map.size() - 1; i >= 0; --i, stride = next_stride) {
58 next_stride = stride * shape.Dims(i);
59 if (shift_map[i] == 0) continue;
60
61 TFLITE_DCHECK_EQ(shape.FlatSize() % next_stride, 0);
62 outer_size = shape.FlatSize() / next_stride;
63 const TfLiteTensor* source = input;
64 if (in_place_rolling) {
65 SequentialTensorWriter<T> writer(output, cache);
66 writer.WriteN(0, shape.FlatSize());
67 source = cache;
68 }
69 SequentialTensorWriter<T> writer(source, output);
70 for (int j = 0; j < outer_size; ++j) {
71 // Copies the first stride.
72 const int begin_1 =
73 j * next_stride + (shape.Dims(i) - shift_map[i]) * stride;
74 const int size_1 = shift_map[i] * stride;
75 writer.WriteN(begin_1, size_1);
76 // Copies the second stride.
77 const int begin_2 = j * next_stride;
78 const int size_2 = (shape.Dims(i) - shift_map[i]) * stride;
79 writer.WriteN(begin_2, size_2);
80 }
81 in_place_rolling = true;
82 }
83
84 // Copies input to output if no rolling is needed.
85 if (!in_place_rolling) {
86 SequentialTensorWriter<T> writer(input, output);
87 writer.WriteN(0, shape.FlatSize());
88 return;
89 }
90 }
91
92 } // namespace
93
94 constexpr int kInputTensor = 0;
95 constexpr int kShiftTensor = 1;
96 constexpr int kAxisTensor = 2;
97 constexpr int kOutputTensor = 0;
98 constexpr int kTensorNotAllocated = -1;
99
100 struct OpData {
101 // A temporary tensor to store intermediate output data when doing in-place
102 // rolling.
103 int cache_tensor_id = kTensorNotAllocated;
104 int32_t cache_index = kTensorNotAllocated;
105 bool need_cache = false;
106 };
107
AllocateTemporaryTensorsIfRequired(TfLiteContext * context,TfLiteNode * node,OpData * opdata,const TfLiteTensor * input,const TfLiteTensor * shift)108 TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
109 TfLiteNode* node,
110 OpData* opdata,
111 const TfLiteTensor* input,
112 const TfLiteTensor* shift) {
113 int temporaries_count = 0;
114 opdata->need_cache = (NumElements(shift) > 1);
115 if (opdata->need_cache) {
116 if (opdata->cache_tensor_id == kTensorNotAllocated) {
117 TF_LITE_ENSURE_OK(
118 context, context->AddTensors(context, 1, &opdata->cache_tensor_id));
119 }
120 opdata->cache_index = temporaries_count++;
121 }
122
123 TfLiteIntArrayFree(node->temporaries);
124 node->temporaries = TfLiteIntArrayCreate(temporaries_count);
125
126 if (opdata->need_cache) {
127 node->temporaries->data[opdata->cache_index] = opdata->cache_tensor_id;
128 TfLiteTensor* cache;
129 TF_LITE_ENSURE_OK(
130 context, GetTemporarySafe(context, node, opdata->cache_index, &cache));
131 cache->type = input->type;
132 cache->allocation_type = kTfLiteArenaRw;
133 TfLiteIntArray* cache_shape = TfLiteIntArrayCopy(input->dims);
134 TF_LITE_ENSURE_OK(context,
135 context->ResizeTensor(context, cache, cache_shape));
136 }
137 return kTfLiteOk;
138 }
139
Init(TfLiteContext * context,const char * buffer,size_t length)140 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
141 auto* opdata = new OpData;
142 return opdata;
143 }
144
Free(TfLiteContext * context,void * buffer)145 void Free(TfLiteContext* context, void* buffer) {
146 delete static_cast<OpData*>(buffer);
147 }
148
Prepare(TfLiteContext * context,TfLiteNode * node)149 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
150 OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
151 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
152 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
153
154 const TfLiteTensor* input;
155 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
156 const TfLiteTensor* shift;
157 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShiftTensor, &shift));
158 const TfLiteTensor* axis;
159 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
160 TfLiteTensor* output;
161 TF_LITE_ENSURE_OK(context,
162 GetOutputSafe(context, node, kOutputTensor, &output));
163
164 // Check tensor type.
165 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
166 TF_LITE_ENSURE(
167 context, (shift->type == kTfLiteInt32) || (shift->type == kTfLiteInt64));
168 TF_LITE_ENSURE(context,
169 (axis->type == kTfLiteInt32) || (axis->type == kTfLiteInt64));
170
171 // Make sure shift and axis are scalars or 1-D tensors.
172 TF_LITE_ENSURE(context,
173 (NumDimensions(shift) == 0) || (NumDimensions(shift) == 1));
174 TF_LITE_ENSURE(context,
175 (NumDimensions(shift) == 0) || (NumDimensions(shift) == 1));
176 TF_LITE_ENSURE_EQ(context, NumElements(shift), NumElements(axis));
177
178 TF_LITE_ENSURE_OK(context, AllocateTemporaryTensorsIfRequired(
179 context, node, opdata, input, shift));
180
181 // Output shape always equals to input shape.
182 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
183 return context->ResizeTensor(context, output, output_shape);
184 }
185
Eval(TfLiteContext * context,TfLiteNode * node)186 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
187 OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
188 const TfLiteTensor* input;
189 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
190 const TfLiteTensor* shift;
191 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShiftTensor, &shift));
192 const TfLiteTensor* axis;
193 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
194
195 TfLiteTensor* cache = GetTemporary(context, node, opdata->cache_index);
196 TfLiteTensor* output;
197 TF_LITE_ENSURE_OK(context,
198 GetOutputSafe(context, node, kOutputTensor, &output));
199
200 // Extract the shift and axis information.
201 std::vector<int32_t> shift_data = ExtractIntegerVector(shift);
202 std::vector<int32_t> axis_data = ExtractIntegerVector(axis);
203
204 // Maps from index as axis to its corresponding shift value.
205 const int input_rank = NumDimensions(input);
206 std::vector<int32_t> shift_map(input_rank, 0);
207
208 // Make sure axis is in range [0, rank(input)).
209 for (int i = 0; i < axis_data.size(); ++i) {
210 int32_t axis_i = axis_data[i];
211 if (axis_i < 0) axis_i += input_rank;
212 shift_map[axis_i] += shift_data[i];
213 }
214
215 // Make sure shift is range [0, rank(input)).
216 for (int i = 0; i < input_rank; ++i) {
217 const int32_t input_dims_i = SizeOfDimension(input, i);
218 int32_t shift_i = shift_map[i] % input_dims_i;
219 if (shift_i < 0) shift_i += input_dims_i;
220 shift_map[i] = shift_i;
221 }
222
223 #define TF_LITE_ROLL(type) \
224 Pool<type>(shift_map, GetTensorShape(input), input, cache, output);
225
226 // The type itself doesn't matter, we only care about type size.
227 switch (input->type) {
228 case kTfLiteFloat32:
229 TF_LITE_ROLL(float);
230 break;
231 case kTfLiteInt32:
232 TF_LITE_ROLL(int32_t);
233 break;
234 case kTfLiteInt64:
235 TF_LITE_ROLL(int64_t);
236 break;
237 case kTfLiteInt8:
238 TF_LITE_ROLL(int8_t);
239 break;
240 case kTfLiteInt16:
241 TF_LITE_ROLL(int16_t);
242 break;
243 case kTfLiteUInt8:
244 TF_LITE_ROLL(uint8_t);
245 break;
246 case kTfLiteBool:
247 TF_LITE_ROLL(bool);
248 break;
249 case kTfLiteString:
250 TF_LITE_ROLL(string);
251 break;
252 default:
253 TF_LITE_KERNEL_LOG(
254 context, "Type %d is currently not supported by Slice.", input->type);
255 return kTfLiteError;
256 }
257 #undef TF_LITE_ROLL
258 return kTfLiteOk;
259 }
260 } // namespace roll
261
Register_ROLL()262 TfLiteRegistration* Register_ROLL() {
263 static TfLiteRegistration r = {roll::Init, roll::Free, roll::Prepare,
264 roll::Eval};
265 return &r;
266 }
267
268 } // namespace custom
269 } // namespace ops
270 } // namespace tflite
271