1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <algorithm> 17 #include <cstdint> 18 #include <string> 19 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/platform/errors.h" 23 24 // Please use the appropriate namespace for your project 25 namespace tensorflow { 26 namespace custom_op_examples { 27 28 using CPUDevice = Eigen::ThreadPoolDevice; 29 using ::tensorflow::errors::InvalidArgument; 30 31 template <typename T> 32 class MultiplexSparseOp : public OpKernel { 33 public: MultiplexSparseOp(OpKernelConstruction * ctx)34 explicit MultiplexSparseOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 35 MultiplexSparseOp(const MultiplexSparseOp& other) = delete; 36 MultiplexSparseOp& operator=(const MultiplexSparseOp& other) = delete; 37 ~MultiplexSparseOp() override = default; 38 Compute(OpKernelContext * ctx)39 void Compute(OpKernelContext* ctx) override { 40 const auto& cond_indices_tensor = ctx->input(0); 41 const auto& cond_values_tensor = ctx->input(1); 42 const auto& cond_shape_tensor = ctx->input(2); 43 const auto& a_indices_tensor = ctx->input(3); 44 const auto& a_values_tensor = ctx->input(4); 45 const auto& a_shape_tensor = ctx->input(5); 46 const auto& b_indices_tensor = ctx->input(6); 47 const auto& b_values_tensor = ctx->input(7); 48 const auto& b_shape_tensor = ctx->input(8); 49 OP_REQUIRES_OK(ctx, 50 ValidateSparseTensor(cond_indices_tensor, cond_values_tensor, 51 cond_shape_tensor, "cond")); 52 OP_REQUIRES_OK(ctx, ValidateSparseTensor(a_indices_tensor, a_values_tensor, 53 a_shape_tensor, "a")); 54 OP_REQUIRES_OK(ctx, ValidateSparseTensor(b_indices_tensor, b_values_tensor, 55 b_shape_tensor, "b")); 56 OP_REQUIRES( 57 ctx, cond_shape_tensor.shape() == a_shape_tensor.shape(), 58 InvalidArgument("Sparse tensors must be the same shape. cond_shape: ", 59 cond_shape_tensor.shape().DebugString(), 60 " vs a_shape: ", a_shape_tensor.shape().DebugString())); 61 OP_REQUIRES( 62 ctx, a_shape_tensor.shape() == b_shape_tensor.shape(), 63 InvalidArgument("Sparse tensors must be the same shape. a_shape: ", 64 a_shape_tensor.shape().DebugString(), 65 " vs b_shape: ", b_shape_tensor.shape().DebugString())); 66 const int rank = a_shape_tensor.dim_size(0); 67 OP_REQUIRES( 68 ctx, rank == 1, 69 InvalidArgument("Sorry, multiplex for sparse tensors only " 70 "supports rank 1 tensors to simplify this example.")); 71 const int cond_elements = cond_indices_tensor.dim_size(0); 72 const int a_elements = a_indices_tensor.dim_size(0); 73 const int b_elements = b_indices_tensor.dim_size(0); 74 const auto cond_indices = cond_indices_tensor.matrix<int64_t>(); 75 const auto cond_values = cond_values_tensor.flat<bool>(); 76 const auto cond_shape = cond_shape_tensor.flat<int64_t>(); 77 const auto a_indices = a_indices_tensor.matrix<int64_t>(); 78 const auto a_values = a_values_tensor.flat<T>(); 79 const auto a_shape = a_shape_tensor.flat<int64_t>(); 80 const auto b_indices = b_indices_tensor.matrix<int64_t>(); 81 const auto b_values = b_values_tensor.flat<T>(); 82 const auto b_shape = b_shape_tensor.flat<int64_t>(); 83 int cond_index = 0; 84 int a_index = 0; 85 int b_index = 0; 86 // This vector is a list of source tensors (a = true, b = false) and source 87 // indices. 88 std::vector<std::pair<bool, int>> merged_output; 89 merged_output.reserve(std::min(cond_elements, a_elements) + b_elements); 90 while (a_index < a_elements || b_index < b_elements) { 91 // Determine the whether the current location with values has a value 92 // for `a`, for `b` or for both `a` and `b`. 93 int64_t cur_row; 94 bool is_a_at_cur = false; 95 bool is_b_at_cur = false; 96 if (a_index < a_elements && b_index < b_elements) { 97 const int64_t a_row = a_indices(a_index, 0); 98 const int64_t b_row = b_indices(b_index, 0); 99 cur_row = std::min(a_row, b_row); 100 if (a_row == cur_row) { 101 is_a_at_cur = true; 102 } 103 if (b_row == cur_row) { 104 is_b_at_cur = true; 105 } 106 } else if (a_index < a_elements) { 107 cur_row = a_indices(a_index, 0); 108 is_a_at_cur = true; 109 } else { // b_index < b_elements 110 cur_row = b_indices(b_index, 0); 111 is_b_at_cur = true; 112 } 113 // Deterimine if `cond` has a value at the current location 114 bool cond_flag = false; 115 while (cond_index < cond_elements) { 116 const int64_t cond_row = cond_indices(cond_index, 0); 117 if (cond_row > cur_row) { 118 break; 119 } 120 if (cond_row == cur_row) { 121 cond_flag = cond_values(cond_index); 122 break; 123 } 124 ++cond_index; 125 } 126 // Add `a` or `b` to the merged output based on the condition 127 if (is_a_at_cur) { 128 if (cond_flag) { 129 merged_output.emplace_back(true, a_index); 130 } 131 ++a_index; 132 } 133 if (is_b_at_cur) { 134 if (!cond_flag) { 135 merged_output.emplace_back(false, b_index); 136 } 137 ++b_index; 138 } 139 } 140 141 // Allocate output tensors. 142 Tensor* output_indices_tensor; 143 Tensor* output_values_tensor; 144 Tensor* output_dense_shape_tensor; 145 const int num_values = merged_output.size(); 146 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({num_values, rank}), 147 &output_indices_tensor)); 148 OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({num_values}), 149 &output_values_tensor)); 150 OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({rank}), 151 &output_dense_shape_tensor)); 152 auto output_indices = output_indices_tensor->matrix<int64_t>(); 153 auto output_values = output_values_tensor->flat<T>(); 154 auto output_shape = output_dense_shape_tensor->flat<int64_t>(); 155 for (int row = 0; row < num_values; ++row) { 156 const auto& source_flag = merged_output[row].first; 157 const auto& source_row = merged_output[row].second; 158 const auto& indices = source_flag ? a_indices : b_indices; 159 const auto& values = source_flag ? a_values : b_values; 160 for (int column = 0; column < rank; ++column) { 161 output_indices(row, column) = indices(source_row, column); 162 } 163 output_values(row) = values(source_row); 164 } 165 // Expand the shape of the output sparse tensor so that it is as large 166 // as the shape of the largest input in each dimension. 167 // An alternative behavoir would be to require that the shapes be the 168 // same and implement error checking that all the corresponding values 169 // in the shape tensors are the same (e.g. 170 // `cond_shape(i) == a_shape(i)` and `a_shape(i) == b_shape(i)` in 171 // OP_REQUIRES above and `output_shape(i) = a_shape(i)` here). 172 for (int i = 0; i < rank; ++i) { 173 output_shape(i) = 174 std::max(cond_shape(i), std::max(a_shape(i), b_shape(i))); 175 } 176 } 177 178 private: ValidateSparseTensor(const::tensorflow::Tensor & indices_tensor,const::tensorflow::Tensor & values_tensor,const::tensorflow::Tensor & shape_tensor,const string label)179 Status ValidateSparseTensor(const ::tensorflow::Tensor& indices_tensor, 180 const ::tensorflow::Tensor& values_tensor, 181 const ::tensorflow::Tensor& shape_tensor, 182 const string label) { 183 if (!TensorShapeUtils::IsMatrix(indices_tensor.shape())) { 184 return InvalidArgument( 185 "Sparse indices for ", label, 186 " must be rank 2, not shape: ", indices_tensor.shape().DebugString()); 187 } 188 if (!TensorShapeUtils::IsVector(values_tensor.shape())) { 189 return InvalidArgument("Sparse values for ", label, 190 " must be a vector, not shape: ", 191 values_tensor.shape().DebugString()); 192 } 193 if (!TensorShapeUtils::IsVector(shape_tensor.shape())) { 194 return InvalidArgument( 195 "Sparse shape for ", label, 196 " must be a vector, not shape: ", shape_tensor.shape().DebugString()); 197 } 198 if (indices_tensor.dim_size(0) != values_tensor.dim_size(0)) { 199 return InvalidArgument("Sparse indices and values for " + label + 200 " must have the same " 201 "number of rows. indices: ", 202 indices_tensor.shape().DebugString(), 203 " values: ", values_tensor.shape().DebugString()); 204 } 205 return OkStatus(); 206 } 207 }; 208 209 // To support tensors containing different types (e.g. int32, float), one 210 // kernel per type is registered and is templatized by the "T" attr value. 211 // See go/tf-custom-ops-guide 212 #define REGISTER_KERNELS_CPU(type) \ 213 REGISTER_KERNEL_BUILDER(Name("Examples>MultiplexSparse") \ 214 .Device(::tensorflow::DEVICE_CPU) \ 215 .TypeConstraint<type>("T"), \ 216 MultiplexSparseOp<type>) 217 TF_CALL_ALL_TYPES(REGISTER_KERNELS_CPU); 218 219 #undef REGISTER_KERNELS_CPU 220 221 } // namespace custom_op_examples 222 } // namespace tensorflow 223