xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/image_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
20 #include "tensorflow/compiler/tf2xla/lib/util.h"
21 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
26 #include "tensorflow/compiler/xla/client/lib/comparators.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
29 #include "tensorflow/compiler/xla/client/lib/loops.h"
30 #include "tensorflow/compiler/xla/client/lib/sorting.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/status.h"
39 
40 namespace tensorflow {
41 namespace {
42 
43 // Converts 'input' from RGB format to HSV format.
44 // 'shape' is the shape of the red/green/blue tensors.
RGBToHSV(XlaOpKernelContext * ctx,xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & rgb,DataType dtype,const TensorShape & shape)45 std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b,
46                                    const std::array<xla::XlaOp, 3>& rgb,
47                                    DataType dtype, const TensorShape& shape) {
48   auto zero = XlaHelpers::Zero(b, dtype);
49   auto one = XlaHelpers::One(b, dtype);
50 
51   auto red = rgb[0];
52   auto green = rgb[1];
53   auto blue = rgb[2];
54   auto value = xla::Max(xla::Max(red, green), blue);
55   auto minimum = xla::Min(xla::Min(red, green), blue);
56   auto range = xla::Sub(value, minimum);
57 
58   auto zeros = xla::Broadcast(zero, shape.dim_sizes());
59   auto saturation =
60       xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros);
61 
62   auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
63 
64   auto hue =
65       xla::Select(xla::Eq(green, value),
66                   xla::Add(xla::Mul(norm, xla::Sub(blue, red)),
67                            XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
68                   xla::Add(xla::Mul(norm, xla::Sub(red, green)),
69                            XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
70   hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)),
71                     hue);
72   hue = xla::Select(xla::Gt(range, zero), hue, zeros);
73   hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue);
74   return {hue, saturation, value};
75 }
76 
77 // Converts 'input' from HSV format to RGB format.
HSVToRGB(xla::XlaBuilder * b,const std::array<xla::XlaOp,3> & hsv,DataType dtype)78 std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b,
79                                    const std::array<xla::XlaOp, 3>& hsv,
80                                    DataType dtype) {
81   xla::XlaOp hue = hsv[0];
82   xla::XlaOp saturation = hsv[1];
83   xla::XlaOp value = hsv[2];
84   auto zero = XlaHelpers::Zero(b, dtype);
85   auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
86   auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
87   auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0);
88   auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0);
89   auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0);
90 
91   auto dh = xla::Mul(hue, six);
92   auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one);
93   auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one);
94   auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one);
95   auto one_minus_s = xla::Sub(one, saturation);
96 
97   auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value);
98   auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value);
99   auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value);
100   return {red, green, blue};
101 }
102 
103 class RGBToHSVOp : public XlaOpKernel {
104  public:
RGBToHSVOp(OpKernelConstruction * context)105   explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
106 
Compile(XlaOpKernelContext * context)107   void Compile(XlaOpKernelContext* context) override {
108     const TensorShape input_shape = context->InputShape(0);
109     OP_REQUIRES(context, input_shape.dims() >= 1,
110                 errors::InvalidArgument("input must be at least 1D",
111                                         input_shape.DebugString()));
112     int channel_dim = input_shape.dims() - 1;
113     int64_t channels = input_shape.dim_size(channel_dim);
114     OP_REQUIRES(
115         context, channels == 3,
116         errors::FailedPrecondition("input must have 3 channels but input has ",
117                                    channels, " channels."));
118 
119     xla::XlaBuilder* b = context->builder();
120     xla::XlaOp input = context->Input(0);
121 
122     xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
123                                      /*limit_index=*/1, /*stride=*/1,
124                                      /*dimno=*/channel_dim);
125     xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
126                                        /*limit_index=*/2, /*stride=*/1,
127                                        /*dimno=*/channel_dim);
128     xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
129                                       /*limit_index=*/3, /*stride=*/1,
130                                       /*dimno=*/channel_dim);
131     TensorShape channel_shape = input_shape;
132     channel_shape.set_dim(channel_dim, 1);
133     auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
134                         channel_shape);
135 
136     context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim));
137   }
138 };
139 REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp);
140 
141 class HSVToRGBOp : public XlaOpKernel {
142  public:
HSVToRGBOp(OpKernelConstruction * context)143   explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
144 
Compile(XlaOpKernelContext * context)145   void Compile(XlaOpKernelContext* context) override {
146     const TensorShape input_shape = context->InputShape(0);
147     OP_REQUIRES(context, input_shape.dims() >= 1,
148                 errors::InvalidArgument("input must be at least 1D",
149                                         input_shape.DebugString()));
150     int channel_dim = input_shape.dims() - 1;
151     int64_t channels = input_shape.dim_size(channel_dim);
152     OP_REQUIRES(
153         context, channels == 3,
154         errors::FailedPrecondition("input must have 3 channels but input has ",
155                                    channels, " channels."));
156 
157     xla::XlaBuilder* b = context->builder();
158     xla::XlaOp input = context->Input(0);
159     xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0,
160                                      /*limit_index=*/1, /*stride=*/1,
161                                      /*dimno=*/channel_dim);
162     xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1,
163                                             /*limit_index=*/2, /*stride=*/1,
164                                             /*dimno=*/channel_dim);
165     xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2,
166                                        /*limit_index=*/3, /*stride=*/1,
167                                        /*dimno=*/channel_dim);
168 
169     auto rgb = HSVToRGB(context->builder(), {hue, saturation, value},
170                         context->input_type(0));
171 
172     context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
173   }
174 };
175 REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp);
176 
177 class AdjustContrastOpV2 : public XlaOpKernel {
178  public:
AdjustContrastOpV2(OpKernelConstruction * context)179   explicit AdjustContrastOpV2(OpKernelConstruction* context)
180       : XlaOpKernel(context) {}
181 
Compile(XlaOpKernelContext * context)182   void Compile(XlaOpKernelContext* context) override {
183     const TensorShape& input_shape = context->InputShape(0);
184     const TensorShape& factor_shape = context->InputShape(1);
185     OP_REQUIRES(context, input_shape.dims() >= 3,
186                 errors::InvalidArgument("input must be at least 3-D, got shape",
187                                         input_shape.DebugString()));
188     int height_dim = input_shape.dims() - 3;
189     int width_dim = input_shape.dims() - 2;
190     int channel_dim = input_shape.dims() - 1;
191     const int64_t height = input_shape.dim_size(height_dim);
192     const int64_t width = input_shape.dim_size(width_dim);
193 
194     OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape),
195                 errors::InvalidArgument("contrast_factor must be scalar: ",
196                                         factor_shape.DebugString()));
197 
198     xla::XlaBuilder* b = context->builder();
199     DataType type = context->input_type(0);
200 
201     xla::XlaOp input = context->Input(0);
202     xla::XlaOp factor = XlaHelpers::ConvertElementType(context->Input(1), type);
203 
204     const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
205     auto converted = XlaHelpers::ConvertElementType(input, accumulation_type);
206     auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
207                               *context->GetOrCreateAdd(accumulation_type),
208                               {height_dim, width_dim});
209 
210     auto output = xla::Div(
211         reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width));
212     output = XlaHelpers::ConvertElementType(output, type);
213 
214     std::vector<int64_t> broadcast_dims(input_shape.dims() - 2);
215     std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
216     broadcast_dims.back() = channel_dim;
217     output =
218         xla::Add(xla::Mul(input, factor),
219                  xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)),
220                  broadcast_dims);
221     context->SetOutput(0, output);
222   }
223 };
224 REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2);
225 
226 class AdjustSaturationOp : public XlaOpKernel {
227  public:
AdjustSaturationOp(OpKernelConstruction * context)228   explicit AdjustSaturationOp(OpKernelConstruction* context)
229       : XlaOpKernel(context) {}
230 
Compile(XlaOpKernelContext * context)231   void Compile(XlaOpKernelContext* context) override {
232     const TensorShape& input_shape = context->InputShape(0);
233     const TensorShape& scale_shape = context->InputShape(1);
234     OP_REQUIRES(context, input_shape.dims() >= 3,
235                 errors::InvalidArgument("input must be at least 3-D, got shape",
236                                         input_shape.DebugString()));
237     OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape),
238                 errors::InvalidArgument("scale must be scalar: ",
239                                         scale_shape.DebugString()));
240     const int channel_dim = input_shape.dims() - 1;
241     const int64_t channels = input_shape.dim_size(channel_dim);
242     OP_REQUIRES(
243         context, channels == 3,
244         errors::InvalidArgument("input must have 3 channels but instead has ",
245                                 channels, " channels."));
246 
247     xla::XlaBuilder* b = context->builder();
248     xla::XlaOp input =
249         XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
250     xla::XlaOp scale =
251         XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
252 
253     DataType type = context->input_type(0);
254 
255     xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
256                                      /*limit_index=*/1, /*stride=*/1,
257                                      /*dimno=*/channel_dim);
258     xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
259                                        /*limit_index=*/2, /*stride=*/1,
260                                        /*dimno=*/channel_dim);
261     xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
262                                       /*limit_index=*/3, /*stride=*/1,
263                                       /*dimno=*/channel_dim);
264     TensorShape channel_shape = input_shape;
265     channel_shape.set_dim(channel_dim, 1);
266     auto hsv =
267         RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
268 
269     hsv[1] = xla::Clamp(XlaHelpers::Zero(b, DT_FLOAT), xla::Mul(hsv[1], scale),
270                         XlaHelpers::One(b, DT_FLOAT));
271 
272     auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
273 
274     auto output = XlaHelpers::ConvertElementType(
275         xla::ConcatInDim(b, rgb, channel_dim), type);
276     context->SetOutput(0, output);
277   }
278 };
279 REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp);
280 
281 class AdjustHueOp : public XlaOpKernel {
282  public:
AdjustHueOp(OpKernelConstruction * context)283   explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
284 
Compile(XlaOpKernelContext * context)285   void Compile(XlaOpKernelContext* context) override {
286     const TensorShape& input_shape = context->InputShape(0);
287     const TensorShape& delta_shape = context->InputShape(1);
288     OP_REQUIRES(context, input_shape.dims() >= 3,
289                 errors::InvalidArgument("input must be at least 3-D, got shape",
290                                         input_shape.DebugString()));
291     OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape),
292                 errors::InvalidArgument("delta must be scalar: ",
293                                         delta_shape.DebugString()));
294     const int channel_dim = input_shape.dims() - 1;
295     const int64_t channels = input_shape.dim_size(channel_dim);
296     OP_REQUIRES(
297         context, channels == 3,
298         errors::InvalidArgument("input must have 3 channels but instead has ",
299                                 channels, " channels."));
300 
301     xla::XlaBuilder* b = context->builder();
302     xla::XlaOp input =
303         XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT);
304     xla::XlaOp delta =
305         XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT);
306 
307     DataType type = context->input_type(0);
308 
309     xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
310                                      /*limit_index=*/1, /*stride=*/1,
311                                      /*dimno=*/channel_dim);
312     xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
313                                        /*limit_index=*/2, /*stride=*/1,
314                                        /*dimno=*/channel_dim);
315     xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
316                                       /*limit_index=*/3, /*stride=*/1,
317                                       /*dimno=*/channel_dim);
318     TensorShape channel_shape = input_shape;
319     channel_shape.set_dim(channel_dim, 1);
320     auto hsv =
321         RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape);
322 
323     auto zero = XlaHelpers::Zero(b, DT_FLOAT);
324     auto one = XlaHelpers::One(b, DT_FLOAT);
325 
326     auto& hue = hsv[0];
327     hue = xla::Rem(xla::Add(hsv[0], delta), one);
328     hue =
329         xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue);
330 
331     auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT);
332 
333     auto output = XlaHelpers::ConvertElementType(
334         xla::ConcatInDim(b, rgb, channel_dim), type);
335     context->SetOutput(0, output);
336   }
337 };
338 REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
339 
340 struct WhileCondFn {
341   const int64_t num_boxes;
342   const int64_t output_size;
343 
WhileCondFntensorflow::__anonabc971000111::WhileCondFn344   explicit WhileCondFn(int64_t num_boxes, int64_t output_size)
345       : num_boxes(num_boxes), output_size(output_size) {}
346 
operator ()tensorflow::__anonabc971000111::WhileCondFn347   StatusOr<xla::XlaOp> operator()(absl::Span<const xla::XlaOp> values,
348                                   xla::XlaBuilder* cond_builder) const {
349     xla::XlaOp row_idx = values[0];
350     xla::XlaOp row_in_bounds =
351         xla::Lt(row_idx, xla::ConstantR0<int32>(cond_builder, num_boxes));
352     xla::XlaOp num_outputs_so_far = values[1];
353     xla::XlaOp results_not_full = xla::Lt(
354         num_outputs_so_far, xla::ConstantR0<int32>(cond_builder, output_size));
355     return xla::And(row_in_bounds, results_not_full);
356   }
357 };
358 
359 // Process the boxes one-by-one using the iou matrix mask.
360 // This implementation uses a correct, but greedy, sequential algorithm
361 // to ensure that suppressed boxes cannot themselves suppress other
362 // boxes.
363 struct SuppressBodyFn {
364   const int64_t num_boxes;
365 
SuppressBodyFntensorflow::__anonabc971000111::SuppressBodyFn366   explicit SuppressBodyFn(int64_t num_boxes) : num_boxes(num_boxes) {}
367 
operator ()tensorflow::__anonabc971000111::SuppressBodyFn368   StatusOr<std::vector<xla::XlaOp>> operator()(
369       absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) const {
370     auto row_idx = values[0];
371     auto num_outputs_so_far = values[1];
372     auto iou_mask = values[2];
373     auto included_iou = values[3];
374     auto zero = xla::ConstantR0<int32>(builder, 0);
375     // Determine if current elem is active using a slice.
376     // TODO(b/118437727): The only reason we need an explicit vector is because
377     // some old GCCs can't deduce the right type for MakeConstSpan, and
378     // providing a single-value initializer list directly uses the wrong
379     // overload. Delete this once the deprecated overload is gone.
380     std::vector<xla::XlaOp> row_idx_vector = {row_idx};
381     auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1});
382     active_elem = xla::Reshape(active_elem, {});
383     // Increment output count iff current elem is not suppressed.
384     num_outputs_so_far = xla::Select(
385         active_elem, num_outputs_so_far + xla::ConstantR0<int32>(builder, 1),
386         num_outputs_so_far);
387     // Slice out the row_idx.
388     auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes});
389     // Remove the diagonal from consideration. An elem cannot suppress
390     // itself.
391     row_iou = xla::DynamicUpdateSlice(
392         row_iou, xla::ConstantR2FromArray2D<bool>(builder, {{false}}),
393         {zero, row_idx});
394     // Create a suppression by inverting polarity.
395     row_iou = xla::Reshape(row_iou, {num_boxes});
396     auto supp_mask = xla::Not(row_iou);
397     // Update mask iff current elem is not suppressed.
398     included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}),
399                                xla::And(included_iou, supp_mask), included_iou);
400     row_idx = row_idx + xla::ConstantR0<int32>(builder, 1);
401     return std::vector<xla::XlaOp>{row_idx, num_outputs_so_far, iou_mask,
402                                    included_iou};
403   }
404 };
405 
406 class NonMaxSuppressionOp : public XlaOpKernel {
407  public:
NonMaxSuppressionOp(OpKernelConstruction * context)408   explicit NonMaxSuppressionOp(OpKernelConstruction* context)
409       : XlaOpKernel(context) {
410     OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
411                                              &pad_to_max_output_size_));
412   }
413 
Compile(XlaOpKernelContext * context)414   void Compile(XlaOpKernelContext* context) override {
415     // TODO(b/111646731): Improve scalability of this op, using blocking.
416     OP_REQUIRES(context, pad_to_max_output_size_,
417                 errors::Unimplemented(
418                     "XLA compilation requires pad_to_max_output_size == True"));
419 
420     xla::XlaOp selected_indices, num_valid;
421     ComputeResult(context, pad_to_max_output_size_);
422   }
ComputeResult(XlaOpKernelContext * context,bool pad_to_max_output_size=false)423   static void ComputeResult(XlaOpKernelContext* context,
424                             bool pad_to_max_output_size = false) {
425     const TensorShape& boxes_shape = context->InputShape("boxes");
426     OP_REQUIRES(
427         context, TensorShapeUtils::IsMatrix(boxes_shape),
428         errors::InvalidArgument("boxes must be 2-D, currently: [",
429                                 std::to_string(boxes_shape.dim_size(0)), ",",
430                                 std::to_string(boxes_shape.dim_size(1)), "]"));
431     const int64_t num_boxes = boxes_shape.dim_size(0);
432     OP_REQUIRES(
433         context, boxes_shape.dim_size(1) == 4,
434         errors::InvalidArgument("boxes must have 4 columns, currently: ",
435                                 std::to_string(boxes_shape.dim_size(1))));
436     const TensorShape& scores_shape = context->InputShape("scores");
437     OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape),
438                 errors::InvalidArgument("scores must be 1-D, currently: ",
439                                         scores_shape.DebugString()));
440     OP_REQUIRES(context, scores_shape.dim_size(0) == num_boxes,
441                 errors::InvalidArgument(
442                     "scores size ", std::to_string(scores_shape.dim_size(0)),
443                     " must equal number of boxes ", std::to_string(num_boxes)));
444     OP_REQUIRES(context, num_boxes <= kint32max,
445                 errors::InvalidArgument("XLA compilation requires number of "
446                                         "boxes to be <= kint32max, got ",
447                                         num_boxes));
448     xla::PrimitiveType boxes_xla_type = context->InputXlaType("boxes");
449     xla::PrimitiveType scores_xla_type = context->InputXlaType("scores");
450     const xla::XlaOp boxes_input = context->Input("boxes");
451     const xla::XlaOp scores_input = context->Input("scores");
452     int64_t output_size;
453     OP_REQUIRES(
454         context,
455         TensorShapeUtils::IsScalar(context->InputShape("max_output_size")),
456         errors::InvalidArgument("Max Output Size isn't a scalar"));
457     OP_REQUIRES(
458         context,
459         TensorShapeUtils::IsScalar(context->InputShape("iou_threshold")),
460         errors::InvalidArgument("IOU Threshold isn't a scalar"));
461     OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size));
462     OP_REQUIRES(
463         context, output_size >= 0,
464         errors::InvalidArgument("Need output_size >= 0, got ", output_size));
465     OP_REQUIRES(context, output_size <= kint32max,
466                 errors::InvalidArgument("Need output_size <= kint32Max, got ",
467                                         output_size));
468     const xla::XlaOp score_thresh = context->Input("score_threshold");
469     const xla::XlaOp iou_thresh = context->Input("iou_threshold");
470     xla::XlaBuilder* const builder = context->builder();
471 
472     // Choose a more convenient layout.
473     const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0});
474     const xla::XlaOp boxes_sorted = xla::GetTupleElement(
475         xla::Sort({xla::Broadcast(scores_input, {4}), boxes},
476                   xla::CreateScalarGtComputation(
477                       {scores_xla_type, boxes_xla_type}, builder),
478                   /*dimension=*/1),
479         1);
480     // Track the mapping of indices into sorted domain.
481     const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes);
482     const xla::XlaOp indices_sort = xla::Sort(
483         {scores_input, iota_indices},
484         xla::CreateScalarGtComputation({scores_xla_type, xla::S32}, builder));
485     const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1);
486     const xla::XlaOp scores = xla::GetTupleElement(indices_sort, 0);
487 
488     // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0.
489     const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
490                                                          /*start_index=*/0,
491                                                          /*limit_index=*/1,
492                                                          /*stride=*/1,
493                                                          /*dimno=*/0),
494                                          {num_boxes});
495     const xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted,
496                                                          /*start_index=*/1,
497                                                          /*limit_index=*/2,
498                                                          /*stride=*/1,
499                                                          /*dimno=*/0),
500                                          {num_boxes});
501     const xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
502                                                          /*start_index=*/2,
503                                                          /*limit_index=*/3,
504                                                          /*stride=*/1,
505                                                          /*dimno=*/0),
506                                          {num_boxes});
507     const xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted,
508                                                          /*start_index=*/3,
509                                                          /*limit_index=*/4,
510                                                          /*stride=*/1,
511                                                          /*dimno=*/0),
512                                          {num_boxes});
513 
514     xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1);
515     xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0);
516     xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1);
517     xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0);
518     xla::XlaOp area = (y2 - y1) * (x2 - x1);
519 
520     // Shapes are henceforth [1, num_boxes].
521     y1 = xla::Broadcast(y1, {1});
522     y2 = xla::Broadcast(y2, {1});
523     x1 = xla::Broadcast(x1, {1});
524     x2 = xla::Broadcast(x2, {1});
525     area = xla::Broadcast(area, {1});
526 
527     // Shapes are henceforth [num_boxes, num_boxes].
528     xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0}));
529     xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0}));
530     xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0}));
531     xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0}));
532     auto square_zero = xla::ZerosLike(i_xmin);
533 
534     xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) *
535                         xla::Max(i_ymax - i_ymin, square_zero);
536     xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area;
537     xla::XlaOp iou = i_area / u_area;
538 
539     xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero);
540     xla::XlaOp included_iou =
541         xla::Broadcast(xla::ConstantR0<bool>(builder, true), {num_boxes});
542 
543     std::vector<xla::XlaOp> init_values;
544     init_values.reserve(4);
545     init_values.push_back(xla::ConstantR0<int32>(builder, 0));  // col_idx
546     init_values.push_back(xla::ConstantR0<int32>(builder, 0));  // num_outputs
547     init_values.push_back(iou_thresh_mask);
548     init_values.push_back(included_iou);
549 
550     auto suppress_loop_result =
551         xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
552                              SuppressBodyFn(num_boxes), init_values,
553                              "suppress_loop", builder)
554             .ValueOrDie();
555 
556     xla::XlaOp included_score =
557         xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes}));
558     xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]);
559 
560     // Only consider boxes over which we have iterated. This allows for accurate
561     // counting. DynamicSlice would require knowledge of the size of the output.
562     auto valid_elem = xla::Lt(
563         iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes}));
564     included = xla::And(included, valid_elem);
565 
566     xla::XlaOp neg_inf =
567         xla::Broadcast(xla::MinValue(builder, boxes_xla_type), {num_boxes});
568     xla::XlaOp scores_included = xla::Select(included, scores, neg_inf);
569     xla::XlaOp output_tuple = TopK(scores_included, output_size);
570     xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1);
571     // Calculate num_valid.
572     // Note: num_valid cannot be taken from the loop outputs, because outputs
573     // can be suppressed by score threshold.
574     xla::XlaOp ones_included = xla::Select(
575         included,
576         xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
577         xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
578     // num_valid is scalar. Value should be bound by output_size.
579 
580     xla::XlaOp num_valid_total = xla::Reduce(
581         ones_included,
582         /*init_value=*/xla::ConstantR0<int>(builder, 0),
583         /*computation=*/CreateScalarAddComputation(xla::S32, builder),
584         /*dimensions_to_reduce=*/{0});
585     xla::XlaOp num_valid =
586         xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
587 
588     // Re-index into the original scores input tensor, using a Gather.
589     // Boxes were suppressed in the sorted domain.
590     xla::XlaOp selected_indices;
591     DataType gather_type = context->expected_output_dtype(0);
592     OP_REQUIRES_OK(
593         context,
594         XlaGather(indices_sorted, scores_shape, selected_indices_sorted,
595                   TensorShape({output_size}),
596                   /*axis=*/0,
597                   /*indices_are_nd=*/false,
598                   /*dtype=*/gather_type, DT_INT32, builder, &selected_indices));
599 
600     if (!pad_to_max_output_size) {
601       StatusOr<xla::XlaOp> rebounded_result = xla::SetDimensionSizeWithRebound(
602           &context->value_inference(), selected_indices, num_valid, 0);
603       if (rebounded_result.ok()) {
604         selected_indices = *rebounded_result;
605       } else {
606         // TODO(b/207187072): Remove special handling once dynamic reshape
607         // can also be handled.
608         selected_indices =
609             xla::SetDimensionSize(selected_indices, num_valid, 0);
610       }
611     }
612     context->SetOutput(0, selected_indices);
613     if (pad_to_max_output_size) context->SetOutput(1, num_valid);
614   }
615 
616  private:
617   bool pad_to_max_output_size_;
618 };
619 
620 REGISTER_XLA_OP(
621     Name("NonMaxSuppressionV4").CompileTimeConstantInput("max_output_size"),
622     NonMaxSuppressionOp);
623 
624 class NonMaxSuppressionV3Op : public XlaOpKernel {
625  public:
NonMaxSuppressionV3Op(OpKernelConstruction * context)626   explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
627       : XlaOpKernel(context) {}
628 
Compile(XlaOpKernelContext * context)629   void Compile(XlaOpKernelContext* context) override {
630     xla::XlaOp selected_indices, num_valid;
631     NonMaxSuppressionOp::ComputeResult(context);
632   }
633 };
634 
635 REGISTER_XLA_OP(
636     Name("NonMaxSuppressionV3").CompileTimeConstantInput("max_output_size"),
637     NonMaxSuppressionV3Op);
638 
639 }  // namespace
640 }  // namespace tensorflow
641