xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/slicing.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 
DynamicStridedSlice(XlaOp input,absl::Span<const XlaOp> base_indices,absl::Span<const int64_t> window_sizes,absl::Span<const int64_t> strides)29 XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices,
30                           absl::Span<const int64_t> window_sizes,
31                           absl::Span<const int64_t> strides) {
32   XlaOp sliced_input = DynamicSlice(input, base_indices, window_sizes);
33   if (std::any_of(strides.begin(), strides.end(),
34                   [](int64_t stride) { return stride != 1; })) {
35     sliced_input =
36         Slice(sliced_input, std::vector<int64_t>(window_sizes.size()),
37               window_sizes, strides);
38   }
39   return sliced_input;
40 }
41 
SliceInMinorDims(XlaOp x,absl::Span<const int64_t> start,absl::Span<const int64_t> end)42 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64_t> start,
43                        absl::Span<const int64_t> end) {
44   XlaBuilder* builder = x.builder();
45   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
46     TF_RET_CHECK(start.size() == end.size());
47     int64_t n_minor_dims = start.size();
48 
49     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
50 
51     const int64_t n_dims = shape.rank();
52     TF_RET_CHECK(n_minor_dims <= n_dims);
53     auto major_dims = shape.dimensions().subspan(
54         /*pos=*/0,
55         /*len=*/n_dims - n_minor_dims);
56 
57     // Prepends 0s in the major dim
58     std::vector<int64_t> padded_start(n_dims, 0);
59     std::copy(start.begin(), start.end(),
60               padded_start.begin() + major_dims.size());
61 
62     // Prepends the shape of the major dims.
63     std::vector<int64_t> padded_end(n_dims);
64     std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
65     std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
66 
67     std::vector<int64_t> strides(n_dims, 1);
68     return Slice(x, padded_start, padded_end, strides);
69   });
70 }
71 
UpdateSlice(XlaOp x,XlaOp update,absl::Span<const int64_t> start)72 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64_t> start) {
73   XlaBuilder* builder = x.builder();
74   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
75     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
76     const int64_t n_dims = shape.rank();
77     const int64_t start_size = start.size();
78     TF_RET_CHECK(start_size == n_dims);
79 
80     // TODO(phawkins): make int64_t work on all backends, remove the int32_t
81     // cast.
82     std::vector<int32_t> start_as_int32(start.begin(), start.end());
83     std::vector<XlaOp> start_ops(start.size());
84     for (int i = 0, end = start.size(); i < end; ++i) {
85       start_ops[i] = ConstantR0(builder, start_as_int32[i]);
86     }
87     return DynamicUpdateSlice(x, update, start_ops);
88   });
89 }
90 
UpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const int64_t> start)91 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
92                              absl::Span<const int64_t> start) {
93   XlaBuilder* builder = x.builder();
94   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
95     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
96     const int64_t n_dims = shape.rank();
97     const int64_t n_minor_dims = start.size();
98     TF_RET_CHECK(n_minor_dims <= n_dims);
99     std::vector<int64_t> padded_start(n_dims, 0);
100     std::copy(start.begin(), start.end(),
101               padded_start.begin() + (n_dims - n_minor_dims));
102     return UpdateSlice(x, update, padded_start);
103   });
104 }
105 
106 namespace {
107 
ConcatVectors(absl::Span<const int64_t> xs,absl::Span<const int64_t> ys)108 std::vector<int64_t> ConcatVectors(absl::Span<const int64_t> xs,
109                                    absl::Span<const int64_t> ys) {
110   std::vector<int64_t> output(xs.size() + ys.size());
111   std::copy(xs.begin(), xs.end(), output.begin());
112   std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
113   return output;
114 }
115 
PrependZerosInMajorDims(XlaOp x,absl::Span<const XlaOp> starts)116 StatusOr<std::vector<XlaOp>> PrependZerosInMajorDims(
117     XlaOp x, absl::Span<const XlaOp> starts) {
118   XlaBuilder* builder = x.builder();
119   TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
120   const int64_t n_dims = shape.rank();
121   auto zero = ConstantR0<int32_t>(builder, 0);
122   std::vector<XlaOp> padded_starts(n_dims, zero);
123   for (int i = 0; i < starts.size(); ++i) {
124     padded_starts[n_dims - starts.size() + i] = starts[i];
125   }
126   return padded_starts;
127 }
128 
129 }  // namespace
130 
DynamicSliceInMinorDims(XlaOp x,absl::Span<const XlaOp> starts,absl::Span<const int64_t> sizes)131 XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
132                               absl::Span<const int64_t> sizes) {
133   XlaBuilder* builder = x.builder();
134   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
135     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
136     const int64_t n_dims = shape.rank();
137     int64_t n_minor_dims = starts.size();
138     TF_RET_CHECK(n_minor_dims == sizes.size());
139     TF_RET_CHECK(n_minor_dims <= n_dims);
140     auto major_dims = shape.dimensions().subspan(
141         /*pos=*/0,
142         /*len=*/n_dims - sizes.size());
143     TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
144     auto padded_sizes = ConcatVectors(major_dims, sizes);
145     return DynamicSlice(x, padded_starts, padded_sizes);
146   });
147 }
148 
DynamicUpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const XlaOp> starts)149 XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
150                                     absl::Span<const XlaOp> starts) {
151   XlaBuilder* builder = x.builder();
152   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
153     TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
154     return DynamicUpdateSlice(x, update, padded_starts);
155   });
156 }
157 
TorchGather(XlaOp input,XlaOp index,int64_t dim,bool sparse)158 XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) {
159   XlaBuilder* builder = input.builder();
160   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
161     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
162     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
163     if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
164         input_shape.dimensions(dim) < std::numeric_limits<uint32_t>::max()) {
165       index = ConvertElementType(index, U32);
166       index_shape.set_element_type(U32);
167     }
168     if (index_shape.rank() == 1) {
169       return TorchIndexSelect(input, index, 0);
170     }
171     if (!sparse) {
172       std::vector<int64_t> index_broadcast_dims;
173       std::vector<int64_t> input_broadcast_dims;
174       std::vector<int64_t> sizes;
175       sizes.reserve(index_shape.rank());
176       for (int64_t i = 0; i < index_shape.rank(); ++i) {
177         if (i < dim) {
178           input_broadcast_dims.push_back(i);
179           index_broadcast_dims.push_back(i);
180         } else if (i == dim) {
181           sizes.push_back(input_shape.dimensions(i));
182           input_broadcast_dims.push_back(i);
183           index_broadcast_dims.push_back(i + 1);
184         } else {
185           input_broadcast_dims.push_back(i + 1);
186           index_broadcast_dims.push_back(i + 1);
187         }
188         sizes.push_back(index_shape.dimensions(i));
189       }
190       auto mask = Eq(
191           BroadcastInDim(index, sizes, index_broadcast_dims),
192           Iota(builder, ShapeUtil::MakeShape(index_shape.element_type(), sizes),
193                dim));
194       auto masked_input = Select(
195           mask, BroadcastInDim(input, sizes, input_broadcast_dims),
196           Zeros(builder,
197                 ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
198       return Reduce(masked_input, Zero(builder, input_shape.element_type()),
199                     CreateScalarIdentityWithZeroComputation(
200                         input_shape.element_type(), builder),
201                     {dim});
202     }
203 
204     ShapeUtil::AppendMajorDimension(1, &index_shape);
205     std::vector<XlaOp> to_concat;
206 
207     to_concat.reserve(input_shape.rank());
208     for (int64_t i = 0; i < input_shape.rank(); ++i) {
209       if (i == dim) {
210         to_concat.push_back(Reshape(index, index_shape.dimensions()));
211       } else {
212         to_concat.push_back(Iota(builder, index_shape, i));
213       }
214     }
215     XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank());
216     std::vector<int64_t> slice_sizes(input_shape.rank(), 1);
217     GatherDimensionNumbers gather_dnums;
218     gather_dnums.set_index_vector_dim(input_shape.rank());
219     for (int64_t i = 0; i < input_shape.rank(); ++i) {
220       gather_dnums.add_collapsed_slice_dims(i);
221       gather_dnums.add_start_index_map(i);
222     }
223     return Gather(input, gather_indices, gather_dnums, slice_sizes);
224   });
225 }
226 
TorchScatterDense(XlaOp input,XlaOp index,XlaOp src,int64_t dim,const std::function<XlaOp (XlaOp,XlaOp)> & combiner)227 XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim,
228                         const std::function<XlaOp(XlaOp, XlaOp)>& combiner) {
229   XlaBuilder* builder = input.builder();
230   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
231     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
232     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
233     std::vector<int64_t> index_broadcast_dims;
234     std::vector<int64_t> sizes;
235     const auto rank = index_shape.rank();
236     sizes.reserve(rank + 1);
237     for (int64_t i = 0; i < index_shape.rank(); ++i) {
238       if (i < dim) {
239         index_broadcast_dims.push_back(i);
240       } else {
241         if (i == dim) {
242           sizes.push_back(input_shape.dimensions(i));
243         }
244         index_broadcast_dims.push_back(i + 1);
245       }
246       sizes.push_back(index_shape.dimensions(i));
247     }
248     auto mask =
249         Eq(BroadcastInDim(index, sizes, index_broadcast_dims),
250            Iota(builder,
251                 ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim));
252     auto masked_src =
253         Select(mask, BroadcastInDim(src, sizes, index_broadcast_dims),
254                Zeros(builder,
255                      ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
256 
257     return combiner(
258         input,
259         Reduce(masked_src, Zero(builder, input_shape.element_type()),
260                CreateScalarComputation("reducer", input_shape.element_type(),
261                                        builder, combiner),
262                {dim + 1}));
263   });
264 }
265 
TorchIndexSelect(XlaOp input,XlaOp index,int64_t dim,int64_t batch_dims)266 XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim,
267                        int64_t batch_dims) {
268   XlaBuilder* builder = input.builder();
269   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
270     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
271     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
272     if (dim < batch_dims) {
273       return InvalidArgument(
274           "Gather dim must be greater than or equal to the number of batch "
275           "dims");
276     }
277     if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
278         input_shape.dimensions(dim) < std::numeric_limits<uint32_t>::max()) {
279       index = ConvertElementType(index, U32);
280       index_shape.set_element_type(U32);
281     }
282     std::vector<int64_t> slice_sizes = SpanToVector(input_shape.dimensions());
283     GatherDimensionNumbers gather_dnums;
284     gather_dnums.set_index_vector_dim(index_shape.rank());
285     if (batch_dims > 0) {
286       ShapeUtil::AppendMajorDimension(1, &index_shape);
287       std::vector<XlaOp> to_concat;
288       to_concat.reserve(batch_dims + 1);
289       for (int64_t batch_dim = 0; batch_dim < batch_dims; ++batch_dim) {
290         to_concat.push_back(Iota(builder, index_shape, batch_dim));
291       }
292       to_concat.push_back(Reshape(index, index_shape.dimensions()));
293       index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim());
294     }
295     for (int64_t i = 0; i < input_shape.rank(); ++i) {
296       if (i < batch_dims || i == dim) {
297         slice_sizes[i] = std::min<int64_t>(slice_sizes[i], 1);
298         gather_dnums.add_collapsed_slice_dims(i);
299         gather_dnums.add_start_index_map(i);
300       } else {
301         if (i < dim) {
302           gather_dnums.add_offset_dims(i);
303         } else {
304           gather_dnums.add_offset_dims(i + gather_dnums.index_vector_dim() -
305                                        (1 + batch_dims));
306         }
307       }
308     }
309     return Gather(input, index, gather_dnums, slice_sizes);
310   });
311 }
312 
313 }  // namespace xla
314