1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11
12 namespace torch {
13 namespace executor {
14
15 using Tensor = exec_aten::Tensor;
16 using TensorOptList = exec_aten::ArrayRef<exec_aten::optional<Tensor>>;
17
18 namespace {
19
check_indices_dtypes(TensorOptList indices)20 bool check_indices_dtypes(TensorOptList indices) {
21 for (auto i = 0; i < indices.size(); i++) {
22 if (indices[i].has_value()) {
23 const Tensor& index = indices[i].value();
24 ScalarType ix_type = index.scalar_type();
25 ET_LOG_MSG_AND_RETURN_IF_FALSE(
26 ix_type == ScalarType::Long || ix_type == ScalarType::Int ||
27 ix_type == ScalarType::Byte || ix_type == ScalarType::Bool,
28 "Index tensors should be Long, Int, Byte or Bool");
29 }
30 }
31 return true;
32 }
33
is_mask_index(const Tensor & index)34 bool is_mask_index(const Tensor& index) {
35 if (index.scalar_type() == ScalarType::Bool ||
36 index.scalar_type() == ScalarType::Byte) {
37 return true;
38 }
39 return false;
40 }
41
check_mask_indices(const Tensor & in,TensorOptList indices)42 bool check_mask_indices(const Tensor& in, TensorOptList indices) {
43 size_t in_i = 0;
44 for (auto i = 0; i < indices.size(); i++) {
45 if (indices[i].has_value()) {
46 const Tensor& index = indices[i].value();
47 if (is_mask_index(index)) {
48 ET_LOG_MSG_AND_RETURN_IF_FALSE(
49 index.dim() > 0, "Zero-dimensional mask index not allowed");
50 for (auto j = 0; j < index.dim(); j++) {
51 ET_LOG_MSG_AND_RETURN_IF_FALSE(
52 index.size(j) == in.size(in_i + j),
53 "The shape of mask index must match the sizes of the corresponding input dimensions.");
54 }
55 in_i += index.dim();
56 } else {
57 in_i += 1;
58 }
59 } else {
60 in_i += 1;
61 }
62 }
63 return true;
64 }
65
66 template <typename CTYPE_IX>
_count_trues_in_mask_index(const Tensor & index)67 size_t _count_trues_in_mask_index(const Tensor& index) {
68 const CTYPE_IX* const index_ptr = index.const_data_ptr<CTYPE_IX>();
69 size_t sum = 0;
70 for (size_t i = 0; i < index.numel(); ++i) {
71 if (index_ptr[i]) {
72 sum += 1;
73 }
74 }
75 return sum;
76 }
77
count_trues_in_mask_index(const Tensor & index)78 size_t count_trues_in_mask_index(const Tensor& index) {
79 if (index.scalar_type() == ScalarType::Bool) {
80 return _count_trues_in_mask_index<bool>(index);
81 } else {
82 return _count_trues_in_mask_index<uint8_t>(index);
83 }
84 }
85
86 template <typename CTYPE_IX>
_query_mask_index(const Tensor & index,size_t query_idx,size_t * res)87 void _query_mask_index(const Tensor& index, size_t query_idx, size_t* res) {
88 const CTYPE_IX* const index_ptr = index.const_data_ptr<CTYPE_IX>();
89 // Broadcasting for mask index tensors
90 size_t num_true = _count_trues_in_mask_index<CTYPE_IX>(index);
91 if (num_true == 1) {
92 query_idx = 0;
93 }
94 // Extract the index value by finding the idx-th element that is set to
95 // true.
96 size_t count = 0;
97 size_t flat_ix = 0;
98 for (size_t i = 0; i < index.numel(); ++i) {
99 if (index_ptr[i]) {
100 if (count == query_idx) {
101 flat_ix = i;
102 break;
103 } else {
104 count++;
105 }
106 }
107 }
108 delinearize_index(flat_ix, index, res, kTensorDimensionLimit);
109 }
110
query_mask_index(const Tensor & index,size_t query_idx,size_t * res)111 void query_mask_index(const Tensor& index, size_t query_idx, size_t* res) {
112 if (index.scalar_type() == ScalarType::Bool) {
113 _query_mask_index<bool>(index, query_idx, res);
114 } else {
115 _query_mask_index<uint8_t>(index, query_idx, res);
116 }
117 }
118
query_integral_index(const Tensor & index,size_t * ix_coord,size_t broadcast_ndim)119 int64_t query_integral_index(
120 const Tensor& index,
121 size_t* ix_coord,
122 size_t broadcast_ndim) {
123 size_t flat_ix = linearize_access_indexes(
124 {ix_coord, broadcast_ndim}, broadcast_ndim, index);
125
126 ScalarType idx_type = index.scalar_type();
127 int64_t index_val = 0;
128 // Extract the index value
129 if (idx_type == ScalarType::Int) {
130 const int32_t* const index_ptr = index.const_data_ptr<int32_t>();
131 index_val = static_cast<int64_t>(index_ptr[flat_ix]);
132 } else {
133 const int64_t* const index_ptr = index.const_data_ptr<int64_t>();
134 index_val = index_ptr[flat_ix];
135 }
136 return index_val;
137 }
138
139 } // namespace
140
check_index_args(const Tensor & in,TensorOptList indices,Tensor & out)141 bool check_index_args(const Tensor& in, TensorOptList indices, Tensor& out) {
142 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
143 ET_LOG_AND_RETURN_IF_FALSE(check_indices_dtypes(indices));
144 ET_LOG_MSG_AND_RETURN_IF_FALSE(
145 indices.size() <= in.dim(), "Indexing too many dimensions");
146 ET_LOG_AND_RETURN_IF_FALSE(check_mask_indices(in, indices));
147 return true;
148 }
149
count_index_blocks(TensorOptList indices)150 size_t count_index_blocks(TensorOptList indices) {
151 size_t block_count = 0;
152 bool in_block = false;
153 for (size_t i = 0; i < indices.size(); i++) {
154 if (indices[i].has_value()) {
155 if (!in_block) {
156 in_block = true;
157 block_count++;
158 }
159 } else {
160 in_block = false;
161 }
162 }
163 return block_count;
164 }
165
get_indices_broadcast_shape(TensorOptList indices,Tensor::SizesType * ix_sizes,size_t * ix_ndim)166 bool get_indices_broadcast_shape(
167 TensorOptList indices,
168 Tensor::SizesType* ix_sizes,
169 size_t* ix_ndim) {
170 // Holds the (reversed) broadcasted shape of the indices.
171 Tensor::SizesType rev_ix_sizes[kTensorDimensionLimit];
172 size_t curr_ndim = 0;
173
174 for (size_t i = 0; i < indices.size(); i++) {
175 if (indices[i].has_value()) {
176 const Tensor& index = indices[i].value();
177 if (is_mask_index(index)) {
178 size_t len = count_trues_in_mask_index(index);
179 if (curr_ndim == 0) {
180 curr_ndim = 1;
181 rev_ix_sizes[0] = len;
182 } else if (rev_ix_sizes[0] == 1) {
183 rev_ix_sizes[0] = len;
184 } else if (len != 1 && rev_ix_sizes[0] != len) {
185 ET_LOG_MSG_AND_RETURN_IF_FALSE(
186 false, "Broadcast of mask index failed.");
187 }
188 } else {
189 for (size_t j = 0; j < index.dim(); j++) {
190 size_t rev_j_size = index.size(index.dim() - j - 1);
191 if (j >= curr_ndim) {
192 curr_ndim = j + 1;
193 rev_ix_sizes[j] = rev_j_size;
194 } else if (rev_ix_sizes[j] == 1) {
195 rev_ix_sizes[j] = rev_j_size;
196 } else if (rev_j_size != 1 && rev_ix_sizes[j] != rev_j_size) {
197 ET_LOG_MSG_AND_RETURN_IF_FALSE(false, "Broadcast of index failed.");
198 }
199 }
200 }
201 }
202 }
203
204 for (size_t i = 0; i < curr_ndim; i++) {
205 ix_sizes[i] = rev_ix_sizes[curr_ndim - i - 1];
206 }
207 (*ix_ndim) = curr_ndim;
208 return true;
209 }
210
get_indices_broadcast_ndim(TensorOptList indices)211 size_t get_indices_broadcast_ndim(TensorOptList indices) {
212 size_t ndim = 0;
213 for (size_t i = 0; i < indices.size(); i++) {
214 if (indices[i].has_value()) {
215 const Tensor& index = indices[i].value();
216 if (is_mask_index(index)) {
217 if (ndim == 0) {
218 ndim = 1;
219 }
220 } else {
221 if (ndim < index.dim()) {
222 ndim = index.dim();
223 }
224 }
225 }
226 }
227 return ndim;
228 }
229
get_num_indexed_dims(TensorOptList indices)230 size_t get_num_indexed_dims(TensorOptList indices) {
231 size_t num_indexed_dims = 0;
232 for (size_t i = 0; i < indices.size(); i++) {
233 if (indices[i].has_value()) {
234 const Tensor& index = indices[i].value();
235 if (is_mask_index(index)) {
236 num_indexed_dims += index.dim();
237 } else {
238 num_indexed_dims += 1;
239 }
240 }
241 }
242 return num_indexed_dims;
243 }
244
get_num_null_indices(TensorOptList indices)245 size_t get_num_null_indices(TensorOptList indices) {
246 size_t num_null_indices = 0;
247 for (size_t i = 0; i < indices.size(); i++) {
248 if (!indices[i].has_value()) {
249 num_null_indices += 1;
250 }
251 }
252 return num_null_indices;
253 }
254
get_num_leading_null_indices(TensorOptList indices)255 size_t get_num_leading_null_indices(TensorOptList indices) {
256 size_t start = 0;
257 while (!indices[start].has_value()) {
258 start += 1;
259 }
260 return start;
261 }
262
get_index_out_target_size(const Tensor & in,TensorOptList indices,bool adjacent,Tensor::SizesType * out_sizes,size_t * out_ndim)263 bool get_index_out_target_size(
264 const Tensor& in,
265 TensorOptList indices,
266 bool adjacent,
267 Tensor::SizesType* out_sizes,
268 size_t* out_ndim) {
269 Tensor::SizesType broadcast_sizes[kTensorDimensionLimit];
270 size_t broadcast_ndim = 0;
271 if (!get_indices_broadcast_shape(indices, broadcast_sizes, &broadcast_ndim)) {
272 return false;
273 }
274
275 size_t num_null_indices = get_num_null_indices(indices);
276 size_t num_indexed_dims = get_num_indexed_dims(indices);
277
278 ET_LOG_MSG_AND_RETURN_IF_FALSE(
279 num_null_indices + num_indexed_dims <= in.dim(),
280 "Indexing too many dimensions");
281
282 ET_LOG_MSG_AND_RETURN_IF_FALSE(
283 in.dim() + broadcast_ndim - num_indexed_dims <= kTensorDimensionLimit,
284 "Out tensor would exceed number of allowed dimensions");
285
286 (*out_ndim) = in.dim() + broadcast_ndim - num_indexed_dims;
287
288 if (adjacent) {
289 size_t start = get_num_leading_null_indices(indices);
290 for (size_t i = 0; i < start; i++) {
291 out_sizes[i] = in.size(i);
292 }
293 for (size_t i = 0; i < broadcast_ndim; i++) {
294 out_sizes[i + start] = broadcast_sizes[i];
295 }
296 for (size_t i = num_indexed_dims + start; i < in.dim(); i++) {
297 out_sizes[i + broadcast_ndim - num_indexed_dims] = in.size(i);
298 }
299 } else {
300 for (size_t i = 0; i < broadcast_ndim; i++) {
301 out_sizes[i] = broadcast_sizes[i];
302 }
303 size_t in_i = 0;
304 size_t out_i = broadcast_ndim;
305 for (size_t i = 0; i < indices.size(); i++) {
306 if (!indices[i].has_value()) {
307 out_sizes[out_i++] = in.size(in_i++);
308 } else {
309 const Tensor& index = indices[i].value();
310 if (is_mask_index(index)) {
311 in_i += index.dim();
312 } else {
313 in_i += 1;
314 }
315 }
316 }
317 for (size_t i = num_indexed_dims + num_null_indices; i < in.dim(); i++) {
318 out_sizes[i + broadcast_ndim - num_indexed_dims] = in.size(i);
319 }
320 }
321 return true;
322 }
323
324 // dim_map maps non-indexed input dimensions to the corresponding output
325 // dimensions. Indexed dimensions are mapped to -1.
compute_dim_map(const Tensor & in,TensorOptList indices,int32_t * dim_map,bool adjacent)326 void compute_dim_map(
327 const Tensor& in,
328 TensorOptList indices,
329 int32_t* dim_map,
330 bool adjacent) {
331 size_t broadcast_ndim = get_indices_broadcast_ndim(indices);
332 size_t start = get_num_leading_null_indices(indices);
333 size_t num_indexed_dims = get_num_indexed_dims(indices);
334 size_t num_null_indices = get_num_null_indices(indices);
335
336 if (adjacent) {
337 for (auto i = 0; i < start; i++) {
338 dim_map[i] = i;
339 }
340 for (auto i = start; i < start + num_indexed_dims; i++) {
341 dim_map[i] = -1;
342 }
343 for (auto i = start + num_indexed_dims; i < in.dim(); i++) {
344 dim_map[i] = i - num_indexed_dims + broadcast_ndim;
345 }
346 } else {
347 size_t in_i = 0;
348 size_t out_i = broadcast_ndim;
349 for (size_t i = 0; i < indices.size(); i++) {
350 if (!indices[i].has_value()) {
351 dim_map[in_i++] = out_i++;
352 } else {
353 const Tensor& index = indices[i].value();
354 if (is_mask_index(index)) {
355 for (auto j = 0; j < index.dim(); j++) {
356 dim_map[in_i++] = -1;
357 }
358 } else {
359 dim_map[in_i++] = -1;
360 }
361 }
362 }
363 for (size_t i = num_indexed_dims + num_null_indices; i < in.dim(); i++) {
364 dim_map[i] = i - num_indexed_dims + broadcast_ndim;
365 }
366 }
367 }
368
369 // ix_map maps indexed input dimensions to the corresponding index.
370 // Non-indexed dimensions are mapped to -1.
compute_index_map(const Tensor & in,TensorOptList indices,int32_t * ix_map)371 void compute_index_map(
372 const Tensor& in,
373 TensorOptList indices,
374 int32_t* ix_map) {
375 for (size_t i = 0; i < in.dim(); i++) {
376 ix_map[i] = -1;
377 }
378 size_t in_i = 0;
379 for (size_t i = 0; i < indices.size(); i++) {
380 if (indices[i].has_value()) {
381 const Tensor& index = indices[i].value();
382 if (is_mask_index(index)) {
383 for (auto j = 0; j < index.dim(); j++) {
384 ix_map[in_i++] = i;
385 }
386 } else {
387 ix_map[in_i++] = i;
388 }
389 } else {
390 in_i++;
391 }
392 }
393 }
394
get_in_coord(const Tensor & in,TensorOptList indices,size_t start,size_t broadcast_ndim,int32_t * dim_map,int32_t * ix_map,size_t * out_coord,size_t * in_coord)395 bool get_in_coord(
396 const Tensor& in,
397 TensorOptList indices,
398 size_t start,
399 size_t broadcast_ndim,
400 int32_t* dim_map,
401 int32_t* ix_map,
402 size_t* out_coord,
403 size_t* in_coord) {
404 for (ssize_t i = 0; i < in.dim(); i++) {
405 if (dim_map[i] >= 0) {
406 in_coord[i] = out_coord[dim_map[i]];
407 } else {
408 const Tensor& index = indices[ix_map[i]].value();
409
410 size_t ix_coord[kTensorDimensionLimit];
411 for (auto j = 0; j < broadcast_ndim; j++) {
412 ix_coord[j] = out_coord[j + start];
413 }
414
415 if (is_mask_index(index)) {
416 size_t query_ix = ix_coord[broadcast_ndim - 1];
417 size_t query_result[kTensorDimensionLimit];
418 query_mask_index(index, query_ix, query_result);
419 for (auto j = 0; j < index.dim(); j++) {
420 in_coord[i + j] = query_result[j];
421 }
422 i += index.dim() - 1;
423 } else {
424 int64_t index_val =
425 query_integral_index(index, ix_coord, broadcast_ndim);
426 if (index_val < 0) {
427 index_val += in.size(i);
428 }
429 ET_LOG_MSG_AND_RETURN_IF_FALSE(
430 index_val >= 0 && index_val < in.size(i),
431 "Index %" PRId64
432 " is out of bounds for input dimension %zd with size %zd.",
433 index_val,
434 i,
435 in.size(i));
436 in_coord[i] = static_cast<size_t>(index_val);
437 }
438 }
439 }
440 return true;
441 }
442
get_in_ix(const Tensor & in,TensorOptList indices,Tensor & out,size_t out_ix,size_t start,size_t broadcast_ndim,int32_t * dim_map,int32_t * ix_map)443 std::pair<size_t, bool> get_in_ix(
444 const Tensor& in,
445 TensorOptList indices,
446 Tensor& out,
447 size_t out_ix,
448 size_t start,
449 size_t broadcast_ndim,
450 int32_t* dim_map,
451 int32_t* ix_map) {
452 size_t out_coord[kTensorDimensionLimit];
453 delinearize_index(out_ix, out, out_coord, kTensorDimensionLimit);
454
455 size_t in_coord[kTensorDimensionLimit];
456 bool success = get_in_coord(
457 in, indices, start, broadcast_ndim, dim_map, ix_map, out_coord, in_coord);
458 if (!success) {
459 return std::make_pair(0, false);
460 }
461 return std::make_pair(coordinateToIndex(in, in_coord), true);
462 }
463
464 } // namespace executor
465 } // namespace torch
466