1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include "tensorflow/core/kernels/sparse_reorder_op.h" 19 20 #include <algorithm> 21 #include <numeric> 22 #include <unordered_map> 23 #include <utility> 24 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor_util.h" 29 #include "tensorflow/core/framework/types.h" 30 #include "tensorflow/core/lib/gtl/inlined_vector.h" 31 #include "tensorflow/core/util/sparse/sparse_tensor.h" 32 33 namespace tensorflow { 34 35 using CPUDevice = Eigen::ThreadPoolDevice; 36 37 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 38 using GPUDevice = Eigen::GpuDevice; 39 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 40 41 namespace functor { 42 43 template <typename T> 44 struct SparseReorderFunctor<CPUDevice, T> { operator ()tensorflow::functor::SparseReorderFunctor45 void operator()(OpKernelContext* context, const Tensor& input_ind, 46 const Tensor& input_val, const Tensor& input_shape_in) { 47 gtl::ArraySlice<int64_t> input_shape(input_shape_in.vec<int64_t>().data(), 48 input_shape_in.NumElements()); 49 50 gtl::InlinedVector<int64_t, 8> std_order(input_shape.size()); 51 std::iota(std_order.begin(), std_order.end(), 0); 52 53 // Check if the sparse tensor is already ordered correctly 54 sparse::SparseTensor input_sp; 55 OP_REQUIRES_OK( 56 context, sparse::SparseTensor::Create(input_ind, input_val, input_shape, 57 std_order, &input_sp)); 58 59 if (input_sp.IndicesValid().ok()) { 60 context->set_output(0, input_sp.indices()); 61 context->set_output(1, input_sp.values()); 62 } else { 63 // Deep-copy the input Tensors, then reorder in-place 64 sparse::SparseTensor reordered_sp; 65 OP_REQUIRES_OK(context, 66 sparse::SparseTensor::Create(tensor::DeepCopy(input_ind), 67 tensor::DeepCopy(input_val), 68 input_shape, &reordered_sp)); 69 reordered_sp.Reorder<T>(std_order); 70 context->set_output(0, reordered_sp.indices()); 71 context->set_output(1, reordered_sp.values()); 72 } 73 } 74 }; 75 76 } // namespace functor 77 78 template <typename Device, typename T> 79 class SparseReorderOp : public OpKernel { 80 public: SparseReorderOp(OpKernelConstruction * context)81 explicit SparseReorderOp(OpKernelConstruction* context) : OpKernel(context) {} 82 Compute(OpKernelContext * context)83 void Compute(OpKernelContext* context) override { 84 const Tensor& input_ind = context->input(0); 85 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_ind.shape()), 86 errors::InvalidArgument( 87 "Input indices should be a matrix but received shape ", 88 input_ind.shape().DebugString())); 89 90 const Tensor& input_val = context->input(1); 91 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_val.shape()), 92 errors::InvalidArgument( 93 "Input values should be a vector but received shape ", 94 input_val.shape().DebugString())); 95 96 const Tensor& input_shape_in = context->input(2); 97 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()), 98 errors::InvalidArgument( 99 "Input shape should be a vector but received shape ", 100 input_shape_in.shape().DebugString())); 101 102 functor::SparseReorderFunctor<Device, T>()(context, input_ind, input_val, 103 input_shape_in); 104 } 105 }; 106 107 #define REGISTER_KERNELS(type) \ 108 REGISTER_KERNEL_BUILDER( \ 109 Name("SparseReorder").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 110 SparseReorderOp<CPUDevice, type>) 111 112 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 113 #undef REGISTER_KERNELS 114 115 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 116 117 #define REGISTER_GPU_KERNELS(type) \ 118 REGISTER_KERNEL_BUILDER( \ 119 Name("SparseReorder").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 120 SparseReorderOp<GPUDevice, type>) 121 122 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); 123 TF_CALL_INTEGRAL_TYPES(REGISTER_GPU_KERNELS); 124 REGISTER_GPU_KERNELS(bool); 125 #undef REGISTER_GPU_KERNELS 126 127 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 128 129 } // namespace tensorflow 130