1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cinttypes>
10 #include <cstdint>
11 #include <cstring>
12
13 #include <executorch/kernels/portable/cpu/util/index_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
20 using Tensor = exec_aten::Tensor;
21 using ScalarType = exec_aten::ScalarType;
22
23 namespace {
24
25 template <typename CTYPE>
gather_helper(const Tensor & in,const Tensor & index,Tensor & out,int64_t dim)26 void gather_helper(
27 const Tensor& in,
28 const Tensor& index,
29 Tensor& out,
30 int64_t dim) {
31 const CTYPE* in_data = in.const_data_ptr<CTYPE>();
32 const long* index_data = index.const_data_ptr<long>();
33 CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
34
35 if (index.dim() == 0) {
36 out_data[0] = in_data[index_data[0]];
37 return;
38 }
39
40 for (size_t ix = 0; ix < index.numel(); ++ix) {
41 size_t ix_coord[kTensorDimensionLimit];
42 indexToCoordinate(index, ix, ix_coord);
43
44 size_t in_coord[kTensorDimensionLimit];
45 for (size_t i = 0; i < out.dim(); ++i) {
46 if (i == dim) {
47 in_coord[i] = index_data[ix];
48 } else {
49 in_coord[i] = ix_coord[i];
50 }
51 }
52
53 size_t in_ix = coordinateToIndex(in, in_coord);
54 size_t out_ix = coordinateToIndex(out, ix_coord);
55
56 out_data[out_ix] = in_data[in_ix];
57 }
58 }
59
60 } // namespace
61
gather_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,const Tensor & index,bool sparse_grad,Tensor & out)62 Tensor& gather_out(
63 KernelRuntimeContext& ctx,
64 const Tensor& in,
65 int64_t dim,
66 const Tensor& index,
67 bool sparse_grad,
68 Tensor& out) {
69 (void)ctx;
70
71 ET_KERNEL_CHECK(
72 ctx,
73 check_gather_args(in, dim, index, sparse_grad, out),
74 InvalidArgument,
75 out);
76
77 if (dim < 0) {
78 dim += nonzero_dim(in);
79 }
80
81 ET_KERNEL_CHECK(
82 ctx,
83 resize_tensor(out, index.sizes()) == Error::Ok,
84 InvalidArgument,
85 out);
86
87 constexpr auto name = "gather.out";
88
89 ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
90 gather_helper<CTYPE>(in, index, out, dim);
91 });
92
93 return out;
94 }
95
96 } // namespace native
97 } // namespace executor
98 } // namespace torch
99