xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/substr_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <cstdlib>
18 #include <string>
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/kernels/string_util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/bcast.h"
34 
35 namespace tensorflow {
36 
37 // Position/length can be 32 or 64-bit integers
38 template <typename T>
39 class SubstrOp : public OpKernel {
40  public:
SubstrOp(OpKernelConstruction * ctx)41   explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
42     string unit;
43     OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
44     OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
45   }
46 
Compute(OpKernelContext * context)47   void Compute(OpKernelContext* context) override {
48     // Get inputs
49     const Tensor& input_tensor = context->input(0);
50     const Tensor& pos_tensor = context->input(1);
51     const Tensor& len_tensor = context->input(2);
52     const TensorShape& input_shape = input_tensor.shape();
53     const TensorShape& pos_shape = pos_tensor.shape();
54     const TensorShape& len_shape = len_tensor.shape();
55     OP_REQUIRES(context, (pos_shape == len_shape),
56                 errors::InvalidArgument(
57                     "pos and len should have the same shape, got: ",
58                     pos_shape.DebugString(), " vs. ", len_shape.DebugString()));
59 
60     bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);
61 
62     if (is_scalar || input_shape == pos_shape) {
63       // pos/len are either scalar or match the shape of input_tensor
64       // Do not need to do broadcasting
65 
66       // Reshape input
67       auto input = input_tensor.flat<tstring>();
68       // Allocate output
69       Tensor* output_tensor = nullptr;
70       OP_REQUIRES_OK(context,
71                      context->allocate_output("output", input_tensor.shape(),
72                                               &output_tensor));
73       auto output = output_tensor->flat<tstring>();
74       if (is_scalar) {
75         // Perform Op with scalar pos/len
76         const T pos =
77             tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()());
78         const T len =
79             tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
80         for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
81           StringPiece in(input(i));
82           T byte_pos = pos;
83           T byte_len = len;
84           switch (unit_) {
85             case CharUnit::UTF8_CHAR:
86               OP_REQUIRES(
87                   context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
88                   errors::InvalidArgument("pos ", pos, " out of range for ",
89                                           "string at index ", i));
90               break;
91             case CharUnit::BYTE:
92               byte_pos = AdjustedPosIndex(byte_pos, in);
93               OP_REQUIRES(
94                   context, FastBoundsCheck(byte_pos, in.size() + 1),
95                   errors::InvalidArgument("pos ", pos, " out of range for ",
96                                           "string b'", in, "' at index ", i));
97           }
98           StringPiece sub_in = in.substr(byte_pos, byte_len);
99           output(i).assign(sub_in.data(), sub_in.size());
100         }
101       } else {
102         // Perform Op element-wise with tensor pos/len
103         auto pos_flat = pos_tensor.flat<T>();
104         auto len_flat = len_tensor.flat<T>();
105         for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
106           StringPiece in(input(i));
107           const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
108           const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
109           T byte_pos = pos;
110           T byte_len = len;
111           switch (unit_) {
112             case CharUnit::UTF8_CHAR:
113               OP_REQUIRES(
114                   context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
115                   errors::InvalidArgument("pos ", pos, " out of range for ",
116                                           "string at index ", i));
117               break;
118             case CharUnit::BYTE:
119               byte_pos = AdjustedPosIndex(byte_pos, in);
120               OP_REQUIRES(
121                   context, FastBoundsCheck(byte_pos, in.size() + 1),
122                   errors::InvalidArgument("pos ", pos, " out of range for ",
123                                           "string b'", in, "' at index ", i));
124           }
125           StringPiece sub_in = in.substr(byte_pos, byte_len);
126           output(i).assign(sub_in.data(), sub_in.size());
127         }
128       }
129     } else {
130       // Perform op with broadcasting
131       // TODO: Use ternary broadcasting for once available in Eigen. Current
132       //       implementation iterates through broadcasted ops element-wise;
133       //       this should be parallelized.
134 
135       // Create BCast helper with shape of input and pos/len
136       BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape),
137                   /*fewer_dims_optimization*/ false);
138       OP_REQUIRES(context, bcast.IsValid(),
139                   errors::InvalidArgument(
140                       "Incompatible shapes: ", input_shape.DebugString(),
141                       " vs. ", pos_shape.DebugString()));
142       TensorShape output_shape = BCast::ToShape(bcast.result_shape());
143       int ndims = output_shape.dims();
144       Tensor* output_tensor = nullptr;
145       OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
146                                                        &output_tensor));
147       switch (ndims) {
148         case 1: {
149           // Reshape tensors according to BCast results
150           auto input = input_tensor.shaped<tstring, 1>(bcast.x_reshape());
151           auto output = output_tensor->shaped<tstring, 1>(bcast.result_shape());
152           auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape());
153           auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape());
154 
155           // Allocate temporary buffer for broadcasted position tensor
156           Tensor pos_buffer;
157           OP_REQUIRES_OK(context,
158                          context->allocate_temp(DataTypeToEnum<T>::v(),
159                                                 output_shape, &pos_buffer));
160           typename TTypes<T, 1>::Tensor pos_bcast(
161               pos_buffer.shaped<T, 1>(bcast.result_shape()));
162           pos_bcast =
163               pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
164 
165           // Allocate temporary buffer for broadcasted length tensor
166           Tensor len_buffer;
167           OP_REQUIRES_OK(context,
168                          context->allocate_temp(DataTypeToEnum<T>::v(),
169                                                 output_shape, &len_buffer));
170           typename TTypes<T, 1>::Tensor len_bcast(
171               len_buffer.shaped<T, 1>(bcast.result_shape()));
172           len_bcast =
173               len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
174 
175           // Iterate through broadcasted tensors and perform substr
176           for (int i = 0; i < output_shape.dim_size(0); ++i) {
177             StringPiece in(input(input.dimension(0) > 1 ? i : 0));
178             const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
179             const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
180             T byte_pos = pos;
181             T byte_len = len;
182             switch (unit_) {
183               case CharUnit::UTF8_CHAR:
184                 OP_REQUIRES(
185                     context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
186                     errors::InvalidArgument("pos ", pos, " out of range for ",
187                                             "string at index ", i));
188                 break;
189               case CharUnit::BYTE:
190                 byte_pos = AdjustedPosIndex(byte_pos, in);
191                 OP_REQUIRES(
192                     context, FastBoundsCheck(byte_pos, in.size() + 1),
193                     errors::InvalidArgument("pos ", pos, " out of range for ",
194                                             "string b'", in, "' at index ", i));
195             }
196             StringPiece sub_in = in.substr(byte_pos, byte_len);
197             output(i).assign(sub_in.data(), sub_in.size());
198           }
199           break;
200         }
201         case 2: {
202           // Reshape tensors according to BCast results
203           auto input = input_tensor.shaped<tstring, 2>(bcast.x_reshape());
204           auto output = output_tensor->shaped<tstring, 2>(bcast.result_shape());
205           auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape());
206           auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape());
207 
208           // Allocate temporary buffer for broadcasted position tensor
209           Tensor pos_buffer;
210           OP_REQUIRES_OK(context,
211                          context->allocate_temp(DataTypeToEnum<T>::v(),
212                                                 output_shape, &pos_buffer));
213           typename TTypes<T, 2>::Tensor pos_bcast(
214               pos_buffer.shaped<T, 2>(bcast.result_shape()));
215           pos_bcast =
216               pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
217 
218           // Allocate temporary buffer for broadcasted length tensor
219           Tensor len_buffer;
220           OP_REQUIRES_OK(context,
221                          context->allocate_temp(DataTypeToEnum<T>::v(),
222                                                 output_shape, &len_buffer));
223           typename TTypes<T, 2>::Tensor len_bcast(
224               len_buffer.shaped<T, 2>(bcast.result_shape()));
225           len_bcast =
226               len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
227 
228           // Iterate through broadcasted tensors and perform substr
229           for (int i = 0; i < output_shape.dim_size(0); ++i) {
230             for (int j = 0; j < output_shape.dim_size(1); ++j) {
231               StringPiece in(input(input.dimension(0) > 1 ? i : 0,
232                                    input.dimension(1) > 1 ? j : 0));
233               const T pos =
234                   tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
235               const T len =
236                   tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
237               T byte_pos = pos;
238               T byte_len = len;
239               switch (unit_) {
240                 case CharUnit::UTF8_CHAR:
241                   OP_REQUIRES(
242                       context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
243                       errors::InvalidArgument("pos ", pos, " out of range for ",
244                                               "string at index ", i));
245                   break;
246                 case CharUnit::BYTE:
247                   byte_pos = AdjustedPosIndex(byte_pos, in);
248                   OP_REQUIRES(
249                       context, FastBoundsCheck(byte_pos, in.size() + 1),
250                       errors::InvalidArgument("pos ", pos, " out of range for ",
251                                               "string b'", in, "' at index (",
252                                               i, ", ", j, ")"));
253               }
254               StringPiece sub_in = in.substr(byte_pos, byte_len);
255               output(i, j).assign(sub_in.data(), sub_in.size());
256             }
257           }
258           break;
259         }
260         default: {
261           context->SetStatus(errors::Unimplemented(
262               "Substr broadcast not implemented for ", ndims, " dimensions"));
263         }
264       }
265     }
266   }
267 
268  private:
269   // This adjusts the requested position. Note it does not perform any bound
270   // checks.
AdjustedPosIndex(const T pos_requested,const StringPiece s)271   static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
272     if (pos_requested < 0) {
273       return s.size() + pos_requested;
274     }
275     return pos_requested;
276   }
277 
278   // Return true if successful; otherwise, return false if the `pos` argument
279   // is out of range in the string.
UpdatePosAndLenForUtf8(const StringPiece in,T * pos,T * len)280   static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos,
281                                             T* len) {
282     if (*pos >= 0) {
283       return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len);
284     } else {
285       return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len);
286     }
287   }
288 
UpdatePositivePosAndLenForUtf8(const StringPiece in,const T pos,const T len,T * char_pos,T * char_len)289   static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos,
290                                              const T len, T* char_pos,
291                                              T* char_len) {
292     *char_pos = 0;
293     // Determine byte position of the substring start.
294     if (!ForwardNUTF8CharPositions(in, pos, char_pos)) {
295       return false;
296     }
297     // Determine position of the end of the substring.
298     // The length will be capped at the end of the string, and we ignore whether
299     // the string had enough characters to handle it or not.
300     *char_len = *char_pos;
301     ForwardNUTF8CharPositions(in, len, char_len);
302     // The length in bytes is the position end of the substring less the start.
303     *char_len = *char_len - *char_pos;
304     return true;
305   }
306 
307   // This function expects a negative position relative to the end of the
308   // string, but will update the character position to a positive number
309   // relative to the beginning of the string.
UpdateNegativePosAndLenForUtf8(const StringPiece in,const T pos,const T len,T * char_pos,T * char_len)310   static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos,
311                                              const T len, T* char_pos,
312                                              T* char_len) {
313     // Initially treat the length as position of the end of the substring.
314     *char_len = in.size();
315     // This is the number of character to skip from the end of the string to
316     // arrive at the position where the substring should end.
317     T utf8_chars_to_skip = -pos - len;
318     if (utf8_chars_to_skip < 0) {
319       utf8_chars_to_skip = 0;
320     }
321     // Find the byte position where the substring should end using the computed
322     // number of characters to skip.
323     if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) {
324       return false;
325     }
326     // Next, determine where the substring should begin. The number of chars to
327     // skip is the requested position minus the chars we've previously skipped.
328     *char_pos = *char_len;
329     if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) {
330       return false;
331     }
332     // The length in bytes is the position end of the substring less the start.
333     *char_len = *char_len - *char_pos;
334     return true;
335   }
336 
337   CharUnit unit_ = CharUnit::BYTE;
338 };
339 
340 #define REGISTER_SUBSTR(type)                                      \
341   REGISTER_KERNEL_BUILDER(                                         \
342       Name("Substr").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
343       SubstrOp<type>);
344 REGISTER_SUBSTR(int32);
345 REGISTER_SUBSTR(int64_t);
346 }  // namespace tensorflow
347