1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/numeric/bits.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/core/tpu/tpu_defs.h"
22
23 namespace tensorflow {
24 namespace {
25
26 using ::tensorflow::errors::InvalidArgument;
27
28 // Computes the Kth order statistic of a data set. The current
29 // implementation uses a binary search requiring exactly 32 passes
30 // over the input data. The running time is linear with respect to
31 // input size. The median-of-medians algorithm is probably faster, but
32 // is difficult to implement efficiently in XLA. The implementation
33 // imposes a total ordering on floats. The ordering is consistent with
34 // the usual partial order. Positive NaNs are greater than positive
35 // infinity. Negative NaNs are less than negative infinity. NaNs with
36 // distinct payloads are treated as distinct. Subnormal numbers are
37 // preserved (not flushed to zero). Positive infinity is greater than
38 // all numbers. Negative infinity is less than all numbers. Positive
39 // is greater than negative zero. There are less than k values greater
40 // than the kth order statistic. There are at least k values greater
41 // than or equal to the Kth order statistic. The semantics are not the
42 // same as TopKUnique.
CreateKthOrderStatisticComputation(xla::XlaBuilder * builder,const TensorShape & input_shape,const xla::XlaOp input,const xla::XlaOp k)43 xla::XlaOp CreateKthOrderStatisticComputation(xla::XlaBuilder* builder,
44 const TensorShape& input_shape,
45 const xla::XlaOp input,
46 const xla::XlaOp k) {
47 const int64_t height = input_shape.dim_size(0);
48 const int64_t width = input_shape.dim_size(1);
49
50 xla::XlaOp input_sm32 = xla::BitcastConvertType(input, xla::S32);
51 xla::XlaOp zero_r0 = xla::ConstantR0<int32>(builder, 0);
52 xla::XlaOp zero_r1 = xla::Broadcast(zero_r0, {height});
53 xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, {height, width});
54
55 xla::XlaOp max_r0 = xla::ConstantR0<int32>(builder, 0x7FFFFFFF);
56 xla::XlaOp max_r1 = xla::Broadcast(max_r0, {height});
57
58 // Start at positive zero, so that pivot is always less than top.
59 xla::XlaOp negative_zero_r0 = xla::ConstantR0<int32>(builder, 0x80000000);
60 xla::XlaOp negative_zero_r1 = xla::Broadcast(negative_zero_r0, {height});
61 xla::XlaOp top_r1 = zero_r1;
62
63 for (uint32 mask = 1U << 31; mask; mask >>= 1) {
64 xla::XlaOp broadcast_mask_r1 =
65 xla::Broadcast(xla::ConstantR0<int32>(builder, mask), {height});
66
67 // The first iteration of the loop determines if the kth element
68 // is positive or negative. If the kth element is negative, we
69 // start the search from +QNAN (0x7FFFFFF). If k is negative, we
70 // start from -0 (0x8000000). The pivot is less than the top and
71 // is always half way between the top and the implicit bottom in
72 // IEEE754 space.
73 xla::XlaOp pivot_r1 = xla::Xor(top_r1, broadcast_mask_r1);
74 xla::XlaOp pivot_r2 = xla::Add(pivot_r1, zero_r2, {0});
75 xla::XlaOp both_negative_r2 =
76 xla::Lt(xla::And(input_sm32, pivot_r2), zero_r0);
77 xla::XlaOp left_r2 = xla::Select(both_negative_r2, pivot_r2, input_sm32);
78 xla::XlaOp right_r2 = xla::Select(both_negative_r2, input_sm32, pivot_r2);
79 xla::XlaOp pred_r2 = xla::Gt(left_r2, right_r2);
80 xla::XlaOp conv_r2 = xla::ConvertElementType(pred_r2, xla::S32);
81
82 xla::XlaComputation add = CreateScalarAddComputation(xla::S32, builder);
83 xla::XlaOp sum_r1 = xla::Reduce(conv_r2, zero_r0, add, {1});
84
85 xla::XlaOp pivot_too_low_r1 = xla::Le(k, sum_r1, {});
86
87 if (mask == (1U << 31)) {
88 top_r1 = xla::Select(pivot_too_low_r1, max_r1, negative_zero_r1);
89 } else {
90 top_r1 = xla::Select(pivot_too_low_r1, top_r1, pivot_r1);
91 }
92 }
93 return xla::BitcastConvertType(top_r1, xla::F32);
94 }
95
96 class KthOrderStatistic : public XlaOpKernel {
97 public:
KthOrderStatistic(OpKernelConstruction * ctx)98 explicit KthOrderStatistic(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
99 OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
100 OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
101 }
102
Compile(XlaOpKernelContext * ctx)103 void Compile(XlaOpKernelContext* ctx) override {
104 xla::XlaBuilder* builder = ctx->builder();
105 xla::XlaOp input = ctx->Input(0);
106 const TensorShape& input_shape = ctx->InputShape(0);
107 OP_REQUIRES(
108 ctx, input_shape.dims() == 2,
109 InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
110
111 xla::XlaOp k = xla::ConstantR0<int32>(builder, k_);
112 xla::XlaOp kth_order_statistics =
113 CreateKthOrderStatisticComputation(builder, input_shape, input, k);
114 ctx->SetOutput(0, kth_order_statistics);
115 }
116
117 private:
118 int32 k_;
119 };
120
121 REGISTER_XLA_OP(Name("KthOrderStatistic"), KthOrderStatistic);
122
123 // Returns the TopK unique values in the array in sorted order and the
124 // indices of those elements. The running time is proportional to the
125 // product of K and the input size. Sorting the whole array is more
126 // efficient for sufficiently large values of K. The median-of-medians
127 // algorithm is probably faster, but difficult to implement
128 // efficiently in XLA. If there are fewer than K unique values, the
129 // results are padded with negative infinity. NaNs are never
130 // returned. Subnormal numbers are flushed to zero.
131 //
132 // If an element appears at multiple indices, the highest index is
133 // returned. If a TopK element never appears in the input due to
134 // padding values, the indices are padded with negative one. If a
135 // padding value appears in the input and padding is needed, the
136 // highest index of the padding value will be returned.
137 //
138 // The semantics are not the same as KthOrderStatistic.
139 //
140 // If masked_with_iota is true, the index is already encoded in the lower bits
141 // of the mantissa, which will be extracted as the index in the output.
142 // Otherwise, every iteration will use the following algorithm to get the index:
143 // index = max([i if data[i] == max else -1 for i in size])
144 //
145 // TODO(b/74994968): Replace TopKUnique with an LLO implementation of
146 // TopK with reasonable semantics.
CreateTopKUnique(xla::XlaBuilder * builder,const xla::XlaOp input,const TensorShape & input_shape,int64_t k,bool masked_with_iota)147 std::pair<xla::XlaOp, xla::XlaOp> CreateTopKUnique(
148 xla::XlaBuilder* builder, const xla::XlaOp input,
149 const TensorShape& input_shape, int64_t k, bool masked_with_iota) {
150 const int64_t height = input_shape.dim_size(0);
151 const int64_t width = input_shape.dim_size(1);
152
153 xla::XlaOp iota_r1 = xla::Iota(builder, xla::S32, width);
154 xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {height});
155
156 xla::XlaOp negative_one_r0 = xla::ConstantR0<int>(builder, -1);
157 xla::XlaOp negative_one_r2 = xla::Broadcast(negative_one_r0, {height, width});
158
159 xla::XlaOp negative_infinity_r0 = xla::ConstantR0<float>(builder, -INFINITY);
160 xla::XlaOp negative_infinity_r2 =
161 xla::Broadcast(negative_infinity_r0, {height, width});
162
163 xla::XlaOp scratch_pad_r2 = input;
164 std::vector<xla::XlaOp> topk_r1s;
165 std::vector<xla::XlaOp> topk_indices;
166 for (int i = 0; i < k; ++i) {
167 xla::XlaOp kth_order_statistic_r1 =
168 xla::Reduce(scratch_pad_r2, negative_infinity_r0,
169 CreateScalarMaxComputation(xla::F32, builder), {1});
170 topk_r1s.push_back(kth_order_statistic_r1);
171
172 xla::XlaOp ge_r2 = xla::Ge(input, kth_order_statistic_r1, {0});
173 scratch_pad_r2 = xla::Select(ge_r2, negative_infinity_r2, input);
174
175 if (!masked_with_iota) {
176 xla::XlaOp eq_r2 = xla::Eq(input, kth_order_statistic_r1, {0});
177 xla::XlaOp indices_r2 = xla::Select(eq_r2, iota_r2, negative_one_r2);
178 xla::XlaOp topk_index_r1 =
179 xla::Reduce(indices_r2, negative_one_r0,
180 CreateScalarMaxComputation(xla::S32, builder), {1});
181 topk_indices.push_back(topk_index_r1);
182 }
183 }
184 xla::XlaOp topk_r1_concat = xla::ConcatInDim(builder, topk_r1s, 0);
185 xla::XlaOp topk_r2 =
186 xla::Transpose(xla::Reshape(topk_r1_concat, {k, height}), {1, 0});
187
188 xla::XlaOp topk_indices_r2;
189 if (masked_with_iota) {
190 int32_t next_power_of_two = absl::bit_ceil<uint64_t>(width);
191 int32_t count_mask = next_power_of_two - 1;
192 xla::XlaOp mask_r0 = xla::ConstantR0(builder, count_mask);
193 xla::XlaOp mask_r2 = xla::Broadcast(mask_r0, {height, k});
194 xla::XlaOp topk_r2_s32 = xla::BitcastConvertType(topk_r2, xla::S32);
195 topk_indices_r2 = xla::And(topk_r2_s32, mask_r2);
196 } else {
197 xla::XlaOp topk_indices_concat = xla::ConcatInDim(builder, topk_indices, 0);
198 topk_indices_r2 =
199 xla::Transpose(xla::Reshape(topk_indices_concat, {k, height}), {1, 0});
200 }
201 return std::make_pair(topk_r2, topk_indices_r2);
202 }
203
204 class TopKUnique : public XlaOpKernel {
205 public:
TopKUnique(OpKernelConstruction * ctx)206 explicit TopKUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
207 OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
208 OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
209 }
210
Compile(XlaOpKernelContext * ctx)211 void Compile(XlaOpKernelContext* ctx) override {
212 xla::XlaBuilder* builder = ctx->builder();
213 xla::XlaOp input = ctx->Input(0);
214 const TensorShape& input_shape = ctx->InputShape(0);
215 OP_REQUIRES(
216 ctx, input_shape.dims() == 2,
217 InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
218
219 auto topk = CreateTopKUnique(builder, input, input_shape, k_, false);
220 ctx->SetOutput(0, topk.first);
221 ctx->SetOutput(1, topk.second);
222 }
223
224 private:
225 int k_;
226 };
227 REGISTER_XLA_OP(Name("TopKUnique"), TopKUnique);
228
229 // Make all elements in the non-Batch dimension unique and close to
230 // their initial value on a relative scale, but potential far from
231 // their initial value in an absolute scale.
232 //
233 // This operation is meant to be combined with TopKUnique to avoid
234 // suppressing identical elements. For most TopK users, the indices of
235 // the TopK elements are important but the relative order of the TopK
236 // elements and their exact values is not so important. Ideally, the
237 // the indices of the TopK elements of the output of MakeUnique are
238 // the same as the indices of the TopK elements of the inputs.
239 //
240 // Its an open question whether it is better to accept the risk of two
241 // elements in the input to TopK have exactly the same value or the
242 // risk that MakeUnique will alter the indices of the TopK
243 // elements. Model owners are encouraged to experiment!
244 //
245 // Never returns a sub-normal number. Never returns zero. The sign of
246 // each input element is always identical to the sign of the
247 // corresponding output element. Behavior for infinite elements is
248 // undefined. Behavior for subnormal elements is undefined.
249 //
250 // Algorithm:
251 // 1. Replace zeros with the smallest representable normal floating
252 // point number with the same sign.
253 // 2. Mask away enough low order bits that every value can be distinct.
254 // 3. Replace the low order bits with iota.
255 //
256 // TODO(b/74994968): Replace MakeUnique with an LLO implementation of
257 // TopK with reasonable semantics.
CreateMakeUnique(xla::XlaBuilder * builder,const xla::XlaOp input,const TensorShape & input_shape)258 xla::XlaOp CreateMakeUnique(xla::XlaBuilder* builder, const xla::XlaOp input,
259 const TensorShape& input_shape) {
260 const int64_t height = input_shape.dim_size(0);
261 const int64_t width = input_shape.dim_size(1);
262
263 xla::XlaOp zero_r0 = xla::ConstantR0(builder, 0U);
264 xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, {height, width});
265
266 // count_mask is used to mask away the low order bits to ensure
267 // that every element is distinct.
268 uint32_t next_power_of_two = absl::bit_ceil<uint64_t>(width);
269 uint32 count_mask = ~(next_power_of_two - 1);
270 xla::XlaOp count_mask_r0 = xla::ConstantR0(builder, count_mask);
271 xla::XlaOp count_mask_r2 = xla::Broadcast(count_mask_r0, {height, width});
272
273 // smallest_normal is the bit representation of the smallest
274 // positive normal floating point number. The sign is zero,
275 // exponent is one, and the fraction is zero.
276 uint32 smallest_normal = 1U << 23;
277 xla::XlaOp smallest_normal_r0 = xla::ConstantR0(builder, smallest_normal);
278 xla::XlaOp smallest_normal_r2 =
279 xla::Broadcast(smallest_normal_r0, {height, width});
280
281 // Used to mask away the sign bit when computing the absolute
282 // value.
283 uint32 low_bit_mask = ~(1U << 31);
284 xla::XlaOp low_bit_mask_r0 = xla::ConstantR0(builder, low_bit_mask);
285 xla::XlaOp low_bit_mask_r2 = xla::Broadcast(low_bit_mask_r0, {height, width});
286
287 xla::XlaOp iota_r1 = xla::Iota(builder, xla::U32, width);
288 xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {height});
289
290 // Compare the absolute value with positive zero to handle
291 // negative zero.
292 //
293 // Pseudocode: input_no_zeros = abs(input) == 0 ? FLT_MIN : input
294 xla::XlaOp input_u32_r2 = xla::BitcastConvertType(input, xla::U32);
295 xla::XlaOp abs_r2 = xla::And(input_u32_r2, low_bit_mask_r2);
296 xla::XlaOp if_zero_r2 = xla::Eq(abs_r2, zero_r2);
297 xla::XlaOp smallest_normal_preserving_sign_r2 =
298 xla::Or(input_u32_r2, smallest_normal_r2);
299 xla::XlaOp input_no_zeros_r2 =
300 xla::Select(if_zero_r2, smallest_normal_preserving_sign_r2, input_u32_r2);
301
302 // Discard the low-order bits and replace with iota.
303 xla::XlaOp and_r2 = xla::And(input_no_zeros_r2, count_mask_r2);
304 xla::XlaOp or_r2 = xla::Or(and_r2, iota_r2);
305 return xla::BitcastConvertType(or_r2, xla::F32);
306 }
307
308 class MakeUnique : public XlaOpKernel {
309 public:
MakeUnique(OpKernelConstruction * ctx)310 explicit MakeUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
311
Compile(XlaOpKernelContext * ctx)312 void Compile(XlaOpKernelContext* ctx) override {
313 xla::XlaBuilder* builder = ctx->builder();
314 xla::XlaOp input = ctx->Input(0);
315 const TensorShape& input_shape = ctx->InputShape(0);
316 OP_REQUIRES(
317 ctx, input_shape.dims() == 2,
318 InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
319
320 ctx->SetOutput(0, CreateMakeUnique(builder, input, input_shape));
321 }
322 };
323 REGISTER_XLA_OP(Name("MakeUnique"), MakeUnique);
324
325 // Returns the TopK approximate values in the array in sorted order and the
326 // indices of those elements. The running time is proportional to the
327 // product of K and the input size.
328 //
329 // The algorithm first updates the lower bits of each element with iota,
330 // which is used to derive the index. The iota also serves the purpose to
331 // make each element unique so that each iteration, we are guaranteed to
332 // get one and only one unique top-1 element.
333 class TopKWithUnique : public XlaOpKernel {
334 public:
TopKWithUnique(OpKernelConstruction * ctx)335 explicit TopKWithUnique(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
336 OP_REQUIRES_OK(ctx, ctx->GetAttr("k", &k_));
337 OP_REQUIRES(ctx, k_ >= 0, errors::InvalidArgument("Need k >= 0, got ", k_));
338 }
339
Compile(XlaOpKernelContext * ctx)340 void Compile(XlaOpKernelContext* ctx) override {
341 xla::XlaBuilder* builder = ctx->builder();
342 xla::XlaOp input = ctx->Input(0);
343 const TensorShape& input_shape = ctx->InputShape(0);
344 OP_REQUIRES(
345 ctx, input_shape.dims() == 2,
346 InvalidArgument("input must be rank-2: ", input_shape.DebugString()));
347
348 xla::XlaOp unique = CreateMakeUnique(builder, input, input_shape);
349 auto topk = CreateTopKUnique(builder, unique, input_shape, k_, true);
350 ctx->SetOutput(0, topk.first);
351 ctx->SetOutput(1, topk.second);
352 }
353
354 private:
355 int k_;
356 };
357 REGISTER_XLA_OP(Name("TopKWithUnique"), TopKWithUnique);
358 } // namespace
359 } // namespace tensorflow
360