xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/captured_function.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_
16 #define TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/cancellation.h"
22 #include "tensorflow/core/framework/dataset.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/model.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/platform/macros.h"
31 
32 namespace tensorflow {
33 
34 class Device;
35 class OpKernelContext;
36 class ResourceMgr;
37 
38 namespace data {
39 
40 class CapturedFunction;
41 class InstantiatedCapturedFunction;
42 
43 // Creates an iterator for a dataset which is created by applying the given
44 // function to the given input element.
45 Status MakeIteratorFromInputElement(
46     IteratorContext* ctx, const IteratorBase* parent,
47     const std::vector<Tensor>& input_element, int64_t thread_index,
48     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
49     std::unique_ptr<IteratorBase>* out_iterator);
50 
51 // Creates an iterator for a dataset which is created by applying the given
52 // function to the given input element. Pass non-null `node` to record
53 // processing time for modeling Iterator's GetNext() resource usage.
54 Status MakeIteratorFromInputElement(
55     IteratorContext* ctx, const IteratorBase* parent,
56     const std::vector<Tensor>& input_element, int64_t thread_index,
57     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
58     std::unique_ptr<IteratorBase>* out_iterator,
59     const std::shared_ptr<model::Node>& node);
60 
61 // Creates an iterator context appropriate for a nested dataset's iterator. A
62 // nested dataset is a dataset created within another dataset, e.g. by the
63 // function passed to `interleave` or `flat_map`.
64 IteratorContext MakeNestedIteratorContext(IteratorContext* ctx);
65 
66 struct ShortCircuitInfo {
67   std::vector<int> indices;
68   std::vector<bool> can_move;
69 };
70 
71 // Metadata shared across all captures of the same function.
72 class FunctionMetadata {
73  public:
74   struct Params {
75     bool use_inter_op_parallelism = true;
76     bool use_default_device = true;
77   };
78 
79   // Creates a new instance of the `FunctionMetadata` class, fetching function
80   // from a context argument.
81   static Status Create(tensorflow::OpKernelConstruction* ctx,
82                        const string& func_name, Params params,
83                        std::shared_ptr<FunctionMetadata>* out_metadata);
84 
85   // Creates a new instance of the `FunctionMetadata` class, using the provided
86   // function.
87   static Status Create(tensorflow::OpKernelConstruction* ctx,
88                        NameAttrList&& func, Params params,
89                        std::shared_ptr<FunctionMetadata>* out_metadata);
90 
91   // Returns the named list of function arguments.
func()92   const NameAttrList& func() const { return func_; }
93 
94   // Returns a borrowed pointer to the function library that contains the
95   // transitive closure of definitions used by the function.
lib_def()96   const FunctionLibraryDefinition* lib_def() const { return lib_def_.get(); }
97 
98   // Returns short-circuit information.
short_circuit_info()99   const ShortCircuitInfo& short_circuit_info() const {
100     return short_circuit_info_;
101   }
102 
103   // Indicates whether a default device should be used for executing function
104   // ops.
use_default_device()105   bool use_default_device() const { return use_default_device_; }
106 
107   // Indicates whether to use inter-op parallelism for execution of the
108   // function.
use_inter_op_parallelism()109   bool use_inter_op_parallelism() const { return use_inter_op_parallelism_; }
110 
111   // Indicates whether the function should a multi-device function backend.
use_multi_device_function()112   bool use_multi_device_function() const { return use_multi_device_function_; }
113 
114  private:
FunctionMetadata(NameAttrList && func,Params params)115   FunctionMetadata(NameAttrList&& func, Params params)
116       : func_(std::move(func)),
117         use_default_device_(params.use_default_device),
118         use_inter_op_parallelism_(params.use_inter_op_parallelism) {}
119 
120   NameAttrList func_;
121   std::unique_ptr<FunctionLibraryDefinition> lib_def_ = nullptr;
122   ShortCircuitInfo short_circuit_info_;
123   bool use_default_device_ = true;
124   bool use_inter_op_parallelism_ = true;
125   bool use_multi_device_function_ = true;
126 };
127 
128 // Constructs and stores the parameters for the CapturedFunction Instantiate
129 // function.
130 struct InstantiateCapturedFunctionParams {
InstantiateCapturedFunctionParamsInstantiateCapturedFunctionParams131   explicit InstantiateCapturedFunctionParams(IteratorContext* ctx) {
132     flr = ctx->flr();
133     function_handle_cache = ctx->function_handle_cache();
134     runner = ctx->runner();
135   }
136 
InstantiateCapturedFunctionParamsInstantiateCapturedFunctionParams137   explicit InstantiateCapturedFunctionParams(OpKernelContext* ctx) {
138     flr = ctx->function_library();
139     function_handle_cache = nullptr;
140     runner = ctx->runner();
141   }
142 
143   FunctionLibraryRuntime* flr;
144   FunctionHandleCache* function_handle_cache;
145   std::function<void(std::function<void()>)>* runner;
146 };
147 
148 // A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured"
149 // arguments that it closed over in the user program.
150 class CapturedFunction {
151  public:
152   // Creates a new instance using a list of named attributes, fetching captured
153   // inputs from a context argument.
154   static Status Create(OpKernelContext* ctx,
155                        std::shared_ptr<const FunctionMetadata> metadata,
156                        const string& argument_name,
157                        std::unique_ptr<CapturedFunction>* out_function);
158 
159   // Creates a new instance using a list of named attributes, using provided
160   // captured inputs.
161   static Status Create(OpKernelContext* ctx,
162                        std::shared_ptr<const FunctionMetadata> metadata,
163                        std::vector<Tensor>&& captured_inputs,
164                        std::unique_ptr<CapturedFunction>* out_function);
165 
166   // Adds the definition of this captured function into the given graph,
167   // returning its captured inputs and types through the respective output
168   // arguments.
169   Status AddToGraph(SerializationContext* ctx,
170                     DatasetBase::DatasetGraphDefBuilder* b,
171                     std::vector<Node*>* other_arguments,
172                     DataTypeVector* other_arguments_types) const;
173 
174   // Instantiates this function for use in the given context, providing an
175   // InstantiatedCapturedFunction that can be used to execute functions.
176   Status Instantiate(IteratorContext* ctx,
177                      std::unique_ptr<InstantiatedCapturedFunction>*
178                          instantiated_captured_function);
179 
180   Status Instantiate(InstantiateCapturedFunctionParams params,
181                      std::unique_ptr<InstantiatedCapturedFunction>*
182                          instantiated_captured_function);
183 
184   // Determines whether the captured function is stateful.
185   Status CheckExternalState() const;
186 
187   // Returns the additional captured inputs that will be passed to the function.
captured_inputs()188   const std::vector<Tensor>& captured_inputs() const {
189     return captured_inputs_;
190   }
191 
192   // Returns the named list of function arguments.
func()193   const NameAttrList& func() const { return metadata_->func(); }
194 
195   // Returns the transitive set of function definition required to instantiate
196   // this function.
lib_def()197   const FunctionLibraryDefinition* lib_def() const {
198     return metadata_->lib_def();
199   }
200 
201   // If every function output corresponds to one of its inputs, the method
202   // returns the mapping from output indices to input indices. Otherwise, it
203   // returns an empty list.
short_circuit_info()204   const ShortCircuitInfo& short_circuit_info() const {
205     return metadata_->short_circuit_info();
206   }
207 
208   // Indicates whether the function should use inter op parallelism.
use_inter_op_parallelism()209   bool use_inter_op_parallelism() const {
210     return metadata_->use_inter_op_parallelism();
211   }
212 
213  private:
214   CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
215                    std::vector<Tensor> captured_inputs);
216 
217   Status IsMultiDevice(FunctionLibraryRuntime* flr,
218                        bool* is_multi_device) const;
219 
220   const std::shared_ptr<const FunctionMetadata> metadata_;
221   const std::vector<Tensor> captured_inputs_;
222 
223   TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
224 };
225 
226 // `InstantiatedCapturedFunction` encapsulates all the runtime support needed
227 // to execute a tensorflow function.
228 //
229 // While `CapturedFunction` encapsulates constant attributes of the function,
230 // such as its name and captured arguments, `InstantiatedCapturedFunction`
231 // encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function
232 // handle.
233 //
234 // The `Iterator` related classes use `InstantiatedCapturedFunction` to execute
235 // functions outside of the normal `OpKernel::Compute()` context.
236 class InstantiatedCapturedFunction {
237  public:
238   // Runs the instantiated captured function. This method takes ownership of
239   // the tensors in `args`, in order to be able to deallocate them as early as
240   // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
241   // ownership of the `args`.
242   Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
243              std::vector<Tensor>* rets) const;
244 
245   // Runs the instantiated captured function. This method takes ownership of
246   // the tensors in `args`, in order to be able to deallocate them as early as
247   // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
248   // ownership of the `args`. Pass non-null `node` to record processing time
249   // for modeling Iterator's GetNext() resource usage. When non-null node is
250   // provided, the pre-requisite is that the calling thread has previously
251   // called `DatasetBaseIterator::RecordStart().
252   Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
253              std::vector<Tensor>* rets,
254              const std::shared_ptr<model::Node>& node) const;
255 
256   // Synchronously runs the captured function on the given `args`, and stores
257   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
258   // possible.
259   Status RunWithBorrowedArgs(IteratorContext* ctx,
260                              const std::vector<Tensor>& args,
261                              std::vector<Tensor>* rets) const;
262 
263   // Synchronously runs the captured function on the given `args`, and stores
264   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
265   // possible. Pass non-null `node` to record processing time for modeling
266   // Iterator's GetNext() resource usage. When non-null node is provided, the
267   // pre-requisite is that the calling thread has previously called
268   // `DatasetBaseIterator::RecordStart().
269   Status RunWithBorrowedArgs(IteratorContext* ctx,
270                              const std::vector<Tensor>& args,
271                              std::vector<Tensor>* rets,
272                              const std::shared_ptr<model::Node>& node) const;
273 
274   // Synchronously runs the captured function on the given `args`, and stores
275   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
276   // possible. This can be useful for calling a captured function in cases where
277   // an `IteratorContext*` is not available (such as a destructor).
278   //
279   // TODO(b/144278100): Avoid running functions without IteratorContext.
280   Status RunInstantiated(const std::vector<Tensor>& args,
281                          std::vector<Tensor>* rets);
282 
283   // Asynchronously runs the captured function on the given `args`, stores the
284   // results in `*rets`, and calls the given `done` callback when the function
285   // returns. This method takes ownership of the tensors in `args`, in order to
286   // be able to deallocate them as early as possible. Pass non-null `node` to
287   // record processing time for modeling Iterator's GetNext() resource usage.
288   // When non-null node is provided, the pre-requisite is that the calling
289   // thread has previously called `DatasetBaseIterator::RecordStart().
290   void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
291                 std::vector<Tensor>* rets,
292                 FunctionLibraryRuntime::DoneCallback done,
293                 const std::shared_ptr<model::Node>& node) const;
294 
295  private:
296   friend class CapturedFunction;
297 
298   InstantiatedCapturedFunction(
299       FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
300       DataTypeVector ret_types,
301       std::function<void(std::function<void()>)> runner,
302       CapturedFunction* captured_func, bool is_multi_device);
303 
304   // Determines whether a rendezvous object should be created when running the
305   // instantiated function.
306   bool ShouldCreateRendezvous() const;
307 
308   FunctionLibraryRuntime* const lib_;  // Not owned.
309   const FunctionLibraryRuntime::Handle f_handle_;
310   const DataTypeVector ret_types_;
311   // Note: We capture the runner at function instantiation time to be able to
312   // run the function without `IteratorContext` via `RunInstantiated`.
313   std::function<void(std::function<void()>)> captured_runner_;
314   CapturedFunction* const captured_func_;  // Not owned.
315   const bool is_multi_device_;
316 
317   TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction);
318 };
319 
320 }  // namespace data
321 }  // namespace tensorflow
322 
323 #endif  // TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_
324