xref: /aosp_15_r20/external/tensorflow/tensorflow/c/eager/immediate_execution_context.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
16 #define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/types/optional.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/c/eager/abstract_context.h"
24 #include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
25 #include "tensorflow/c/eager/immediate_execution_operation.h"
26 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
27 #include "tensorflow/c/tensor_interface.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/function.pb.h"
30 #include "tensorflow/core/framework/numeric_types.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/platform/platform.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/tstring.h"
36 #include "tensorflow/core/protobuf/config.pb.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 
39 namespace tensorflow {
40 class EagerExecutor;
41 class EagerContext;
42 class CustomDevice;
43 class CustomDeviceOpHandler;
44 class Device;
45 
46 // LINT.IfChange
47 // Note: Keep in sync with exported copy of enum in eager/c_api.h.
48 enum ContextDevicePlacementPolicy {
49   // Running operations with input tensors on the wrong device will fail.
50   DEVICE_PLACEMENT_EXPLICIT = 0,
51   // Copy the tensor to the right device but log a warning.
52   DEVICE_PLACEMENT_WARN = 1,
53   // Silently copy the tensor, which has a performance cost since the operation
54   // will be blocked till the copy completes. This is the default policy.
55   DEVICE_PLACEMENT_SILENT = 2,
56   // Placement policy which silently copies int32 tensors but not other dtypes.
57   DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
58 };
59 // LINT.ThenChange(//tensorflow/c/eager/c_api.h)
60 
61 // Abstract interface to a context.
62 //
63 // A context is responsible for creating key objects such as Tensors,
64 // TensorHandles & Operations.
65 class ImmediateExecutionContext : public AbstractContext {
66  public:
67   // Optimized scalar creation functions
68   virtual AbstractTensorInterface* CreateInt64Scalar(int64_t value) = 0;
69   virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
70   virtual AbstractTensorInterface* CreateInt32Scalar(int32_t value) = 0;
71   virtual AbstractTensorInterface* CreateFloatScalar(float value) = 0;
72   virtual AbstractTensorInterface* CreateDoubleScalar(double value) = 0;
73   virtual AbstractTensorInterface* CreateHalfScalar(Eigen::half value) = 0;
74   virtual AbstractTensorInterface* CreateStringScalar(tstring value) = 0;
75   virtual AbstractTensorInterface* CreateComplex128Scalar(complex128 value) = 0;
76   virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0;
77 
78   // Tensor creation functions
79   virtual AbstractTensorInterface* CreateTensor(
80       DataType dtype, absl::Span<const int64_t> dim_sizes) = 0;
81 
82   typedef void (*MemoryReleaser)(void* data, size_t len, void* arg);
83 
84   // Create a tensor instance from the given data buffer and description.
85   // `memory_releaser` will be called on destruction, and it's responsible for
86   // cleaning up the underlying buffer.
87   virtual AbstractTensorInterface* CreateTensor(
88       DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
89       MemoryReleaser memory_releaser, void* memory_releaser_arg) = 0;
90 
91   // Create a handle to wrap and manage a Tensor
92   virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
93       AbstractTensorInterface* t) = 0;
94   // Copy the handle to another device.
95   virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
96       ImmediateExecutionTensorHandle* handle, const char* device_name,
97       Status* status) = 0;
98 
99   // Create an operation to perform op execution
100   ImmediateExecutionOperation* CreateOperation() override = 0;
101 
102   // Returns whether the runtime is backed by TFRT or the legacy TF Eager
103   // Runtime. This is necessary to decouple runtime-dependent
104   // code that is layered on top of the runtime.
105   virtual bool UsesTFRT() = 0;
106 
107   // List attributes of available devices
108   virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
109 
110   // Add `devices` into context's device manager. Context's device manager
111   // will take ownership and maintain devices' lifetime.
112   virtual Status AddDevices(std::vector<std::unique_ptr<Device>> devices) = 0;
113 
114   // Block until all pending nodes are finished.
115   virtual Status AsyncWait() = 0;
116 
117   // Add a function (serialized FunctionDef protocol buffer) so that it can
118   // be executed as an op. Return error if the function with the same name
119   // already exists.
120   virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
121 
122   // Same as `AddFunctionDef`, but additionally saves the `stack_traces` under
123   // the key of the function definition name (to be retrieved during function
124   // instantiation).
125   virtual Status AddFunctionDefWithStackTraces(
126       const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0;
127 
128   // Find and return a added function by its name.
129   virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
130 
131   // Return the ParsedName of Host CPU device.
132   virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
133   virtual const string& HostCPUName() const = 0;
134 
135   // Configure soft device placement policy.
136   virtual void SetAllowSoftPlacement(bool enable) = 0;
137 
138   // Configure device placement policy logging.
139   virtual void SetLogDevicePlacement(bool enable) = 0;
140 
141   // Enables running eager ops as functions.
142   virtual void SetRunEagerOpAsFunction(bool enable) = 0;
143 
144   // Enables rewriting jit_compile functions.
145   virtual void SetJitCompileRewrite(bool enable) = 0;
146 
147   // Sets the device placement policy for the current thread.
148   virtual void SetThreadLocalDevicePlacementPolicy(
149       ContextDevicePlacementPolicy policy) = 0;
150   // Returns the device placement policy for the current thread.
151   virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
152 
153   // Configure graph collection in RunMetadata.
154   virtual void SetShouldStoreGraphs(bool value) = 0;
155 
156   // Return the collected RunMetadata. This method will transfer the ownership
157   // to the caller.
158   virtual std::unique_ptr<RunMetadata> ExportRunMetadata() = 0;
159 
160   // For LLVM style RTTI.
classof(const AbstractContext * ptr)161   static bool classof(const AbstractContext* ptr) {
162     return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
163   }
164 
165   //===--------------------------------------------------------------------===//
166   // Experimental Custom Device.
167   //===--------------------------------------------------------------------===//
168   virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0;
169 
170   // Register a custom device. It will return error is the device name is
171   // already registered.
172   // TODO(tfrt-devs): Remove this method. Let caller register it directly into
173   // CustomDeviceOpHandler.
174   virtual Status RegisterCustomDevice(const string& name,
175                                       std::unique_ptr<CustomDevice> device) = 0;
176 
177   // Return FunctionLibraryDefinition. Transformations need to use it to use it
178   // to invoke MLIR compiler passes.
179   virtual FunctionLibraryDefinition* FuncLibDef() = 0;
180 
181   // When tensor transfer across functions/eager executions using send/recv ops
182   // are required, `reuse_rendezvous_for_functions_` can be set to true so that
183   // function executions and eager executions use the same rendezvous instance,
184   // instead of creating new instance per function calls.
185   virtual void SetReuseRendezvousForFunctions(
186       bool reuse_rendezvous_for_functions) = 0;
187 
188   // Resets the global rendezvous used for functions.
189   virtual void ResetGlobalRendezvousForFunction() = 0;
190 
191   //===--------------------------------------------------------------------===//
192   // Following are features in current TF Eager Runtime.
193   // TODO(tfrt-devs): Figure out a way to deprecate following features after
194   // migrated to TFRT.
195   //===--------------------------------------------------------------------===//
196   // Clear pending nodes in thread executors and kernel caches.
197   virtual void ClearCachesAndThreadExecutors() = 0;
198 
199   // Initialize the step resource container for a training step. This is used
200   // in current TF runtime. For tfrt, it is used by fallback op handler.
201   virtual void StartStep() = 0;
202   // Destroy the step resource container for a training step.
203   virtual void EndStep() = 0;
204 
205   // Return the Eager Executor for current thread. Please note that Eager
206   // Executor is only used in current TF but not in TFRT.
207   virtual EagerExecutor& Executor() = 0;
208   // Update the Eager Executor for current thread.
209   virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
210 
211   // Return a list of local tensorflow::Device*.
212   // TODO(tfrt-devs): We shouldn't expose legacy device in this API.
213   virtual std::vector<tensorflow::Device*> ListLocalTfDevices() = 0;
214 
215   // Return a list of all tensorflow::Device*.
216   virtual std::vector<tensorflow::Device*> ListAllTfDevices() = 0;
217 
218   //===--------------------------------------------------------------------===//
219   // Following are helper functions to assist integrating TFRT with current
220   // TF eager runtime.
221   // TODO(b/172877902): These helper functions are currently used to support
222   // PyFuncOp on TFRT, and might be useful for ops that directly use low
223   // level TF APIs. Remove/replace the following functions when TFRT native
224   // ops are implemented.
225   //===--------------------------------------------------------------------===//
226   // Create an abstract tensor handle from tensorflow::Tensor.
227   virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
228       tensorflow::Tensor& t, const char* d_name) = 0;
229 
230   // Convert a TFRT TensorHandle to tensorflow::TensorHandle.
231   virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
232       ImmediateExecutionTensorHandle* handle) = 0;
233 
GetLoggedOpsTestonly()234   virtual std::vector<std::string> GetLoggedOpsTestonly() { return {}; }
235 
236   // Get a list of the names of functions that have been registered.
237   virtual std::vector<string> ListFunctionNames() = 0;
238 
239   //===--------------------------------------------------------------------===//
240   // Distributed runtime related functions.
241   //===--------------------------------------------------------------------===//
242 #if !defined(IS_MOBILE_PLATFORM)
243   // Set up a multi-client distributed execution environment. Must be called on
244   // all tasks in the cluster.
245   // This call internally coordinates with other tasks to initialize the eager
246   // context and TF server for multi-client execution.
247   virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0;
248 
249   // Set a distributed manager that helps set up, update, and check liveness
250   // of member tasks in the cluster.
251   virtual void SetDistributedManager(
252       std::unique_ptr<ImmediateExecutionDistributedManager> distributed) = 0;
253 
254   virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0;
255 #endif  // !IS_MOBILE_PLATFORM
256 
257  protected:
ImmediateExecutionContext(AbstractContextKind kind)258   explicit ImmediateExecutionContext(AbstractContextKind kind)
259       : AbstractContext(kind) {}
~ImmediateExecutionContext()260   ~ImmediateExecutionContext() override {}
261 };
262 
263 namespace internal {
264 struct ImmediateExecutionContextDeleter {
operatorImmediateExecutionContextDeleter265   void operator()(ImmediateExecutionContext* p) const {
266     if (p != nullptr) {
267       p->Release();
268     }
269   }
270 };
271 }  // namespace internal
272 
273 using ImmediateContextPtr =
274     std::unique_ptr<ImmediateExecutionContext,
275                     internal::ImmediateExecutionContextDeleter>;
276 
277 }  // namespace tensorflow
278 
279 #endif  // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
280