xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/advanced_index_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 
12 namespace torch {
13 namespace executor {
14 
15 using Tensor = exec_aten::Tensor;
16 using TensorOptList = exec_aten::ArrayRef<exec_aten::optional<Tensor>>;
17 
18 namespace {
19 
check_indices_dtypes(TensorOptList indices)20 bool check_indices_dtypes(TensorOptList indices) {
21   for (auto i = 0; i < indices.size(); i++) {
22     if (indices[i].has_value()) {
23       const Tensor& index = indices[i].value();
24       ScalarType ix_type = index.scalar_type();
25       ET_LOG_MSG_AND_RETURN_IF_FALSE(
26           ix_type == ScalarType::Long || ix_type == ScalarType::Int ||
27               ix_type == ScalarType::Byte || ix_type == ScalarType::Bool,
28           "Index tensors should be Long, Int, Byte or Bool");
29     }
30   }
31   return true;
32 }
33 
is_mask_index(const Tensor & index)34 bool is_mask_index(const Tensor& index) {
35   if (index.scalar_type() == ScalarType::Bool ||
36       index.scalar_type() == ScalarType::Byte) {
37     return true;
38   }
39   return false;
40 }
41 
check_mask_indices(const Tensor & in,TensorOptList indices)42 bool check_mask_indices(const Tensor& in, TensorOptList indices) {
43   size_t in_i = 0;
44   for (auto i = 0; i < indices.size(); i++) {
45     if (indices[i].has_value()) {
46       const Tensor& index = indices[i].value();
47       if (is_mask_index(index)) {
48         ET_LOG_MSG_AND_RETURN_IF_FALSE(
49             index.dim() > 0, "Zero-dimensional mask index not allowed");
50         for (auto j = 0; j < index.dim(); j++) {
51           ET_LOG_MSG_AND_RETURN_IF_FALSE(
52               index.size(j) == in.size(in_i + j),
53               "The shape of mask index must match the sizes of the corresponding input dimensions.");
54         }
55         in_i += index.dim();
56       } else {
57         in_i += 1;
58       }
59     } else {
60       in_i += 1;
61     }
62   }
63   return true;
64 }
65 
66 template <typename CTYPE_IX>
_count_trues_in_mask_index(const Tensor & index)67 size_t _count_trues_in_mask_index(const Tensor& index) {
68   const CTYPE_IX* const index_ptr = index.const_data_ptr<CTYPE_IX>();
69   size_t sum = 0;
70   for (size_t i = 0; i < index.numel(); ++i) {
71     if (index_ptr[i]) {
72       sum += 1;
73     }
74   }
75   return sum;
76 }
77 
count_trues_in_mask_index(const Tensor & index)78 size_t count_trues_in_mask_index(const Tensor& index) {
79   if (index.scalar_type() == ScalarType::Bool) {
80     return _count_trues_in_mask_index<bool>(index);
81   } else {
82     return _count_trues_in_mask_index<uint8_t>(index);
83   }
84 }
85 
86 template <typename CTYPE_IX>
_query_mask_index(const Tensor & index,size_t query_idx,size_t * res)87 void _query_mask_index(const Tensor& index, size_t query_idx, size_t* res) {
88   const CTYPE_IX* const index_ptr = index.const_data_ptr<CTYPE_IX>();
89   // Broadcasting for mask index tensors
90   size_t num_true = _count_trues_in_mask_index<CTYPE_IX>(index);
91   if (num_true == 1) {
92     query_idx = 0;
93   }
94   // Extract the index value by finding the idx-th element that is set to
95   // true.
96   size_t count = 0;
97   size_t flat_ix = 0;
98   for (size_t i = 0; i < index.numel(); ++i) {
99     if (index_ptr[i]) {
100       if (count == query_idx) {
101         flat_ix = i;
102         break;
103       } else {
104         count++;
105       }
106     }
107   }
108   delinearize_index(flat_ix, index, res, kTensorDimensionLimit);
109 }
110 
query_mask_index(const Tensor & index,size_t query_idx,size_t * res)111 void query_mask_index(const Tensor& index, size_t query_idx, size_t* res) {
112   if (index.scalar_type() == ScalarType::Bool) {
113     _query_mask_index<bool>(index, query_idx, res);
114   } else {
115     _query_mask_index<uint8_t>(index, query_idx, res);
116   }
117 }
118 
query_integral_index(const Tensor & index,size_t * ix_coord,size_t broadcast_ndim)119 int64_t query_integral_index(
120     const Tensor& index,
121     size_t* ix_coord,
122     size_t broadcast_ndim) {
123   size_t flat_ix = linearize_access_indexes(
124       {ix_coord, broadcast_ndim}, broadcast_ndim, index);
125 
126   ScalarType idx_type = index.scalar_type();
127   int64_t index_val = 0;
128   // Extract the index value
129   if (idx_type == ScalarType::Int) {
130     const int32_t* const index_ptr = index.const_data_ptr<int32_t>();
131     index_val = static_cast<int64_t>(index_ptr[flat_ix]);
132   } else {
133     const int64_t* const index_ptr = index.const_data_ptr<int64_t>();
134     index_val = index_ptr[flat_ix];
135   }
136   return index_val;
137 }
138 
139 } // namespace
140 
check_index_args(const Tensor & in,TensorOptList indices,Tensor & out)141 bool check_index_args(const Tensor& in, TensorOptList indices, Tensor& out) {
142   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
143   ET_LOG_AND_RETURN_IF_FALSE(check_indices_dtypes(indices));
144   ET_LOG_MSG_AND_RETURN_IF_FALSE(
145       indices.size() <= in.dim(), "Indexing too many dimensions");
146   ET_LOG_AND_RETURN_IF_FALSE(check_mask_indices(in, indices));
147   return true;
148 }
149 
count_index_blocks(TensorOptList indices)150 size_t count_index_blocks(TensorOptList indices) {
151   size_t block_count = 0;
152   bool in_block = false;
153   for (size_t i = 0; i < indices.size(); i++) {
154     if (indices[i].has_value()) {
155       if (!in_block) {
156         in_block = true;
157         block_count++;
158       }
159     } else {
160       in_block = false;
161     }
162   }
163   return block_count;
164 }
165 
get_indices_broadcast_shape(TensorOptList indices,Tensor::SizesType * ix_sizes,size_t * ix_ndim)166 bool get_indices_broadcast_shape(
167     TensorOptList indices,
168     Tensor::SizesType* ix_sizes,
169     size_t* ix_ndim) {
170   // Holds the (reversed) broadcasted shape of the indices.
171   Tensor::SizesType rev_ix_sizes[kTensorDimensionLimit];
172   size_t curr_ndim = 0;
173 
174   for (size_t i = 0; i < indices.size(); i++) {
175     if (indices[i].has_value()) {
176       const Tensor& index = indices[i].value();
177       if (is_mask_index(index)) {
178         size_t len = count_trues_in_mask_index(index);
179         if (curr_ndim == 0) {
180           curr_ndim = 1;
181           rev_ix_sizes[0] = len;
182         } else if (rev_ix_sizes[0] == 1) {
183           rev_ix_sizes[0] = len;
184         } else if (len != 1 && rev_ix_sizes[0] != len) {
185           ET_LOG_MSG_AND_RETURN_IF_FALSE(
186               false, "Broadcast of mask index failed.");
187         }
188       } else {
189         for (size_t j = 0; j < index.dim(); j++) {
190           size_t rev_j_size = index.size(index.dim() - j - 1);
191           if (j >= curr_ndim) {
192             curr_ndim = j + 1;
193             rev_ix_sizes[j] = rev_j_size;
194           } else if (rev_ix_sizes[j] == 1) {
195             rev_ix_sizes[j] = rev_j_size;
196           } else if (rev_j_size != 1 && rev_ix_sizes[j] != rev_j_size) {
197             ET_LOG_MSG_AND_RETURN_IF_FALSE(false, "Broadcast of index failed.");
198           }
199         }
200       }
201     }
202   }
203 
204   for (size_t i = 0; i < curr_ndim; i++) {
205     ix_sizes[i] = rev_ix_sizes[curr_ndim - i - 1];
206   }
207   (*ix_ndim) = curr_ndim;
208   return true;
209 }
210 
get_indices_broadcast_ndim(TensorOptList indices)211 size_t get_indices_broadcast_ndim(TensorOptList indices) {
212   size_t ndim = 0;
213   for (size_t i = 0; i < indices.size(); i++) {
214     if (indices[i].has_value()) {
215       const Tensor& index = indices[i].value();
216       if (is_mask_index(index)) {
217         if (ndim == 0) {
218           ndim = 1;
219         }
220       } else {
221         if (ndim < index.dim()) {
222           ndim = index.dim();
223         }
224       }
225     }
226   }
227   return ndim;
228 }
229 
get_num_indexed_dims(TensorOptList indices)230 size_t get_num_indexed_dims(TensorOptList indices) {
231   size_t num_indexed_dims = 0;
232   for (size_t i = 0; i < indices.size(); i++) {
233     if (indices[i].has_value()) {
234       const Tensor& index = indices[i].value();
235       if (is_mask_index(index)) {
236         num_indexed_dims += index.dim();
237       } else {
238         num_indexed_dims += 1;
239       }
240     }
241   }
242   return num_indexed_dims;
243 }
244 
get_num_null_indices(TensorOptList indices)245 size_t get_num_null_indices(TensorOptList indices) {
246   size_t num_null_indices = 0;
247   for (size_t i = 0; i < indices.size(); i++) {
248     if (!indices[i].has_value()) {
249       num_null_indices += 1;
250     }
251   }
252   return num_null_indices;
253 }
254 
get_num_leading_null_indices(TensorOptList indices)255 size_t get_num_leading_null_indices(TensorOptList indices) {
256   size_t start = 0;
257   while (!indices[start].has_value()) {
258     start += 1;
259   }
260   return start;
261 }
262 
get_index_out_target_size(const Tensor & in,TensorOptList indices,bool adjacent,Tensor::SizesType * out_sizes,size_t * out_ndim)263 bool get_index_out_target_size(
264     const Tensor& in,
265     TensorOptList indices,
266     bool adjacent,
267     Tensor::SizesType* out_sizes,
268     size_t* out_ndim) {
269   Tensor::SizesType broadcast_sizes[kTensorDimensionLimit];
270   size_t broadcast_ndim = 0;
271   if (!get_indices_broadcast_shape(indices, broadcast_sizes, &broadcast_ndim)) {
272     return false;
273   }
274 
275   size_t num_null_indices = get_num_null_indices(indices);
276   size_t num_indexed_dims = get_num_indexed_dims(indices);
277 
278   ET_LOG_MSG_AND_RETURN_IF_FALSE(
279       num_null_indices + num_indexed_dims <= in.dim(),
280       "Indexing too many dimensions");
281 
282   ET_LOG_MSG_AND_RETURN_IF_FALSE(
283       in.dim() + broadcast_ndim - num_indexed_dims <= kTensorDimensionLimit,
284       "Out tensor would exceed number of allowed dimensions");
285 
286   (*out_ndim) = in.dim() + broadcast_ndim - num_indexed_dims;
287 
288   if (adjacent) {
289     size_t start = get_num_leading_null_indices(indices);
290     for (size_t i = 0; i < start; i++) {
291       out_sizes[i] = in.size(i);
292     }
293     for (size_t i = 0; i < broadcast_ndim; i++) {
294       out_sizes[i + start] = broadcast_sizes[i];
295     }
296     for (size_t i = num_indexed_dims + start; i < in.dim(); i++) {
297       out_sizes[i + broadcast_ndim - num_indexed_dims] = in.size(i);
298     }
299   } else {
300     for (size_t i = 0; i < broadcast_ndim; i++) {
301       out_sizes[i] = broadcast_sizes[i];
302     }
303     size_t in_i = 0;
304     size_t out_i = broadcast_ndim;
305     for (size_t i = 0; i < indices.size(); i++) {
306       if (!indices[i].has_value()) {
307         out_sizes[out_i++] = in.size(in_i++);
308       } else {
309         const Tensor& index = indices[i].value();
310         if (is_mask_index(index)) {
311           in_i += index.dim();
312         } else {
313           in_i += 1;
314         }
315       }
316     }
317     for (size_t i = num_indexed_dims + num_null_indices; i < in.dim(); i++) {
318       out_sizes[i + broadcast_ndim - num_indexed_dims] = in.size(i);
319     }
320   }
321   return true;
322 }
323 
324 // dim_map maps non-indexed input dimensions to the corresponding output
325 // dimensions. Indexed dimensions are mapped to -1.
compute_dim_map(const Tensor & in,TensorOptList indices,int32_t * dim_map,bool adjacent)326 void compute_dim_map(
327     const Tensor& in,
328     TensorOptList indices,
329     int32_t* dim_map,
330     bool adjacent) {
331   size_t broadcast_ndim = get_indices_broadcast_ndim(indices);
332   size_t start = get_num_leading_null_indices(indices);
333   size_t num_indexed_dims = get_num_indexed_dims(indices);
334   size_t num_null_indices = get_num_null_indices(indices);
335 
336   if (adjacent) {
337     for (auto i = 0; i < start; i++) {
338       dim_map[i] = i;
339     }
340     for (auto i = start; i < start + num_indexed_dims; i++) {
341       dim_map[i] = -1;
342     }
343     for (auto i = start + num_indexed_dims; i < in.dim(); i++) {
344       dim_map[i] = i - num_indexed_dims + broadcast_ndim;
345     }
346   } else {
347     size_t in_i = 0;
348     size_t out_i = broadcast_ndim;
349     for (size_t i = 0; i < indices.size(); i++) {
350       if (!indices[i].has_value()) {
351         dim_map[in_i++] = out_i++;
352       } else {
353         const Tensor& index = indices[i].value();
354         if (is_mask_index(index)) {
355           for (auto j = 0; j < index.dim(); j++) {
356             dim_map[in_i++] = -1;
357           }
358         } else {
359           dim_map[in_i++] = -1;
360         }
361       }
362     }
363     for (size_t i = num_indexed_dims + num_null_indices; i < in.dim(); i++) {
364       dim_map[i] = i - num_indexed_dims + broadcast_ndim;
365     }
366   }
367 }
368 
369 // ix_map maps indexed input dimensions to the corresponding index.
370 // Non-indexed dimensions are mapped to -1.
compute_index_map(const Tensor & in,TensorOptList indices,int32_t * ix_map)371 void compute_index_map(
372     const Tensor& in,
373     TensorOptList indices,
374     int32_t* ix_map) {
375   for (size_t i = 0; i < in.dim(); i++) {
376     ix_map[i] = -1;
377   }
378   size_t in_i = 0;
379   for (size_t i = 0; i < indices.size(); i++) {
380     if (indices[i].has_value()) {
381       const Tensor& index = indices[i].value();
382       if (is_mask_index(index)) {
383         for (auto j = 0; j < index.dim(); j++) {
384           ix_map[in_i++] = i;
385         }
386       } else {
387         ix_map[in_i++] = i;
388       }
389     } else {
390       in_i++;
391     }
392   }
393 }
394 
get_in_coord(const Tensor & in,TensorOptList indices,size_t start,size_t broadcast_ndim,int32_t * dim_map,int32_t * ix_map,size_t * out_coord,size_t * in_coord)395 bool get_in_coord(
396     const Tensor& in,
397     TensorOptList indices,
398     size_t start,
399     size_t broadcast_ndim,
400     int32_t* dim_map,
401     int32_t* ix_map,
402     size_t* out_coord,
403     size_t* in_coord) {
404   for (ssize_t i = 0; i < in.dim(); i++) {
405     if (dim_map[i] >= 0) {
406       in_coord[i] = out_coord[dim_map[i]];
407     } else {
408       const Tensor& index = indices[ix_map[i]].value();
409 
410       size_t ix_coord[kTensorDimensionLimit];
411       for (auto j = 0; j < broadcast_ndim; j++) {
412         ix_coord[j] = out_coord[j + start];
413       }
414 
415       if (is_mask_index(index)) {
416         size_t query_ix = ix_coord[broadcast_ndim - 1];
417         size_t query_result[kTensorDimensionLimit];
418         query_mask_index(index, query_ix, query_result);
419         for (auto j = 0; j < index.dim(); j++) {
420           in_coord[i + j] = query_result[j];
421         }
422         i += index.dim() - 1;
423       } else {
424         int64_t index_val =
425             query_integral_index(index, ix_coord, broadcast_ndim);
426         if (index_val < 0) {
427           index_val += in.size(i);
428         }
429         ET_LOG_MSG_AND_RETURN_IF_FALSE(
430             index_val >= 0 && index_val < in.size(i),
431             "Index %" PRId64
432             " is out of bounds for input dimension %zd with size %zd.",
433             index_val,
434             i,
435             in.size(i));
436         in_coord[i] = static_cast<size_t>(index_val);
437       }
438     }
439   }
440   return true;
441 }
442 
get_in_ix(const Tensor & in,TensorOptList indices,Tensor & out,size_t out_ix,size_t start,size_t broadcast_ndim,int32_t * dim_map,int32_t * ix_map)443 std::pair<size_t, bool> get_in_ix(
444     const Tensor& in,
445     TensorOptList indices,
446     Tensor& out,
447     size_t out_ix,
448     size_t start,
449     size_t broadcast_ndim,
450     int32_t* dim_map,
451     int32_t* ix_map) {
452   size_t out_coord[kTensorDimensionLimit];
453   delinearize_index(out_ix, out, out_coord, kTensorDimensionLimit);
454 
455   size_t in_coord[kTensorDimensionLimit];
456   bool success = get_in_coord(
457       in, indices, start, broadcast_ndim, dim_map, ix_map, out_coord, in_coord);
458   if (!success) {
459     return std::make_pair(0, false);
460   }
461   return std::make_pair(coordinateToIndex(in, in_coord), true);
462 }
463 
464 } // namespace executor
465 } // namespace torch
466