1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/kernels/portable/cpu/util/advanced_index_util.h>
12 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using Tensor = exec_aten::Tensor;
20
index_put_out(KernelRuntimeContext & ctx,const Tensor & in,exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,const Tensor & values,const bool accumulate,Tensor & out)21 Tensor& index_put_out(
22 KernelRuntimeContext& ctx,
23 const Tensor& in,
24 exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
25 const Tensor& values,
26 const bool accumulate,
27 Tensor& out) {
28 (void)ctx;
29
30 ET_KERNEL_CHECK(
31 ctx, check_index_args(in, indices, out), InvalidArgument, out);
32
33 ET_KERNEL_CHECK(
34 ctx, tensors_have_same_dtype(in, values), InvalidArgument, out);
35
36 ET_KERNEL_CHECK(
37 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
38
39 ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
40
41 ScalarType in_type = in.scalar_type();
42 size_t block_count = count_index_blocks(indices);
43
44 // If indices list is empty or all indices are null, then the operation is
45 // performed over then entire input tensor. So, this is equivalent to
46 // out = values when accumulate is false. Otherwise, the operation is
47 // out = in + values where accumulate is true.
48 if (block_count == 0) {
49 ET_KERNEL_CHECK(
50 ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
51
52 // Check that values tensors can be broadcasted to out
53 ET_KERNEL_CHECK(
54 ctx, tensor_is_broadcastable_to(values, out), InvalidArgument, out);
55
56 ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() {
57 apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
58 [accumulate](const CTYPE val_in, const CTYPE val) {
59 return accumulate ? val_in + val : val;
60 },
61 in,
62 values,
63 out);
64 });
65 return out;
66 }
67
68 // The index output shape depends on whether all the non-null indices are
69 // adjacent or not.
70 bool adjacent = (block_count == 1);
71
72 // Compute the expected index output shape.
73 Tensor::SizesType x_sizes[kTensorDimensionLimit];
74 size_t x_dim = 0;
75 ET_KERNEL_CHECK(
76 ctx,
77 get_index_out_target_size(in, indices, adjacent, x_sizes, &x_dim),
78 InvalidArgument,
79 out);
80
81 // Check that values tensors can be broadcasted to indexing result
82 ET_KERNEL_CHECK(
83 ctx,
84 tensor_is_broadcastable_to(values.sizes(), {x_sizes, x_dim}),
85 InvalidArgument,
86 out);
87
88 ET_KERNEL_CHECK(
89 ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
90
91 // No further action if the input is empty
92 if (in.numel() == 0) {
93 return out;
94 }
95
96 // To start, copy the input data into the out tensor
97 memcpy(out.mutable_data_ptr<char>(), in.const_data_ptr<char>(), in.nbytes());
98
99 // In what follows, `x = in[indices]`. This tensor is implicit, and it would
100 // be much easier to be able to allocate memory, and then call index.Tensor
101 // to compute `x`. But since we can't do that, we have to keep track of its
102 // shape, number of dimensions, number of elements, and use it to translate
103 // coordinates from `x` to `in`.
104
105 // Compute the dim_map and ix_map needed for `x -> in` coordinate translation
106 int32_t dim_map[kTensorDimensionLimit];
107 int32_t ix_map[kTensorDimensionLimit];
108 size_t start = 0;
109
110 if (adjacent) {
111 start = get_num_leading_null_indices(indices);
112 }
113 size_t bc_ndim = get_indices_broadcast_ndim(indices);
114 compute_dim_map(in, indices, dim_map, block_count == 1);
115 compute_index_map(in, indices, ix_map);
116
117 // Compute the number of elements in the indexed space
118 size_t x_numel = 1;
119 for (size_t i = 0; i < x_dim; i++) {
120 x_numel *= x_sizes[i];
121 }
122
123 ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() {
124 const CTYPE* const values_data = values.const_data_ptr<CTYPE>();
125 CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
126
127 for (auto x_ix = 0; x_ix < x_numel; x_ix++) {
128 size_t in_ix = 0;
129
130 size_t x_coord[kTensorDimensionLimit];
131 delinearize_index(x_ix, {x_sizes, x_dim}, x_coord, kTensorDimensionLimit);
132
133 size_t in_coord[kTensorDimensionLimit];
134
135 ET_KERNEL_CHECK(
136 ctx,
137 get_in_coord(
138 in, indices, start, bc_ndim, dim_map, ix_map, x_coord, in_coord),
139 InvalidArgument, );
140
141 in_ix = coordinateToIndex(in, in_coord);
142
143 // Braodcast values
144 size_t val_ix = linearize_access_indexes(x_coord, x_dim, values);
145 if (accumulate) {
146 out_data[in_ix] += values_data[val_ix];
147 } else {
148 out_data[in_ix] = values_data[val_ix];
149 }
150 }
151 });
152
153 return out;
154 }
155
156 } // namespace native
157 } // namespace executor
158 } // namespace torch
159