1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
17
18 #include <cstring>
19 #include <memory>
20
21 #include "absl/strings/str_cat.h"
22 #include "absl/types/optional.h"
23 #include "absl/types/variant.h"
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/c/eager/c_api.h"
26 #include "tensorflow/c/eager/c_api_experimental.h"
27 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
28 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
29 #include "tensorflow/c/tf_status.h"
30 #include "tensorflow/c/tf_status_helper.h"
31
32 namespace tensorflow {
33 namespace parallel_device {
34 namespace {
35
36 class OpDeleter {
37 public:
operator ()(TFE_Op * to_delete) const38 void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
39 };
40
41 using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
42
43 using MaybeParallelTensorOwned =
44 absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
45
46 using MaybeParallelTensorUnowned =
47 absl::variant<ParallelTensor*, TFE_TensorHandle*>;
48
49 // A ParallelDevice on its own is not registered with a TFE_Context, and so has
50 // no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
51 // name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
52 // placed on the parallel device.
53 class NamedParallelDevice {
54 public:
NamedParallelDevice(const std::string & name,std::unique_ptr<ParallelDevice> parallel_device)55 NamedParallelDevice(const std::string& name,
56 std::unique_ptr<ParallelDevice> parallel_device)
57 : device_name_(name), parallel_device_(std::move(parallel_device)) {}
name() const58 const std::string& name() const { return device_name_; }
device() const59 const ParallelDevice& device() const { return *parallel_device_; }
60
61 private:
62 std::string device_name_;
63 std::unique_ptr<ParallelDevice> parallel_device_;
64 };
65
ExecuteWithSpecialOps(const ParallelDevice & parallel_device,const std::string & parallel_device_name,TFE_Context * context,std::vector<MaybeParallelTensorUnowned> inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,TF_Status * status)66 absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
67 const ParallelDevice& parallel_device,
68 const std::string& parallel_device_name, TFE_Context* context,
69 std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
70 const TFE_OpAttrs* attributes, int expected_max_outputs,
71 TF_Status* status) {
72 absl::optional<std::vector<MaybeParallelTensorOwned>> result;
73 // TODO(allenl): We should remove "TPU" from these op names at the very least,
74 // or consider other ways of packing/unpacking parallel tensors.
75 if (operation_name == std::string("TPUReplicatedInput")) {
76 // Special-cased operation for packing per-device tensors into one parallel
77 // tensor.
78 if (inputs.size() != parallel_device.num_underlying_devices()) {
79 std::string message(absl::StrCat(
80 "The parallel device ", parallel_device_name, " expected ",
81 parallel_device.num_underlying_devices(),
82 " inputs to TPUReplicatedInput, but got ", inputs.size()));
83 TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
84 return result;
85 }
86 std::vector<TensorHandlePtr> components;
87 components.reserve(inputs.size());
88 for (int i = 0; i < inputs.size(); ++i) {
89 if (absl::holds_alternative<ParallelTensor*>(inputs[i])) {
90 std::string message(absl::StrCat(
91 "Expected all inputs to TPUReplicatedInput to be non-parallel "
92 "TensorHandles. The input ",
93 i,
94 " was a parallel tensor (already "
95 "placed on the parallel device)."));
96 TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
97 return result;
98 }
99 components.emplace_back(TFE_TensorHandleCopySharingTensor(
100 absl::get<TFE_TensorHandle*>(inputs[i]), status));
101 }
102 std::vector<MaybeParallelTensorOwned> result_content;
103 result_content.reserve(1);
104 result_content.push_back(ParallelTensor::FromTensorHandles(
105 parallel_device, std::move(components), status));
106 if (TF_GetCode(status) != TF_OK) return result;
107 result.emplace(std::move(result_content));
108 return result;
109 } else if (operation_name == std::string("TPUReplicatedOutput")) {
110 // Special-cased operation for un-packing one parallel tensor into
111 // per-device tensors.
112 OpPtr op(TFE_NewOp(context, operation_name, status));
113 TFE_OpAddAttrs(op.get(), attributes);
114 int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
115 if (TF_GetCode(status) != TF_OK) return result;
116 if (expected_outputs != parallel_device.num_underlying_devices()) {
117 std::string message(absl::StrCat(
118 "The parallel device ", parallel_device_name, " expected ",
119 parallel_device.num_underlying_devices(),
120 " outputs for TPUReplicatedOutput, but got ", expected_outputs));
121 TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
122 return result;
123 }
124 if (absl::holds_alternative<TFE_TensorHandle*>(inputs[0])) {
125 TF_SetStatus(status, TF_INVALID_ARGUMENT,
126 "Expected the input to "
127 "TPUReplicatedOutput to be a parallel tensor (placed on the "
128 "parallel device).");
129 return result;
130 }
131 ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]);
132 std::vector<MaybeParallelTensorOwned> outputs;
133 outputs.reserve(t->num_tensors());
134 for (int i = 0; i < t->num_tensors(); ++i) {
135 TensorHandlePtr this_output(
136 TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
137 outputs.emplace_back(std::move(this_output));
138 if (TF_GetCode(status) != TF_OK) return result;
139 }
140 result.emplace(std::move(outputs));
141 return result;
142 }
143 std::vector<ParallelTensor*> parallel_inputs;
144 std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
145 parallel_inputs.reserve(inputs.size());
146 implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
147 for (const auto& input : inputs) {
148 if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
149 if (operation_name == std::string("_EagerConst")) {
150 // Non-parallel tensors from _EagerConst/tf.constant are implicitly
151 // broadcast, i.e. set as the input to each parallel operation. This
152 // allows code like "tf.constant(1.)" or "tf.reduce_sum(..., axis=1)"
153 // (where the value starts on the host), without allowing other implicit
154 // copies/broadcasts. Other implicit copies may be supported eventually,
155 // but need special handling for gradients (gradient of copy-on is not
156 // just copy-off but includes a sum) and consideration of performance.
157 //
158 // TODO(allenl): There may be smarter ways to do this copy in some
159 // cases, i.e. with a collective broadcast. We'll need to be careful
160 // about things that are taken as inputs on the host or on their
161 // existing device (for multi-device functions).
162 std::unique_ptr<ParallelTensor> parallel_tensor(
163 parallel_device.CopyToParallelDevice(
164 context, absl::get<TFE_TensorHandle*>(input), status));
165 if (TF_GetCode(status) != TF_OK) return absl::nullopt;
166 parallel_inputs.push_back(parallel_tensor.get());
167 implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
168 } else {
169 TF_SetStatus(
170 status, TF_INVALID_ARGUMENT,
171 absl::StrCat(
172 "Got a non-parallel tensor ",
173 tensorflow::unwrap(absl::get<TFE_TensorHandle*>(input))
174 ->DebugString(),
175 " as input to a parallel operation. First pack non-parallel "
176 "tensors for each device into a parallel tensor explicitly.")
177 .c_str());
178 return absl::nullopt;
179 }
180 } else {
181 parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
182 }
183 }
184 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
185 maybe_parallel_results(
186 parallel_device.Execute(context, parallel_inputs, operation_name,
187 attributes, expected_max_outputs, status));
188 if (!maybe_parallel_results.has_value()) return result;
189 std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
190 std::move(maybe_parallel_results.value()));
191 std::vector<MaybeParallelTensorOwned> result_content;
192 result_content.reserve(parallel_results.size());
193 for (std::unique_ptr<ParallelTensor>& parallel_result : parallel_results) {
194 result_content.push_back(
195 MaybeParallelTensorOwned(std::move(parallel_result)));
196 }
197 result.emplace(std::move(result_content));
198 return result;
199 }
200
201 // Used as an argument to TFE_NewCustomDeviceTensorHandle, indicating how
202 // ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
203 // reference counts drop to zero.
ParallelTensorDeallocator(void * data)204 void ParallelTensorDeallocator(void* data) {
205 delete reinterpret_cast<ParallelTensor*>(data);
206 }
207
208 // Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing the
209 // number of dimensions of a parallel tensor.
ParallelTensorNumDims(void * data,TF_Status * status)210 int ParallelTensorNumDims(void* data, TF_Status* status) {
211 const std::vector<int64_t>* shape;
212 Status s = reinterpret_cast<ParallelTensor*>(data)->Shape(&shape);
213 if (!s.ok()) {
214 Set_TF_Status_from_Status(status, s);
215 return -1;
216 }
217 return shape->size();
218 }
219
220 // Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing a
221 // dimension of a parallel tensor.
ParallelTensorDim(void * data,int dim_index,TF_Status * status)222 int64_t ParallelTensorDim(void* data, int dim_index, TF_Status* status) {
223 const std::vector<int64_t>* shape;
224 Status s = reinterpret_cast<ParallelTensor*>(data)->Shape(&shape);
225 if (!s.ok()) {
226 Set_TF_Status_from_Status(status, s);
227 return -1;
228 }
229 return (*shape)[dim_index];
230 }
231
ParallelTensorSummarize(void * data,TF_Status * status)232 TF_Buffer* ParallelTensorSummarize(void* data, TF_Status* status) {
233 ParallelTensor* parallel_tensor = reinterpret_cast<ParallelTensor*>(data);
234 std::string summary;
235 Status cpp_status = parallel_tensor->SummarizeValue(summary);
236 if (!cpp_status.ok()) {
237 Set_TF_Status_from_Status(status, cpp_status);
238 return nullptr;
239 }
240 return TF_NewBufferFromString(summary.data(), summary.size());
241 }
242
ParallelTensorToTensorHandle(const std::string & parallel_device_name,TFE_Context * context,std::unique_ptr<ParallelTensor> t,TF_Status * status)243 TensorHandlePtr ParallelTensorToTensorHandle(
244 const std::string& parallel_device_name, TFE_Context* context,
245 std::unique_ptr<ParallelTensor> t, TF_Status* status) {
246 // The resulting TensorHandle owns an opaque pointer to "device memory", which
247 // for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
248 // deleted, it will call ParallelTensorDeallocator to free the struct.
249 ParallelTensor* t_released = t.release();
250 TFE_CustomDeviceTensorHandleMethods handle_methods;
251 handle_methods.num_dims = &ParallelTensorNumDims;
252 handle_methods.dim = &ParallelTensorDim;
253 handle_methods.deallocator = &ParallelTensorDeallocator;
254 handle_methods.summarize = &ParallelTensorSummarize;
255 return TensorHandlePtr(TFE_NewCustomDeviceTensorHandle(
256 context, parallel_device_name.c_str(), t_released->dtype(), t_released,
257 handle_methods, status));
258 }
259
260 // For TFE_CustomDevice::copy_tensor_to_device in the parallel device
261 // registration.
262 //
263 // Since this function is used to satisfy the TFE_CustomDevice C API,
264 // device_info is passed in using a C-style generic. It must always be a
265 // ParallelDevice.
CopyToParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status,void * device_info)266 TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
267 TFE_TensorHandle* tensor,
268 TF_Status* status, void* device_info) {
269 TF_SetStatus(
270 status, TF_UNIMPLEMENTED,
271 absl::StrCat("Trying to copy a tensor ",
272 tensorflow::unwrap(tensor)->DebugString(),
273 " on to a parallel device. Pack non-parallel "
274 "tensors for each device into a parallel tensor explicitly.")
275 .c_str());
276 return nullptr;
277 }
278
279 // For TFE_CustomDevice::copy_tensor_from_device in the parallel device
280 // registration.
281 //
282 // Currently this is an error, and un-packing ParallelTensors must be performed
283 // explicitly by running a TPUReplicatedOutput operation on the parallel device.
284 //
285 // TODO(allenl): There are some use-cases that are only supported by copying to
286 // host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
287 // need to return something here or address these use-cases one by one.
CopyTensorFromParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,const char * target_device_name,TF_Status * status,void * device_info)288 TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
289 TFE_TensorHandle* tensor,
290 const char* target_device_name,
291 TF_Status* status,
292 void* device_info) {
293 ParallelTensor* parallel_tensor = reinterpret_cast<ParallelTensor*>(
294 TFE_TensorHandleDevicePointer(tensor, status));
295 if (TF_GetCode(status) != TF_OK) return nullptr;
296 if (parallel_tensor->num_tensors() == 1) {
297 // Copy-off for single-device tensors is allowed to make debugging dynamic
298 // control flow easier.
299 return TFE_TensorHandleCopySharingTensor(parallel_tensor->tensor(0),
300 status);
301 } else {
302 TF_SetStatus(
303 status, TF_UNIMPLEMENTED,
304 absl::StrCat(
305 "Trying to copy a tensor out of a parallel device. Since there "
306 "are multiple components to parallel tensors, they must be "
307 "unpacked explicitly.\n",
308 tensorflow::unwrap(tensor)->DebugString())
309 .c_str());
310 return nullptr;
311 }
312 }
313
314 // For TFE_CustomDevice::execute in the parallel device registration.
315 //
316 // Since this function is used to satisfy the TFE_CustomDevice C API,
317 // device_info is passed in using a C-style generic. It must always be a
318 // ParallelDevice.
ParallelDeviceExecute(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * status,void * device_info)319 void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
320 TFE_TensorHandle** outputs, TF_Status* status,
321 void* device_info) {
322 const char* requested_placement = TFE_OpGetDevice(original_op, status);
323 if (*requested_placement == '\0') {
324 TF_SetStatus(
325 status, TF_INTERNAL,
326 "Ops must be placed on the parallel device explicitly, or their inputs "
327 "first un-packed. Got an un-placed op with an input placed on the "
328 "parallel device.");
329 return;
330 }
331 TFE_Context* context = TFE_OpGetContext(original_op, status);
332 if (TF_GetCode(status) != TF_OK) return;
333 const char* operation_name = TFE_OpGetName(original_op, status);
334 if (TF_GetCode(status) != TF_OK) return;
335 const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
336
337 NamedParallelDevice* named_device =
338 reinterpret_cast<NamedParallelDevice*>(device_info);
339 std::vector<MaybeParallelTensorUnowned> typed_inputs;
340 int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
341 if (TF_GetCode(status) != TF_OK) return;
342 typed_inputs.reserve(num_inputs);
343 for (int i = 0; i < num_inputs; ++i) {
344 TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
345 if (TF_GetCode(status) != TF_OK) return;
346 const char* tensor_handle_device =
347 TFE_TensorHandleDeviceName(input, status);
348 if (TF_GetCode(status) != TF_OK) return;
349 if (named_device->name() == tensor_handle_device) {
350 // We assume that any tensors already placed on this device are
351 // ParallelTensors.
352 typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
353 TFE_TensorHandleDevicePointer(input, status)));
354 if (TF_GetCode(status) != TF_OK) return;
355 } else {
356 typed_inputs.emplace_back(input);
357 }
358 }
359
360 absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
361 ExecuteWithSpecialOps(named_device->device(), named_device->name(),
362 context, std::move(typed_inputs), operation_name,
363 attributes, *num_outputs, status));
364 if (TF_GetCode(status) != TF_OK) return;
365 if (!maybe_typed_outputs.has_value()) {
366 TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
367 return;
368 }
369
370 std::vector<MaybeParallelTensorOwned> typed_outputs(
371 std::move(maybe_typed_outputs.value()));
372
373 if (typed_outputs.size() > *num_outputs) {
374 TF_SetStatus(status, TF_INTERNAL,
375 "The allocated output buffer was too small.");
376 return;
377 }
378
379 for (int i = 0; i < typed_outputs.size(); ++i) {
380 MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
381 if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
382 outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
383 } else {
384 outputs[i] = ParallelTensorToTensorHandle(
385 named_device->name(), context,
386 std::move(absl::get<std::unique_ptr<ParallelTensor>>(
387 typed_output)),
388 status)
389 .release();
390 if (TF_GetCode(status) != TF_OK) return;
391 }
392 }
393 *num_outputs = typed_outputs.size();
394 }
395
396 // For TFE_CustomDevice::delete_device in the parallel device registration.
397 //
398 // Since this function is used to satisfy the TFE_CustomDevice C API,
399 // device_info is passed in using a C-style generic. It must always be a
400 // ParallelDevice.
DeleteParallelDevice(void * device_info)401 void DeleteParallelDevice(void* device_info) {
402 delete reinterpret_cast<NamedParallelDevice*>(device_info);
403 }
404
405 } // namespace
406
AllocateParallelDevice(const char * device_name,const char * const * underlying_devices,int num_underlying_devices,TFE_CustomDevice * device,void ** device_info)407 void AllocateParallelDevice(const char* device_name,
408 const char* const* underlying_devices,
409 int num_underlying_devices,
410 TFE_CustomDevice* device, void** device_info) {
411 device->copy_tensor_to_device = &CopyToParallelDevice;
412 device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
413 device->delete_device = &DeleteParallelDevice;
414 device->execute = &ParallelDeviceExecute;
415 std::vector<std::string> underlying_devices_vector;
416 underlying_devices_vector.reserve(num_underlying_devices);
417 for (int device_index = 0; device_index < num_underlying_devices;
418 ++device_index) {
419 underlying_devices_vector.push_back(underlying_devices[device_index]);
420 }
421 std::unique_ptr<ParallelDevice> parallel_device(
422 new ParallelDevice(underlying_devices_vector));
423 *device_info =
424 new NamedParallelDevice{device_name, std::move(parallel_device)};
425 }
426 } // namespace parallel_device
427 } // namespace tensorflow
428