xref: /aosp_15_r20/external/executorch/kernels/optimized/cpu/op_log_softmax.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #ifdef __aarch64__
10 #include <arm_neon.h>
11 #include <sleef.h>
12 #endif
13 
14 #include <cmath>
15 #include <type_traits>
16 
17 #include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
18 #include <executorch/runtime/kernel/kernel_includes.h>
19 
20 // `_log_softmax_out` Applies the Log_Softmax function to an n-dimensional input
21 // Tensor rescaling them so that the elements of the n-dimensional output
22 // Tensor.
23 
24 namespace torch {
25 namespace executor {
26 namespace native {
27 
28 using Tensor = exec_aten::Tensor;
29 namespace {
30 
31 template <typename IN_T, typename OUT_T>
log_softmax_kernel(const Tensor & input,int64_t dim,Tensor & out)32 void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
33   const IN_T* __restrict__ input_data_base = input.const_data_ptr<IN_T>();
34   OUT_T* __restrict__ output_data_base = out.mutable_data_ptr<OUT_T>();
35 
36   if (input.dim() == 0) {
37     output_data_base[0] = 0;
38     return;
39   }
40 
41   int64_t dim_size = input.size(dim);
42 
43   int64_t outer_size = 1;
44   int64_t inner_size = 1;
45   for (int64_t i = 0; i < dim; ++i) {
46     outer_size *= input.size(i);
47   }
48   for (int64_t i = dim + 1; i < input.dim(); ++i) {
49     inner_size *= input.size(i);
50   }
51 
52   int64_t dim_stride = inner_size;
53   int64_t outer_stride = dim_size * dim_stride;
54 
55   for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
56     for (size_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) {
57       const IN_T* input_data =
58           input_data_base + outer_idx * outer_stride + inner_idx;
59       OUT_T* output_data =
60           output_data_base + outer_idx * outer_stride + inner_idx;
61 
62       // calculate max in softmax dim
63       IN_T max_input = input_data[0];
64       for (auto d = 0; d < dim_size; ++d) {
65         max_input = std::max(max_input, input_data[d * dim_stride]);
66       }
67       // calculate sum and exponential in softmax dim
68       OUT_T temp_sum = 0;
69 #ifndef __aarch64__
70       for (auto d = 0; d < dim_size; ++d) {
71         output_data[d * dim_stride] =
72             std::exp(input_data[d * dim_stride] - max_input);
73         temp_sum += output_data[d * dim_stride];
74       }
75 #else
76       auto d = 0;
77       for (; d + 4 < dim_size; d += 4) {
78         auto index = d * dim_stride;
79         float32x4_t in =
80             vld1q_f32(static_cast<const float*>(&input_data[index]));
81         float32x4_t out_ =
82             Sleef_expf4_u10(vsubq_f32(in, vmovq_n_f32(max_input)));
83         vst1q_f32(static_cast<float*>(&output_data[index]), out_);
84         temp_sum += vaddvq_f32(out_);
85       }
86 
87       for (; d < dim_size; ++d) {
88         output_data[d * dim_stride] =
89             std::exp(input_data[d * dim_stride] - max_input);
90         temp_sum += output_data[d * dim_stride];
91       }
92 #endif // __aarch64__
93 
94       temp_sum = std::log(temp_sum);
95 
96       for (auto dd = 0; dd < dim_size; ++dd) {
97         output_data[dd * dim_stride] =
98             input_data[dd * dim_stride] - max_input - temp_sum;
99       }
100     }
101   }
102 }
103 
104 // OUT_T is the corresponding C++ type for out.scalar_type(). Only takes float
105 // or double.
106 template <
107     typename OUT_T,
108     std::enable_if_t<std::is_floating_point<OUT_T>::value, bool> = true>
log_softmax_wrapper(const Tensor & X,int64_t dim,Tensor & out)109 void log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) {
110   auto input_scalar_type = X.scalar_type();
111   switch (input_scalar_type) {
112     // TODO: support Double as well
113     case ScalarType::Float:
114       log_softmax_kernel<float, OUT_T>(X, dim, out);
115       break;
116     default:
117       ET_CHECK_MSG(
118           false,
119           "Unhandled input dtype %" PRId8,
120           static_cast<int8_t>(input_scalar_type));
121   }
122 }
123 } // namespace
124 
125 // _log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out)
126 // -> Tensor(a!)
opt_log_softmax_out(KernelRuntimeContext & context,const Tensor & self,int64_t dim,bool half_to_float,Tensor & out)127 Tensor& opt_log_softmax_out(
128     KernelRuntimeContext& context,
129     const Tensor& self,
130     int64_t dim,
131     bool half_to_float,
132     Tensor& out) {
133   (void)context;
134 
135   ET_KERNEL_CHECK(
136       context,
137       check_log_softmax_args(self, dim, half_to_float, out),
138       InvalidArgument,
139       out);
140 
141   ET_KERNEL_CHECK(
142       context,
143       resize_tensor(out, self.sizes()) == Error::Ok,
144       InvalidArgument,
145       out);
146 
147   dim = dim < 0 ? dim + nonzero_dim(self) : dim;
148 
149   auto out_scalar_type = out.scalar_type();
150   switch (out_scalar_type) {
151     // TODO: support Double as well
152     case ScalarType::Float:
153       log_softmax_wrapper<float>(self, dim, out);
154       break;
155     default:
156       ET_CHECK_MSG(
157           false,
158           "Unhandled out dtype %" PRId8,
159           static_cast<int8_t>(out_scalar_type));
160   }
161   return out;
162 }
163 
164 } // namespace native
165 } // namespace executor
166 } // namespace torch
167