1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/approx_topk.h"
17
18 #include <limits>
19 #include <string>
20
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/client/lib/approx_topk_shape.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/shape.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30
31 // Used by rank 2+ operands
32 const uint64_t kTpuLaneTiling = 128;
33 // Used by rank 1 operands.
34 const uint64_t kTpuChunkTiling = 1024;
35
36 namespace xla {
37
38 namespace {
GetOperandTypes(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values)39 StatusOr<std::vector<PrimitiveType>> GetOperandTypes(
40 XlaBuilder* builder, absl::Span<const XlaOp> operands,
41 absl::Span<const XlaOp> init_values) {
42 std::vector<PrimitiveType> op_types;
43 auto num_operands = operands.size();
44 auto operands_shapes = builder->GetOperandShapes(operands).ValueOrDie();
45 auto init_values_shapes = builder->GetOperandShapes(init_values).ValueOrDie();
46 for (int i = 0; i < num_operands; ++i) {
47 const auto& op_shape = operands_shapes[i];
48 const auto& init_shape = init_values_shapes[i];
49 if (op_shape.rank() == 0) {
50 return InvalidArgument("ApproxTopK operands must have rank 1+.");
51 }
52 if (!ShapeUtil::CompatibleIgnoringElementType(operands_shapes[0],
53 op_shape)) {
54 return InvalidArgument("operands shape mismatch: %s vs %s",
55 operands_shapes[0].DebugString(),
56 op_shape.DebugString());
57 }
58 if (op_shape.element_type() != init_shape.element_type()) {
59 return InvalidArgument("operands type mismatch: %s vs %s",
60 op_shape.DebugString(), init_shape.DebugString());
61 }
62 op_types.push_back(op_shape.element_type());
63 }
64 return op_types;
65 }
66 } // namespace
67
68 // Converts a comparator to a combiner computation that can be fed to reduce or
69 // partial reduce ops.
BuildReductionComputation(XlaBuilder * builder,absl::Span<const PrimitiveType> op_types,const XlaComputation & comparator)70 XlaComputation BuildReductionComputation(
71 XlaBuilder* builder, absl::Span<const PrimitiveType> op_types,
72 const XlaComputation& comparator) {
73 auto num_operands = op_types.size();
74 std::vector<XlaOp> lhs_params;
75 std::vector<XlaOp> rhs_params;
76 int64_t param_number = 0;
77 lhs_params.reserve(num_operands);
78 rhs_params.reserve(num_operands);
79 auto reduction_builder = builder->CreateSubBuilder("ReductionFn");
80 for (const auto& op_type : op_types) {
81 lhs_params.push_back(Parameter(reduction_builder.get(), param_number,
82 ShapeUtil::MakeScalarShape(op_type),
83 absl::StrFormat("lhs.%d", param_number)));
84 param_number++;
85 }
86 for (const auto& op_type : op_types) {
87 rhs_params.push_back(Parameter(reduction_builder.get(), param_number,
88 ShapeUtil::MakeScalarShape(op_type),
89 absl::StrFormat("rhs.%d", param_number)));
90 param_number++;
91 }
92
93 std::vector<XlaOp> comparator_args;
94 comparator_args.reserve(num_operands * 2);
95 for (int i = 0; i < num_operands; ++i) {
96 comparator_args.push_back(lhs_params[i]);
97 comparator_args.push_back(rhs_params[i]);
98 }
99 auto pred = Call(reduction_builder.get(), comparator, comparator_args);
100 std::vector<XlaOp> results;
101 results.reserve(num_operands);
102 for (int i = 0; i < num_operands; ++i) {
103 results.push_back(Select(pred, lhs_params[i], rhs_params[i]));
104 }
105 Tuple(reduction_builder.get(), results);
106 return reduction_builder->BuildAndNoteError();
107 }
108
AggregateToTopKBuilder(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,int64_t top_k,int64_t reduction_dim,const XlaComputation & comparator)109 XlaOp AggregateToTopKBuilder(XlaBuilder* builder,
110 absl::Span<const XlaOp> operands,
111 absl::Span<const XlaOp> init_values, int64_t top_k,
112 int64_t reduction_dim,
113 const XlaComputation& comparator) {
114 auto operands_shapes = builder->GetOperandShapes(operands).ValueOrDie();
115 int64_t rank = operands_shapes[0].rank();
116 int64_t num_operands = operands.size();
117
118 if (top_k == 1) {
119 auto status_or_optypes = GetOperandTypes(builder, operands, init_values);
120 if (!status_or_optypes.ok()) {
121 return builder->ReportError(status_or_optypes.status());
122 }
123 auto op_types = status_or_optypes.value();
124
125 auto reduction_computation =
126 BuildReductionComputation(builder, op_types, comparator);
127 auto val_args = Reduce(builder, operands, init_values,
128 reduction_computation, {reduction_dim});
129 Shape op_shape = operands_shapes[0];
130 op_shape.mutable_dimensions()[reduction_dim] = 1;
131 auto top1_vals =
132 Reshape(GetTupleElement(val_args, 0), op_shape.dimensions());
133 auto top1_args =
134 Reshape(GetTupleElement(val_args, 1), op_shape.dimensions());
135 return Tuple(builder, {top1_vals, top1_args});
136 }
137
138 auto sorted_results = Sort(operands, comparator, reduction_dim);
139 std::vector<int64_t> slice_start_indices(rank, 0);
140 std::vector<int64_t> slice_limit_indices;
141 std::vector<int64_t> slice_strides(rank, 1);
142 slice_limit_indices.insert(slice_limit_indices.begin(),
143 operands_shapes[0].dimensions().begin(),
144 operands_shapes[0].dimensions().end());
145 slice_limit_indices[reduction_dim] = top_k;
146
147 std::vector<XlaOp> sliced_results;
148 sliced_results.reserve(num_operands);
149 for (int i = 0; i < num_operands; ++i) {
150 sliced_results.push_back(Slice(GetTupleElement(sorted_results, i),
151 slice_start_indices, slice_limit_indices,
152 slice_strides));
153 }
154 return Tuple(builder, sliced_results);
155 }
156
ApproxTopK(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,int64_t top_k,int64_t reduction_dim,const XlaComputation & comparator,float recall_target,bool aggregate_to_topk,int64_t reduction_input_size_override)157 XlaOp ApproxTopK(XlaBuilder* builder, absl::Span<const XlaOp> operands,
158 absl::Span<const XlaOp> init_values, int64_t top_k,
159 int64_t reduction_dim, const XlaComputation& comparator,
160 float recall_target, bool aggregate_to_topk,
161 int64_t reduction_input_size_override) {
162 // Validates shapes and ranks
163 if (operands.size() != init_values.size()) {
164 return builder->ReportError(
165 InvalidArgument("operands and init_values size mismatch: %d vs %d",
166 operands.size(), init_values.size()));
167 }
168 auto num_operands = operands.size();
169 auto operands_shapes = builder->GetOperandShapes(operands).ValueOrDie();
170 auto init_values_shapes = builder->GetOperandShapes(init_values).ValueOrDie();
171 auto status_or_optypes = GetOperandTypes(builder, operands, init_values);
172 if (!status_or_optypes.ok()) {
173 return builder->ReportError(status_or_optypes.status());
174 }
175 auto op_types = status_or_optypes.value();
176 int64_t rank = operands_shapes[0].rank();
177 if (reduction_dim < 0 || reduction_dim >= rank) {
178 return builder->ReportError(
179 InvalidArgument("reduction_dim should range in [0,%d)", rank));
180 }
181
182 auto reduction_computation =
183 BuildReductionComputation(builder, op_types, comparator);
184
185 uint64_t tpu_tiling = rank == 1 ? kTpuChunkTiling : kTpuLaneTiling;
186 uint64_t n = operands_shapes[0].dimensions(reduction_dim);
187 // ApproxTopK can only reduce elements larger than the tiling.
188 if (n <= tpu_tiling) {
189 if (aggregate_to_topk) {
190 return AggregateToTopKBuilder(builder, operands, init_values, top_k,
191 reduction_dim, comparator);
192 }
193 return Tuple(builder, operands);
194 }
195
196 auto status_or_approx_output_size = ApproxTopKReductionOutputSize(
197 n, rank, top_k, recall_target, /*aggregate_to_topk=*/false,
198 reduction_input_size_override);
199 if (!status_or_approx_output_size.status().ok()) {
200 return builder->ReportError(status_or_approx_output_size.status());
201 }
202
203 int64_t approx_output_size, log2_reduction;
204 std::tie(approx_output_size, log2_reduction) =
205 status_or_approx_output_size.ValueOrDie();
206
207 if (log2_reduction == 0) {
208 if (aggregate_to_topk) {
209 return AggregateToTopKBuilder(builder, operands, init_values, top_k,
210 reduction_dim, comparator);
211 }
212 return Tuple(builder, operands);
213 }
214
215 std::vector<XlaOp> partial_reduce_args;
216 partial_reduce_args.reserve(operands.size() + init_values.size());
217 for (const auto& op : operands) {
218 partial_reduce_args.push_back(op);
219 }
220 for (const auto& op : init_values) {
221 partial_reduce_args.push_back(op);
222 }
223 std::vector<const Shape*> approx_output_shapes;
224 approx_output_shapes.reserve(operands_shapes.size());
225 for (auto& op_shape : operands_shapes) {
226 op_shape.mutable_dimensions()[reduction_dim] = approx_output_size;
227 approx_output_shapes.push_back(&op_shape);
228 }
229 auto approx_output_shape =
230 ShapeUtil::MakeTupleShapeWithPtrs(approx_output_shapes);
231 // PartialReduce option in JSON form
232 std::string partial_reduce_option = absl::StrFormat(
233 "{\"log2_reduction\": %d, \"reduction_dim\": %d, \"to_apply_type\": "
234 "\"comparator\"}",
235 log2_reduction, reduction_dim);
236
237 auto approx_topk = CustomCallWithComputation(
238 builder, "PartialReduce", partial_reduce_args, comparator,
239 approx_output_shape, partial_reduce_option);
240
241 if (aggregate_to_topk) {
242 std::vector<XlaOp> approx_topk_results;
243 approx_topk_results.reserve(num_operands);
244 for (int i = 0; i < num_operands; ++i) {
245 approx_topk_results.push_back(GetTupleElement(approx_topk, i));
246 }
247 return AggregateToTopKBuilder(builder, approx_topk_results, init_values,
248 top_k, reduction_dim, comparator);
249 }
250 return approx_topk;
251 }
252
ApproxTopKFallback(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,int64_t top_k,int64_t reduction_dim,const XlaComputation & comparator,float recall_target,bool aggregate_to_topk,int64_t reduction_input_size_override)253 XlaOp ApproxTopKFallback(XlaBuilder* builder, absl::Span<const XlaOp> operands,
254 absl::Span<const XlaOp> init_values, int64_t top_k,
255 int64_t reduction_dim,
256 const XlaComputation& comparator, float recall_target,
257 bool aggregate_to_topk,
258 int64_t reduction_input_size_override) {
259 auto operands_shapes = builder->GetOperandShapes(operands).ValueOrDie();
260 int64_t rank = operands_shapes[0].rank();
261 uint64_t n = operands_shapes[0].dimensions(reduction_dim);
262 // Align the output size with ApproxTopK.
263 auto status_or_approx_output_size = ApproxTopKReductionOutputSize(
264 n, rank, top_k, recall_target, aggregate_to_topk,
265 reduction_input_size_override);
266 if (!status_or_approx_output_size.ok()) {
267 return builder->ReportError(status_or_approx_output_size.status());
268 }
269 auto output_size = status_or_approx_output_size.value().first;
270 return AggregateToTopKBuilder(builder, operands, init_values, output_size,
271 reduction_dim, comparator);
272 }
273
274 } // namespace xla
275