1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ 16 #define TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "tensorflow/core/framework/cancellation.h" 22 #include "tensorflow/core/framework/dataset.h" 23 #include "tensorflow/core/framework/function.h" 24 #include "tensorflow/core/framework/model.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/gtl/array_slice.h" 29 #include "tensorflow/core/lib/random/random.h" 30 #include "tensorflow/core/platform/macros.h" 31 32 namespace tensorflow { 33 34 class Device; 35 class OpKernelContext; 36 class ResourceMgr; 37 38 namespace data { 39 40 class CapturedFunction; 41 class InstantiatedCapturedFunction; 42 43 // Creates an iterator for a dataset which is created by applying the given 44 // function to the given input element. 45 Status MakeIteratorFromInputElement( 46 IteratorContext* ctx, const IteratorBase* parent, 47 const std::vector<Tensor>& input_element, int64_t thread_index, 48 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, 49 std::unique_ptr<IteratorBase>* out_iterator); 50 51 // Creates an iterator for a dataset which is created by applying the given 52 // function to the given input element. Pass non-null `node` to record 53 // processing time for modeling Iterator's GetNext() resource usage. 54 Status MakeIteratorFromInputElement( 55 IteratorContext* ctx, const IteratorBase* parent, 56 const std::vector<Tensor>& input_element, int64_t thread_index, 57 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, 58 std::unique_ptr<IteratorBase>* out_iterator, 59 const std::shared_ptr<model::Node>& node); 60 61 // Creates an iterator context appropriate for a nested dataset's iterator. A 62 // nested dataset is a dataset created within another dataset, e.g. by the 63 // function passed to `interleave` or `flat_map`. 64 IteratorContext MakeNestedIteratorContext(IteratorContext* ctx); 65 66 struct ShortCircuitInfo { 67 std::vector<int> indices; 68 std::vector<bool> can_move; 69 }; 70 71 // Metadata shared across all captures of the same function. 72 class FunctionMetadata { 73 public: 74 struct Params { 75 bool use_inter_op_parallelism = true; 76 bool use_default_device = true; 77 }; 78 79 // Creates a new instance of the `FunctionMetadata` class, fetching function 80 // from a context argument. 81 static Status Create(tensorflow::OpKernelConstruction* ctx, 82 const string& func_name, Params params, 83 std::shared_ptr<FunctionMetadata>* out_metadata); 84 85 // Creates a new instance of the `FunctionMetadata` class, using the provided 86 // function. 87 static Status Create(tensorflow::OpKernelConstruction* ctx, 88 NameAttrList&& func, Params params, 89 std::shared_ptr<FunctionMetadata>* out_metadata); 90 91 // Returns the named list of function arguments. func()92 const NameAttrList& func() const { return func_; } 93 94 // Returns a borrowed pointer to the function library that contains the 95 // transitive closure of definitions used by the function. lib_def()96 const FunctionLibraryDefinition* lib_def() const { return lib_def_.get(); } 97 98 // Returns short-circuit information. short_circuit_info()99 const ShortCircuitInfo& short_circuit_info() const { 100 return short_circuit_info_; 101 } 102 103 // Indicates whether a default device should be used for executing function 104 // ops. use_default_device()105 bool use_default_device() const { return use_default_device_; } 106 107 // Indicates whether to use inter-op parallelism for execution of the 108 // function. use_inter_op_parallelism()109 bool use_inter_op_parallelism() const { return use_inter_op_parallelism_; } 110 111 // Indicates whether the function should a multi-device function backend. use_multi_device_function()112 bool use_multi_device_function() const { return use_multi_device_function_; } 113 114 private: FunctionMetadata(NameAttrList && func,Params params)115 FunctionMetadata(NameAttrList&& func, Params params) 116 : func_(std::move(func)), 117 use_default_device_(params.use_default_device), 118 use_inter_op_parallelism_(params.use_inter_op_parallelism) {} 119 120 NameAttrList func_; 121 std::unique_ptr<FunctionLibraryDefinition> lib_def_ = nullptr; 122 ShortCircuitInfo short_circuit_info_; 123 bool use_default_device_ = true; 124 bool use_inter_op_parallelism_ = true; 125 bool use_multi_device_function_ = true; 126 }; 127 128 // Constructs and stores the parameters for the CapturedFunction Instantiate 129 // function. 130 struct InstantiateCapturedFunctionParams { InstantiateCapturedFunctionParamsInstantiateCapturedFunctionParams131 explicit InstantiateCapturedFunctionParams(IteratorContext* ctx) { 132 flr = ctx->flr(); 133 function_handle_cache = ctx->function_handle_cache(); 134 runner = ctx->runner(); 135 } 136 InstantiateCapturedFunctionParamsInstantiateCapturedFunctionParams137 explicit InstantiateCapturedFunctionParams(OpKernelContext* ctx) { 138 flr = ctx->function_library(); 139 function_handle_cache = nullptr; 140 runner = ctx->runner(); 141 } 142 143 FunctionLibraryRuntime* flr; 144 FunctionHandleCache* function_handle_cache; 145 std::function<void(std::function<void()>)>* runner; 146 }; 147 148 // A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured" 149 // arguments that it closed over in the user program. 150 class CapturedFunction { 151 public: 152 // Creates a new instance using a list of named attributes, fetching captured 153 // inputs from a context argument. 154 static Status Create(OpKernelContext* ctx, 155 std::shared_ptr<const FunctionMetadata> metadata, 156 const string& argument_name, 157 std::unique_ptr<CapturedFunction>* out_function); 158 159 // Creates a new instance using a list of named attributes, using provided 160 // captured inputs. 161 static Status Create(OpKernelContext* ctx, 162 std::shared_ptr<const FunctionMetadata> metadata, 163 std::vector<Tensor>&& captured_inputs, 164 std::unique_ptr<CapturedFunction>* out_function); 165 166 // Adds the definition of this captured function into the given graph, 167 // returning its captured inputs and types through the respective output 168 // arguments. 169 Status AddToGraph(SerializationContext* ctx, 170 DatasetBase::DatasetGraphDefBuilder* b, 171 std::vector<Node*>* other_arguments, 172 DataTypeVector* other_arguments_types) const; 173 174 // Instantiates this function for use in the given context, providing an 175 // InstantiatedCapturedFunction that can be used to execute functions. 176 Status Instantiate(IteratorContext* ctx, 177 std::unique_ptr<InstantiatedCapturedFunction>* 178 instantiated_captured_function); 179 180 Status Instantiate(InstantiateCapturedFunctionParams params, 181 std::unique_ptr<InstantiatedCapturedFunction>* 182 instantiated_captured_function); 183 184 // Determines whether the captured function is stateful. 185 Status CheckExternalState() const; 186 187 // Returns the additional captured inputs that will be passed to the function. captured_inputs()188 const std::vector<Tensor>& captured_inputs() const { 189 return captured_inputs_; 190 } 191 192 // Returns the named list of function arguments. func()193 const NameAttrList& func() const { return metadata_->func(); } 194 195 // Returns the transitive set of function definition required to instantiate 196 // this function. lib_def()197 const FunctionLibraryDefinition* lib_def() const { 198 return metadata_->lib_def(); 199 } 200 201 // If every function output corresponds to one of its inputs, the method 202 // returns the mapping from output indices to input indices. Otherwise, it 203 // returns an empty list. short_circuit_info()204 const ShortCircuitInfo& short_circuit_info() const { 205 return metadata_->short_circuit_info(); 206 } 207 208 // Indicates whether the function should use inter op parallelism. use_inter_op_parallelism()209 bool use_inter_op_parallelism() const { 210 return metadata_->use_inter_op_parallelism(); 211 } 212 213 private: 214 CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata, 215 std::vector<Tensor> captured_inputs); 216 217 Status IsMultiDevice(FunctionLibraryRuntime* flr, 218 bool* is_multi_device) const; 219 220 const std::shared_ptr<const FunctionMetadata> metadata_; 221 const std::vector<Tensor> captured_inputs_; 222 223 TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); 224 }; 225 226 // `InstantiatedCapturedFunction` encapsulates all the runtime support needed 227 // to execute a tensorflow function. 228 // 229 // While `CapturedFunction` encapsulates constant attributes of the function, 230 // such as its name and captured arguments, `InstantiatedCapturedFunction` 231 // encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function 232 // handle. 233 // 234 // The `Iterator` related classes use `InstantiatedCapturedFunction` to execute 235 // functions outside of the normal `OpKernel::Compute()` context. 236 class InstantiatedCapturedFunction { 237 public: 238 // Runs the instantiated captured function. This method takes ownership of 239 // the tensors in `args`, in order to be able to deallocate them as early as 240 // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain 241 // ownership of the `args`. 242 Status Run(IteratorContext* ctx, std::vector<Tensor>&& args, 243 std::vector<Tensor>* rets) const; 244 245 // Runs the instantiated captured function. This method takes ownership of 246 // the tensors in `args`, in order to be able to deallocate them as early as 247 // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain 248 // ownership of the `args`. Pass non-null `node` to record processing time 249 // for modeling Iterator's GetNext() resource usage. When non-null node is 250 // provided, the pre-requisite is that the calling thread has previously 251 // called `DatasetBaseIterator::RecordStart(). 252 Status Run(IteratorContext* ctx, std::vector<Tensor>&& args, 253 std::vector<Tensor>* rets, 254 const std::shared_ptr<model::Node>& node) const; 255 256 // Synchronously runs the captured function on the given `args`, and stores 257 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 258 // possible. 259 Status RunWithBorrowedArgs(IteratorContext* ctx, 260 const std::vector<Tensor>& args, 261 std::vector<Tensor>* rets) const; 262 263 // Synchronously runs the captured function on the given `args`, and stores 264 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 265 // possible. Pass non-null `node` to record processing time for modeling 266 // Iterator's GetNext() resource usage. When non-null node is provided, the 267 // pre-requisite is that the calling thread has previously called 268 // `DatasetBaseIterator::RecordStart(). 269 Status RunWithBorrowedArgs(IteratorContext* ctx, 270 const std::vector<Tensor>& args, 271 std::vector<Tensor>* rets, 272 const std::shared_ptr<model::Node>& node) const; 273 274 // Synchronously runs the captured function on the given `args`, and stores 275 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 276 // possible. This can be useful for calling a captured function in cases where 277 // an `IteratorContext*` is not available (such as a destructor). 278 // 279 // TODO(b/144278100): Avoid running functions without IteratorContext. 280 Status RunInstantiated(const std::vector<Tensor>& args, 281 std::vector<Tensor>* rets); 282 283 // Asynchronously runs the captured function on the given `args`, stores the 284 // results in `*rets`, and calls the given `done` callback when the function 285 // returns. This method takes ownership of the tensors in `args`, in order to 286 // be able to deallocate them as early as possible. Pass non-null `node` to 287 // record processing time for modeling Iterator's GetNext() resource usage. 288 // When non-null node is provided, the pre-requisite is that the calling 289 // thread has previously called `DatasetBaseIterator::RecordStart(). 290 void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args, 291 std::vector<Tensor>* rets, 292 FunctionLibraryRuntime::DoneCallback done, 293 const std::shared_ptr<model::Node>& node) const; 294 295 private: 296 friend class CapturedFunction; 297 298 InstantiatedCapturedFunction( 299 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, 300 DataTypeVector ret_types, 301 std::function<void(std::function<void()>)> runner, 302 CapturedFunction* captured_func, bool is_multi_device); 303 304 // Determines whether a rendezvous object should be created when running the 305 // instantiated function. 306 bool ShouldCreateRendezvous() const; 307 308 FunctionLibraryRuntime* const lib_; // Not owned. 309 const FunctionLibraryRuntime::Handle f_handle_; 310 const DataTypeVector ret_types_; 311 // Note: We capture the runner at function instantiation time to be able to 312 // run the function without `IteratorContext` via `RunInstantiated`. 313 std::function<void(std::function<void()>)> captured_runner_; 314 CapturedFunction* const captured_func_; // Not owned. 315 const bool is_multi_device_; 316 317 TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction); 318 }; 319 320 } // namespace data 321 } // namespace tensorflow 322 323 #endif // TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ 324