xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/unique_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <sys/types.h>
17 
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
28 #include "tensorflow/compiler/xla/client/lib/comparators.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/comparison_util.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/shape.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/framework/ops_util.h"
40 #include "tensorflow/core/framework/register_types.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/tpu/tpu_defs.h"
44 
45 namespace tensorflow {
46 namespace {
47 
48 class UniqueOpBase : public XlaOpKernel {
49  public:
UniqueOpBase(OpKernelConstruction * ctx)50   explicit UniqueOpBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
51     DataType dtype;
52     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_idx", &dtype));
53     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &idx_type_));
54   }
55 
56   // Transpose a tensor by moving axis `from` into `to`.
MoveAxis(xla::XlaOp a,int64_t from,int64_t to,const xla::Shape & input_shape)57   xla::XlaOp MoveAxis(xla::XlaOp a, int64_t from, int64_t to,
58                       const xla::Shape& input_shape) {
59     std::vector<int64_t> permutation;
60     permutation.reserve(input_shape.rank());
61     for (int64_t i = 0; i < input_shape.rank(); ++i) {
62       permutation.push_back(i);
63     }
64     std::swap(permutation[from], permutation[to]);
65     return xla::Transpose(a, permutation);
66   }
67 
CumSumR1(XlaOpKernelContext * ctx,xla::XlaOp input,int64_t size)68   xla::XlaOp CumSumR1(XlaOpKernelContext* ctx, xla::XlaOp input, int64_t size) {
69     auto init = xla::Zero(ctx->builder(), xla::S32);
70     auto reducer = xla::CreateScalarAddComputation(xla::S32, ctx->builder());
71 
72     return xla::ReduceWindowWithGeneralPadding(
73         input, init, reducer, {size}, {1},
74         /*base_dilations=*/{}, /*window_dilations=*/{}, {{size - 1, 0}});
75   }
76 
77   // RollingSelectR1 takes two arrays: `data` and `mask`. It scans this two
78   // arrays in parallel and accumulates outputs into `accum`.
79   //
80   // For each position i, accum[i] = data[i]
81   // if mask[i] = 1 or accum[i - 1] if mask[i] = 0.
82   //
83   // Requires mask[0] = 1, meaning that accum[i - 1] will never be accessed.
84   //
85   // This is implemented as an hlo while loop.
RollingSelectR1(XlaOpKernelContext * ctx,xla::XlaOp data,xla::XlaOp mask,int64_t size)86   xla::XlaOp RollingSelectR1(XlaOpKernelContext* ctx, xla::XlaOp data,
87                              xla::XlaOp mask, int64_t size) {
88     xla::XlaComputation cond, body;
89     const xla::Shape r1_shape = xla::ShapeUtil::MakeShape(xla::S32, {size});
90     const xla::Shape counter_shape = xla::ShapeUtil::MakeScalarShape(xla::S32);
91     const xla::Shape& single_element_shape = counter_shape;
92 
93     auto loop_shape = xla::ShapeUtil::MakeTupleShape(
94         {counter_shape, r1_shape, r1_shape, r1_shape});
95     {
96       std::unique_ptr<xla::XlaBuilder> builder =
97           ctx->builder()->CreateSubBuilder("loop_cond");
98       auto param = xla::Parameter(builder.get(), 0, loop_shape, "param");
99       auto counter = xla::GetTupleElement(param, 0);
100       auto limit = xla::ConstantR0<int32_t>(builder.get(), size);
101       xla::Lt(counter, limit);
102 
103       cond = builder->Build().value();
104     }
105 
106     {
107       std::unique_ptr<xla::XlaBuilder> builder =
108           ctx->builder()->CreateSubBuilder("loop_body");
109       auto param = xla::Parameter(builder.get(), 0, loop_shape, "param");
110       auto counter = xla::GetTupleElement(param, 0);
111 
112       auto data_stack = xla::GetTupleElement(param, 1);
113       auto data = xla::DynamicSlice(data_stack, {counter}, {1});
114       data = xla::Reshape(single_element_shape, data);
115 
116       auto mask_stack = xla::GetTupleElement(param, 2);
117       auto mask = xla::DynamicSlice(mask_stack, {counter}, {1});
118       mask = xla::Reshape(single_element_shape, mask);
119 
120       auto counter_minus = counter - xla::One(builder.get(), xla::S32);
121       // If counter = 0, then counter_minus = 0.
122       auto zero = xla::Zero(builder.get(), xla::S32);
123       counter_minus = xla::Select(xla::Eq(counter, zero), zero, counter_minus);
124 
125       auto accum_stack = xla::GetTupleElement(param, 3);
126       auto accum_minus = xla::DynamicSlice(accum_stack, {counter_minus}, {1});
127       accum_minus = xla::Reshape(single_element_shape, accum_minus);
128 
129       auto accum = xla::Select(xla::ConvertElementType(mask, xla::PRED), data,
130                                accum_minus);
131       accum_stack = xla::DynamicUpdateSlice(
132           accum_stack, xla::Reshape(accum, {1}), {counter});
133       counter = counter + xla::One(builder.get(), xla::S32);
134 
135       xla::Tuple(builder.get(), {counter, data_stack, mask_stack, accum_stack});
136       body = builder->Build().value();
137     }
138 
139     auto zero = xla::Zero(ctx->builder(), xla::S32);
140     auto zero_broadcast = xla::Broadcast(zero, {size});
141     auto init = xla::Tuple(ctx->builder(), {zero, data, mask, zero_broadcast});
142     return xla::GetTupleElement(xla::While(cond, body, init), 3);
143   }
144 
CompileWithAxis(XlaOpKernelContext * ctx,int64_t axis)145   void CompileWithAxis(XlaOpKernelContext* ctx, int64_t axis) {
146     xla::XlaOp input = ctx->Input(0);
147     StatusOr<xla::Shape> input_shape_or = ctx->builder()->GetShape(input);
148     OP_REQUIRES_OK(ctx, input_shape_or.status());
149     auto input_shape = input_shape_or.ValueOrDie();
150     auto aux = MoveAxis(input, axis, 0, input_shape);
151     auto aux_shape = ctx->builder()->GetShape(aux).ValueOrDie();
152     int64_t leading_size = aux_shape.dimensions(0);
153     int64_t product = 1;
154     for (int64_t i = 1; i < aux_shape.rank(); ++i) {
155       product *= aux_shape.dimensions(i);
156     }
157     aux = xla::Reshape(aux, {leading_size, product});
158     if (leading_size == 0) {
159       auto result_data = xla::Reshape(aux, aux_shape.dimensions());
160       result_data = MoveAxis(result_data, 0, axis, aux_shape);
161       ctx->SetOutput(0, result_data);
162       ctx->SetOutput(1, xla::Iota(ctx->builder(), xla::S32, leading_size));
163       return;
164     }
165     std::vector<xla::XlaOp> sort_keys;
166     sort_keys.reserve(product + 1);
167     std::vector<xla::PrimitiveType> sort_types;
168     sort_types.reserve(product + 1);
169     for (int64_t i = 0; i < product; ++i) {
170       xla::XlaOp slice = xla::SliceInDim(aux, i, i + 1, 1, 1);
171       sort_keys.push_back(xla::Reshape(slice, {leading_size}));
172       sort_types.push_back(input_shape.element_type());
173     }
174     auto iota = xla::Iota(ctx->builder(), xla::S32, leading_size);
175     sort_keys.push_back(iota);
176     sort_types.push_back(xla::S32);
177 
178     std::vector<std::optional<xla::XlaOp (*)(xla::XlaOp, xla::XlaOp,
179                                              absl::Span<const int64_t>)>>
180         generators(sort_types.size(), xla::LtTotalOrder);
181     auto lt_chain = xla::CreateScalarComparisonComputation(
182         "UniqueV2Lt", sort_types, generators, ctx->builder());
183 
184     auto sorted = xla::Sort(sort_keys, lt_chain, 0, /*is_stable=*/true);
185     // Last element is permutation.
186     xla::XlaOp perm;
187     if (sort_keys.size() == 1) {
188       perm = sorted;
189     } else {
190       perm = xla::GetTupleElement(sorted, sort_keys.size() - 1);
191     }
192 
193     // Use gather to rearrange minor dimension.
194     xla::GatherDimensionNumbers gather_dim_numbers;
195     gather_dim_numbers.add_offset_dims(1);
196     // The dimension to rewrite is the index dim.
197     gather_dim_numbers.add_start_index_map(0);
198     gather_dim_numbers.set_index_vector_dim(1);
199     gather_dim_numbers.add_collapsed_slice_dims(0);
200     auto permuted = xla::Gather(aux, perm, gather_dim_numbers, {1, product});
201     // Tail is everything except for first element.
202     auto tail = xla::SliceInDim(permuted, 1, leading_size, 1, 0);
203     // Init is everything except for last element.
204     auto init = xla::SliceInDim(permuted, 0, leading_size - 1, 1, 0);
205     auto ne = xla::Compare(tail, init, xla::ComparisonDirection::kNe);
206     auto reduce =
207         xla::Reduce(ne, xla::ConstantR0(ctx->builder(), false),
208                     CreateScalarOrComputation(xla::PRED, ctx->builder()), {1});
209     auto mask = xla::ConvertElementType(reduce, xla::S32);
210     mask = xla::PadInDim(mask, xla::One(ctx->builder(), xla::S32), 0, 1, 0);
211     auto iperm = RollingSelectR1(ctx, perm, mask, leading_size);
212 
213     auto sort_by_iperm =
214         xla::Sort({iperm, mask, perm},
215                   xla::CreateScalarLtComputation({xla::S32, xla::S32, xla::S32},
216                                                  ctx->builder()),
217                   0,
218                   /*is_stable=*/true);
219     mask = xla::GetTupleElement(sort_by_iperm, 1);
220     // perm_sort is used later to revert the indices back to input order.
221     auto perm_sort = xla::GetTupleElement(sort_by_iperm, 2);
222 
223     auto dynamic_size = xla::ReduceAll(
224         mask, xla::Zero(ctx->builder(), xla::S32),
225         xla::CreateScalarAddComputation(xla::S32, ctx->builder()));
226     auto mask_sort = xla::Sort(
227         {mask, perm_sort},
228         xla::CreateScalarGtComputation({xla::S32, xla::S32}, ctx->builder()), 0,
229         /*is_stable=*/true);
230     auto mask_permute = xla::GetTupleElement(mask_sort, 1);
231     permuted = xla::Gather(aux, mask_permute, gather_dim_numbers, {1, product});
232     auto result_data = xla::Reshape(permuted, aux_shape.dimensions());
233     result_data = MoveAxis(result_data, 0, axis, aux_shape);
234     result_data = xla::SetDimensionSize(result_data, dynamic_size, axis);
235     ctx->SetOutput(0, result_data);
236     auto imask = CumSumR1(ctx, mask, leading_size);
237     imask = xla::Sub(imask, xla::One(ctx->builder(), xla::S32), {});
238     auto idx = xla::GetTupleElement(
239         xla::Sort({perm_sort, imask},
240                   xla::CreateScalarLtComputation({xla::S32, xla::S32},
241                                                  ctx->builder())),
242         1);
243     idx = xla::ConvertElementType(idx, idx_type_);
244     ctx->SetOutput(1, idx);
245   }
246 
247  private:
248   xla::PrimitiveType idx_type_;
249 };
250 
251 class UniqueOp : public UniqueOpBase {
252  public:
UniqueOp(OpKernelConstruction * ctx)253   explicit UniqueOp(OpKernelConstruction* ctx) : UniqueOpBase(ctx) {}
254 
Compile(XlaOpKernelContext * ctx)255   void Compile(XlaOpKernelContext* ctx) override {
256     CompileWithAxis(ctx, /*axis=*/0);
257   }
258 };
259 
260 REGISTER_XLA_OP(Name("Unique"), UniqueOp);
261 
262 class UniqueV2Op : public UniqueOpBase {
263  public:
UniqueV2Op(OpKernelConstruction * ctx)264   explicit UniqueV2Op(OpKernelConstruction* ctx) : UniqueOpBase(ctx) {}
265 
Compile(XlaOpKernelContext * ctx)266   void Compile(XlaOpKernelContext* ctx) override {
267     std::vector<int64_t> axises;
268     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axises));
269     OP_REQUIRES(
270         ctx, axises.size() <= 1,
271         xla::InvalidArgument("Only single axis unique op is supported"));
272     int64_t axis;
273     if (axises.empty()) {
274       axis = 0;
275     } else {
276       axis = axises.front();
277     }
278     CompileWithAxis(ctx, /*axis=*/axis);
279   }
280 };
281 
282 REGISTER_XLA_OP(Name("UniqueV2").CompileTimeConstantInput("axis"), UniqueV2Op);
283 
284 }  // namespace
285 }  // namespace tensorflow
286