1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
20 #include "tensorflow/compiler/tf2xla/lib/util.h"
21 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
26 #include "tensorflow/compiler/xla/client/lib/comparators.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
29 #include "tensorflow/compiler/xla/client/lib/loops.h"
30 #include "tensorflow/compiler/xla/client/lib/sorting.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/status.h"
39
40 namespace tensorflow {
41 namespace {
42
43 // Converts 'input' from RGB format to HSV format.
44 // 'shape' is the shape of the red/green/blue tensors.
RGBToHSV(XlaOpKernelContext * ctx,xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & rgb,DataType dtype,const TensorShape & shape)45 std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b,
46 const std::array<xla::XlaOp, 3>& rgb,
47 DataType dtype, const TensorShape& shape) {
48 auto zero = XlaHelpers::Zero(b, dtype);
49 auto one = XlaHelpers::One(b, dtype);
50
51 auto red = rgb[0];
52 auto green = rgb[1];
53 auto blue = rgb[2];
54 auto value = xla::Max(xla::Max(red, green), blue);
55 auto minimum = xla::Min(xla::Min(red, green), blue);
56 auto range = xla::Sub(value, minimum);
57
58 auto zeros = xla::Broadcast(zero, shape.dim_sizes());
59 auto saturation =
60 xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros);
61
62 auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
63
64 auto hue =
65 xla::Select(xla::Eq(green, value),
66 xla::Add(xla::Mul(norm, xla::Sub(blue, red)),
67 XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
68 xla::Add(xla::Mul(norm, xla::Sub(red, green)),
69 XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
70 hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)),
71 hue);
72 hue = xla::Select(xla::Gt(range, zero), hue, zeros);
73 hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue);
74 return {hue, saturation, value};
75 }
76
77 // Converts 'input' from HSV format to RGB format.
HSVToRGB(xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & hsv,DataType dtype)78 std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b,
79 const std::array<xla::XlaOp, 3>& hsv,
80 DataType dtype) {
81 xla::XlaOp hue = hsv[0];
82 xla::XlaOp saturation = hsv[1];
83 xla::XlaOp value = hsv[2];
84 auto zero = XlaHelpers::Zero(b, dtype);
85 auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
86 auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
87 auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0);
88 auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0);
89 auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0);
90
91 auto dh = xla::Mul(hue, six);
92 auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one);
93 auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one);
94 auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one);
95 auto one_minus_s = xla::Sub(one, saturation);
96
97 auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value);
98 auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value);
99 auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value);
100 return {red, green, blue};
101 }
102
103 class RGBToHSVOp : public XlaOpKernel {
104 public:
RGBToHSVOp(OpKernelConstruction * context)105 explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
106
Compile(XlaOpKernelContext * context)107 void Compile(XlaOpKernelContext* context) override {
108 const TensorShape input_shape = context->InputShape(0);
109 OP_REQUIRES(context, input_shape.dims() >= 1,
110 errors::InvalidArgument("input must be at least 1D",
111 input_shape.DebugString()));
112 int channel_dim = input_shape.dims() - 1;
113 int64_t channels = input_shape.dim_size(channel_dim);
114 OP_REQUIRES(
115 context, channels == 3,
116 errors::FailedPrecondition("input must have 3 channels but input has ",
117 channels, " channels."));
118
119 xla::XlaBuilder* b = context->builder();
120 xla::XlaOp input = context->Input(0);
121
122 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
123 /*limit_index=*/1, /*stride=*/1,
124 /*dimno=*/channel_dim);
125 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
126 /*limit_index=*/2, /*stride=*/1,
127 /*dimno=*/channel_dim);
128 xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
129 /*limit_index=*/3, /*stride=*/1,
130 /*dimno=*/channel_dim);
131 TensorShape channel_shape = input_shape;
132 channel_shape.set_dim(channel_dim, 1);
133 auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
134 channel_shape);
135
136 context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim));
137 }
138 };
139 REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp);
140
141 class HSVToRGBOp : public XlaOpKernel {
142 public:
HSVToRGBOp(OpKernelConstruction * context)143 explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
144
Compile(XlaOpKernelContext * context)145 void Compile(XlaOpKernelContext* context) override {
146 const TensorShape input_shape = context->InputShape(0);
147 OP_REQUIRES(context, input_shape.dims() >= 1,
148 errors::InvalidArgument("input must be at least 1D",
149 input_shape.DebugString()));
150 int channel_dim = input_shape.dims() - 1;
151 int64_t channels = input_shape.dim_size(channel_dim);
152 OP_REQUIRES(
153 context, channels == 3,
154 errors::FailedPrecondition("input must have 3 channels but input has ",
155 channels, " channels."));
156
157 xla::XlaBuilder* b = context->builder();
158 xla::XlaOp input = context->Input(0);
159 xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0,
160 /*limit_index=*/1, /*stride=*/1,
161 /*dimno=*/channel_dim);
162 xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1,
163 /*limit_index=*/2, /*stride=*/1,
164 /*dimno=*/channel_dim);
165 xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2,
166 /*limit_index=*/3, /*stride=*/1,
167 /*dimno=*/channel_dim);
168
169 auto rgb = HSVToRGB(context->builder(), {hue, saturation, value},
170 context->input_type(0));
171
172 context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
173 }
174 };
175 REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp);
176
177 class AdjustContrastOpV2 : public XlaOpKernel {
178 public:
AdjustContrastOpV2(OpKernelConstruction * context)179 explicit AdjustContrastOpV2(OpKernelConstruction* context)
180 : XlaOpKernel(context) {}
181
Compile(XlaOpKernelContext * context)182 void Compile(XlaOpKernelContext* context) override {
183 const TensorShape& input_shape = context->InputShape(0);
184 const TensorShape& factor_shape = context->InputShape(1);
185 OP_REQUIRES(context, input_shape.dims() >= 3,
186 errors::InvalidArgument("input must be at least 3-D, got shape",
187 input_shape.DebugString()));
188 int height_dim = input_shape.dims() - 3;
189 int width_dim = input_shape.dims() - 2;
190 int channel_dim = input_shape.dims() - 1;
191 const int64_t height = input_shape.dim_size(height_dim);
192 const int64_t width = input_shape.dim_size(width_dim);
193
194 OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape),
195 errors::InvalidArgument("contrast_factor must be scalar: ",
196 factor_shape.DebugString()));
197
198 xla::XlaBuilder* b = context->builder();
199 DataType type = context->input_type(0);
200
201 xla::XlaOp input = context->Input(0);
202 xla::XlaOp factor = XlaHelpers::ConvertElementType(context->Input(1), type);
203
204 const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
205 auto converted = XlaHelpers::ConvertElementType(input, accumulation_type);
206 auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
207 *context->GetOrCreateAdd(accumulation_type),
208 {height_dim, width_dim});
209
210 auto output = xla::Div(
211 reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width));
212 output = XlaHelpers::ConvertElementType(output, type);
213
214 std::vector<int64_t> broadcast_dims(input_shape.dims() - 2);
215 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
216 broadcast_dims.back() = channel_dim;
217 output =
218 xla::Add(xla::Mul(input, factor),
219 xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)),
220 broadcast_dims);
221 context->SetOutput(0, output);
222 }
223 };
224 REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2);
225
226 class AdjustSaturationOp : public XlaOpKernel {
227 public:
AdjustSaturationOp(OpKernelConstruction * context)228 explicit AdjustSaturationOp(OpKernelConstruction* context)
229 : XlaOpKernel(context) {}
230
Compile(XlaOpKernelContext * context)231 void Compile(XlaOpKernelContext* context) override {
232 const TensorShape& input_shape = context->InputShape(0);
233 const TensorShape& scale_shape = context->InputShape(1);
234 OP_REQUIRES(context, input_shape.dims() >= 3,
235 errors::InvalidArgument("input must be at least 3-D, got shape",
236 input_shape.DebugString()));
237 OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape),
238 errors::InvalidArgument("scale must be scalar: ",
239 scale_shape.DebugString()));
240 const int channel_dim = input_shape.dims() - 1;
241 const int64_t channels = input_shape.dim_size(channel_dim);
242 OP_REQUIRES(
243 context, channels == 3,
244 errors::InvalidArgument("input must have 3 channels but instead has ",
245 channels, " channels."));
246
247 xla::XlaBuilder* b = context->builder();
248 xla::XlaOp input =
249 XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
250 xla::XlaOp scale =
251 XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
252
253 DataType type = context->input_type(0);
254
255 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
256 /*limit_index=*/1, /*stride=*/1,
257 /*dimno=*/channel_dim);
258 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
259 /*limit_index=*/2, /*stride=*/1,
260 /*dimno=*/channel_dim);
261 xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
262 /*limit_index=*/3, /*stride=*/1,
263 /*dimno=*/channel_dim);
264 TensorShape channel_shape = input_shape;
265 channel_shape.set_dim(channel_dim, 1);
266 auto hsv =
267 RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
268
269 hsv[1] = xla::Clamp(XlaHelpers::Zero(b, DT_FLOAT), xla::Mul(hsv[1], scale),
270 XlaHelpers::One(b, DT_FLOAT));
271
272 auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
273
274 auto output = XlaHelpers::ConvertElementType(
275 xla::ConcatInDim(b, rgb, channel_dim), type);
276 context->SetOutput(0, output);
277 }
278 };
279 REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp);
280
281 class AdjustHueOp : public XlaOpKernel {
282 public:
AdjustHueOp(OpKernelConstruction * context)283 explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
284
Compile(XlaOpKernelContext * context)285 void Compile(XlaOpKernelContext* context) override {
286 const TensorShape& input_shape = context->InputShape(0);
287 const TensorShape& delta_shape = context->InputShape(1);
288 OP_REQUIRES(context, input_shape.dims() >= 3,
289 errors::InvalidArgument("input must be at least 3-D, got shape",
290 input_shape.DebugString()));
291 OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape),
292 errors::InvalidArgument("delta must be scalar: ",
293 delta_shape.DebugString()));
294 const int channel_dim = input_shape.dims() - 1;
295 const int64_t channels = input_shape.dim_size(channel_dim);
296 OP_REQUIRES(
297 context, channels == 3,
298 errors::InvalidArgument("input must have 3 channels but instead has ",
299 channels, " channels."));
300
301 xla::XlaBuilder* b = context->builder();
302 xla::XlaOp input =
303 XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
304 xla::XlaOp delta =
305 XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
306
307 DataType type = context->input_type(0);
308
309 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
310 /*limit_index=*/1, /*stride=*/1,
311 /*dimno=*/channel_dim);
312 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
313 /*limit_index=*/2, /*stride=*/1,
314 /*dimno=*/channel_dim);
315 xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
316 /*limit_index=*/3, /*stride=*/1,
317 /*dimno=*/channel_dim);
318 TensorShape channel_shape = input_shape;
319 channel_shape.set_dim(channel_dim, 1);
320 auto hsv =
321 RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
322
323 auto zero = XlaHelpers::Zero(b, DT_FLOAT);
324 auto one = XlaHelpers::One(b, DT_FLOAT);
325
326 auto& hue = hsv[0];
327 hue = xla::Rem(xla::Add(hsv[0], delta), one);
328 hue =
329 xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue);
330
331 auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
332
333 auto output = XlaHelpers::ConvertElementType(
334 xla::ConcatInDim(b, rgb, channel_dim), type);
335 context->SetOutput(0, output);
336 }
337 };
338 REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
339
340 struct WhileCondFn {
341 const int64_t num_boxes;
342 const int64_t output_size;
343
WhileCondFntensorflow::__anonabc971000111::WhileCondFn344 explicit WhileCondFn(int64_t num_boxes, int64_t output_size)
345 : num_boxes(num_boxes), output_size(output_size) {}
346
operator ()tensorflow::__anonabc971000111::WhileCondFn347 StatusOr<xla::XlaOp> operator()(absl::Span<const xla::XlaOp> values,
348 xla::XlaBuilder* cond_builder) const {
349 xla::XlaOp row_idx = values[0];
350 xla::XlaOp row_in_bounds =
351 xla::Lt(row_idx, xla::ConstantR0<int32>(cond_builder, num_boxes));
352 xla::XlaOp num_outputs_so_far = values[1];
353 xla::XlaOp results_not_full = xla::Lt(
354 num_outputs_so_far, xla::ConstantR0<int32>(cond_builder, output_size));
355 return xla::And(row_in_bounds, results_not_full);
356 }
357 };
358
359 // Process the boxes one-by-one using the iou matrix mask.
360 // This implementation uses a correct, but greedy, sequential algorithm
361 // to ensure that suppressed boxes cannot themselves suppress other
362 // boxes.
363 struct SuppressBodyFn {
364 const int64_t num_boxes;
365
SuppressBodyFntensorflow::__anonabc971000111::SuppressBodyFn366 explicit SuppressBodyFn(int64_t num_boxes) : num_boxes(num_boxes) {}
367
operator ()tensorflow::__anonabc971000111::SuppressBodyFn368 StatusOr<std::vector<xla::XlaOp>> operator()(
369 absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) const {
370 auto row_idx = values[0];
371 auto num_outputs_so_far = values[1];
372 auto iou_mask = values[2];
373 auto included_iou = values[3];
374 auto zero = xla::ConstantR0<int32>(builder, 0);
375 // Determine if current elem is active using a slice.
376 // TODO(b/118437727): The only reason we need an explicit vector is because
377 // some old GCCs can't deduce the right type for MakeConstSpan, and
378 // providing a single-value initializer list directly uses the wrong
379 // overload. Delete this once the deprecated overload is gone.
380 std::vector<xla::XlaOp> row_idx_vector = {row_idx};
381 auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1});
382 active_elem = xla::Reshape(active_elem, {});
383 // Increment output count iff current elem is not suppressed.
384 num_outputs_so_far = xla::Select(
385 active_elem, num_outputs_so_far + xla::ConstantR0<int32>(builder, 1),
386 num_outputs_so_far);
387 // Slice out the row_idx.
388 auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes});
389 // Remove the diagonal from consideration. An elem cannot suppress
390 // itself.
391 row_iou = xla::DynamicUpdateSlice(
392 row_iou, xla::ConstantR2FromArray2D<bool>(builder, {{false}}),
393 {zero, row_idx});
394 // Create a suppression by inverting polarity.
395 row_iou = xla::Reshape(row_iou, {num_boxes});
396 auto supp_mask = xla::Not(row_iou);
397 // Update mask iff current elem is not suppressed.
398 included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}),
399 xla::And(included_iou, supp_mask), included_iou);
400 row_idx = row_idx + xla::ConstantR0<int32>(builder, 1);
401 return std::vector<xla::XlaOp>{row_idx, num_outputs_so_far, iou_mask,
402 included_iou};
403 }
404 };
405
406 class NonMaxSuppressionOp : public XlaOpKernel {
407 public:
NonMaxSuppressionOp(OpKernelConstruction * context)408 explicit NonMaxSuppressionOp(OpKernelConstruction* context)
409 : XlaOpKernel(context) {
410 OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
411 &pad_to_max_output_size_));
412 }
413
Compile(XlaOpKernelContext * context)414 void Compile(XlaOpKernelContext* context) override {
415 // TODO(b/111646731): Improve scalability of this op, using blocking.
416 OP_REQUIRES(context, pad_to_max_output_size_,
417 errors::Unimplemented(
418 "XLA compilation requires pad_to_max_output_size == True"));
419
420 xla::XlaOp selected_indices, num_valid;
421 ComputeResult(context, pad_to_max_output_size_);
422 }
ComputeResult(XlaOpKernelContext * context,bool pad_to_max_output_size=false)423 static void ComputeResult(XlaOpKernelContext* context,
424 bool pad_to_max_output_size = false) {
425 const TensorShape& boxes_shape = context->InputShape("boxes");
426 OP_REQUIRES(
427 context, TensorShapeUtils::IsMatrix(boxes_shape),
428 errors::InvalidArgument("boxes must be 2-D, currently: [",
429 std::to_string(boxes_shape.dim_size(0)), ",",
430 std::to_string(boxes_shape.dim_size(1)), "]"));
431 const int64_t num_boxes = boxes_shape.dim_size(0);
432 OP_REQUIRES(
433 context, boxes_shape.dim_size(1) == 4,
434 errors::InvalidArgument("boxes must have 4 columns, currently: ",
435 std::to_string(boxes_shape.dim_size(1))));
436 const TensorShape& scores_shape = context->InputShape("scores");
437 OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape),
438 errors::InvalidArgument("scores must be 1-D, currently: ",
439 scores_shape.DebugString()));
440 OP_REQUIRES(context, scores_shape.dim_size(0) == num_boxes,
441 errors::InvalidArgument(
442 "scores size ", std::to_string(scores_shape.dim_size(0)),
443 " must equal number of boxes ", std::to_string(num_boxes)));
444 OP_REQUIRES(context, num_boxes <= kint32max,
445 errors::InvalidArgument("XLA compilation requires number of "
446 "boxes to be <= kint32max, got ",
447 num_boxes));
448 xla::PrimitiveType boxes_xla_type = context->InputXlaType("boxes");
449 xla::PrimitiveType scores_xla_type = context->InputXlaType("scores");
450 const xla::XlaOp boxes_input = context->Input("boxes");
451 const xla::XlaOp scores_input = context->Input("scores");
452 int64_t output_size;
453 OP_REQUIRES(
454 context,
455 TensorShapeUtils::IsScalar(context->InputShape("max_output_size")),
456 errors::InvalidArgument("Max Output Size isn't a scalar"));
457 OP_REQUIRES(
458 context,
459 TensorShapeUtils::IsScalar(context->InputShape("iou_threshold")),
460 errors::InvalidArgument("IOU Threshold isn't a scalar"));
461 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size));
462 OP_REQUIRES(
463 context, output_size >= 0,
464 errors::InvalidArgument("Need output_size >= 0, got ", output_size));
465 OP_REQUIRES(context, output_size <= kint32max,
466 errors::InvalidArgument("Need output_size <= kint32Max, got ",
467 output_size));
468 const xla::XlaOp score_thresh = context->Input("score_threshold");
469 const xla::XlaOp iou_thresh = context->Input("iou_threshold");
470 xla::XlaBuilder* const builder = context->builder();
471
472 // Choose a more convenient layout.
473 const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0});
474 const xla::XlaOp boxes_sorted = xla::GetTupleElement(
475 xla::Sort({xla::Broadcast(scores_input, {4}), boxes},
476 xla::CreateScalarGtComputation(
477 {scores_xla_type, boxes_xla_type}, builder),
478 /*dimension=*/1),
479 1);
480 // Track the mapping of indices into sorted domain.
481 const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes);
482 const xla::XlaOp indices_sort = xla::Sort(
483 {scores_input, iota_indices},
484 xla::CreateScalarGtComputation({scores_xla_type, xla::S32}, builder));
485 const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1);
486 const xla::XlaOp scores = xla::GetTupleElement(indices_sort, 0);
487
488 // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0.
489 const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
490 /*start_index=*/0,
491 /*limit_index=*/1,
492 /*stride=*/1,
493 /*dimno=*/0),
494 {num_boxes});
495 const xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
496 /*start_index=*/1,
497 /*limit_index=*/2,
498 /*stride=*/1,
499 /*dimno=*/0),
500 {num_boxes});
501 const xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
502 /*start_index=*/2,
503 /*limit_index=*/3,
504 /*stride=*/1,
505 /*dimno=*/0),
506 {num_boxes});
507 const xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
508 /*start_index=*/3,
509 /*limit_index=*/4,
510 /*stride=*/1,
511 /*dimno=*/0),
512 {num_boxes});
513
514 xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1);
515 xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0);
516 xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1);
517 xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0);
518 xla::XlaOp area = (y2 - y1) * (x2 - x1);
519
520 // Shapes are henceforth [1, num_boxes].
521 y1 = xla::Broadcast(y1, {1});
522 y2 = xla::Broadcast(y2, {1});
523 x1 = xla::Broadcast(x1, {1});
524 x2 = xla::Broadcast(x2, {1});
525 area = xla::Broadcast(area, {1});
526
527 // Shapes are henceforth [num_boxes, num_boxes].
528 xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0}));
529 xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0}));
530 xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0}));
531 xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0}));
532 auto square_zero = xla::ZerosLike(i_xmin);
533
534 xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) *
535 xla::Max(i_ymax - i_ymin, square_zero);
536 xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area;
537 xla::XlaOp iou = i_area / u_area;
538
539 xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero);
540 xla::XlaOp included_iou =
541 xla::Broadcast(xla::ConstantR0<bool>(builder, true), {num_boxes});
542
543 std::vector<xla::XlaOp> init_values;
544 init_values.reserve(4);
545 init_values.push_back(xla::ConstantR0<int32>(builder, 0)); // col_idx
546 init_values.push_back(xla::ConstantR0<int32>(builder, 0)); // num_outputs
547 init_values.push_back(iou_thresh_mask);
548 init_values.push_back(included_iou);
549
550 auto suppress_loop_result =
551 xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
552 SuppressBodyFn(num_boxes), init_values,
553 "suppress_loop", builder)
554 .ValueOrDie();
555
556 xla::XlaOp included_score =
557 xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes}));
558 xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]);
559
560 // Only consider boxes over which we have iterated. This allows for accurate
561 // counting. DynamicSlice would require knowledge of the size of the output.
562 auto valid_elem = xla::Lt(
563 iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes}));
564 included = xla::And(included, valid_elem);
565
566 xla::XlaOp neg_inf =
567 xla::Broadcast(xla::MinValue(builder, boxes_xla_type), {num_boxes});
568 xla::XlaOp scores_included = xla::Select(included, scores, neg_inf);
569 xla::XlaOp output_tuple = TopK(scores_included, output_size);
570 xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1);
571 // Calculate num_valid.
572 // Note: num_valid cannot be taken from the loop outputs, because outputs
573 // can be suppressed by score threshold.
574 xla::XlaOp ones_included = xla::Select(
575 included,
576 xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
577 xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
578 // num_valid is scalar. Value should be bound by output_size.
579
580 xla::XlaOp num_valid_total = xla::Reduce(
581 ones_included,
582 /*init_value=*/xla::ConstantR0<int>(builder, 0),
583 /*computation=*/CreateScalarAddComputation(xla::S32, builder),
584 /*dimensions_to_reduce=*/{0});
585 xla::XlaOp num_valid =
586 xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
587
588 // Re-index into the original scores input tensor, using a Gather.
589 // Boxes were suppressed in the sorted domain.
590 xla::XlaOp selected_indices;
591 DataType gather_type = context->expected_output_dtype(0);
592 OP_REQUIRES_OK(
593 context,
594 XlaGather(indices_sorted, scores_shape, selected_indices_sorted,
595 TensorShape({output_size}),
596 /*axis=*/0,
597 /*indices_are_nd=*/false,
598 /*dtype=*/gather_type, DT_INT32, builder, &selected_indices));
599
600 if (!pad_to_max_output_size) {
601 StatusOr<xla::XlaOp> rebounded_result = xla::SetDimensionSizeWithRebound(
602 &context->value_inference(), selected_indices, num_valid, 0);
603 if (rebounded_result.ok()) {
604 selected_indices = *rebounded_result;
605 } else {
606 // TODO(b/207187072): Remove special handling once dynamic reshape
607 // can also be handled.
608 selected_indices =
609 xla::SetDimensionSize(selected_indices, num_valid, 0);
610 }
611 }
612 context->SetOutput(0, selected_indices);
613 if (pad_to_max_output_size) context->SetOutput(1, num_valid);
614 }
615
616 private:
617 bool pad_to_max_output_size_;
618 };
619
620 REGISTER_XLA_OP(
621 Name("NonMaxSuppressionV4").CompileTimeConstantInput("max_output_size"),
622 NonMaxSuppressionOp);
623
624 class NonMaxSuppressionV3Op : public XlaOpKernel {
625 public:
NonMaxSuppressionV3Op(OpKernelConstruction * context)626 explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
627 : XlaOpKernel(context) {}
628
Compile(XlaOpKernelContext * context)629 void Compile(XlaOpKernelContext* context) override {
630 xla::XlaOp selected_indices, num_valid;
631 NonMaxSuppressionOp::ComputeResult(context);
632 }
633 };
634
635 REGISTER_XLA_OP(
636 Name("NonMaxSuppressionV3").CompileTimeConstantInput("max_output_size"),
637 NonMaxSuppressionV3Op);
638
639 } // namespace
640 } // namespace tensorflow
641