1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ 18 19 #include <utility> 20 21 #include "tensorflow/core/common_runtime/function.h" 22 #include "tensorflow/core/data/dataset_utils.h" 23 #include "tensorflow/core/data/metric_utils.h" 24 #include "tensorflow/core/data/unbounded_thread_pool.h" 25 #include "tensorflow/core/framework/dataset.h" 26 #include "tensorflow/core/framework/function_handle_cache.h" 27 #include "tensorflow/core/framework/metrics.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/kernels/ops_util.h" 32 #include "tensorflow/core/platform/refcount.h" 33 34 namespace tensorflow { 35 namespace data { 36 37 class IteratorResource : public ResourceBase { 38 public: 39 IteratorResource(Env* env, const DataTypeVector& output_dtypes, 40 const std::vector<PartialTensorShape>& output_shapes, 41 std::unique_ptr<DeviceMgr> device_mgr, 42 std::unique_ptr<FunctionLibraryDefinition> flib_def, 43 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, 44 FunctionLibraryRuntime* flr); 45 46 ~IteratorResource() override; 47 48 // Gets the next output from the iterator managed by this iterator resource. 49 // 50 // If at least one output remains, that output will be stored in 51 // `*out_tensors` and `false` will be stored in `*end_of_sequence`. 52 // 53 // If no more outputs remain, `true` will be stored in `*end_of_sequence`, and 54 // the content of `*out_tensors` will be undefined. 55 Status GetNext(OpKernelContext* ctx, std::vector<Tensor>* out_tensors, 56 bool* end_of_sequence); 57 58 // Saves a checkpoint of the state of the iterator through the given `writer`. 59 Status Save(SerializationContext* ctx, IteratorStateWriter* writer); 60 61 // Restores the state of the iterator from a checkpoint created by `Save`. 62 Status Restore(OpKernelContext* ctx, IteratorStateReader* reader); 63 64 // Creates an iterator for `dataset`, and associates the iterator with this 65 // iterator resource. 66 // 67 // `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`, 68 // or `Restore`. 69 Status SetIteratorFromDataset(OpKernelContext* ctx, 70 const DatasetBase* dataset); 71 DebugString()72 string DebugString() const override { return "Iterator resource"; } 73 output_dtypes()74 const DataTypeVector& output_dtypes() const { return output_dtypes_; } 75 output_shapes()76 const std::vector<PartialTensorShape>& output_shapes() const { 77 return output_shapes_; 78 } 79 80 private: 81 class State { 82 public: State(std::shared_ptr<FunctionLibraryDefinition> flib_def,std::shared_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr,std::unique_ptr<DatasetBaseIterator> iterator)83 State(std::shared_ptr<FunctionLibraryDefinition> flib_def, 84 std::shared_ptr<ProcessFunctionLibraryRuntime> pflr, 85 FunctionLibraryRuntime* flr, 86 std::unique_ptr<DatasetBaseIterator> iterator) 87 : flib_def_(std::move(flib_def)), 88 flr_(flr), 89 pflr_(std::move(pflr)), 90 function_handle_cache_(std::make_unique<FunctionHandleCache>(flr)), 91 iterator_(std::move(iterator)) {} 92 ~State()93 ~State() { cancellation_manager_.StartCancel(); } 94 95 // Downcasts the given `IteratorBase` to a `DatasetBaseIterator`, and uses 96 // it to set the `iterator` and the `dataset` field. DowncastAndSetIteratorAndDataset(std::unique_ptr<IteratorBase> it,const DatasetBase * dataset)97 void DowncastAndSetIteratorAndDataset(std::unique_ptr<IteratorBase> it, 98 const DatasetBase* dataset) { 99 iterator_.reset(static_cast<DatasetBaseIterator*>(it.release())); 100 if (dataset) { 101 dataset->Ref(); 102 dataset_.reset(const_cast<DatasetBase*>(dataset)); 103 } 104 } 105 flib_def()106 std::shared_ptr<FunctionLibraryDefinition> flib_def() { return flib_def_; } 107 flr()108 FunctionLibraryRuntime* flr() { return flr_; } 109 pflr()110 std::shared_ptr<ProcessFunctionLibraryRuntime> pflr() { return pflr_; } 111 function_handle_cache()112 FunctionHandleCache* function_handle_cache() { 113 return function_handle_cache_.get(); 114 } 115 resource_mgr()116 ResourceMgr* resource_mgr() { return &resource_mgr_; } 117 cancellation_manager()118 CancellationManager* cancellation_manager() { 119 return &cancellation_manager_; 120 } 121 iterator()122 DatasetBaseIterator* iterator() { return iterator_.get(); } 123 dataset()124 DatasetBase* dataset() { return dataset_.get(); } 125 126 private: 127 std::shared_ptr<FunctionLibraryDefinition> flib_def_; 128 FunctionLibraryRuntime* flr_ = nullptr; // not owned 129 std::shared_ptr<ProcessFunctionLibraryRuntime> pflr_; 130 std::unique_ptr<FunctionHandleCache> function_handle_cache_; 131 ResourceMgr resource_mgr_; 132 CancellationManager cancellation_manager_; 133 std::unique_ptr<DatasetBaseIterator> iterator_; 134 core::RefCountPtr<DatasetBase> dataset_; 135 }; 136 137 IteratorMetricsCollector metrics_collector_; 138 UnboundedThreadPool unbounded_thread_pool_; 139 140 mutex mu_; 141 const std::unique_ptr<DeviceMgr> device_mgr_ TF_GUARDED_BY(mu_); 142 std::shared_ptr<State> iterator_state_ TF_GUARDED_BY(mu_); 143 const DataTypeVector output_dtypes_; 144 const std::vector<PartialTensorShape> output_shapes_; 145 }; 146 147 class IteratorHandleOp : public OpKernel { 148 public: 149 explicit IteratorHandleOp(OpKernelConstruction* ctx); 150 151 // The resource is deleted from the resource manager only when it is private 152 // to kernel. Ideally the resource should be deleted when it is no longer held 153 // by anyone, but it would break backward compatibility. 154 ~IteratorHandleOp() override; 155 156 void Compute(OpKernelContext* context) override TF_LOCKS_EXCLUDED(mu_); 157 158 private: 159 // During the first Compute(), resource is either created or looked up using 160 // shared_name. In the latter case, the resource found should be verified if 161 // it is compatible with this op's configuration. The verification may fail in 162 // cases such as two graphs asking queues of the same shared name to have 163 // inconsistent capacities. 164 Status VerifyResource(IteratorResource* resource); 165 166 FunctionLibraryRuntime* CreatePrivateFLR( 167 OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr, 168 std::unique_ptr<FunctionLibraryDefinition>* flib_def, 169 std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr); 170 171 mutex mu_; 172 ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. 173 IteratorResource* resource_ TF_GUARDED_BY(mu_) = nullptr; 174 DataTypeVector output_dtypes_; 175 std::vector<PartialTensorShape> output_shapes_; 176 const int graph_def_version_; 177 string name_; 178 }; 179 180 // Like IteratorHandleOp, but creates handles which are never shared, and does 181 // not hold a reference to these handles. The latter is important for eager 182 // execution, since OpKernel instances generally live as long as the program 183 // running them. 184 class AnonymousIteratorHandleOp : public AnonymousResourceOp<IteratorResource> { 185 public: 186 explicit AnonymousIteratorHandleOp(OpKernelConstruction* context); 187 188 private: 189 string name() override; 190 191 Status CreateResource(OpKernelContext* ctx, 192 std::unique_ptr<FunctionLibraryDefinition> flib_def, 193 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, 194 FunctionLibraryRuntime* lib, 195 IteratorResource** resource) override; 196 197 DataTypeVector output_dtypes_; 198 std::vector<PartialTensorShape> output_shapes_; 199 const int graph_def_version_; 200 }; 201 202 // A hybrid asynchronous-and-synchronous OpKernel with efficient support for 203 // both modes. 204 // 205 // Inherit from this class when the application logic of the kernel (i) is 206 // implemented synchronously, (ii) must run on a background thread when the 207 // kernel executes in the inter-op threadpool (typically because it depends on 208 // inter-op threadpool threads, e.g. for function execution), and (iii) can run 209 // synchronously on the calling thread when the caller donates a thread 210 // (typically in eager execution). The implementation avoids a thread-hop in 211 // case (iii). 212 // 213 // NOTE: Unlike typical OpKernel subclasses, the application logic is 214 // implemented in a method (DoCompute()) that returns Status. Use 215 // TF_RETURN_IF_ERROR for error-related control flow rather than 216 // OP_REQUIRES_OK(). 217 class HybridAsyncOpKernel : public AsyncOpKernel { 218 public: 219 HybridAsyncOpKernel(OpKernelConstruction* ctx, 220 const char* background_worker_name); 221 222 void Compute(OpKernelContext* ctx) final; 223 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) final; 224 225 protected: 226 virtual Status DoCompute(OpKernelContext* ctx) = 0; 227 228 private: 229 BackgroundWorker background_worker_; 230 }; 231 232 class MakeIteratorOp : public HybridAsyncOpKernel { 233 public: MakeIteratorOp(OpKernelConstruction * ctx)234 explicit MakeIteratorOp(OpKernelConstruction* ctx) 235 : HybridAsyncOpKernel(ctx, "tf_data_make_iterator") {} 236 237 protected: 238 Status DoCompute(OpKernelContext* ctx) override; 239 }; 240 241 class IteratorGetNextOp : public HybridAsyncOpKernel { 242 public: IteratorGetNextOp(OpKernelConstruction * ctx)243 explicit IteratorGetNextOp(OpKernelConstruction* ctx) 244 : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") { 245 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 246 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 247 } 248 249 AsyncOpKernel* AsAsync() override; 250 251 protected: 252 Status DoCompute(OpKernelContext* ctx) override; 253 254 private: 255 DataTypeVector output_types_; 256 std::vector<PartialTensorShape> output_shapes_; 257 }; 258 259 class DeleteIteratorOp : public HybridAsyncOpKernel { 260 public: DeleteIteratorOp(OpKernelConstruction * ctx)261 explicit DeleteIteratorOp(OpKernelConstruction* ctx) 262 : HybridAsyncOpKernel(ctx, "tf_data_delete_iterator") {} 263 264 protected: 265 Status DoCompute(OpKernelContext* ctx) override; 266 }; 267 268 class IteratorGetNextAsOptionalOp : public HybridAsyncOpKernel { 269 public: IteratorGetNextAsOptionalOp(OpKernelConstruction * ctx)270 explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) 271 : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next_as_optional") { 272 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 273 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 274 } 275 276 protected: 277 Status DoCompute(OpKernelContext* ctx) override; 278 279 private: 280 DataTypeVector output_types_; 281 std::vector<PartialTensorShape> output_shapes_; 282 }; 283 284 class IteratorToStringHandleOp : public OpKernel { 285 public: IteratorToStringHandleOp(OpKernelConstruction * ctx)286 explicit IteratorToStringHandleOp(OpKernelConstruction* ctx) 287 : OpKernel(ctx) {} 288 289 void Compute(OpKernelContext* ctx) override; 290 }; 291 292 class IteratorFromStringHandleOp : public OpKernel { 293 public: 294 explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx); 295 296 void Compute(OpKernelContext* ctx) override; 297 298 private: 299 DataTypeVector output_dtypes_; 300 std::vector<PartialTensorShape> output_shapes_; 301 }; 302 303 class SerializeIteratorOp : public OpKernel { 304 public: 305 static constexpr const char* const kExternalStatePolicy = 306 "external_state_policy"; 307 308 explicit SerializeIteratorOp(OpKernelConstruction* ctx); 309 310 void Compute(OpKernelContext* ctx) override; 311 312 private: 313 SerializationContext::ExternalStatePolicy external_state_policy_ = 314 SerializationContext::ExternalStatePolicy::kWarn; 315 }; 316 317 class DeserializeIteratorOp : public OpKernel { 318 public: DeserializeIteratorOp(OpKernelConstruction * ctx)319 explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 320 321 void Compute(OpKernelContext* ctx) override; 322 }; 323 324 } // namespace data 325 } // namespace tensorflow 326 327 #endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ 328