xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/approx_topk_shape.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_
18 
19 #include "tensorflow/compiler/xla/statusor.h"
20 
21 namespace xla {
22 
23 // Determine the output size of the reduction dimension. This is useful for jax
24 // abstract eval to determine the output size.
25 //
26 // input_size: Input size of the reduction dimension.
27 // rank: Rank of the input operand.
28 // top_k: Determines the k in top-k operation.
29 // recall_target: Valid range (0, 1]. User can trade-off quality and performance
30 //   with this knob.
31 // aggregate_to_topk: When true, sorts the set of approximate top-k elements and
32 //   only keep the final k elements on TPU. This option is useful when user
33 //   wanted to forward the approximate results to host and aggregate the results
34 //   on CPU for better throughput.
35 //
36 // Returns a pair of
37 //   1. Reduction output size
38 //   2. Reduction amount in log2 form.
39 //
40 // 2. is invalid and set to -1 when the approximate output is disabled, i.e.
41 //   top_k = 1 or aggregate_to_topk = true.
42 StatusOr<std::pair<int64_t, int64_t>> ApproxTopKReductionOutputSize(
43     int64_t input_size, int64_t rank, int64_t top_k, float recall_target,
44     bool aggregate_to_topk, int64_t input_size_override = -1);
45 
46 }  // namespace xla
47 
48 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_
49