xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/approx_topk_shape.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/approx_topk_shape.h"
17 
18 #include <algorithm>
19 
20 #include "tensorflow/compiler/xla/util.h"
21 
22 // Used by rank 2+ operands
23 const uint64_t kTpuLaneTiling = 128;
24 // Used by rank 1 operands.
25 const uint64_t kTpuChunkTiling = 1024;
26 
27 namespace xla {
28 
log2_floor(uint64_t value)29 inline uint32_t log2_floor(uint64_t value) {
30   return value == 0 ? 0 : Log2Floor(value);
31 }
32 
log2_ceil(uint64_t value)33 inline uint32_t log2_ceil(uint64_t value) {
34   return value == 0 ? 0 : Log2Ceiling(value);
35 }
36 
37 // LINT.IfChange
ApproxTopKReductionOutputSize(int64_t input_size,int64_t rank,int64_t top_k,float recall_target,bool aggregate_to_topk,int64_t input_size_override)38 StatusOr<std::pair<int64_t, int64_t>> ApproxTopKReductionOutputSize(
39     int64_t input_size, int64_t rank, int64_t top_k, float recall_target,
40     bool aggregate_to_topk, int64_t input_size_override) {
41   if (aggregate_to_topk) {
42     return std::pair<int64_t, int64_t>(top_k, -1);
43   }
44 
45   uint64_t tpu_tiling = rank == 1 ? kTpuChunkTiling : kTpuLaneTiling;
46 
47   if (input_size <= tpu_tiling) {
48     return std::pair<int64_t, int64_t>(input_size, 0);
49   }
50 
51   if (input_size_override >= 0) {
52     if (input_size > input_size_override) {
53       return InvalidArgument(
54           "reduction_input_size_override: %d should be greater "
55           "equals to operands[reduction_dim]: %d",
56           input_size_override, input_size);
57     }
58   }
59   uint64_t logical_input_size =
60       input_size_override >= 0 ? input_size_override : input_size;
61 
62   // Reduce to the tiling size when k == 1.
63   if (top_k == 1) {
64     uint32_t log2_reduction =
65         log2_ceil(CeilOfRatio(logical_input_size, tpu_tiling));
66     return std::pair<int64_t, int64_t>(tpu_tiling, log2_reduction);
67   }
68 
69   // Need to handle 1.0 explicitly, otherwise we would encounter division by
70   // log(1.0) = 0 issue.
71   if (recall_target == 1.0) {
72     return std::pair<int64_t, int64_t>(input_size, 0);
73   }
74 
75   if (recall_target <= 0. || recall_target > 1.0) {
76     return InvalidArgument("recall_target should range in (0,1]");
77   }
78 
79   // Given number of data points N, K for top-k elements, and W for the size of
80   // the reduce window, let M = Ceil(N / W) be the number of windows. The
81   // expected number of top-k elements that doesn't collide in windows is
82   //
83   //   K * ((M - 1) / M)^{K - 1}
84   //
85   // The recall of is the expected number of top-k elements divided by K
86   //
87   //   recall = ((M - 1) / M)^{K - 1}
88   //          = (1 - 1/M)^{K - 1}
89   //          = (1 - 1/M)^{-M * (K - 1)/(-M)}
90   //          ~= EXP((1 - K) / M)    for large M
91   //
92   //   => M = (1 - K)/LOG(recall)
93   uint64_t m = std::min<uint64_t>(
94       std::max(
95           static_cast<uint64_t>((1.0 - top_k) /
96                                 std::log(static_cast<double>(recall_target))),
97           tpu_tiling),
98       input_size);
99   uint32_t log2_reduction = log2_floor(logical_input_size / m);
100   if (log2_reduction == 0) {
101     return std::pair<int64_t, int64_t>(input_size, 0);
102   }
103 
104   // Do not reduce too much when the logical_input is too large.
105   log2_reduction =
106       std::min<uint32_t>(log2_reduction, log2_ceil(input_size / tpu_tiling));
107 
108   int64_t approx_output_size =
109       CeilOfRatio<int64_t>(CeilOfRatio<int64_t>(input_size, tpu_tiling),
110                            (1 << log2_reduction)) *
111       tpu_tiling;
112 
113   return std::pair<int64_t, int64_t>(approx_output_size, log2_reduction);
114 }
115 // LINT.ThenChange(//tensorflow/core/ops/nn_ops.cc)
116 
117 }  // namespace xla
118