1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/dtype_util.h>
10 #include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 #include <executorch/runtime/platform/assert.h>
13 #include <cmath>
14 #include <cstddef>
15 // #include <cstdint>
16 // #include <type_traits>
17
18 namespace torch {
19 namespace executor {
20 namespace native {
21
22 using Tensor = exec_aten::Tensor;
23 using ScalarType = exec_aten::ScalarType;
24
25 namespace {
26 /**
27 * Returns the cumulative sum of elements of input in the dimension dim.
28 *
29 * Given a self tensor whose size is (d1, d2, .., d_dim, ..., dm), and does
30 * cumsum along dim, we first copy all values in self[d1, d2, .., 0, ..., dm]
31 * to out[d1, d2, .., 0, ..., dm] since no cumsum should be done for the
32 * first element. Then calculate all out[d1, d2, .., i, ..., dm] by adding
33 * out[d1, d2, .., i-1, ..., dm] and self[d1, d2, .., i-1, ..., dm].
34 * This approach ensures that computations are sequential rather than jumpy at
35 * the memory level, thereby increasing the speed of memory IO as
36 * well as reducing the number of cache misses.
37 */
38 template <typename CTYPE_OUT, typename LoadFn = CTYPE_OUT (*)(const void*)>
cumsum_tensors(const Tensor & self,LoadFn load_self,int64_t dim,Tensor & out)39 void cumsum_tensors(
40 const Tensor& self,
41 LoadFn load_self,
42 int64_t dim,
43 Tensor& out) {
44 if (self.numel() == 0) {
45 return;
46 }
47
48 const char* const input_data_base =
49 reinterpret_cast<const char*>(self.const_data_ptr());
50 CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
51
52 if (self.dim() == 0) {
53 output_data_base[0] = load_self(&input_data_base[0]);
54 return;
55 }
56
57 const size_t dim_size = static_cast<size_t>(self.size(dim));
58 const size_t leading_dims = getLeadingDims(self, dim);
59 const size_t trailing_dims = getTrailingDims(self, dim);
60
61 for (size_t i = 0; i < leading_dims; i++) {
62 size_t start_loc = i * (trailing_dims * dim_size);
63
64 for (size_t idx = 0; idx < trailing_dims; idx++) {
65 output_data_base[start_loc + idx] =
66 load_self(&input_data_base[(start_loc + idx) * self.element_size()]);
67 }
68
69 for (size_t j = 1; j < dim_size; j++) {
70 size_t cur_round_base = start_loc + j * trailing_dims;
71 size_t prev_round_base = start_loc + (j - 1) * trailing_dims;
72 for (size_t idx = 0; idx < trailing_dims; idx++) {
73 output_data_base[cur_round_base + idx] =
74 load_self(&input_data_base
75 [(cur_round_base + idx) * self.element_size()]) +
76 output_data_base[prev_round_base + idx];
77 }
78 }
79 }
80 }
81
82 } // namespace
83
84 /**
85 * Returns the cumulative sum of elements of input in the dimension dim.
86 * If dtype is specified, the input tensor is casted to dtype before the
87 * operation is performed. This is useful for preventing data type overflows.
88 */
cumsum_out(KernelRuntimeContext & ctx,const Tensor & self,int64_t dim,optional<ScalarType> enforced_dtype,Tensor & out)89 Tensor& cumsum_out(
90 KernelRuntimeContext& ctx,
91 const Tensor& self,
92 int64_t dim,
93 optional<ScalarType> enforced_dtype,
94 Tensor& out) {
95 (void)ctx;
96
97 ET_KERNEL_CHECK(
98 ctx,
99 check_cumsum_args(self, dim, enforced_dtype, out),
100 InvalidArgument,
101 out);
102
103 ET_KERNEL_CHECK(
104 ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
105
106 ET_KERNEL_CHECK(
107 ctx, resize_tensor(out, self.sizes()) == Error::Ok, InvalidArgument, out);
108
109 dim = (self.dim() == 0) ? 0 : dim < 0 ? dim + self.dim() : dim;
110
111 // @lint-ignore CLANGTIDY facebook-hte-CArray
112 static constexpr const char op_name[] = "cumsum.out";
113
114 ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
115 const auto load_self =
116 utils::internal::get_load_to_common_fn<CTYPE_OUT, op_name>(
117 self, utils::SupportedTensorDtypes::REALHBBF16);
118 cumsum_tensors<CTYPE_OUT>(self, load_self, dim, out);
119 });
120
121 return out;
122 }
123
124 } // namespace native
125 } // namespace executor
126 } // namespace torch
127