xref: /aosp_15_r20/external/libtextclassifier/native/utils/math/softmax.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/math/softmax.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <limits>
20*993b0882SAndroid Build Coastguard Worker 
21*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
22*993b0882SAndroid Build Coastguard Worker #include "utils/math/fastexp.h"
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
25*993b0882SAndroid Build Coastguard Worker 
ComputeSoftmaxProbability(const std::vector<float> & scores,int label)26*993b0882SAndroid Build Coastguard Worker float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) {
27*993b0882SAndroid Build Coastguard Worker   if ((label < 0) || (label >= scores.size())) {
28*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "label " << label << " outside range "
29*993b0882SAndroid Build Coastguard Worker                    << "[0, " << scores.size() << ")";
30*993b0882SAndroid Build Coastguard Worker     return 0.0f;
31*993b0882SAndroid Build Coastguard Worker   }
32*993b0882SAndroid Build Coastguard Worker 
33*993b0882SAndroid Build Coastguard Worker   // Standard softmax formula for label's probability is
34*993b0882SAndroid Build Coastguard Worker   //
35*993b0882SAndroid Build Coastguard Worker   //   exp(scores[label]) / sum_i exp(scores[i])
36*993b0882SAndroid Build Coastguard Worker   //
37*993b0882SAndroid Build Coastguard Worker   // We compute the mathematically equivalent
38*993b0882SAndroid Build Coastguard Worker   //
39*993b0882SAndroid Build Coastguard Worker   //   1 / (1 + sum_{i != label} exp(scores[i] - scores[label]))
40*993b0882SAndroid Build Coastguard Worker   //
41*993b0882SAndroid Build Coastguard Worker   // which saves two calls to exp().
42*993b0882SAndroid Build Coastguard Worker   const float label_score = scores[label];
43*993b0882SAndroid Build Coastguard Worker   float denominator = 1.0f;  // Contribution of i == label.
44*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < scores.size(); ++i) {
45*993b0882SAndroid Build Coastguard Worker     if (i == label) continue;
46*993b0882SAndroid Build Coastguard Worker     const float delta_score = scores[i] - label_score;
47*993b0882SAndroid Build Coastguard Worker 
48*993b0882SAndroid Build Coastguard Worker     // TODO(salcianu): one can optimize the test below, to avoid any float
49*993b0882SAndroid Build Coastguard Worker     // operation: extract exponent (via bit mask + shift) and check it's >= 4.
50*993b0882SAndroid Build Coastguard Worker     if (fabs(delta_score) >= 16.0f) {
51*993b0882SAndroid Build Coastguard Worker       if (delta_score > 0.0f) {
52*993b0882SAndroid Build Coastguard Worker         // If delta_score >= 16, the denominator (e^delta_score + other positive
53*993b0882SAndroid Build Coastguard Worker         // terms) is very big and its inverse can be approximated with 0.
54*993b0882SAndroid Build Coastguard Worker         return 0.0f;
55*993b0882SAndroid Build Coastguard Worker       } else {
56*993b0882SAndroid Build Coastguard Worker         // If delta_score <= -16, then e^delta_score < 1.2e-7.  Even if we have
57*993b0882SAndroid Build Coastguard Worker         // 1000 such labels i, their sum is < 1.2e-4 (which gets summed with
58*993b0882SAndroid Build Coastguard Worker         // 1.0f for i == label).  Hence, we can approximate each such label with
59*993b0882SAndroid Build Coastguard Worker         // 0 and skip the call to VeryFastExp and the update to denominator.
60*993b0882SAndroid Build Coastguard Worker         continue;
61*993b0882SAndroid Build Coastguard Worker       }
62*993b0882SAndroid Build Coastguard Worker     }
63*993b0882SAndroid Build Coastguard Worker 
64*993b0882SAndroid Build Coastguard Worker     // At this point, delta_score is in (-16.0, 16.0).  For such values, vfexp
65*993b0882SAndroid Build Coastguard Worker     // works fine: no under/overflows (we have tests for that in fastexp_test).
66*993b0882SAndroid Build Coastguard Worker     // Also, even for 1000 labels, denominator will not overflow.
67*993b0882SAndroid Build Coastguard Worker     denominator += VeryFastExp(delta_score);
68*993b0882SAndroid Build Coastguard Worker   }
69*993b0882SAndroid Build Coastguard Worker   return 1.0f / denominator;
70*993b0882SAndroid Build Coastguard Worker }
71*993b0882SAndroid Build Coastguard Worker 
ComputeSoftmax(const std::vector<float> & scores)72*993b0882SAndroid Build Coastguard Worker std::vector<float> ComputeSoftmax(const std::vector<float> &scores) {
73*993b0882SAndroid Build Coastguard Worker   return ComputeSoftmax(scores.data(), scores.size());
74*993b0882SAndroid Build Coastguard Worker }
75*993b0882SAndroid Build Coastguard Worker 
ComputeSoftmax(const float * scores,int scores_size)76*993b0882SAndroid Build Coastguard Worker std::vector<float> ComputeSoftmax(const float *scores, int scores_size) {
77*993b0882SAndroid Build Coastguard Worker   std::vector<float> softmax;
78*993b0882SAndroid Build Coastguard Worker   std::vector<float> exp_scores;
79*993b0882SAndroid Build Coastguard Worker   exp_scores.reserve(scores_size);
80*993b0882SAndroid Build Coastguard Worker   softmax.reserve(scores_size);
81*993b0882SAndroid Build Coastguard Worker 
82*993b0882SAndroid Build Coastguard Worker   // Find max value in "scores" vector and rescale to avoid overflows.
83*993b0882SAndroid Build Coastguard Worker   float max = std::numeric_limits<float>::min();
84*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < scores_size; ++i) {
85*993b0882SAndroid Build Coastguard Worker     const float score = scores[i];
86*993b0882SAndroid Build Coastguard Worker     if (score > max) max = score;
87*993b0882SAndroid Build Coastguard Worker   }
88*993b0882SAndroid Build Coastguard Worker   float denominator = 0;
89*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < scores_size; ++i) {
90*993b0882SAndroid Build Coastguard Worker     const float score = scores[i];
91*993b0882SAndroid Build Coastguard Worker     // See comments above in ComputeSoftmaxProbability for the reasoning behind
92*993b0882SAndroid Build Coastguard Worker     // this approximation.
93*993b0882SAndroid Build Coastguard Worker     const float exp_score = score - max < -16.0f ? 0 : VeryFastExp(score - max);
94*993b0882SAndroid Build Coastguard Worker     exp_scores.push_back(exp_score);
95*993b0882SAndroid Build Coastguard Worker     denominator += exp_score;
96*993b0882SAndroid Build Coastguard Worker   }
97*993b0882SAndroid Build Coastguard Worker 
98*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < scores_size; ++i) {
99*993b0882SAndroid Build Coastguard Worker     softmax.push_back(exp_scores[i] / denominator);
100*993b0882SAndroid Build Coastguard Worker   }
101*993b0882SAndroid Build Coastguard Worker   return softmax;
102*993b0882SAndroid Build Coastguard Worker }
103*993b0882SAndroid Build Coastguard Worker 
104*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
105