1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/window_dataset_op.h"
16
17 #include "tensorflow/core/data/name_utils.h"
18 #include "tensorflow/core/framework/dataset.h"
19 #include "tensorflow/core/kernels/data/window_dataset.h"
20 #include "tensorflow/core/platform/stringprintf.h"
21
22 namespace tensorflow {
23 namespace data {
24
25 // See documentation in ../../ops/dataset_ops.cc for a high-level
26 // description of the following op.
27
28 /* static */ constexpr const char* const WindowDatasetOp::kDatasetType;
29 /* static */ constexpr const char* const WindowDatasetOp::kInputDataset;
30 /* static */ constexpr const char* const WindowDatasetOp::kSize;
31 /* static */ constexpr const char* const WindowDatasetOp::kShift;
32 /* static */ constexpr const char* const WindowDatasetOp::kStride;
33 /* static */ constexpr const char* const WindowDatasetOp::kDropRemainder;
34 /* static */ constexpr const char* const WindowDatasetOp::kOutputTypes;
35 /* static */ constexpr const char* const WindowDatasetOp::kOutputShapes;
36
37 constexpr char kInputImplEmpty[] = "input_impl_empty";
38 constexpr char kBufferSize[] = "buffer_size";
39 constexpr char kBuffer[] = "buffer";
40 constexpr char kSizeSuffix[] = ".size";
41 constexpr char kCodeSuffix[] = ".code";
42 constexpr char kErrorMessage[] = ".error_message";
43
44 class WindowDatasetOp::Dataset : public DatasetBase {
45 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t window_size,int64_t window_shift,int64_t window_stride,bool drop_remainder)46 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t window_size,
47 int64_t window_shift, int64_t window_stride, bool drop_remainder)
48 : DatasetBase(DatasetContext(ctx)),
49 input_(input),
50 window_size_(window_size),
51 window_shift_(window_shift),
52 window_stride_(window_stride),
53 drop_remainder_(drop_remainder),
54 output_dtypes_(input_->output_dtypes().size(), {DT_VARIANT}),
55 output_shapes_(input_->output_shapes().size(), TensorShape({})),
56 traceme_metadata_(
57 {{"window_size",
58 strings::Printf("%lld", static_cast<long long>(window_size))},
59 {"window_shift",
60 strings::Printf("%lld", static_cast<long long>(window_shift))},
61 {"window_stride", strings::Printf("%lld", static_cast<long long>(
62 window_stride))}}) {
63 input_->Ref();
64 }
65
~Dataset()66 ~Dataset() override { input_->Unref(); }
67
MakeIteratorInternal(const string & prefix) const68 std::unique_ptr<IteratorBase> MakeIteratorInternal(
69 const string& prefix) const override {
70 return std::make_unique<Iterator>(Iterator::Params{
71 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
72 }
73
output_dtypes() const74 const DataTypeVector& output_dtypes() const override {
75 return output_dtypes_;
76 }
77
output_shapes() const78 const std::vector<PartialTensorShape>& output_shapes() const override {
79 return output_shapes_;
80 }
81
DebugString() const82 string DebugString() const override {
83 name_utils::DatasetDebugStringParams params;
84 params.set_args(window_size_, window_shift_, window_stride_,
85 drop_remainder_);
86 return name_utils::DatasetDebugString(kDatasetType, params);
87 }
88
CardinalityInternal() const89 int64_t CardinalityInternal() const override {
90 int64_t n = input_->Cardinality();
91 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
92 return n;
93 }
94 int64_t cardinality = 0;
95 if (drop_remainder_) {
96 // Compute rest_elements, the number of elements after the last element
97 // of the initial window. If it is negative, we know that the
98 // cardinality is 0. Otherwise, it will be the number of valid shifts
99 // over the rest_elements.
100 int64_t rest_elements = n - ((window_size_ - 1) * window_stride_ + 1);
101 cardinality = rest_elements < 0 ? 0 : rest_elements / window_shift_ + 1;
102 } else {
103 cardinality = n / window_shift_ + (n % window_shift_ == 0 ? 0 : 1);
104 }
105 return cardinality;
106 }
107
InputDatasets(std::vector<const DatasetBase * > * inputs) const108 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
109 inputs->push_back(input_);
110 return OkStatus();
111 }
112
CheckExternalState() const113 Status CheckExternalState() const override {
114 return input_->CheckExternalState();
115 }
116
117 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const118 Status AsGraphDefInternal(SerializationContext* ctx,
119 DatasetGraphDefBuilder* b,
120 Node** output) const override {
121 Node* input_graph_node = nullptr;
122 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
123 Node* window_size_node = nullptr;
124 TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node));
125 Node* window_shift_node = nullptr;
126 TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node));
127 Node* window_stride_node = nullptr;
128 TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node));
129 Node* drop_remainder_node = nullptr;
130 TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
131 TF_RETURN_IF_ERROR(
132 b->AddDataset(this,
133 {input_graph_node, window_size_node, window_shift_node,
134 window_stride_node, drop_remainder_node},
135 output));
136 return OkStatus();
137 }
138
139 private:
140 class Iterator : public DatasetIterator<Dataset> {
141 public:
Iterator(const Params & params)142 explicit Iterator(const Params& params)
143 : DatasetIterator<Dataset>(params) {}
144
Initialize(IteratorContext * ctx)145 Status Initialize(IteratorContext* ctx) override {
146 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
147 }
148
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)149 Status GetNextInternal(IteratorContext* ctx,
150 std::vector<Tensor>* out_tensors,
151 bool* end_of_sequence) override {
152 const int64_t window_size = dataset()->window_size_;
153 const int64_t window_shift = dataset()->window_shift_;
154 const int64_t window_stride = dataset()->window_stride_;
155 std::vector<std::vector<Tensor>> window_elements;
156 Status status = OkStatus();
157 {
158 const size_t target_size = TargetBufferSize(window_size, window_stride);
159
160 mutex_lock l(mu_);
161 if (!input_impl_ &&
162 (buffer_.empty() ||
163 (dataset()->drop_remainder_ && buffer_.size() < target_size))) {
164 *end_of_sequence = true;
165 return OkStatus();
166 }
167
168 // Add elements to the buffer.
169 if (input_impl_) {
170 *end_of_sequence = false;
171 for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence;
172 ++i) {
173 std::vector<Tensor> element;
174 Status status =
175 input_impl_->GetNext(ctx, &element, end_of_sequence);
176 if (!*end_of_sequence) {
177 RecordBufferEnqueue(ctx, element);
178 buffer_.emplace_back(std::move(element), status);
179 } else {
180 input_impl_.reset();
181 }
182 }
183 }
184
185 // If there are not enough elements and `drop_remainder` is set, we do
186 // not wish to return a smaller window.
187 if (buffer_.empty() ||
188 (dataset()->drop_remainder_ && buffer_.size() < target_size)) {
189 DCHECK(*end_of_sequence);
190 return OkStatus();
191 }
192
193 int num_elements = 1 + (buffer_.size() - 1) / window_stride;
194 window_elements.reserve(num_elements);
195 for (size_t i = 0; i < num_elements; ++i) {
196 status.Update(buffer_[window_stride * i].status);
197 if (!status.ok()) {
198 break;
199 }
200 window_elements.emplace_back(buffer_[window_stride * i].result);
201 }
202
203 // Shift the window, discarding elements if necessary.
204 int buffer_size = buffer_.size();
205 if (window_shift >= buffer_size) {
206 for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) {
207 bool end_of_input;
208 std::vector<Tensor> element;
209 // Ignore non-error status of discarded elements.
210 input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError();
211 if (end_of_input) {
212 input_impl_.reset();
213 }
214 }
215 for (size_t i = 0; i < buffer_.size(); ++i) {
216 RecordBufferDequeue(ctx, buffer_.at(i).result);
217 }
218 buffer_.clear();
219 } else {
220 for (size_t i = 0; i < window_shift; ++i) {
221 RecordBufferDequeue(ctx, buffer_.at(i).result);
222 }
223 buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift);
224 }
225 }
226
227 if (!status.ok()) {
228 return status;
229 }
230
231 // Construct output tensors.
232 const size_t num_tuple_components = window_elements[0].size();
233 const int64_t num_window_elements = window_elements.size();
234 *end_of_sequence = false;
235 for (size_t idx = 0; idx < num_tuple_components; ++idx) {
236 DatasetBase* window_dataset;
237 std::vector<std::vector<Tensor>> window_component_elements;
238 window_component_elements.reserve(num_window_elements);
239 // Build the output tuple component by copying one slice
240 // from each input element in the window.
241 for (size_t i = 0; i < num_window_elements; ++i) {
242 std::vector<Tensor> component_element;
243 component_element.push_back(std::move(window_elements[i][idx]));
244 window_component_elements.push_back(component_element);
245 }
246 DataTypeVector output_types({dataset()->input_->output_dtypes()[idx]});
247 std::vector<PartialTensorShape> output_shapes(
248 {dataset()->input_->output_shapes()[idx]});
249 TF_RETURN_IF_ERROR(NewWindow(window_component_elements, output_types,
250 output_shapes, &window_dataset));
251 out_tensors->emplace_back(DT_VARIANT, TensorShape({}));
252 TF_RETURN_IF_ERROR(
253 StoreDatasetInVariantTensor(window_dataset, &out_tensors->back()));
254 }
255 return OkStatus();
256 }
257
258 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const259 std::shared_ptr<model::Node> CreateNode(
260 IteratorContext* ctx, model::Node::Args args) const override {
261 return model::MakeKnownRatioNode(std::move(args),
262 dataset()->window_shift_);
263 }
264
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)265 Status SaveInternal(SerializationContext* ctx,
266 IteratorStateWriter* writer) override {
267 mutex_lock l(mu_);
268 if (!input_impl_) {
269 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
270 } else {
271 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
272 }
273 // Save buffer.
274 TF_RETURN_IF_ERROR(
275 writer->WriteScalar(full_name(kBufferSize), buffer_.size()));
276 for (int64_t i = 0; i < buffer_.size(); i++) {
277 TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status));
278 TF_RETURN_IF_ERROR(writer->WriteScalar(
279 full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)),
280 buffer_[i].result.size()));
281 for (int64_t j = 0; j < buffer_[i].result.size(); j++) {
282 TF_RETURN_IF_ERROR(writer->WriteTensor(
283 full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")),
284 buffer_[i].result[j]));
285 }
286 }
287 return OkStatus();
288 }
289
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)290 Status RestoreInternal(IteratorContext* ctx,
291 IteratorStateReader* reader) override {
292 mutex_lock l(mu_);
293 if (!reader->Contains(full_name(kInputImplEmpty))) {
294 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
295 } else {
296 input_impl_.reset();
297 }
298 // Restore buffer.
299 int64_t buffer_size = 0;
300 TF_RETURN_IF_ERROR(
301 reader->ReadScalar(full_name(kBufferSize), &buffer_size));
302 buffer_.resize(buffer_size);
303 for (int64_t i = 0; i < buffer_size; i++) {
304 int64_t vector_size;
305 TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status));
306 TF_RETURN_IF_ERROR(reader->ReadScalar(
307 full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)),
308 &vector_size));
309 buffer_[i].result.resize(vector_size);
310 for (int64_t j = 0; j < vector_size; j++) {
311 TF_RETURN_IF_ERROR(reader->ReadTensor(
312 ctx->flr(),
313 full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")),
314 &buffer_[i].result[j]));
315 }
316 }
317 return OkStatus();
318 }
319
GetTraceMeMetadata() const320 TraceMeMetadata GetTraceMeMetadata() const override {
321 return dataset()->traceme_metadata_;
322 }
323
324 private:
325 struct InvocationResult {
326 InvocationResult() = default;
InvocationResulttensorflow::data::WindowDatasetOp::Dataset::Iterator::InvocationResult327 InvocationResult(std::vector<Tensor>&& result, const Status& status)
328 : result(result), status(status) {}
329
330 std::vector<Tensor> result;
331 Status status;
332 };
333
WriteStatusLocked(IteratorStateWriter * writer,size_t index,const Status & status)334 Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
335 const Status& status)
336 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
337 TF_RETURN_IF_ERROR(writer->WriteScalar(
338 CodeKey(index), static_cast<int64_t>(status.code())));
339 if (!status.ok()) {
340 TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
341 status.error_message()));
342 }
343 return OkStatus();
344 }
345
ReadStatusLocked(IteratorStateReader * reader,size_t index,Status * status)346 Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
347 Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
348 int64_t code_int;
349 TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
350 error::Code code = static_cast<error::Code>(code_int);
351
352 if (code != error::Code::OK) {
353 tstring error_message;
354 TF_RETURN_IF_ERROR(
355 reader->ReadScalar(ErrorMessageKey(index), &error_message));
356 *status = Status(code, error_message);
357 } else {
358 *status = OkStatus();
359 }
360 return OkStatus();
361 }
362
CodeKey(size_t index)363 string CodeKey(size_t index) {
364 return full_name(strings::StrCat(kBuffer, "[", index, "]", kCodeSuffix));
365 }
366
ErrorMessageKey(size_t index)367 string ErrorMessageKey(size_t index) {
368 return full_name(
369 strings::StrCat(kBuffer, "[", index, "]", kErrorMessage));
370 }
371
TargetBufferSize(int64_t window_size,int64_t window_stride)372 size_t TargetBufferSize(int64_t window_size, int64_t window_stride) {
373 return (window_size - 1) * window_stride + 1;
374 }
375
376 mutex mu_;
377 std::deque<InvocationResult> buffer_ TF_GUARDED_BY(mu_);
378 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
379 };
380
381 const DatasetBase* const input_;
382 const int64_t window_size_;
383 const int64_t window_shift_;
384 const int64_t window_stride_;
385 const bool drop_remainder_;
386 const DataTypeVector output_dtypes_;
387 const std::vector<PartialTensorShape> output_shapes_;
388 const TraceMeMetadata traceme_metadata_;
389 };
390
WindowDatasetOp(OpKernelConstruction * ctx)391 WindowDatasetOp::WindowDatasetOp(OpKernelConstruction* ctx)
392 : UnaryDatasetOpKernel(ctx) {}
393
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)394 void WindowDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
395 DatasetBase** output) {
396 int64_t window_size = 0;
397 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSize, &window_size));
398 OP_REQUIRES(
399 ctx, window_size > 0,
400 errors::InvalidArgument("Window size must be greater than zero."));
401
402 int64_t window_shift = 0;
403 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kShift, &window_shift));
404 OP_REQUIRES(
405 ctx, window_shift > 0,
406 errors::InvalidArgument("Window shift must be greater than zero."));
407
408 int64_t window_stride = 0;
409 OP_REQUIRES_OK(ctx,
410 ParseScalarArgument<int64_t>(ctx, kStride, &window_stride));
411 OP_REQUIRES(
412 ctx, window_stride > 0,
413 errors::InvalidArgument("Window stride must be greater than zero."));
414
415 bool drop_remainder;
416 OP_REQUIRES_OK(
417 ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
418
419 *output = new Dataset(ctx, input, window_size, window_shift, window_stride,
420 drop_remainder);
421 }
422
423 namespace {
424 REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
425 WindowDatasetOp);
426 } // namespace
427 } // namespace data
428 } // namespace tensorflow
429