1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cstdint>
18 #include <string>
19 
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/platform/errors.h"
23 
24 // Please use the appropriate namespace for your project
25 namespace tensorflow {
26 namespace custom_op_examples {
27 
28 using CPUDevice = Eigen::ThreadPoolDevice;
29 using ::tensorflow::errors::InvalidArgument;
30 
31 template <typename T>
32 class MultiplexSparseOp : public OpKernel {
33  public:
MultiplexSparseOp(OpKernelConstruction * ctx)34   explicit MultiplexSparseOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
35   MultiplexSparseOp(const MultiplexSparseOp& other) = delete;
36   MultiplexSparseOp& operator=(const MultiplexSparseOp& other) = delete;
37   ~MultiplexSparseOp() override = default;
38 
Compute(OpKernelContext * ctx)39   void Compute(OpKernelContext* ctx) override {
40     const auto& cond_indices_tensor = ctx->input(0);
41     const auto& cond_values_tensor = ctx->input(1);
42     const auto& cond_shape_tensor = ctx->input(2);
43     const auto& a_indices_tensor = ctx->input(3);
44     const auto& a_values_tensor = ctx->input(4);
45     const auto& a_shape_tensor = ctx->input(5);
46     const auto& b_indices_tensor = ctx->input(6);
47     const auto& b_values_tensor = ctx->input(7);
48     const auto& b_shape_tensor = ctx->input(8);
49     OP_REQUIRES_OK(ctx,
50                    ValidateSparseTensor(cond_indices_tensor, cond_values_tensor,
51                                         cond_shape_tensor, "cond"));
52     OP_REQUIRES_OK(ctx, ValidateSparseTensor(a_indices_tensor, a_values_tensor,
53                                              a_shape_tensor, "a"));
54     OP_REQUIRES_OK(ctx, ValidateSparseTensor(b_indices_tensor, b_values_tensor,
55                                              b_shape_tensor, "b"));
56     OP_REQUIRES(
57         ctx, cond_shape_tensor.shape() == a_shape_tensor.shape(),
58         InvalidArgument("Sparse tensors must be the same shape. cond_shape: ",
59                         cond_shape_tensor.shape().DebugString(),
60                         " vs a_shape: ", a_shape_tensor.shape().DebugString()));
61     OP_REQUIRES(
62         ctx, a_shape_tensor.shape() == b_shape_tensor.shape(),
63         InvalidArgument("Sparse tensors must be the same shape. a_shape: ",
64                         a_shape_tensor.shape().DebugString(),
65                         " vs b_shape: ", b_shape_tensor.shape().DebugString()));
66     const int rank = a_shape_tensor.dim_size(0);
67     OP_REQUIRES(
68         ctx, rank == 1,
69         InvalidArgument("Sorry, multiplex for sparse tensors only "
70                         "supports rank 1 tensors to simplify this example."));
71     const int cond_elements = cond_indices_tensor.dim_size(0);
72     const int a_elements = a_indices_tensor.dim_size(0);
73     const int b_elements = b_indices_tensor.dim_size(0);
74     const auto cond_indices = cond_indices_tensor.matrix<int64_t>();
75     const auto cond_values = cond_values_tensor.flat<bool>();
76     const auto cond_shape = cond_shape_tensor.flat<int64_t>();
77     const auto a_indices = a_indices_tensor.matrix<int64_t>();
78     const auto a_values = a_values_tensor.flat<T>();
79     const auto a_shape = a_shape_tensor.flat<int64_t>();
80     const auto b_indices = b_indices_tensor.matrix<int64_t>();
81     const auto b_values = b_values_tensor.flat<T>();
82     const auto b_shape = b_shape_tensor.flat<int64_t>();
83     int cond_index = 0;
84     int a_index = 0;
85     int b_index = 0;
86     // This vector is a list of source tensors (a = true, b = false) and source
87     // indices.
88     std::vector<std::pair<bool, int>> merged_output;
89     merged_output.reserve(std::min(cond_elements, a_elements) + b_elements);
90     while (a_index < a_elements || b_index < b_elements) {
91       // Determine the whether the current location with values has a value
92       // for `a`, for `b` or for both `a` and `b`.
93       int64_t cur_row;
94       bool is_a_at_cur = false;
95       bool is_b_at_cur = false;
96       if (a_index < a_elements && b_index < b_elements) {
97         const int64_t a_row = a_indices(a_index, 0);
98         const int64_t b_row = b_indices(b_index, 0);
99         cur_row = std::min(a_row, b_row);
100         if (a_row == cur_row) {
101           is_a_at_cur = true;
102         }
103         if (b_row == cur_row) {
104           is_b_at_cur = true;
105         }
106       } else if (a_index < a_elements) {
107         cur_row = a_indices(a_index, 0);
108         is_a_at_cur = true;
109       } else {  // b_index < b_elements
110         cur_row = b_indices(b_index, 0);
111         is_b_at_cur = true;
112       }
113       // Deterimine if `cond` has a value at the current location
114       bool cond_flag = false;
115       while (cond_index < cond_elements) {
116         const int64_t cond_row = cond_indices(cond_index, 0);
117         if (cond_row > cur_row) {
118           break;
119         }
120         if (cond_row == cur_row) {
121           cond_flag = cond_values(cond_index);
122           break;
123         }
124         ++cond_index;
125       }
126       // Add `a` or `b` to the merged output based on the condition
127       if (is_a_at_cur) {
128         if (cond_flag) {
129           merged_output.emplace_back(true, a_index);
130         }
131         ++a_index;
132       }
133       if (is_b_at_cur) {
134         if (!cond_flag) {
135           merged_output.emplace_back(false, b_index);
136         }
137         ++b_index;
138       }
139     }
140 
141     // Allocate output tensors.
142     Tensor* output_indices_tensor;
143     Tensor* output_values_tensor;
144     Tensor* output_dense_shape_tensor;
145     const int num_values = merged_output.size();
146     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({num_values, rank}),
147                                              &output_indices_tensor));
148     OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({num_values}),
149                                              &output_values_tensor));
150     OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({rank}),
151                                              &output_dense_shape_tensor));
152     auto output_indices = output_indices_tensor->matrix<int64_t>();
153     auto output_values = output_values_tensor->flat<T>();
154     auto output_shape = output_dense_shape_tensor->flat<int64_t>();
155     for (int row = 0; row < num_values; ++row) {
156       const auto& source_flag = merged_output[row].first;
157       const auto& source_row = merged_output[row].second;
158       const auto& indices = source_flag ? a_indices : b_indices;
159       const auto& values = source_flag ? a_values : b_values;
160       for (int column = 0; column < rank; ++column) {
161         output_indices(row, column) = indices(source_row, column);
162       }
163       output_values(row) = values(source_row);
164     }
165     // Expand the shape of the output sparse tensor so that it is as large
166     // as the shape of the largest input in each dimension.
167     // An alternative behavoir would be to require that the shapes be the
168     // same and implement error checking that all the corresponding values
169     // in the shape tensors are the same (e.g.
170     // `cond_shape(i) == a_shape(i)` and `a_shape(i) == b_shape(i)` in
171     // OP_REQUIRES above and `output_shape(i) = a_shape(i)` here).
172     for (int i = 0; i < rank; ++i) {
173       output_shape(i) =
174           std::max(cond_shape(i), std::max(a_shape(i), b_shape(i)));
175     }
176   }
177 
178  private:
ValidateSparseTensor(const::tensorflow::Tensor & indices_tensor,const::tensorflow::Tensor & values_tensor,const::tensorflow::Tensor & shape_tensor,const string label)179   Status ValidateSparseTensor(const ::tensorflow::Tensor& indices_tensor,
180                               const ::tensorflow::Tensor& values_tensor,
181                               const ::tensorflow::Tensor& shape_tensor,
182                               const string label) {
183     if (!TensorShapeUtils::IsMatrix(indices_tensor.shape())) {
184       return InvalidArgument(
185           "Sparse indices for ", label,
186           " must be rank 2, not shape: ", indices_tensor.shape().DebugString());
187     }
188     if (!TensorShapeUtils::IsVector(values_tensor.shape())) {
189       return InvalidArgument("Sparse values for ", label,
190                              " must be a vector, not shape: ",
191                              values_tensor.shape().DebugString());
192     }
193     if (!TensorShapeUtils::IsVector(shape_tensor.shape())) {
194       return InvalidArgument(
195           "Sparse shape for ", label,
196           " must be a vector, not shape: ", shape_tensor.shape().DebugString());
197     }
198     if (indices_tensor.dim_size(0) != values_tensor.dim_size(0)) {
199       return InvalidArgument("Sparse indices and values for " + label +
200                                  " must have the same "
201                                  "number of rows. indices: ",
202                              indices_tensor.shape().DebugString(),
203                              " values: ", values_tensor.shape().DebugString());
204     }
205     return OkStatus();
206   }
207 };
208 
209 // To support tensors containing different types (e.g. int32, float), one
210 // kernel per type is registered and is templatized by the "T" attr value.
211 // See go/tf-custom-ops-guide
212 #define REGISTER_KERNELS_CPU(type)                              \
213   REGISTER_KERNEL_BUILDER(Name("Examples>MultiplexSparse")      \
214                               .Device(::tensorflow::DEVICE_CPU) \
215                               .TypeConstraint<type>("T"),       \
216                           MultiplexSparseOp<type>)
217 TF_CALL_ALL_TYPES(REGISTER_KERNELS_CPU);
218 
219 #undef REGISTER_KERNELS_CPU
220 
221 }  // namespace custom_op_examples
222 }  // namespace tensorflow
223