xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/tile.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <algorithm>
18 #include <tuple>
19 #include <utility>
20 
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 #include "tensorflow/lite/string_util.h"
27 
28 namespace tflite {
29 namespace ops {
30 namespace builtin {
31 namespace tile {
32 
33 constexpr int kInputTensor = 0;
34 constexpr int kInputMultipliers = 1;
35 constexpr int kOutputTensor = 0;
36 
37 namespace {
38 template <typename T>
MultiplyShapeDims(const TfLiteIntArray & shape,const TfLiteTensor * multipliers,int num_dimensions)39 TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape,
40                                   const TfLiteTensor* multipliers,
41                                   int num_dimensions) {
42   const T* multipliers_v = GetTensorData<T>(multipliers);
43 
44   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
45   for (int i = 0; i < num_dimensions; ++i) {
46     output_shape->data[i] = shape.data[i] * multipliers_v[i];
47   }
48   return output_shape;
49 }
50 
ResizeOutput(TfLiteContext * context,TfLiteNode * node)51 TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
52   const TfLiteTensor* input;
53   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
54   TfLiteTensor* output;
55   TF_LITE_ENSURE_OK(context,
56                     GetOutputSafe(context, node, kOutputTensor, &output));
57   const TfLiteTensor* multipliers;
58   TF_LITE_ENSURE_OK(
59       context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
60 
61   const int num_dimensions = NumDimensions(input);
62   const int num_multipliers = NumElements(multipliers);
63   TF_LITE_ENSURE_EQ(context, num_dimensions, num_multipliers);
64   switch (multipliers->type) {
65     case kTfLiteInt32:
66       return context->ResizeTensor(
67           context, output,
68           MultiplyShapeDims<int32_t>(*input->dims, multipliers,
69                                      num_dimensions));
70     case kTfLiteInt64:
71       return context->ResizeTensor(
72           context, output,
73           MultiplyShapeDims<int64_t>(*input->dims, multipliers,
74                                      num_dimensions));
75     default:
76       TF_LITE_KERNEL_LOG(context,
77                          "Multipliers of type '%s' are not supported by tile.",
78                          TfLiteTypeGetName(multipliers->type));
79       return kTfLiteError;
80   }
81 }
82 
83 template <typename T, typename M>
CopyMultipleTimes(const T * in_data,int32_t in_size,M multiplier,T * out_data)84 void CopyMultipleTimes(const T* in_data, int32_t in_size, M multiplier,
85                        T* out_data) {
86   for (M i = 0; i < multiplier; ++i) {
87     const T* in_end = in_data + in_size;
88     T* new_out_data = std::copy(in_data, in_end, out_data);
89     in_data = out_data;
90     out_data = new_out_data;
91   }
92 }
93 
94 template <typename M>
CopyStringMultipleTimes(const TfLiteTensor * in_data,int in_data_index,const int dimension_size,M multiplier,DynamicBuffer * buffer)95 void CopyStringMultipleTimes(const TfLiteTensor* in_data, int in_data_index,
96                              const int dimension_size, M multiplier,
97                              DynamicBuffer* buffer) {
98   for (M i = 0; i < multiplier; ++i) {
99     for (int j = 0; j < dimension_size; ++j) {
100       const auto string_ref = GetString(in_data, in_data_index + j);
101       buffer->AddString(string_ref.str, string_ref.len);
102     }
103   }
104 }
105 
106 template <typename T, typename M>
TileOneDimension(const TfLiteIntArray & in_dimensions,const T * in_data,const M * multipliers,T * out_data,int dimension)107 std::pair<int, int> TileOneDimension(const TfLiteIntArray& in_dimensions,
108                                      const T* in_data, const M* multipliers,
109                                      T* out_data, int dimension) {
110   if (in_dimensions.size == 0) {
111     // If input tensor is a scalar, then just copy it to output (no need to
112     // multiply).
113     *out_data = *in_data;
114     return std::make_pair(0, 0);
115   }
116 
117   const int dimension_size = in_dimensions.data[dimension];
118   if (dimension == in_dimensions.size - 1) {
119     CopyMultipleTimes(in_data, dimension_size, multipliers[dimension],
120                       out_data);
121     return std::make_pair(
122         dimension_size,
123         dimension_size * static_cast<int>(multipliers[dimension]));
124   }
125   int total_stride_size = 0, total_tiled_stride_size = 0;
126   const T* copy_from_data = in_data;
127   T* copy_to_data = out_data;
128   for (int i = 0; i < dimension_size; ++i) {
129     int stride_size = 0, tiled_stride_size = 0;
130     std::tie(stride_size, tiled_stride_size) =
131         TileOneDimension(in_dimensions, copy_from_data, multipliers,
132                          copy_to_data, dimension + 1);
133     copy_from_data += stride_size;
134     copy_to_data += tiled_stride_size;
135     total_stride_size += stride_size;
136     total_tiled_stride_size += tiled_stride_size;
137   }
138   CopyMultipleTimes(out_data, total_tiled_stride_size,
139                     multipliers[dimension] - 1,
140                     out_data + total_tiled_stride_size);
141   return std::make_pair(
142       total_stride_size,
143       static_cast<int>(total_tiled_stride_size * multipliers[dimension]));
144 }
145 
146 template <typename M>
TileStringOneDimension(const TfLiteIntArray & in_dimensions,const TfLiteTensor * in_data,int in_data_index,const M * multipliers,DynamicBuffer * buffer,int buffer_index,int dimension,TfLiteTensor * out_data)147 std::pair<int, int> TileStringOneDimension(
148     const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data,
149     int in_data_index, const M* multipliers, DynamicBuffer* buffer,
150     int buffer_index, int dimension, TfLiteTensor* out_data) {
151   const int dimension_size = in_dimensions.data[dimension];
152   if (dimension == in_dimensions.size - 1) {
153     CopyStringMultipleTimes(in_data, in_data_index, dimension_size,
154                             multipliers[dimension], buffer);
155     return {dimension_size,
156             dimension_size * static_cast<int>(multipliers[dimension])};
157   }
158 
159   int total_stride_size = 0, total_tiled_stride_size = 0;
160   for (int i = 0; i < dimension_size; ++i) {
161     int stride_size, tiled_stride_size;
162     std::tie(stride_size, tiled_stride_size) = TileStringOneDimension(
163         in_dimensions, in_data, in_data_index + total_stride_size, multipliers,
164         buffer, buffer_index + total_tiled_stride_size, dimension + 1,
165         out_data);
166     total_stride_size += stride_size;
167     total_tiled_stride_size += tiled_stride_size;
168   }
169 
170   buffer->WriteToTensor(out_data, /*new_shape=*/nullptr);
171   CopyStringMultipleTimes(out_data, buffer_index, total_tiled_stride_size,
172                           multipliers[dimension] - 1, buffer);
173 
174   return {total_stride_size,
175           total_tiled_stride_size * static_cast<int>(multipliers[dimension])};
176 }
177 
178 template <typename T>
Tile(const TfLiteIntArray & in_dimensions,const TfLiteTensor * in_data,const TfLiteTensor * multipliers,TfLiteTensor * out_data)179 void Tile(const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data,
180           const TfLiteTensor* multipliers, TfLiteTensor* out_data) {
181   // Doing recursively tiling from top to down dimension.
182   switch (multipliers->type) {
183     case kTfLiteInt32:
184       TileOneDimension(in_dimensions, GetTensorData<T>(in_data),
185                        GetTensorData<int32_t>(multipliers),
186                        GetTensorData<T>(out_data), 0);
187       break;
188     case kTfLiteInt64:
189       TileOneDimension(in_dimensions, GetTensorData<T>(in_data),
190                        GetTensorData<int64_t>(multipliers),
191                        GetTensorData<T>(out_data), 0);
192       break;
193     default:
194       break;
195   }
196 }
197 
TileString(const TfLiteIntArray & in_dimensions,const TfLiteTensor * in_data,const TfLiteTensor * multipliers,DynamicBuffer * buffer,TfLiteTensor * out_data)198 void TileString(const TfLiteIntArray& in_dimensions,
199                 const TfLiteTensor* in_data, const TfLiteTensor* multipliers,
200                 DynamicBuffer* buffer, TfLiteTensor* out_data) {
201   // Doing recursively tiling from top to down dimension.
202   switch (multipliers->type) {
203     case kTfLiteInt32:
204       TileStringOneDimension(in_dimensions, in_data, 0,
205                              GetTensorData<int32_t>(multipliers), buffer, 0, 0,
206                              out_data);
207       break;
208     case kTfLiteInt64:
209       TileStringOneDimension(in_dimensions, in_data, 0,
210                              GetTensorData<int64_t>(multipliers), buffer, 0, 0,
211                              out_data);
212       break;
213     default:
214       break;
215   }
216 }
217 }  // namespace
218 
Prepare(TfLiteContext * context,TfLiteNode * node)219 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
220   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
221   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
222 
223   const TfLiteTensor* input;
224   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
225 
226   TfLiteTensor* output;
227   TF_LITE_ENSURE_OK(context,
228                     GetOutputSafe(context, node, kOutputTensor, &output));
229   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
230 
231   const TfLiteTensor* multipliers;
232   TF_LITE_ENSURE_OK(
233       context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
234   // Only int32 and int64 multipliers type is supported.
235   if (multipliers->type != kTfLiteInt32 && multipliers->type != kTfLiteInt64) {
236     TF_LITE_KERNEL_LOG(context,
237                        "Multipliers of type '%s' are not supported by tile.",
238                        TfLiteTypeGetName(multipliers->type));
239     return kTfLiteError;
240   }
241 
242   if (IsConstantTensor(multipliers)) {
243     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
244   } else {
245     SetTensorToDynamic(output);
246   }
247   return kTfLiteOk;
248 }
249 
Eval(TfLiteContext * context,TfLiteNode * node)250 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
251   const TfLiteTensor* input;
252   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
253   TfLiteTensor* output;
254   TF_LITE_ENSURE_OK(context,
255                     GetOutputSafe(context, node, kOutputTensor, &output));
256   const TfLiteTensor* multipliers;
257   TF_LITE_ENSURE_OK(
258       context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
259 
260   if (IsDynamicTensor(output)) {
261     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
262   }
263   if (GetTensorShape(output).FlatSize() == 0) {
264     return kTfLiteOk;
265   }
266 
267   switch (output->type) {
268     case kTfLiteFloat32:
269       Tile<float>(*(input->dims), input, multipliers, output);
270       break;
271     case kTfLiteInt8:
272       Tile<int8_t>(*(input->dims), input, multipliers, output);
273       break;
274     case kTfLiteUInt8:
275       Tile<uint8_t>(*(input->dims), input, multipliers, output);
276       break;
277     case kTfLiteInt32:
278       Tile<int32_t>(*(input->dims), input, multipliers, output);
279       break;
280     case kTfLiteInt64:
281       Tile<int64_t>(*(input->dims), input, multipliers, output);
282       break;
283     case kTfLiteString: {
284       DynamicBuffer buffer;
285       TileString(*(input->dims), input, multipliers, &buffer, output);
286       buffer.WriteToTensor(output, /*new_shape=*/nullptr);
287       break;
288     }
289     case kTfLiteBool:
290       Tile<bool>(*(input->dims), input, multipliers, output);
291       break;
292     default:
293       TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by tile.",
294                          TfLiteTypeGetName(output->type));
295       return kTfLiteError;
296   }
297   return kTfLiteOk;
298 }
299 
300 }  // namespace tile
Register_TILE()301 TfLiteRegistration* Register_TILE() {
302   static TfLiteRegistration r = {nullptr, nullptr, tile::Prepare, tile::Eval};
303   return &r;
304 }
305 }  // namespace builtin
306 }  // namespace ops
307 }  // namespace tflite
308