1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "utils/math/softmax.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <limits>
20*993b0882SAndroid Build Coastguard Worker
21*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
22*993b0882SAndroid Build Coastguard Worker #include "utils/math/fastexp.h"
23*993b0882SAndroid Build Coastguard Worker
24*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
25*993b0882SAndroid Build Coastguard Worker
ComputeSoftmaxProbability(const std::vector<float> & scores,int label)26*993b0882SAndroid Build Coastguard Worker float ComputeSoftmaxProbability(const std::vector<float> &scores, int label) {
27*993b0882SAndroid Build Coastguard Worker if ((label < 0) || (label >= scores.size())) {
28*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "label " << label << " outside range "
29*993b0882SAndroid Build Coastguard Worker << "[0, " << scores.size() << ")";
30*993b0882SAndroid Build Coastguard Worker return 0.0f;
31*993b0882SAndroid Build Coastguard Worker }
32*993b0882SAndroid Build Coastguard Worker
33*993b0882SAndroid Build Coastguard Worker // Standard softmax formula for label's probability is
34*993b0882SAndroid Build Coastguard Worker //
35*993b0882SAndroid Build Coastguard Worker // exp(scores[label]) / sum_i exp(scores[i])
36*993b0882SAndroid Build Coastguard Worker //
37*993b0882SAndroid Build Coastguard Worker // We compute the mathematically equivalent
38*993b0882SAndroid Build Coastguard Worker //
39*993b0882SAndroid Build Coastguard Worker // 1 / (1 + sum_{i != label} exp(scores[i] - scores[label]))
40*993b0882SAndroid Build Coastguard Worker //
41*993b0882SAndroid Build Coastguard Worker // which saves two calls to exp().
42*993b0882SAndroid Build Coastguard Worker const float label_score = scores[label];
43*993b0882SAndroid Build Coastguard Worker float denominator = 1.0f; // Contribution of i == label.
44*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < scores.size(); ++i) {
45*993b0882SAndroid Build Coastguard Worker if (i == label) continue;
46*993b0882SAndroid Build Coastguard Worker const float delta_score = scores[i] - label_score;
47*993b0882SAndroid Build Coastguard Worker
48*993b0882SAndroid Build Coastguard Worker // TODO(salcianu): one can optimize the test below, to avoid any float
49*993b0882SAndroid Build Coastguard Worker // operation: extract exponent (via bit mask + shift) and check it's >= 4.
50*993b0882SAndroid Build Coastguard Worker if (fabs(delta_score) >= 16.0f) {
51*993b0882SAndroid Build Coastguard Worker if (delta_score > 0.0f) {
52*993b0882SAndroid Build Coastguard Worker // If delta_score >= 16, the denominator (e^delta_score + other positive
53*993b0882SAndroid Build Coastguard Worker // terms) is very big and its inverse can be approximated with 0.
54*993b0882SAndroid Build Coastguard Worker return 0.0f;
55*993b0882SAndroid Build Coastguard Worker } else {
56*993b0882SAndroid Build Coastguard Worker // If delta_score <= -16, then e^delta_score < 1.2e-7. Even if we have
57*993b0882SAndroid Build Coastguard Worker // 1000 such labels i, their sum is < 1.2e-4 (which gets summed with
58*993b0882SAndroid Build Coastguard Worker // 1.0f for i == label). Hence, we can approximate each such label with
59*993b0882SAndroid Build Coastguard Worker // 0 and skip the call to VeryFastExp and the update to denominator.
60*993b0882SAndroid Build Coastguard Worker continue;
61*993b0882SAndroid Build Coastguard Worker }
62*993b0882SAndroid Build Coastguard Worker }
63*993b0882SAndroid Build Coastguard Worker
64*993b0882SAndroid Build Coastguard Worker // At this point, delta_score is in (-16.0, 16.0). For such values, vfexp
65*993b0882SAndroid Build Coastguard Worker // works fine: no under/overflows (we have tests for that in fastexp_test).
66*993b0882SAndroid Build Coastguard Worker // Also, even for 1000 labels, denominator will not overflow.
67*993b0882SAndroid Build Coastguard Worker denominator += VeryFastExp(delta_score);
68*993b0882SAndroid Build Coastguard Worker }
69*993b0882SAndroid Build Coastguard Worker return 1.0f / denominator;
70*993b0882SAndroid Build Coastguard Worker }
71*993b0882SAndroid Build Coastguard Worker
ComputeSoftmax(const std::vector<float> & scores)72*993b0882SAndroid Build Coastguard Worker std::vector<float> ComputeSoftmax(const std::vector<float> &scores) {
73*993b0882SAndroid Build Coastguard Worker return ComputeSoftmax(scores.data(), scores.size());
74*993b0882SAndroid Build Coastguard Worker }
75*993b0882SAndroid Build Coastguard Worker
ComputeSoftmax(const float * scores,int scores_size)76*993b0882SAndroid Build Coastguard Worker std::vector<float> ComputeSoftmax(const float *scores, int scores_size) {
77*993b0882SAndroid Build Coastguard Worker std::vector<float> softmax;
78*993b0882SAndroid Build Coastguard Worker std::vector<float> exp_scores;
79*993b0882SAndroid Build Coastguard Worker exp_scores.reserve(scores_size);
80*993b0882SAndroid Build Coastguard Worker softmax.reserve(scores_size);
81*993b0882SAndroid Build Coastguard Worker
82*993b0882SAndroid Build Coastguard Worker // Find max value in "scores" vector and rescale to avoid overflows.
83*993b0882SAndroid Build Coastguard Worker float max = std::numeric_limits<float>::min();
84*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < scores_size; ++i) {
85*993b0882SAndroid Build Coastguard Worker const float score = scores[i];
86*993b0882SAndroid Build Coastguard Worker if (score > max) max = score;
87*993b0882SAndroid Build Coastguard Worker }
88*993b0882SAndroid Build Coastguard Worker float denominator = 0;
89*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < scores_size; ++i) {
90*993b0882SAndroid Build Coastguard Worker const float score = scores[i];
91*993b0882SAndroid Build Coastguard Worker // See comments above in ComputeSoftmaxProbability for the reasoning behind
92*993b0882SAndroid Build Coastguard Worker // this approximation.
93*993b0882SAndroid Build Coastguard Worker const float exp_score = score - max < -16.0f ? 0 : VeryFastExp(score - max);
94*993b0882SAndroid Build Coastguard Worker exp_scores.push_back(exp_score);
95*993b0882SAndroid Build Coastguard Worker denominator += exp_score;
96*993b0882SAndroid Build Coastguard Worker }
97*993b0882SAndroid Build Coastguard Worker
98*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < scores_size; ++i) {
99*993b0882SAndroid Build Coastguard Worker softmax.push_back(exp_scores[i] / denominator);
100*993b0882SAndroid Build Coastguard Worker }
101*993b0882SAndroid Build Coastguard Worker return softmax;
102*993b0882SAndroid Build Coastguard Worker }
103*993b0882SAndroid Build Coastguard Worker
104*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
105