xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/advanced_index_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 
16 using Tensor = exec_aten::Tensor;
17 using TensorOptList = exec_aten::ArrayRef<exec_aten::optional<Tensor>>;
18 
19 /**
20  * Performs preliminary checks on the arguments. However, it doesn't check that
21  * the values of integer indices are within the right bounds for the given
22  * input tensor
23  */
24 bool check_index_args(const Tensor& in, TensorOptList indices, Tensor& out);
25 
26 /**
27  * The output shape depends on whether there are null indices in between
28  * non-null indices or not. So, conceptually, the list of indices can be
29  * divided in alternating segments of non-null indices and null indices.
30  * We refer to the segments of non-null indices as blocks. If the indices list
31  * has 0 blocks, it means that the list is empty, or all its elements are null.
32  * If the list has exactly 1 block, it means that all the non-null indices are
33  * contiguous, and there are possibly some null indices at the beginning of the
34  * list and some of at the end. If the list has more than 1 block, it means
35  * there are null indices in between the non-null inidces.
36  * This functions simplu counts the number of blocks (i.e. non-null segments) in
37  * the indices list.
38  */
39 size_t count_index_blocks(TensorOptList indices);
40 
41 /**
42  * Counts the number of true values in a mask index
43  */
44 size_t count_trues_in_mask_index(const Tensor& index);
45 
46 /**
47  * Compute the broadcast shape between the indices
48  */
49 bool get_indices_broadcast_shape(
50     TensorOptList indices,
51     Tensor::SizesType* ix_sizes,
52     size_t* ix_ndim);
53 
54 /**
55  * Compute the dimension of the broadcast shape between the indices
56  */
57 size_t get_indices_broadcast_ndim(TensorOptList indices);
58 
59 /**
60  * Computes the number of dimensions that are being indexed by some non-null
61  * index.
62  */
63 size_t get_num_indexed_dims(TensorOptList indices);
64 
65 /**
66  * Computes the number of null indices
67  */
68 size_t get_num_null_indices(TensorOptList indices);
69 
70 /**
71  * Computes the number of null indices at the beginning of the list
72  */
73 size_t get_num_leading_null_indices(TensorOptList indices);
74 
75 /**
76  * Compute the expected size for the out tensor
77  */
78 bool get_index_out_target_size(
79     const Tensor& in,
80     TensorOptList indices,
81     bool adjacent,
82     Tensor::SizesType* out_sizes,
83     size_t* out_ndim);
84 
85 /**
86  * dim_map maps non-indexed input dimensions to the corresponding output
87  * dimensions. Indexed dimensions are mapped to -1.
88  */
89 void compute_dim_map(
90     const Tensor& in,
91     TensorOptList indices,
92     int32_t* dim_map,
93     bool adjacent);
94 
95 /**
96  * ix_map maps indexed input dimensions to the corresponding index.
97  * Non-indexed dimensions are mapped to -1.
98  */
99 void compute_index_map(
100     const Tensor& in,
101     TensorOptList indices,
102     int32_t* ix_map);
103 
104 /**
105  * Computes the input coordinate corresponding to a given output coordinate
106  */
107 bool get_in_coord(
108     const Tensor& in,
109     TensorOptList indices,
110     size_t start,
111     size_t broadcast_ndim,
112     int32_t* dim_map,
113     int32_t* ix_map,
114     size_t* out_coord,
115     size_t* in_coord);
116 
117 /**
118  * Computes input flat index corresponding to a given output flat index
119  */
120 std::pair<size_t, bool> get_in_ix(
121     const Tensor& in,
122     TensorOptList indices,
123     Tensor& out,
124     size_t out_ix,
125     size_t start,
126     size_t broadcast_ndim,
127     int32_t* dim_map,
128     int32_t* ix_map);
129 
130 } // namespace executor
131 } // namespace torch
132