xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/batch_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/batch_dataset_op.h"
16 
17 #include <algorithm>
18 #include <utility>
19 
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/framework/dataset.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/stringprintf.h"
29 #include "tensorflow/core/util/batch_util.h"
30 
31 namespace tensorflow {
32 namespace data {
33 
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36 
37 /* static */ constexpr const char* const BatchDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const BatchDatasetOp::kInputDataset;
39 /* static */ constexpr const char* const BatchDatasetOp::kBatchSize;
40 /* static */ constexpr const char* const BatchDatasetOp::kDropRemainder;
41 /* static */ constexpr const char* const BatchDatasetOp::kParallelCopy;
42 /* static */ constexpr const char* const BatchDatasetOp::kOutputTypes;
43 /* static */ constexpr const char* const BatchDatasetOp::kOutputShapes;
44 
45 constexpr char kInputImplEmpty[] = "input_impl_empty";
46 constexpr char kBatchDataset[] = "BatchDataset";
47 
48 class BatchDatasetOp::Dataset : public DatasetBase {
49  public:
Dataset(OpKernelContext * ctx,int64_t batch_size,bool drop_remainder,bool parallel_copy,const DatasetBase * input,int op_version)50   Dataset(OpKernelContext* ctx, int64_t batch_size, bool drop_remainder,
51           bool parallel_copy, const DatasetBase* input, int op_version)
52       : DatasetBase(DatasetContext(ctx)),
53         batch_size_(batch_size),
54         // Dataset batch is sometimes used to stack all elements in the
55         // dataset. In such cases, a very large batch size (e.g., INT32_MAX)
56         // is passed with drop_remainder set to false. Avoid OOM in such case
57         // by limiting `reserve()` size by 2**16.
58         reserve_size_(drop_remainder ? batch_size
59                                      : std::min<int64_t>(batch_size, 1 << 16)),
60         drop_remainder_(drop_remainder),
61         parallel_copy_(parallel_copy),
62         input_(input),
63         op_version_(op_version),
64         traceme_metadata_(
65             {{"batch_size",
66               strings::Printf("%lld", static_cast<long long>(batch_size))},
67              {"drop_remainder", drop_remainder ? "true" : "false"},
68              {"parallel_copy", parallel_copy ? "true" : "false"}}) {
69     input_->Ref();
70 
71     // NOTE(mrry): Currently we implement "batch up to" semantics. If
72     // we could tell statically that the input dataset is infinite,
73     // then we could always report `batch_size` as the 0th dimension.
74     const auto& input_shapes = input_->output_shapes();
75     output_shapes_.reserve(input_shapes.size());
76     for (const auto& input_shape : input_shapes) {
77       if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
78         output_shapes_.emplace_back(
79             PartialTensorShape({batch_size_}).Concatenate(input_shape));
80       } else {
81         output_shapes_.emplace_back(
82             PartialTensorShape({-1}).Concatenate(input_shape));
83       }
84     }
85   }
86 
~Dataset()87   ~Dataset() override { input_->Unref(); }
88 
MakeIteratorInternal(const string & prefix) const89   std::unique_ptr<IteratorBase> MakeIteratorInternal(
90       const string& prefix) const override {
91     name_utils::IteratorPrefixParams params;
92     params.op_version = op_version_;
93     return std::make_unique<Iterator>(Iterator::Params{
94         this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
95   }
96 
output_dtypes() const97   const DataTypeVector& output_dtypes() const override {
98     return input_->output_dtypes();
99   }
100 
output_shapes() const101   const std::vector<PartialTensorShape>& output_shapes() const override {
102     return output_shapes_;
103   }
104 
DebugString() const105   string DebugString() const override {
106     name_utils::DatasetDebugStringParams params;
107     params.op_version = op_version_;
108     params.set_args(batch_size_);
109     return name_utils::DatasetDebugString(kDatasetType, params);
110   }
111 
CardinalityInternal() const112   int64_t CardinalityInternal() const override {
113     int64_t n = input_->Cardinality();
114     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
115       return n;
116     }
117     return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
118   }
119 
CardinalityInternal(CardinalityOptions options) const120   int64_t CardinalityInternal(CardinalityOptions options) const override {
121     int64_t n = input_->Cardinality(options);
122     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
123       return n;
124     }
125     return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
126   }
127 
InputDatasets(std::vector<const DatasetBase * > * inputs) const128   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
129     inputs->push_back(input_);
130     return OkStatus();
131   }
132 
CheckExternalState() const133   Status CheckExternalState() const override {
134     return input_->CheckExternalState();
135   }
136 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const137   Status Get(OpKernelContext* ctx, int64 index,
138              std::vector<Tensor>* out_tensors) const override {
139     const int64 cardinality = Cardinality();
140     if (index < 0 || index >= cardinality) {
141       return errors::OutOfRange("Index out of range [0, ", cardinality,
142                                 "):", index);
143     }
144     int batch_start_index = batch_size_ * index;
145     std::vector<std::vector<Tensor>> batch_elements;
146     int input_cardinality = input_->Cardinality();
147     for (int i = batch_start_index;
148          i < batch_start_index + batch_size_ && i < input_cardinality; ++i) {
149       std::vector<Tensor> batch_element_tuple;
150       TF_RETURN_IF_ERROR(input_->Get(ctx, i, &batch_element_tuple));
151       batch_elements.emplace_back(std::move(batch_element_tuple));
152     }
153     TF_RETURN_IF_ERROR(CopyBatch(CopyBatchParams(ctx), batch_elements,
154                                  parallel_copy_,
155                                  /*allocation_callback=*/nullptr, out_tensors));
156     return OkStatus();
157   }
158 
159  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const160   Status AsGraphDefInternal(SerializationContext* ctx,
161                             DatasetGraphDefBuilder* b,
162                             Node** output) const override {
163     Node* input_graph_node = nullptr;
164     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
165     Node* batch_size = nullptr;
166     TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
167     Node* drop_remainder = nullptr;
168     TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
169     AttrValue parallel_copy;
170     b->BuildAttrValue(parallel_copy_, &parallel_copy);
171     TF_RETURN_IF_ERROR(
172         b->AddDataset(this, {input_graph_node, batch_size, drop_remainder},
173                       {{kParallelCopy, parallel_copy}}, output));
174     return OkStatus();
175   }
176 
177  private:
178   class Iterator : public DatasetIterator<Dataset> {
179    public:
Iterator(const Params & params)180     explicit Iterator(const Params& params)
181         : DatasetIterator<Dataset>(params) {}
182 
Initialize(IteratorContext * ctx)183     Status Initialize(IteratorContext* ctx) override {
184       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
185     }
186 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)187     Status GetNextInternal(IteratorContext* ctx,
188                            std::vector<Tensor>* out_tensors,
189                            bool* end_of_sequence) override {
190       // Each row of `batch_elements` is a tuple of tensors from the
191       // input iterator.
192       std::vector<std::vector<Tensor>> batch_elements;
193       {
194         mutex_lock l(mu_);
195         if (!input_impl_) {
196           *end_of_sequence = true;
197           return OkStatus();
198         }
199         batch_elements.reserve(dataset()->reserve_size_);
200         *end_of_sequence = false;
201         for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence; ++i) {
202           std::vector<Tensor> batch_element_tuple;
203           TF_RETURN_IF_ERROR(
204               input_impl_->GetNext(ctx, &batch_element_tuple, end_of_sequence));
205           if (!*end_of_sequence) {
206             batch_elements.emplace_back(std::move(batch_element_tuple));
207           } else {
208             input_impl_.reset();
209           }
210         }
211       }
212 
213       if (batch_elements.empty()) {
214         DCHECK(*end_of_sequence);
215         return OkStatus();
216       }
217 
218       if (dataset()->drop_remainder_ &&
219           batch_elements.size() < dataset()->batch_size_) {
220         *end_of_sequence = true;
221         return OkStatus();
222       }
223 
224       // Copy the retrieved batch elements into one output tensor per tuple
225       // component.
226       //
227       // NOTE(mrry): If the input or output sizes are statically known, we
228       // could potentially read the input values in-place into their
229       // respective slice locations. This would require a different GetNext()
230       // overload that supports zero-copy, and might make sense in an
231       // optimization pass.
232       TF_RETURN_IF_ERROR(CopyBatch(
233           CopyBatchParams(ctx), batch_elements, dataset()->parallel_copy_,
234           /*allocation_callback=*/nullptr, out_tensors));
235 
236       *end_of_sequence = false;
237       return OkStatus();
238     }
239 
240    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const241     std::shared_ptr<model::Node> CreateNode(
242         IteratorContext* ctx, model::Node::Args args) const override {
243       return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
244     }
245 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)246     Status SaveInternal(SerializationContext* ctx,
247                         IteratorStateWriter* writer) override {
248       mutex_lock l(mu_);
249       if (!input_impl_) {
250         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
251       } else {
252         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
253       }
254       return OkStatus();
255     }
256 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)257     Status RestoreInternal(IteratorContext* ctx,
258                            IteratorStateReader* reader) override {
259       mutex_lock l(mu_);
260       if (!reader->Contains(full_name(kInputImplEmpty))) {
261         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
262       } else {
263         input_impl_.reset();
264       }
265       return OkStatus();
266     }
267 
GetTraceMeMetadata() const268     TraceMeMetadata GetTraceMeMetadata() const override {
269       return dataset()->traceme_metadata_;
270     }
271 
272    private:
273     mutex mu_;
274     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
275   };
276 
277   const int64_t batch_size_;
278   const int64_t reserve_size_;
279   const bool drop_remainder_;
280   const bool parallel_copy_;
281   const DatasetBase* const input_;
282   const int op_version_;
283   std::vector<PartialTensorShape> output_shapes_;
284   const TraceMeMetadata traceme_metadata_;
285 };
286 
BatchDatasetOp(OpKernelConstruction * ctx)287 BatchDatasetOp::BatchDatasetOp(OpKernelConstruction* ctx)
288     : UnaryDatasetOpKernel(ctx),
289       op_version_(ctx->def().op() == kBatchDataset ? 1 : 2) {
290   if (ctx->HasAttr(kParallelCopy)) {
291     OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, &parallel_copy_));
292   }
293 }
294 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)295 void BatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
296                                  DatasetBase** output) {
297   int64_t batch_size = 0;
298   OP_REQUIRES_OK(ctx,
299                  ParseScalarArgument<int64_t>(ctx, kBatchSize, &batch_size));
300   OP_REQUIRES(ctx, batch_size > 0,
301               errors::InvalidArgument("Batch size must be greater than zero."));
302 
303   bool drop_remainder = false;
304   if (op_version_ > 1) {
305     OP_REQUIRES_OK(
306         ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
307   }
308 
309   *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_, input,
310                         op_version_);
311 }
312 
313 namespace {
314 REGISTER_KERNEL_BUILDER(Name("BatchDataset").Device(DEVICE_CPU),
315                         BatchDatasetOp);
316 
317 REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU),
318                         BatchDatasetOp);
319 }  // namespace
320 }  // namespace data
321 }  // namespace tensorflow
322