1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cmath>
10
11 #include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
12 #include <executorch/kernels/portable/cpu/util/functional_util.h>
13 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
20 using Tensor = exec_aten::Tensor;
21
softmax_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,bool half_to_float,Tensor & out)22 Tensor& softmax_out(
23 KernelRuntimeContext& ctx,
24 const Tensor& in,
25 int64_t dim,
26 bool half_to_float,
27 Tensor& out) {
28 (void)ctx;
29
30 ET_KERNEL_CHECK(
31 ctx,
32 check_softmax_args(in, dim, half_to_float, out),
33 InvalidArgument,
34 out);
35
36 ET_KERNEL_CHECK(
37 ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
38
39 ET_KERNEL_CHECK(
40 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
41
42 // Adjust for negative dim
43 dim = dim < 0 ? dim + nonzero_dim(in) : dim;
44
45 ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() {
46 const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
47 CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
48
49 apply_over_dim(
50 [in_data, out_data](
51 const size_t size, const size_t stride, const size_t base) {
52 // calculate max in softmax dim. During softmax computation each
53 // value is subtracted by the maximum in value before calling exp
54 // to preserve numerical stability.
55 const CTYPE max_in = apply_unary_reduce_fn(
56 [](const CTYPE val_in, CTYPE val_accum) {
57 return std::max(val_in, val_accum);
58 },
59 in_data + base,
60 size,
61 stride);
62
63 const CTYPE temp_sum = apply_unary_map_reduce_fn<CTYPE, CTYPE>(
64 [max_in](const CTYPE val_in) {
65 return std::exp(val_in - max_in);
66 },
67 [](const CTYPE mapped_in, CTYPE val_accum) {
68 return val_accum + mapped_in;
69 },
70 in_data + base,
71 size,
72 stride);
73
74 apply_unary_map_fn(
75 [max_in, temp_sum](const CTYPE val_in) {
76 return std::exp(val_in - max_in) / temp_sum;
77 },
78 in_data + base,
79 out_data + base,
80 size,
81 stride);
82 },
83 in,
84 dim);
85 });
86
87 return out;
88 }
89
90 } // namespace native
91 } // namespace executor
92 } // namespace torch
93