xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/range_sampler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/range_sampler.h"
17 
18 #include <cmath>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/io/inputbuffer.h"
25 #include "tensorflow/core/lib/strings/numbers.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace tensorflow {
32 
33 using gtl::ArraySlice;
34 using gtl::MutableArraySlice;
35 
~RangeSampler()36 RangeSampler::~RangeSampler() {}
37 
SampleBatch(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64_t> batch) const38 void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique,
39                                gtl::MutableArraySlice<int64_t> batch) const {
40   SampleBatchGetExpectedCount(
41       rnd, unique, batch, gtl::MutableArraySlice<float>(),
42       gtl::ArraySlice<int64_t>(), gtl::MutableArraySlice<float>());
43 }
44 
SampleBatchGetExpectedCount(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64_t> batch,gtl::MutableArraySlice<float> batch_expected_count,gtl::ArraySlice<int64_t> extras,gtl::MutableArraySlice<float> extras_expected_count) const45 void RangeSampler::SampleBatchGetExpectedCount(
46     random::SimplePhilox* rnd, bool unique,
47     gtl::MutableArraySlice<int64_t> batch,
48     gtl::MutableArraySlice<float> batch_expected_count,
49     gtl::ArraySlice<int64_t> extras,
50     gtl::MutableArraySlice<float> extras_expected_count) const {
51   SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count,
52                                    extras, extras_expected_count,
53                                    gtl::ArraySlice<int64_t>());
54 }
55 
56 namespace {
57 
58 // Approximates the expected count of a value in the output of SampleBatch.
59 //
60 // If unique=false, then this is (Probability(value) * batch_size)
61 //
62 // We use batch_size and num_tries, where num_tries is the observed number of
63 // tries it took to get batch_size unique values.
64 //
65 // Assuming (falsely) that the number of tries to get a batch of batch_size
66 // distinct values is _always_ num_tries, the probability that the value
67 // is in a batch is (1 - (1-p)^num_tries)
ExpectedCountHelper(float p,int batch_size,int num_tries)68 static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
69   if (num_tries == batch_size) {
70     // This shortcut will always be taken if unique=false
71     return p * batch_size;
72   }
73   // numerically stable version of (1 - (1-p)^num_tries)
74   return -std::expm1(num_tries * std::log1p(-p));
75 }
76 
77 }  // namespace
78 
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64_t> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64_t> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64_t> avoided_values) const79 void RangeSampler::SampleBatchGetExpectedCountAvoid(
80     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch,
81     MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> extras,
82     MutableArraySlice<float> extras_expected_count,
83     ArraySlice<int64_t> avoided_values) const {
84   const int batch_size = batch.size();
85   int num_tries;
86 
87   if (unique) {
88     CHECK_LE(static_cast<int64_t>(batch_size + avoided_values.size()), range_);
89     std::unordered_set<int64_t> used(batch_size);
90     used.insert(avoided_values.begin(), avoided_values.end());
91     int num_picked = 0;
92     num_tries = 0;
93     while (num_picked < batch_size) {
94       num_tries++;
95       CHECK_LT(num_tries, kint32max);
96       int64_t value = Sample(rnd);
97       if (gtl::InsertIfNotPresent(&used, value)) {
98         batch[num_picked++] = value;
99       }
100     }
101   } else {
102     CHECK_EQ(avoided_values.size(), size_t{0})
103         << "avoided_values only supported with unique=true";
104     for (int i = 0; i < batch_size; i++) {
105       batch[i] = Sample(rnd);
106     }
107     num_tries = batch_size;
108   }
109   // Compute the expected counts of the batch and the extra values
110   if (!batch_expected_count.empty()) {
111     CHECK_EQ(batch_size, batch_expected_count.size());
112     for (int i = 0; i < batch_size; i++) {
113       batch_expected_count[i] =
114           ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries);
115     }
116   }
117   CHECK_EQ(extras.size(), extras_expected_count.size());
118   for (size_t i = 0; i < extras.size(); i++) {
119     extras_expected_count[i] =
120         ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries);
121   }
122 }
123 
AllSampler(int64_t range)124 AllSampler::AllSampler(int64_t range) : RangeSampler(range) {}
125 
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64_t> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64_t> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64_t> avoided_values) const126 void AllSampler::SampleBatchGetExpectedCountAvoid(
127     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch,
128     MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> extras,
129     MutableArraySlice<float> extras_expected_count,
130     ArraySlice<int64_t> avoided_values) const {
131   const int batch_size = batch.size();
132   CHECK_EQ(range_, batch_size);
133   for (int i = 0; i < batch_size; i++) {
134     batch[i] = i;
135   }
136   if (!batch_expected_count.empty()) {
137     CHECK_EQ(batch_size, batch_expected_count.size());
138     for (int i = 0; i < batch_size; i++) {
139       batch_expected_count[i] = 1;
140     }
141   }
142   CHECK_EQ(size_t{0}, avoided_values.size());
143   CHECK_EQ(extras.size(), extras_expected_count.size());
144   for (size_t i = 0; i < extras.size(); i++) {
145     extras_expected_count[i] = 1;
146   }
147 }
148 
UniformSampler(int64_t range)149 UniformSampler::UniformSampler(int64_t range)
150     : RangeSampler(range), inv_range_(1.0 / range) {}
151 
Sample(random::SimplePhilox * rnd) const152 int64_t UniformSampler::Sample(random::SimplePhilox* rnd) const {
153   return rnd->Uniform64(range_);
154 }
155 
Probability(int64_t value) const156 float UniformSampler::Probability(int64_t value) const { return inv_range_; }
157 
LogUniformSampler(int64_t range)158 LogUniformSampler::LogUniformSampler(int64_t range)
159     : RangeSampler(range), log_range_(log1p(range)) {}
160 
Sample(random::SimplePhilox * rnd) const161 int64_t LogUniformSampler::Sample(random::SimplePhilox* rnd) const {
162   const int64_t value =
163       static_cast<int64_t>(exp(rnd->RandDouble() * log_range_)) - 1;
164   DCHECK_GE(value, 0);
165   // Mathematically, value should be <= range_, but might not be due to some
166   // floating point roundoff, so we mod by range_.  In practice this case
167   // happens never regardless of the value of range_, including and up to
168   // DBL_MAX.  But we include it as a guarantee of the function's output.
169   return value % range_;
170 }
171 
Probability(int64_t value) const172 float LogUniformSampler::Probability(int64_t value) const {
173   // value is returned iff the call to UniformDouble(log_range_) in the
174   // Sample() function returns a value between log(value + 1)
175   // and log(value + 2).   The probability of this is:
176   // (log(value + 2) - log(value + 1)) / log_range
177   // To avoid two calls to log(), we compute this as follows:
178   return (log((value + 2.0) / (value + 1.0))) / log_range_;
179 }
180 
ThreadUnsafeUnigramSampler(int64_t range)181 ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64_t range)
182     : RangeSampler(range), picker_(range) {
183   CHECK_LT(range, kint32max);
184 }
185 
Sample(random::SimplePhilox * rnd) const186 int64_t ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const {
187   return picker_.Pick(rnd);
188 }
189 
Probability(int64_t value) const190 float ThreadUnsafeUnigramSampler::Probability(int64_t value) const {
191   return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight();
192 }
193 
Update(ArraySlice<int64_t> values)194 void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64_t> values) {
195   int num_updates = std::min(static_cast<int>(values.size()),
196                              kint32max - picker_.total_weight());
197   for (int i = 0; i < num_updates; i++) {
198     const int64_t value = values[i];
199     picker_.set_weight(value, picker_.get_weight(value) + 1);
200   }
201 }
202 
203 // Thread-safe unigram sampler
UnigramSampler(int64_t range)204 UnigramSampler::UnigramSampler(int64_t range)
205     : RangeSampler(range), unsafe_sampler_(range) {
206   CHECK_LT(range, kint32max);
207 }
208 
Sample(random::SimplePhilox * rnd) const209 int64_t UnigramSampler::Sample(random::SimplePhilox* rnd) const {
210   tf_shared_lock lock(mu_);
211   return unsafe_sampler_.Sample(rnd);
212 }
213 
Probability(int64_t value) const214 float UnigramSampler::Probability(int64_t value) const {
215   tf_shared_lock lock(mu_);
216   return unsafe_sampler_.Probability(value);
217 }
218 
219 // Overriding at a high level results in far fewer lock acquisitions.
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64_t> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64_t> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64_t> avoided_values) const220 void UnigramSampler::SampleBatchGetExpectedCountAvoid(
221     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch,
222     MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> extras,
223     MutableArraySlice<float> extras_expected_count,
224     ArraySlice<int64_t> avoided_values) const {
225   tf_shared_lock lock(mu_);
226   unsafe_sampler_.SampleBatchGetExpectedCountAvoid(
227       rnd, unique, batch, batch_expected_count, extras, extras_expected_count,
228       avoided_values);
229 }
230 
Update(ArraySlice<int64_t> values)231 void UnigramSampler::Update(ArraySlice<int64_t> values) {
232   mutex_lock lock(mu_);
233   unsafe_sampler_.Update(values);
234 }
235 
FixedUnigramSampler(Env * env,int64_t range,const string & vocab_file,float distortion,int32_t num_reserved_ids,int32_t num_shards,int32_t shard)236 FixedUnigramSampler::FixedUnigramSampler(Env* env, int64_t range,
237                                          const string& vocab_file,
238                                          float distortion,
239                                          int32_t num_reserved_ids,
240                                          int32_t num_shards, int32_t shard)
241     : RangeSampler(range),
242       total_weight_(0.0),
243       num_shards_(num_shards),
244       shard_(shard) {
245   FillReservedIds(num_reserved_ids);
246   // TODO(vanhoucke): make this non-crashing.
247   TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion));
248   CHECK_EQ(range, weights_.size());
249   dist_sampler_.reset(new random::DistributionSampler(weights_));
250 }
251 
FixedUnigramSampler(int64_t range,const std::vector<float> & unigrams,float distortion,int32_t num_reserved_ids,int32_t num_shards,int32_t shard)252 FixedUnigramSampler::FixedUnigramSampler(int64_t range,
253                                          const std::vector<float>& unigrams,
254                                          float distortion,
255                                          int32_t num_reserved_ids,
256                                          int32_t num_shards, int32_t shard)
257     : RangeSampler(range),
258       total_weight_(0.0),
259       num_shards_(num_shards),
260       shard_(shard) {
261   FillReservedIds(num_reserved_ids);
262   LoadFromUnigrams(unigrams, distortion);
263   // TODO(vanhoucke): make this non-crashing.
264   CHECK_EQ(range, weights_.size());
265   dist_sampler_.reset(new random::DistributionSampler(weights_));
266 }
267 
Probability(int64_t value) const268 float FixedUnigramSampler::Probability(int64_t value) const {
269   if (value < 0 || static_cast<size_t>(value) >= weights_.size()) {
270     return 0.0;
271   }
272   return weights_.at(value) / total_weight_;
273 }
274 
Sample(random::SimplePhilox * rnd) const275 int64_t FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const {
276   return dist_sampler_->Sample(rnd);
277 }
278 
FillReservedIds(int32_t num_reserved_ids)279 void FixedUnigramSampler::FillReservedIds(int32_t num_reserved_ids) {
280   for (int32_t word_id = 0; word_id < num_reserved_ids; ++word_id) {
281     if (word_id % num_shards_ == shard_) weights_.push_back(0.0);
282   }
283 }
284 
LoadFromFile(Env * env,const string & vocab_file,float distortion)285 Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file,
286                                          float distortion) {
287   std::unique_ptr<RandomAccessFile> file;
288   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
289 
290   io::InputBuffer in(file.get(), 262144 /*bytes*/);
291   string line;
292   int32_t word_id = weights_.size();
293   while (in.ReadLine(&line).ok()) {
294     // The vocabulary file should be in csv like format, with the last
295     // field the weight associated with the word.
296     std::vector<string> cols = str_util::Split(line, ',');
297     if (cols.empty()) continue;
298     // Skip entries that do not belong to this shard.
299     if (word_id % num_shards_ == shard_) {
300       float w = 0.0;
301       if (!strings::safe_strtof(cols.at(cols.size() - 1), &w)) {
302         return errors::InvalidArgument("Wrong vocabulary format at line: ",
303                                        line);
304       }
305       w = std::pow(w, distortion);
306       total_weight_ += w;
307       weights_.push_back(w);
308     }
309     ++word_id;
310   }
311   return OkStatus();
312 }
313 
LoadFromUnigrams(const std::vector<float> & unigrams,float distortion)314 void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams,
315                                            float distortion) {
316   int32_t word_id = weights_.size();
317   for (float w : unigrams) {
318     // Skip entries that do not belong to this shard.
319     if (word_id % num_shards_ == shard_) {
320       w = std::pow(w, distortion);
321       total_weight_ += w;
322       weights_.push_back(w);
323     }
324     ++word_id;
325   }
326 }
327 
328 }  // namespace tensorflow
329