1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/padded_batch_dataset_op.h"
16
17 #include "tensorflow/core/data/dataset_utils.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/partial_tensor_shape.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_util.h"
24 #include "tensorflow/core/lib/core/blocking_counter.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/stringprintf.h"
29 #include "tensorflow/core/util/batch_util.h"
30
31 namespace tensorflow {
32 namespace data {
33
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36
37 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const PaddedBatchDatasetOp::kInputDataset;
39 /* static */ constexpr const char* const PaddedBatchDatasetOp::kBatchSize;
40 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddedShapes;
41 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddingValues;
42 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDropRemainder;
43 /* static */ constexpr const char* const PaddedBatchDatasetOp::kParallelCopy;
44 /* static */ constexpr const char* const PaddedBatchDatasetOp::kToutputTypes;
45 /* static */ constexpr const char* const PaddedBatchDatasetOp::kOutputShapes;
46 /* static */ constexpr const char* const PaddedBatchDatasetOp::kNumPaddedShapes;
47
48 constexpr char kExhausted[] = "exhausted";
49
50 class PaddedBatchDatasetOp::Dataset : public DatasetBase {
51 public:
Dataset(OpKernelContext * ctx,int64_t batch_size,bool drop_remainder,bool parallel_copy,std::vector<PartialTensorShape> padded_shapes,std::vector<Tensor> padding_values,const DatasetBase * input,int op_version)52 Dataset(OpKernelContext* ctx, int64_t batch_size, bool drop_remainder,
53 bool parallel_copy, std::vector<PartialTensorShape> padded_shapes,
54 std::vector<Tensor> padding_values, const DatasetBase* input,
55 int op_version)
56 : DatasetBase(DatasetContext(ctx)),
57 batch_size_(batch_size),
58 drop_remainder_(drop_remainder),
59 parallel_copy_(parallel_copy),
60 padded_shapes_(std::move(padded_shapes)),
61 padding_values_(std::move(padding_values)),
62 input_(input),
63 op_version_(op_version),
64 traceme_metadata_(
65 {{"batch_size",
66 strings::Printf("%lld", static_cast<long long>(batch_size))},
67 {"drop_remainder", drop_remainder ? "true" : "false"},
68 {"parallel_copy", parallel_copy ? "true" : "false"}}) {
69 input_->Ref();
70
71 // NOTE(mrry): Currently we implement "batch up to" semantics. If we could
72 // tell statically that the input dataset is infinite, then we could
73 // always report `batch_size` as the 0th dimension.
74 //
75 // TODO(mrry): Need to validate that the input shape and the padded shape
76 // are "compatible" (i.e. that padded shape is >= input shape, with both
77 // static and dynamic checks as appropriate).
78 const auto& input_shapes = input_->output_shapes();
79 output_shapes_.reserve(input_shapes.size());
80 for (size_t i = 0; i < input_shapes.size(); ++i) {
81 if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
82 output_shapes_.push_back(
83 PartialTensorShape({batch_size_}).Concatenate(padded_shapes_[i]));
84 } else {
85 output_shapes_.push_back(
86 PartialTensorShape({-1}).Concatenate(padded_shapes_[i]));
87 }
88 }
89 }
90
~Dataset()91 ~Dataset() override { input_->Unref(); }
92
MakeIteratorInternal(const string & prefix) const93 std::unique_ptr<IteratorBase> MakeIteratorInternal(
94 const string& prefix) const override {
95 name_utils::IteratorPrefixParams params;
96 params.op_version = op_version_;
97 return std::make_unique<Iterator>(Iterator::Params{
98 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
99 }
100
output_dtypes() const101 const DataTypeVector& output_dtypes() const override {
102 return input_->output_dtypes();
103 }
104
output_shapes() const105 const std::vector<PartialTensorShape>& output_shapes() const override {
106 return output_shapes_;
107 }
108
DebugString() const109 string DebugString() const override {
110 name_utils::DatasetDebugStringParams params;
111 params.op_version = op_version_;
112 params.set_args(batch_size_);
113 return name_utils::DatasetDebugString(kDatasetType, params);
114 }
115
CardinalityInternal() const116 int64_t CardinalityInternal() const override {
117 int64_t n = input_->Cardinality();
118 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
119 return n;
120 }
121 return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
122 }
123
InputDatasets(std::vector<const DatasetBase * > * inputs) const124 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
125 inputs->push_back(input_);
126 return OkStatus();
127 }
128
CheckExternalState() const129 Status CheckExternalState() const override {
130 return input_->CheckExternalState();
131 }
132
133 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const134 Status AsGraphDefInternal(SerializationContext* ctx,
135 DatasetGraphDefBuilder* b,
136 Node** output) const override {
137 Node* input_graph_node = nullptr;
138 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
139 Node* batch_size = nullptr;
140 TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
141
142 std::vector<Node*> padded_shapes;
143 padded_shapes.reserve(padded_shapes_.size());
144 for (int i = 0; i < padded_shapes_.size(); i++) {
145 Node* node;
146 Tensor t(DT_INT64, TensorShape({padded_shapes_[i].dims()}));
147 for (int j = 0; j < padded_shapes_[i].dims(); j++) {
148 t.vec<int64_t>()(j) = padded_shapes_[i].dim_size(j);
149 }
150 TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
151 padded_shapes.emplace_back(node);
152 }
153
154 std::vector<Node*> padding_values;
155 padding_values.reserve(padding_values_.size());
156 for (const Tensor& t : padding_values_) {
157 Node* node;
158 TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
159 padding_values.emplace_back(node);
160 }
161
162 Node* drop_remainder = nullptr;
163 TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
164
165 AttrValue parallel_copy;
166 b->BuildAttrValue(parallel_copy_, ¶llel_copy);
167
168 AttrValue output_types;
169 b->BuildAttrValue(output_dtypes(), &output_types);
170
171 AttrValue N;
172 b->BuildAttrValue<int64_t>(padded_shapes_.size(), &N);
173
174 TF_RETURN_IF_ERROR(b->AddDataset(
175 this, {{0, input_graph_node}, {1, batch_size}, {4, drop_remainder}},
176 {{2, padded_shapes}, {3, padding_values}},
177 {{kParallelCopy, parallel_copy},
178 {kToutputTypes, output_types},
179 {kNumPaddedShapes, N}},
180 output));
181 return OkStatus();
182 }
183
184 private:
185 class Iterator : public DatasetIterator<Dataset> {
186 public:
Iterator(const Params & params)187 explicit Iterator(const Params& params)
188 : DatasetIterator<Dataset>(params) {}
189
Initialize(IteratorContext * ctx)190 Status Initialize(IteratorContext* ctx) override {
191 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
192 }
193
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)194 Status GetNextInternal(IteratorContext* ctx,
195 std::vector<Tensor>* out_tensors,
196 bool* end_of_sequence) override {
197 // Each row of `batch_elements` is a tuple of tensors from the
198 // input iterator.
199 std::vector<std::vector<Tensor>> batch_elements;
200 {
201 mutex_lock l(mu_);
202 if (!input_impl_) {
203 *end_of_sequence = true;
204 return OkStatus();
205 } else {
206 *end_of_sequence = false;
207 batch_elements.reserve(dataset()->batch_size_);
208 for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
209 ++i) {
210 std::vector<Tensor> batch_element_tuple;
211 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
212 end_of_sequence));
213 if (!*end_of_sequence) {
214 batch_elements.push_back(std::move(batch_element_tuple));
215 }
216 }
217 if (*end_of_sequence) {
218 input_impl_.reset();
219 }
220 }
221 }
222
223 if (batch_elements.empty()) {
224 DCHECK(*end_of_sequence);
225 return OkStatus();
226 }
227
228 if (dataset()->drop_remainder_ &&
229 batch_elements.size() < dataset()->batch_size_) {
230 *end_of_sequence = true;
231 return OkStatus();
232 }
233
234 TF_RETURN_IF_ERROR(CopyBatch(ctx, batch_elements, out_tensors));
235 *end_of_sequence = false;
236 return OkStatus();
237 }
238
239 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const240 std::shared_ptr<model::Node> CreateNode(
241 IteratorContext* ctx, model::Node::Args args) const override {
242 return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
243 }
244
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)245 Status SaveInternal(SerializationContext* ctx,
246 IteratorStateWriter* writer) override {
247 mutex_lock l(mu_);
248 if (input_impl_)
249 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
250 else
251 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), ""));
252 return OkStatus();
253 }
254
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)255 Status RestoreInternal(IteratorContext* ctx,
256 IteratorStateReader* reader) override {
257 mutex_lock l(mu_);
258 if (reader->Contains(full_name(kExhausted))) {
259 input_impl_.reset();
260 } else {
261 TF_RETURN_IF_ERROR(
262 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
263 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
264 }
265 return OkStatus();
266 }
267
GetTraceMeMetadata() const268 TraceMeMetadata GetTraceMeMetadata() const override {
269 return dataset()->traceme_metadata_;
270 }
271
272 private:
273 // Copies the retrieved batch elements into one output tensor per tuple
274 // component.
275 //
276 // NOTE(mrry): If the input or output sizes are statically known, we could
277 // potentially read the input values in-place into their respective slice
278 // locations. This would require a different GetNext() overload that
279 // supports zero-copy, and might make sense in an optimization pass.
CopyBatch(IteratorContext * ctx,const std::vector<std::vector<Tensor>> & batch_elements,std::vector<Tensor> * out_tensors)280 Status CopyBatch(IteratorContext* ctx,
281 const std::vector<std::vector<Tensor>>& batch_elements,
282 std::vector<Tensor>* out_tensors) {
283 const size_t num_tuple_components = batch_elements[0].size();
284 const int64_t num_batch_elements = batch_elements.size();
285 for (size_t component_index = 0; component_index < num_tuple_components;
286 ++component_index) {
287 // 1. Determine the shape of the padded tensor.
288 TensorShape batch_component_shape({num_batch_elements});
289 const PartialTensorShape& padded_shape =
290 dataset()->padded_shapes_[component_index];
291
292 for (int dim = 0; dim < padded_shape.dims(); ++dim) {
293 if (padded_shape.dim_size(dim) == -1) {
294 batch_component_shape.AddDim(0);
295 } else {
296 batch_component_shape.AddDim(padded_shape.dim_size(dim));
297 }
298 }
299
300 for (int64_t i = 0; i < num_batch_elements; ++i) {
301 const TensorShape& element_shape =
302 batch_elements[i][component_index].shape();
303 // TODO(mrry): Perform this check in the shape function if
304 // enough static information is available to do so.
305 if (element_shape.dims() != padded_shape.dims()) {
306 return errors::InvalidArgument(
307 "All elements in a batch must have the same rank as the "
308 "padded shape for component",
309 component_index, ": expected rank ", padded_shape.dims(),
310 " but got element with rank ", element_shape.dims());
311 }
312 for (int dim = 0; dim < padded_shape.dims(); ++dim) {
313 if (padded_shape.dim_size(dim) == -1) {
314 // Take the max of all batch elements in this dimension.
315 if (batch_elements[i][component_index].shape().dim_size(dim) >
316 batch_component_shape.dim_size(dim + 1)) {
317 batch_component_shape.set_dim(
318 dim + 1,
319 batch_elements[i][component_index].shape().dim_size(dim));
320 }
321 } else {
322 if (batch_elements[i][component_index].shape().dim_size(dim) >
323 batch_component_shape.dim_size(dim + 1)) {
324 return errors::DataLoss(
325 "Attempted to pad to a smaller size than the input "
326 "element.");
327 }
328 }
329 }
330 }
331
332 // 2. Copy each batch element to the appropriate location in
333 // the output component tensor.
334 out_tensors->emplace_back(ctx->allocator({}),
335 output_dtypes()[component_index],
336 batch_component_shape);
337 Tensor& batch_component = out_tensors->back();
338 TF_RETURN_IF_ERROR(batch_util::SetElementZero(
339 &batch_component, dataset()->padding_values_[component_index]));
340
341 // Build the output tuple component by copying one slice from each input
342 // element in the batch.
343 TensorShape component_shape({});
344 for (int i = 1; i < batch_component_shape.dims(); ++i) {
345 component_shape.AddDim(batch_component_shape.dim_size(i));
346 }
347 auto copy_element_fn = [component_index, &batch_elements,
348 &batch_component, &component_shape](int index) {
349 // Take the fast path if possible.
350 if (batch_elements[index][component_index].shape() ==
351 component_shape) {
352 TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(
353 batch_elements[index][component_index], &batch_component,
354 index));
355 } else {
356 TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice(
357 batch_elements[index][component_index], &batch_component,
358 index));
359 }
360 return OkStatus();
361 };
362
363 if (dataset()->parallel_copy_ && (batch_component.AllocatedBytes() /
364 num_batch_elements) >= (1 << 15)) {
365 BlockingCounter counter(num_batch_elements);
366 Status status;
367 mutex status_mu;
368 const auto num_threads = ctx->runner_threadpool_size();
369 const auto slice_size = num_batch_elements / num_threads;
370 int64_t offset = 0;
371 for (size_t i = 0; i < num_threads; ++i) {
372 int64_t length = slice_size;
373 // When the number of threads does not divide the number of elements
374 // evenly, the size of some slices is incremented to guarantee their
375 // sizes add up to the total number of elements.
376 if (i < num_batch_elements % num_threads) ++length;
377 (*ctx->runner())([offset, length, &status, &status_mu, &counter,
378 ©_element_fn]() {
379 for (size_t j = offset; j < offset + length; ++j) {
380 {
381 Status s = copy_element_fn(j);
382 mutex_lock l(status_mu);
383 status.Update(s);
384 }
385 counter.DecrementCount();
386 }
387 });
388 offset += length;
389 }
390 counter.Wait();
391 TF_RETURN_IF_ERROR(status);
392 } else {
393 for (size_t i = 0; i < num_batch_elements; ++i) {
394 TF_RETURN_IF_ERROR(copy_element_fn(i));
395 }
396 }
397 }
398 return OkStatus();
399 }
400
401 mutex mu_;
402 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
403 };
404
405 const int64_t batch_size_;
406 const bool drop_remainder_;
407 const bool parallel_copy_;
408 const std::vector<PartialTensorShape> padded_shapes_;
409 const std::vector<Tensor> padding_values_;
410 const DatasetBase* const input_;
411 const int op_version_;
412 std::vector<PartialTensorShape> output_shapes_;
413 const TraceMeMetadata traceme_metadata_;
414 };
415
PaddedBatchDatasetOp(OpKernelConstruction * ctx)416 PaddedBatchDatasetOp::PaddedBatchDatasetOp(OpKernelConstruction* ctx)
417 : UnaryDatasetOpKernel(ctx),
418 op_version_(ctx->def().op() == "PaddedBatchDataset" ? 1 : 2) {
419 if (ctx->HasAttr(kParallelCopy)) {
420 OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, ¶llel_copy_));
421 }
422 }
423
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)424 void PaddedBatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
425 DatasetBase** output) {
426 int64_t batch_size;
427 OP_REQUIRES_OK(ctx,
428 ParseScalarArgument<int64_t>(ctx, kBatchSize, &batch_size));
429 OP_REQUIRES(ctx, batch_size > 0,
430 errors::InvalidArgument("Batch size must be greater than zero."));
431
432 bool drop_remainder = false;
433 if (op_version_ > 1) {
434 OP_REQUIRES_OK(
435 ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
436 }
437
438 OpInputList padded_shape_tensors;
439 OP_REQUIRES_OK(ctx, ctx->input_list(kPaddedShapes, &padded_shape_tensors));
440 std::vector<PartialTensorShape> padded_shapes;
441 padded_shapes.reserve(padded_shape_tensors.size());
442 OP_REQUIRES(ctx, padded_shape_tensors.size() == input->output_shapes().size(),
443 errors::InvalidArgument("Number of padded shapes (",
444 padded_shape_tensors.size(),
445 ") must match the number of components "
446 "in the input dataset's elements (",
447 input->output_shapes().size(), ")"));
448 for (const Tensor& padded_shape_t : padded_shape_tensors) {
449 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(padded_shape_t.shape()),
450 errors::InvalidArgument("All padded shapes must be vectors"));
451 PartialTensorShape padded_shape;
452 OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape(
453 padded_shape_t.vec<int64_t>().data(),
454 padded_shape_t.NumElements(), &padded_shape));
455 padded_shapes.push_back(std::move(padded_shape));
456 }
457 OpInputList padding_values_list;
458 OP_REQUIRES_OK(ctx, ctx->input_list(kPaddingValues, &padding_values_list));
459 std::vector<Tensor> padding_values;
460 OP_REQUIRES(ctx, padding_values_list.size() == input->output_shapes().size(),
461 errors::InvalidArgument(
462 "Number of padding values (", padding_values_list.size(),
463 ") must match the number of components in the input "
464 "dataset's elements (",
465 input->output_shapes().size(), ")"));
466 for (int i = 0; i < padding_values_list.size(); ++i) {
467 const Tensor& padding_value_t = padding_values_list[i];
468 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(padding_value_t.shape()),
469 errors::InvalidArgument("All padding values must be scalars"));
470 OP_REQUIRES(ctx, padding_value_t.dtype() == input->output_dtypes()[i],
471 errors::InvalidArgument(
472 "Mismatched type between padding value ", i,
473 " and input dataset's component ", i, ": ",
474 DataTypeString(padding_value_t.dtype()), " vs. ",
475 DataTypeString(input->output_dtypes()[i])));
476 padding_values.push_back(tensor::DeepCopy(padding_value_t));
477 }
478
479 *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_,
480 std::move(padded_shapes), std::move(padding_values),
481 input, op_version_);
482 }
483
484 namespace {
485 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDataset").Device(DEVICE_CPU),
486 PaddedBatchDatasetOp);
487
488 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU),
489 PaddedBatchDatasetOp);
490 } // namespace
491 } // namespace data
492 } // namespace tensorflow
493