1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
17 #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
18
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "absl/types/optional.h"
24 #include "absl/types/span.h"
25 #include "absl/types/variant.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/eager/c_api.h"
28 #include "tensorflow/c/eager/c_api_experimental.h"
29 #include "tensorflow/c/eager/tfe_op_internal.h"
30 #include "tensorflow/core/framework/cancellation.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/types.h"
33
34 namespace tensorflow {
35 namespace parallel_device {
36
37 // Functor for making unique_ptrs slightly more ergonomic. Using
38 // decltype(delete_fn) in the unique_ptr's second template argument requires
39 // passing a function pointer to delete_fn when constructing the unique_ptr.
40 class TensorHandleDeleter {
41 public:
operator()42 void operator()(TFE_TensorHandle* to_delete) const {
43 TFE_DeleteTensorHandle(to_delete);
44 }
45 };
46
47 using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
48
49 class ParallelTensor;
50 class DeviceThread;
51
52 // Forwards operations to `devices`, maintaining ParallelTensor with components
53 // placed on each underlying device.
54 class ParallelDevice {
55 public:
56 // Eager async execution is only supported when remote eager is not in use
57 // (b/157523095).
58 explicit ParallelDevice(const std::vector<std::string>& devices,
59 const bool is_async = false);
60
61 ~ParallelDevice();
62
63 // Helper to copy a tensor handle from another device once for each component
64 // of the ParallelDevice.
65 //
66 // Sets a bad status and returns a nullptr if `tensor` is already on the
67 // ParallelDevice, or if the individual copies fail.
68 std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
69 TFE_TensorHandle* tensor,
70 TF_Status* status) const;
71
72 // Construct a parallel tensor consisting of the scalar values from `values`.
73 template <typename DataType>
74 std::unique_ptr<ParallelTensor> ScalarsFromSequence(
75 absl::Span<const DataType> values, TFE_Context* context,
76 TF_Status* status) const;
77
78 // A parallel tensor with scalar integers numbering component devices.
79 std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
80 TF_Status* status) const;
81
82 // The number of devices operations run on.
num_underlying_devices()83 size_t num_underlying_devices() const { return underlying_devices_.size(); }
84
85 // The devices operations run on.
underlying_devices()86 const std::vector<std::string>& underlying_devices() const {
87 return underlying_devices_;
88 }
89
90 // Takes a description of a single operation being executed on the
91 // ParallelDevice, and in turn runs one operation per component device with
92 // its corresponding inputs from the input ParallelTensors. Wraps the
93 // resulting per-device and per-output TFE_TensorHandles into one
94 // ParallelTensor per output of the original operation.
95 //
96 // Attributes are forwarded to executed operations unmodified.
97 //
98 // The returned optional has a value if and only if `status` evaluates to
99 // TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
100 // if sanity checks on dtypes/metadata fail.
101 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
102 TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
103 const char* operation_name, const TFE_OpAttrs* attributes,
104 int expected_max_outputs, TF_Status* status) const;
105
106 // A non-blocking version of `Execute`. After each call, `Join` must be called
107 // before `StartExecute` is called again. Using `StartExecute` with `Join`
108 // allows the caller to schedule computation on multiple ParallelDevices
109 // without sequencing those operations (first call `StartExecute` on each
110 // parallel device, then call `Join` on each; even if some of the `Join`s
111 // return a bad status the caller must run all of the `Join`s or any future
112 // `StartExecute`s will deadlock).
113 //
114 // If `is_async=false` (constructor argument), `cancellation_manager` must
115 // live until `Join` finishes. If `is_async=true` it must live until `Join` is
116 // followed by `TFE_ContextAsyncWait` to clear pending operations. It will be
117 // used to cancel all other operations if any fails.
118 //
119 // Set step_id to configure the step id used for rendezvous creation. step id
120 // of value -1 is reserved for global rendezvous and should not be set here.
121 void StartExecute(TFE_Context* context,
122 const std::vector<ParallelTensor*>& inputs,
123 const char* operation_name, const TFE_OpAttrs* attributes,
124 int expected_max_outputs,
125 CancellationManager& cancellation_manager,
126 absl::optional<int64_t> step_id = absl::nullopt) const;
127
128 // Blocks until the previous `StartExecute` has run `TFE_Execute` on each
129 // device. If is_async=false (constructor argument) this means the ops have
130 // run and have results. If is_async=true it means that all of the
131 // device-specific executors have scheduled the op.
132 //
133 // Accepts inferred shapes for outputs (`expected_output_shapes`), which if
134 // fully defined will avoid querying the shapes of the underlying
135 // TensorHandles when ParallelTensor::Shape is called. This allows async
136 // computation to continue without blocking.
137 //
138 // The return status and value is the same as `Execute`.
139 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Join(
140 const std::vector<PartialTensorShape>& expected_output_shapes,
141 TF_Status* status) const;
142
143 void AsyncWait(TFE_Context* context, TF_Status* status) const;
144
145 // Device strings for component devices that only include a
146 // worker/task/replica if any of those differ across components. Useful for
147 // printing debug messages.
148 std::vector<std::string> SummarizeDeviceNames() const;
149
150 private:
151 // A sequence of device names, indicating which devices replicated operations
152 // are forwarded to.
153 const std::vector<std::string> underlying_devices_;
154 // A sequence of thread wrappers, one per device, for executing operations in
155 // parallel.
156 //
157 // Conceptually this is a thread pool with one thread per device. It requires
158 // less synchronization than a thread pool would for this task, since Execute
159 // acquires each thread in order (and so only one Execute will schedule
160 // blocking collective operations at a time), and avoids some dynamic
161 // allocation/scheduling.
162 //
163 // TODO(allenl): Keep a map from outer thread to list of inner threads rather
164 // than a single list of threads so aliased nested parallel devices don't
165 // re-use a thread.
166 std::vector<std::unique_ptr<DeviceThread>> device_threads_;
167 // A cancellation manager to use if the caller does not provide one. When ops
168 // are executed asynchronously this must outlive the queued op, so it can't be
169 // function-local to Execute.
170 mutable std::unique_ptr<CancellationManager> default_cancellation_manager_;
171 };
172
173 // Contains a tuple of tensors, one on each of the `underlying_devices_` of the
174 // ParallelDevice.
175 class ParallelTensor {
176 public:
177 // Construct a ParallelTensor from TensorHandles placed on the component
178 // devices of a ParallelDevice. If called, ParallelTensor::Shape inspects
179 // `components` to determine a shape.
180 static std::unique_ptr<ParallelTensor> FromTensorHandles(
181 const ParallelDevice& parallel_device,
182 std::vector<TensorHandlePtr> components, TF_Status* status);
183 // Uses the provided shape without additional checks, which avoids blocking
184 // when ParallelTensor::Shape is called.
185 static std::unique_ptr<ParallelTensor> FromTensorHandles(
186 const ParallelDevice& parallel_device,
187 std::vector<TensorHandlePtr> components, absl::Span<const int64_t> shape,
188 TF_Status* status);
189
num_tensors()190 size_t num_tensors() const { return tensors_.size(); }
tensor(size_t index)191 TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
192
193 // If the `shape` argument to `FromTensorHandles` is specified, returns that.
194 //
195 // Otherwise if all of the tensors have the same shape, returns that via the
196 // `shape` output argument. This blocks waiting for async tensors, may return
197 // a delayed bad status encountered during async execution, and will return a
198 // bad status unless all tensors have the same shape.
199 Status Shape(const std::vector<int64_t>** shape) const;
dtype()200 TF_DataType dtype() const { return dtype_; }
201
202 // Sets its output argument to a summary of the values of this tensor on every
203 // component device.
204 Status SummarizeValue(std::string& summary);
205
206 private:
ParallelTensor(const ParallelDevice & device,std::vector<TensorHandlePtr> tensors,absl::Span<const int64_t> shape,const TF_DataType dtype)207 ParallelTensor(const ParallelDevice& device,
208 std::vector<TensorHandlePtr> tensors,
209 absl::Span<const int64_t> shape, const TF_DataType dtype)
210 : device_(device),
211 tensors_(std::move(tensors)),
212 shape_(std::vector<int64_t>(shape.begin(), shape.end())),
213 dtype_(dtype) {}
ParallelTensor(const ParallelDevice & device,std::vector<TensorHandlePtr> tensors,const TF_DataType dtype)214 ParallelTensor(const ParallelDevice& device,
215 std::vector<TensorHandlePtr> tensors, const TF_DataType dtype)
216 : device_(device),
217 tensors_(std::move(tensors)),
218 shape_(absl::nullopt),
219 dtype_(dtype) {}
220
221 const ParallelDevice& device_;
222 const std::vector<TensorHandlePtr> tensors_;
223 // Parallel tensors are immutable but compute their shape lazily unless it is
224 // provided on construction. The optional has a value if the lazy computation
225 // has been completed or the shape was provided on construction.
226 mutable absl::optional<std::vector<int64_t>> shape_;
227 const TF_DataType dtype_;
228 };
229
230 template <typename DataType>
ScalarsFromSequence(absl::Span<DataType const> values,TFE_Context * context,TF_Status * status)231 std::unique_ptr<ParallelTensor> ParallelDevice::ScalarsFromSequence(
232 absl::Span<DataType const> values, TFE_Context* context,
233 TF_Status* status) const {
234 std::vector<TensorHandlePtr> components;
235 components.reserve(underlying_devices_.size());
236
237 if (values.size() != num_underlying_devices()) {
238 TF_SetStatus(
239 status, TF_INVALID_ARGUMENT,
240 "Number of values did not match number of underlying devices.");
241 return nullptr;
242 }
243 TF_DataType datatype_enum(
244 static_cast<TF_DataType>(DataTypeToEnum<DataType>().value));
245 for (int device_index = 0; device_index < num_underlying_devices();
246 ++device_index) {
247 auto device_value = absl::make_unique<DataType>();
248 *device_value = values[device_index];
249 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
250 TF_NewTensor(
251 datatype_enum, /*dims=*/nullptr, /*num_dims=*/0,
252 device_value.release(), sizeof(DataType),
253 [](void* data, size_t, void* arg) {
254 delete reinterpret_cast<DataType*>(data);
255 },
256 nullptr),
257 TF_DeleteTensor);
258 // TODO(allenl): Here and when executing regular operations, we could hold
259 // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
260 // device names repeatedly.
261 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> const_op(
262 TFE_NewOp(context, "Const", status), TFE_DeleteOp);
263 if (TF_GetCode(status) != TF_OK) return nullptr;
264 TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
265 status);
266 if (TF_GetCode(status) != TF_OK) return nullptr;
267 TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
268 if (TF_GetCode(status) != TF_OK) return nullptr;
269 TFE_OpSetAttrType(const_op.get(), "dtype", datatype_enum);
270 TFE_TensorHandle* device_handle;
271 int num_outputs = 1;
272 TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
273 if (TF_GetCode(status) != TF_OK) return nullptr;
274 components.emplace_back(device_handle);
275 }
276 return ParallelTensor::FromTensorHandles(*this, std::move(components),
277 status);
278 }
279
280 } // namespace parallel_device
281 } // namespace tensorflow
282
283 #endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
284