xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse_tensors_map_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <algorithm>
19 #include <numeric>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/resource_mgr.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_util.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 #include "tensorflow/core/util/overflow.h"
32 #include "tensorflow/core/util/sparse/sparse_tensor.h"
33 
34 namespace tensorflow {
35 
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 
38 using sparse::SparseTensor;
39 
40 class SparseTensorsMap : public ResourceBase {
41  public:
SparseTensorsMap(const string & name)42   explicit SparseTensorsMap(const string& name) : name_(name), counter_(0) {}
43 
DebugString() const44   string DebugString() const override { return "A SparseTensorsMap"; }
45 
46   typedef struct {
47     Tensor indices;
48     Tensor values;
49     gtl::InlinedVector<int64_t, 8> shape;
50   } PersistentSparseTensor;
51 
AddSparseTensor(OpKernelContext * ctx,const SparseTensor & sp,int64_t * handle)52   Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp,
53                          int64_t* handle) {
54     Tensor ix;
55     TF_RETURN_IF_ERROR(
56         ctx->allocate_temp(sp.indices().dtype(), sp.indices().shape(), &ix));
57     ix = sp.indices();
58 
59     Tensor values;
60     TF_RETURN_IF_ERROR(ctx->allocate_temp(sp.indices().dtype(),
61                                           sp.indices().shape(), &values));
62     values = sp.values();
63     {
64       mutex_lock l(mu_);
65       int64_t unique_st_handle = counter_++;  // increment is guarded on purpose
66       sp_tensors_[unique_st_handle] = PersistentSparseTensor{
67           ix, values,
68           gtl::InlinedVector<int64_t, 8>(sp.shape().begin(), sp.shape().end())};
69       *handle = unique_st_handle;
70     }
71     return OkStatus();
72   }
73 
RetrieveAndClearSparseTensors(OpKernelContext * ctx,const TTypes<int64_t>::ConstVec & handles,std::vector<SparseTensor> * sparse_tensors)74   Status RetrieveAndClearSparseTensors(
75       OpKernelContext* ctx, const TTypes<int64_t>::ConstVec& handles,
76       std::vector<SparseTensor>* sparse_tensors) {
77     sparse_tensors->clear();
78     sparse_tensors->reserve(handles.size());
79     {
80       mutex_lock l(mu_);
81       for (size_t i = 0; i < handles.size(); ++i) {
82         const int64_t handle = handles(i);
83         auto sp_iter = sp_tensors_.find(handle);
84         if (sp_iter == sp_tensors_.end()) {
85           return errors::InvalidArgument(
86               "Unable to find SparseTensor: ", handle, " in map: ", name_);
87         }
88         const Tensor* ix = &sp_iter->second.indices;
89         const Tensor* values = &sp_iter->second.values;
90         const auto& shape = sp_iter->second.shape;
91         SparseTensor tensor;
92         TF_RETURN_IF_ERROR(SparseTensor::Create(*ix, *values, shape, &tensor));
93         sparse_tensors->push_back(std::move(tensor));
94         sp_tensors_.erase(sp_iter);
95       }
96     }
97 
98     return OkStatus();
99   }
100 
101  protected:
~SparseTensorsMap()102   ~SparseTensorsMap() override {}
103 
104  private:
105   string name_;
106 
107   mutex mu_;
108   int64_t counter_ TF_GUARDED_BY(mu_);
109   std::unordered_map<int64_t, PersistentSparseTensor> sp_tensors_
110       TF_GUARDED_BY(mu_);
111 };
112 
113 class SparseTensorAccessingOp : public OpKernel {
114  public:
115   typedef std::function<Status(SparseTensorsMap**)> CreatorCallback;
116 
SparseTensorAccessingOp(OpKernelConstruction * context)117   explicit SparseTensorAccessingOp(OpKernelConstruction* context)
118       : OpKernel(context), sparse_tensors_map_(nullptr) {}
119 
120  protected:
~SparseTensorAccessingOp()121   ~SparseTensorAccessingOp() override {
122     if (sparse_tensors_map_) sparse_tensors_map_->Unref();
123   }
124 
GetMap(OpKernelContext * ctx,bool is_writing,SparseTensorsMap ** sparse_tensors_map)125   Status GetMap(OpKernelContext* ctx, bool is_writing,
126                 SparseTensorsMap** sparse_tensors_map) {
127     mutex_lock l(mu_);
128 
129     if (sparse_tensors_map_) {
130       *sparse_tensors_map = sparse_tensors_map_;
131       return OkStatus();
132     }
133 
134     TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(),
135                                    is_writing /* use_node_name_as_default */));
136 
137     CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) {
138       SparseTensorsMap* map = new SparseTensorsMap(cinfo_.name());
139       *c = map;
140       return OkStatus();
141     };
142 
143     TF_RETURN_IF_ERROR(
144         cinfo_.resource_manager()->LookupOrCreate<SparseTensorsMap>(
145             cinfo_.container(), cinfo_.name(), &sparse_tensors_map_,
146             sparse_tensors_map_creator));
147 
148     *sparse_tensors_map = sparse_tensors_map_;
149     return OkStatus();
150   }
151 
152  private:
153   ContainerInfo cinfo_;
154 
155   mutex mu_;
156   SparseTensorsMap* sparse_tensors_map_ TF_PT_GUARDED_BY(mu_);
157 };
158 
159 class AddSparseToTensorsMapOp : public SparseTensorAccessingOp {
160  public:
AddSparseToTensorsMapOp(OpKernelConstruction * context)161   explicit AddSparseToTensorsMapOp(OpKernelConstruction* context)
162       : SparseTensorAccessingOp(context) {}
163 
Compute(OpKernelContext * context)164   void Compute(OpKernelContext* context) override {
165     const Tensor* input_indices;
166     const Tensor* input_values;
167     const Tensor* input_shape;
168     SparseTensorsMap* map;
169 
170     OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
171     OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
172     OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
173     OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));
174 
175     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
176                 errors::InvalidArgument(
177                     "Input indices should be a matrix but received shape ",
178                     input_indices->shape().DebugString()));
179 
180     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
181                 errors::InvalidArgument(
182                     "Input values should be a vector but received shape ",
183                     input_values->shape().DebugString()));
184 
185     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
186                 errors::InvalidArgument(
187                     "Input shape should be a vector but received shape ",
188                     input_shape->shape().DebugString()));
189 
190     TensorShape input_shape_object;
191     OP_REQUIRES_OK(
192         context, TensorShapeUtils::MakeShape(input_shape->vec<int64_t>().data(),
193                                              input_shape->NumElements(),
194                                              &input_shape_object));
195     SparseTensor st;
196     OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
197                                                  input_shape_object, &st));
198     int64_t handle;
199     OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle));
200 
201     Tensor sparse_handle(DT_INT64, TensorShape({}));
202     auto sparse_handle_t = sparse_handle.scalar<int64_t>();
203 
204     sparse_handle_t() = handle;
205 
206     context->set_output(0, sparse_handle);
207   }
208 };
209 
210 REGISTER_KERNEL_BUILDER(Name("AddSparseToTensorsMap").Device(DEVICE_CPU),
211                         AddSparseToTensorsMapOp);
212 
213 template <typename T>
214 class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp {
215  public:
AddManySparseToTensorsMapOp(OpKernelConstruction * context)216   explicit AddManySparseToTensorsMapOp(OpKernelConstruction* context)
217       : SparseTensorAccessingOp(context) {}
218 
Compute(OpKernelContext * context)219   void Compute(OpKernelContext* context) override {
220     const Tensor* input_indices;
221     const Tensor* input_values;
222     const Tensor* input_shape;
223     SparseTensorsMap* map;
224 
225     OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
226     OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
227     OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
228     OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));
229 
230     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
231                 errors::InvalidArgument(
232                     "Input indices should be a matrix but received shape ",
233                     input_indices->shape().DebugString()));
234     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
235                 errors::InvalidArgument(
236                     "Input values should be a vector but received shape ",
237                     input_values->shape().DebugString()));
238     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
239                 errors::InvalidArgument(
240                     "Input shape should be a vector but received shape ",
241                     input_shape->shape().DebugString()));
242     OP_REQUIRES(
243         context,
244         input_values->shape().dim_size(0) == input_indices->shape().dim_size(0),
245         errors::InvalidArgument(
246             "Number of values must match first dimension of indices. ", "Got ",
247             input_values->shape().dim_size(0),
248             " values, indices shape: ", input_indices->shape().DebugString()));
249     OP_REQUIRES(
250         context,
251         input_shape->shape().dim_size(0) == input_indices->shape().dim_size(1),
252         errors::InvalidArgument(
253             "Number of dimensions must match second dimension of indices. ",
254             "Got ", input_shape->shape().dim_size(0),
255             " dimensions, indices shape: ",
256             input_indices->shape().DebugString()));
257 
258     int rank = input_shape->NumElements();
259 
260     OP_REQUIRES(
261         context, rank > 1,
262         errors::InvalidArgument(
263             "Rank of input SparseTensor should be > 1, but saw rank: ", rank));
264 
265     auto input_shape_vec = input_shape->vec<int64_t>();
266 
267     TensorShape tensor_input_shape;
268     OP_REQUIRES_OK(context, TensorShape::BuildTensorShape(input_shape_vec,
269                                                           &tensor_input_shape));
270     gtl::InlinedVector<int64_t, 8> std_order(rank);
271     std::iota(std_order.begin(), std_order.end(), 0);
272     SparseTensor input_st;
273     OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
274                                                  tensor_input_shape, std_order,
275                                                  &input_st));
276 
277     const int64_t N = input_shape_vec(0);
278 
279     Tensor sparse_handles(DT_INT64, TensorShape({N}));
280     auto sparse_handles_t = sparse_handles.vec<int64_t>();
281 
282     OP_REQUIRES_OK(context, input_st.IndicesValid());
283 
284     // We can generate the output shape proto string now, for all
285     // minibatch entries.
286     TensorShape output_shape;
287     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
288                                 input_shape_vec.data() + 1,
289                                 input_shape->NumElements() - 1, &output_shape));
290 
291     // Get groups by minibatch dimension
292     std::unordered_set<int64_t> visited;
293     sparse::GroupIterable minibatch = input_st.group({0});
294     for (const auto& subset : minibatch) {
295       const int64_t b = subset.group()[0];
296       visited.insert(b);
297       OP_REQUIRES(
298           context, b > -1 && b < N,
299           errors::InvalidArgument(
300               "Received unexpected column 0 value in input SparseTensor: ", b,
301               " < 0 or >= N (= ", N, ")"));
302 
303       const auto indices = subset.indices();
304       const auto values = subset.values<T>();
305       const int64_t num_entries = values.size();
306 
307       Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
308       Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
309 
310       auto output_indices_t = output_indices.matrix<int64_t>();
311       auto output_values_t = output_values.vec<T>();
312 
313       for (int i = 0; i < num_entries; ++i) {
314         for (int d = 1; d < rank; ++d) {
315           output_indices_t(i, d - 1) = indices(i, d);
316         }
317         output_values_t(i) = values(i);
318       }
319 
320       SparseTensor st_i;
321       OP_REQUIRES_OK(context,
322                      SparseTensor::Create(output_indices, output_values,
323                                           output_shape, &st_i));
324       int64_t handle;
325       OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle));
326       sparse_handles_t(b) = handle;
327     }
328 
329     // Fill in any gaps; we must provide an empty ST for batch entries
330     // the grouper didn't find.
331     if (visited.size() < N) {
332       Tensor empty_indices(DT_INT64, {0, rank - 1});
333       Tensor empty_values(DataTypeToEnum<T>::value, {0});
334       SparseTensor empty_st;
335       OP_REQUIRES_OK(context, SparseTensor::Create(empty_indices, empty_values,
336                                                    output_shape, &empty_st));
337 
338       for (int64_t b = 0; b < N; ++b) {
339         // We skipped this batch entry.
340         if (visited.find(b) == visited.end()) {
341           int64_t handle;
342           OP_REQUIRES_OK(context,
343                          map->AddSparseTensor(context, empty_st, &handle));
344           sparse_handles_t(b) = handle;
345         }
346       }
347     }
348 
349     context->set_output(0, sparse_handles);
350   }
351 };
352 
353 #define REGISTER_KERNELS(type)                              \
354   REGISTER_KERNEL_BUILDER(Name("AddManySparseToTensorsMap") \
355                               .Device(DEVICE_CPU)           \
356                               .TypeConstraint<type>("T"),   \
357                           AddManySparseToTensorsMapOp<type>)
358 
359 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
360 #undef REGISTER_KERNELS
361 
362 template <typename T>
363 class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp {
364  public:
TakeManySparseFromTensorsMapOp(OpKernelConstruction * context)365   explicit TakeManySparseFromTensorsMapOp(OpKernelConstruction* context)
366       : SparseTensorAccessingOp(context) {}
367 
Compute(OpKernelContext * context)368   void Compute(OpKernelContext* context) override {
369     SparseTensorsMap* map = nullptr;
370     OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map));
371 
372     const Tensor& sparse_handles = context->input(0);
373 
374     OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_handles.shape()),
375                 errors::InvalidArgument(
376                     "sparse_handles should be a vector but received shape ",
377                     sparse_handles.shape().DebugString()));
378 
379     int64_t N = sparse_handles.shape().dim_size(0);
380 
381     OP_REQUIRES(
382         context, N > 0,
383         errors::InvalidArgument("Must have at least 1 serialized SparseTensor, "
384                                 "but input matrix has 0 rows"));
385 
386     std::vector<Tensor> indices_to_concat;
387     std::vector<Tensor> values_to_concat;
388     std::vector<TensorShape> shapes_to_concat;
389 
390     const auto& sparse_handles_t = sparse_handles.vec<int64_t>();
391 
392     std::vector<SparseTensor> sparse_tensors;
393 
394     OP_REQUIRES_OK(context, map->RetrieveAndClearSparseTensors(
395                                 context, sparse_handles_t, &sparse_tensors));
396 
397     for (int64_t i = 0; i < N; ++i) {
398       const SparseTensor& st = sparse_tensors[i];
399       const Tensor& output_indices = st.indices();
400       const Tensor& output_values = st.values();
401       const auto output_shape = st.shape();
402 
403       OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
404                   errors::InvalidArgument(
405                       "Expected sparse_handles[", i,
406                       "] to represent an index matrix but received shape ",
407                       output_indices.shape().DebugString()));
408       OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
409                   errors::InvalidArgument(
410                       "Expected sparse_handles[", i,
411                       "] to represent a values vector but received shape ",
412                       output_values.shape().DebugString()));
413       OP_REQUIRES(
414           context, DataTypeToEnum<T>::value == output_values.dtype(),
415           errors::InvalidArgument(
416               "Requested SparseTensor of type ",
417               DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i,
418               "].values.dtype() == ", DataTypeString(output_values.dtype())));
419 
420       int64_t num_entries = output_indices.dim_size(0);
421       OP_REQUIRES(context, num_entries == output_values.dim_size(0),
422                   errors::InvalidArgument(
423                       "Expected row counts of SparseTensor[", i,
424                       "].indices and SparseTensor[", i,
425                       "].values to match but they do not: ", num_entries,
426                       " vs. ", output_values.dim_size(0)));
427       int rank = output_indices.dim_size(1);
428       OP_REQUIRES(
429           context, rank == output_shape.size(),
430           errors::InvalidArgument("Expected column counts of SparseTensor[", i,
431                                   "].indices to match size of SparseTensor[", i,
432                                   "].shape "
433                                   "but they do not: ",
434                                   rank, " vs. ", output_shape.size()));
435 
436       // Now we expand each SparseTensors' indices and shape by
437       // prefixing a dimension
438       Tensor expanded_indices(
439           DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)}));
440       Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
441       const auto& output_indices_t = output_indices.matrix<int64_t>();
442       auto expanded_indices_t = expanded_indices.matrix<int64_t>();
443       auto expanded_shape_t = expanded_shape.vec<int64_t>();
444       expanded_indices_t.chip<1>(0).setZero();
445       Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
446       Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
447       expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t;
448       expanded_shape_t(0) = 1;
449       // TODO: copy shape from TensorShape to &expanded_shape_t(1)
450       // std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
451       for (int i = 0; i < rank; ++i) {
452         expanded_shape_t(i + 1) = output_shape[i];
453       }
454       TensorShape expanded_tensor_shape(expanded_shape_t);
455 
456       indices_to_concat.push_back(std::move(expanded_indices));
457       values_to_concat.push_back(output_values);
458       shapes_to_concat.push_back(std::move(expanded_tensor_shape));
459     }
460 
461     int rank = -1;
462     for (int i = 0; i < N; ++i) {
463       if (rank < 0) rank = shapes_to_concat[i].dims();
464       OP_REQUIRES(context, rank == shapes_to_concat[i].dims(),
465                   errors::InvalidArgument(
466                       "Inconsistent rank across SparseTensors: rank prior to "
467                       "SparseTensor[",
468                       i, "] was: ", rank, " but rank of SparseTensor[", i,
469                       "] is: ", shapes_to_concat[i].dims()));
470     }
471 
472     // SparseTensor::Concat requires consistent shape for all but the
473     // primary order dimension (dimension 0 in this case).  So we get
474     // the maximum value across all the input SparseTensors for each
475     // dimension and use that.
476     TensorShape preconcat_shape(shapes_to_concat[0]);
477     for (int i = 0; i < N; ++i) {
478       for (int d = 0; d < rank; ++d) {
479         preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d),
480                                             shapes_to_concat[i].dim_size(d)));
481       }
482     }
483 
484     // Dimension 0 is the primary dimension.
485     gtl::InlinedVector<int64_t, 8> std_order(rank);
486     std::iota(std_order.begin(), std_order.end(), 0);
487 
488     std::vector<SparseTensor> tensors_to_concat;
489     tensors_to_concat.reserve(N);
490     for (int i = 0; i < N; ++i) {
491       SparseTensor tensor;
492       OP_REQUIRES_OK(context,
493                      SparseTensor::Create(std::move(indices_to_concat[i]),
494                                           std::move(values_to_concat[i]),
495                                           preconcat_shape, std_order, &tensor));
496       tensors_to_concat.push_back(std::move(tensor));
497     }
498 
499     auto output = SparseTensor::Concat<T>(tensors_to_concat);
500     Tensor final_output_shape(DT_INT64, TensorShape({output.dims()}));
501 
502     std::copy_n(output.shape().data(), output.dims(),
503                 final_output_shape.vec<int64_t>().data());
504 
505     context->set_output(0, output.indices());
506     context->set_output(1, output.values());
507     context->set_output(2, final_output_shape);
508   }
509 };
510 
511 #define REGISTER_KERNELS(type)                                 \
512   REGISTER_KERNEL_BUILDER(Name("TakeManySparseFromTensorsMap") \
513                               .Device(DEVICE_CPU)              \
514                               .TypeConstraint<type>("dtype"),  \
515                           TakeManySparseFromTensorsMapOp<type>)
516 
517 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
518 #undef REGISTER_KERNELS
519 
520 }  // namespace tensorflow
521