1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cstddef> 17 #include <cstdlib> 18 #include <string> 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/framework/bounds_check.h" 22 #include "tensorflow/core/framework/kernel_def_builder.h" 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/kernels/string_util.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/stringpiece.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/core/util/bcast.h" 34 35 namespace tensorflow { 36 37 // Position/length can be 32 or 64-bit integers 38 template <typename T> 39 class SubstrOp : public OpKernel { 40 public: SubstrOp(OpKernelConstruction * ctx)41 explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 42 string unit; 43 OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit)); 44 OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_)); 45 } 46 Compute(OpKernelContext * context)47 void Compute(OpKernelContext* context) override { 48 // Get inputs 49 const Tensor& input_tensor = context->input(0); 50 const Tensor& pos_tensor = context->input(1); 51 const Tensor& len_tensor = context->input(2); 52 const TensorShape& input_shape = input_tensor.shape(); 53 const TensorShape& pos_shape = pos_tensor.shape(); 54 const TensorShape& len_shape = len_tensor.shape(); 55 OP_REQUIRES(context, (pos_shape == len_shape), 56 errors::InvalidArgument( 57 "pos and len should have the same shape, got: ", 58 pos_shape.DebugString(), " vs. ", len_shape.DebugString())); 59 60 bool is_scalar = TensorShapeUtils::IsScalar(pos_shape); 61 62 if (is_scalar || input_shape == pos_shape) { 63 // pos/len are either scalar or match the shape of input_tensor 64 // Do not need to do broadcasting 65 66 // Reshape input 67 auto input = input_tensor.flat<tstring>(); 68 // Allocate output 69 Tensor* output_tensor = nullptr; 70 OP_REQUIRES_OK(context, 71 context->allocate_output("output", input_tensor.shape(), 72 &output_tensor)); 73 auto output = output_tensor->flat<tstring>(); 74 if (is_scalar) { 75 // Perform Op with scalar pos/len 76 const T pos = 77 tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()()); 78 const T len = 79 tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); 80 for (size_t i = 0; i < input_tensor.NumElements(); ++i) { 81 StringPiece in(input(i)); 82 T byte_pos = pos; 83 T byte_len = len; 84 switch (unit_) { 85 case CharUnit::UTF8_CHAR: 86 OP_REQUIRES( 87 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), 88 errors::InvalidArgument("pos ", pos, " out of range for ", 89 "string at index ", i)); 90 break; 91 case CharUnit::BYTE: 92 byte_pos = AdjustedPosIndex(byte_pos, in); 93 OP_REQUIRES( 94 context, FastBoundsCheck(byte_pos, in.size() + 1), 95 errors::InvalidArgument("pos ", pos, " out of range for ", 96 "string b'", in, "' at index ", i)); 97 } 98 StringPiece sub_in = in.substr(byte_pos, byte_len); 99 output(i).assign(sub_in.data(), sub_in.size()); 100 } 101 } else { 102 // Perform Op element-wise with tensor pos/len 103 auto pos_flat = pos_tensor.flat<T>(); 104 auto len_flat = len_tensor.flat<T>(); 105 for (size_t i = 0; i < input_tensor.NumElements(); ++i) { 106 StringPiece in(input(i)); 107 const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); 108 const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); 109 T byte_pos = pos; 110 T byte_len = len; 111 switch (unit_) { 112 case CharUnit::UTF8_CHAR: 113 OP_REQUIRES( 114 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), 115 errors::InvalidArgument("pos ", pos, " out of range for ", 116 "string at index ", i)); 117 break; 118 case CharUnit::BYTE: 119 byte_pos = AdjustedPosIndex(byte_pos, in); 120 OP_REQUIRES( 121 context, FastBoundsCheck(byte_pos, in.size() + 1), 122 errors::InvalidArgument("pos ", pos, " out of range for ", 123 "string b'", in, "' at index ", i)); 124 } 125 StringPiece sub_in = in.substr(byte_pos, byte_len); 126 output(i).assign(sub_in.data(), sub_in.size()); 127 } 128 } 129 } else { 130 // Perform op with broadcasting 131 // TODO: Use ternary broadcasting for once available in Eigen. Current 132 // implementation iterates through broadcasted ops element-wise; 133 // this should be parallelized. 134 135 // Create BCast helper with shape of input and pos/len 136 BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape), 137 /*fewer_dims_optimization*/ false); 138 OP_REQUIRES(context, bcast.IsValid(), 139 errors::InvalidArgument( 140 "Incompatible shapes: ", input_shape.DebugString(), 141 " vs. ", pos_shape.DebugString())); 142 TensorShape output_shape = BCast::ToShape(bcast.result_shape()); 143 int ndims = output_shape.dims(); 144 Tensor* output_tensor = nullptr; 145 OP_REQUIRES_OK(context, context->allocate_output("output", output_shape, 146 &output_tensor)); 147 switch (ndims) { 148 case 1: { 149 // Reshape tensors according to BCast results 150 auto input = input_tensor.shaped<tstring, 1>(bcast.x_reshape()); 151 auto output = output_tensor->shaped<tstring, 1>(bcast.result_shape()); 152 auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape()); 153 auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape()); 154 155 // Allocate temporary buffer for broadcasted position tensor 156 Tensor pos_buffer; 157 OP_REQUIRES_OK(context, 158 context->allocate_temp(DataTypeToEnum<T>::v(), 159 output_shape, &pos_buffer)); 160 typename TTypes<T, 1>::Tensor pos_bcast( 161 pos_buffer.shaped<T, 1>(bcast.result_shape())); 162 pos_bcast = 163 pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); 164 165 // Allocate temporary buffer for broadcasted length tensor 166 Tensor len_buffer; 167 OP_REQUIRES_OK(context, 168 context->allocate_temp(DataTypeToEnum<T>::v(), 169 output_shape, &len_buffer)); 170 typename TTypes<T, 1>::Tensor len_bcast( 171 len_buffer.shaped<T, 1>(bcast.result_shape())); 172 len_bcast = 173 len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); 174 175 // Iterate through broadcasted tensors and perform substr 176 for (int i = 0; i < output_shape.dim_size(0); ++i) { 177 StringPiece in(input(input.dimension(0) > 1 ? i : 0)); 178 const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); 179 const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); 180 T byte_pos = pos; 181 T byte_len = len; 182 switch (unit_) { 183 case CharUnit::UTF8_CHAR: 184 OP_REQUIRES( 185 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), 186 errors::InvalidArgument("pos ", pos, " out of range for ", 187 "string at index ", i)); 188 break; 189 case CharUnit::BYTE: 190 byte_pos = AdjustedPosIndex(byte_pos, in); 191 OP_REQUIRES( 192 context, FastBoundsCheck(byte_pos, in.size() + 1), 193 errors::InvalidArgument("pos ", pos, " out of range for ", 194 "string b'", in, "' at index ", i)); 195 } 196 StringPiece sub_in = in.substr(byte_pos, byte_len); 197 output(i).assign(sub_in.data(), sub_in.size()); 198 } 199 break; 200 } 201 case 2: { 202 // Reshape tensors according to BCast results 203 auto input = input_tensor.shaped<tstring, 2>(bcast.x_reshape()); 204 auto output = output_tensor->shaped<tstring, 2>(bcast.result_shape()); 205 auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape()); 206 auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape()); 207 208 // Allocate temporary buffer for broadcasted position tensor 209 Tensor pos_buffer; 210 OP_REQUIRES_OK(context, 211 context->allocate_temp(DataTypeToEnum<T>::v(), 212 output_shape, &pos_buffer)); 213 typename TTypes<T, 2>::Tensor pos_bcast( 214 pos_buffer.shaped<T, 2>(bcast.result_shape())); 215 pos_bcast = 216 pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); 217 218 // Allocate temporary buffer for broadcasted length tensor 219 Tensor len_buffer; 220 OP_REQUIRES_OK(context, 221 context->allocate_temp(DataTypeToEnum<T>::v(), 222 output_shape, &len_buffer)); 223 typename TTypes<T, 2>::Tensor len_bcast( 224 len_buffer.shaped<T, 2>(bcast.result_shape())); 225 len_bcast = 226 len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); 227 228 // Iterate through broadcasted tensors and perform substr 229 for (int i = 0; i < output_shape.dim_size(0); ++i) { 230 for (int j = 0; j < output_shape.dim_size(1); ++j) { 231 StringPiece in(input(input.dimension(0) > 1 ? i : 0, 232 input.dimension(1) > 1 ? j : 0)); 233 const T pos = 234 tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); 235 const T len = 236 tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); 237 T byte_pos = pos; 238 T byte_len = len; 239 switch (unit_) { 240 case CharUnit::UTF8_CHAR: 241 OP_REQUIRES( 242 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), 243 errors::InvalidArgument("pos ", pos, " out of range for ", 244 "string at index ", i)); 245 break; 246 case CharUnit::BYTE: 247 byte_pos = AdjustedPosIndex(byte_pos, in); 248 OP_REQUIRES( 249 context, FastBoundsCheck(byte_pos, in.size() + 1), 250 errors::InvalidArgument("pos ", pos, " out of range for ", 251 "string b'", in, "' at index (", 252 i, ", ", j, ")")); 253 } 254 StringPiece sub_in = in.substr(byte_pos, byte_len); 255 output(i, j).assign(sub_in.data(), sub_in.size()); 256 } 257 } 258 break; 259 } 260 default: { 261 context->SetStatus(errors::Unimplemented( 262 "Substr broadcast not implemented for ", ndims, " dimensions")); 263 } 264 } 265 } 266 } 267 268 private: 269 // This adjusts the requested position. Note it does not perform any bound 270 // checks. AdjustedPosIndex(const T pos_requested,const StringPiece s)271 static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) { 272 if (pos_requested < 0) { 273 return s.size() + pos_requested; 274 } 275 return pos_requested; 276 } 277 278 // Return true if successful; otherwise, return false if the `pos` argument 279 // is out of range in the string. UpdatePosAndLenForUtf8(const StringPiece in,T * pos,T * len)280 static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos, 281 T* len) { 282 if (*pos >= 0) { 283 return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len); 284 } else { 285 return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len); 286 } 287 } 288 UpdatePositivePosAndLenForUtf8(const StringPiece in,const T pos,const T len,T * char_pos,T * char_len)289 static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos, 290 const T len, T* char_pos, 291 T* char_len) { 292 *char_pos = 0; 293 // Determine byte position of the substring start. 294 if (!ForwardNUTF8CharPositions(in, pos, char_pos)) { 295 return false; 296 } 297 // Determine position of the end of the substring. 298 // The length will be capped at the end of the string, and we ignore whether 299 // the string had enough characters to handle it or not. 300 *char_len = *char_pos; 301 ForwardNUTF8CharPositions(in, len, char_len); 302 // The length in bytes is the position end of the substring less the start. 303 *char_len = *char_len - *char_pos; 304 return true; 305 } 306 307 // This function expects a negative position relative to the end of the 308 // string, but will update the character position to a positive number 309 // relative to the beginning of the string. UpdateNegativePosAndLenForUtf8(const StringPiece in,const T pos,const T len,T * char_pos,T * char_len)310 static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos, 311 const T len, T* char_pos, 312 T* char_len) { 313 // Initially treat the length as position of the end of the substring. 314 *char_len = in.size(); 315 // This is the number of character to skip from the end of the string to 316 // arrive at the position where the substring should end. 317 T utf8_chars_to_skip = -pos - len; 318 if (utf8_chars_to_skip < 0) { 319 utf8_chars_to_skip = 0; 320 } 321 // Find the byte position where the substring should end using the computed 322 // number of characters to skip. 323 if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) { 324 return false; 325 } 326 // Next, determine where the substring should begin. The number of chars to 327 // skip is the requested position minus the chars we've previously skipped. 328 *char_pos = *char_len; 329 if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) { 330 return false; 331 } 332 // The length in bytes is the position end of the substring less the start. 333 *char_len = *char_len - *char_pos; 334 return true; 335 } 336 337 CharUnit unit_ = CharUnit::BYTE; 338 }; 339 340 #define REGISTER_SUBSTR(type) \ 341 REGISTER_KERNEL_BUILDER( \ 342 Name("Substr").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 343 SubstrOp<type>); 344 REGISTER_SUBSTR(int32); 345 REGISTER_SUBSTR(int64_t); 346 } // namespace tensorflow 347