1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28
DynamicStridedSlice(XlaOp input,absl::Span<const XlaOp> base_indices,absl::Span<const int64_t> window_sizes,absl::Span<const int64_t> strides)29 XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices,
30 absl::Span<const int64_t> window_sizes,
31 absl::Span<const int64_t> strides) {
32 XlaOp sliced_input = DynamicSlice(input, base_indices, window_sizes);
33 if (std::any_of(strides.begin(), strides.end(),
34 [](int64_t stride) { return stride != 1; })) {
35 sliced_input =
36 Slice(sliced_input, std::vector<int64_t>(window_sizes.size()),
37 window_sizes, strides);
38 }
39 return sliced_input;
40 }
41
SliceInMinorDims(XlaOp x,absl::Span<const int64_t> start,absl::Span<const int64_t> end)42 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64_t> start,
43 absl::Span<const int64_t> end) {
44 XlaBuilder* builder = x.builder();
45 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
46 TF_RET_CHECK(start.size() == end.size());
47 int64_t n_minor_dims = start.size();
48
49 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
50
51 const int64_t n_dims = shape.rank();
52 TF_RET_CHECK(n_minor_dims <= n_dims);
53 auto major_dims = shape.dimensions().subspan(
54 /*pos=*/0,
55 /*len=*/n_dims - n_minor_dims);
56
57 // Prepends 0s in the major dim
58 std::vector<int64_t> padded_start(n_dims, 0);
59 std::copy(start.begin(), start.end(),
60 padded_start.begin() + major_dims.size());
61
62 // Prepends the shape of the major dims.
63 std::vector<int64_t> padded_end(n_dims);
64 std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
65 std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
66
67 std::vector<int64_t> strides(n_dims, 1);
68 return Slice(x, padded_start, padded_end, strides);
69 });
70 }
71
UpdateSlice(XlaOp x,XlaOp update,absl::Span<const int64_t> start)72 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64_t> start) {
73 XlaBuilder* builder = x.builder();
74 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
75 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
76 const int64_t n_dims = shape.rank();
77 const int64_t start_size = start.size();
78 TF_RET_CHECK(start_size == n_dims);
79
80 // TODO(phawkins): make int64_t work on all backends, remove the int32_t
81 // cast.
82 std::vector<int32_t> start_as_int32(start.begin(), start.end());
83 std::vector<XlaOp> start_ops(start.size());
84 for (int i = 0, end = start.size(); i < end; ++i) {
85 start_ops[i] = ConstantR0(builder, start_as_int32[i]);
86 }
87 return DynamicUpdateSlice(x, update, start_ops);
88 });
89 }
90
UpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const int64_t> start)91 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
92 absl::Span<const int64_t> start) {
93 XlaBuilder* builder = x.builder();
94 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
95 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
96 const int64_t n_dims = shape.rank();
97 const int64_t n_minor_dims = start.size();
98 TF_RET_CHECK(n_minor_dims <= n_dims);
99 std::vector<int64_t> padded_start(n_dims, 0);
100 std::copy(start.begin(), start.end(),
101 padded_start.begin() + (n_dims - n_minor_dims));
102 return UpdateSlice(x, update, padded_start);
103 });
104 }
105
106 namespace {
107
ConcatVectors(absl::Span<const int64_t> xs,absl::Span<const int64_t> ys)108 std::vector<int64_t> ConcatVectors(absl::Span<const int64_t> xs,
109 absl::Span<const int64_t> ys) {
110 std::vector<int64_t> output(xs.size() + ys.size());
111 std::copy(xs.begin(), xs.end(), output.begin());
112 std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
113 return output;
114 }
115
PrependZerosInMajorDims(XlaOp x,absl::Span<const XlaOp> starts)116 StatusOr<std::vector<XlaOp>> PrependZerosInMajorDims(
117 XlaOp x, absl::Span<const XlaOp> starts) {
118 XlaBuilder* builder = x.builder();
119 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
120 const int64_t n_dims = shape.rank();
121 auto zero = ConstantR0<int32_t>(builder, 0);
122 std::vector<XlaOp> padded_starts(n_dims, zero);
123 for (int i = 0; i < starts.size(); ++i) {
124 padded_starts[n_dims - starts.size() + i] = starts[i];
125 }
126 return padded_starts;
127 }
128
129 } // namespace
130
DynamicSliceInMinorDims(XlaOp x,absl::Span<const XlaOp> starts,absl::Span<const int64_t> sizes)131 XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
132 absl::Span<const int64_t> sizes) {
133 XlaBuilder* builder = x.builder();
134 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
135 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
136 const int64_t n_dims = shape.rank();
137 int64_t n_minor_dims = starts.size();
138 TF_RET_CHECK(n_minor_dims == sizes.size());
139 TF_RET_CHECK(n_minor_dims <= n_dims);
140 auto major_dims = shape.dimensions().subspan(
141 /*pos=*/0,
142 /*len=*/n_dims - sizes.size());
143 TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
144 auto padded_sizes = ConcatVectors(major_dims, sizes);
145 return DynamicSlice(x, padded_starts, padded_sizes);
146 });
147 }
148
DynamicUpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const XlaOp> starts)149 XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
150 absl::Span<const XlaOp> starts) {
151 XlaBuilder* builder = x.builder();
152 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
153 TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
154 return DynamicUpdateSlice(x, update, padded_starts);
155 });
156 }
157
TorchGather(XlaOp input,XlaOp index,int64_t dim,bool sparse)158 XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) {
159 XlaBuilder* builder = input.builder();
160 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
161 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
162 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
163 if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
164 input_shape.dimensions(dim) < std::numeric_limits<uint32_t>::max()) {
165 index = ConvertElementType(index, U32);
166 index_shape.set_element_type(U32);
167 }
168 if (index_shape.rank() == 1) {
169 return TorchIndexSelect(input, index, 0);
170 }
171 if (!sparse) {
172 std::vector<int64_t> index_broadcast_dims;
173 std::vector<int64_t> input_broadcast_dims;
174 std::vector<int64_t> sizes;
175 sizes.reserve(index_shape.rank());
176 for (int64_t i = 0; i < index_shape.rank(); ++i) {
177 if (i < dim) {
178 input_broadcast_dims.push_back(i);
179 index_broadcast_dims.push_back(i);
180 } else if (i == dim) {
181 sizes.push_back(input_shape.dimensions(i));
182 input_broadcast_dims.push_back(i);
183 index_broadcast_dims.push_back(i + 1);
184 } else {
185 input_broadcast_dims.push_back(i + 1);
186 index_broadcast_dims.push_back(i + 1);
187 }
188 sizes.push_back(index_shape.dimensions(i));
189 }
190 auto mask = Eq(
191 BroadcastInDim(index, sizes, index_broadcast_dims),
192 Iota(builder, ShapeUtil::MakeShape(index_shape.element_type(), sizes),
193 dim));
194 auto masked_input = Select(
195 mask, BroadcastInDim(input, sizes, input_broadcast_dims),
196 Zeros(builder,
197 ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
198 return Reduce(masked_input, Zero(builder, input_shape.element_type()),
199 CreateScalarIdentityWithZeroComputation(
200 input_shape.element_type(), builder),
201 {dim});
202 }
203
204 ShapeUtil::AppendMajorDimension(1, &index_shape);
205 std::vector<XlaOp> to_concat;
206
207 to_concat.reserve(input_shape.rank());
208 for (int64_t i = 0; i < input_shape.rank(); ++i) {
209 if (i == dim) {
210 to_concat.push_back(Reshape(index, index_shape.dimensions()));
211 } else {
212 to_concat.push_back(Iota(builder, index_shape, i));
213 }
214 }
215 XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank());
216 std::vector<int64_t> slice_sizes(input_shape.rank(), 1);
217 GatherDimensionNumbers gather_dnums;
218 gather_dnums.set_index_vector_dim(input_shape.rank());
219 for (int64_t i = 0; i < input_shape.rank(); ++i) {
220 gather_dnums.add_collapsed_slice_dims(i);
221 gather_dnums.add_start_index_map(i);
222 }
223 return Gather(input, gather_indices, gather_dnums, slice_sizes);
224 });
225 }
226
TorchScatterDense(XlaOp input,XlaOp index,XlaOp src,int64_t dim,const std::function<XlaOp (XlaOp,XlaOp)> & combiner)227 XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim,
228 const std::function<XlaOp(XlaOp, XlaOp)>& combiner) {
229 XlaBuilder* builder = input.builder();
230 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
231 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
232 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
233 std::vector<int64_t> index_broadcast_dims;
234 std::vector<int64_t> sizes;
235 const auto rank = index_shape.rank();
236 sizes.reserve(rank + 1);
237 for (int64_t i = 0; i < index_shape.rank(); ++i) {
238 if (i < dim) {
239 index_broadcast_dims.push_back(i);
240 } else {
241 if (i == dim) {
242 sizes.push_back(input_shape.dimensions(i));
243 }
244 index_broadcast_dims.push_back(i + 1);
245 }
246 sizes.push_back(index_shape.dimensions(i));
247 }
248 auto mask =
249 Eq(BroadcastInDim(index, sizes, index_broadcast_dims),
250 Iota(builder,
251 ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim));
252 auto masked_src =
253 Select(mask, BroadcastInDim(src, sizes, index_broadcast_dims),
254 Zeros(builder,
255 ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
256
257 return combiner(
258 input,
259 Reduce(masked_src, Zero(builder, input_shape.element_type()),
260 CreateScalarComputation("reducer", input_shape.element_type(),
261 builder, combiner),
262 {dim + 1}));
263 });
264 }
265
TorchIndexSelect(XlaOp input,XlaOp index,int64_t dim,int64_t batch_dims)266 XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim,
267 int64_t batch_dims) {
268 XlaBuilder* builder = input.builder();
269 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
270 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
271 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
272 if (dim < batch_dims) {
273 return InvalidArgument(
274 "Gather dim must be greater than or equal to the number of batch "
275 "dims");
276 }
277 if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
278 input_shape.dimensions(dim) < std::numeric_limits<uint32_t>::max()) {
279 index = ConvertElementType(index, U32);
280 index_shape.set_element_type(U32);
281 }
282 std::vector<int64_t> slice_sizes = SpanToVector(input_shape.dimensions());
283 GatherDimensionNumbers gather_dnums;
284 gather_dnums.set_index_vector_dim(index_shape.rank());
285 if (batch_dims > 0) {
286 ShapeUtil::AppendMajorDimension(1, &index_shape);
287 std::vector<XlaOp> to_concat;
288 to_concat.reserve(batch_dims + 1);
289 for (int64_t batch_dim = 0; batch_dim < batch_dims; ++batch_dim) {
290 to_concat.push_back(Iota(builder, index_shape, batch_dim));
291 }
292 to_concat.push_back(Reshape(index, index_shape.dimensions()));
293 index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim());
294 }
295 for (int64_t i = 0; i < input_shape.rank(); ++i) {
296 if (i < batch_dims || i == dim) {
297 slice_sizes[i] = std::min<int64_t>(slice_sizes[i], 1);
298 gather_dnums.add_collapsed_slice_dims(i);
299 gather_dnums.add_start_index_map(i);
300 } else {
301 if (i < dim) {
302 gather_dnums.add_offset_dims(i);
303 } else {
304 gather_dnums.add_offset_dims(i + gather_dnums.index_vector_dim() -
305 (1 + batch_dims));
306 }
307 }
308 }
309 return Gather(input, index, gather_dnums, slice_sizes);
310 });
311 }
312
313 } // namespace xla
314