1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ 16 #define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "absl/types/optional.h" 22 #include "absl/types/span.h" 23 #include "tensorflow/c/eager/abstract_context.h" 24 #include "tensorflow/c/eager/immediate_execution_distributed_manager.h" 25 #include "tensorflow/c/eager/immediate_execution_operation.h" 26 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" 27 #include "tensorflow/c/tensor_interface.h" 28 #include "tensorflow/core/framework/function.h" 29 #include "tensorflow/core/framework/function.pb.h" 30 #include "tensorflow/core/framework/numeric_types.h" 31 #include "tensorflow/core/framework/tensor.h" 32 #include "tensorflow/core/framework/types.pb.h" 33 #include "tensorflow/core/platform/platform.h" 34 #include "tensorflow/core/platform/status.h" 35 #include "tensorflow/core/platform/tstring.h" 36 #include "tensorflow/core/protobuf/config.pb.h" 37 #include "tensorflow/core/util/device_name_utils.h" 38 39 namespace tensorflow { 40 class EagerExecutor; 41 class EagerContext; 42 class CustomDevice; 43 class CustomDeviceOpHandler; 44 class Device; 45 46 // LINT.IfChange 47 // Note: Keep in sync with exported copy of enum in eager/c_api.h. 48 enum ContextDevicePlacementPolicy { 49 // Running operations with input tensors on the wrong device will fail. 50 DEVICE_PLACEMENT_EXPLICIT = 0, 51 // Copy the tensor to the right device but log a warning. 52 DEVICE_PLACEMENT_WARN = 1, 53 // Silently copy the tensor, which has a performance cost since the operation 54 // will be blocked till the copy completes. This is the default policy. 55 DEVICE_PLACEMENT_SILENT = 2, 56 // Placement policy which silently copies int32 tensors but not other dtypes. 57 DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, 58 }; 59 // LINT.ThenChange(//tensorflow/c/eager/c_api.h) 60 61 // Abstract interface to a context. 62 // 63 // A context is responsible for creating key objects such as Tensors, 64 // TensorHandles & Operations. 65 class ImmediateExecutionContext : public AbstractContext { 66 public: 67 // Optimized scalar creation functions 68 virtual AbstractTensorInterface* CreateInt64Scalar(int64_t value) = 0; 69 virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0; 70 virtual AbstractTensorInterface* CreateInt32Scalar(int32_t value) = 0; 71 virtual AbstractTensorInterface* CreateFloatScalar(float value) = 0; 72 virtual AbstractTensorInterface* CreateDoubleScalar(double value) = 0; 73 virtual AbstractTensorInterface* CreateHalfScalar(Eigen::half value) = 0; 74 virtual AbstractTensorInterface* CreateStringScalar(tstring value) = 0; 75 virtual AbstractTensorInterface* CreateComplex128Scalar(complex128 value) = 0; 76 virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0; 77 78 // Tensor creation functions 79 virtual AbstractTensorInterface* CreateTensor( 80 DataType dtype, absl::Span<const int64_t> dim_sizes) = 0; 81 82 typedef void (*MemoryReleaser)(void* data, size_t len, void* arg); 83 84 // Create a tensor instance from the given data buffer and description. 85 // `memory_releaser` will be called on destruction, and it's responsible for 86 // cleaning up the underlying buffer. 87 virtual AbstractTensorInterface* CreateTensor( 88 DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len, 89 MemoryReleaser memory_releaser, void* memory_releaser_arg) = 0; 90 91 // Create a handle to wrap and manage a Tensor 92 virtual ImmediateExecutionTensorHandle* CreateLocalHandle( 93 AbstractTensorInterface* t) = 0; 94 // Copy the handle to another device. 95 virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( 96 ImmediateExecutionTensorHandle* handle, const char* device_name, 97 Status* status) = 0; 98 99 // Create an operation to perform op execution 100 ImmediateExecutionOperation* CreateOperation() override = 0; 101 102 // Returns whether the runtime is backed by TFRT or the legacy TF Eager 103 // Runtime. This is necessary to decouple runtime-dependent 104 // code that is layered on top of the runtime. 105 virtual bool UsesTFRT() = 0; 106 107 // List attributes of available devices 108 virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0; 109 110 // Add `devices` into context's device manager. Context's device manager 111 // will take ownership and maintain devices' lifetime. 112 virtual Status AddDevices(std::vector<std::unique_ptr<Device>> devices) = 0; 113 114 // Block until all pending nodes are finished. 115 virtual Status AsyncWait() = 0; 116 117 // Add a function (serialized FunctionDef protocol buffer) so that it can 118 // be executed as an op. Return error if the function with the same name 119 // already exists. 120 virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; 121 122 // Same as `AddFunctionDef`, but additionally saves the `stack_traces` under 123 // the key of the function definition name (to be retrieved during function 124 // instantiation). 125 virtual Status AddFunctionDefWithStackTraces( 126 const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0; 127 128 // Find and return a added function by its name. 129 virtual const FunctionDef* FindFunctionDef(const string& name) const = 0; 130 131 // Return the ParsedName of Host CPU device. 132 virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; 133 virtual const string& HostCPUName() const = 0; 134 135 // Configure soft device placement policy. 136 virtual void SetAllowSoftPlacement(bool enable) = 0; 137 138 // Configure device placement policy logging. 139 virtual void SetLogDevicePlacement(bool enable) = 0; 140 141 // Enables running eager ops as functions. 142 virtual void SetRunEagerOpAsFunction(bool enable) = 0; 143 144 // Enables rewriting jit_compile functions. 145 virtual void SetJitCompileRewrite(bool enable) = 0; 146 147 // Sets the device placement policy for the current thread. 148 virtual void SetThreadLocalDevicePlacementPolicy( 149 ContextDevicePlacementPolicy policy) = 0; 150 // Returns the device placement policy for the current thread. 151 virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0; 152 153 // Configure graph collection in RunMetadata. 154 virtual void SetShouldStoreGraphs(bool value) = 0; 155 156 // Return the collected RunMetadata. This method will transfer the ownership 157 // to the caller. 158 virtual std::unique_ptr<RunMetadata> ExportRunMetadata() = 0; 159 160 // For LLVM style RTTI. classof(const AbstractContext * ptr)161 static bool classof(const AbstractContext* ptr) { 162 return ptr->getKind() == kEager || ptr->getKind() == kTfrt; 163 } 164 165 //===--------------------------------------------------------------------===// 166 // Experimental Custom Device. 167 //===--------------------------------------------------------------------===// 168 virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0; 169 170 // Register a custom device. It will return error is the device name is 171 // already registered. 172 // TODO(tfrt-devs): Remove this method. Let caller register it directly into 173 // CustomDeviceOpHandler. 174 virtual Status RegisterCustomDevice(const string& name, 175 std::unique_ptr<CustomDevice> device) = 0; 176 177 // Return FunctionLibraryDefinition. Transformations need to use it to use it 178 // to invoke MLIR compiler passes. 179 virtual FunctionLibraryDefinition* FuncLibDef() = 0; 180 181 // When tensor transfer across functions/eager executions using send/recv ops 182 // are required, `reuse_rendezvous_for_functions_` can be set to true so that 183 // function executions and eager executions use the same rendezvous instance, 184 // instead of creating new instance per function calls. 185 virtual void SetReuseRendezvousForFunctions( 186 bool reuse_rendezvous_for_functions) = 0; 187 188 // Resets the global rendezvous used for functions. 189 virtual void ResetGlobalRendezvousForFunction() = 0; 190 191 //===--------------------------------------------------------------------===// 192 // Following are features in current TF Eager Runtime. 193 // TODO(tfrt-devs): Figure out a way to deprecate following features after 194 // migrated to TFRT. 195 //===--------------------------------------------------------------------===// 196 // Clear pending nodes in thread executors and kernel caches. 197 virtual void ClearCachesAndThreadExecutors() = 0; 198 199 // Initialize the step resource container for a training step. This is used 200 // in current TF runtime. For tfrt, it is used by fallback op handler. 201 virtual void StartStep() = 0; 202 // Destroy the step resource container for a training step. 203 virtual void EndStep() = 0; 204 205 // Return the Eager Executor for current thread. Please note that Eager 206 // Executor is only used in current TF but not in TFRT. 207 virtual EagerExecutor& Executor() = 0; 208 // Update the Eager Executor for current thread. 209 virtual void SetExecutorForThread(EagerExecutor* executor) = 0; 210 211 // Return a list of local tensorflow::Device*. 212 // TODO(tfrt-devs): We shouldn't expose legacy device in this API. 213 virtual std::vector<tensorflow::Device*> ListLocalTfDevices() = 0; 214 215 // Return a list of all tensorflow::Device*. 216 virtual std::vector<tensorflow::Device*> ListAllTfDevices() = 0; 217 218 //===--------------------------------------------------------------------===// 219 // Following are helper functions to assist integrating TFRT with current 220 // TF eager runtime. 221 // TODO(b/172877902): These helper functions are currently used to support 222 // PyFuncOp on TFRT, and might be useful for ops that directly use low 223 // level TF APIs. Remove/replace the following functions when TFRT native 224 // ops are implemented. 225 //===--------------------------------------------------------------------===// 226 // Create an abstract tensor handle from tensorflow::Tensor. 227 virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor( 228 tensorflow::Tensor& t, const char* d_name) = 0; 229 230 // Convert a TFRT TensorHandle to tensorflow::TensorHandle. 231 virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface( 232 ImmediateExecutionTensorHandle* handle) = 0; 233 GetLoggedOpsTestonly()234 virtual std::vector<std::string> GetLoggedOpsTestonly() { return {}; } 235 236 // Get a list of the names of functions that have been registered. 237 virtual std::vector<string> ListFunctionNames() = 0; 238 239 //===--------------------------------------------------------------------===// 240 // Distributed runtime related functions. 241 //===--------------------------------------------------------------------===// 242 #if !defined(IS_MOBILE_PLATFORM) 243 // Set up a multi-client distributed execution environment. Must be called on 244 // all tasks in the cluster. 245 // This call internally coordinates with other tasks to initialize the eager 246 // context and TF server for multi-client execution. 247 virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0; 248 249 // Set a distributed manager that helps set up, update, and check liveness 250 // of member tasks in the cluster. 251 virtual void SetDistributedManager( 252 std::unique_ptr<ImmediateExecutionDistributedManager> distributed) = 0; 253 254 virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0; 255 #endif // !IS_MOBILE_PLATFORM 256 257 protected: ImmediateExecutionContext(AbstractContextKind kind)258 explicit ImmediateExecutionContext(AbstractContextKind kind) 259 : AbstractContext(kind) {} ~ImmediateExecutionContext()260 ~ImmediateExecutionContext() override {} 261 }; 262 263 namespace internal { 264 struct ImmediateExecutionContextDeleter { operatorImmediateExecutionContextDeleter265 void operator()(ImmediateExecutionContext* p) const { 266 if (p != nullptr) { 267 p->Release(); 268 } 269 } 270 }; 271 } // namespace internal 272 273 using ImmediateContextPtr = 274 std::unique_ptr<ImmediateExecutionContext, 275 internal::ImmediateExecutionContextDeleter>; 276 277 } // namespace tensorflow 278 279 #endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ 280