xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_gather.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cinttypes>
10 #include <cstdint>
11 #include <cstring>
12 
13 #include <executorch/kernels/portable/cpu/util/index_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 using Tensor = exec_aten::Tensor;
21 using ScalarType = exec_aten::ScalarType;
22 
23 namespace {
24 
25 template <typename CTYPE>
gather_helper(const Tensor & in,const Tensor & index,Tensor & out,int64_t dim)26 void gather_helper(
27     const Tensor& in,
28     const Tensor& index,
29     Tensor& out,
30     int64_t dim) {
31   const CTYPE* in_data = in.const_data_ptr<CTYPE>();
32   const long* index_data = index.const_data_ptr<long>();
33   CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
34 
35   if (index.dim() == 0) {
36     out_data[0] = in_data[index_data[0]];
37     return;
38   }
39 
40   for (size_t ix = 0; ix < index.numel(); ++ix) {
41     size_t ix_coord[kTensorDimensionLimit];
42     indexToCoordinate(index, ix, ix_coord);
43 
44     size_t in_coord[kTensorDimensionLimit];
45     for (size_t i = 0; i < out.dim(); ++i) {
46       if (i == dim) {
47         in_coord[i] = index_data[ix];
48       } else {
49         in_coord[i] = ix_coord[i];
50       }
51     }
52 
53     size_t in_ix = coordinateToIndex(in, in_coord);
54     size_t out_ix = coordinateToIndex(out, ix_coord);
55 
56     out_data[out_ix] = in_data[in_ix];
57   }
58 }
59 
60 } // namespace
61 
gather_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,const Tensor & index,bool sparse_grad,Tensor & out)62 Tensor& gather_out(
63     KernelRuntimeContext& ctx,
64     const Tensor& in,
65     int64_t dim,
66     const Tensor& index,
67     bool sparse_grad,
68     Tensor& out) {
69   (void)ctx;
70 
71   ET_KERNEL_CHECK(
72       ctx,
73       check_gather_args(in, dim, index, sparse_grad, out),
74       InvalidArgument,
75       out);
76 
77   if (dim < 0) {
78     dim += nonzero_dim(in);
79   }
80 
81   ET_KERNEL_CHECK(
82       ctx,
83       resize_tensor(out, index.sizes()) == Error::Ok,
84       InvalidArgument,
85       out);
86 
87   constexpr auto name = "gather.out";
88 
89   ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
90     gather_helper<CTYPE>(in, index, out, dim);
91   });
92 
93   return out;
94 }
95 
96 } // namespace native
97 } // namespace executor
98 } // namespace torch
99