xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/window_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/window_dataset_op.h"
16 
17 #include "tensorflow/core/data/name_utils.h"
18 #include "tensorflow/core/framework/dataset.h"
19 #include "tensorflow/core/kernels/data/window_dataset.h"
20 #include "tensorflow/core/platform/stringprintf.h"
21 
22 namespace tensorflow {
23 namespace data {
24 
25 // See documentation in ../../ops/dataset_ops.cc for a high-level
26 // description of the following op.
27 
28 /* static */ constexpr const char* const WindowDatasetOp::kDatasetType;
29 /* static */ constexpr const char* const WindowDatasetOp::kInputDataset;
30 /* static */ constexpr const char* const WindowDatasetOp::kSize;
31 /* static */ constexpr const char* const WindowDatasetOp::kShift;
32 /* static */ constexpr const char* const WindowDatasetOp::kStride;
33 /* static */ constexpr const char* const WindowDatasetOp::kDropRemainder;
34 /* static */ constexpr const char* const WindowDatasetOp::kOutputTypes;
35 /* static */ constexpr const char* const WindowDatasetOp::kOutputShapes;
36 
37 constexpr char kInputImplEmpty[] = "input_impl_empty";
38 constexpr char kBufferSize[] = "buffer_size";
39 constexpr char kBuffer[] = "buffer";
40 constexpr char kSizeSuffix[] = ".size";
41 constexpr char kCodeSuffix[] = ".code";
42 constexpr char kErrorMessage[] = ".error_message";
43 
44 class WindowDatasetOp::Dataset : public DatasetBase {
45  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t window_size,int64_t window_shift,int64_t window_stride,bool drop_remainder)46   Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t window_size,
47           int64_t window_shift, int64_t window_stride, bool drop_remainder)
48       : DatasetBase(DatasetContext(ctx)),
49         input_(input),
50         window_size_(window_size),
51         window_shift_(window_shift),
52         window_stride_(window_stride),
53         drop_remainder_(drop_remainder),
54         output_dtypes_(input_->output_dtypes().size(), {DT_VARIANT}),
55         output_shapes_(input_->output_shapes().size(), TensorShape({})),
56         traceme_metadata_(
57             {{"window_size",
58               strings::Printf("%lld", static_cast<long long>(window_size))},
59              {"window_shift",
60               strings::Printf("%lld", static_cast<long long>(window_shift))},
61              {"window_stride", strings::Printf("%lld", static_cast<long long>(
62                                                            window_stride))}}) {
63     input_->Ref();
64   }
65 
~Dataset()66   ~Dataset() override { input_->Unref(); }
67 
MakeIteratorInternal(const string & prefix) const68   std::unique_ptr<IteratorBase> MakeIteratorInternal(
69       const string& prefix) const override {
70     return std::make_unique<Iterator>(Iterator::Params{
71         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
72   }
73 
output_dtypes() const74   const DataTypeVector& output_dtypes() const override {
75     return output_dtypes_;
76   }
77 
output_shapes() const78   const std::vector<PartialTensorShape>& output_shapes() const override {
79     return output_shapes_;
80   }
81 
DebugString() const82   string DebugString() const override {
83     name_utils::DatasetDebugStringParams params;
84     params.set_args(window_size_, window_shift_, window_stride_,
85                     drop_remainder_);
86     return name_utils::DatasetDebugString(kDatasetType, params);
87   }
88 
CardinalityInternal() const89   int64_t CardinalityInternal() const override {
90     int64_t n = input_->Cardinality();
91     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
92       return n;
93     }
94     int64_t cardinality = 0;
95     if (drop_remainder_) {
96       // Compute rest_elements, the number of elements after the last element
97       // of the initial window. If it is negative, we know that the
98       // cardinality is 0. Otherwise, it will be the number of valid shifts
99       // over the rest_elements.
100       int64_t rest_elements = n - ((window_size_ - 1) * window_stride_ + 1);
101       cardinality = rest_elements < 0 ? 0 : rest_elements / window_shift_ + 1;
102     } else {
103       cardinality = n / window_shift_ + (n % window_shift_ == 0 ? 0 : 1);
104     }
105     return cardinality;
106   }
107 
InputDatasets(std::vector<const DatasetBase * > * inputs) const108   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
109     inputs->push_back(input_);
110     return OkStatus();
111   }
112 
CheckExternalState() const113   Status CheckExternalState() const override {
114     return input_->CheckExternalState();
115   }
116 
117  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const118   Status AsGraphDefInternal(SerializationContext* ctx,
119                             DatasetGraphDefBuilder* b,
120                             Node** output) const override {
121     Node* input_graph_node = nullptr;
122     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
123     Node* window_size_node = nullptr;
124     TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node));
125     Node* window_shift_node = nullptr;
126     TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node));
127     Node* window_stride_node = nullptr;
128     TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node));
129     Node* drop_remainder_node = nullptr;
130     TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
131     TF_RETURN_IF_ERROR(
132         b->AddDataset(this,
133                       {input_graph_node, window_size_node, window_shift_node,
134                        window_stride_node, drop_remainder_node},
135                       output));
136     return OkStatus();
137   }
138 
139  private:
140   class Iterator : public DatasetIterator<Dataset> {
141    public:
Iterator(const Params & params)142     explicit Iterator(const Params& params)
143         : DatasetIterator<Dataset>(params) {}
144 
Initialize(IteratorContext * ctx)145     Status Initialize(IteratorContext* ctx) override {
146       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
147     }
148 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)149     Status GetNextInternal(IteratorContext* ctx,
150                            std::vector<Tensor>* out_tensors,
151                            bool* end_of_sequence) override {
152       const int64_t window_size = dataset()->window_size_;
153       const int64_t window_shift = dataset()->window_shift_;
154       const int64_t window_stride = dataset()->window_stride_;
155       std::vector<std::vector<Tensor>> window_elements;
156       Status status = OkStatus();
157       {
158         const size_t target_size = TargetBufferSize(window_size, window_stride);
159 
160         mutex_lock l(mu_);
161         if (!input_impl_ &&
162             (buffer_.empty() ||
163              (dataset()->drop_remainder_ && buffer_.size() < target_size))) {
164           *end_of_sequence = true;
165           return OkStatus();
166         }
167 
168         // Add elements to the buffer.
169         if (input_impl_) {
170           *end_of_sequence = false;
171           for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence;
172                ++i) {
173             std::vector<Tensor> element;
174             Status status =
175                 input_impl_->GetNext(ctx, &element, end_of_sequence);
176             if (!*end_of_sequence) {
177               RecordBufferEnqueue(ctx, element);
178               buffer_.emplace_back(std::move(element), status);
179             } else {
180               input_impl_.reset();
181             }
182           }
183         }
184 
185         // If there are not enough elements and `drop_remainder` is set, we do
186         // not wish to return a smaller window.
187         if (buffer_.empty() ||
188             (dataset()->drop_remainder_ && buffer_.size() < target_size)) {
189           DCHECK(*end_of_sequence);
190           return OkStatus();
191         }
192 
193         int num_elements = 1 + (buffer_.size() - 1) / window_stride;
194         window_elements.reserve(num_elements);
195         for (size_t i = 0; i < num_elements; ++i) {
196           status.Update(buffer_[window_stride * i].status);
197           if (!status.ok()) {
198             break;
199           }
200           window_elements.emplace_back(buffer_[window_stride * i].result);
201         }
202 
203         // Shift the window, discarding elements if necessary.
204         int buffer_size = buffer_.size();
205         if (window_shift >= buffer_size) {
206           for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) {
207             bool end_of_input;
208             std::vector<Tensor> element;
209             // Ignore non-error status of discarded elements.
210             input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError();
211             if (end_of_input) {
212               input_impl_.reset();
213             }
214           }
215           for (size_t i = 0; i < buffer_.size(); ++i) {
216             RecordBufferDequeue(ctx, buffer_.at(i).result);
217           }
218           buffer_.clear();
219         } else {
220           for (size_t i = 0; i < window_shift; ++i) {
221             RecordBufferDequeue(ctx, buffer_.at(i).result);
222           }
223           buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift);
224         }
225       }
226 
227       if (!status.ok()) {
228         return status;
229       }
230 
231       // Construct output tensors.
232       const size_t num_tuple_components = window_elements[0].size();
233       const int64_t num_window_elements = window_elements.size();
234       *end_of_sequence = false;
235       for (size_t idx = 0; idx < num_tuple_components; ++idx) {
236         DatasetBase* window_dataset;
237         std::vector<std::vector<Tensor>> window_component_elements;
238         window_component_elements.reserve(num_window_elements);
239         // Build the output tuple component by copying one slice
240         // from each input element in the window.
241         for (size_t i = 0; i < num_window_elements; ++i) {
242           std::vector<Tensor> component_element;
243           component_element.push_back(std::move(window_elements[i][idx]));
244           window_component_elements.push_back(component_element);
245         }
246         DataTypeVector output_types({dataset()->input_->output_dtypes()[idx]});
247         std::vector<PartialTensorShape> output_shapes(
248             {dataset()->input_->output_shapes()[idx]});
249         TF_RETURN_IF_ERROR(NewWindow(window_component_elements, output_types,
250                                      output_shapes, &window_dataset));
251         out_tensors->emplace_back(DT_VARIANT, TensorShape({}));
252         TF_RETURN_IF_ERROR(
253             StoreDatasetInVariantTensor(window_dataset, &out_tensors->back()));
254       }
255       return OkStatus();
256     }
257 
258    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const259     std::shared_ptr<model::Node> CreateNode(
260         IteratorContext* ctx, model::Node::Args args) const override {
261       return model::MakeKnownRatioNode(std::move(args),
262                                        dataset()->window_shift_);
263     }
264 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)265     Status SaveInternal(SerializationContext* ctx,
266                         IteratorStateWriter* writer) override {
267       mutex_lock l(mu_);
268       if (!input_impl_) {
269         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
270       } else {
271         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
272       }
273       // Save buffer.
274       TF_RETURN_IF_ERROR(
275           writer->WriteScalar(full_name(kBufferSize), buffer_.size()));
276       for (int64_t i = 0; i < buffer_.size(); i++) {
277         TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status));
278         TF_RETURN_IF_ERROR(writer->WriteScalar(
279             full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)),
280             buffer_[i].result.size()));
281         for (int64_t j = 0; j < buffer_[i].result.size(); j++) {
282           TF_RETURN_IF_ERROR(writer->WriteTensor(
283               full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")),
284               buffer_[i].result[j]));
285         }
286       }
287       return OkStatus();
288     }
289 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)290     Status RestoreInternal(IteratorContext* ctx,
291                            IteratorStateReader* reader) override {
292       mutex_lock l(mu_);
293       if (!reader->Contains(full_name(kInputImplEmpty))) {
294         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
295       } else {
296         input_impl_.reset();
297       }
298       // Restore buffer.
299       int64_t buffer_size = 0;
300       TF_RETURN_IF_ERROR(
301           reader->ReadScalar(full_name(kBufferSize), &buffer_size));
302       buffer_.resize(buffer_size);
303       for (int64_t i = 0; i < buffer_size; i++) {
304         int64_t vector_size;
305         TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status));
306         TF_RETURN_IF_ERROR(reader->ReadScalar(
307             full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)),
308             &vector_size));
309         buffer_[i].result.resize(vector_size);
310         for (int64_t j = 0; j < vector_size; j++) {
311           TF_RETURN_IF_ERROR(reader->ReadTensor(
312               ctx->flr(),
313               full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")),
314               &buffer_[i].result[j]));
315         }
316       }
317       return OkStatus();
318     }
319 
GetTraceMeMetadata() const320     TraceMeMetadata GetTraceMeMetadata() const override {
321       return dataset()->traceme_metadata_;
322     }
323 
324    private:
325     struct InvocationResult {
326       InvocationResult() = default;
InvocationResulttensorflow::data::WindowDatasetOp::Dataset::Iterator::InvocationResult327       InvocationResult(std::vector<Tensor>&& result, const Status& status)
328           : result(result), status(status) {}
329 
330       std::vector<Tensor> result;
331       Status status;
332     };
333 
WriteStatusLocked(IteratorStateWriter * writer,size_t index,const Status & status)334     Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
335                              const Status& status)
336         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
337       TF_RETURN_IF_ERROR(writer->WriteScalar(
338           CodeKey(index), static_cast<int64_t>(status.code())));
339       if (!status.ok()) {
340         TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
341                                                status.error_message()));
342       }
343       return OkStatus();
344     }
345 
ReadStatusLocked(IteratorStateReader * reader,size_t index,Status * status)346     Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
347                             Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
348       int64_t code_int;
349       TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
350       error::Code code = static_cast<error::Code>(code_int);
351 
352       if (code != error::Code::OK) {
353         tstring error_message;
354         TF_RETURN_IF_ERROR(
355             reader->ReadScalar(ErrorMessageKey(index), &error_message));
356         *status = Status(code, error_message);
357       } else {
358         *status = OkStatus();
359       }
360       return OkStatus();
361     }
362 
CodeKey(size_t index)363     string CodeKey(size_t index) {
364       return full_name(strings::StrCat(kBuffer, "[", index, "]", kCodeSuffix));
365     }
366 
ErrorMessageKey(size_t index)367     string ErrorMessageKey(size_t index) {
368       return full_name(
369           strings::StrCat(kBuffer, "[", index, "]", kErrorMessage));
370     }
371 
TargetBufferSize(int64_t window_size,int64_t window_stride)372     size_t TargetBufferSize(int64_t window_size, int64_t window_stride) {
373       return (window_size - 1) * window_stride + 1;
374     }
375 
376     mutex mu_;
377     std::deque<InvocationResult> buffer_ TF_GUARDED_BY(mu_);
378     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
379   };
380 
381   const DatasetBase* const input_;
382   const int64_t window_size_;
383   const int64_t window_shift_;
384   const int64_t window_stride_;
385   const bool drop_remainder_;
386   const DataTypeVector output_dtypes_;
387   const std::vector<PartialTensorShape> output_shapes_;
388   const TraceMeMetadata traceme_metadata_;
389 };
390 
WindowDatasetOp(OpKernelConstruction * ctx)391 WindowDatasetOp::WindowDatasetOp(OpKernelConstruction* ctx)
392     : UnaryDatasetOpKernel(ctx) {}
393 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)394 void WindowDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
395                                   DatasetBase** output) {
396   int64_t window_size = 0;
397   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSize, &window_size));
398   OP_REQUIRES(
399       ctx, window_size > 0,
400       errors::InvalidArgument("Window size must be greater than zero."));
401 
402   int64_t window_shift = 0;
403   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kShift, &window_shift));
404   OP_REQUIRES(
405       ctx, window_shift > 0,
406       errors::InvalidArgument("Window shift must be greater than zero."));
407 
408   int64_t window_stride = 0;
409   OP_REQUIRES_OK(ctx,
410                  ParseScalarArgument<int64_t>(ctx, kStride, &window_stride));
411   OP_REQUIRES(
412       ctx, window_stride > 0,
413       errors::InvalidArgument("Window stride must be greater than zero."));
414 
415   bool drop_remainder;
416   OP_REQUIRES_OK(
417       ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
418 
419   *output = new Dataset(ctx, input, window_size, window_shift, window_stride,
420                         drop_remainder);
421 }
422 
423 namespace {
424 REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
425                         WindowDatasetOp);
426 }  // namespace
427 }  // namespace data
428 }  // namespace tensorflow
429