xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/count_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <limits>
18 
19 #define EIGEN_USE_THREADS
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/op_requires.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_types.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 
32 // Don't allocate too large `BatchedMap<T>` objects
33 static int kMaxBatches = std::numeric_limits<int>::max();
34 
35 template <class T>
36 using BatchedMap = std::vector<absl::flat_hash_map<int64_t, T>>;
37 
38 namespace {
39 // TODO(momernick): Extend this function to work with outputs of rank > 2.
40 template <class T>
OutputSparse(const BatchedMap<T> & per_batch_counts,int64_t num_values,bool is_1d,OpKernelContext * context)41 Status OutputSparse(const BatchedMap<T>& per_batch_counts, int64_t num_values,
42                     bool is_1d, OpKernelContext* context) {
43   int total_values = 0;
44   int num_batches = per_batch_counts.size();
45   for (const auto& per_batch_count : per_batch_counts) {
46     total_values += per_batch_count.size();
47   }
48 
49   Tensor* indices;
50   int inner_dim = is_1d ? 1 : 2;
51   TF_RETURN_IF_ERROR(context->allocate_output(
52       0, TensorShape({total_values, inner_dim}), &indices));
53 
54   Tensor* values;
55   TF_RETURN_IF_ERROR(
56       context->allocate_output(1, TensorShape({total_values}), &values));
57 
58   auto output_indices = indices->matrix<int64_t>();
59   auto output_values = values->flat<T>();
60   int64_t value_loc = 0;
61   for (int b = 0; b < num_batches; ++b) {
62     const auto& per_batch_count = per_batch_counts[b];
63     std::vector<std::pair<int64_t, T>> pairs(per_batch_count.begin(),
64                                              per_batch_count.end());
65     std::sort(pairs.begin(), pairs.end());
66     for (const auto& x : pairs) {
67       if (is_1d) {
68         output_indices(value_loc, 0) = x.first;
69       } else {
70         output_indices(value_loc, 0) = b;
71         output_indices(value_loc, 1) = x.first;
72       }
73       output_values(value_loc) = x.second;
74       ++value_loc;
75     }
76   }
77   Tensor* dense_shape;
78   if (is_1d) {
79     TF_RETURN_IF_ERROR(
80         context->allocate_output(2, TensorShape({1}), &dense_shape));
81     dense_shape->flat<int64_t>().data()[0] = num_values;
82   } else {
83     TF_RETURN_IF_ERROR(
84         context->allocate_output(2, TensorShape({2}), &dense_shape));
85     dense_shape->flat<int64_t>().data()[0] = num_batches;
86     dense_shape->flat<int64_t>().data()[1] = num_values;
87   }
88 
89   return OkStatus();
90 }
91 
GetOutputSize(int64_t max_seen,int64_t max_length,int64_t min_length)92 int64_t GetOutputSize(int64_t max_seen, int64_t max_length,
93                       int64_t min_length) {
94   return max_length >= 0 ? max_length : std::max((max_seen + 1), min_length);
95 }
96 
97 }  // namespace
98 
99 template <class T, class W>
100 class DenseCount : public OpKernel {
101  public:
DenseCount(OpKernelConstruction * context)102   explicit DenseCount(OpKernelConstruction* context) : OpKernel(context) {
103     OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
104     OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
105     OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
106   }
107 
Compute(OpKernelContext * context)108   void Compute(OpKernelContext* context) override {
109     const Tensor& data = context->input(0);
110     const Tensor& weights = context->input(1);
111     bool use_weights = weights.NumElements() > 0;
112 
113     OP_REQUIRES(context,
114                 TensorShapeUtils::IsVector(data.shape()) ||
115                     TensorShapeUtils::IsMatrix(data.shape()),
116                 errors::InvalidArgument(
117                     "Input must be a 1 or 2-dimensional tensor. Got: ",
118                     data.shape().DebugString()));
119 
120     // Ensure all values are non-negative.
121     const auto data_values = data.flat<T>();
122     Eigen::TensorFixedSize<bool, Eigen::Sizes<>, Eigen::RowMajor> nonnegative;
123     nonnegative.device(context->eigen_cpu_device()) =
124         (data_values >= static_cast<T>(0)).all();
125     OP_REQUIRES(
126         context, nonnegative(),
127         errors::InvalidArgument("Input values must all be non-negative"));
128 
129     if (use_weights) {
130       OP_REQUIRES(
131           context, weights.shape() == data.shape(),
132           errors::InvalidArgument(
133               "Weights and data must have the same shape. Weight shape: ",
134               weights.shape().DebugString(),
135               "; data shape: ", data.shape().DebugString()));
136     }
137 
138     bool is_1d = TensorShapeUtils::IsVector(data.shape());
139     int negative_valued_axis = -1;
140     int num_batch_dimensions = (data.shape().dims() + negative_valued_axis);
141 
142     int num_batch_elements = 1;
143     for (int i = 0; i < num_batch_dimensions; ++i) {
144       OP_REQUIRES(context, data.shape().dim_size(i) != 0,
145                   errors::InvalidArgument(
146                       "Invalid input: Shapes dimension cannot be 0."));
147       num_batch_elements *= data.shape().dim_size(i);
148     }
149     int num_value_elements = data.shape().num_elements() / num_batch_elements;
150     auto per_batch_counts = BatchedMap<W>(num_batch_elements);
151 
152     T max_value = 0;
153 
154     const auto weight_values = weights.flat<W>();
155     int i = 0;
156     for (int b = 0; b < num_batch_elements; ++b) {
157       for (int v = 0; v < num_value_elements; ++v) {
158         const auto& value = data_values(i);
159         if (maxlength_ < 0 || value < maxlength_) {
160           if (binary_output_) {
161             per_batch_counts[b][value] = 1;
162           } else if (use_weights) {
163             per_batch_counts[b][value] += weight_values(i);
164           } else {
165             per_batch_counts[b][value]++;
166           }
167           if (value > max_value) {
168             max_value = value;
169           }
170         }
171         ++i;
172       }
173     }
174 
175     int64_t num_output_values =
176         GetOutputSize(max_value, maxlength_, minlength_);
177     OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
178                                             is_1d, context));
179   }
180 
181  private:
182   int64_t maxlength_;
183   int64_t minlength_;
184   bool binary_output_;
185 };
186 
187 template <class T, class W>
188 class SparseCount : public OpKernel {
189  public:
SparseCount(OpKernelConstruction * context)190   explicit SparseCount(OpKernelConstruction* context) : OpKernel(context) {
191     OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
192     OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
193     OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
194   }
195 
Compute(OpKernelContext * context)196   void Compute(OpKernelContext* context) override {
197     const Tensor& indices = context->input(0);
198     const Tensor& values = context->input(1);
199     const Tensor& shape = context->input(2);
200     const Tensor& weights = context->input(3);
201     bool use_weights = weights.NumElements() > 0;
202 
203     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
204                 errors::InvalidArgument(
205                     "Input indices must be a 2-dimensional tensor. Got: ",
206                     indices.shape().DebugString()));
207     OP_REQUIRES(context, TensorShapeUtils::IsVector(values.shape()),
208                 errors::InvalidArgument("Input values must be a vector. Got: ",
209                                         values.shape().DebugString()));
210     OP_REQUIRES(context, TensorShapeUtils::IsVector(shape.shape()),
211                 errors::InvalidArgument("Input shape must be a vector. Got: ",
212                                         shape.shape().DebugString()));
213     OP_REQUIRES(context,
214                 values.shape().dim_size(0) == indices.shape().dim_size(0),
215                 errors::InvalidArgument(
216                     "Number of values must match first dimension of indices.",
217                     "Got ", values.shape().dim_size(0),
218                     " values, indices shape: ", indices.shape().DebugString()));
219     OP_REQUIRES(
220         context, shape.shape().dim_size(0) == indices.shape().dim_size(1),
221         errors::InvalidArgument(
222             "Number of dimensions must match second dimension of indices.",
223             "Got ", shape.shape().dim_size(0),
224             " dimensions, indices shape: ", indices.shape().DebugString()));
225     OP_REQUIRES(context, shape.NumElements() > 0,
226                 errors::InvalidArgument(
227                     "The shape argument requires at least one element."));
228     // Validate indices: each index must be valid for the corresponding
229     // dimension. This could be possibly done better.
230     const auto indices_values = indices.matrix<int64_t>();
231     const auto shape_vector = shape.vec<int64_t>();
232     int num_values = values.NumElements();  // same as first dim of indices
233     int rank = indices.shape().dim_size(1);
234     for (int i = 0; i < num_values; ++i) {
235       for (int j = 0; j < rank; ++j) {
236         OP_REQUIRES(
237             context,
238             indices_values(i, j) >= 0 && indices_values(i, j) < shape_vector(j),
239             errors::InvalidArgument(
240                 "Invalid index value at ", i, ": dimension ", j, " has value ",
241                 indices_values(i, j), " which is not in [0, ", shape_vector(j),
242                 ") (as given by dense shape ", shape.DebugString()));
243       }
244     }
245 
246     // Ensure all values are non-negative.
247     const auto values_values = values.flat<T>();
248     Eigen::TensorFixedSize<bool, Eigen::Sizes<>, Eigen::RowMajor> nonnegative;
249     nonnegative.device(context->eigen_cpu_device()) =
250         (values_values >= static_cast<T>(0)).all();
251     OP_REQUIRES(
252         context, nonnegative(),
253         errors::InvalidArgument("Input values must all be non-negative"));
254 
255     if (use_weights) {
256       OP_REQUIRES(
257           context, weights.shape() == values.shape(),
258           errors::InvalidArgument(
259               "Weights and values must have the same shape. Weight shape: ",
260               weights.shape().DebugString(),
261               "; values shape: ", values.shape().DebugString()));
262     }
263 
264     bool is_1d = shape.NumElements() == 1;
265     int num_batches = is_1d ? 1 : shape_vector(0);
266     OP_REQUIRES(
267         context, 0 < num_batches && num_batches < kMaxBatches,
268         errors::InvalidArgument("Cannot allocate ", num_batches,
269                                 " batches, is the dense shape too wide?"));
270 
271     const auto weight_values = weights.flat<W>();
272 
273     auto per_batch_counts = BatchedMap<W>(num_batches);
274 
275     T max_value = 0;
276 
277     for (int idx = 0; idx < num_values; ++idx) {
278       int batch = is_1d ? 0 : indices_values(idx, 0);
279       if (batch >= num_batches) {
280         OP_REQUIRES(context, batch < num_batches,
281                     errors::InvalidArgument(
282                         "Indices value along the first dimension must be ",
283                         "lower than the first index of the shape.", "Got ",
284                         batch, " as batch and ", num_batches,
285                         " as the first dimension of the shape."));
286       }
287       const auto& value = values_values(idx);
288       if (maxlength_ < 0 || value < maxlength_) {
289         if (binary_output_) {
290           per_batch_counts[batch][value] = 1;
291         } else if (use_weights) {
292           per_batch_counts[batch][value] += weight_values(idx);
293         } else {
294           per_batch_counts[batch][value]++;
295         }
296         if (value > max_value) {
297           max_value = value;
298         }
299       }
300     }
301 
302     int64_t num_output_values =
303         GetOutputSize(max_value, maxlength_, minlength_);
304     OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
305                                             is_1d, context));
306   }
307 
308  private:
309   int64_t maxlength_;
310   int64_t minlength_;
311   bool binary_output_;
312   bool validate_;
313 };
314 
315 template <class T, class W>
316 class RaggedCount : public OpKernel {
317  public:
RaggedCount(OpKernelConstruction * context)318   explicit RaggedCount(OpKernelConstruction* context) : OpKernel(context) {
319     OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
320     OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
321     OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
322   }
323 
Compute(OpKernelContext * context)324   void Compute(OpKernelContext* context) override {
325     const Tensor& splits = context->input(0);
326     const Tensor& values = context->input(1);
327     const Tensor& weights = context->input(2);
328     bool use_weights = weights.NumElements() > 0;
329     bool is_1d = false;
330 
331     if (use_weights) {
332       OP_REQUIRES(
333           context, weights.shape() == values.shape(),
334           errors::InvalidArgument(
335               "Weights and values must have the same shape. Weight shape: ",
336               weights.shape().DebugString(),
337               "; values shape: ", values.shape().DebugString()));
338     }
339 
340     const auto splits_values = splits.flat<int64_t>();
341     const auto values_values = values.flat<T>();
342     const auto weight_values = weights.flat<W>();
343     int num_batches = splits.NumElements() - 1;
344     int num_values = values.NumElements();
345 
346     OP_REQUIRES(
347         context, num_batches > 0,
348         errors::InvalidArgument(
349             "Must provide at least 2 elements for the splits argument"));
350     OP_REQUIRES(context, splits_values(0) == 0,
351                 errors::InvalidArgument("Splits must start with 0, not with ",
352                                         splits_values(0)));
353     OP_REQUIRES(context, splits_values(num_batches) == num_values,
354                 errors::InvalidArgument(
355                     "Splits must end with the number of values, got ",
356                     splits_values(num_batches), " instead of ", num_values));
357 
358     // Ensure all values are non-negative.
359     Eigen::TensorFixedSize<bool, Eigen::Sizes<>, Eigen::RowMajor> nonnegative;
360     nonnegative.device(context->eigen_cpu_device()) =
361         (values_values >= static_cast<T>(0)).all();
362     OP_REQUIRES(
363         context, nonnegative(),
364         errors::InvalidArgument("Input values must all be non-negative"));
365 
366     auto per_batch_counts = BatchedMap<W>(num_batches);
367     T max_value = 0;
368     int batch_idx = 0;
369 
370     for (int idx = 0; idx < num_values; ++idx) {
371       while (idx >= splits_values(batch_idx)) {
372         batch_idx++;
373       }
374       const auto& value = values_values(idx);
375       if (maxlength_ < 0 || value < maxlength_) {
376         if (binary_output_) {
377           per_batch_counts[batch_idx - 1][value] = 1;
378         } else if (use_weights) {
379           per_batch_counts[batch_idx - 1][value] += weight_values(idx);
380         } else {
381           per_batch_counts[batch_idx - 1][value]++;
382         }
383         if (value > max_value) {
384           max_value = value;
385         }
386       }
387     }
388 
389     int64_t num_output_values =
390         GetOutputSize(max_value, maxlength_, minlength_);
391     OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
392                                             is_1d, context));
393   }
394 
395  private:
396   int64_t maxlength_;
397   int64_t minlength_;
398   bool binary_output_;
399   bool validate_;
400 };
401 
402 #define REGISTER_W(W_TYPE) \
403   REGISTER(int32, W_TYPE)  \
404   REGISTER(int64_t, W_TYPE)
405 
406 #define REGISTER(I_TYPE, W_TYPE)                                     \
407                                                                      \
408   REGISTER_KERNEL_BUILDER(Name("DenseCountSparseOutput")             \
409                               .TypeConstraint<I_TYPE>("T")           \
410                               .TypeConstraint<W_TYPE>("output_type") \
411                               .Device(DEVICE_CPU),                   \
412                           DenseCount<I_TYPE, W_TYPE>)                \
413                                                                      \
414   REGISTER_KERNEL_BUILDER(Name("SparseCountSparseOutput")            \
415                               .TypeConstraint<I_TYPE>("T")           \
416                               .TypeConstraint<W_TYPE>("output_type") \
417                               .Device(DEVICE_CPU),                   \
418                           SparseCount<I_TYPE, W_TYPE>)               \
419                                                                      \
420   REGISTER_KERNEL_BUILDER(Name("RaggedCountSparseOutput")            \
421                               .TypeConstraint<I_TYPE>("T")           \
422                               .TypeConstraint<W_TYPE>("output_type") \
423                               .Device(DEVICE_CPU),                   \
424                           RaggedCount<I_TYPE, W_TYPE>)
425 
426 TF_CALL_INTEGRAL_TYPES(REGISTER_W);
427 TF_CALL_float(REGISTER_W);
428 TF_CALL_double(REGISTER_W);
429 
430 #undef REGISTER_W
431 #undef REGISTER
432 
433 }  // namespace tensorflow
434