xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/window_dataset.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/window_dataset.h"
16 
17 #include <string>
18 #include <utility>
19 
20 #include "tensorflow/core/data/name_utils.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/platform/errors.h"
24 
25 namespace tensorflow {
26 namespace data {
27 namespace {
28 
29 constexpr char kInputs[] = "inputs";
30 constexpr char kOutputTypes[] = "output_types";
31 constexpr char kOutputShapes[] = "output_shapes";
32 constexpr char kWindow[] = "Window";
33 constexpr char kWindowOp[] = "WindowOp";
34 constexpr char kCurIndex[] = "i";
35 
36 class Window : public DatasetBase {
37  public:
Window(std::vector<std::vector<Tensor>> elements,DataTypeVector output_types,std::vector<PartialTensorShape> output_shapes)38   Window(std::vector<std::vector<Tensor>> elements, DataTypeVector output_types,
39          std::vector<PartialTensorShape> output_shapes)
40       : DatasetBase(DatasetContext({kWindowOp, kWindow})),
41         elements_(std::move(elements)),
42         output_types_(std::move(output_types)),
43         output_shapes_(std::move(output_shapes)) {}
44 
MakeIteratorInternal(const string & prefix) const45   std::unique_ptr<IteratorBase> MakeIteratorInternal(
46       const string& prefix) const override {
47     return std::make_unique<Iterator>(
48         Iterator::Params{this, name_utils::IteratorPrefix(kWindow, prefix)});
49   }
50 
output_dtypes() const51   const DataTypeVector& output_dtypes() const override { return output_types_; }
52 
output_shapes() const53   const std::vector<PartialTensorShape>& output_shapes() const override {
54     return output_shapes_;
55   }
56 
AllocatedBytes() const57   int64_t AllocatedBytes() const override {
58     int64_t allocated_bytes = 0;
59     for (auto& element : elements_) {
60       allocated_bytes += GetAllocatedBytes(element);
61     }
62     return allocated_bytes;
63   }
64 
TotalBytes() const65   int64_t TotalBytes() const override {
66     int64_t total_bytes = 0;
67     for (auto& element : elements_) {
68       total_bytes += GetTotalBytes(element);
69     }
70     return total_bytes;
71   }
72 
CardinalityInternal() const73   int64_t CardinalityInternal() const override { return elements_.size(); }
74 
DebugString() const75   string DebugString() const override { return kWindow; }
76 
InputDatasets(std::vector<const DatasetBase * > * inputs) const77   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
78     return OkStatus();
79   }
80 
CheckExternalState() const81   Status CheckExternalState() const override { return OkStatus(); }
82 
83  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const84   Status AsGraphDefInternal(SerializationContext* ctx,
85                             DatasetGraphDefBuilder* b,
86                             Node** output) const override {
87     if (ctx->is_graph_rewrite()) {
88       // If data tensors are not to be serialized (e.g. when the serialization
89       // is done for the sake of graph optimizations), we return
90       // `errors::Unimplemented` to short-circuit the computation.
91       return errors::Unimplemented(DebugString(),
92                                    " does not support serialization");
93     }
94     std::vector<Node*> input_nodes;
95     for (const auto& element : elements_) {
96       for (const auto& t : element) {
97         Node* node;
98         TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
99         input_nodes.emplace_back(node);
100       }
101     }
102     TF_RETURN_IF_ERROR(
103         b->AddDataset(this, {}, {std::make_pair(0, input_nodes)}, {}, output));
104     return OkStatus();
105   }
106 
107  private:
108   class Iterator : public DatasetIterator<Window> {
109    public:
Iterator(const Params & params)110     explicit Iterator(const Params& params) : DatasetIterator<Window>(params) {}
111 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)112     Status GetNextInternal(IteratorContext* ctx,
113                            std::vector<Tensor>* out_tensors,
114                            bool* end_of_sequence) override {
115       mutex_lock l(mu_);
116       if (i_ == dataset()->elements_.size()) {
117         *end_of_sequence = true;
118       } else {
119         *end_of_sequence = false;
120         *out_tensors = dataset()->elements_[i_++];
121       }
122       return OkStatus();
123     }
124 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)125     Status SaveInternal(SerializationContext* ctx,
126                         IteratorStateWriter* writer) override {
127       mutex_lock l(mu_);
128       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
129       return OkStatus();
130     }
131 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)132     Status RestoreInternal(IteratorContext* ctx,
133                            IteratorStateReader* reader) override {
134       mutex_lock l(mu_);
135       int64_t i;
136       TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &i));
137       i_ = size_t(i);
138       return OkStatus();
139     }
140 
141     mutex mu_;
142     size_t i_ TF_GUARDED_BY(mu_) = 0;
143   };
144 
145   const std::vector<std::vector<Tensor>> elements_;
146   const DataTypeVector output_types_;
147   const std::vector<PartialTensorShape> output_shapes_;
148 };
149 
150 class WindowOp : public DatasetOpKernel {
151  public:
WindowOp(OpKernelConstruction * ctx)152   explicit WindowOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
153     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
154     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
155   }
156 
157  protected:
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)158   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
159     OpInputList inputs;
160     OP_REQUIRES_OK(ctx, ctx->input_list(kInputs, &inputs));
161     auto element_size = output_shapes_.size();
162     auto num_elements = ctx->num_inputs() / element_size;
163     std::vector<std::vector<Tensor>> elements;
164     for (size_t i = 0; i < num_elements; ++i) {
165       std::vector<Tensor> element;
166       for (size_t j = 0; j < element_size; ++j) {
167         element.push_back(std::move(inputs[i * element_size + j]));
168       }
169       elements.push_back(std::move(element));
170     }
171     *output = new Window(std::move(elements), output_types_, output_shapes_);
172   }
173 
174  private:
175   DataTypeVector output_types_;
176   std::vector<PartialTensorShape> output_shapes_;
177 };
178 
179 REGISTER_KERNEL_BUILDER(Name("WindowOp").Device(DEVICE_CPU), WindowOp);
180 
181 }  // namespace
182 
NewWindow(std::vector<std::vector<Tensor>> elements,DataTypeVector output_types,std::vector<PartialTensorShape> output_shapes,DatasetBase ** out_dataset)183 Status NewWindow(std::vector<std::vector<Tensor>> elements,
184                  DataTypeVector output_types,
185                  std::vector<PartialTensorShape> output_shapes,
186                  DatasetBase** out_dataset) {
187   // TODO(mrry): If this becomes more public, we must validate that
188   // the elements match the output_types and output_shapes.
189   *out_dataset = new Window(std::move(elements), std::move(output_types),
190                             std::move(output_shapes));
191   (*out_dataset)->Initialize(/*metadata=*/{});
192   return OkStatus();
193 }
194 
195 }  // namespace data
196 }  // namespace tensorflow
197