xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/topk_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/numeric/bits.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/core/tpu/tpu_defs.h"
22 
23 namespace tensorflow {
24 namespace {
25 
26 using ::tensorflow::errors::InvalidArgument;
27 
28 // Computes the Kth order statistic of a data set. The current
29 // implementation uses a binary search requiring exactly 32 passes
30 // over the input data. The running time is linear with respect to
31 // input size. The median-of-medians algorithm is probably faster, but
32 // is difficult to implement efficiently in XLA. The implementation
33 // imposes a total ordering on floats. The ordering is consistent with
34 // the usual partial order.  Positive NaNs are greater than positive
35 // infinity. Negative NaNs are less than negative infinity. NaNs with
36 // distinct payloads are treated as distinct. Subnormal numbers are
37 // preserved (not flushed to zero). Positive infinity is greater than
38 // all numbers. Negative infinity is less than all numbers. Positive
39 // is greater than negative zero. There are less than k values greater
40 // than the kth order statistic. There are at least k values greater
41 // than or equal to the Kth order statistic. The semantics are not the
42 // same as TopKUnique.
CreateKthOrderStatisticComputation(xla::XlaBuilder * builder,const TensorShape & input_shape,const xla::XlaOp input,const xla::XlaOp k)43 xla::XlaOp CreateKthOrderStatisticComputation(xla::XlaBuilder* builder,
44                                               const TensorShape& input_shape,
45                                               const xla::XlaOp input,
46                                               const xla::XlaOp k) {
47   const int64_t height = input_shape.dim_size(0);
48   const int64_t width = input_shape.dim_size(1);
49 
50   xla::XlaOp input_sm32 = xla::BitcastConvertType(input, xla::S32);
51   xla::XlaOp zero_r0 = xla::ConstantR0<int32>(builder, 0);
52   xla::XlaOp zero_r1 = xla::Broadcast(zero_r0, {height});
53   xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, {height, width});
54 
55   xla::XlaOp max_r0 = xla::ConstantR0<int32>(builder, 0x7FFFFFFF);
56   xla::XlaOp max_r1 = xla::Broadcast(max_r0, {height});
57 
58   // Start at positive zero, so that pivot is always less than top.
59   xla::XlaOp negative_zero_r0 = xla::ConstantR0<int32>(builder, 0x80000000);
60   xla::XlaOp negative_zero_r1 = xla::Broadcast(negative_zero_r0, {height});
61   xla::XlaOp top_r1 = zero_r1;
62 
63   for (uint32 mask = 1U << 31; mask; mask >>= 1) {
64     xla::XlaOp broadcast_mask_r1 =
65         xla::Broadcast(xla::ConstantR0<int32>(builder, mask), {height});
66 
67     // The first iteration of the loop determines if the kth element
68     // is positive or negative. If the kth element is negative, we
69     // start the search from +QNAN (0x7FFFFFF). If k is negative, we
70     // start from -0 (0x8000000). The pivot is less than the top and
71     // is always half way between the top and the implicit bottom in
72     // IEEE754 space.
73     xla::XlaOp pivot_r1 = xla::Xor(top_r1, broadcast_mask_r1);
74     xla::XlaOp pivot_r2 = xla::Add(pivot_r1, zero_r2, {0});
75     xla::XlaOp both_negative_r2 =
76         xla::Lt(xla::And(input_sm32, pivot_r2), zero_r0);
77     xla::XlaOp left_r2 = xla::Select(both_negative_r2, pivot_r2, input_sm32);
78     xla::XlaOp right_r2 = xla::Select(both_negative_r2, input_sm32, pivot_r2);
79     xla::XlaOp pred_r2 = xla::Gt(left_r2, right_r2);
80     xla::XlaOp conv_r2 = xla::ConvertElementType(pred_r2, xla::S32);
81 
82     xla::XlaComputation add = CreateScalarAddComputation(xla::S32, builder);
83     xla::XlaOp sum_r1 = xla::Reduce(conv_r2, zero_r0, add, {1});
84 
85     xla::XlaOp pivot_too_low_r1 = xla::Le(k, sum_r1, {});
86 
87     if (mask == (1U << 31)) {
88       top_r1 = xla::Select(pivot_too_low_r1, max_r1, negative_zero_r1);
89     } else {
90       top_r1 = xla::Select(pivot_too_low_r1, top_r1, pivot_r1);
91     }
92   }
93   return xla::BitcastConvertType(top_r1, xla::F32);
94 }
95 
96 class KthOrderStatistic : public XlaOpKernel {
97  public:
KthOrderStatistic(OpKernelConstruction * ctx)98   explicit KthOrderStatistic(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
99     OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
100     OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
101   }
102 
Compile(XlaOpKernelContext * ctx)103   void Compile(XlaOpKernelContext* ctx) override {
104     xla::XlaBuilder* builder = ctx->builder();
105     xla::XlaOp input = ctx->Input(0);
106     const TensorShape& input_shape = ctx->InputShape(0);
107     OP_REQUIRES(
108         ctx, input_shape.dims() == 2,
109         InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
110 
111     xla::XlaOp k = xla::ConstantR0<int32>(builder, k_);
112     xla::XlaOp kth_order_statistics =
113         CreateKthOrderStatisticComputation(builder, input_shape, input, k);
114     ctx->SetOutput(0, kth_order_statistics);
115   }
116 
117  private:
118   int32 k_;
119 };
120 
121 REGISTER_XLA_OP(Name("KthOrderStatistic"), KthOrderStatistic);
122 
123 // Returns the TopK unique values in the array in sorted order and the
124 // indices of those elements. The running time is proportional to the
125 // product of K and the input size. Sorting the whole array is more
126 // efficient for sufficiently large values of K. The median-of-medians
127 // algorithm is probably faster, but difficult to implement
128 // efficiently in XLA. If there are fewer than K unique values, the
129 // results are padded with negative infinity. NaNs are never
130 // returned. Subnormal numbers are flushed to zero.
131 //
132 // If an element appears at multiple indices, the highest index is
133 // returned. If a TopK element never appears in the input due to
134 // padding values, the indices are padded with negative one. If a
135 // padding value appears in the input and padding is needed, the
136 // highest index of the padding value will be returned.
137 //
138 // The semantics are not the same as KthOrderStatistic.
139 //
140 // If masked_with_iota is true, the index is already encoded in the lower bits
141 // of the mantissa, which will be extracted as the index in the output.
142 // Otherwise, every iteration will use the following algorithm to get the index:
143 //   index = max([i if data[i] == max else -1 for i in size])
144 //
145 // TODO(b/74994968): Replace TopKUnique with an LLO implementation of
146 // TopK with reasonable semantics.
CreateTopKUnique(xla::XlaBuilder * builder,const xla::XlaOp input,const TensorShape & input_shape,int64_t k,bool masked_with_iota)147 std::pair<xla::XlaOp, xla::XlaOp> CreateTopKUnique(
148     xla::XlaBuilder* builder, const xla::XlaOp input,
149     const TensorShape& input_shape, int64_t k, bool masked_with_iota) {
150   const int64_t height = input_shape.dim_size(0);
151   const int64_t width = input_shape.dim_size(1);
152 
153   xla::XlaOp iota_r1 = xla::Iota(builder, xla::S32, width);
154   xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {height});
155 
156   xla::XlaOp negative_one_r0 = xla::ConstantR0<int>(builder, -1);
157   xla::XlaOp negative_one_r2 = xla::Broadcast(negative_one_r0, {height, width});
158 
159   xla::XlaOp negative_infinity_r0 = xla::ConstantR0<float>(builder, -INFINITY);
160   xla::XlaOp negative_infinity_r2 =
161       xla::Broadcast(negative_infinity_r0, {height, width});
162 
163   xla::XlaOp scratch_pad_r2 = input;
164   std::vector<xla::XlaOp> topk_r1s;
165   std::vector<xla::XlaOp> topk_indices;
166   for (int i = 0; i < k; ++i) {
167     xla::XlaOp kth_order_statistic_r1 =
168         xla::Reduce(scratch_pad_r2, negative_infinity_r0,
169                     CreateScalarMaxComputation(xla::F32, builder), {1});
170     topk_r1s.push_back(kth_order_statistic_r1);
171 
172     xla::XlaOp ge_r2 = xla::Ge(input, kth_order_statistic_r1, {0});
173     scratch_pad_r2 = xla::Select(ge_r2, negative_infinity_r2, input);
174 
175     if (!masked_with_iota) {
176       xla::XlaOp eq_r2 = xla::Eq(input, kth_order_statistic_r1, {0});
177       xla::XlaOp indices_r2 = xla::Select(eq_r2, iota_r2, negative_one_r2);
178       xla::XlaOp topk_index_r1 =
179           xla::Reduce(indices_r2, negative_one_r0,
180                       CreateScalarMaxComputation(xla::S32, builder), {1});
181       topk_indices.push_back(topk_index_r1);
182     }
183   }
184   xla::XlaOp topk_r1_concat = xla::ConcatInDim(builder, topk_r1s, 0);
185   xla::XlaOp topk_r2 =
186       xla::Transpose(xla::Reshape(topk_r1_concat, {k, height}), {1, 0});
187 
188   xla::XlaOp topk_indices_r2;
189   if (masked_with_iota) {
190     int32_t next_power_of_two = absl::bit_ceil<uint64_t>(width);
191     int32_t count_mask = next_power_of_two - 1;
192     xla::XlaOp mask_r0 = xla::ConstantR0(builder, count_mask);
193     xla::XlaOp mask_r2 = xla::Broadcast(mask_r0, {height, k});
194     xla::XlaOp topk_r2_s32 = xla::BitcastConvertType(topk_r2, xla::S32);
195     topk_indices_r2 = xla::And(topk_r2_s32, mask_r2);
196   } else {
197     xla::XlaOp topk_indices_concat = xla::ConcatInDim(builder, topk_indices, 0);
198     topk_indices_r2 =
199         xla::Transpose(xla::Reshape(topk_indices_concat, {k, height}), {1, 0});
200   }
201   return std::make_pair(topk_r2, topk_indices_r2);
202 }
203 
204 class TopKUnique : public XlaOpKernel {
205  public:
TopKUnique(OpKernelConstruction * ctx)206   explicit TopKUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
207     OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
208     OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
209   }
210 
Compile(XlaOpKernelContext * ctx)211   void Compile(XlaOpKernelContext* ctx) override {
212     xla::XlaBuilder* builder = ctx->builder();
213     xla::XlaOp input = ctx->Input(0);
214     const TensorShape& input_shape = ctx->InputShape(0);
215     OP_REQUIRES(
216         ctx, input_shape.dims() == 2,
217         InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
218 
219     auto topk = CreateTopKUnique(builder, input, input_shape, k_, false);
220     ctx->SetOutput(0, topk.first);
221     ctx->SetOutput(1, topk.second);
222   }
223 
224  private:
225   int k_;
226 };
227 REGISTER_XLA_OP(Name("TopKUnique"), TopKUnique);
228 
229 // Make all elements in the non-Batch dimension unique and close to
230 // their initial value on a relative scale, but potential far from
231 // their initial value in an absolute scale.
232 //
233 // This operation is meant to be combined with TopKUnique to avoid
234 // suppressing identical elements. For most TopK users, the indices of
235 // the TopK elements are important but the relative order of the TopK
236 // elements and their exact values is not so important. Ideally, the
237 // the indices of the TopK elements of the output of MakeUnique are
238 // the same as the indices of the TopK elements of the inputs.
239 //
240 // Its an open question whether it is better to accept the risk of two
241 // elements in the input to TopK have exactly the same value or the
242 // risk that MakeUnique will alter the indices of the TopK
243 // elements. Model owners are encouraged to experiment!
244 //
245 // Never returns a sub-normal number. Never returns zero. The sign of
246 // each input element is always identical to the sign of the
247 // corresponding output element. Behavior for infinite elements is
248 // undefined. Behavior for subnormal elements is undefined.
249 //
250 // Algorithm:
251 // 1. Replace zeros with the smallest representable normal floating
252 // point number with the same sign.
253 // 2. Mask away enough low order bits that every value can be distinct.
254 // 3. Replace the low order bits with iota.
255 //
256 // TODO(b/74994968): Replace MakeUnique with an LLO implementation of
257 // TopK with reasonable semantics.
CreateMakeUnique(xla::XlaBuilder * builder,const xla::XlaOp input,const TensorShape & input_shape)258 xla::XlaOp CreateMakeUnique(xla::XlaBuilder* builder, const xla::XlaOp input,
259                             const TensorShape& input_shape) {
260   const int64_t height = input_shape.dim_size(0);
261   const int64_t width = input_shape.dim_size(1);
262 
263   xla::XlaOp zero_r0 = xla::ConstantR0(builder, 0U);
264   xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, {height, width});
265 
266   // count_mask is used to mask away the low order bits to ensure
267   // that every element is distinct.
268   uint32_t next_power_of_two = absl::bit_ceil<uint64_t>(width);
269   uint32 count_mask = ~(next_power_of_two - 1);
270   xla::XlaOp count_mask_r0 = xla::ConstantR0(builder, count_mask);
271   xla::XlaOp count_mask_r2 = xla::Broadcast(count_mask_r0, {height, width});
272 
273   // smallest_normal is the bit representation of the smallest
274   // positive normal floating point number. The sign is zero,
275   // exponent is one, and the fraction is zero.
276   uint32 smallest_normal = 1U << 23;
277   xla::XlaOp smallest_normal_r0 = xla::ConstantR0(builder, smallest_normal);
278   xla::XlaOp smallest_normal_r2 =
279       xla::Broadcast(smallest_normal_r0, {height, width});
280 
281   // Used to mask away the sign bit when computing the absolute
282   // value.
283   uint32 low_bit_mask = ~(1U << 31);
284   xla::XlaOp low_bit_mask_r0 = xla::ConstantR0(builder, low_bit_mask);
285   xla::XlaOp low_bit_mask_r2 = xla::Broadcast(low_bit_mask_r0, {height, width});
286 
287   xla::XlaOp iota_r1 = xla::Iota(builder, xla::U32, width);
288   xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {height});
289 
290   // Compare the absolute value with positive zero to handle
291   // negative zero.
292   //
293   // Pseudocode: input_no_zeros = abs(input) == 0 ? FLT_MIN : input
294   xla::XlaOp input_u32_r2 = xla::BitcastConvertType(input, xla::U32);
295   xla::XlaOp abs_r2 = xla::And(input_u32_r2, low_bit_mask_r2);
296   xla::XlaOp if_zero_r2 = xla::Eq(abs_r2, zero_r2);
297   xla::XlaOp smallest_normal_preserving_sign_r2 =
298       xla::Or(input_u32_r2, smallest_normal_r2);
299   xla::XlaOp input_no_zeros_r2 =
300       xla::Select(if_zero_r2, smallest_normal_preserving_sign_r2, input_u32_r2);
301 
302   // Discard the low-order bits and replace with iota.
303   xla::XlaOp and_r2 = xla::And(input_no_zeros_r2, count_mask_r2);
304   xla::XlaOp or_r2 = xla::Or(and_r2, iota_r2);
305   return xla::BitcastConvertType(or_r2, xla::F32);
306 }
307 
308 class MakeUnique : public XlaOpKernel {
309  public:
MakeUnique(OpKernelConstruction * ctx)310   explicit MakeUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
311 
Compile(XlaOpKernelContext * ctx)312   void Compile(XlaOpKernelContext* ctx) override {
313     xla::XlaBuilder* builder = ctx->builder();
314     xla::XlaOp input = ctx->Input(0);
315     const TensorShape& input_shape = ctx->InputShape(0);
316     OP_REQUIRES(
317         ctx, input_shape.dims() == 2,
318         InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
319 
320     ctx->SetOutput(0, CreateMakeUnique(builder, input, input_shape));
321   }
322 };
323 REGISTER_XLA_OP(Name("MakeUnique"), MakeUnique);
324 
325 // Returns the TopK approximate values in the array in sorted order and the
326 // indices of those elements. The running time is proportional to the
327 // product of K and the input size.
328 //
329 // The algorithm first updates the lower bits of each element with iota,
330 // which is used to derive the index. The iota also serves the purpose to
331 // make each element unique so that each iteration, we are guaranteed to
332 // get one and only one unique top-1 element.
333 class TopKWithUnique : public XlaOpKernel {
334  public:
TopKWithUnique(OpKernelConstruction * ctx)335   explicit TopKWithUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
336     OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
337     OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
338   }
339 
Compile(XlaOpKernelContext * ctx)340   void Compile(XlaOpKernelContext* ctx) override {
341     xla::XlaBuilder* builder = ctx->builder();
342     xla::XlaOp input = ctx->Input(0);
343     const TensorShape& input_shape = ctx->InputShape(0);
344     OP_REQUIRES(
345         ctx, input_shape.dims() == 2,
346         InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
347 
348     xla::XlaOp unique = CreateMakeUnique(builder, input, input_shape);
349     auto topk = CreateTopKUnique(builder, unique, input_shape, k_, true);
350     ctx->SetOutput(0, topk.first);
351     ctx->SetOutput(1, topk.second);
352   }
353 
354  private:
355   int k_;
356 };
357 REGISTER_XLA_OP(Name("TopKWithUnique"), TopKWithUnique);
358 }  // namespace
359 }  // namespace tensorflow
360