xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_cumsum.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/dtype_util.h>
10 #include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 #include <executorch/runtime/platform/assert.h>
13 #include <cmath>
14 #include <cstddef>
15 // #include <cstdint>
16 // #include <type_traits>
17 
18 namespace torch {
19 namespace executor {
20 namespace native {
21 
22 using Tensor = exec_aten::Tensor;
23 using ScalarType = exec_aten::ScalarType;
24 
25 namespace {
26 /**
27  * Returns the cumulative sum of elements of input in the dimension dim.
28  *
29  * Given a self tensor whose size is (d1, d2, .., d_dim, ..., dm), and does
30  * cumsum along dim, we first copy all values in self[d1, d2, .., 0, ..., dm]
31  * to out[d1, d2, .., 0, ..., dm] since no cumsum should be done for the
32  * first element. Then calculate all out[d1, d2, .., i, ..., dm] by adding
33  * out[d1, d2, .., i-1, ..., dm] and self[d1, d2, .., i-1, ..., dm].
34  * This approach ensures that computations are sequential rather than jumpy at
35  * the memory level, thereby increasing the speed of memory IO as
36  * well as reducing the number of cache misses.
37  */
38 template <typename CTYPE_OUT, typename LoadFn = CTYPE_OUT (*)(const void*)>
cumsum_tensors(const Tensor & self,LoadFn load_self,int64_t dim,Tensor & out)39 void cumsum_tensors(
40     const Tensor& self,
41     LoadFn load_self,
42     int64_t dim,
43     Tensor& out) {
44   if (self.numel() == 0) {
45     return;
46   }
47 
48   const char* const input_data_base =
49       reinterpret_cast<const char*>(self.const_data_ptr());
50   CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
51 
52   if (self.dim() == 0) {
53     output_data_base[0] = load_self(&input_data_base[0]);
54     return;
55   }
56 
57   const size_t dim_size = static_cast<size_t>(self.size(dim));
58   const size_t leading_dims = getLeadingDims(self, dim);
59   const size_t trailing_dims = getTrailingDims(self, dim);
60 
61   for (size_t i = 0; i < leading_dims; i++) {
62     size_t start_loc = i * (trailing_dims * dim_size);
63 
64     for (size_t idx = 0; idx < trailing_dims; idx++) {
65       output_data_base[start_loc + idx] =
66           load_self(&input_data_base[(start_loc + idx) * self.element_size()]);
67     }
68 
69     for (size_t j = 1; j < dim_size; j++) {
70       size_t cur_round_base = start_loc + j * trailing_dims;
71       size_t prev_round_base = start_loc + (j - 1) * trailing_dims;
72       for (size_t idx = 0; idx < trailing_dims; idx++) {
73         output_data_base[cur_round_base + idx] =
74             load_self(&input_data_base
75                           [(cur_round_base + idx) * self.element_size()]) +
76             output_data_base[prev_round_base + idx];
77       }
78     }
79   }
80 }
81 
82 } // namespace
83 
84 /**
85  * Returns the cumulative sum of elements of input in the dimension dim.
86  * If dtype is specified, the input tensor is casted to dtype before the
87  * operation is performed. This is useful for preventing data type overflows.
88  */
cumsum_out(KernelRuntimeContext & ctx,const Tensor & self,int64_t dim,optional<ScalarType> enforced_dtype,Tensor & out)89 Tensor& cumsum_out(
90     KernelRuntimeContext& ctx,
91     const Tensor& self,
92     int64_t dim,
93     optional<ScalarType> enforced_dtype,
94     Tensor& out) {
95   (void)ctx;
96 
97   ET_KERNEL_CHECK(
98       ctx,
99       check_cumsum_args(self, dim, enforced_dtype, out),
100       InvalidArgument,
101       out);
102 
103   ET_KERNEL_CHECK(
104       ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
105 
106   ET_KERNEL_CHECK(
107       ctx, resize_tensor(out, self.sizes()) == Error::Ok, InvalidArgument, out);
108 
109   dim = (self.dim() == 0) ? 0 : dim < 0 ? dim + self.dim() : dim;
110 
111   // @lint-ignore CLANGTIDY facebook-hte-CArray
112   static constexpr const char op_name[] = "cumsum.out";
113 
114   ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
115     const auto load_self =
116         utils::internal::get_load_to_common_fn<CTYPE_OUT, op_name>(
117             self, utils::SupportedTensorDtypes::REALHBBF16);
118     cumsum_tensors<CTYPE_OUT>(self, load_self, dim, out);
119   });
120 
121   return out;
122 }
123 
124 } // namespace native
125 } // namespace executor
126 } // namespace torch
127