1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/sampling_dataset_op.h"
16
17 #include "tensorflow/core/data/dataset_utils.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/random/philox_random.h"
23 #include "tensorflow/core/lib/random/random.h"
24 #include "tensorflow/core/lib/random/random_distributions.h"
25 #include "tensorflow/core/lib/random/simple_philox.h"
26
27 namespace tensorflow {
28 namespace data {
29 namespace experimental {
30
31 // Constants declared in sampling_dataset_op.h and used both here and in test
32 // cases.
33 /* static */ constexpr const char* const SamplingDatasetOp::kDatasetType;
34 /* static */ constexpr const char* const SamplingDatasetOp::kInputDataset;
35 /* static */ constexpr const char* const SamplingDatasetOp::kRate;
36 /* static */ constexpr const char* const SamplingDatasetOp::kSeed;
37 /* static */ constexpr const char* const SamplingDatasetOp::kSeed2;
38 /* static */ constexpr const char* const SamplingDatasetOp::kOutputTypes;
39 /* static */ constexpr const char* const SamplingDatasetOp::kOutputShapes;
40
41 class SamplingDatasetOp::Dataset : public DatasetBase {
42 public:
Dataset(OpKernelContext * ctx,float rate,int64_t seed,int64_t seed2,const DatasetBase * input)43 Dataset(OpKernelContext* ctx, float rate, int64_t seed, int64_t seed2,
44 const DatasetBase* input)
45 : DatasetBase(DatasetContext(ctx)),
46 rate_(rate),
47 seeds_(seed, seed2),
48 input_(input) {
49 input_->Ref();
50 }
51
~Dataset()52 ~Dataset() override { input_->Unref(); }
53
MakeIteratorInternal(const string & prefix) const54 std::unique_ptr<IteratorBase> MakeIteratorInternal(
55 const string& prefix) const override {
56 return std::unique_ptr<IteratorBase>(
57 new Iterator({this, name_utils::IteratorPrefix(kDatasetType, prefix)},
58 seeds_.first, seeds_.second));
59 }
60
output_dtypes() const61 const DataTypeVector& output_dtypes() const override {
62 return input_->output_dtypes();
63 }
64
output_shapes() const65 const std::vector<PartialTensorShape>& output_shapes() const override {
66 return input_->output_shapes();
67 }
68
DebugString() const69 string DebugString() const override {
70 return name_utils::DatasetDebugString(kDatasetType);
71 }
72
InputDatasets(std::vector<const DatasetBase * > * inputs) const73 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
74 inputs->push_back(input_);
75 return OkStatus();
76 }
77
CheckExternalState() const78 Status CheckExternalState() const override {
79 return input_->CheckExternalState();
80 }
81
82 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const83 Status AsGraphDefInternal(SerializationContext* ctx,
84 DatasetGraphDefBuilder* b,
85 Node** output) const override {
86 Node* input_graph_node = nullptr;
87 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
88 Node* rate = nullptr;
89 Node* seed = nullptr;
90 Node* seed2 = nullptr;
91 TF_RETURN_IF_ERROR(b->AddScalar(rate_, &rate));
92 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.first, &seed));
93 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.second, &seed2));
94 TF_RETURN_IF_ERROR(
95 b->AddDataset(this, {input_graph_node, rate, seed, seed2}, output));
96 return OkStatus();
97 }
98
99 private:
100 class Iterator : public DatasetIterator<Dataset> {
101 public:
Iterator(const Params & params,int64_t seed,int64_t seed2)102 explicit Iterator(const Params& params, int64_t seed, int64_t seed2)
103 : DatasetIterator<Dataset>(params),
104 seeds_(MaybeOverrideSeeds({seed, seed2})),
105 parent_generator_(seeds_.first, seeds_.second),
106 generator_(&parent_generator_) {}
107
Initialize(IteratorContext * ctx)108 Status Initialize(IteratorContext* ctx) override {
109 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
110 }
111
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)112 Status GetNextInternal(IteratorContext* ctx,
113 std::vector<Tensor>* out_tensors,
114 bool* end_of_sequence) override {
115 bool rand_val_hit;
116 do {
117 {
118 tf_shared_lock l(mu_);
119 if (!input_impl_) {
120 *end_of_sequence = true;
121 return OkStatus();
122 }
123 TF_RETURN_IF_ERROR(
124 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
125 }
126 if (*end_of_sequence) {
127 mutex_lock l(mu_);
128 input_impl_.reset();
129 return OkStatus();
130 }
131
132 // generate a number from random uniform [0, 1)
133 float rand_val = Random();
134 rand_val_hit = rand_val < dataset()->rate_;
135 if (!rand_val_hit) {
136 // Clear the output tensor list since it doesn't match.
137 out_tensors->clear();
138 }
139 } while (!rand_val_hit);
140 *end_of_sequence = false;
141 return OkStatus();
142 }
143
144 protected:
ResetRngs()145 void ResetRngs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
146 // Reset the generators based on the current iterator seeds.
147 parent_generator_ = random::PhiloxRandom(seeds_.first, seeds_.second);
148 generator_ =
149 random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
150 generator_.Skip(num_random_samples_);
151 }
152
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)153 Status SaveInternal(SerializationContext* ctx,
154 IteratorStateWriter* writer) override {
155 mutex_lock l(mu_);
156 // Save state needed to restore the random number generators.
157 TF_RETURN_IF_ERROR(writer->WriteScalar(
158 this->full_name("num_random_samples"), num_random_samples_));
159 TF_RETURN_IF_ERROR(
160 writer->WriteScalar(this->full_name("seed"), seeds_.first));
161 TF_RETURN_IF_ERROR(
162 writer->WriteScalar(this->full_name("seed2"), seeds_.second));
163
164 if (input_impl_) {
165 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
166 } else {
167 TF_RETURN_IF_ERROR(
168 writer->WriteScalar(full_name("input_impl_empty"), ""));
169 }
170 return OkStatus();
171 }
172
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)173 Status RestoreInternal(IteratorContext* ctx,
174 IteratorStateReader* reader) override {
175 mutex_lock l(mu_);
176 // Restore the random number generators.
177 TF_RETURN_IF_ERROR(reader->ReadScalar(
178 this->full_name("num_random_samples"), &num_random_samples_));
179 int64_t seed;
180 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed"), &seed));
181 int64_t seed2;
182 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed2"), &seed2));
183 seeds_ = {seed, seed2};
184 ResetRngs();
185
186 if (!reader->Contains(full_name("input_impl_empty"))) {
187 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
188 } else {
189 input_impl_.reset();
190 }
191 return OkStatus();
192 }
193
194 mutex mu_;
195 std::pair<int64_t, int64_t> seeds_ TF_GUARDED_BY(mu_);
196
197 private:
198 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
199
Random()200 float Random() {
201 mutex_lock l(mu_);
202 num_random_samples_++;
203 uint32 random_uint = generator_();
204
205 // PhiloxRandom returns 32-bit unsigned ints. Convert to float in [0,1)
206 // using the same method that the RandomUniform op uses.
207 return random::Uint32ToFloat(random_uint);
208 }
209
210 // random util
211 random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_);
212 random::SingleSampleAdapter<random::PhiloxRandom> generator_
213 TF_GUARDED_BY(mu_);
214 int64_t num_random_samples_ TF_GUARDED_BY(mu_) = 0;
215 };
216
217 const float rate_;
218 const std::pair<int64_t, int64_t> seeds_;
219 const DatasetBase* const input_;
220 }; // SamplingDatasetOp::Dataset
221
SamplingDatasetOp(OpKernelConstruction * ctx)222 SamplingDatasetOp::SamplingDatasetOp(OpKernelConstruction* ctx)
223 : UnaryDatasetOpKernel(ctx) {}
224
225 // Create a new SamplingDatasetOp::Dataset, and return it as the output.
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)226 void SamplingDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
227 DatasetBase** output) {
228 float rate;
229 int64_t seed;
230 int64_t seed2;
231 OP_REQUIRES_OK(ctx, ParseScalarArgument<float>(ctx, kRate, &rate));
232 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed, &seed));
233 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed2, &seed2));
234
235 *output = new Dataset(ctx, rate, seed, seed2, input);
236 }
237
238 namespace {
239 REGISTER_KERNEL_BUILDER(Name("SamplingDataset").Device(DEVICE_CPU),
240 SamplingDatasetOp);
241 } // namespace
242 } // namespace experimental
243 } // namespace data
244 } // namespace tensorflow
245