xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/approx_topk.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/approx_topk.h"
17 
18 #include <limits>
19 #include <string>
20 
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/client/lib/approx_topk_shape.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/shape.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 
31 // Used by rank 2+ operands
32 const uint64_t kTpuLaneTiling = 128;
33 // Used by rank 1 operands.
34 const uint64_t kTpuChunkTiling = 1024;
35 
36 namespace xla {
37 
38 namespace {
GetOperandTypes(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values)39 StatusOr<std::vector<PrimitiveType>> GetOperandTypes(
40     XlaBuilder* builder, absl::Span<const XlaOp> operands,
41     absl::Span<const XlaOp> init_values) {
42   std::vector<PrimitiveType> op_types;
43   auto num_operands = operands.size();
44   auto operands_shapes = builder->GetOperandShapes(operands).ValueOrDie();
45   auto init_values_shapes = builder->GetOperandShapes(init_values).ValueOrDie();
46   for (int i = 0; i < num_operands; ++i) {
47     const auto& op_shape = operands_shapes[i];
48     const auto& init_shape = init_values_shapes[i];
49     if (op_shape.rank() == 0) {
50       return InvalidArgument("ApproxTopK operands must have rank 1+.");
51     }
52     if (!ShapeUtil::CompatibleIgnoringElementType(operands_shapes[0],
53                                                   op_shape)) {
54       return InvalidArgument("operands shape mismatch: %s vs %s",
55                              operands_shapes[0].DebugString(),
56                              op_shape.DebugString());
57     }
58     if (op_shape.element_type() != init_shape.element_type()) {
59       return InvalidArgument("operands type mismatch: %s vs %s",
60                              op_shape.DebugString(), init_shape.DebugString());
61     }
62     op_types.push_back(op_shape.element_type());
63   }
64   return op_types;
65 }
66 }  // namespace
67 
68 // Converts a comparator to a combiner computation that can be fed to reduce or
69 // partial reduce ops.
BuildReductionComputation(XlaBuilder * builder,absl::Span<const PrimitiveType> op_types,const XlaComputation & comparator)70 XlaComputation BuildReductionComputation(
71     XlaBuilder* builder, absl::Span<const PrimitiveType> op_types,
72     const XlaComputation& comparator) {
73   auto num_operands = op_types.size();
74   std::vector<XlaOp> lhs_params;
75   std::vector<XlaOp> rhs_params;
76   int64_t param_number = 0;
77   lhs_params.reserve(num_operands);
78   rhs_params.reserve(num_operands);
79   auto reduction_builder = builder->CreateSubBuilder("ReductionFn");
80   for (const auto& op_type : op_types) {
81     lhs_params.push_back(Parameter(reduction_builder.get(), param_number,
82                                    ShapeUtil::MakeScalarShape(op_type),
83                                    absl::StrFormat("lhs.%d", param_number)));
84     param_number++;
85   }
86   for (const auto& op_type : op_types) {
87     rhs_params.push_back(Parameter(reduction_builder.get(), param_number,
88                                    ShapeUtil::MakeScalarShape(op_type),
89                                    absl::StrFormat("rhs.%d", param_number)));
90     param_number++;
91   }
92 
93   std::vector<XlaOp> comparator_args;
94   comparator_args.reserve(num_operands * 2);
95   for (int i = 0; i < num_operands; ++i) {
96     comparator_args.push_back(lhs_params[i]);
97     comparator_args.push_back(rhs_params[i]);
98   }
99   auto pred = Call(reduction_builder.get(), comparator, comparator_args);
100   std::vector<XlaOp> results;
101   results.reserve(num_operands);
102   for (int i = 0; i < num_operands; ++i) {
103     results.push_back(Select(pred, lhs_params[i], rhs_params[i]));
104   }
105   Tuple(reduction_builder.get(), results);
106   return reduction_builder->BuildAndNoteError();
107 }
108 
AggregateToTopKBuilder(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,int64_t top_k,int64_t reduction_dim,const XlaComputation & comparator)109 XlaOp AggregateToTopKBuilder(XlaBuilder* builder,
110                              absl::Span<const XlaOp> operands,
111                              absl::Span<const XlaOp> init_values, int64_t top_k,
112                              int64_t reduction_dim,
113                              const XlaComputation& comparator) {
114   auto operands_shapes = builder->GetOperandShapes(operands).ValueOrDie();
115   int64_t rank = operands_shapes[0].rank();
116   int64_t num_operands = operands.size();
117 
118   if (top_k == 1) {
119     auto status_or_optypes = GetOperandTypes(builder, operands, init_values);
120     if (!status_or_optypes.ok()) {
121       return builder->ReportError(status_or_optypes.status());
122     }
123     auto op_types = status_or_optypes.value();
124 
125     auto reduction_computation =
126         BuildReductionComputation(builder, op_types, comparator);
127     auto val_args = Reduce(builder, operands, init_values,
128                            reduction_computation, {reduction_dim});
129     Shape op_shape = operands_shapes[0];
130     op_shape.mutable_dimensions()[reduction_dim] = 1;
131     auto top1_vals =
132         Reshape(GetTupleElement(val_args, 0), op_shape.dimensions());
133     auto top1_args =
134         Reshape(GetTupleElement(val_args, 1), op_shape.dimensions());
135     return Tuple(builder, {top1_vals, top1_args});
136   }
137 
138   auto sorted_results = Sort(operands, comparator, reduction_dim);
139   std::vector<int64_t> slice_start_indices(rank, 0);
140   std::vector<int64_t> slice_limit_indices;
141   std::vector<int64_t> slice_strides(rank, 1);
142   slice_limit_indices.insert(slice_limit_indices.begin(),
143                              operands_shapes[0].dimensions().begin(),
144                              operands_shapes[0].dimensions().end());
145   slice_limit_indices[reduction_dim] = top_k;
146 
147   std::vector<XlaOp> sliced_results;
148   sliced_results.reserve(num_operands);
149   for (int i = 0; i < num_operands; ++i) {
150     sliced_results.push_back(Slice(GetTupleElement(sorted_results, i),
151                                    slice_start_indices, slice_limit_indices,
152                                    slice_strides));
153   }
154   return Tuple(builder, sliced_results);
155 }
156 
ApproxTopK(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,int64_t top_k,int64_t reduction_dim,const XlaComputation & comparator,float recall_target,bool aggregate_to_topk,int64_t reduction_input_size_override)157 XlaOp ApproxTopK(XlaBuilder* builder, absl::Span<const XlaOp> operands,
158                  absl::Span<const XlaOp> init_values, int64_t top_k,
159                  int64_t reduction_dim, const XlaComputation& comparator,
160                  float recall_target, bool aggregate_to_topk,
161                  int64_t reduction_input_size_override) {
162   // Validates shapes and ranks
163   if (operands.size() != init_values.size()) {
164     return builder->ReportError(
165         InvalidArgument("operands and init_values size mismatch: %d vs %d",
166                         operands.size(), init_values.size()));
167   }
168   auto num_operands = operands.size();
169   auto operands_shapes = builder->GetOperandShapes(operands).ValueOrDie();
170   auto init_values_shapes = builder->GetOperandShapes(init_values).ValueOrDie();
171   auto status_or_optypes = GetOperandTypes(builder, operands, init_values);
172   if (!status_or_optypes.ok()) {
173     return builder->ReportError(status_or_optypes.status());
174   }
175   auto op_types = status_or_optypes.value();
176   int64_t rank = operands_shapes[0].rank();
177   if (reduction_dim < 0 || reduction_dim >= rank) {
178     return builder->ReportError(
179         InvalidArgument("reduction_dim should range in [0,%d)", rank));
180   }
181 
182   auto reduction_computation =
183       BuildReductionComputation(builder, op_types, comparator);
184 
185   uint64_t tpu_tiling = rank == 1 ? kTpuChunkTiling : kTpuLaneTiling;
186   uint64_t n = operands_shapes[0].dimensions(reduction_dim);
187   // ApproxTopK can only reduce elements larger than the tiling.
188   if (n <= tpu_tiling) {
189     if (aggregate_to_topk) {
190       return AggregateToTopKBuilder(builder, operands, init_values, top_k,
191                                     reduction_dim, comparator);
192     }
193     return Tuple(builder, operands);
194   }
195 
196   auto status_or_approx_output_size = ApproxTopKReductionOutputSize(
197       n, rank, top_k, recall_target, /*aggregate_to_topk=*/false,
198       reduction_input_size_override);
199   if (!status_or_approx_output_size.status().ok()) {
200     return builder->ReportError(status_or_approx_output_size.status());
201   }
202 
203   int64_t approx_output_size, log2_reduction;
204   std::tie(approx_output_size, log2_reduction) =
205       status_or_approx_output_size.ValueOrDie();
206 
207   if (log2_reduction == 0) {
208     if (aggregate_to_topk) {
209       return AggregateToTopKBuilder(builder, operands, init_values, top_k,
210                                     reduction_dim, comparator);
211     }
212     return Tuple(builder, operands);
213   }
214 
215   std::vector<XlaOp> partial_reduce_args;
216   partial_reduce_args.reserve(operands.size() + init_values.size());
217   for (const auto& op : operands) {
218     partial_reduce_args.push_back(op);
219   }
220   for (const auto& op : init_values) {
221     partial_reduce_args.push_back(op);
222   }
223   std::vector<const Shape*> approx_output_shapes;
224   approx_output_shapes.reserve(operands_shapes.size());
225   for (auto& op_shape : operands_shapes) {
226     op_shape.mutable_dimensions()[reduction_dim] = approx_output_size;
227     approx_output_shapes.push_back(&op_shape);
228   }
229   auto approx_output_shape =
230       ShapeUtil::MakeTupleShapeWithPtrs(approx_output_shapes);
231   // PartialReduce option in JSON form
232   std::string partial_reduce_option = absl::StrFormat(
233       "{\"log2_reduction\": %d, \"reduction_dim\": %d, \"to_apply_type\": "
234       "\"comparator\"}",
235       log2_reduction, reduction_dim);
236 
237   auto approx_topk = CustomCallWithComputation(
238       builder, "PartialReduce", partial_reduce_args, comparator,
239       approx_output_shape, partial_reduce_option);
240 
241   if (aggregate_to_topk) {
242     std::vector<XlaOp> approx_topk_results;
243     approx_topk_results.reserve(num_operands);
244     for (int i = 0; i < num_operands; ++i) {
245       approx_topk_results.push_back(GetTupleElement(approx_topk, i));
246     }
247     return AggregateToTopKBuilder(builder, approx_topk_results, init_values,
248                                   top_k, reduction_dim, comparator);
249   }
250   return approx_topk;
251 }
252 
ApproxTopKFallback(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,int64_t top_k,int64_t reduction_dim,const XlaComputation & comparator,float recall_target,bool aggregate_to_topk,int64_t reduction_input_size_override)253 XlaOp ApproxTopKFallback(XlaBuilder* builder, absl::Span<const XlaOp> operands,
254                          absl::Span<const XlaOp> init_values, int64_t top_k,
255                          int64_t reduction_dim,
256                          const XlaComputation& comparator, float recall_target,
257                          bool aggregate_to_topk,
258                          int64_t reduction_input_size_override) {
259   auto operands_shapes = builder->GetOperandShapes(operands).ValueOrDie();
260   int64_t rank = operands_shapes[0].rank();
261   uint64_t n = operands_shapes[0].dimensions(reduction_dim);
262   // Align the output size with ApproxTopK.
263   auto status_or_approx_output_size = ApproxTopKReductionOutputSize(
264       n, rank, top_k, recall_target, aggregate_to_topk,
265       reduction_input_size_override);
266   if (!status_or_approx_output_size.ok()) {
267     return builder->ReportError(status_or_approx_output_size.status());
268   }
269   auto output_size = status_or_approx_output_size.value().first;
270   return AggregateToTopKBuilder(builder, operands, init_values, output_size,
271                                 reduction_dim, comparator);
272 }
273 
274 }  // namespace xla
275