xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/padded_batch_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/padded_batch_dataset_op.h"
16 
17 #include "tensorflow/core/data/dataset_utils.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/partial_tensor_shape.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_util.h"
24 #include "tensorflow/core/lib/core/blocking_counter.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/stringprintf.h"
29 #include "tensorflow/core/util/batch_util.h"
30 
31 namespace tensorflow {
32 namespace data {
33 
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36 
37 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const PaddedBatchDatasetOp::kInputDataset;
39 /* static */ constexpr const char* const PaddedBatchDatasetOp::kBatchSize;
40 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddedShapes;
41 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddingValues;
42 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDropRemainder;
43 /* static */ constexpr const char* const PaddedBatchDatasetOp::kParallelCopy;
44 /* static */ constexpr const char* const PaddedBatchDatasetOp::kToutputTypes;
45 /* static */ constexpr const char* const PaddedBatchDatasetOp::kOutputShapes;
46 /* static */ constexpr const char* const PaddedBatchDatasetOp::kNumPaddedShapes;
47 
48 constexpr char kExhausted[] = "exhausted";
49 
50 class PaddedBatchDatasetOp::Dataset : public DatasetBase {
51  public:
Dataset(OpKernelContext * ctx,int64_t batch_size,bool drop_remainder,bool parallel_copy,std::vector<PartialTensorShape> padded_shapes,std::vector<Tensor> padding_values,const DatasetBase * input,int op_version)52   Dataset(OpKernelContext* ctx, int64_t batch_size, bool drop_remainder,
53           bool parallel_copy, std::vector<PartialTensorShape> padded_shapes,
54           std::vector<Tensor> padding_values, const DatasetBase* input,
55           int op_version)
56       : DatasetBase(DatasetContext(ctx)),
57         batch_size_(batch_size),
58         drop_remainder_(drop_remainder),
59         parallel_copy_(parallel_copy),
60         padded_shapes_(std::move(padded_shapes)),
61         padding_values_(std::move(padding_values)),
62         input_(input),
63         op_version_(op_version),
64         traceme_metadata_(
65             {{"batch_size",
66               strings::Printf("%lld", static_cast<long long>(batch_size))},
67              {"drop_remainder", drop_remainder ? "true" : "false"},
68              {"parallel_copy", parallel_copy ? "true" : "false"}}) {
69     input_->Ref();
70 
71     // NOTE(mrry): Currently we implement "batch up to" semantics. If we could
72     // tell statically that the input dataset is infinite, then we could
73     // always report `batch_size` as the 0th dimension.
74     //
75     // TODO(mrry): Need to validate that the input shape and the padded shape
76     // are "compatible" (i.e. that padded shape is >= input shape, with both
77     // static and dynamic checks as appropriate).
78     const auto& input_shapes = input_->output_shapes();
79     output_shapes_.reserve(input_shapes.size());
80     for (size_t i = 0; i < input_shapes.size(); ++i) {
81       if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
82         output_shapes_.push_back(
83             PartialTensorShape({batch_size_}).Concatenate(padded_shapes_[i]));
84       } else {
85         output_shapes_.push_back(
86             PartialTensorShape({-1}).Concatenate(padded_shapes_[i]));
87       }
88     }
89   }
90 
~Dataset()91   ~Dataset() override { input_->Unref(); }
92 
MakeIteratorInternal(const string & prefix) const93   std::unique_ptr<IteratorBase> MakeIteratorInternal(
94       const string& prefix) const override {
95     name_utils::IteratorPrefixParams params;
96     params.op_version = op_version_;
97     return std::make_unique<Iterator>(Iterator::Params{
98         this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
99   }
100 
output_dtypes() const101   const DataTypeVector& output_dtypes() const override {
102     return input_->output_dtypes();
103   }
104 
output_shapes() const105   const std::vector<PartialTensorShape>& output_shapes() const override {
106     return output_shapes_;
107   }
108 
DebugString() const109   string DebugString() const override {
110     name_utils::DatasetDebugStringParams params;
111     params.op_version = op_version_;
112     params.set_args(batch_size_);
113     return name_utils::DatasetDebugString(kDatasetType, params);
114   }
115 
CardinalityInternal() const116   int64_t CardinalityInternal() const override {
117     int64_t n = input_->Cardinality();
118     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
119       return n;
120     }
121     return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
122   }
123 
InputDatasets(std::vector<const DatasetBase * > * inputs) const124   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
125     inputs->push_back(input_);
126     return OkStatus();
127   }
128 
CheckExternalState() const129   Status CheckExternalState() const override {
130     return input_->CheckExternalState();
131   }
132 
133  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const134   Status AsGraphDefInternal(SerializationContext* ctx,
135                             DatasetGraphDefBuilder* b,
136                             Node** output) const override {
137     Node* input_graph_node = nullptr;
138     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
139     Node* batch_size = nullptr;
140     TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
141 
142     std::vector<Node*> padded_shapes;
143     padded_shapes.reserve(padded_shapes_.size());
144     for (int i = 0; i < padded_shapes_.size(); i++) {
145       Node* node;
146       Tensor t(DT_INT64, TensorShape({padded_shapes_[i].dims()}));
147       for (int j = 0; j < padded_shapes_[i].dims(); j++) {
148         t.vec<int64_t>()(j) = padded_shapes_[i].dim_size(j);
149       }
150       TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
151       padded_shapes.emplace_back(node);
152     }
153 
154     std::vector<Node*> padding_values;
155     padding_values.reserve(padding_values_.size());
156     for (const Tensor& t : padding_values_) {
157       Node* node;
158       TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
159       padding_values.emplace_back(node);
160     }
161 
162     Node* drop_remainder = nullptr;
163     TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
164 
165     AttrValue parallel_copy;
166     b->BuildAttrValue(parallel_copy_, &parallel_copy);
167 
168     AttrValue output_types;
169     b->BuildAttrValue(output_dtypes(), &output_types);
170 
171     AttrValue N;
172     b->BuildAttrValue<int64_t>(padded_shapes_.size(), &N);
173 
174     TF_RETURN_IF_ERROR(b->AddDataset(
175         this, {{0, input_graph_node}, {1, batch_size}, {4, drop_remainder}},
176         {{2, padded_shapes}, {3, padding_values}},
177         {{kParallelCopy, parallel_copy},
178          {kToutputTypes, output_types},
179          {kNumPaddedShapes, N}},
180         output));
181     return OkStatus();
182   }
183 
184  private:
185   class Iterator : public DatasetIterator<Dataset> {
186    public:
Iterator(const Params & params)187     explicit Iterator(const Params& params)
188         : DatasetIterator<Dataset>(params) {}
189 
Initialize(IteratorContext * ctx)190     Status Initialize(IteratorContext* ctx) override {
191       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
192     }
193 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)194     Status GetNextInternal(IteratorContext* ctx,
195                            std::vector<Tensor>* out_tensors,
196                            bool* end_of_sequence) override {
197       // Each row of `batch_elements` is a tuple of tensors from the
198       // input iterator.
199       std::vector<std::vector<Tensor>> batch_elements;
200       {
201         mutex_lock l(mu_);
202         if (!input_impl_) {
203           *end_of_sequence = true;
204           return OkStatus();
205         } else {
206           *end_of_sequence = false;
207           batch_elements.reserve(dataset()->batch_size_);
208           for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
209                ++i) {
210             std::vector<Tensor> batch_element_tuple;
211             TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
212                                                     end_of_sequence));
213             if (!*end_of_sequence) {
214               batch_elements.push_back(std::move(batch_element_tuple));
215             }
216           }
217           if (*end_of_sequence) {
218             input_impl_.reset();
219           }
220         }
221       }
222 
223       if (batch_elements.empty()) {
224         DCHECK(*end_of_sequence);
225         return OkStatus();
226       }
227 
228       if (dataset()->drop_remainder_ &&
229           batch_elements.size() < dataset()->batch_size_) {
230         *end_of_sequence = true;
231         return OkStatus();
232       }
233 
234       TF_RETURN_IF_ERROR(CopyBatch(ctx, batch_elements, out_tensors));
235       *end_of_sequence = false;
236       return OkStatus();
237     }
238 
239    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const240     std::shared_ptr<model::Node> CreateNode(
241         IteratorContext* ctx, model::Node::Args args) const override {
242       return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
243     }
244 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)245     Status SaveInternal(SerializationContext* ctx,
246                         IteratorStateWriter* writer) override {
247       mutex_lock l(mu_);
248       if (input_impl_)
249         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
250       else
251         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), ""));
252       return OkStatus();
253     }
254 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)255     Status RestoreInternal(IteratorContext* ctx,
256                            IteratorStateReader* reader) override {
257       mutex_lock l(mu_);
258       if (reader->Contains(full_name(kExhausted))) {
259         input_impl_.reset();
260       } else {
261         TF_RETURN_IF_ERROR(
262             dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
263         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
264       }
265       return OkStatus();
266     }
267 
GetTraceMeMetadata() const268     TraceMeMetadata GetTraceMeMetadata() const override {
269       return dataset()->traceme_metadata_;
270     }
271 
272    private:
273     // Copies the retrieved batch elements into one output tensor per tuple
274     // component.
275     //
276     // NOTE(mrry): If the input or output sizes are statically known, we could
277     // potentially read the input values in-place into their respective slice
278     // locations. This would require a different GetNext() overload that
279     // supports zero-copy, and might make sense in an optimization pass.
CopyBatch(IteratorContext * ctx,const std::vector<std::vector<Tensor>> & batch_elements,std::vector<Tensor> * out_tensors)280     Status CopyBatch(IteratorContext* ctx,
281                      const std::vector<std::vector<Tensor>>& batch_elements,
282                      std::vector<Tensor>* out_tensors) {
283       const size_t num_tuple_components = batch_elements[0].size();
284       const int64_t num_batch_elements = batch_elements.size();
285       for (size_t component_index = 0; component_index < num_tuple_components;
286            ++component_index) {
287         // 1. Determine the shape of the padded tensor.
288         TensorShape batch_component_shape({num_batch_elements});
289         const PartialTensorShape& padded_shape =
290             dataset()->padded_shapes_[component_index];
291 
292         for (int dim = 0; dim < padded_shape.dims(); ++dim) {
293           if (padded_shape.dim_size(dim) == -1) {
294             batch_component_shape.AddDim(0);
295           } else {
296             batch_component_shape.AddDim(padded_shape.dim_size(dim));
297           }
298         }
299 
300         for (int64_t i = 0; i < num_batch_elements; ++i) {
301           const TensorShape& element_shape =
302               batch_elements[i][component_index].shape();
303           // TODO(mrry): Perform this check in the shape function if
304           // enough static information is available to do so.
305           if (element_shape.dims() != padded_shape.dims()) {
306             return errors::InvalidArgument(
307                 "All elements in a batch must have the same rank as the "
308                 "padded shape for component",
309                 component_index, ": expected rank ", padded_shape.dims(),
310                 " but got element with rank ", element_shape.dims());
311           }
312           for (int dim = 0; dim < padded_shape.dims(); ++dim) {
313             if (padded_shape.dim_size(dim) == -1) {
314               // Take the max of all batch elements in this dimension.
315               if (batch_elements[i][component_index].shape().dim_size(dim) >
316                   batch_component_shape.dim_size(dim + 1)) {
317                 batch_component_shape.set_dim(
318                     dim + 1,
319                     batch_elements[i][component_index].shape().dim_size(dim));
320               }
321             } else {
322               if (batch_elements[i][component_index].shape().dim_size(dim) >
323                   batch_component_shape.dim_size(dim + 1)) {
324                 return errors::DataLoss(
325                     "Attempted to pad to a smaller size than the input "
326                     "element.");
327               }
328             }
329           }
330         }
331 
332         // 2. Copy each batch element to the appropriate location in
333         // the output component tensor.
334         out_tensors->emplace_back(ctx->allocator({}),
335                                   output_dtypes()[component_index],
336                                   batch_component_shape);
337         Tensor& batch_component = out_tensors->back();
338         TF_RETURN_IF_ERROR(batch_util::SetElementZero(
339             &batch_component, dataset()->padding_values_[component_index]));
340 
341         // Build the output tuple component by copying one slice from each input
342         // element in the batch.
343         TensorShape component_shape({});
344         for (int i = 1; i < batch_component_shape.dims(); ++i) {
345           component_shape.AddDim(batch_component_shape.dim_size(i));
346         }
347         auto copy_element_fn = [component_index, &batch_elements,
348                                 &batch_component, &component_shape](int index) {
349           // Take the fast path if possible.
350           if (batch_elements[index][component_index].shape() ==
351               component_shape) {
352             TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(
353                 batch_elements[index][component_index], &batch_component,
354                 index));
355           } else {
356             TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice(
357                 batch_elements[index][component_index], &batch_component,
358                 index));
359           }
360           return OkStatus();
361         };
362 
363         if (dataset()->parallel_copy_ && (batch_component.AllocatedBytes() /
364                                           num_batch_elements) >= (1 << 15)) {
365           BlockingCounter counter(num_batch_elements);
366           Status status;
367           mutex status_mu;
368           const auto num_threads = ctx->runner_threadpool_size();
369           const auto slice_size = num_batch_elements / num_threads;
370           int64_t offset = 0;
371           for (size_t i = 0; i < num_threads; ++i) {
372             int64_t length = slice_size;
373             // When the number of threads does not divide the number of elements
374             // evenly, the size of some slices is incremented to guarantee their
375             // sizes add up to the total number of elements.
376             if (i < num_batch_elements % num_threads) ++length;
377             (*ctx->runner())([offset, length, &status, &status_mu, &counter,
378                               &copy_element_fn]() {
379               for (size_t j = offset; j < offset + length; ++j) {
380                 {
381                   Status s = copy_element_fn(j);
382                   mutex_lock l(status_mu);
383                   status.Update(s);
384                 }
385                 counter.DecrementCount();
386               }
387             });
388             offset += length;
389           }
390           counter.Wait();
391           TF_RETURN_IF_ERROR(status);
392         } else {
393           for (size_t i = 0; i < num_batch_elements; ++i) {
394             TF_RETURN_IF_ERROR(copy_element_fn(i));
395           }
396         }
397       }
398       return OkStatus();
399     }
400 
401     mutex mu_;
402     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
403   };
404 
405   const int64_t batch_size_;
406   const bool drop_remainder_;
407   const bool parallel_copy_;
408   const std::vector<PartialTensorShape> padded_shapes_;
409   const std::vector<Tensor> padding_values_;
410   const DatasetBase* const input_;
411   const int op_version_;
412   std::vector<PartialTensorShape> output_shapes_;
413   const TraceMeMetadata traceme_metadata_;
414 };
415 
PaddedBatchDatasetOp(OpKernelConstruction * ctx)416 PaddedBatchDatasetOp::PaddedBatchDatasetOp(OpKernelConstruction* ctx)
417     : UnaryDatasetOpKernel(ctx),
418       op_version_(ctx->def().op() == "PaddedBatchDataset" ? 1 : 2) {
419   if (ctx->HasAttr(kParallelCopy)) {
420     OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, &parallel_copy_));
421   }
422 }
423 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)424 void PaddedBatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
425                                        DatasetBase** output) {
426   int64_t batch_size;
427   OP_REQUIRES_OK(ctx,
428                  ParseScalarArgument<int64_t>(ctx, kBatchSize, &batch_size));
429   OP_REQUIRES(ctx, batch_size > 0,
430               errors::InvalidArgument("Batch size must be greater than zero."));
431 
432   bool drop_remainder = false;
433   if (op_version_ > 1) {
434     OP_REQUIRES_OK(
435         ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
436   }
437 
438   OpInputList padded_shape_tensors;
439   OP_REQUIRES_OK(ctx, ctx->input_list(kPaddedShapes, &padded_shape_tensors));
440   std::vector<PartialTensorShape> padded_shapes;
441   padded_shapes.reserve(padded_shape_tensors.size());
442   OP_REQUIRES(ctx, padded_shape_tensors.size() == input->output_shapes().size(),
443               errors::InvalidArgument("Number of padded shapes (",
444                                       padded_shape_tensors.size(),
445                                       ") must match the number of components "
446                                       "in the input dataset's elements (",
447                                       input->output_shapes().size(), ")"));
448   for (const Tensor& padded_shape_t : padded_shape_tensors) {
449     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(padded_shape_t.shape()),
450                 errors::InvalidArgument("All padded shapes must be vectors"));
451     PartialTensorShape padded_shape;
452     OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape(
453                             padded_shape_t.vec<int64_t>().data(),
454                             padded_shape_t.NumElements(), &padded_shape));
455     padded_shapes.push_back(std::move(padded_shape));
456   }
457   OpInputList padding_values_list;
458   OP_REQUIRES_OK(ctx, ctx->input_list(kPaddingValues, &padding_values_list));
459   std::vector<Tensor> padding_values;
460   OP_REQUIRES(ctx, padding_values_list.size() == input->output_shapes().size(),
461               errors::InvalidArgument(
462                   "Number of padding values (", padding_values_list.size(),
463                   ") must match the number of components in the input "
464                   "dataset's elements (",
465                   input->output_shapes().size(), ")"));
466   for (int i = 0; i < padding_values_list.size(); ++i) {
467     const Tensor& padding_value_t = padding_values_list[i];
468     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(padding_value_t.shape()),
469                 errors::InvalidArgument("All padding values must be scalars"));
470     OP_REQUIRES(ctx, padding_value_t.dtype() == input->output_dtypes()[i],
471                 errors::InvalidArgument(
472                     "Mismatched type between padding value ", i,
473                     " and input dataset's component ", i, ": ",
474                     DataTypeString(padding_value_t.dtype()), " vs. ",
475                     DataTypeString(input->output_dtypes()[i])));
476     padding_values.push_back(tensor::DeepCopy(padding_value_t));
477   }
478 
479   *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_,
480                         std::move(padded_shapes), std::move(padding_values),
481                         input, op_version_);
482 }
483 
484 namespace {
485 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDataset").Device(DEVICE_CPU),
486                         PaddedBatchDatasetOp);
487 
488 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU),
489                         PaddedBatchDatasetOp);
490 }  // namespace
491 }  // namespace data
492 }  // namespace tensorflow
493