xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/sampling_dataset_op.h"
16 
17 #include "tensorflow/core/data/dataset_utils.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/lib/random/philox_random.h"
23 #include "tensorflow/core/lib/random/random.h"
24 #include "tensorflow/core/lib/random/random_distributions.h"
25 #include "tensorflow/core/lib/random/simple_philox.h"
26 
27 namespace tensorflow {
28 namespace data {
29 namespace experimental {
30 
31 // Constants declared in sampling_dataset_op.h and used both here and in test
32 // cases.
33 /* static */ constexpr const char* const SamplingDatasetOp::kDatasetType;
34 /* static */ constexpr const char* const SamplingDatasetOp::kInputDataset;
35 /* static */ constexpr const char* const SamplingDatasetOp::kRate;
36 /* static */ constexpr const char* const SamplingDatasetOp::kSeed;
37 /* static */ constexpr const char* const SamplingDatasetOp::kSeed2;
38 /* static */ constexpr const char* const SamplingDatasetOp::kOutputTypes;
39 /* static */ constexpr const char* const SamplingDatasetOp::kOutputShapes;
40 
41 class SamplingDatasetOp::Dataset : public DatasetBase {
42  public:
Dataset(OpKernelContext * ctx,float rate,int64_t seed,int64_t seed2,const DatasetBase * input)43   Dataset(OpKernelContext* ctx, float rate, int64_t seed, int64_t seed2,
44           const DatasetBase* input)
45       : DatasetBase(DatasetContext(ctx)),
46         rate_(rate),
47         seeds_(seed, seed2),
48         input_(input) {
49     input_->Ref();
50   }
51 
~Dataset()52   ~Dataset() override { input_->Unref(); }
53 
MakeIteratorInternal(const string & prefix) const54   std::unique_ptr<IteratorBase> MakeIteratorInternal(
55       const string& prefix) const override {
56     return std::unique_ptr<IteratorBase>(
57         new Iterator({this, name_utils::IteratorPrefix(kDatasetType, prefix)},
58                      seeds_.first, seeds_.second));
59   }
60 
output_dtypes() const61   const DataTypeVector& output_dtypes() const override {
62     return input_->output_dtypes();
63   }
64 
output_shapes() const65   const std::vector<PartialTensorShape>& output_shapes() const override {
66     return input_->output_shapes();
67   }
68 
DebugString() const69   string DebugString() const override {
70     return name_utils::DatasetDebugString(kDatasetType);
71   }
72 
InputDatasets(std::vector<const DatasetBase * > * inputs) const73   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
74     inputs->push_back(input_);
75     return OkStatus();
76   }
77 
CheckExternalState() const78   Status CheckExternalState() const override {
79     return input_->CheckExternalState();
80   }
81 
82  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const83   Status AsGraphDefInternal(SerializationContext* ctx,
84                             DatasetGraphDefBuilder* b,
85                             Node** output) const override {
86     Node* input_graph_node = nullptr;
87     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
88     Node* rate = nullptr;
89     Node* seed = nullptr;
90     Node* seed2 = nullptr;
91     TF_RETURN_IF_ERROR(b->AddScalar(rate_, &rate));
92     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.first, &seed));
93     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.second, &seed2));
94     TF_RETURN_IF_ERROR(
95         b->AddDataset(this, {input_graph_node, rate, seed, seed2}, output));
96     return OkStatus();
97   }
98 
99  private:
100   class Iterator : public DatasetIterator<Dataset> {
101    public:
Iterator(const Params & params,int64_t seed,int64_t seed2)102     explicit Iterator(const Params& params, int64_t seed, int64_t seed2)
103         : DatasetIterator<Dataset>(params),
104           seeds_(MaybeOverrideSeeds({seed, seed2})),
105           parent_generator_(seeds_.first, seeds_.second),
106           generator_(&parent_generator_) {}
107 
Initialize(IteratorContext * ctx)108     Status Initialize(IteratorContext* ctx) override {
109       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
110     }
111 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)112     Status GetNextInternal(IteratorContext* ctx,
113                            std::vector<Tensor>* out_tensors,
114                            bool* end_of_sequence) override {
115       bool rand_val_hit;
116       do {
117         {
118           tf_shared_lock l(mu_);
119           if (!input_impl_) {
120             *end_of_sequence = true;
121             return OkStatus();
122           }
123           TF_RETURN_IF_ERROR(
124               input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
125         }
126         if (*end_of_sequence) {
127           mutex_lock l(mu_);
128           input_impl_.reset();
129           return OkStatus();
130         }
131 
132         // generate a number from random uniform [0, 1)
133         float rand_val = Random();
134         rand_val_hit = rand_val < dataset()->rate_;
135         if (!rand_val_hit) {
136           // Clear the output tensor list since it doesn't match.
137           out_tensors->clear();
138         }
139       } while (!rand_val_hit);
140       *end_of_sequence = false;
141       return OkStatus();
142     }
143 
144    protected:
ResetRngs()145     void ResetRngs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
146       // Reset the generators based on the current iterator seeds.
147       parent_generator_ = random::PhiloxRandom(seeds_.first, seeds_.second);
148       generator_ =
149           random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
150       generator_.Skip(num_random_samples_);
151     }
152 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)153     Status SaveInternal(SerializationContext* ctx,
154                         IteratorStateWriter* writer) override {
155       mutex_lock l(mu_);
156       // Save state needed to restore the random number generators.
157       TF_RETURN_IF_ERROR(writer->WriteScalar(
158           this->full_name("num_random_samples"), num_random_samples_));
159       TF_RETURN_IF_ERROR(
160           writer->WriteScalar(this->full_name("seed"), seeds_.first));
161       TF_RETURN_IF_ERROR(
162           writer->WriteScalar(this->full_name("seed2"), seeds_.second));
163 
164       if (input_impl_) {
165         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
166       } else {
167         TF_RETURN_IF_ERROR(
168             writer->WriteScalar(full_name("input_impl_empty"), ""));
169       }
170       return OkStatus();
171     }
172 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)173     Status RestoreInternal(IteratorContext* ctx,
174                            IteratorStateReader* reader) override {
175       mutex_lock l(mu_);
176       // Restore the random number generators.
177       TF_RETURN_IF_ERROR(reader->ReadScalar(
178           this->full_name("num_random_samples"), &num_random_samples_));
179       int64_t seed;
180       TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed"), &seed));
181       int64_t seed2;
182       TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed2"), &seed2));
183       seeds_ = {seed, seed2};
184       ResetRngs();
185 
186       if (!reader->Contains(full_name("input_impl_empty"))) {
187         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
188       } else {
189         input_impl_.reset();
190       }
191       return OkStatus();
192     }
193 
194     mutex mu_;
195     std::pair<int64_t, int64_t> seeds_ TF_GUARDED_BY(mu_);
196 
197    private:
198     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
199 
Random()200     float Random() {
201       mutex_lock l(mu_);
202       num_random_samples_++;
203       uint32 random_uint = generator_();
204 
205       // PhiloxRandom returns 32-bit unsigned ints. Convert to float in [0,1)
206       // using the same method that the RandomUniform op uses.
207       return random::Uint32ToFloat(random_uint);
208     }
209 
210     // random util
211     random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_);
212     random::SingleSampleAdapter<random::PhiloxRandom> generator_
213         TF_GUARDED_BY(mu_);
214     int64_t num_random_samples_ TF_GUARDED_BY(mu_) = 0;
215   };
216 
217   const float rate_;
218   const std::pair<int64_t, int64_t> seeds_;
219   const DatasetBase* const input_;
220 };  // SamplingDatasetOp::Dataset
221 
SamplingDatasetOp(OpKernelConstruction * ctx)222 SamplingDatasetOp::SamplingDatasetOp(OpKernelConstruction* ctx)
223     : UnaryDatasetOpKernel(ctx) {}
224 
225 // Create a new SamplingDatasetOp::Dataset, and return it as the output.
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)226 void SamplingDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
227                                     DatasetBase** output) {
228   float rate;
229   int64_t seed;
230   int64_t seed2;
231   OP_REQUIRES_OK(ctx, ParseScalarArgument<float>(ctx, kRate, &rate));
232   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed, &seed));
233   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed2, &seed2));
234 
235   *output = new Dataset(ctx, rate, seed, seed2, input);
236 }
237 
238 namespace {
239 REGISTER_KERNEL_BUILDER(Name("SamplingDataset").Device(DEVICE_CPU),
240                         SamplingDatasetOp);
241 }  // namespace
242 }  // namespace experimental
243 }  // namespace data
244 }  // namespace tensorflow
245