xref: /aosp_15_r20/external/executorch/extension/llm/sampler/sampler.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker // This is a modified version of https://github.com/karpathy/llama2.c.git
10*523fa7a6SAndroid Build Coastguard Worker // @lint-ignore-every LICENSELINT
11*523fa7a6SAndroid Build Coastguard Worker /**
12*523fa7a6SAndroid Build Coastguard Worker  * MIT License
13*523fa7a6SAndroid Build Coastguard Worker  *
14*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) 2023 Andrej
15*523fa7a6SAndroid Build Coastguard Worker  *
16*523fa7a6SAndroid Build Coastguard Worker  * Permission is hereby granted, free of charge, to any person obtaining a copy
17*523fa7a6SAndroid Build Coastguard Worker  * of this software and associated documentation files (the "Software"), to deal
18*523fa7a6SAndroid Build Coastguard Worker  * in the Software without restriction, including without limitation the rights
19*523fa7a6SAndroid Build Coastguard Worker  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20*523fa7a6SAndroid Build Coastguard Worker  * copies of the Software, and to permit persons to whom the Software is
21*523fa7a6SAndroid Build Coastguard Worker  * furnished to do so, subject to the following conditions:
22*523fa7a6SAndroid Build Coastguard Worker  *
23*523fa7a6SAndroid Build Coastguard Worker  * The above copyright notice and this permission notice shall be included in
24*523fa7a6SAndroid Build Coastguard Worker  * all copies or substantial portions of the Software.
25*523fa7a6SAndroid Build Coastguard Worker  *
26*523fa7a6SAndroid Build Coastguard Worker  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27*523fa7a6SAndroid Build Coastguard Worker  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28*523fa7a6SAndroid Build Coastguard Worker  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29*523fa7a6SAndroid Build Coastguard Worker  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30*523fa7a6SAndroid Build Coastguard Worker  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31*523fa7a6SAndroid Build Coastguard Worker  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32*523fa7a6SAndroid Build Coastguard Worker  * SOFTWARE.
33*523fa7a6SAndroid Build Coastguard Worker  */
34*523fa7a6SAndroid Build Coastguard Worker 
35*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/sampler/sampler.h>
36*523fa7a6SAndroid Build Coastguard Worker #include <algorithm>
37*523fa7a6SAndroid Build Coastguard Worker 
38*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
39*523fa7a6SAndroid Build Coastguard Worker namespace extension {
40*523fa7a6SAndroid Build Coastguard Worker namespace llm {
41*523fa7a6SAndroid Build Coastguard Worker 
42*523fa7a6SAndroid Build Coastguard Worker // sampler stuff
43*523fa7a6SAndroid Build Coastguard Worker template <typename T>
sample_argmax(T * probabilities)44*523fa7a6SAndroid Build Coastguard Worker int32_t Sampler::sample_argmax(T* probabilities) {
45*523fa7a6SAndroid Build Coastguard Worker   // return the index that has the highest probability
46*523fa7a6SAndroid Build Coastguard Worker   int max_i = 0;
47*523fa7a6SAndroid Build Coastguard Worker   T max_p = probabilities[0];
48*523fa7a6SAndroid Build Coastguard Worker   for (int i = 1; i < vocab_size_; i++) {
49*523fa7a6SAndroid Build Coastguard Worker     if (probabilities[i] > max_p) {
50*523fa7a6SAndroid Build Coastguard Worker       max_i = i;
51*523fa7a6SAndroid Build Coastguard Worker       max_p = probabilities[i];
52*523fa7a6SAndroid Build Coastguard Worker     }
53*523fa7a6SAndroid Build Coastguard Worker   }
54*523fa7a6SAndroid Build Coastguard Worker   return max_i;
55*523fa7a6SAndroid Build Coastguard Worker }
56*523fa7a6SAndroid Build Coastguard Worker 
57*523fa7a6SAndroid Build Coastguard Worker template <typename T>
sample_mult(T * probabilities,float coin)58*523fa7a6SAndroid Build Coastguard Worker int32_t Sampler::sample_mult(T* probabilities, float coin) {
59*523fa7a6SAndroid Build Coastguard Worker   // sample index from probabilities (they must sum to 1!)
60*523fa7a6SAndroid Build Coastguard Worker   // coin is a random number in [0, 1), usually from random_f32()
61*523fa7a6SAndroid Build Coastguard Worker   T cdf = 0.0;
62*523fa7a6SAndroid Build Coastguard Worker   for (int i = 0; i < vocab_size_; i++) {
63*523fa7a6SAndroid Build Coastguard Worker     cdf += probabilities[i];
64*523fa7a6SAndroid Build Coastguard Worker     if (coin < cdf) {
65*523fa7a6SAndroid Build Coastguard Worker       return i;
66*523fa7a6SAndroid Build Coastguard Worker     }
67*523fa7a6SAndroid Build Coastguard Worker   }
68*523fa7a6SAndroid Build Coastguard Worker   return vocab_size_ - 1; // in case of rounding errors
69*523fa7a6SAndroid Build Coastguard Worker }
70*523fa7a6SAndroid Build Coastguard Worker 
71*523fa7a6SAndroid Build Coastguard Worker template <typename T>
sample_topp(T * probabilities,float coin)72*523fa7a6SAndroid Build Coastguard Worker int32_t Sampler::sample_topp(T* probabilities, float coin) {
73*523fa7a6SAndroid Build Coastguard Worker   // top-p sampling (or "nucleus sampling") samples from the smallest set of
74*523fa7a6SAndroid Build Coastguard Worker   // tokens that exceed probability topp. This way we never sample tokens that
75*523fa7a6SAndroid Build Coastguard Worker   // have very low probabilities and are less likely to go "off the rails".
76*523fa7a6SAndroid Build Coastguard Worker   // coin is a random number in [0, 1), usually from random_f32()
77*523fa7a6SAndroid Build Coastguard Worker   int n = vocab_size_;
78*523fa7a6SAndroid Build Coastguard Worker   int n0 = 0;
79*523fa7a6SAndroid Build Coastguard Worker   // quicksort indices in descending order of probabilities
80*523fa7a6SAndroid Build Coastguard Worker   // values smaller than (1 - topp) / (n - 1) cannot be part of the result
81*523fa7a6SAndroid Build Coastguard Worker   // so for efficiency we crop these out as candidates before sorting
82*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<ProbIndex<T>[]> probindex =
83*523fa7a6SAndroid Build Coastguard Worker       std::make_unique<ProbIndex<T>[]>(vocab_size_);
84*523fa7a6SAndroid Build Coastguard Worker 
85*523fa7a6SAndroid Build Coastguard Worker   const float cutoff = (1.0f - topp_) / (n - 1);
86*523fa7a6SAndroid Build Coastguard Worker   for (int i = 0; i < n; i++) {
87*523fa7a6SAndroid Build Coastguard Worker     if (probabilities[i] >= cutoff) {
88*523fa7a6SAndroid Build Coastguard Worker       probindex[n0].index = i;
89*523fa7a6SAndroid Build Coastguard Worker       probindex[n0].prob = probabilities[i];
90*523fa7a6SAndroid Build Coastguard Worker       n0++;
91*523fa7a6SAndroid Build Coastguard Worker     }
92*523fa7a6SAndroid Build Coastguard Worker   }
93*523fa7a6SAndroid Build Coastguard Worker 
94*523fa7a6SAndroid Build Coastguard Worker   auto compare = [](const ProbIndex<T>& a, const ProbIndex<T>& b) {
95*523fa7a6SAndroid Build Coastguard Worker     return a.prob > b.prob;
96*523fa7a6SAndroid Build Coastguard Worker   };
97*523fa7a6SAndroid Build Coastguard Worker   std::sort(probindex.get(), probindex.get() + n0, compare);
98*523fa7a6SAndroid Build Coastguard Worker 
99*523fa7a6SAndroid Build Coastguard Worker   // truncate the list where cumulative probability exceeds topp
100*523fa7a6SAndroid Build Coastguard Worker   T cumulative_prob = 0;
101*523fa7a6SAndroid Build Coastguard Worker   int last_idx = n0 - 1; // in case of rounding errors consider all elements
102*523fa7a6SAndroid Build Coastguard Worker   for (int i = 0; i < n0; i++) {
103*523fa7a6SAndroid Build Coastguard Worker     cumulative_prob += probindex[i].prob;
104*523fa7a6SAndroid Build Coastguard Worker     if (cumulative_prob > topp_) {
105*523fa7a6SAndroid Build Coastguard Worker       last_idx = i;
106*523fa7a6SAndroid Build Coastguard Worker       break; // we've exceeded topp by including last_idx
107*523fa7a6SAndroid Build Coastguard Worker     }
108*523fa7a6SAndroid Build Coastguard Worker   }
109*523fa7a6SAndroid Build Coastguard Worker 
110*523fa7a6SAndroid Build Coastguard Worker   // sample from the truncated list
111*523fa7a6SAndroid Build Coastguard Worker   const T& r = coin * cumulative_prob;
112*523fa7a6SAndroid Build Coastguard Worker   T cdf = 0;
113*523fa7a6SAndroid Build Coastguard Worker   for (int i = 0; i <= last_idx; i++) {
114*523fa7a6SAndroid Build Coastguard Worker     cdf += probindex[i].prob;
115*523fa7a6SAndroid Build Coastguard Worker     if (r < cdf) {
116*523fa7a6SAndroid Build Coastguard Worker       return probindex[i].index;
117*523fa7a6SAndroid Build Coastguard Worker     }
118*523fa7a6SAndroid Build Coastguard Worker   }
119*523fa7a6SAndroid Build Coastguard Worker   return probindex[last_idx].index; // in case of rounding errors
120*523fa7a6SAndroid Build Coastguard Worker }
121*523fa7a6SAndroid Build Coastguard Worker 
Sampler(int vocab_size,float temperature,float topp,unsigned long long rng_seed)122*523fa7a6SAndroid Build Coastguard Worker Sampler::Sampler(
123*523fa7a6SAndroid Build Coastguard Worker     int vocab_size,
124*523fa7a6SAndroid Build Coastguard Worker     float temperature,
125*523fa7a6SAndroid Build Coastguard Worker     float topp,
126*523fa7a6SAndroid Build Coastguard Worker     unsigned long long rng_seed)
127*523fa7a6SAndroid Build Coastguard Worker     : vocab_size_(vocab_size),
128*523fa7a6SAndroid Build Coastguard Worker       inv_temperature_(static_cast<bool>(temperature) ? 1.0f / temperature : 0),
129*523fa7a6SAndroid Build Coastguard Worker       topp_(topp),
130*523fa7a6SAndroid Build Coastguard Worker       rng_state_(rng_seed) {}
131*523fa7a6SAndroid Build Coastguard Worker 
132*523fa7a6SAndroid Build Coastguard Worker template <typename T>
softmax(T * x,int size)133*523fa7a6SAndroid Build Coastguard Worker static void softmax(T* x, int size) {
134*523fa7a6SAndroid Build Coastguard Worker   // find max value (for numerical stability)
135*523fa7a6SAndroid Build Coastguard Worker   T max_val = x[0];
136*523fa7a6SAndroid Build Coastguard Worker   for (int i = 1; i < size; i++) {
137*523fa7a6SAndroid Build Coastguard Worker     if (x[i] > max_val) {
138*523fa7a6SAndroid Build Coastguard Worker       max_val = x[i];
139*523fa7a6SAndroid Build Coastguard Worker     }
140*523fa7a6SAndroid Build Coastguard Worker   }
141*523fa7a6SAndroid Build Coastguard Worker   // exp and sum
142*523fa7a6SAndroid Build Coastguard Worker   T sum = 0;
143*523fa7a6SAndroid Build Coastguard Worker   for (int i = 0; i < size; i++) {
144*523fa7a6SAndroid Build Coastguard Worker     x[i] = expf(x[i] - max_val);
145*523fa7a6SAndroid Build Coastguard Worker     sum += x[i];
146*523fa7a6SAndroid Build Coastguard Worker   }
147*523fa7a6SAndroid Build Coastguard Worker   // normalize
148*523fa7a6SAndroid Build Coastguard Worker   for (int i = 0; i < size; i++) {
149*523fa7a6SAndroid Build Coastguard Worker     x[i] /= sum;
150*523fa7a6SAndroid Build Coastguard Worker   }
151*523fa7a6SAndroid Build Coastguard Worker }
152*523fa7a6SAndroid Build Coastguard Worker 
random_u32(unsigned long long * state)153*523fa7a6SAndroid Build Coastguard Worker static unsigned int random_u32(unsigned long long* state) {
154*523fa7a6SAndroid Build Coastguard Worker   // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
155*523fa7a6SAndroid Build Coastguard Worker   *state ^= *state >> 12;
156*523fa7a6SAndroid Build Coastguard Worker   *state ^= *state << 25;
157*523fa7a6SAndroid Build Coastguard Worker   *state ^= *state >> 27;
158*523fa7a6SAndroid Build Coastguard Worker   return (*state * 0x2545F4914F6CDD1Dull) >> 32;
159*523fa7a6SAndroid Build Coastguard Worker }
160*523fa7a6SAndroid Build Coastguard Worker 
random_f32(unsigned long long * state)161*523fa7a6SAndroid Build Coastguard Worker static float random_f32(unsigned long long* state) { // random float32 in [0,1)
162*523fa7a6SAndroid Build Coastguard Worker   return (random_u32(state) >> 8) / 16777216.0f;
163*523fa7a6SAndroid Build Coastguard Worker }
164*523fa7a6SAndroid Build Coastguard Worker 
165*523fa7a6SAndroid Build Coastguard Worker template <typename T>
sample(T * logits)166*523fa7a6SAndroid Build Coastguard Worker int32_t Sampler::sample(T* logits) {
167*523fa7a6SAndroid Build Coastguard Worker   // sample the token given the logits and some hyperparameters
168*523fa7a6SAndroid Build Coastguard Worker   int next;
169*523fa7a6SAndroid Build Coastguard Worker   if (inv_temperature_ == 0.0f) {
170*523fa7a6SAndroid Build Coastguard Worker     // greedy argmax sampling: take the token with the highest probability
171*523fa7a6SAndroid Build Coastguard Worker     next = sample_argmax(logits);
172*523fa7a6SAndroid Build Coastguard Worker   } else {
173*523fa7a6SAndroid Build Coastguard Worker     // apply the temperature to the logits
174*523fa7a6SAndroid Build Coastguard Worker     for (int q = 0; q < vocab_size_; q++) {
175*523fa7a6SAndroid Build Coastguard Worker       logits[q] *= inv_temperature_;
176*523fa7a6SAndroid Build Coastguard Worker     }
177*523fa7a6SAndroid Build Coastguard Worker     // apply softmax to the logits to get the probabilities for next token
178*523fa7a6SAndroid Build Coastguard Worker     softmax(logits, vocab_size_);
179*523fa7a6SAndroid Build Coastguard Worker     // flip a (float) coin (this is our source of entropy for sampling)
180*523fa7a6SAndroid Build Coastguard Worker     float coin = random_f32(&rng_state_);
181*523fa7a6SAndroid Build Coastguard Worker     // we sample from this distribution to get the next token
182*523fa7a6SAndroid Build Coastguard Worker     if (topp_ <= 0 || topp_ >= 1) {
183*523fa7a6SAndroid Build Coastguard Worker       // simply sample from the predicted probability distribution
184*523fa7a6SAndroid Build Coastguard Worker       next = sample_mult(logits, coin);
185*523fa7a6SAndroid Build Coastguard Worker     } else {
186*523fa7a6SAndroid Build Coastguard Worker       // top-p (nucleus) sampling, clamping the least likely tokens to zero
187*523fa7a6SAndroid Build Coastguard Worker       next = sample_topp(logits, coin);
188*523fa7a6SAndroid Build Coastguard Worker     }
189*523fa7a6SAndroid Build Coastguard Worker   }
190*523fa7a6SAndroid Build Coastguard Worker   return next;
191*523fa7a6SAndroid Build Coastguard Worker }
192*523fa7a6SAndroid Build Coastguard Worker 
193*523fa7a6SAndroid Build Coastguard Worker template int32_t Sampler::sample<float>(float* logits);
194*523fa7a6SAndroid Build Coastguard Worker template int32_t Sampler::sample<exec_aten::Half>(exec_aten::Half* logits);
195*523fa7a6SAndroid Build Coastguard Worker template int32_t Sampler::sample<exec_aten::BFloat16>(
196*523fa7a6SAndroid Build Coastguard Worker     exec_aten::BFloat16* logits);
197*523fa7a6SAndroid Build Coastguard Worker 
198*523fa7a6SAndroid Build Coastguard Worker } // namespace llm
199*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
200*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
201