1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/window_dataset.h"
16
17 #include <string>
18 #include <utility>
19
20 #include "tensorflow/core/data/name_utils.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/platform/errors.h"
24
25 namespace tensorflow {
26 namespace data {
27 namespace {
28
29 constexpr char kInputs[] = "inputs";
30 constexpr char kOutputTypes[] = "output_types";
31 constexpr char kOutputShapes[] = "output_shapes";
32 constexpr char kWindow[] = "Window";
33 constexpr char kWindowOp[] = "WindowOp";
34 constexpr char kCurIndex[] = "i";
35
36 class Window : public DatasetBase {
37 public:
Window(std::vector<std::vector<Tensor>> elements,DataTypeVector output_types,std::vector<PartialTensorShape> output_shapes)38 Window(std::vector<std::vector<Tensor>> elements, DataTypeVector output_types,
39 std::vector<PartialTensorShape> output_shapes)
40 : DatasetBase(DatasetContext({kWindowOp, kWindow})),
41 elements_(std::move(elements)),
42 output_types_(std::move(output_types)),
43 output_shapes_(std::move(output_shapes)) {}
44
MakeIteratorInternal(const string & prefix) const45 std::unique_ptr<IteratorBase> MakeIteratorInternal(
46 const string& prefix) const override {
47 return std::make_unique<Iterator>(
48 Iterator::Params{this, name_utils::IteratorPrefix(kWindow, prefix)});
49 }
50
output_dtypes() const51 const DataTypeVector& output_dtypes() const override { return output_types_; }
52
output_shapes() const53 const std::vector<PartialTensorShape>& output_shapes() const override {
54 return output_shapes_;
55 }
56
AllocatedBytes() const57 int64_t AllocatedBytes() const override {
58 int64_t allocated_bytes = 0;
59 for (auto& element : elements_) {
60 allocated_bytes += GetAllocatedBytes(element);
61 }
62 return allocated_bytes;
63 }
64
TotalBytes() const65 int64_t TotalBytes() const override {
66 int64_t total_bytes = 0;
67 for (auto& element : elements_) {
68 total_bytes += GetTotalBytes(element);
69 }
70 return total_bytes;
71 }
72
CardinalityInternal() const73 int64_t CardinalityInternal() const override { return elements_.size(); }
74
DebugString() const75 string DebugString() const override { return kWindow; }
76
InputDatasets(std::vector<const DatasetBase * > * inputs) const77 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
78 return OkStatus();
79 }
80
CheckExternalState() const81 Status CheckExternalState() const override { return OkStatus(); }
82
83 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const84 Status AsGraphDefInternal(SerializationContext* ctx,
85 DatasetGraphDefBuilder* b,
86 Node** output) const override {
87 if (ctx->is_graph_rewrite()) {
88 // If data tensors are not to be serialized (e.g. when the serialization
89 // is done for the sake of graph optimizations), we return
90 // `errors::Unimplemented` to short-circuit the computation.
91 return errors::Unimplemented(DebugString(),
92 " does not support serialization");
93 }
94 std::vector<Node*> input_nodes;
95 for (const auto& element : elements_) {
96 for (const auto& t : element) {
97 Node* node;
98 TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
99 input_nodes.emplace_back(node);
100 }
101 }
102 TF_RETURN_IF_ERROR(
103 b->AddDataset(this, {}, {std::make_pair(0, input_nodes)}, {}, output));
104 return OkStatus();
105 }
106
107 private:
108 class Iterator : public DatasetIterator<Window> {
109 public:
Iterator(const Params & params)110 explicit Iterator(const Params& params) : DatasetIterator<Window>(params) {}
111
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)112 Status GetNextInternal(IteratorContext* ctx,
113 std::vector<Tensor>* out_tensors,
114 bool* end_of_sequence) override {
115 mutex_lock l(mu_);
116 if (i_ == dataset()->elements_.size()) {
117 *end_of_sequence = true;
118 } else {
119 *end_of_sequence = false;
120 *out_tensors = dataset()->elements_[i_++];
121 }
122 return OkStatus();
123 }
124
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)125 Status SaveInternal(SerializationContext* ctx,
126 IteratorStateWriter* writer) override {
127 mutex_lock l(mu_);
128 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
129 return OkStatus();
130 }
131
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)132 Status RestoreInternal(IteratorContext* ctx,
133 IteratorStateReader* reader) override {
134 mutex_lock l(mu_);
135 int64_t i;
136 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &i));
137 i_ = size_t(i);
138 return OkStatus();
139 }
140
141 mutex mu_;
142 size_t i_ TF_GUARDED_BY(mu_) = 0;
143 };
144
145 const std::vector<std::vector<Tensor>> elements_;
146 const DataTypeVector output_types_;
147 const std::vector<PartialTensorShape> output_shapes_;
148 };
149
150 class WindowOp : public DatasetOpKernel {
151 public:
WindowOp(OpKernelConstruction * ctx)152 explicit WindowOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
153 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
154 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
155 }
156
157 protected:
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)158 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
159 OpInputList inputs;
160 OP_REQUIRES_OK(ctx, ctx->input_list(kInputs, &inputs));
161 auto element_size = output_shapes_.size();
162 auto num_elements = ctx->num_inputs() / element_size;
163 std::vector<std::vector<Tensor>> elements;
164 for (size_t i = 0; i < num_elements; ++i) {
165 std::vector<Tensor> element;
166 for (size_t j = 0; j < element_size; ++j) {
167 element.push_back(std::move(inputs[i * element_size + j]));
168 }
169 elements.push_back(std::move(element));
170 }
171 *output = new Window(std::move(elements), output_types_, output_shapes_);
172 }
173
174 private:
175 DataTypeVector output_types_;
176 std::vector<PartialTensorShape> output_shapes_;
177 };
178
179 REGISTER_KERNEL_BUILDER(Name("WindowOp").Device(DEVICE_CPU), WindowOp);
180
181 } // namespace
182
NewWindow(std::vector<std::vector<Tensor>> elements,DataTypeVector output_types,std::vector<PartialTensorShape> output_shapes,DatasetBase ** out_dataset)183 Status NewWindow(std::vector<std::vector<Tensor>> elements,
184 DataTypeVector output_types,
185 std::vector<PartialTensorShape> output_shapes,
186 DatasetBase** out_dataset) {
187 // TODO(mrry): If this becomes more public, we must validate that
188 // the elements match the output_types and output_shapes.
189 *out_dataset = new Window(std::move(elements), std::move(output_types),
190 std::move(output_shapes));
191 (*out_dataset)->Initialize(/*metadata=*/{});
192 return OkStatus();
193 }
194
195 } // namespace data
196 } // namespace tensorflow
197