1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/range_sampler.h"
17
18 #include <cmath>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/io/inputbuffer.h"
25 #include "tensorflow/core/lib/strings/numbers.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30
31 namespace tensorflow {
32
33 using gtl::ArraySlice;
34 using gtl::MutableArraySlice;
35
~RangeSampler()36 RangeSampler::~RangeSampler() {}
37
SampleBatch(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64_t> batch) const38 void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique,
39 gtl::MutableArraySlice<int64_t> batch) const {
40 SampleBatchGetExpectedCount(
41 rnd, unique, batch, gtl::MutableArraySlice<float>(),
42 gtl::ArraySlice<int64_t>(), gtl::MutableArraySlice<float>());
43 }
44
SampleBatchGetExpectedCount(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64_t> batch,gtl::MutableArraySlice<float> batch_expected_count,gtl::ArraySlice<int64_t> extras,gtl::MutableArraySlice<float> extras_expected_count) const45 void RangeSampler::SampleBatchGetExpectedCount(
46 random::SimplePhilox* rnd, bool unique,
47 gtl::MutableArraySlice<int64_t> batch,
48 gtl::MutableArraySlice<float> batch_expected_count,
49 gtl::ArraySlice<int64_t> extras,
50 gtl::MutableArraySlice<float> extras_expected_count) const {
51 SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count,
52 extras, extras_expected_count,
53 gtl::ArraySlice<int64_t>());
54 }
55
56 namespace {
57
58 // Approximates the expected count of a value in the output of SampleBatch.
59 //
60 // If unique=false, then this is (Probability(value) * batch_size)
61 //
62 // We use batch_size and num_tries, where num_tries is the observed number of
63 // tries it took to get batch_size unique values.
64 //
65 // Assuming (falsely) that the number of tries to get a batch of batch_size
66 // distinct values is _always_ num_tries, the probability that the value
67 // is in a batch is (1 - (1-p)^num_tries)
ExpectedCountHelper(float p,int batch_size,int num_tries)68 static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
69 if (num_tries == batch_size) {
70 // This shortcut will always be taken if unique=false
71 return p * batch_size;
72 }
73 // numerically stable version of (1 - (1-p)^num_tries)
74 return -std::expm1(num_tries * std::log1p(-p));
75 }
76
77 } // namespace
78
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64_t> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64_t> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64_t> avoided_values) const79 void RangeSampler::SampleBatchGetExpectedCountAvoid(
80 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch,
81 MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> extras,
82 MutableArraySlice<float> extras_expected_count,
83 ArraySlice<int64_t> avoided_values) const {
84 const int batch_size = batch.size();
85 int num_tries;
86
87 if (unique) {
88 CHECK_LE(static_cast<int64_t>(batch_size + avoided_values.size()), range_);
89 std::unordered_set<int64_t> used(batch_size);
90 used.insert(avoided_values.begin(), avoided_values.end());
91 int num_picked = 0;
92 num_tries = 0;
93 while (num_picked < batch_size) {
94 num_tries++;
95 CHECK_LT(num_tries, kint32max);
96 int64_t value = Sample(rnd);
97 if (gtl::InsertIfNotPresent(&used, value)) {
98 batch[num_picked++] = value;
99 }
100 }
101 } else {
102 CHECK_EQ(avoided_values.size(), size_t{0})
103 << "avoided_values only supported with unique=true";
104 for (int i = 0; i < batch_size; i++) {
105 batch[i] = Sample(rnd);
106 }
107 num_tries = batch_size;
108 }
109 // Compute the expected counts of the batch and the extra values
110 if (!batch_expected_count.empty()) {
111 CHECK_EQ(batch_size, batch_expected_count.size());
112 for (int i = 0; i < batch_size; i++) {
113 batch_expected_count[i] =
114 ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries);
115 }
116 }
117 CHECK_EQ(extras.size(), extras_expected_count.size());
118 for (size_t i = 0; i < extras.size(); i++) {
119 extras_expected_count[i] =
120 ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries);
121 }
122 }
123
AllSampler(int64_t range)124 AllSampler::AllSampler(int64_t range) : RangeSampler(range) {}
125
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64_t> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64_t> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64_t> avoided_values) const126 void AllSampler::SampleBatchGetExpectedCountAvoid(
127 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch,
128 MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> extras,
129 MutableArraySlice<float> extras_expected_count,
130 ArraySlice<int64_t> avoided_values) const {
131 const int batch_size = batch.size();
132 CHECK_EQ(range_, batch_size);
133 for (int i = 0; i < batch_size; i++) {
134 batch[i] = i;
135 }
136 if (!batch_expected_count.empty()) {
137 CHECK_EQ(batch_size, batch_expected_count.size());
138 for (int i = 0; i < batch_size; i++) {
139 batch_expected_count[i] = 1;
140 }
141 }
142 CHECK_EQ(size_t{0}, avoided_values.size());
143 CHECK_EQ(extras.size(), extras_expected_count.size());
144 for (size_t i = 0; i < extras.size(); i++) {
145 extras_expected_count[i] = 1;
146 }
147 }
148
UniformSampler(int64_t range)149 UniformSampler::UniformSampler(int64_t range)
150 : RangeSampler(range), inv_range_(1.0 / range) {}
151
Sample(random::SimplePhilox * rnd) const152 int64_t UniformSampler::Sample(random::SimplePhilox* rnd) const {
153 return rnd->Uniform64(range_);
154 }
155
Probability(int64_t value) const156 float UniformSampler::Probability(int64_t value) const { return inv_range_; }
157
LogUniformSampler(int64_t range)158 LogUniformSampler::LogUniformSampler(int64_t range)
159 : RangeSampler(range), log_range_(log1p(range)) {}
160
Sample(random::SimplePhilox * rnd) const161 int64_t LogUniformSampler::Sample(random::SimplePhilox* rnd) const {
162 const int64_t value =
163 static_cast<int64_t>(exp(rnd->RandDouble() * log_range_)) - 1;
164 DCHECK_GE(value, 0);
165 // Mathematically, value should be <= range_, but might not be due to some
166 // floating point roundoff, so we mod by range_. In practice this case
167 // happens never regardless of the value of range_, including and up to
168 // DBL_MAX. But we include it as a guarantee of the function's output.
169 return value % range_;
170 }
171
Probability(int64_t value) const172 float LogUniformSampler::Probability(int64_t value) const {
173 // value is returned iff the call to UniformDouble(log_range_) in the
174 // Sample() function returns a value between log(value + 1)
175 // and log(value + 2). The probability of this is:
176 // (log(value + 2) - log(value + 1)) / log_range
177 // To avoid two calls to log(), we compute this as follows:
178 return (log((value + 2.0) / (value + 1.0))) / log_range_;
179 }
180
ThreadUnsafeUnigramSampler(int64_t range)181 ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64_t range)
182 : RangeSampler(range), picker_(range) {
183 CHECK_LT(range, kint32max);
184 }
185
Sample(random::SimplePhilox * rnd) const186 int64_t ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const {
187 return picker_.Pick(rnd);
188 }
189
Probability(int64_t value) const190 float ThreadUnsafeUnigramSampler::Probability(int64_t value) const {
191 return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight();
192 }
193
Update(ArraySlice<int64_t> values)194 void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64_t> values) {
195 int num_updates = std::min(static_cast<int>(values.size()),
196 kint32max - picker_.total_weight());
197 for (int i = 0; i < num_updates; i++) {
198 const int64_t value = values[i];
199 picker_.set_weight(value, picker_.get_weight(value) + 1);
200 }
201 }
202
203 // Thread-safe unigram sampler
UnigramSampler(int64_t range)204 UnigramSampler::UnigramSampler(int64_t range)
205 : RangeSampler(range), unsafe_sampler_(range) {
206 CHECK_LT(range, kint32max);
207 }
208
Sample(random::SimplePhilox * rnd) const209 int64_t UnigramSampler::Sample(random::SimplePhilox* rnd) const {
210 tf_shared_lock lock(mu_);
211 return unsafe_sampler_.Sample(rnd);
212 }
213
Probability(int64_t value) const214 float UnigramSampler::Probability(int64_t value) const {
215 tf_shared_lock lock(mu_);
216 return unsafe_sampler_.Probability(value);
217 }
218
219 // Overriding at a high level results in far fewer lock acquisitions.
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64_t> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64_t> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64_t> avoided_values) const220 void UnigramSampler::SampleBatchGetExpectedCountAvoid(
221 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch,
222 MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> extras,
223 MutableArraySlice<float> extras_expected_count,
224 ArraySlice<int64_t> avoided_values) const {
225 tf_shared_lock lock(mu_);
226 unsafe_sampler_.SampleBatchGetExpectedCountAvoid(
227 rnd, unique, batch, batch_expected_count, extras, extras_expected_count,
228 avoided_values);
229 }
230
Update(ArraySlice<int64_t> values)231 void UnigramSampler::Update(ArraySlice<int64_t> values) {
232 mutex_lock lock(mu_);
233 unsafe_sampler_.Update(values);
234 }
235
FixedUnigramSampler(Env * env,int64_t range,const string & vocab_file,float distortion,int32_t num_reserved_ids,int32_t num_shards,int32_t shard)236 FixedUnigramSampler::FixedUnigramSampler(Env* env, int64_t range,
237 const string& vocab_file,
238 float distortion,
239 int32_t num_reserved_ids,
240 int32_t num_shards, int32_t shard)
241 : RangeSampler(range),
242 total_weight_(0.0),
243 num_shards_(num_shards),
244 shard_(shard) {
245 FillReservedIds(num_reserved_ids);
246 // TODO(vanhoucke): make this non-crashing.
247 TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion));
248 CHECK_EQ(range, weights_.size());
249 dist_sampler_.reset(new random::DistributionSampler(weights_));
250 }
251
FixedUnigramSampler(int64_t range,const std::vector<float> & unigrams,float distortion,int32_t num_reserved_ids,int32_t num_shards,int32_t shard)252 FixedUnigramSampler::FixedUnigramSampler(int64_t range,
253 const std::vector<float>& unigrams,
254 float distortion,
255 int32_t num_reserved_ids,
256 int32_t num_shards, int32_t shard)
257 : RangeSampler(range),
258 total_weight_(0.0),
259 num_shards_(num_shards),
260 shard_(shard) {
261 FillReservedIds(num_reserved_ids);
262 LoadFromUnigrams(unigrams, distortion);
263 // TODO(vanhoucke): make this non-crashing.
264 CHECK_EQ(range, weights_.size());
265 dist_sampler_.reset(new random::DistributionSampler(weights_));
266 }
267
Probability(int64_t value) const268 float FixedUnigramSampler::Probability(int64_t value) const {
269 if (value < 0 || static_cast<size_t>(value) >= weights_.size()) {
270 return 0.0;
271 }
272 return weights_.at(value) / total_weight_;
273 }
274
Sample(random::SimplePhilox * rnd) const275 int64_t FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const {
276 return dist_sampler_->Sample(rnd);
277 }
278
FillReservedIds(int32_t num_reserved_ids)279 void FixedUnigramSampler::FillReservedIds(int32_t num_reserved_ids) {
280 for (int32_t word_id = 0; word_id < num_reserved_ids; ++word_id) {
281 if (word_id % num_shards_ == shard_) weights_.push_back(0.0);
282 }
283 }
284
LoadFromFile(Env * env,const string & vocab_file,float distortion)285 Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file,
286 float distortion) {
287 std::unique_ptr<RandomAccessFile> file;
288 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
289
290 io::InputBuffer in(file.get(), 262144 /*bytes*/);
291 string line;
292 int32_t word_id = weights_.size();
293 while (in.ReadLine(&line).ok()) {
294 // The vocabulary file should be in csv like format, with the last
295 // field the weight associated with the word.
296 std::vector<string> cols = str_util::Split(line, ',');
297 if (cols.empty()) continue;
298 // Skip entries that do not belong to this shard.
299 if (word_id % num_shards_ == shard_) {
300 float w = 0.0;
301 if (!strings::safe_strtof(cols.at(cols.size() - 1), &w)) {
302 return errors::InvalidArgument("Wrong vocabulary format at line: ",
303 line);
304 }
305 w = std::pow(w, distortion);
306 total_weight_ += w;
307 weights_.push_back(w);
308 }
309 ++word_id;
310 }
311 return OkStatus();
312 }
313
LoadFromUnigrams(const std::vector<float> & unigrams,float distortion)314 void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams,
315 float distortion) {
316 int32_t word_id = weights_.size();
317 for (float w : unigrams) {
318 // Skip entries that do not belong to this shard.
319 if (word_id % num_shards_ == shard_) {
320 w = std::pow(w, distortion);
321 total_weight_ += w;
322 weights_.push_back(w);
323 }
324 ++word_id;
325 }
326 }
327
328 } // namespace tensorflow
329