xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/iterator_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
18 
19 #include <utility>
20 
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/metric_utils.h"
24 #include "tensorflow/core/data/unbounded_thread_pool.h"
25 #include "tensorflow/core/framework/dataset.h"
26 #include "tensorflow/core/framework/function_handle_cache.h"
27 #include "tensorflow/core/framework/metrics.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/kernels/ops_util.h"
32 #include "tensorflow/core/platform/refcount.h"
33 
34 namespace tensorflow {
35 namespace data {
36 
37 class IteratorResource : public ResourceBase {
38  public:
39   IteratorResource(Env* env, const DataTypeVector& output_dtypes,
40                    const std::vector<PartialTensorShape>& output_shapes,
41                    std::unique_ptr<DeviceMgr> device_mgr,
42                    std::unique_ptr<FunctionLibraryDefinition> flib_def,
43                    std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
44                    FunctionLibraryRuntime* flr);
45 
46   ~IteratorResource() override;
47 
48   // Gets the next output from the iterator managed by this iterator resource.
49   //
50   // If at least one output remains, that output will be stored in
51   // `*out_tensors` and `false` will be stored in `*end_of_sequence`.
52   //
53   // If no more outputs remain, `true` will be stored in `*end_of_sequence`, and
54   // the content of `*out_tensors` will be undefined.
55   Status GetNext(OpKernelContext* ctx, std::vector<Tensor>* out_tensors,
56                  bool* end_of_sequence);
57 
58   // Saves a checkpoint of the state of the iterator through the given `writer`.
59   Status Save(SerializationContext* ctx, IteratorStateWriter* writer);
60 
61   // Restores the state of the iterator from a checkpoint created by `Save`.
62   Status Restore(OpKernelContext* ctx, IteratorStateReader* reader);
63 
64   // Creates an iterator for `dataset`, and associates the iterator with this
65   // iterator resource.
66   //
67   // `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`,
68   // or `Restore`.
69   Status SetIteratorFromDataset(OpKernelContext* ctx,
70                                 const DatasetBase* dataset);
71 
DebugString()72   string DebugString() const override { return "Iterator resource"; }
73 
output_dtypes()74   const DataTypeVector& output_dtypes() const { return output_dtypes_; }
75 
output_shapes()76   const std::vector<PartialTensorShape>& output_shapes() const {
77     return output_shapes_;
78   }
79 
80  private:
81   class State {
82    public:
State(std::shared_ptr<FunctionLibraryDefinition> flib_def,std::shared_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr,std::unique_ptr<DatasetBaseIterator> iterator)83     State(std::shared_ptr<FunctionLibraryDefinition> flib_def,
84           std::shared_ptr<ProcessFunctionLibraryRuntime> pflr,
85           FunctionLibraryRuntime* flr,
86           std::unique_ptr<DatasetBaseIterator> iterator)
87         : flib_def_(std::move(flib_def)),
88           flr_(flr),
89           pflr_(std::move(pflr)),
90           function_handle_cache_(std::make_unique<FunctionHandleCache>(flr)),
91           iterator_(std::move(iterator)) {}
92 
~State()93     ~State() { cancellation_manager_.StartCancel(); }
94 
95     // Downcasts the given `IteratorBase` to a `DatasetBaseIterator`, and uses
96     // it to set the `iterator` and the `dataset` field.
DowncastAndSetIteratorAndDataset(std::unique_ptr<IteratorBase> it,const DatasetBase * dataset)97     void DowncastAndSetIteratorAndDataset(std::unique_ptr<IteratorBase> it,
98                                           const DatasetBase* dataset) {
99       iterator_.reset(static_cast<DatasetBaseIterator*>(it.release()));
100       if (dataset) {
101         dataset->Ref();
102         dataset_.reset(const_cast<DatasetBase*>(dataset));
103       }
104     }
105 
flib_def()106     std::shared_ptr<FunctionLibraryDefinition> flib_def() { return flib_def_; }
107 
flr()108     FunctionLibraryRuntime* flr() { return flr_; }
109 
pflr()110     std::shared_ptr<ProcessFunctionLibraryRuntime> pflr() { return pflr_; }
111 
function_handle_cache()112     FunctionHandleCache* function_handle_cache() {
113       return function_handle_cache_.get();
114     }
115 
resource_mgr()116     ResourceMgr* resource_mgr() { return &resource_mgr_; }
117 
cancellation_manager()118     CancellationManager* cancellation_manager() {
119       return &cancellation_manager_;
120     }
121 
iterator()122     DatasetBaseIterator* iterator() { return iterator_.get(); }
123 
dataset()124     DatasetBase* dataset() { return dataset_.get(); }
125 
126    private:
127     std::shared_ptr<FunctionLibraryDefinition> flib_def_;
128     FunctionLibraryRuntime* flr_ = nullptr;  // not owned
129     std::shared_ptr<ProcessFunctionLibraryRuntime> pflr_;
130     std::unique_ptr<FunctionHandleCache> function_handle_cache_;
131     ResourceMgr resource_mgr_;
132     CancellationManager cancellation_manager_;
133     std::unique_ptr<DatasetBaseIterator> iterator_;
134     core::RefCountPtr<DatasetBase> dataset_;
135   };
136 
137   IteratorMetricsCollector metrics_collector_;
138   UnboundedThreadPool unbounded_thread_pool_;
139 
140   mutex mu_;
141   const std::unique_ptr<DeviceMgr> device_mgr_ TF_GUARDED_BY(mu_);
142   std::shared_ptr<State> iterator_state_ TF_GUARDED_BY(mu_);
143   const DataTypeVector output_dtypes_;
144   const std::vector<PartialTensorShape> output_shapes_;
145 };
146 
147 class IteratorHandleOp : public OpKernel {
148  public:
149   explicit IteratorHandleOp(OpKernelConstruction* ctx);
150 
151   // The resource is deleted from the resource manager only when it is private
152   // to kernel. Ideally the resource should be deleted when it is no longer held
153   // by anyone, but it would break backward compatibility.
154   ~IteratorHandleOp() override;
155 
156   void Compute(OpKernelContext* context) override TF_LOCKS_EXCLUDED(mu_);
157 
158  private:
159   // During the first Compute(), resource is either created or looked up using
160   // shared_name. In the latter case, the resource found should be verified if
161   // it is compatible with this op's configuration. The verification may fail in
162   // cases such as two graphs asking queues of the same shared name to have
163   // inconsistent capacities.
164   Status VerifyResource(IteratorResource* resource);
165 
166   FunctionLibraryRuntime* CreatePrivateFLR(
167       OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
168       std::unique_ptr<FunctionLibraryDefinition>* flib_def,
169       std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr);
170 
171   mutex mu_;
172   ContainerInfo cinfo_;  // Written once under mu_ then constant afterwards.
173   IteratorResource* resource_ TF_GUARDED_BY(mu_) = nullptr;
174   DataTypeVector output_dtypes_;
175   std::vector<PartialTensorShape> output_shapes_;
176   const int graph_def_version_;
177   string name_;
178 };
179 
180 // Like IteratorHandleOp, but creates handles which are never shared, and does
181 // not hold a reference to these handles. The latter is important for eager
182 // execution, since OpKernel instances generally live as long as the program
183 // running them.
184 class AnonymousIteratorHandleOp : public AnonymousResourceOp<IteratorResource> {
185  public:
186   explicit AnonymousIteratorHandleOp(OpKernelConstruction* context);
187 
188  private:
189   string name() override;
190 
191   Status CreateResource(OpKernelContext* ctx,
192                         std::unique_ptr<FunctionLibraryDefinition> flib_def,
193                         std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
194                         FunctionLibraryRuntime* lib,
195                         IteratorResource** resource) override;
196 
197   DataTypeVector output_dtypes_;
198   std::vector<PartialTensorShape> output_shapes_;
199   const int graph_def_version_;
200 };
201 
202 // A hybrid asynchronous-and-synchronous OpKernel with efficient support for
203 // both modes.
204 //
205 // Inherit from this class when the application logic of the kernel (i) is
206 // implemented synchronously, (ii) must run on a background thread when the
207 // kernel executes in the inter-op threadpool (typically because it depends on
208 // inter-op threadpool threads, e.g. for function execution), and (iii) can run
209 // synchronously on the calling thread when the caller donates a thread
210 // (typically in eager execution). The implementation avoids a thread-hop in
211 // case (iii).
212 //
213 // NOTE: Unlike typical OpKernel subclasses, the application logic is
214 // implemented in a method (DoCompute()) that returns Status. Use
215 // TF_RETURN_IF_ERROR for error-related control flow rather than
216 // OP_REQUIRES_OK().
217 class HybridAsyncOpKernel : public AsyncOpKernel {
218  public:
219   HybridAsyncOpKernel(OpKernelConstruction* ctx,
220                       const char* background_worker_name);
221 
222   void Compute(OpKernelContext* ctx) final;
223   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) final;
224 
225  protected:
226   virtual Status DoCompute(OpKernelContext* ctx) = 0;
227 
228  private:
229   BackgroundWorker background_worker_;
230 };
231 
232 class MakeIteratorOp : public HybridAsyncOpKernel {
233  public:
MakeIteratorOp(OpKernelConstruction * ctx)234   explicit MakeIteratorOp(OpKernelConstruction* ctx)
235       : HybridAsyncOpKernel(ctx, "tf_data_make_iterator") {}
236 
237  protected:
238   Status DoCompute(OpKernelContext* ctx) override;
239 };
240 
241 class IteratorGetNextOp : public HybridAsyncOpKernel {
242  public:
IteratorGetNextOp(OpKernelConstruction * ctx)243   explicit IteratorGetNextOp(OpKernelConstruction* ctx)
244       : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {
245     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
246     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
247   }
248 
249   AsyncOpKernel* AsAsync() override;
250 
251  protected:
252   Status DoCompute(OpKernelContext* ctx) override;
253 
254  private:
255   DataTypeVector output_types_;
256   std::vector<PartialTensorShape> output_shapes_;
257 };
258 
259 class DeleteIteratorOp : public HybridAsyncOpKernel {
260  public:
DeleteIteratorOp(OpKernelConstruction * ctx)261   explicit DeleteIteratorOp(OpKernelConstruction* ctx)
262       : HybridAsyncOpKernel(ctx, "tf_data_delete_iterator") {}
263 
264  protected:
265   Status DoCompute(OpKernelContext* ctx) override;
266 };
267 
268 class IteratorGetNextAsOptionalOp : public HybridAsyncOpKernel {
269  public:
IteratorGetNextAsOptionalOp(OpKernelConstruction * ctx)270   explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
271       : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next_as_optional") {
272     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
273     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
274   }
275 
276  protected:
277   Status DoCompute(OpKernelContext* ctx) override;
278 
279  private:
280   DataTypeVector output_types_;
281   std::vector<PartialTensorShape> output_shapes_;
282 };
283 
284 class IteratorToStringHandleOp : public OpKernel {
285  public:
IteratorToStringHandleOp(OpKernelConstruction * ctx)286   explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
287       : OpKernel(ctx) {}
288 
289   void Compute(OpKernelContext* ctx) override;
290 };
291 
292 class IteratorFromStringHandleOp : public OpKernel {
293  public:
294   explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx);
295 
296   void Compute(OpKernelContext* ctx) override;
297 
298  private:
299   DataTypeVector output_dtypes_;
300   std::vector<PartialTensorShape> output_shapes_;
301 };
302 
303 class SerializeIteratorOp : public OpKernel {
304  public:
305   static constexpr const char* const kExternalStatePolicy =
306       "external_state_policy";
307 
308   explicit SerializeIteratorOp(OpKernelConstruction* ctx);
309 
310   void Compute(OpKernelContext* ctx) override;
311 
312  private:
313   SerializationContext::ExternalStatePolicy external_state_policy_ =
314       SerializationContext::ExternalStatePolicy::kWarn;
315 };
316 
317 class DeserializeIteratorOp : public OpKernel {
318  public:
DeserializeIteratorOp(OpKernelConstruction * ctx)319   explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
320 
321   void Compute(OpKernelContext* ctx) override;
322 };
323 
324 }  // namespace data
325 }  // namespace tensorflow
326 
327 #endif  // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
328