xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_softmax.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 
11 #include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
12 #include <executorch/kernels/portable/cpu/util/functional_util.h>
13 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 using Tensor = exec_aten::Tensor;
21 
softmax_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,bool half_to_float,Tensor & out)22 Tensor& softmax_out(
23     KernelRuntimeContext& ctx,
24     const Tensor& in,
25     int64_t dim,
26     bool half_to_float,
27     Tensor& out) {
28   (void)ctx;
29 
30   ET_KERNEL_CHECK(
31       ctx,
32       check_softmax_args(in, dim, half_to_float, out),
33       InvalidArgument,
34       out);
35 
36   ET_KERNEL_CHECK(
37       ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
38 
39   ET_KERNEL_CHECK(
40       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
41 
42   // Adjust for negative dim
43   dim = dim < 0 ? dim + nonzero_dim(in) : dim;
44 
45   ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() {
46     const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
47     CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
48 
49     apply_over_dim(
50         [in_data, out_data](
51             const size_t size, const size_t stride, const size_t base) {
52           // calculate max in softmax dim. During softmax computation each
53           // value is subtracted by the maximum in value before calling exp
54           // to preserve numerical stability.
55           const CTYPE max_in = apply_unary_reduce_fn(
56               [](const CTYPE val_in, CTYPE val_accum) {
57                 return std::max(val_in, val_accum);
58               },
59               in_data + base,
60               size,
61               stride);
62 
63           const CTYPE temp_sum = apply_unary_map_reduce_fn<CTYPE, CTYPE>(
64               [max_in](const CTYPE val_in) {
65                 return std::exp(val_in - max_in);
66               },
67               [](const CTYPE mapped_in, CTYPE val_accum) {
68                 return val_accum + mapped_in;
69               },
70               in_data + base,
71               size,
72               stride);
73 
74           apply_unary_map_fn(
75               [max_in, temp_sum](const CTYPE val_in) {
76                 return std::exp(val_in - max_in) / temp_sum;
77               },
78               in_data + base,
79               out_data + base,
80               size,
81               stride);
82         },
83         in,
84         dim);
85   });
86 
87   return out;
88 }
89 
90 } // namespace native
91 } // namespace executor
92 } // namespace torch
93