xref: /aosp_15_r20/external/tensorflow/tensorflow/c/eager/parallel_device/parallel_device_lib.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
17 #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/types/optional.h"
24 #include "absl/types/span.h"
25 #include "absl/types/variant.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/eager/c_api.h"
28 #include "tensorflow/c/eager/c_api_experimental.h"
29 #include "tensorflow/c/eager/tfe_op_internal.h"
30 #include "tensorflow/core/framework/cancellation.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/types.h"
33 
34 namespace tensorflow {
35 namespace parallel_device {
36 
37 // Functor for making unique_ptrs slightly more ergonomic. Using
38 // decltype(delete_fn) in the unique_ptr's second template argument requires
39 // passing a function pointer to delete_fn when constructing the unique_ptr.
40 class TensorHandleDeleter {
41  public:
operator()42   void operator()(TFE_TensorHandle* to_delete) const {
43     TFE_DeleteTensorHandle(to_delete);
44   }
45 };
46 
47 using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
48 
49 class ParallelTensor;
50 class DeviceThread;
51 
52 // Forwards operations to `devices`, maintaining ParallelTensor with components
53 // placed on each underlying device.
54 class ParallelDevice {
55  public:
56   // Eager async execution is only supported when remote eager is not in use
57   // (b/157523095).
58   explicit ParallelDevice(const std::vector<std::string>& devices,
59                           const bool is_async = false);
60 
61   ~ParallelDevice();
62 
63   // Helper to copy a tensor handle from another device once for each component
64   // of the ParallelDevice.
65   //
66   // Sets a bad status and returns a nullptr if `tensor` is already on the
67   // ParallelDevice, or if the individual copies fail.
68   std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
69                                                        TFE_TensorHandle* tensor,
70                                                        TF_Status* status) const;
71 
72   // Construct a parallel tensor consisting of the scalar values from `values`.
73   template <typename DataType>
74   std::unique_ptr<ParallelTensor> ScalarsFromSequence(
75       absl::Span<const DataType> values, TFE_Context* context,
76       TF_Status* status) const;
77 
78   // A parallel tensor with scalar integers numbering component devices.
79   std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
80                                             TF_Status* status) const;
81 
82   // The number of devices operations run on.
num_underlying_devices()83   size_t num_underlying_devices() const { return underlying_devices_.size(); }
84 
85   // The devices operations run on.
underlying_devices()86   const std::vector<std::string>& underlying_devices() const {
87     return underlying_devices_;
88   }
89 
90   // Takes a description of a single operation being executed on the
91   // ParallelDevice, and in turn runs one operation per component device with
92   // its corresponding inputs from the input ParallelTensors. Wraps the
93   // resulting per-device and per-output TFE_TensorHandles into one
94   // ParallelTensor per output of the original operation.
95   //
96   // Attributes are forwarded to executed operations unmodified.
97   //
98   // The returned optional has a value if and only if `status` evaluates to
99   // TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
100   // if sanity checks on dtypes/metadata fail.
101   absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
102       TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
103       const char* operation_name, const TFE_OpAttrs* attributes,
104       int expected_max_outputs, TF_Status* status) const;
105 
106   // A non-blocking version of `Execute`. After each call, `Join` must be called
107   // before `StartExecute` is called again. Using `StartExecute` with `Join`
108   // allows the caller to schedule computation on multiple ParallelDevices
109   // without sequencing those operations (first call `StartExecute` on each
110   // parallel device, then call `Join` on each; even if some of the `Join`s
111   // return a bad status the caller must run all of the `Join`s or any future
112   // `StartExecute`s will deadlock).
113   //
114   // If `is_async=false` (constructor argument), `cancellation_manager` must
115   // live until `Join` finishes. If `is_async=true` it must live until `Join` is
116   // followed by `TFE_ContextAsyncWait` to clear pending operations. It will be
117   // used to cancel all other operations if any fails.
118   //
119   // Set step_id to configure the step id used for rendezvous creation. step id
120   // of value -1 is reserved for global rendezvous and should not be set here.
121   void StartExecute(TFE_Context* context,
122                     const std::vector<ParallelTensor*>& inputs,
123                     const char* operation_name, const TFE_OpAttrs* attributes,
124                     int expected_max_outputs,
125                     CancellationManager& cancellation_manager,
126                     absl::optional<int64_t> step_id = absl::nullopt) const;
127 
128   // Blocks until the previous `StartExecute` has run `TFE_Execute` on each
129   // device. If is_async=false (constructor argument) this means the ops have
130   // run and have results. If is_async=true it means that all of the
131   // device-specific executors have scheduled the op.
132   //
133   // Accepts inferred shapes for outputs (`expected_output_shapes`), which if
134   // fully defined will avoid querying the shapes of the underlying
135   // TensorHandles when ParallelTensor::Shape is called. This allows async
136   // computation to continue without blocking.
137   //
138   // The return status and value is the same as `Execute`.
139   absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Join(
140       const std::vector<PartialTensorShape>& expected_output_shapes,
141       TF_Status* status) const;
142 
143   void AsyncWait(TFE_Context* context, TF_Status* status) const;
144 
145   // Device strings for component devices that only include a
146   // worker/task/replica if any of those differ across components. Useful for
147   // printing debug messages.
148   std::vector<std::string> SummarizeDeviceNames() const;
149 
150  private:
151   // A sequence of device names, indicating which devices replicated operations
152   // are forwarded to.
153   const std::vector<std::string> underlying_devices_;
154   // A sequence of thread wrappers, one per device, for executing operations in
155   // parallel.
156   //
157   // Conceptually this is a thread pool with one thread per device. It requires
158   // less synchronization than a thread pool would for this task, since Execute
159   // acquires each thread in order (and so only one Execute will schedule
160   // blocking collective operations at a time), and avoids some dynamic
161   // allocation/scheduling.
162   //
163   // TODO(allenl): Keep a map from outer thread to list of inner threads rather
164   // than a single list of threads so aliased nested parallel devices don't
165   // re-use a thread.
166   std::vector<std::unique_ptr<DeviceThread>> device_threads_;
167   // A cancellation manager to use if the caller does not provide one. When ops
168   // are executed asynchronously this must outlive the queued op, so it can't be
169   // function-local to Execute.
170   mutable std::unique_ptr<CancellationManager> default_cancellation_manager_;
171 };
172 
173 // Contains a tuple of tensors, one on each of the `underlying_devices_` of the
174 // ParallelDevice.
175 class ParallelTensor {
176  public:
177   // Construct a ParallelTensor from TensorHandles placed on the component
178   // devices of a ParallelDevice. If called, ParallelTensor::Shape inspects
179   // `components` to determine a shape.
180   static std::unique_ptr<ParallelTensor> FromTensorHandles(
181       const ParallelDevice& parallel_device,
182       std::vector<TensorHandlePtr> components, TF_Status* status);
183   // Uses the provided shape without additional checks, which avoids blocking
184   // when ParallelTensor::Shape is called.
185   static std::unique_ptr<ParallelTensor> FromTensorHandles(
186       const ParallelDevice& parallel_device,
187       std::vector<TensorHandlePtr> components, absl::Span<const int64_t> shape,
188       TF_Status* status);
189 
num_tensors()190   size_t num_tensors() const { return tensors_.size(); }
tensor(size_t index)191   TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
192 
193   // If the `shape` argument to `FromTensorHandles` is specified, returns that.
194   //
195   // Otherwise if all of the tensors have the same shape, returns that via the
196   // `shape` output argument. This blocks waiting for async tensors, may return
197   // a delayed bad status encountered during async execution, and will return a
198   // bad status unless all tensors have the same shape.
199   Status Shape(const std::vector<int64_t>** shape) const;
dtype()200   TF_DataType dtype() const { return dtype_; }
201 
202   // Sets its output argument to a summary of the values of this tensor on every
203   // component device.
204   Status SummarizeValue(std::string& summary);
205 
206  private:
ParallelTensor(const ParallelDevice & device,std::vector<TensorHandlePtr> tensors,absl::Span<const int64_t> shape,const TF_DataType dtype)207   ParallelTensor(const ParallelDevice& device,
208                  std::vector<TensorHandlePtr> tensors,
209                  absl::Span<const int64_t> shape, const TF_DataType dtype)
210       : device_(device),
211         tensors_(std::move(tensors)),
212         shape_(std::vector<int64_t>(shape.begin(), shape.end())),
213         dtype_(dtype) {}
ParallelTensor(const ParallelDevice & device,std::vector<TensorHandlePtr> tensors,const TF_DataType dtype)214   ParallelTensor(const ParallelDevice& device,
215                  std::vector<TensorHandlePtr> tensors, const TF_DataType dtype)
216       : device_(device),
217         tensors_(std::move(tensors)),
218         shape_(absl::nullopt),
219         dtype_(dtype) {}
220 
221   const ParallelDevice& device_;
222   const std::vector<TensorHandlePtr> tensors_;
223   // Parallel tensors are immutable but compute their shape lazily unless it is
224   // provided on construction. The optional has a value if the lazy computation
225   // has been completed or the shape was provided on construction.
226   mutable absl::optional<std::vector<int64_t>> shape_;
227   const TF_DataType dtype_;
228 };
229 
230 template <typename DataType>
ScalarsFromSequence(absl::Span<DataType const> values,TFE_Context * context,TF_Status * status)231 std::unique_ptr<ParallelTensor> ParallelDevice::ScalarsFromSequence(
232     absl::Span<DataType const> values, TFE_Context* context,
233     TF_Status* status) const {
234   std::vector<TensorHandlePtr> components;
235   components.reserve(underlying_devices_.size());
236 
237   if (values.size() != num_underlying_devices()) {
238     TF_SetStatus(
239         status, TF_INVALID_ARGUMENT,
240         "Number of values did not match number of underlying devices.");
241     return nullptr;
242   }
243   TF_DataType datatype_enum(
244       static_cast<TF_DataType>(DataTypeToEnum<DataType>().value));
245   for (int device_index = 0; device_index < num_underlying_devices();
246        ++device_index) {
247     auto device_value = absl::make_unique<DataType>();
248     *device_value = values[device_index];
249     std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
250         TF_NewTensor(
251             datatype_enum, /*dims=*/nullptr, /*num_dims=*/0,
252             device_value.release(), sizeof(DataType),
253             [](void* data, size_t, void* arg) {
254               delete reinterpret_cast<DataType*>(data);
255             },
256             nullptr),
257         TF_DeleteTensor);
258     // TODO(allenl): Here and when executing regular operations, we could hold
259     // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
260     // device names repeatedly.
261     std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> const_op(
262         TFE_NewOp(context, "Const", status), TFE_DeleteOp);
263     if (TF_GetCode(status) != TF_OK) return nullptr;
264     TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
265                     status);
266     if (TF_GetCode(status) != TF_OK) return nullptr;
267     TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
268     if (TF_GetCode(status) != TF_OK) return nullptr;
269     TFE_OpSetAttrType(const_op.get(), "dtype", datatype_enum);
270     TFE_TensorHandle* device_handle;
271     int num_outputs = 1;
272     TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
273     if (TF_GetCode(status) != TF_OK) return nullptr;
274     components.emplace_back(device_handle);
275   }
276   return ParallelTensor::FromTensorHandles(*this, std::move(components),
277                                            status);
278 }
279 
280 }  // namespace parallel_device
281 }  // namespace tensorflow
282 
283 #endif  // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
284