xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/eager/tensor_handle.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 
26 // clang-format off
27 // Required for IS_MOBILE_PLATFORM
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/platform/platform.h"
31 // clang-format on
32 
33 #include "absl/types/variant.h"
34 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
35 #include "tensorflow/core/common_runtime/device.h"
36 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
37 #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
38 #include "tensorflow/core/common_runtime/function.h"
39 #if !defined(IS_MOBILE_PLATFORM)
40 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
41 #endif  // IS_MOBILE_PLATFORM
42 #include "tensorflow/core/framework/tensor.h"
43 
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 
46 #include "tensorflow/core/platform/mutex.h"
47 #include "tensorflow/core/platform/thread_annotations.h"
48 
49 namespace tensorflow {
50 
51 class EagerContext;
52 
53 // Associates a Tensor and a Device, used in the eager runtime. Internal version
54 // of the TFE_TensorHandle struct and the python EagerTensor class
55 // (unrelated to python TensorHandle).
56 class TensorHandle : public ImmediateExecutionTensorHandle {
57   // TensorHandle for dtype != DT_RESOURCE
58   TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
59                Device* resource_device, EagerContext* ctx);
60   // TensorHandle for dtype == DT_RESOURCE
61   TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
62                EagerContext* ctx);
63   TensorHandle(Device* d, Device* op_device, Device* resource_device,
64                tensorflow::DataType dtype, EagerContext* ctx);
65 
66 #if !defined(IS_MOBILE_PLATFORM)
67   TensorHandle(int64_t op_id, int32_t output_num, const string& remote_task,
68                tensorflow::DataType dtype, Device* device, EagerContext* ctx,
69                const bool unknown_device);
70   TensorHandle(int64_t op_id, int32_t output_num, tensorflow::DataType dtype,
71                Device* device, const bool is_ready, EagerContext* ctx);
72 #endif  // IS_MOBILE_PLATFORM
73 
74  public:
75   // TensorHandle with no assigned device
76   static TensorHandle* CreateLocalHandle(const tensorflow::Tensor& t);
77   static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
78                                          Device* op_device, EagerContext* ctx);
79   static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
80                                          Device* op_device,
81                                          Device* resource_device,
82                                          EagerContext* ctx);
83   static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device,
84                                               Device* resource_device,
85                                               tensorflow::DataType dtype,
86                                               EagerContext* ctx);
87 
88   // Create a handle which packs the given handles of the same dtype and shape.
89   // If handles are on different devices, assign the packed handle to a
90   // CompositeDevice.
91   //
92   // The new tensor handle shares ownership of the given handle: their reference
93   // count will be increased by one after a call to `CreatePackedHandle`.
94   // TODO(b/170414377): Use `TensorHandlePtr` instead.
95   static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
96                                    const tensorflow::DataType dtype,
97                                    const tensorflow::TensorShape& shape,
98                                    const string& device_name, EagerContext* ctx,
99                                    TensorHandle** packed_handle);
100   static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
101                                    EagerContext* ctx,
102                                    TensorHandle** packed_handle);
103 
104 #if !defined(IS_MOBILE_PLATFORM)
105   // An unshaped remote handle refers to a tensor on a remote worker. It's not
106   // ready until the shape is set. It controls the lifetime of the remote
107   // tensor.
108   static TensorHandle* CreateUnshapedRemoteHandle(
109       int64_t op_id, int32_t output_num, const string& remote_task,
110       tensorflow::DataType dtype, Device* d, EagerContext* ctx,
111       const bool unknown_device = false);
112   // A lazy remote handle refers to a tensor on a remote worker. The lifetime of
113   // the remote tensor is controlled by the remote worker, but not by the lazy
114   // remote handle. Lazy handles are normally created on a default function
115   // device.
116   static TensorHandle* CreateLazyRemoteHandle(int64_t op_id, int32_t output_num,
117                                               tensorflow::DataType dtype,
118                                               Device* d, const bool is_ready,
119                                               EagerContext* ctx);
120 #endif  // IS_MOBILE_PLATFORM
121 
122   void Release() override;
123 
124   tensorflow::DataType DataType() const override;
125   Status Shape(tensorflow::PartialTensorShape* shape) const override;
126   Status NumDims(int* num_dims) const override;
127   Status NumElements(int64_t* num_elements) const override;
128   Status Dim(int dim_index, int64_t* dim) const override;
129 
130   const char* DeviceName(Status* status) const override;
131   const char* BackingDeviceName(Status* status) const override;
132   const char* DeviceType(Status* status) const override;
133   int DeviceId(Status* status) const override;
134   AbstractTensorInterface* Resolve(Status* status) override;
135 
136   ImmediateExecutionTensorHandle* Copy() override;
137 
138   // Subclasses may return True to instruct the string formatter
139   // to use SummarizeValue instead of the NumPy formatter.
PreferCustomSummarizer()140   bool PreferCustomSummarizer() const override {
141     return dtype == DT_VARIANT || dtype == DT_RESOURCE;
142   }
143 
144   // Return the Tensor from the default device.
145   Status Tensor(const tensorflow::Tensor** t) const;
146   // Return the Tensor from the specified device which could be either the
147   // default device or a local mirror. The device pointer should be nullptr if
148   // requesting the HostCPU.
149   Status TensorFromDevice(const Device* d, const tensorflow::Tensor** t) const;
150 
151   // Return the TensorValue from the specified device which could be either the
152   // default device or a local mirror. The device pointer should be nullptr if
153   // requesting the HostCPU.
154   Status TensorValue(const Device* d, tensorflow::TensorValue* t);
155 
device()156   Device* device() const { return device_; }
op_device()157   Device* op_device() const { return op_device_; }
resource_device()158   Device* resource_device() const { return resource_device_; }
resource_remote_device_incarnation()159   int64_t resource_remote_device_incarnation() const {
160     return resource_remote_device_incarnation_;
161   }
162 
163   // If the devices are unknown at creation time, block until the actual devices
164   // are set (data is ready).
165   Status WaitUnknownDevice() const;
166 
167   Device* DeviceOrHostCPU(const EagerContext& ctx) const;
168 
169   Status Shape(tensorflow::TensorShape* shape);
170 
171   Status Unprotect(const Device* d);
172 
173   // Checks if a mirror tensor exists for the specified device. Mirrors are only
174   // maintained for local devices, like CPUs & GPUs. Note a mirror may be empty,
175   // as it is still to be set by an async operation.
176   bool HasLocalMirror(const Device* d) const;
177   // Add an empty mirror placeholder for the specified device. The expectation
178   // is this will be populated by a call to SetTensor.
179   Status AddEmptyLocalMirror(const Device* d);
180   // Add a local mirror. This will fail if an empty local mirror was previously
181   // added. For that case, SetTensor should be used instead.
182   Status AddLocalMirror(tensorflow::Tensor&& tensor, const Device* d);
183 
184 #if !defined(IS_MOBILE_PLATFORM)
185   bool HasRemoteMirror(const Device* d, uint64 context_view_id) const;
186   bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const;
187 
188   Status AddUnshapedRemoteMirror(const Device* d, int64_t op_id, int output_num,
189                                  const string& remote_task, EagerContext* ctx);
190   Status AddResourceShapeMirror(const Device* d, int64_t op_id, int output_num,
191                                 EagerContext* ctx);
192 
193   // Return the op_id and output num if the handle refers to a remote tensor.
194   // If wait_until_ready is true, block until the remote tensor is ready on the
195   // given remote worker.
196   Status RemoteAddress(const Device* d, const bool wait_until_ready,
197                        int64_t* op_id, int32* output_num) const;
198 
199   // Called on an async remote tensor once it's shape has been determined. This
200   // transitions the tensor handle from a non-ready to a ready state by
201   // replacing the backing data abstraction to allow for the shape to be
202   // queried.
203   // creating a TensorHandle (e.g. a remote output of a remote function).
204   // This method or Poison must be called exactly once for remote tensors that
205   // were created without a known shape.
206   Status SetRemoteShape(const TensorShape& shape, const Device* d,
207                         uint64 context_view_id);
208   // If op_device is not empty, reset the devices of a remote tensor which is
209   // created without known devices (e.g. function outputs).
210   Status SetRemoteShapeAndDevice(const TensorShape& shape, const Device* d,
211                                  uint64 context_view_id, string op_device);
212 
213   // Poisons either this handle or a remote mirror with error `status`.
214   // Poisoning means that the handle will become ready and methods trying
215   // to access the remote shape will return this error `status`.
216   // Exactly one of SetRemoteShape or PoisonRemote methods must be called on a
217   // unshaped handle on a remote device.
218   void PoisonRemote(Status status, const Device* d, uint64 context_view_id);
219 #endif
220 
221   // Sets the `tensor` for this async non-ready handle making it ready.
222   // This method or Poison must be called exactly once for non-ready async
223   // handles to make them ready.
224   Status SetTensor(tensorflow::Tensor&& tensor, const Device* d);
225 
226   // Poisons either this handle or a local mirror with error `status`.
227   // Poisoning means that the handle will become ready and methods trying
228   // to access the actual tensor or shape will return this error `status`.
229   // Exactly one of SetTensor or Poison methods must be called on a non-ready
230   // tensor for a specific device.
231   void Poison(Status status, const Device* d);
232 
233   // TODO(b/154282629): Consider moving it to EagerContext.
234   // Copies to the tensor on the given device `d`, or to host iff `d` is null.
235   Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* d,
236                       tensorflow::Tensor* output) const;
237 
238   Status InferenceShape(
239       shape_inference::InferenceContext* const inference_context,
240       shape_inference::ShapeHandle* shape_handle);
241   void SetInferenceShape(
242       shape_inference::InferenceContext* const inference_context,
243       const shape_inference::ShapeHandle& shape_handle);
244   Status CopyInferenceShape(TensorHandle* other);
245 
246   // dtype for the handle. It must be the same as t.dtype() once the handle is
247   // ready.
248   const tensorflow::DataType dtype;
249 
250   enum HandleType { LOCAL = 0, PACKED = 1, REMOTE = 2 };
251 
252   HandleType Type() const;
253   string TypeString() const;
254 
255   void SetResourceHandleDtypeAndShape(
256       std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes);
257 
258   // If this TensorHandle is 1) a local tensor, and 2) a resource handle,
259   // return data types and shapes of the underlying resource.
260   Status GetResourceHandleDtypesAndShapes(
261       std::vector<DtypeAndPartialTensorShape>* result);
262 
263   // Returns the number of packed handles. 0 if the handle type is not PACKED.
264   int NumPackedHandles() const;
265   // It's called on a packed TensorHandle. Extract a handle with the given
266   // index.
267   Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
268 
269   // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)270   static bool classof(const AbstractTensorHandle* ptr) {
271     return ptr->getKind() == kEager;
272   }
273 
274  private:
275   friend class PackedTensorHandleTest;
276 
277   TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
278                const tensorflow::DataType dtype,
279                const tensorflow::TensorShape& shape, EagerContext* ctx);
280 
281   ~TensorHandle() override;
282 
283   // The TensorHandleData can either represent a local or remote tensor handle.
284   // Further, it can be in a non-ready state. It would become ready with a call
285   // to either SetTensor or SetRemoteShape which replaces the underlying data
286   // with a ready version of the tensor handle data.
287   bool IsReady() const;
288   Status WaitReady(const char* caller) const;
289 
290   tensorflow::Device* device_;
291 
292   // Device in which the op producing this tensor was executed. Equals to
293   // device_ for constant tensors.
294   // Can be nullptr if the op producing this tensor was a function executed
295   // with function library runtime.
296   tensorflow::Device* op_device_;
297 
298   // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device
299   // backing the resource. Else resource_device_ is nullptr.
300   tensorflow::Device* resource_device_;
301   // Incarnation ID of the resource device if it locates on a remote device, or
302   // 0 if it locates on a local device.
303   int64_t resource_remote_device_incarnation_;
304 
305   // If true, the handle refers to a remote tensor which is created without
306   // known devices. The actual devices are set by SetRemoteShape. The devices
307   // should be accessed once the handle is ready.
308   const bool unknown_device_ = false;
309 
310   mutable mutex mu_;
311 
312   // Map of local mirrors. This can include both ready and non-ready mirrors.
313   std::unordered_map<const tensorflow::Device*, LocalTensorHandleData>
314       local_mirrors_ TF_GUARDED_BY(mu_);
315 #if !defined(IS_MOBILE_PLATFORM)
316   // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
317   // variable is ready, since we could get the shape locally without remote copy
318   // then.
319   std::unordered_map<string, RemoteTensorHandleData> resource_shape_mirrors_
320       TF_GUARDED_BY(mu_);
321   // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be
322   // a fixed size map.
323   std::unordered_map<string, RemoteTensorHandleData> remote_mirrors_
324       TF_GUARDED_BY(mu_);
325 #endif
326 
327   // `ctx` is only guaranteed to be set if the handle is not "ready". This is
328   // typically true when the handle was produced during async execution.
329   // `ctx` object is not owned and should outlive this handle.
330   //
331   // TODO(b/150614042): Reference count EagerContext to ensure that 'device_' of
332   // a TensorHandle does not outlive the EagerContext from which it came?
333   EagerContext* const ctx_;
334 
335   // Does not need synchronization because it can be accessed only after
336   // WaitReady() has returned. At that point, is_poisoned_ is immutable.
337   Status is_poisoned_;
338 
339   // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
340   // refers to a remote resource handle, we store data types and shapes for
341   // the underlying resource.
342   std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
343 
344   // A handle data which refers to multiple TensorHandles of the same dtype and
345   // shape.
346   class PackedTensorHandleData {
347    public:
348     // Initialize handle data from list of tensor handles.
349     // Ownership of the tensor handles is shared between the
350     // `PackedTensorHandleData` and the caller (the reference count for the
351     // given handles is incremented).
352     // TODO(b/170414377): Use `TensorHandlePtr` instead.
353     PackedTensorHandleData(std::vector<TensorHandle*>&& handles,
354                            const TensorShape& shape);
355 
356     ~PackedTensorHandleData();
357 
358     Status Shape(TensorShape* shape) const;
359     Status NumDims(int* num_dims) const;
360     Status Dim(int dim_index, int64_t* dim) const;
361     Status NumElements(int64_t* num_elements) const;
362     Status Unprotect();
363     bool IsReady() const;
364     Status WaitReady(const char* caller) const;
365     void Poison(Status status);
366     string DebugString() const;
367 
368     // Number of packed handles.
369     int NumPackedHandles() const;
370     // Extract a handle on the given index.
371     Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
372 
373    private:
374     // TODO(b/170414377): Use `TensorHandlePtr` instead.
375     const std::vector<TensorHandle*> handles_;
376     const TensorShape shape_;
377 
378     mutable mutex mu_;
379     Status is_poisoned_ TF_GUARDED_BY(mu_);
380   };
381 
382   // Does not need synchronization because it can be accessed only after
383   // WaitReady() has returned. At that point, data_ is immutable.
384 #if !defined(IS_MOBILE_PLATFORM)
385   absl::variant<LocalTensorHandleData, PackedTensorHandleData,
386                 RemoteTensorHandleData>
387       data_;
388 #else
389   absl::variant<LocalTensorHandleData, PackedTensorHandleData> data_;
390 #endif
391 
392   PartialTensorShape inference_shape_;
393 };
394 
395 // Returns the device backing the resource. Else, returns nullptr.
396 Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);
397 
398 class TensorHandleInterface : public ImmediateExecutionTensorHandle {
399  public:
400 };
401 
402 template <typename T>
TensorHandleFromInterface(T * handle)403 inline TensorHandle* TensorHandleFromInterface(T* handle) {
404   return down_cast<TensorHandle*>(handle);
405 }
406 
407 }  // namespace tensorflow
408 
409 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
410