xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/clustering_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License.  You may obtain a copy
5 // of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 // ==============================================================================
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <algorithm>
19 #include <memory>
20 #include <numeric>
21 #include <tuple>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/lib/gtl/top_n.h"
33 #include "tensorflow/core/lib/random/philox_random.h"
34 #include "tensorflow/core/lib/random/simple_philox.h"
35 #include "tensorflow/core/platform/byte_order.h"
36 #include "tensorflow/core/platform/cpu_info.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace tensorflow {
40 namespace {
41 using errors::InvalidArgument;
42 
43 template <typename Scalar>
44 using RowMajorMatrix =
45     Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
46 
47 using MatrixXfRowMajor = RowMajorMatrix<float>;
48 using MatrixXi64RowMajor = RowMajorMatrix<int64_t>;
49 
50 // Ideally this should be computed by dividing L3 cache size by the number of
51 // physical CPUs. Since there isn't a portable method to do this, we are using
52 // a conservative estimate here.
53 const int64_t kDefaultL3CachePerCpu = 1 << 20;
54 
55 // These values were determined by performing a parameter sweep on the
56 // NearestNeighborsOp benchmark.
57 const int64_t kNearestNeighborsCentersMaxBlockSize = 1024;
58 const int64_t kNearestNeighborsPointsMinBlockSize = 16;
59 
60 // Returns the smallest multiple of a that is not smaller than b.
NextMultiple(int64_t a,int64_t b)61 int64_t NextMultiple(int64_t a, int64_t b) {
62   const int64_t remainder = b % a;
63   return remainder == 0 ? b : (b + a - remainder);
64 }
65 
66 // Returns a / b rounded up to the next higher integer.
CeilOfRatio(int64_t a,int64_t b)67 int64_t CeilOfRatio(int64_t a, int64_t b) { return (a + b - 1) / b; }
68 
69 }  // namespace
70 
71 // Implementation of K-means++ initialization. Samples points iteratively in
72 // proportion to the squared distances from selected points.
73 // TODO(ands): Add support for other distance metrics.
74 class KmeansPlusPlusInitializationOp : public OpKernel {
75  public:
KmeansPlusPlusInitializationOp(OpKernelConstruction * context)76   explicit KmeansPlusPlusInitializationOp(OpKernelConstruction* context)
77       : OpKernel(context) {
78     OP_REQUIRES_OK(context,
79                    context->MatchSignature(
80                        {DT_FLOAT, DT_INT64, DT_INT64, DT_INT64}, {DT_FLOAT}));
81   }
82 
Compute(OpKernelContext * context)83   void Compute(OpKernelContext* context) override {
84     const Tensor& points_tensor = context->input(0);
85     const Tensor& num_to_sample_tensor = context->input(1);
86     const Tensor& seed_tensor = context->input(2);
87     const Tensor& num_retries_per_sample_tensor = context->input(3);
88 
89     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
90                 InvalidArgument("Input points should be a matrix."));
91     OP_REQUIRES(context,
92                 TensorShapeUtils::IsScalar(num_to_sample_tensor.shape()),
93                 InvalidArgument("Input num_to_sample should be a scalar."));
94     OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
95                 InvalidArgument("Input seed should be a scalar."));
96     OP_REQUIRES(
97         context,
98         TensorShapeUtils::IsScalar(num_retries_per_sample_tensor.shape()),
99         InvalidArgument("Input num_retries_per_sample should be a scalar."));
100 
101     const int64_t num_points = points_tensor.dim_size(0);
102     const int64_t point_dimensions = points_tensor.dim_size(1);
103     const int64_t num_to_sample = num_to_sample_tensor.scalar<int64_t>()();
104     const int64_t seed = seed_tensor.scalar<int64_t>()();
105     const int64_t num_retries_per_sample = [&]() {
106       const int64_t value = num_retries_per_sample_tensor.scalar<int64_t>()();
107       return value >= 0 ? value
108                         : 2 + static_cast<int64_t>(std::log(num_to_sample));
109     }();
110 
111     OP_REQUIRES(context, num_points > 0,
112                 InvalidArgument("Expected points.rows() > 0."));
113     OP_REQUIRES(context, num_to_sample > 0,
114                 InvalidArgument("Expected num_to_sample > 0."));
115     OP_REQUIRES(context, num_to_sample <= num_points,
116                 InvalidArgument("Expected num_to_sample <= points.rows(). ",
117                                 num_to_sample, " vs ", num_points, "."));
118 
119     Tensor* output_sampled_points_tensor;
120     OP_REQUIRES_OK(context,
121                    context->allocate_output(
122                        0, TensorShape({num_to_sample, point_dimensions}),
123                        &output_sampled_points_tensor));
124 
125     const Eigen::Map<const MatrixXfRowMajor> points(
126         points_tensor.matrix<float>().data(), num_points, point_dimensions);
127     const Eigen::VectorXf points_half_squared_norm =
128         0.5 * points.rowwise().squaredNorm();
129 
130     Eigen::Map<MatrixXfRowMajor> sampled_points(
131         output_sampled_points_tensor->matrix<float>().data(), num_to_sample,
132         point_dimensions);
133     std::unordered_set<int64_t> sampled_indices;
134 
135     random::PhiloxRandom random(seed);
136     random::SimplePhilox rng(&random);
137 
138     auto add_one_point = [&](int64_t from, int64_t to) {
139       from = std::min(from, num_points - 1);
140       sampled_points.row(to) = points.row(from);
141       sampled_indices.insert(from);
142     };
143 
144     // Distances from all points to nearest selected point. Initialize with
145     // distances to first selected point.
146     Eigen::VectorXf min_distances(num_points);
147     min_distances.fill(std::numeric_limits<float>::infinity());
148     Eigen::VectorXf min_distances_cumsum(num_points);
149 
150     auto draw_one_sample = [&]() -> int64 {
151       if (sampled_indices.empty()) return rng.Uniform64(num_points);
152       int64_t index = 0;
153       do {
154         // If v is drawn from Uniform[0, distances.sum()), then
155         // Prob[cumsum(distances)(i - 1) <= v < cumsum(distances)(i)] is
156         // proportional to distances(i).
157         index = std::upper_bound(
158                     min_distances_cumsum.data(),
159                     min_distances_cumsum.data() + num_points,
160                     rng.RandFloat() * min_distances_cumsum(num_points - 1)) -
161                 min_distances_cumsum.data();
162       } while (sampled_indices.find(index) != sampled_indices.end());
163       return index;
164     };
165 
166     auto sample_one_point = [&]() {
167       const int64_t sampled_index = draw_one_sample();
168       min_distances = min_distances.cwiseMin(GetHalfSquaredDistancesToY(
169           points, points_half_squared_norm, points.row(sampled_index),
170           points_half_squared_norm(sampled_index)));
171       return sampled_index;
172     };
173 
174     auto sample_one_point_with_retries = [&]() {
175       Eigen::VectorXf best_new_min_distances(num_points);
176       float best_potential = std::numeric_limits<float>::infinity();
177       int64_t best_sampled_index = 0;
178       for (int i = 1 + num_retries_per_sample; i > 0; --i) {
179         const int64_t sampled_index = draw_one_sample();
180         Eigen::VectorXf new_min_distances =
181             min_distances.cwiseMin(GetHalfSquaredDistancesToY(
182                 points, points_half_squared_norm, points.row(sampled_index),
183                 points_half_squared_norm(sampled_index)));
184         const float potential = new_min_distances.sum();
185         if (potential < best_potential) {
186           best_potential = potential;
187           best_sampled_index = sampled_index;
188           best_new_min_distances.swap(new_min_distances);
189         }
190       }
191       min_distances.swap(best_new_min_distances);
192       return best_sampled_index;
193     };
194 
195     for (int64_t i = 0; i < num_to_sample; ++i) {
196       if (i > 0) {
197         std::partial_sum(min_distances.data(),
198                          min_distances.data() + num_points,
199                          min_distances_cumsum.data());
200       }
201       int64_t next = num_retries_per_sample == 0
202                          ? sample_one_point()
203                          : sample_one_point_with_retries();
204       add_one_point(next, i);
205     }
206   }
207 
208  private:
209   // Returns a column vector with the i-th element set to half the squared
210   // euclidean distance between the i-th row of xs, and y. Precomputed norms for
211   // each row of xs and y must be provided for efficiency.
212   // TODO(ands): Parallelize this for large xs.
GetHalfSquaredDistancesToY(const Eigen::Ref<const MatrixXfRowMajor> & xs,const Eigen::Ref<const Eigen::VectorXf> & xs_half_squared_norm,const Eigen::Ref<const Eigen::RowVectorXf> & y,float y_half_squared_norm)213   static Eigen::VectorXf GetHalfSquaredDistancesToY(
214       const Eigen::Ref<const MatrixXfRowMajor>& xs,
215       const Eigen::Ref<const Eigen::VectorXf>& xs_half_squared_norm,
216       const Eigen::Ref<const Eigen::RowVectorXf>& y,
217       float y_half_squared_norm) {
218     // Squared distance between points xs_i and y is:
219     //   || xs_i ||^2 - 2 <xs_i, y> + || y ||^2
220     return (xs_half_squared_norm - xs * y.transpose()).array() +
221            y_half_squared_norm;
222   }
223 };
224 
225 REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU),
226                         KmeansPlusPlusInitializationOp);
227 
228 // Implementation of one single Markov Chain for the k-MC^2 algorithm
229 class KMC2ChainInitializationOp : public OpKernel {
230  public:
KMC2ChainInitializationOp(OpKernelConstruction * context)231   explicit KMC2ChainInitializationOp(OpKernelConstruction* context)
232       : OpKernel(context) {
233     OP_REQUIRES_OK(context,
234                    context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64}));
235   }
236 
Compute(OpKernelContext * context)237   void Compute(OpKernelContext* context) override {
238     const Tensor& distances_tensor = context->input(0);
239     const Tensor& seed_tensor = context->input(1);
240     OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()),
241                 InvalidArgument("Input distances should be a vector."));
242     OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
243                 InvalidArgument("Input seed should be a scalar."));
244     const int64_t num_points = distances_tensor.dim_size(0);
245     const int64_t seed = seed_tensor.scalar<int64_t>()();
246     OP_REQUIRES(context, num_points > 0,
247                 InvalidArgument("Expected distances_tensor.size() > 0."));
248 
249     random::PhiloxRandom random(seed);
250     random::SimplePhilox rng(&random);
251 
252     auto distances = distances_tensor.flat<float>();
253     // Set the initial state of the Markov chain to be the first candidate.
254     int64_t selected_index = 0;
255     float selected_distance = distances(selected_index);
256     // Build a Markov chain of length num_points.
257     for (int64_t i = 1; i < num_points; ++i) {
258       const float candidate_distance = distances(i);
259       // Set the next state of the Markov chain to be the candidate with
260       // probability min(1, candidate_distance/selected_distance).
261       if (candidate_distance > rng.RandFloat() * selected_distance) {
262         selected_index = i;
263         selected_distance = candidate_distance;
264       }
265     }
266 
267     Tensor* output_sampled_index_tensor;
268     OP_REQUIRES_OK(context,
269                    context->allocate_output(0, TensorShape({}),
270                                             &output_sampled_index_tensor));
271     auto output = output_sampled_index_tensor->scalar<int64_t>();
272     // Return the last state of the Markov chain as the new center.
273     output() = selected_index;
274   }
275 };
276 
277 REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU),
278                         KMC2ChainInitializationOp);
279 
280 // Operator for computing the nearest neighbors for a set of points.
281 class NearestNeighborsOp : public OpKernel {
282  public:
NearestNeighborsOp(OpKernelConstruction * context)283   explicit NearestNeighborsOp(OpKernelConstruction* context)
284       : OpKernel(context) {
285     OP_REQUIRES_OK(context,
286                    context->MatchSignature({DT_FLOAT, DT_FLOAT, DT_INT64},
287                                            {DT_INT64, DT_FLOAT}));
288   }
289 
Compute(OpKernelContext * context)290   void Compute(OpKernelContext* context) override {
291     const Tensor& points_tensor = context->input(0);
292     const Tensor& centers_tensor = context->input(1);
293     const Tensor& k_tensor = context->input(2);
294 
295     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
296                 InvalidArgument("Input points should be a matrix."));
297     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(centers_tensor.shape()),
298                 InvalidArgument("Input centers should be a matrix."));
299     OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_tensor.shape()),
300                 InvalidArgument("Input k should be a scalar."));
301 
302     const int64_t num_points = points_tensor.dim_size(0);
303     const int64_t point_dimensions = points_tensor.dim_size(1);
304     const int64_t num_centers = centers_tensor.dim_size(0);
305     const int64_t center_dimensions = centers_tensor.dim_size(1);
306 
307     OP_REQUIRES(context, num_points > 0,
308                 InvalidArgument("Expected points.rows() > 0."));
309     OP_REQUIRES(
310         context, point_dimensions == center_dimensions,
311         InvalidArgument("Expected point_dimensions == center_dimensions: ",
312                         point_dimensions, " vs ", center_dimensions, "."));
313 
314     const Eigen::Map<const MatrixXfRowMajor> points(
315         points_tensor.matrix<float>().data(), num_points, point_dimensions);
316     const Eigen::Map<const MatrixXfRowMajor> centers(
317         centers_tensor.matrix<float>().data(), num_centers, center_dimensions);
318     const int64_t k =
319         std::min<int64_t>(num_centers, k_tensor.scalar<int64_t>()());
320 
321     Tensor* output_nearest_center_indices_tensor;
322     Tensor* output_nearest_center_distances_tensor;
323     OP_REQUIRES_OK(context, context->allocate_output(
324                                 0, TensorShape({num_points, k}),
325                                 &output_nearest_center_indices_tensor));
326     OP_REQUIRES_OK(context, context->allocate_output(
327                                 1, TensorShape({num_points, k}),
328                                 &output_nearest_center_distances_tensor));
329 
330     if (k == 0) return;
331 
332     Eigen::Map<MatrixXi64RowMajor> nearest_center_indices(
333         output_nearest_center_indices_tensor->matrix<int64_t>().data(),
334         num_points, k);
335     Eigen::Map<MatrixXfRowMajor> nearest_center_distances(
336         output_nearest_center_distances_tensor->matrix<float>().data(),
337         num_points, k);
338 
339     const Eigen::VectorXf centers_half_squared_norm =
340         0.5 * centers.rowwise().squaredNorm();
341 
342     // The distance computation is sharded to take advantage of multiple cores
343     // and to allow intermediate values to reside in L3 cache. This is done by
344     // sharding the points and centers as follows:
345     //
346     // 1. Centers are sharded such that each block of centers has at most
347     //    kNearestNeighborsCentersMaxBlockSize rows.
348     // 2. Points are sharded, and each block of points is multiplied with each
349     //    block of centers. The block size of points is chosen such that the
350     //    point coordinates (point_dimensions) and the matrix of distances to
351     //    each center in one block -- the intermediate data -- fits in L3 cache.
352     // 3. After performing each block-block distance computation, the results
353     //    are reduced to a set of k nearest centers as soon as possible. This
354     //    decreases total memory I/O.
355     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
356     const int64_t num_threads = worker_threads.num_threads;
357     // This kernel might be configured to use fewer than the total number of
358     // available CPUs on the host machine. To avoid destructive interference
359     // with other jobs running on the host machine, we must only use a fraction
360     // of total available L3 cache. Unfortunately, we cannot query the host
361     // machine to get the number of physical CPUs. So, we use a fixed per-CPU
362     // budget and scale it by the number of CPUs available to this operation.
363     const int64_t total_memory_budget =
364         kDefaultL3CachePerCpu * port::NumSchedulableCPUs();
365     // Compute the number of blocks into which rows of points must be split so
366     // that the distance matrix and the block of points can fit in cache. One
367     // row of points will yield a vector of distances to each center in a block.
368     const int64_t bytes_per_row =
369         (std::min(kNearestNeighborsCentersMaxBlockSize,
370                   num_centers) /* centers in a block */
371          + point_dimensions /* coordinates of one point */) *
372         sizeof(float);
373     // The memory needed for storing the centers being processed. This is shared
374     // by all workers. Adding slack to the number of threads to avoid incorrect
375     // cache eviction when a new block of centers is loaded.
376     const int64_t bytes_for_centers =
377         std::min(num_centers,
378                  (num_threads + 2) * kNearestNeighborsCentersMaxBlockSize) *
379         point_dimensions * sizeof(float);
380     // The memory budget available for workers to store their distance matrices.
381     const int64_t available_memory_budget =
382         total_memory_budget - bytes_for_centers;
383     // That memory budget is shared by all threads.
384     const int64_t rows_per_block = std::max<int64_t>(
385         kNearestNeighborsPointsMinBlockSize,
386         available_memory_budget / num_threads / bytes_per_row);
387     // Divide rows into almost uniformly-sized units of work that are small
388     // enough for the memory budget (rows_per_block). Round up to a multiple of
389     // the number of threads.
390     const int64_t num_units =
391         NextMultiple(num_threads, CeilOfRatio(num_points, rows_per_block));
392     auto work = [&](int64_t start, int64_t limit) {
393       for (; start < limit; ++start) {
394         const int64_t start_row = num_points * start / num_units;
395         const int64_t limit_row = num_points * (start + 1) / num_units;
396         DCHECK_LE(limit_row, num_points);
397         const int64_t num_rows = limit_row - start_row;
398         auto points_shard = points.middleRows(start_row, num_rows);
399         const Eigen::VectorXf points_half_squared_norm =
400             0.5 * points_shard.rowwise().squaredNorm();
401         auto nearest_center_indices_shard =
402             nearest_center_indices.middleRows(start_row, num_rows);
403         auto nearest_center_distances_shard =
404             nearest_center_distances.middleRows(start_row, num_rows);
405         FindKNearestCenters(k, points_shard, points_half_squared_norm, centers,
406                             centers_half_squared_norm,
407                             nearest_center_indices_shard,
408                             nearest_center_distances_shard);
409       }
410     };
411 
412     const int64_t units_per_thread = num_units / num_threads;
413     BlockingCounter counter(num_threads - 1);
414     for (int64_t i = 1; i < num_threads; ++i) {
415       const int64_t start = i * units_per_thread;
416       const int64_t limit = start + units_per_thread;
417       worker_threads.workers->Schedule([work, &counter, start, limit]() {
418         work(start, limit);
419         counter.DecrementCount();
420       });
421     }
422     work(0, units_per_thread);
423     counter.Wait();
424   }
425 
426  private:
FindKNearestCenters(int64_t k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,const Eigen::Ref<MatrixXi64RowMajor> & nearest_center_indices,const Eigen::Ref<MatrixXfRowMajor> & nearest_center_distances)427   static void FindKNearestCenters(
428       int64_t k, const Eigen::Ref<const MatrixXfRowMajor>& points,
429       const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
430       const Eigen::Ref<const MatrixXfRowMajor>& centers,
431       const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
432       const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
433       const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
434     DCHECK_LE(k, centers.rows());
435     if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
436       FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
437                                   centers_half_squared_norm,
438                                   nearest_center_indices,
439                                   nearest_center_distances);
440     } else {
441       FindKNearestCentersBlockwise(k, points, points_half_squared_norm, centers,
442                                    centers_half_squared_norm,
443                                    nearest_center_indices,
444                                    nearest_center_distances);
445     }
446   }
447 
FindKNearestCentersOneBlock(int64_t k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,Eigen::Ref<MatrixXfRowMajor> nearest_center_distances)448   static void FindKNearestCentersOneBlock(
449       int64_t k, const Eigen::Ref<const MatrixXfRowMajor>& points,
450       const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
451       const Eigen::Ref<const MatrixXfRowMajor>& centers,
452       const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
453       Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
454       Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
455     DCHECK_LE(k, centers.rows());
456     const int64_t num_points = points.rows();
457     const MatrixXfRowMajor inner_product = points * centers.transpose();
458     // Find nearest neighbors.
459     if (k == 1) {
460       for (int i = 0; i < num_points; ++i) {
461         int64_t index;
462         nearest_center_distances(i, 0) =
463             2.0 *
464             (points_half_squared_norm(i) +
465              (centers_half_squared_norm.transpose() - inner_product.row(i))
466                  .minCoeff(&index));
467         nearest_center_indices(i, 0) = index;
468       }
469     } else {
470       // Select k nearest centers for each point.
471       using Center = std::pair<float, int64_t>;
472       const int64_t num_centers = centers.rows();
473       gtl::TopN<Center, std::less<Center>> selector(k);
474       std::unique_ptr<std::vector<Center>> nearest_centers;
475       for (int i = 0; i < num_points; ++i) {
476         selector.reserve(num_centers);
477         for (int j = 0; j < num_centers; ++j) {
478           const float partial_distance =
479               centers_half_squared_norm(j) - inner_product(i, j);
480           selector.push(Center(partial_distance, j));
481         }
482         nearest_centers.reset(selector.Extract());
483         selector.Reset();
484         const float point_half_squared_norm = points_half_squared_norm(i);
485         for (int j = 0; j < k; ++j) {
486           const Center& center = (*nearest_centers)[j];
487           nearest_center_distances(i, j) =
488               2.0 * (point_half_squared_norm + center.first);
489           nearest_center_indices(i, j) = center.second;
490         }
491       }
492     }
493   }
494 
FindKNearestCentersBlockwise(int64_t k,const Eigen::Ref<const MatrixXfRowMajor> & points,const Eigen::Ref<const Eigen::VectorXf> & points_half_squared_norm,const Eigen::Ref<const MatrixXfRowMajor> & centers,const Eigen::Ref<const Eigen::VectorXf> & centers_half_squared_norm,Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,Eigen::Ref<MatrixXfRowMajor> nearest_center_distances)495   static void FindKNearestCentersBlockwise(
496       int64_t k, const Eigen::Ref<const MatrixXfRowMajor>& points,
497       const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
498       const Eigen::Ref<const MatrixXfRowMajor>& centers,
499       const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
500       Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
501       Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
502     const int64_t num_points = points.rows();
503     const int64_t num_centers = centers.rows();
504     DCHECK_LE(k, num_centers);
505     DCHECK_GT(num_centers, kNearestNeighborsCentersMaxBlockSize);
506     // Store nearest neighbors with first block of centers directly into the
507     // output matrices.
508     int64_t out_k = std::min(k, kNearestNeighborsCentersMaxBlockSize);
509     FindKNearestCentersOneBlock(
510         out_k, points, points_half_squared_norm,
511         centers.topRows(kNearestNeighborsCentersMaxBlockSize),
512         centers_half_squared_norm.head(kNearestNeighborsCentersMaxBlockSize),
513         nearest_center_indices, nearest_center_distances);
514     // Iteratively compute nearest neighbors with other blocks of centers, and
515     // update the output matrices.
516     MatrixXi64RowMajor block_nearest_center_indices(num_points, k);
517     MatrixXfRowMajor block_nearest_center_distances(num_points, k);
518     Eigen::Matrix<int64_t, 1, Eigen::Dynamic> merged_indices(k);
519     Eigen::Matrix<float, 1, Eigen::Dynamic> merged_distances(k);
520     for (int64_t centers_start = kNearestNeighborsCentersMaxBlockSize;
521          centers_start < num_centers;
522          centers_start += kNearestNeighborsCentersMaxBlockSize) {
523       const int64_t centers_block_size = std::min(
524           kNearestNeighborsCentersMaxBlockSize, num_centers - centers_start);
525       const int64_t block_k = std::min(k, centers_block_size);
526       FindKNearestCentersOneBlock(
527           block_k, points, points_half_squared_norm,
528           centers.middleRows(centers_start, centers_block_size),
529           centers_half_squared_norm.segment(centers_start, centers_block_size),
530           block_nearest_center_indices, block_nearest_center_distances);
531       if (k == 1) {
532         for (int i = 0; i < num_points; ++i) {
533           if (block_nearest_center_distances(i, 0) <
534               nearest_center_distances(i, 0)) {
535             nearest_center_indices(i, 0) =
536                 block_nearest_center_indices(i, 0) + centers_start;
537             nearest_center_distances(i, 0) =
538                 block_nearest_center_distances(i, 0);
539           }
540         }
541       } else {
542         for (int i = 0; i < num_points; ++i) {
543           // Merge and accumulate top-k list from block_nearest_center_indices
544           // into nearest_center_indices.
545           for (int64_t j_out = 0, j_block = 0, j_merged = 0;
546                (j_out < out_k || j_block < block_k) && j_merged < k;
547                ++j_merged) {
548             const float distance_out =
549                 j_out < out_k ? nearest_center_distances(i, j_out)
550                               : std::numeric_limits<float>::infinity();
551             const float distance_block =
552                 j_block < block_k ? block_nearest_center_distances(i, j_block)
553                                   : std::numeric_limits<float>::infinity();
554             if (distance_out <= distance_block) {
555               merged_indices(j_merged) = nearest_center_indices(i, j_out);
556               merged_distances(j_merged) = distance_out;
557               ++j_out;
558             } else {
559               merged_indices(j_merged) =
560                   block_nearest_center_indices(i, j_block) + centers_start;
561               merged_distances(j_merged) = distance_block;
562               ++j_block;
563             }
564           }
565           nearest_center_indices.row(i) = merged_indices;
566           nearest_center_distances.row(i) = merged_distances;
567           out_k = std::min(k, out_k + block_k);
568         }
569       }
570     }
571   }
572 };
573 
574 REGISTER_KERNEL_BUILDER(Name("NearestNeighbors").Device(DEVICE_CPU),
575                         NearestNeighborsOp);
576 
577 }  // namespace tensorflow
578