1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <sys/types.h> 17 18 #include <utility> 19 #include <vector> 20 21 #include "absl/types/optional.h" 22 #include "tensorflow/compiler/tf2xla/literal_util.h" 23 #include "tensorflow/compiler/tf2xla/type_util.h" 24 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 27 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 28 #include "tensorflow/compiler/xla/client/lib/comparators.h" 29 #include "tensorflow/compiler/xla/client/lib/constants.h" 30 #include "tensorflow/compiler/xla/client/xla_builder.h" 31 #include "tensorflow/compiler/xla/client/xla_computation.h" 32 #include "tensorflow/compiler/xla/comparison_util.h" 33 #include "tensorflow/compiler/xla/literal.h" 34 #include "tensorflow/compiler/xla/shape.h" 35 #include "tensorflow/compiler/xla/shape_util.h" 36 #include "tensorflow/compiler/xla/util.h" 37 #include "tensorflow/compiler/xla/xla_data.pb.h" 38 #include "tensorflow/core/framework/op_kernel.h" 39 #include "tensorflow/core/framework/ops_util.h" 40 #include "tensorflow/core/framework/register_types.h" 41 #include "tensorflow/core/framework/tensor.h" 42 #include "tensorflow/core/lib/core/status.h" 43 #include "tensorflow/core/tpu/tpu_defs.h" 44 45 namespace tensorflow { 46 namespace { 47 48 class UniqueOpBase : public XlaOpKernel { 49 public: UniqueOpBase(OpKernelConstruction * ctx)50 explicit UniqueOpBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 51 DataType dtype; 52 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_idx", &dtype)); 53 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &idx_type_)); 54 } 55 56 // Transpose a tensor by moving axis `from` into `to`. MoveAxis(xla::XlaOp a,int64_t from,int64_t to,const xla::Shape & input_shape)57 xla::XlaOp MoveAxis(xla::XlaOp a, int64_t from, int64_t to, 58 const xla::Shape& input_shape) { 59 std::vector<int64_t> permutation; 60 permutation.reserve(input_shape.rank()); 61 for (int64_t i = 0; i < input_shape.rank(); ++i) { 62 permutation.push_back(i); 63 } 64 std::swap(permutation[from], permutation[to]); 65 return xla::Transpose(a, permutation); 66 } 67 CumSumR1(XlaOpKernelContext * ctx,xla::XlaOp input,int64_t size)68 xla::XlaOp CumSumR1(XlaOpKernelContext* ctx, xla::XlaOp input, int64_t size) { 69 auto init = xla::Zero(ctx->builder(), xla::S32); 70 auto reducer = xla::CreateScalarAddComputation(xla::S32, ctx->builder()); 71 72 return xla::ReduceWindowWithGeneralPadding( 73 input, init, reducer, {size}, {1}, 74 /*base_dilations=*/{}, /*window_dilations=*/{}, {{size - 1, 0}}); 75 } 76 77 // RollingSelectR1 takes two arrays: `data` and `mask`. It scans this two 78 // arrays in parallel and accumulates outputs into `accum`. 79 // 80 // For each position i, accum[i] = data[i] 81 // if mask[i] = 1 or accum[i - 1] if mask[i] = 0. 82 // 83 // Requires mask[0] = 1, meaning that accum[i - 1] will never be accessed. 84 // 85 // This is implemented as an hlo while loop. RollingSelectR1(XlaOpKernelContext * ctx,xla::XlaOp data,xla::XlaOp mask,int64_t size)86 xla::XlaOp RollingSelectR1(XlaOpKernelContext* ctx, xla::XlaOp data, 87 xla::XlaOp mask, int64_t size) { 88 xla::XlaComputation cond, body; 89 const xla::Shape r1_shape = xla::ShapeUtil::MakeShape(xla::S32, {size}); 90 const xla::Shape counter_shape = xla::ShapeUtil::MakeScalarShape(xla::S32); 91 const xla::Shape& single_element_shape = counter_shape; 92 93 auto loop_shape = xla::ShapeUtil::MakeTupleShape( 94 {counter_shape, r1_shape, r1_shape, r1_shape}); 95 { 96 std::unique_ptr<xla::XlaBuilder> builder = 97 ctx->builder()->CreateSubBuilder("loop_cond"); 98 auto param = xla::Parameter(builder.get(), 0, loop_shape, "param"); 99 auto counter = xla::GetTupleElement(param, 0); 100 auto limit = xla::ConstantR0<int32_t>(builder.get(), size); 101 xla::Lt(counter, limit); 102 103 cond = builder->Build().value(); 104 } 105 106 { 107 std::unique_ptr<xla::XlaBuilder> builder = 108 ctx->builder()->CreateSubBuilder("loop_body"); 109 auto param = xla::Parameter(builder.get(), 0, loop_shape, "param"); 110 auto counter = xla::GetTupleElement(param, 0); 111 112 auto data_stack = xla::GetTupleElement(param, 1); 113 auto data = xla::DynamicSlice(data_stack, {counter}, {1}); 114 data = xla::Reshape(single_element_shape, data); 115 116 auto mask_stack = xla::GetTupleElement(param, 2); 117 auto mask = xla::DynamicSlice(mask_stack, {counter}, {1}); 118 mask = xla::Reshape(single_element_shape, mask); 119 120 auto counter_minus = counter - xla::One(builder.get(), xla::S32); 121 // If counter = 0, then counter_minus = 0. 122 auto zero = xla::Zero(builder.get(), xla::S32); 123 counter_minus = xla::Select(xla::Eq(counter, zero), zero, counter_minus); 124 125 auto accum_stack = xla::GetTupleElement(param, 3); 126 auto accum_minus = xla::DynamicSlice(accum_stack, {counter_minus}, {1}); 127 accum_minus = xla::Reshape(single_element_shape, accum_minus); 128 129 auto accum = xla::Select(xla::ConvertElementType(mask, xla::PRED), data, 130 accum_minus); 131 accum_stack = xla::DynamicUpdateSlice( 132 accum_stack, xla::Reshape(accum, {1}), {counter}); 133 counter = counter + xla::One(builder.get(), xla::S32); 134 135 xla::Tuple(builder.get(), {counter, data_stack, mask_stack, accum_stack}); 136 body = builder->Build().value(); 137 } 138 139 auto zero = xla::Zero(ctx->builder(), xla::S32); 140 auto zero_broadcast = xla::Broadcast(zero, {size}); 141 auto init = xla::Tuple(ctx->builder(), {zero, data, mask, zero_broadcast}); 142 return xla::GetTupleElement(xla::While(cond, body, init), 3); 143 } 144 CompileWithAxis(XlaOpKernelContext * ctx,int64_t axis)145 void CompileWithAxis(XlaOpKernelContext* ctx, int64_t axis) { 146 xla::XlaOp input = ctx->Input(0); 147 StatusOr<xla::Shape> input_shape_or = ctx->builder()->GetShape(input); 148 OP_REQUIRES_OK(ctx, input_shape_or.status()); 149 auto input_shape = input_shape_or.ValueOrDie(); 150 auto aux = MoveAxis(input, axis, 0, input_shape); 151 auto aux_shape = ctx->builder()->GetShape(aux).ValueOrDie(); 152 int64_t leading_size = aux_shape.dimensions(0); 153 int64_t product = 1; 154 for (int64_t i = 1; i < aux_shape.rank(); ++i) { 155 product *= aux_shape.dimensions(i); 156 } 157 aux = xla::Reshape(aux, {leading_size, product}); 158 if (leading_size == 0) { 159 auto result_data = xla::Reshape(aux, aux_shape.dimensions()); 160 result_data = MoveAxis(result_data, 0, axis, aux_shape); 161 ctx->SetOutput(0, result_data); 162 ctx->SetOutput(1, xla::Iota(ctx->builder(), xla::S32, leading_size)); 163 return; 164 } 165 std::vector<xla::XlaOp> sort_keys; 166 sort_keys.reserve(product + 1); 167 std::vector<xla::PrimitiveType> sort_types; 168 sort_types.reserve(product + 1); 169 for (int64_t i = 0; i < product; ++i) { 170 xla::XlaOp slice = xla::SliceInDim(aux, i, i + 1, 1, 1); 171 sort_keys.push_back(xla::Reshape(slice, {leading_size})); 172 sort_types.push_back(input_shape.element_type()); 173 } 174 auto iota = xla::Iota(ctx->builder(), xla::S32, leading_size); 175 sort_keys.push_back(iota); 176 sort_types.push_back(xla::S32); 177 178 std::vector<std::optional<xla::XlaOp (*)(xla::XlaOp, xla::XlaOp, 179 absl::Span<const int64_t>)>> 180 generators(sort_types.size(), xla::LtTotalOrder); 181 auto lt_chain = xla::CreateScalarComparisonComputation( 182 "UniqueV2Lt", sort_types, generators, ctx->builder()); 183 184 auto sorted = xla::Sort(sort_keys, lt_chain, 0, /*is_stable=*/true); 185 // Last element is permutation. 186 xla::XlaOp perm; 187 if (sort_keys.size() == 1) { 188 perm = sorted; 189 } else { 190 perm = xla::GetTupleElement(sorted, sort_keys.size() - 1); 191 } 192 193 // Use gather to rearrange minor dimension. 194 xla::GatherDimensionNumbers gather_dim_numbers; 195 gather_dim_numbers.add_offset_dims(1); 196 // The dimension to rewrite is the index dim. 197 gather_dim_numbers.add_start_index_map(0); 198 gather_dim_numbers.set_index_vector_dim(1); 199 gather_dim_numbers.add_collapsed_slice_dims(0); 200 auto permuted = xla::Gather(aux, perm, gather_dim_numbers, {1, product}); 201 // Tail is everything except for first element. 202 auto tail = xla::SliceInDim(permuted, 1, leading_size, 1, 0); 203 // Init is everything except for last element. 204 auto init = xla::SliceInDim(permuted, 0, leading_size - 1, 1, 0); 205 auto ne = xla::Compare(tail, init, xla::ComparisonDirection::kNe); 206 auto reduce = 207 xla::Reduce(ne, xla::ConstantR0(ctx->builder(), false), 208 CreateScalarOrComputation(xla::PRED, ctx->builder()), {1}); 209 auto mask = xla::ConvertElementType(reduce, xla::S32); 210 mask = xla::PadInDim(mask, xla::One(ctx->builder(), xla::S32), 0, 1, 0); 211 auto iperm = RollingSelectR1(ctx, perm, mask, leading_size); 212 213 auto sort_by_iperm = 214 xla::Sort({iperm, mask, perm}, 215 xla::CreateScalarLtComputation({xla::S32, xla::S32, xla::S32}, 216 ctx->builder()), 217 0, 218 /*is_stable=*/true); 219 mask = xla::GetTupleElement(sort_by_iperm, 1); 220 // perm_sort is used later to revert the indices back to input order. 221 auto perm_sort = xla::GetTupleElement(sort_by_iperm, 2); 222 223 auto dynamic_size = xla::ReduceAll( 224 mask, xla::Zero(ctx->builder(), xla::S32), 225 xla::CreateScalarAddComputation(xla::S32, ctx->builder())); 226 auto mask_sort = xla::Sort( 227 {mask, perm_sort}, 228 xla::CreateScalarGtComputation({xla::S32, xla::S32}, ctx->builder()), 0, 229 /*is_stable=*/true); 230 auto mask_permute = xla::GetTupleElement(mask_sort, 1); 231 permuted = xla::Gather(aux, mask_permute, gather_dim_numbers, {1, product}); 232 auto result_data = xla::Reshape(permuted, aux_shape.dimensions()); 233 result_data = MoveAxis(result_data, 0, axis, aux_shape); 234 result_data = xla::SetDimensionSize(result_data, dynamic_size, axis); 235 ctx->SetOutput(0, result_data); 236 auto imask = CumSumR1(ctx, mask, leading_size); 237 imask = xla::Sub(imask, xla::One(ctx->builder(), xla::S32), {}); 238 auto idx = xla::GetTupleElement( 239 xla::Sort({perm_sort, imask}, 240 xla::CreateScalarLtComputation({xla::S32, xla::S32}, 241 ctx->builder())), 242 1); 243 idx = xla::ConvertElementType(idx, idx_type_); 244 ctx->SetOutput(1, idx); 245 } 246 247 private: 248 xla::PrimitiveType idx_type_; 249 }; 250 251 class UniqueOp : public UniqueOpBase { 252 public: UniqueOp(OpKernelConstruction * ctx)253 explicit UniqueOp(OpKernelConstruction* ctx) : UniqueOpBase(ctx) {} 254 Compile(XlaOpKernelContext * ctx)255 void Compile(XlaOpKernelContext* ctx) override { 256 CompileWithAxis(ctx, /*axis=*/0); 257 } 258 }; 259 260 REGISTER_XLA_OP(Name("Unique"), UniqueOp); 261 262 class UniqueV2Op : public UniqueOpBase { 263 public: UniqueV2Op(OpKernelConstruction * ctx)264 explicit UniqueV2Op(OpKernelConstruction* ctx) : UniqueOpBase(ctx) {} 265 Compile(XlaOpKernelContext * ctx)266 void Compile(XlaOpKernelContext* ctx) override { 267 std::vector<int64_t> axises; 268 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axises)); 269 OP_REQUIRES( 270 ctx, axises.size() <= 1, 271 xla::InvalidArgument("Only single axis unique op is supported")); 272 int64_t axis; 273 if (axises.empty()) { 274 axis = 0; 275 } else { 276 axis = axises.front(); 277 } 278 CompileWithAxis(ctx, /*axis=*/axis); 279 } 280 }; 281 282 REGISTER_XLA_OP(Name("UniqueV2").CompileTimeConstantInput("axis"), UniqueV2Op); 283 284 } // namespace 285 } // namespace tensorflow 286