1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #ifdef __aarch64__
10 #include <arm_neon.h>
11 #include <sleef.h>
12 #endif
13
14 #include <cmath>
15 #include <type_traits>
16
17 #include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
18 #include <executorch/runtime/kernel/kernel_includes.h>
19
20 // `_log_softmax_out` Applies the Log_Softmax function to an n-dimensional input
21 // Tensor rescaling them so that the elements of the n-dimensional output
22 // Tensor.
23
24 namespace torch {
25 namespace executor {
26 namespace native {
27
28 using Tensor = exec_aten::Tensor;
29 namespace {
30
31 template <typename IN_T, typename OUT_T>
log_softmax_kernel(const Tensor & input,int64_t dim,Tensor & out)32 void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
33 const IN_T* __restrict__ input_data_base = input.const_data_ptr<IN_T>();
34 OUT_T* __restrict__ output_data_base = out.mutable_data_ptr<OUT_T>();
35
36 if (input.dim() == 0) {
37 output_data_base[0] = 0;
38 return;
39 }
40
41 int64_t dim_size = input.size(dim);
42
43 int64_t outer_size = 1;
44 int64_t inner_size = 1;
45 for (int64_t i = 0; i < dim; ++i) {
46 outer_size *= input.size(i);
47 }
48 for (int64_t i = dim + 1; i < input.dim(); ++i) {
49 inner_size *= input.size(i);
50 }
51
52 int64_t dim_stride = inner_size;
53 int64_t outer_stride = dim_size * dim_stride;
54
55 for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
56 for (size_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) {
57 const IN_T* input_data =
58 input_data_base + outer_idx * outer_stride + inner_idx;
59 OUT_T* output_data =
60 output_data_base + outer_idx * outer_stride + inner_idx;
61
62 // calculate max in softmax dim
63 IN_T max_input = input_data[0];
64 for (auto d = 0; d < dim_size; ++d) {
65 max_input = std::max(max_input, input_data[d * dim_stride]);
66 }
67 // calculate sum and exponential in softmax dim
68 OUT_T temp_sum = 0;
69 #ifndef __aarch64__
70 for (auto d = 0; d < dim_size; ++d) {
71 output_data[d * dim_stride] =
72 std::exp(input_data[d * dim_stride] - max_input);
73 temp_sum += output_data[d * dim_stride];
74 }
75 #else
76 auto d = 0;
77 for (; d + 4 < dim_size; d += 4) {
78 auto index = d * dim_stride;
79 float32x4_t in =
80 vld1q_f32(static_cast<const float*>(&input_data[index]));
81 float32x4_t out_ =
82 Sleef_expf4_u10(vsubq_f32(in, vmovq_n_f32(max_input)));
83 vst1q_f32(static_cast<float*>(&output_data[index]), out_);
84 temp_sum += vaddvq_f32(out_);
85 }
86
87 for (; d < dim_size; ++d) {
88 output_data[d * dim_stride] =
89 std::exp(input_data[d * dim_stride] - max_input);
90 temp_sum += output_data[d * dim_stride];
91 }
92 #endif // __aarch64__
93
94 temp_sum = std::log(temp_sum);
95
96 for (auto dd = 0; dd < dim_size; ++dd) {
97 output_data[dd * dim_stride] =
98 input_data[dd * dim_stride] - max_input - temp_sum;
99 }
100 }
101 }
102 }
103
104 // OUT_T is the corresponding C++ type for out.scalar_type(). Only takes float
105 // or double.
106 template <
107 typename OUT_T,
108 std::enable_if_t<std::is_floating_point<OUT_T>::value, bool> = true>
log_softmax_wrapper(const Tensor & X,int64_t dim,Tensor & out)109 void log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) {
110 auto input_scalar_type = X.scalar_type();
111 switch (input_scalar_type) {
112 // TODO: support Double as well
113 case ScalarType::Float:
114 log_softmax_kernel<float, OUT_T>(X, dim, out);
115 break;
116 default:
117 ET_CHECK_MSG(
118 false,
119 "Unhandled input dtype %" PRId8,
120 static_cast<int8_t>(input_scalar_type));
121 }
122 }
123 } // namespace
124
125 // _log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out)
126 // -> Tensor(a!)
opt_log_softmax_out(KernelRuntimeContext & context,const Tensor & self,int64_t dim,bool half_to_float,Tensor & out)127 Tensor& opt_log_softmax_out(
128 KernelRuntimeContext& context,
129 const Tensor& self,
130 int64_t dim,
131 bool half_to_float,
132 Tensor& out) {
133 (void)context;
134
135 ET_KERNEL_CHECK(
136 context,
137 check_log_softmax_args(self, dim, half_to_float, out),
138 InvalidArgument,
139 out);
140
141 ET_KERNEL_CHECK(
142 context,
143 resize_tensor(out, self.sizes()) == Error::Ok,
144 InvalidArgument,
145 out);
146
147 dim = dim < 0 ? dim + nonzero_dim(self) : dim;
148
149 auto out_scalar_type = out.scalar_type();
150 switch (out_scalar_type) {
151 // TODO: support Double as well
152 case ScalarType::Float:
153 log_softmax_wrapper<float>(self, dim, out);
154 break;
155 default:
156 ET_CHECK_MSG(
157 false,
158 "Unhandled out dtype %" PRId8,
159 static_cast<int8_t>(out_scalar_type));
160 }
161 return out;
162 }
163
164 } // namespace native
165 } // namespace executor
166 } // namespace torch
167