1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/runtime/kernel/kernel_includes.h> 12 13 namespace torch { 14 namespace executor { 15 16 using Tensor = exec_aten::Tensor; 17 using TensorOptList = exec_aten::ArrayRef<exec_aten::optional<Tensor>>; 18 19 /** 20 * Performs preliminary checks on the arguments. However, it doesn't check that 21 * the values of integer indices are within the right bounds for the given 22 * input tensor 23 */ 24 bool check_index_args(const Tensor& in, TensorOptList indices, Tensor& out); 25 26 /** 27 * The output shape depends on whether there are null indices in between 28 * non-null indices or not. So, conceptually, the list of indices can be 29 * divided in alternating segments of non-null indices and null indices. 30 * We refer to the segments of non-null indices as blocks. If the indices list 31 * has 0 blocks, it means that the list is empty, or all its elements are null. 32 * If the list has exactly 1 block, it means that all the non-null indices are 33 * contiguous, and there are possibly some null indices at the beginning of the 34 * list and some of at the end. If the list has more than 1 block, it means 35 * there are null indices in between the non-null inidces. 36 * This functions simplu counts the number of blocks (i.e. non-null segments) in 37 * the indices list. 38 */ 39 size_t count_index_blocks(TensorOptList indices); 40 41 /** 42 * Counts the number of true values in a mask index 43 */ 44 size_t count_trues_in_mask_index(const Tensor& index); 45 46 /** 47 * Compute the broadcast shape between the indices 48 */ 49 bool get_indices_broadcast_shape( 50 TensorOptList indices, 51 Tensor::SizesType* ix_sizes, 52 size_t* ix_ndim); 53 54 /** 55 * Compute the dimension of the broadcast shape between the indices 56 */ 57 size_t get_indices_broadcast_ndim(TensorOptList indices); 58 59 /** 60 * Computes the number of dimensions that are being indexed by some non-null 61 * index. 62 */ 63 size_t get_num_indexed_dims(TensorOptList indices); 64 65 /** 66 * Computes the number of null indices 67 */ 68 size_t get_num_null_indices(TensorOptList indices); 69 70 /** 71 * Computes the number of null indices at the beginning of the list 72 */ 73 size_t get_num_leading_null_indices(TensorOptList indices); 74 75 /** 76 * Compute the expected size for the out tensor 77 */ 78 bool get_index_out_target_size( 79 const Tensor& in, 80 TensorOptList indices, 81 bool adjacent, 82 Tensor::SizesType* out_sizes, 83 size_t* out_ndim); 84 85 /** 86 * dim_map maps non-indexed input dimensions to the corresponding output 87 * dimensions. Indexed dimensions are mapped to -1. 88 */ 89 void compute_dim_map( 90 const Tensor& in, 91 TensorOptList indices, 92 int32_t* dim_map, 93 bool adjacent); 94 95 /** 96 * ix_map maps indexed input dimensions to the corresponding index. 97 * Non-indexed dimensions are mapped to -1. 98 */ 99 void compute_index_map( 100 const Tensor& in, 101 TensorOptList indices, 102 int32_t* ix_map); 103 104 /** 105 * Computes the input coordinate corresponding to a given output coordinate 106 */ 107 bool get_in_coord( 108 const Tensor& in, 109 TensorOptList indices, 110 size_t start, 111 size_t broadcast_ndim, 112 int32_t* dim_map, 113 int32_t* ix_map, 114 size_t* out_coord, 115 size_t* in_coord); 116 117 /** 118 * Computes input flat index corresponding to a given output flat index 119 */ 120 std::pair<size_t, bool> get_in_ix( 121 const Tensor& in, 122 TensorOptList indices, 123 Tensor& out, 124 size_t out_ix, 125 size_t start, 126 size_t broadcast_ndim, 127 int32_t* dim_map, 128 int32_t* ix_map); 129 130 } // namespace executor 131 } // namespace torch 132