1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ 19 #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ 20 21 #include <cstdint> 22 23 #include "tensorflow/core/framework/op_requires.h" 24 #include "tensorflow/core/platform/types.h" 25 #define EIGEN_USE_THREADS 26 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 27 #define EIGEN_USE_GPU 28 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 29 30 #include "third_party/eigen3/Eigen/Core" 31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 32 #include "tensorflow/core/framework/bounds_check.h" 33 #include "tensorflow/core/framework/numeric_op.h" 34 #include "tensorflow/core/framework/op_kernel.h" 35 #include "tensorflow/core/framework/register_types.h" 36 #include "tensorflow/core/framework/tensor.h" 37 #include "tensorflow/core/framework/tensor_types.h" 38 #include "tensorflow/core/framework/tensor_util.h" 39 #include "tensorflow/core/framework/types.h" 40 #include "tensorflow/core/kernels/segment_reduction_ops.h" 41 #include "tensorflow/core/lib/core/status.h" 42 #include "tensorflow/core/platform/logging.h" 43 #include "tensorflow/core/util/determinism.h" 44 #include "tensorflow/core/util/util.h" 45 46 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 47 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 48 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 49 50 #if GOOGLE_CUDA 51 #include "tensorflow/core/util/gpu_solvers.h" 52 #include "tensorflow/stream_executor/cuda/cuda_activation.h" 53 54 using stream_executor::cuda::ScopedActivateExecutorContext; 55 #elif TENSORFLOW_USE_ROCM 56 #include "tensorflow/core/platform/rocm.h" 57 #include "tensorflow/core/util/gpu_solvers.h" 58 using stream_executor::rocm::ScopedActivateExecutorContext; 59 #endif // GOOGLE_CUDA 60 61 namespace tensorflow { 62 63 typedef Eigen::ThreadPoolDevice CPUDevice; 64 typedef Eigen::GpuDevice GPUDevice; 65 66 namespace internal { 67 Status ValidateSegmentReduction(OpKernelContext* c, const Tensor& input, 68 const Tensor& segment_ids); 69 Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, 70 OpKernelContext* context, 71 const Tensor& data, 72 const Tensor& segment_ids, 73 const Tensor& num_segments); 74 Status ValidateSparseSegmentReduction(OpKernelContext* context, 75 const Tensor& input, 76 const Tensor& indices, 77 const Tensor& segment_ids, 78 bool has_num_segments); 79 } // namespace internal 80 81 // This operator handles reducing segments along the first dimension. 82 // See core/ops/math_ops.cc for more details. 83 template <typename Device, class T, class Index, typename Reducer, 84 int default_value> 85 class SegmentReductionOp : public OpKernel { 86 public: SegmentReductionOp(OpKernelConstruction * context)87 explicit SegmentReductionOp(OpKernelConstruction* context) 88 : OpKernel(context) {} 89 Compute(OpKernelContext * context)90 void Compute(OpKernelContext* context) override { 91 const Tensor& input = context->input(0); 92 const Tensor& segment_ids = context->input(1); 93 94 OP_REQUIRES_OK(context, internal::ValidateSegmentReduction(context, input, 95 segment_ids)); 96 97 const int64_t num_indices = segment_ids.NumElements(); 98 auto input_flat = input.flat_outer_dims<T>(); 99 const int64_t num_col = input_flat.dimension(1); 100 101 const auto segment_vec = segment_ids.vec<Index>(); 102 // Note that the current implementation assumes that segment_vec values are 103 // sorted. 104 const Index output_rows = 105 num_indices > 0 106 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 107 : 0; 108 OP_REQUIRES(context, output_rows >= 0, 109 errors::InvalidArgument("segment ids must be >= 0")); 110 111 OP_REQUIRES(context, input.dims() >= 1, 112 errors::InvalidArgument("Shape must be at least rank 1")); 113 114 TensorShape output_shape = input.shape(); 115 // Since we're changing the first dimension of the shape, we need to make 116 // sure the new shape won't overflow. 117 OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, output_rows)); 118 119 // Note that we do not initialize the output buffer with a default value, so 120 // we need to explicitly set missing indices to the default value. 121 Tensor* output = nullptr; 122 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 123 if (num_indices == 0) return; 124 OP_REQUIRES(context, output_rows > 0, 125 errors::InvalidArgument("segment ids must be >= 0")); 126 auto output_flat = output->flat_outer_dims<T>(); 127 128 Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce; 129 Index start = 0, end = 1; 130 131 Index uninitialized_index = 0; // Index from which the output is not set. 132 Index out_index = internal::SubtleMustCopy(segment_vec(start)); 133 134 // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it 135 // across threads. 136 Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col); 137 while (end <= num_indices) { 138 // We initialize next_index to 0 to avoid "warning: 'next_index' may be 139 // used uninitialized in this function" in the Mac build (since the 140 // compiler isn't smart enough to realize the code is safe). 141 Index next_index = 0; 142 if (end < num_indices) { 143 next_index = internal::SubtleMustCopy(segment_vec(end)); 144 if (out_index == next_index) { 145 ++end; 146 continue; 147 } 148 // We have a new segment here. Verify that the segment ids are growing. 149 OP_REQUIRES(context, out_index < next_index, 150 errors::InvalidArgument("segment ids are not increasing")); 151 } 152 153 // Process segment [start, end) 154 const T* in_slice_ptr = &input_flat(start, 0); 155 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, 156 Eigen::Unaligned> 157 OutT; 158 159 OP_REQUIRES( 160 context, FastBoundsCheck(out_index, output_rows), 161 errors::InvalidArgument( 162 "Segment id ", out_index, " out of range [0, ", output_rows, 163 "), possibly because 'segment_ids' input is not sorted.")); 164 165 // If there is a gap between two indices, we need to set that gap to the 166 // default value. 167 if (out_index > uninitialized_index) { 168 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 169 out_index - uninitialized_index, num_col); 170 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 171 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 172 gap_slice.setConstant(T(default_value)); 173 } 174 175 T* out_slice_ptr = &output_flat(out_index, 0); 176 OutT out_slice(out_slice_ptr, out_slice_shape); 177 // We don't use out_slice.device(context->eigen_device<Device>) 178 // because these pieces of work are likely to be very small and 179 // the context switching overhead dwarfs any benefit we get from 180 // using another thread to do this work. 181 if (start == end - 1) { 182 typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>, 183 Eigen::Unaligned> 184 InT; 185 InT in_slice(in_slice_ptr, out_slice_shape); 186 out_slice = in_slice; 187 } else { 188 Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start, 189 num_col); 190 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, 191 Eigen::Unaligned> 192 InT; 193 InT in_slice(in_slice_ptr, in_slice_shape); 194 195 out_slice = in_slice.reduce(dims_to_reduce, Reducer()); 196 } 197 if (end >= num_indices) break; 198 start = end; 199 ++end; 200 uninitialized_index = out_index + 1; 201 out_index = next_index; 202 } 203 } 204 }; 205 206 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 207 208 // SegmentReductionGPUOp is a segment reduction operator implemented for GPU 209 // only. 210 // TODO: This implementation of SegmentReductionGPUOp is sometimes slower than 211 // its unsorted counterpart (mostly when problem size is small). 212 // This is due to the following two main reasons and a cost-effective way 213 // to resolve these problems is desirable. 214 // 1. Sorted segment reduction requires a memory transfer from device to host 215 // in order to know the size of the output dimension whereas unsorted 216 // segment reduction receives the size of the output dimension as an input 217 // parameter. 218 // 2. Sorted segment reduction is essentially a tiled version of unsorted 219 // segment reduction and therefore such optimization comes at an inherent 220 // cost. However such cost may not be justified when the problem size is 221 // small. When to use the tiled version or the untiled version depends on 222 // many factors including data alignments, ratio of calculation to memory 223 // traffic and obviously, the problem sizes. 224 template <class T, class Index, class SegmentReductionFunctor, bool IsMean> 225 class SegmentReductionGPUOp : public AsyncOpKernel { 226 public: SegmentReductionGPUOp(OpKernelConstruction * context)227 explicit SegmentReductionGPUOp(OpKernelConstruction* context) 228 : AsyncOpKernel(context) {} 229 ComputeAsync(OpKernelContext * context,DoneCallback done)230 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 231 const Tensor& input = context->input(0); 232 const Tensor& segment_ids = context->input(1); 233 234 OP_REQUIRES_ASYNC( 235 context, TensorShapeUtils::IsVector(segment_ids.shape()), 236 errors::InvalidArgument("segment_ids should be a vector."), done); 237 238 OP_REQUIRES_ASYNC(context, input.dims() >= 1, 239 errors::InvalidArgument("Shape must be at least rank 1"), 240 done); 241 242 const int64_t num_indices = segment_ids.NumElements(); 243 OP_REQUIRES_ASYNC( 244 context, num_indices == input.dim_size(0), 245 errors::InvalidArgument( 246 "segment_ids should be the same size as dimension 0 of" 247 " input."), 248 done); 249 250 if (num_indices == 0) { 251 TensorShape output_shape = input.shape(); 252 output_shape.set_dim(0, 0); 253 254 Tensor* output = nullptr; 255 OP_REQUIRES_OK_ASYNC( 256 context, context->allocate_output(0, output_shape, &output), done); 257 done(); 258 return; 259 } 260 261 se::DeviceMemoryBase output_rows_device( 262 const_cast<Tensor&>(segment_ids).template flat<Index>().data() + 263 (num_indices - 1)); 264 ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true); 265 266 auto stream = context->op_device_context()->stream(); 267 OP_REQUIRES_ASYNC( 268 context, 269 stream 270 ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device, 271 sizeof(Index)) 272 .ok(), 273 errors::Internal(type_string() + 274 ": failed to copy output_rows from device"), 275 done); 276 277 SegmentReductionFunctor functor_; 278 auto create_and_check_output = [context, output_rows_host, &input, 279 &segment_ids, &functor_, done]() { 280 // Ensure that within the callback, the proper GPU settings are 281 // configured. 282 auto stream = context->op_device_context()->stream(); 283 ScopedActivateExecutorContext scoped_activation{stream->parent()}; 284 285 Index output_rows = *output_rows_host.data(); 286 output_rows++; 287 OP_REQUIRES_ASYNC(context, output_rows > 0, 288 errors::InvalidArgument("segment ids must be >= 0"), 289 done); 290 291 TensorShape output_shape = input.shape(); 292 // Since we're changing the first dimension of the shape, we need to make 293 // sure the new shape won't overflow. 294 OP_REQUIRES_OK_ASYNC(context, 295 output_shape.SetDimWithStatus(0, output_rows), done); 296 297 Tensor* output = nullptr; 298 OP_REQUIRES_OK_ASYNC( 299 context, context->allocate_output(0, output_shape, &output), done); 300 301 bool use_deterministic_kernels = 302 #if defined(PLATFORM_WINDOWS) 303 // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows 304 // CI build error. 305 false; 306 #else 307 UseDeterministicSegmentReductions() || 308 (!SegmentReductionFunctor::atomic_reduction_is_associative && 309 OpDeterminismRequired()); 310 #endif 311 312 // The determinism check is here, rather than inside the functor (as it is 313 // for the unsorted segment reduction ops) because the done callback 314 // (required for OP_REQUIRES_ASYNC) is not available inside the functor. 315 bool determinism_requirement_met = 316 use_deterministic_kernels || 317 SegmentReductionFunctor::atomic_reduction_is_associative || 318 !OpDeterminismRequired() || 319 DisableSegmentReductionOpDeterminismExceptions(); 320 OP_REQUIRES_ASYNC( 321 context, determinism_requirement_met, 322 errors::Unimplemented( 323 "Deterministic GPU implementation of sorted segment reduction op" 324 " not available."), 325 done); 326 327 auto output_flat = output->flat_outer_dims<T>(); 328 auto data_ptr = input.template flat<T>().data(); 329 auto segment_flat = segment_ids.flat<Index>(); 330 functor_(context, context->eigen_device<GPUDevice>(), output_rows, 331 segment_ids.shape(), IsMean, segment_flat, input.NumElements(), 332 data_ptr, output_flat); 333 334 done(); 335 }; 336 337 context->device() 338 ->tensorflow_accelerator_device_info() 339 ->event_mgr->ThenExecute(stream, create_and_check_output); 340 } 341 }; 342 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 343 344 // ____________________________________________________________________________ 345 // Unsorted segment reduction ops. 346 347 namespace functor { 348 349 // The ReductionFunctor implementation for CPU. 350 template <typename T, typename Index, typename InitialValueF, 351 typename ReductionF> 352 struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> { 353 void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape, 354 typename TTypes<Index>::ConstFlat segment_ids, 355 typename TTypes<T, 2>::ConstTensor data, 356 typename TTypes<T, 2>::Tensor output) { 357 auto cpu_device = ctx->eigen_cpu_device(); 358 output.device(cpu_device) = output.constant(InitialValueF()()); 359 if (data.size() == 0) { 360 return; 361 } 362 363 // This functor will reduce `N` rows input to `num_segments` rows output. 364 const int64_t N = segment_ids.dimension(0); 365 const int64_t num_segments = output.dimension(0); 366 const int64_t inner_dim = data.dimension(1); 367 ReductionF reduction; 368 369 // `num_real_segment` counts the rows actually reduced from input, 370 // the rows with negative segment index will be excluded. 371 // It will be used for cost model. 372 int64_t num_real_segment = N; 373 // `num_reductions` counts the rows actually reduced in output, 374 // the rows only filled with InitialValueF() will be excluded. 375 int64_t num_reductions = 0; 376 // `row_counter` records how many input rows will be reduced in each 377 // output row, the row only fills with InitialValueF() will keep 0. 378 // Length of non-zero elements is `num_reductions`. 379 std::vector<Index> row_counter(num_segments, 0); 380 381 for (int64_t i = 0; i < N; ++i) { 382 Index j = internal::SubtleMustCopy(segment_ids(i)); 383 if (j < 0) { 384 --num_real_segment; 385 continue; 386 } 387 OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments), 388 errors::InvalidArgument( 389 "segment_ids", SliceDebugString(segment_ids_shape, i), 390 " = ", j, " is out of range [0, ", num_segments, ")")); 391 if (row_counter[j] == 0) num_reductions++; 392 row_counter[j]++; 393 } 394 395 // Nothing to reduce. All output values equal to `InitialValueF()`. 396 if (num_reductions == 0) return; 397 398 // Parallelize by `num_segments`. It's simple, efficient and safe 399 // (no data dependency): 400 // 401 // input segment_ids num_segments operation 402 // | a0 | | 0 | worker 1: |0| f(a0, a1) 403 // | b0 | | 1 | worker 2: |1| f(b0, b1) 404 // N | c0 | | 2 | --> worker 3: |2| f(c0) 405 // | b1 | | 1 | 406 // | a1 | | 0 | 407 // 408 // TODO(intel-tf): Balance workload in `row_counter` to make parallelism 409 // more efficient. 410 auto reductionWorker = [&](int64_t begin, int64_t end) -> void { 411 for (int64_t i = 0; i < N; i++) { 412 Index j = internal::SubtleMustCopy(segment_ids(i)); 413 // If `j` is in work scope of this worker, do the reduction. 414 if (j >= begin && j < end) { 415 reduction(data.template chip<0>(i), output.template chip<0>(j)); 416 } 417 } 418 }; 419 420 // Reduction functors includes Sum, Max, Min, etc. Simply consider it 421 // will cost 5 cycles per operation. 422 const int64_t kAverTaskSize = num_real_segment / num_segments; 423 const int64_t compute_cycles = 5 * inner_dim * kAverTaskSize; 424 const int64_t input_bytes = sizeof(T) * inner_dim * kAverTaskSize; 425 const int64_t output_bytes = sizeof(T) * inner_dim * kAverTaskSize; 426 const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles); 427 cpu_device.parallelFor(num_segments, cost, reductionWorker); 428 } 429 }; 430 431 template <typename T> 432 using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>; 433 434 template <typename T> 435 using constMatrixChip = 436 Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>; 437 438 // reduction functors 439 template <typename T> 440 struct SumOp { 441 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 442 output += data; 443 } 444 }; 445 446 template <typename T> 447 struct MaxOp { 448 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 449 output = data.cwiseMax(output); 450 } 451 }; 452 453 template <typename T> 454 struct MinOp { 455 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 456 output = data.cwiseMin(output); 457 } 458 }; 459 460 template <typename T> 461 struct ProdOp { 462 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 463 output *= data; 464 } 465 }; 466 } // namespace functor 467 468 // The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor 469 // is the device specific implementation of the reduction. These device 470 // specific implementations are templated themselves with the corresponding 471 // initial value functors and reduction functors. 472 template <typename T, typename Index, typename DeviceReductionFunctor> 473 class UnsortedSegmentReductionOp : public OpKernel { 474 public: 475 explicit UnsortedSegmentReductionOp(OpKernelConstruction* context) 476 : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {} 477 478 void Compute(OpKernelContext* context) override { 479 const Tensor& data = context->input(0); 480 const Tensor& segment_ids = context->input(1); 481 const Tensor& num_segments = context->input(2); 482 OP_REQUIRES_OK(context, 483 internal::ValidateUnsortedSegmentReduction( 484 this, context, data, segment_ids, num_segments)); 485 const auto segment_flat = segment_ids.flat<Index>(); 486 const int64_t output_rows = internal::SubtleMustCopy(static_cast<int64_t>( 487 num_segments.dtype() == DT_INT32 ? num_segments.scalar<int32>()() 488 : num_segments.scalar<int64_t>()())); 489 OP_REQUIRES(context, output_rows >= 0, 490 errors::InvalidArgument("Input num_segments == ", output_rows, 491 " must not be negative.")); 492 TensorShape output_shape; 493 output_shape.AddDim(output_rows); 494 for (int i = segment_ids.dims(); i < data.dims(); i++) { 495 output_shape.AddDim(data.dim_size(i)); 496 } 497 Tensor* output = nullptr; 498 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 499 auto output_flat = output->flat_outer_dims<T>(); 500 auto data_flat = data.flat_inner_outer_dims<T, 2>(segment_ids.dims() - 1); 501 reduction_functor_(context, segment_ids.shape(), segment_flat, data_flat, 502 output_flat); 503 } 504 505 protected: 506 DeviceReductionFunctor reduction_functor_; 507 }; 508 509 // ____________________________________________________________________________ 510 // Sparse segment reduction ops. 511 512 // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented 513 // by two dense tensors, one containing the data, and the other containing 514 // indices into the data. 515 // 516 // The template parameters are: 517 // * Device: An Eigen device object, on which the kernel will execute. 518 // * T: The value type. 519 // * Index: The element type of the indices tensor (int32 or int64). 520 // * SegmentId: The element type of the segment_ids tensor (int32 or int64). 521 template <typename Device, class T, typename Index, typename SegmentId> 522 class SparseSegmentReductionOpBase : public OpKernel { 523 public: 524 explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, 525 bool is_mean, bool is_sqrtn, 526 bool has_num_segments, T default_value) 527 : OpKernel(context), 528 dtidx_(DataTypeToEnum<Index>::v()), 529 is_mean_(is_mean), 530 is_sqrtn_(is_sqrtn), 531 has_num_segments_(has_num_segments), 532 default_value_(default_value) {} 533 534 void Compute(OpKernelContext* context) override { 535 const Tensor& input = context->input(0); 536 const Tensor& indices = context->input(1); 537 const Tensor& segment_ids = context->input(2); 538 539 OP_REQUIRES_OK( 540 context, internal::ValidateSparseSegmentReduction( 541 context, input, indices, segment_ids, has_num_segments_)); 542 543 Index output_rows = -1; 544 if (has_num_segments_) { 545 const Tensor& num_segments = context->input(3); 546 // Note that there is a Tnumsegments parameter on the op, but it is not 547 // plumbed through to here and so always takes its default value of int32. 548 output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()()); 549 } 550 const int64_t num_indices = indices.NumElements(); 551 552 auto input_flat = input.flat_outer_dims<T>(); 553 const int64_t num_col = input_flat.dimension(1); 554 const auto indices_vec = indices.vec<Index>(); 555 const auto segment_vec = segment_ids.vec<SegmentId>(); 556 // Note that the current implementation assumes that segment_vec values are 557 // sorted. 558 const SegmentId last_segment_id = 559 num_indices > 0 ? segment_vec(num_indices - 1) : 0; 560 int64_t limit = dtidx_ == DataType::DT_INT32 ? kint32max : kint64max; 561 562 OP_REQUIRES( 563 context, last_segment_id < limit, 564 errors::InvalidArgument("Last segment id must be < kintmax, got ", 565 last_segment_id, " limit ", limit)); 566 567 const SegmentId last_segment_id_plus_one = 568 num_indices > 0 569 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 570 : 0; 571 572 if (has_num_segments_) { 573 OP_REQUIRES( 574 context, output_rows >= last_segment_id_plus_one, 575 errors::InvalidArgument("segment ids must be < num_segments")); 576 } else { 577 output_rows = last_segment_id_plus_one; 578 } 579 OP_REQUIRES(context, output_rows >= 0, 580 errors::InvalidArgument("segment ids must be >= 0")); 581 582 TensorShape output_shape = input.shape(); 583 OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, output_rows)); 584 585 // Note that we do not initialize the output buffer with a default value, so 586 // we need to explicitly set missing indices to the default value. 587 Tensor* output = nullptr; 588 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 589 if (num_indices == 0) { 590 if (output_rows > 0) { 591 output->flat_outer_dims<T>().setConstant(default_value_); 592 } 593 return; 594 } 595 OP_REQUIRES(context, output_rows > 0, 596 errors::InvalidArgument("segment ids must be >= 0")); 597 auto output_flat = output->flat_outer_dims<T>(); 598 599 Tensor temp; 600 if (input.dtype() == DT_BFLOAT16 || input.dtype() == DT_HALF) { 601 temp = tensorflow::Tensor(DT_FLOAT, output_shape); 602 } 603 auto temp_flat = temp.flat_outer_dims<float>(); 604 605 int64_t start = 0, end = 1; 606 // Index from which the output is not initialized. 607 SegmentId uninitialized_index = 0; 608 SegmentId out_index = internal::SubtleMustCopy(segment_vec(start)); 609 610 while (true) { 611 // We initialize next_index to 0 to avoid "warning: 'next_index' may be 612 // used uninitialized in this function" in the Mac build (since the 613 // compiler isn't smart enough to realize the code is safe). 614 SegmentId next_index = 0; 615 if (end < num_indices) { 616 next_index = internal::SubtleMustCopy(segment_vec(end)); 617 if (out_index == next_index) { 618 ++end; 619 continue; 620 } 621 // We have a new segment here. Verify that the segment ids are growing. 622 OP_REQUIRES(context, out_index < next_index, 623 errors::InvalidArgument("segment ids are not increasing")); 624 } 625 626 OP_REQUIRES( 627 context, FastBoundsCheck(out_index, output_rows), 628 errors::InvalidArgument( 629 "Segment id ", out_index, " out of range [0, ", output_rows, 630 "), possibly because 'segment_ids' input is not sorted.")); 631 632 // If there is a gap between two indices, we need to set that gap to the 633 // default value. 634 if (out_index > uninitialized_index) { 635 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 636 out_index - uninitialized_index, num_col); 637 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 638 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 639 gap_slice.setConstant(default_value_); 640 } 641 642 auto out = output_flat.template chip<0>(out_index); 643 auto temp = temp_flat.template chip<0>(out_index); 644 const int bad_offset = Reduce<T, Index>(input_flat, indices_vec, start, 645 end - start, out, temp); 646 OP_REQUIRES(context, bad_offset < 0, 647 errors::InvalidArgument( 648 "Bad: indices[", start + bad_offset, 649 "] == ", indices_vec(start + bad_offset), 650 " out of range [0, ", input_flat.dimension(0), ")")); 651 652 start = end; 653 ++end; 654 uninitialized_index = out_index + 1; 655 out_index = next_index; 656 if (end > num_indices) break; 657 } 658 659 // Fill the gap at the end with the default value. 660 if (uninitialized_index < output_rows) { 661 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 662 output_rows - uninitialized_index, num_col); 663 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 664 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 665 gap_slice.setConstant(default_value_); 666 } 667 } 668 669 private: 670 const DataType dtidx_; 671 template <typename Tin> 672 using EnableIfBfloat16OrHalf = 673 typename std::enable_if<std::is_same<Tin, bfloat16>::value || 674 std::is_same<Tin, Eigen::half>::value, 675 int>::type; 676 template <typename Tin> 677 using EnableIfNotBfloat16OrHalf = 678 typename std::enable_if<!std::is_same<Tin, bfloat16>::value && 679 !std::is_same<Tin, Eigen::half>::value, 680 int>::type; 681 682 template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0> 683 EIGEN_ALWAYS_INLINE auto fetch_val( 684 const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) { 685 return input_flat.template chip<0>(index); 686 } 687 688 template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0> 689 EIGEN_ALWAYS_INLINE auto fetch_val( 690 const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) { 691 return input_flat.template chip<0>(index).template cast<float>(); 692 } 693 694 template <typename Tout> 695 EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64_t num) { 696 Tout m(1); 697 if (is_mean_ && (num < 10)) { 698 m = Tout(num); 699 } 700 if (is_sqrtn_ && (num < 10)) { 701 m = Tout(sqrt(num)); 702 } 703 return Tout(1) / m; 704 } 705 706 template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0> 707 int64_t Reduce( 708 const typename TTypes<Tin>::ConstMatrix& input_flat, 709 const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start, 710 int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out, 711 Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) { 712 return ReduceImpl<Tin, Tindex, Tin>(input_flat, indices_vec, start, num, 713 out, get_scaling_factor<Tin>(num)); 714 } 715 716 template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0> 717 int64_t Reduce( 718 const typename TTypes<Tin>::ConstMatrix& input_flat, 719 const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start, 720 int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out, 721 Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) { 722 int64_t res = 723 ReduceImpl<Tin, Tindex, float>(input_flat, indices_vec, start, num, 724 temp, get_scaling_factor<float>(num)); 725 out = temp.template cast<Tin>(); 726 return res; 727 } 728 729 template <typename Tin, typename Tindex, typename Tout> 730 int64_t ReduceImpl( 731 const typename TTypes<Tin>::ConstMatrix& input_flat, 732 const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start, 733 int64_t num, 734 Eigen::TensorChippingOp<0, typename TTypes<Tout>::Matrix> out, 735 const Tout scaling_factor) { 736 #define INDEX(n, i) \ 737 const auto index##n = indices_vec(start + (i)); \ 738 if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i); 739 740 #define L(n) fetch_val<Tin, Tindex>(input_flat, index##n) 741 742 if (num == 1) { 743 INDEX(0, 0); 744 out = L(0); 745 } else { 746 int64_t r = num & 7; 747 switch (r) { 748 case 2: { 749 INDEX(0, 0); 750 INDEX(1, 1); 751 out = (L(0) + L(1)) * scaling_factor; 752 break; 753 } 754 case 3: { 755 INDEX(0, 0); 756 INDEX(1, 1); 757 INDEX(2, 2); 758 out = (L(0) + L(1) + L(2)) * scaling_factor; 759 break; 760 } 761 case 4: { 762 INDEX(0, 0); 763 INDEX(1, 1); 764 INDEX(2, 2); 765 INDEX(3, 3); 766 out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor; 767 break; 768 } 769 case 5: { 770 INDEX(0, 0); 771 INDEX(1, 1); 772 INDEX(2, 2); 773 INDEX(3, 3); 774 INDEX(4, 4); 775 out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor; 776 break; 777 } 778 case 6: { 779 INDEX(0, 0); 780 INDEX(1, 1); 781 INDEX(2, 2); 782 INDEX(3, 3); 783 INDEX(4, 4); 784 INDEX(5, 5); 785 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor; 786 break; 787 } 788 case 7: { 789 INDEX(0, 0); 790 INDEX(1, 1); 791 INDEX(2, 2); 792 INDEX(3, 3); 793 INDEX(4, 4); 794 INDEX(5, 5); 795 INDEX(6, 6); 796 out = 797 (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor; 798 break; 799 } 800 case 0: { 801 INDEX(0, 0); 802 INDEX(1, 1); 803 INDEX(2, 2); 804 INDEX(3, 3); 805 INDEX(4, 4); 806 INDEX(5, 5); 807 INDEX(6, 6); 808 INDEX(7, 7); 809 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) * 810 scaling_factor; 811 r = 8; 812 break; 813 } 814 case 1: { 815 INDEX(0, 0); 816 INDEX(1, 1); 817 INDEX(2, 2); 818 INDEX(3, 3); 819 INDEX(4, 4); 820 INDEX(5, 5); 821 INDEX(6, 6); 822 INDEX(7, 7); 823 INDEX(8, 8); 824 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) * 825 scaling_factor; 826 r = 9; 827 break; 828 } 829 } 830 for (; r < num; r += 8) { 831 INDEX(0, r); 832 INDEX(1, r + 1); 833 INDEX(2, r + 2); 834 INDEX(3, r + 3); 835 INDEX(4, r + 4); 836 INDEX(5, r + 5); 837 INDEX(6, r + 6); 838 INDEX(7, r + 7); 839 out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7); 840 } 841 if (is_mean_ && num >= 10) { 842 out = out / static_cast<Tout>(num); 843 } 844 if (is_sqrtn_ && num >= 10) { 845 out = out / static_cast<Tout>(sqrt(num)); 846 } 847 } 848 849 return -1; 850 #undef L 851 #undef INDEX 852 } 853 854 const bool is_mean_; 855 const bool is_sqrtn_; 856 const bool has_num_segments_; 857 const T default_value_; 858 }; 859 860 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 861 862 // Specialization for GPU. Must be Async because may need to wait for a host to 863 // device memcpy before allocating output. 864 template <class T, typename Index, typename SegmentId> 865 class SparseSegmentReductionOpBase<GPUDevice, T, Index, SegmentId> 866 : public AsyncOpKernel { 867 public: 868 explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, 869 bool is_mean, bool is_sqrtn, 870 bool has_num_segments, T default_value) 871 : AsyncOpKernel(context), 872 is_mean_(is_mean), 873 is_sqrtn_(is_sqrtn), 874 has_num_segments_(has_num_segments), 875 default_value_(default_value) {} 876 877 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 878 const Tensor& input = context->input(0); 879 const Tensor& indices = context->input(1); 880 const Tensor& segment_ids = context->input(2); 881 882 OP_REQUIRES_OK_ASYNC( 883 context, 884 internal::ValidateSparseSegmentReduction( 885 context, input, indices, segment_ids, has_num_segments_), 886 done); 887 888 ScratchSpace<SegmentId> last_segment_id_host(context, 1, /*on_host=*/true); 889 890 auto create_and_check_output = [this, context, input, indices, segment_ids, 891 last_segment_id_host, done]() { 892 // Ensure that within the callback, the proper GPU settings are 893 // configured. 894 auto stream = context->op_device_context()->stream(); 895 ScopedActivateExecutorContext scoped_activation{stream->parent()}; 896 897 SegmentId last_segment_id = *last_segment_id_host.data(); 898 SegmentId output_rows = last_segment_id + 1; 899 OP_REQUIRES_ASYNC(context, output_rows > 0, 900 errors::InvalidArgument("segment ids must be >= 0"), 901 done); 902 903 TensorShape output_shape = input.shape(); 904 output_shape.set_dim(0, output_rows); 905 906 Tensor* output = nullptr; 907 OP_REQUIRES_OK_ASYNC( 908 context, context->allocate_output(0, output_shape, &output), done); 909 910 auto input_flat = input.flat_outer_dims<T>(); 911 const auto indices_vec = indices.vec<Index>(); 912 const auto segment_ids_vec = segment_ids.vec<SegmentId>(); 913 auto output_flat = output->flat_outer_dims<T>(); 914 915 functor::SparseSegmentReductionFunctor<T, Index, SegmentId> functor; 916 OP_REQUIRES_OK_ASYNC( 917 context, 918 functor(context, is_mean_, is_sqrtn_, default_value_, input_flat, 919 indices_vec, segment_ids_vec, output_flat), 920 done); 921 done(); 922 }; 923 924 if (has_num_segments_) { 925 // No need to do any device to host memcpy, just compute synchronously. 926 const Tensor& num_segments_t = context->input(3); 927 SegmentId num_segments = 928 internal::SubtleMustCopy(num_segments_t.dtype() == DT_INT32 929 ? num_segments_t.scalar<int32>()() 930 : num_segments_t.scalar<int64_t>()()); 931 *last_segment_id_host.mutable_data() = num_segments - 1; 932 create_and_check_output(); 933 } else { 934 const int64_t num_indices = indices.NumElements(); 935 // Need to copy last element of segment_ids from device to host, and then 936 // asynchronously allocate the output and finish the computation. 937 se::DeviceMemoryBase last_segment_id_device( 938 const_cast<Tensor&>(segment_ids).template flat<SegmentId>().data() + 939 (num_indices - 1)); 940 auto stream = context->op_device_context()->stream(); 941 OP_REQUIRES_ASYNC( 942 context, 943 stream 944 ->ThenMemcpy(last_segment_id_host.mutable_data(), 945 last_segment_id_device, sizeof(SegmentId)) 946 .ok(), 947 errors::Internal(type_string() + 948 ": failed to copy last_segment_id from device"), 949 done); 950 context->device() 951 ->tensorflow_accelerator_device_info() 952 ->event_mgr->ThenExecute(stream, create_and_check_output); 953 } 954 } 955 956 private: 957 const bool is_mean_; 958 const bool is_sqrtn_; 959 const bool has_num_segments_; 960 const T default_value_; 961 }; 962 963 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 964 965 template <typename Device, class T, typename Index, typename SegmentId> 966 class SparseSegmentReductionMeanOp 967 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 968 public: 969 explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context) 970 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 971 context, true /*is_mean*/, false /*is_sqrtn*/, 972 false /* has_num_segments */, T(0) /* default_value */) {} 973 }; 974 975 template <typename Device, class T, typename Index, typename SegmentId> 976 class SparseSegmentReductionMeanWithNumSegmentsOp 977 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 978 public: 979 explicit SparseSegmentReductionMeanWithNumSegmentsOp( 980 OpKernelConstruction* context) 981 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 982 context, true /*is_mean*/, false /*is_sqrtn*/, 983 true /* has_num_segments */, T(0) /* default_value */) {} 984 }; 985 986 template <typename Device, class T, typename Index, typename SegmentId> 987 class SparseSegmentReductionSqrtNOp 988 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 989 public: 990 explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context) 991 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 992 context, false /*is_mean*/, true /*is_sqrtn*/, 993 false /* has_num_segments */, T(0) /* default_value */) {} 994 }; 995 996 template <typename Device, class T, typename Index, typename SegmentId> 997 class SparseSegmentReductionSqrtNWithNumSegmentsOp 998 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 999 public: 1000 explicit SparseSegmentReductionSqrtNWithNumSegmentsOp( 1001 OpKernelConstruction* context) 1002 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 1003 context, false /*is_mean*/, true /*is_sqrtn*/, 1004 true /* has_num_segments */, T(0) /* default_value */) {} 1005 }; 1006 1007 template <typename Device, class T, typename Index, typename SegmentId> 1008 class SparseSegmentReductionSumOp 1009 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 1010 public: 1011 explicit SparseSegmentReductionSumOp(OpKernelConstruction* context) 1012 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 1013 context, false /*is_mean*/, false /*is_sqrtn*/, 1014 false /* has_num_segments */, T(0) /* default_value */) {} 1015 }; 1016 1017 template <typename Device, class T, typename Index, typename SegmentId> 1018 class SparseSegmentReductionSumWithNumSegmentsOp 1019 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { 1020 public: 1021 explicit SparseSegmentReductionSumWithNumSegmentsOp( 1022 OpKernelConstruction* context) 1023 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( 1024 context, false /*is_mean*/, false /*is_sqrtn*/, 1025 true /* has_num_segments */, T(0) /* default_value */) {} 1026 }; 1027 1028 namespace functor { 1029 1030 template <typename T, typename Index, typename SegmentId> 1031 struct SparseSegmentGradFunctor<CPUDevice, T, Index, SegmentId> { 1032 void operator()(OpKernelContext* context, 1033 SparseSegmentReductionOperation operation, 1034 typename TTypes<T>::ConstMatrix input_flat, 1035 typename TTypes<Index>::ConstVec indices_vec, 1036 typename TTypes<SegmentId>::ConstVec segment_vec, 1037 typename TTypes<T>::Matrix output_flat) { 1038 const int64_t N = indices_vec.size(); 1039 const SegmentId M = output_flat.dimension(0); 1040 1041 // Note that similar to SparseSegmentMean, we assume that segment_vec is 1042 // already sorted and has non-negative values. 1043 const SegmentId num_segments = input_flat.dimension(0); 1044 const SegmentId last_segment_id_plus_one = 1045 internal::SubtleMustCopy(segment_vec(N - 1)) + 1; 1046 OP_REQUIRES(context, last_segment_id_plus_one <= num_segments, 1047 errors::InvalidArgument("Invalid number of segments")); 1048 1049 // Compute scaling factors for input. 1050 std::vector<double> scaling( 1051 (operation == SparseSegmentReductionOperation::kSum ? 0 : num_segments), 1052 0.0); 1053 if (operation != SparseSegmentReductionOperation::kSum) { 1054 for (int64_t i = 0; i < N; ++i) { 1055 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); 1056 OP_REQUIRES( 1057 context, FastBoundsCheck(idx, num_segments), 1058 errors::InvalidArgument("Segment id ", idx, " out of range [0, ", 1059 num_segments, ").")); 1060 scaling[idx] += 1; 1061 } 1062 for (size_t i = 0; i < scaling.size(); ++i) { 1063 switch (operation) { 1064 case SparseSegmentReductionOperation::kSum: { 1065 OP_REQUIRES( 1066 context, false, 1067 errors::Internal( 1068 "Should not happen: sum inside SparseSegmentReductionOp " 1069 "scaling generation.")); 1070 } 1071 case SparseSegmentReductionOperation::kMean: { 1072 scaling[i] = 1.0 / std::max(scaling[i], 1.0); 1073 break; 1074 } 1075 case SparseSegmentReductionOperation::kSqrtN: { 1076 scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0)); 1077 break; 1078 } 1079 // No default to get compiler warnings for missing cases. 1080 } 1081 } 1082 } 1083 1084 output_flat.setZero(); 1085 std::vector<bool> is_modified(M, false); 1086 1087 for (int64_t i = 0; i < N; ++i) { 1088 const Index output_idx = internal::SubtleMustCopy(indices_vec(i)); 1089 OP_REQUIRES(context, FastBoundsCheck(output_idx, M), 1090 errors::InvalidArgument("Index ", output_idx, 1091 " out of range [0, ", M, ").")); 1092 1093 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); 1094 OP_REQUIRES( 1095 context, FastBoundsCheck(idx, num_segments), 1096 errors::InvalidArgument("Segment id ", idx, " out of range [0, ", 1097 num_segments, ").")); 1098 1099 const T scale = (operation == SparseSegmentReductionOperation::kSum 1100 ? static_cast<T>(1) 1101 : static_cast<T>(scaling[idx])); 1102 if (is_modified[output_idx]) { 1103 if (scale == 1.0) { 1104 output_flat.template chip<0>(output_idx) += 1105 input_flat.template chip<0>(idx); 1106 } else { 1107 output_flat.template chip<0>(output_idx) += 1108 input_flat.template chip<0>(idx) * scale; 1109 } 1110 } else { 1111 if (scale == 1.0) { 1112 output_flat.template chip<0>(output_idx) = 1113 input_flat.template chip<0>(idx); 1114 } else { 1115 output_flat.template chip<0>(output_idx) = 1116 input_flat.template chip<0>(idx) * scale; 1117 } 1118 } 1119 is_modified[output_idx] = true; 1120 } 1121 } 1122 }; 1123 1124 } // namespace functor 1125 1126 // Implements the common logic for the gradients of SparseSegmentReduction 1127 // kernels. 1128 // 1129 // The template parameters are: 1130 // * Device: An Eigen device object, on which the kernel will execute. 1131 // * T: The value type. 1132 // * Index: The element type of the indices tensor (int32 or int64). 1133 // * SegmentId: The element type of the segment_ids tensor (int32 or int64). 1134 template <typename Device, class T, typename Index, typename SegmentId> 1135 class SparseSegmentGradOpBase : public OpKernel { 1136 public: 1137 explicit SparseSegmentGradOpBase(OpKernelConstruction* context, 1138 SparseSegmentReductionOperation operation) 1139 : OpKernel(context), operation_(operation) {} 1140 1141 void Compute(OpKernelContext* context) override { 1142 const Tensor& input = context->input(0); 1143 const Tensor& indices = context->input(1); 1144 const Tensor& segment_ids = context->input(2); 1145 const Tensor& output_dim0 = context->input(3); 1146 1147 OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()), 1148 errors::InvalidArgument("indices should be a vector.")); 1149 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), 1150 errors::InvalidArgument("segment_ids should be a vector.")); 1151 OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_dim0.shape()), 1152 errors::InvalidArgument("output_dim0 should be a scalar.")); 1153 1154 const int64_t N = indices.NumElements(); 1155 OP_REQUIRES(context, N == segment_ids.NumElements(), 1156 errors::InvalidArgument( 1157 "segment_ids and indices should have same size.")); 1158 const SegmentId M = internal::SubtleMustCopy(output_dim0.scalar<int32>()()); 1159 1160 auto input_flat = input.flat_outer_dims<T>(); 1161 const auto indices_vec = indices.vec<Index>(); 1162 const auto segment_vec = segment_ids.vec<SegmentId>(); 1163 1164 TensorShape output_shape = input.shape(); 1165 OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, M)); 1166 Tensor* output = nullptr; 1167 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 1168 if (M == 0 || N == 0) return; 1169 1170 auto output_flat = output->flat_outer_dims<T>(); 1171 functor::SparseSegmentGradFunctor<Device, T, Index, SegmentId>()( 1172 context, operation_, input_flat, indices_vec, segment_vec, output_flat); 1173 } 1174 1175 private: 1176 const SparseSegmentReductionOperation operation_; 1177 }; 1178 1179 template <typename Device, class T, typename Index, typename SegmentId> 1180 class SparseSegmentSumGradOp 1181 : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> { 1182 public: 1183 explicit SparseSegmentSumGradOp(OpKernelConstruction* context) 1184 : SparseSegmentGradOpBase<Device, T, Index, SegmentId>( 1185 context, SparseSegmentReductionOperation::kSum) {} 1186 }; 1187 1188 template <typename Device, class T, typename Index, typename SegmentId> 1189 class SparseSegmentMeanGradOp 1190 : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> { 1191 public: 1192 explicit SparseSegmentMeanGradOp(OpKernelConstruction* context) 1193 : SparseSegmentGradOpBase<Device, T, Index, SegmentId>( 1194 context, SparseSegmentReductionOperation::kMean) {} 1195 }; 1196 1197 template <typename Device, class T, typename Index, typename SegmentId> 1198 class SparseSegmentSqrtNGradOp 1199 : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> { 1200 public: 1201 explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context) 1202 : SparseSegmentGradOpBase<Device, T, Index, SegmentId>( 1203 context, SparseSegmentReductionOperation::kSqrtN) {} 1204 }; 1205 1206 } // namespace tensorflow 1207 1208 #endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ 1209