xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_index_put.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/kernels/portable/cpu/util/advanced_index_util.h>
12 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using Tensor = exec_aten::Tensor;
20 
index_put_out(KernelRuntimeContext & ctx,const Tensor & in,exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,const Tensor & values,const bool accumulate,Tensor & out)21 Tensor& index_put_out(
22     KernelRuntimeContext& ctx,
23     const Tensor& in,
24     exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
25     const Tensor& values,
26     const bool accumulate,
27     Tensor& out) {
28   (void)ctx;
29 
30   ET_KERNEL_CHECK(
31       ctx, check_index_args(in, indices, out), InvalidArgument, out);
32 
33   ET_KERNEL_CHECK(
34       ctx, tensors_have_same_dtype(in, values), InvalidArgument, out);
35 
36   ET_KERNEL_CHECK(
37       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
38 
39   ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
40 
41   ScalarType in_type = in.scalar_type();
42   size_t block_count = count_index_blocks(indices);
43 
44   // If indices list is empty or all indices are null, then the operation is
45   // performed over then entire input tensor. So, this is equivalent to
46   // out = values when accumulate is false. Otherwise, the operation is
47   // out = in + values where accumulate is true.
48   if (block_count == 0) {
49     ET_KERNEL_CHECK(
50         ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
51 
52     // Check that values tensors can be broadcasted to out
53     ET_KERNEL_CHECK(
54         ctx, tensor_is_broadcastable_to(values, out), InvalidArgument, out);
55 
56     ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() {
57       apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
58           [accumulate](const CTYPE val_in, const CTYPE val) {
59             return accumulate ? val_in + val : val;
60           },
61           in,
62           values,
63           out);
64     });
65     return out;
66   }
67 
68   // The index output shape depends on whether all the non-null indices are
69   // adjacent or not.
70   bool adjacent = (block_count == 1);
71 
72   // Compute the expected index output shape.
73   Tensor::SizesType x_sizes[kTensorDimensionLimit];
74   size_t x_dim = 0;
75   ET_KERNEL_CHECK(
76       ctx,
77       get_index_out_target_size(in, indices, adjacent, x_sizes, &x_dim),
78       InvalidArgument,
79       out);
80 
81   // Check that values tensors can be broadcasted to indexing result
82   ET_KERNEL_CHECK(
83       ctx,
84       tensor_is_broadcastable_to(values.sizes(), {x_sizes, x_dim}),
85       InvalidArgument,
86       out);
87 
88   ET_KERNEL_CHECK(
89       ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
90 
91   // No further action if the input is empty
92   if (in.numel() == 0) {
93     return out;
94   }
95 
96   // To start, copy the input data into the out tensor
97   memcpy(out.mutable_data_ptr<char>(), in.const_data_ptr<char>(), in.nbytes());
98 
99   // In what follows, `x = in[indices]`. This tensor is implicit, and it would
100   // be much easier to be able to allocate memory, and then call index.Tensor
101   // to compute `x`. But since we can't do that, we have to keep track of its
102   // shape, number of dimensions, number of elements, and use it to translate
103   // coordinates from `x` to `in`.
104 
105   // Compute the dim_map and ix_map needed for `x -> in` coordinate translation
106   int32_t dim_map[kTensorDimensionLimit];
107   int32_t ix_map[kTensorDimensionLimit];
108   size_t start = 0;
109 
110   if (adjacent) {
111     start = get_num_leading_null_indices(indices);
112   }
113   size_t bc_ndim = get_indices_broadcast_ndim(indices);
114   compute_dim_map(in, indices, dim_map, block_count == 1);
115   compute_index_map(in, indices, ix_map);
116 
117   // Compute the number of elements in the indexed space
118   size_t x_numel = 1;
119   for (size_t i = 0; i < x_dim; i++) {
120     x_numel *= x_sizes[i];
121   }
122 
123   ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() {
124     const CTYPE* const values_data = values.const_data_ptr<CTYPE>();
125     CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
126 
127     for (auto x_ix = 0; x_ix < x_numel; x_ix++) {
128       size_t in_ix = 0;
129 
130       size_t x_coord[kTensorDimensionLimit];
131       delinearize_index(x_ix, {x_sizes, x_dim}, x_coord, kTensorDimensionLimit);
132 
133       size_t in_coord[kTensorDimensionLimit];
134 
135       ET_KERNEL_CHECK(
136           ctx,
137           get_in_coord(
138               in, indices, start, bc_ndim, dim_map, ix_map, x_coord, in_coord),
139           InvalidArgument, );
140 
141       in_ix = coordinateToIndex(in, in_coord);
142 
143       // Braodcast values
144       size_t val_ix = linearize_access_indexes(x_coord, x_dim, values);
145       if (accumulate) {
146         out_data[in_ix] += values_data[val_ix];
147       } else {
148         out_data[in_ix] = values_data[val_ix];
149       }
150     }
151   });
152 
153   return out;
154 }
155 
156 } // namespace native
157 } // namespace executor
158 } // namespace torch
159