1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/batch_dataset_op.h"
16
17 #include <algorithm>
18 #include <utility>
19
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/framework/dataset.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/stringprintf.h"
29 #include "tensorflow/core/util/batch_util.h"
30
31 namespace tensorflow {
32 namespace data {
33
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36
37 /* static */ constexpr const char* const BatchDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const BatchDatasetOp::kInputDataset;
39 /* static */ constexpr const char* const BatchDatasetOp::kBatchSize;
40 /* static */ constexpr const char* const BatchDatasetOp::kDropRemainder;
41 /* static */ constexpr const char* const BatchDatasetOp::kParallelCopy;
42 /* static */ constexpr const char* const BatchDatasetOp::kOutputTypes;
43 /* static */ constexpr const char* const BatchDatasetOp::kOutputShapes;
44
45 constexpr char kInputImplEmpty[] = "input_impl_empty";
46 constexpr char kBatchDataset[] = "BatchDataset";
47
48 class BatchDatasetOp::Dataset : public DatasetBase {
49 public:
Dataset(OpKernelContext * ctx,int64_t batch_size,bool drop_remainder,bool parallel_copy,const DatasetBase * input,int op_version)50 Dataset(OpKernelContext* ctx, int64_t batch_size, bool drop_remainder,
51 bool parallel_copy, const DatasetBase* input, int op_version)
52 : DatasetBase(DatasetContext(ctx)),
53 batch_size_(batch_size),
54 // Dataset batch is sometimes used to stack all elements in the
55 // dataset. In such cases, a very large batch size (e.g., INT32_MAX)
56 // is passed with drop_remainder set to false. Avoid OOM in such case
57 // by limiting `reserve()` size by 2**16.
58 reserve_size_(drop_remainder ? batch_size
59 : std::min<int64_t>(batch_size, 1 << 16)),
60 drop_remainder_(drop_remainder),
61 parallel_copy_(parallel_copy),
62 input_(input),
63 op_version_(op_version),
64 traceme_metadata_(
65 {{"batch_size",
66 strings::Printf("%lld", static_cast<long long>(batch_size))},
67 {"drop_remainder", drop_remainder ? "true" : "false"},
68 {"parallel_copy", parallel_copy ? "true" : "false"}}) {
69 input_->Ref();
70
71 // NOTE(mrry): Currently we implement "batch up to" semantics. If
72 // we could tell statically that the input dataset is infinite,
73 // then we could always report `batch_size` as the 0th dimension.
74 const auto& input_shapes = input_->output_shapes();
75 output_shapes_.reserve(input_shapes.size());
76 for (const auto& input_shape : input_shapes) {
77 if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
78 output_shapes_.emplace_back(
79 PartialTensorShape({batch_size_}).Concatenate(input_shape));
80 } else {
81 output_shapes_.emplace_back(
82 PartialTensorShape({-1}).Concatenate(input_shape));
83 }
84 }
85 }
86
~Dataset()87 ~Dataset() override { input_->Unref(); }
88
MakeIteratorInternal(const string & prefix) const89 std::unique_ptr<IteratorBase> MakeIteratorInternal(
90 const string& prefix) const override {
91 name_utils::IteratorPrefixParams params;
92 params.op_version = op_version_;
93 return std::make_unique<Iterator>(Iterator::Params{
94 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
95 }
96
output_dtypes() const97 const DataTypeVector& output_dtypes() const override {
98 return input_->output_dtypes();
99 }
100
output_shapes() const101 const std::vector<PartialTensorShape>& output_shapes() const override {
102 return output_shapes_;
103 }
104
DebugString() const105 string DebugString() const override {
106 name_utils::DatasetDebugStringParams params;
107 params.op_version = op_version_;
108 params.set_args(batch_size_);
109 return name_utils::DatasetDebugString(kDatasetType, params);
110 }
111
CardinalityInternal() const112 int64_t CardinalityInternal() const override {
113 int64_t n = input_->Cardinality();
114 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
115 return n;
116 }
117 return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
118 }
119
CardinalityInternal(CardinalityOptions options) const120 int64_t CardinalityInternal(CardinalityOptions options) const override {
121 int64_t n = input_->Cardinality(options);
122 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
123 return n;
124 }
125 return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
126 }
127
InputDatasets(std::vector<const DatasetBase * > * inputs) const128 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
129 inputs->push_back(input_);
130 return OkStatus();
131 }
132
CheckExternalState() const133 Status CheckExternalState() const override {
134 return input_->CheckExternalState();
135 }
136
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const137 Status Get(OpKernelContext* ctx, int64 index,
138 std::vector<Tensor>* out_tensors) const override {
139 const int64 cardinality = Cardinality();
140 if (index < 0 || index >= cardinality) {
141 return errors::OutOfRange("Index out of range [0, ", cardinality,
142 "):", index);
143 }
144 int batch_start_index = batch_size_ * index;
145 std::vector<std::vector<Tensor>> batch_elements;
146 int input_cardinality = input_->Cardinality();
147 for (int i = batch_start_index;
148 i < batch_start_index + batch_size_ && i < input_cardinality; ++i) {
149 std::vector<Tensor> batch_element_tuple;
150 TF_RETURN_IF_ERROR(input_->Get(ctx, i, &batch_element_tuple));
151 batch_elements.emplace_back(std::move(batch_element_tuple));
152 }
153 TF_RETURN_IF_ERROR(CopyBatch(CopyBatchParams(ctx), batch_elements,
154 parallel_copy_,
155 /*allocation_callback=*/nullptr, out_tensors));
156 return OkStatus();
157 }
158
159 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const160 Status AsGraphDefInternal(SerializationContext* ctx,
161 DatasetGraphDefBuilder* b,
162 Node** output) const override {
163 Node* input_graph_node = nullptr;
164 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
165 Node* batch_size = nullptr;
166 TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
167 Node* drop_remainder = nullptr;
168 TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
169 AttrValue parallel_copy;
170 b->BuildAttrValue(parallel_copy_, ¶llel_copy);
171 TF_RETURN_IF_ERROR(
172 b->AddDataset(this, {input_graph_node, batch_size, drop_remainder},
173 {{kParallelCopy, parallel_copy}}, output));
174 return OkStatus();
175 }
176
177 private:
178 class Iterator : public DatasetIterator<Dataset> {
179 public:
Iterator(const Params & params)180 explicit Iterator(const Params& params)
181 : DatasetIterator<Dataset>(params) {}
182
Initialize(IteratorContext * ctx)183 Status Initialize(IteratorContext* ctx) override {
184 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
185 }
186
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)187 Status GetNextInternal(IteratorContext* ctx,
188 std::vector<Tensor>* out_tensors,
189 bool* end_of_sequence) override {
190 // Each row of `batch_elements` is a tuple of tensors from the
191 // input iterator.
192 std::vector<std::vector<Tensor>> batch_elements;
193 {
194 mutex_lock l(mu_);
195 if (!input_impl_) {
196 *end_of_sequence = true;
197 return OkStatus();
198 }
199 batch_elements.reserve(dataset()->reserve_size_);
200 *end_of_sequence = false;
201 for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence; ++i) {
202 std::vector<Tensor> batch_element_tuple;
203 TF_RETURN_IF_ERROR(
204 input_impl_->GetNext(ctx, &batch_element_tuple, end_of_sequence));
205 if (!*end_of_sequence) {
206 batch_elements.emplace_back(std::move(batch_element_tuple));
207 } else {
208 input_impl_.reset();
209 }
210 }
211 }
212
213 if (batch_elements.empty()) {
214 DCHECK(*end_of_sequence);
215 return OkStatus();
216 }
217
218 if (dataset()->drop_remainder_ &&
219 batch_elements.size() < dataset()->batch_size_) {
220 *end_of_sequence = true;
221 return OkStatus();
222 }
223
224 // Copy the retrieved batch elements into one output tensor per tuple
225 // component.
226 //
227 // NOTE(mrry): If the input or output sizes are statically known, we
228 // could potentially read the input values in-place into their
229 // respective slice locations. This would require a different GetNext()
230 // overload that supports zero-copy, and might make sense in an
231 // optimization pass.
232 TF_RETURN_IF_ERROR(CopyBatch(
233 CopyBatchParams(ctx), batch_elements, dataset()->parallel_copy_,
234 /*allocation_callback=*/nullptr, out_tensors));
235
236 *end_of_sequence = false;
237 return OkStatus();
238 }
239
240 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const241 std::shared_ptr<model::Node> CreateNode(
242 IteratorContext* ctx, model::Node::Args args) const override {
243 return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
244 }
245
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)246 Status SaveInternal(SerializationContext* ctx,
247 IteratorStateWriter* writer) override {
248 mutex_lock l(mu_);
249 if (!input_impl_) {
250 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
251 } else {
252 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
253 }
254 return OkStatus();
255 }
256
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)257 Status RestoreInternal(IteratorContext* ctx,
258 IteratorStateReader* reader) override {
259 mutex_lock l(mu_);
260 if (!reader->Contains(full_name(kInputImplEmpty))) {
261 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
262 } else {
263 input_impl_.reset();
264 }
265 return OkStatus();
266 }
267
GetTraceMeMetadata() const268 TraceMeMetadata GetTraceMeMetadata() const override {
269 return dataset()->traceme_metadata_;
270 }
271
272 private:
273 mutex mu_;
274 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
275 };
276
277 const int64_t batch_size_;
278 const int64_t reserve_size_;
279 const bool drop_remainder_;
280 const bool parallel_copy_;
281 const DatasetBase* const input_;
282 const int op_version_;
283 std::vector<PartialTensorShape> output_shapes_;
284 const TraceMeMetadata traceme_metadata_;
285 };
286
BatchDatasetOp(OpKernelConstruction * ctx)287 BatchDatasetOp::BatchDatasetOp(OpKernelConstruction* ctx)
288 : UnaryDatasetOpKernel(ctx),
289 op_version_(ctx->def().op() == kBatchDataset ? 1 : 2) {
290 if (ctx->HasAttr(kParallelCopy)) {
291 OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, ¶llel_copy_));
292 }
293 }
294
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)295 void BatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
296 DatasetBase** output) {
297 int64_t batch_size = 0;
298 OP_REQUIRES_OK(ctx,
299 ParseScalarArgument<int64_t>(ctx, kBatchSize, &batch_size));
300 OP_REQUIRES(ctx, batch_size > 0,
301 errors::InvalidArgument("Batch size must be greater than zero."));
302
303 bool drop_remainder = false;
304 if (op_version_ > 1) {
305 OP_REQUIRES_OK(
306 ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
307 }
308
309 *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_, input,
310 op_version_);
311 }
312
313 namespace {
314 REGISTER_KERNEL_BUILDER(Name("BatchDataset").Device(DEVICE_CPU),
315 BatchDatasetOp);
316
317 REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU),
318 BatchDatasetOp);
319 } // namespace
320 } // namespace data
321 } // namespace tensorflow
322