xref: /aosp_15_r20/external/tensorflow/tensorflow/c/eager/parallel_device/parallel_device.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
17 
18 #include <cstring>
19 #include <memory>
20 
21 #include "absl/strings/str_cat.h"
22 #include "absl/types/optional.h"
23 #include "absl/types/variant.h"
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/c/eager/c_api.h"
26 #include "tensorflow/c/eager/c_api_experimental.h"
27 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
28 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
29 #include "tensorflow/c/tf_status.h"
30 #include "tensorflow/c/tf_status_helper.h"
31 
32 namespace tensorflow {
33 namespace parallel_device {
34 namespace {
35 
36 class OpDeleter {
37  public:
operator ()(TFE_Op * to_delete) const38   void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
39 };
40 
41 using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
42 
43 using MaybeParallelTensorOwned =
44     absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
45 
46 using MaybeParallelTensorUnowned =
47     absl::variant<ParallelTensor*, TFE_TensorHandle*>;
48 
49 // A ParallelDevice on its own is not registered with a TFE_Context, and so has
50 // no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
51 // name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
52 // placed on the parallel device.
53 class NamedParallelDevice {
54  public:
NamedParallelDevice(const std::string & name,std::unique_ptr<ParallelDevice> parallel_device)55   NamedParallelDevice(const std::string& name,
56                       std::unique_ptr<ParallelDevice> parallel_device)
57       : device_name_(name), parallel_device_(std::move(parallel_device)) {}
name() const58   const std::string& name() const { return device_name_; }
device() const59   const ParallelDevice& device() const { return *parallel_device_; }
60 
61  private:
62   std::string device_name_;
63   std::unique_ptr<ParallelDevice> parallel_device_;
64 };
65 
ExecuteWithSpecialOps(const ParallelDevice & parallel_device,const std::string & parallel_device_name,TFE_Context * context,std::vector<MaybeParallelTensorUnowned> inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,TF_Status * status)66 absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
67     const ParallelDevice& parallel_device,
68     const std::string& parallel_device_name, TFE_Context* context,
69     std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
70     const TFE_OpAttrs* attributes, int expected_max_outputs,
71     TF_Status* status) {
72   absl::optional<std::vector<MaybeParallelTensorOwned>> result;
73   // TODO(allenl): We should remove "TPU" from these op names at the very least,
74   // or consider other ways of packing/unpacking parallel tensors.
75   if (operation_name == std::string("TPUReplicatedInput")) {
76     // Special-cased operation for packing per-device tensors into one parallel
77     // tensor.
78     if (inputs.size() != parallel_device.num_underlying_devices()) {
79       std::string message(absl::StrCat(
80           "The parallel device ", parallel_device_name, " expected ",
81           parallel_device.num_underlying_devices(),
82           " inputs to TPUReplicatedInput, but got ", inputs.size()));
83       TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
84       return result;
85     }
86     std::vector<TensorHandlePtr> components;
87     components.reserve(inputs.size());
88     for (int i = 0; i < inputs.size(); ++i) {
89       if (absl::holds_alternative<ParallelTensor*>(inputs[i])) {
90         std::string message(absl::StrCat(
91             "Expected all inputs to TPUReplicatedInput to be non-parallel "
92             "TensorHandles. The input ",
93             i,
94             " was a parallel tensor (already "
95             "placed on the parallel device)."));
96         TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
97         return result;
98       }
99       components.emplace_back(TFE_TensorHandleCopySharingTensor(
100           absl::get<TFE_TensorHandle*>(inputs[i]), status));
101     }
102     std::vector<MaybeParallelTensorOwned> result_content;
103     result_content.reserve(1);
104     result_content.push_back(ParallelTensor::FromTensorHandles(
105         parallel_device, std::move(components), status));
106     if (TF_GetCode(status) != TF_OK) return result;
107     result.emplace(std::move(result_content));
108     return result;
109   } else if (operation_name == std::string("TPUReplicatedOutput")) {
110     // Special-cased operation for un-packing one parallel tensor into
111     // per-device tensors.
112     OpPtr op(TFE_NewOp(context, operation_name, status));
113     TFE_OpAddAttrs(op.get(), attributes);
114     int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
115     if (TF_GetCode(status) != TF_OK) return result;
116     if (expected_outputs != parallel_device.num_underlying_devices()) {
117       std::string message(absl::StrCat(
118           "The parallel device ", parallel_device_name, " expected ",
119           parallel_device.num_underlying_devices(),
120           " outputs for TPUReplicatedOutput, but got ", expected_outputs));
121       TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
122       return result;
123     }
124     if (absl::holds_alternative<TFE_TensorHandle*>(inputs[0])) {
125       TF_SetStatus(status, TF_INVALID_ARGUMENT,
126                    "Expected the input to "
127                    "TPUReplicatedOutput to be a parallel tensor (placed on the "
128                    "parallel device).");
129       return result;
130     }
131     ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]);
132     std::vector<MaybeParallelTensorOwned> outputs;
133     outputs.reserve(t->num_tensors());
134     for (int i = 0; i < t->num_tensors(); ++i) {
135       TensorHandlePtr this_output(
136           TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
137       outputs.emplace_back(std::move(this_output));
138       if (TF_GetCode(status) != TF_OK) return result;
139     }
140     result.emplace(std::move(outputs));
141     return result;
142   }
143   std::vector<ParallelTensor*> parallel_inputs;
144   std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
145   parallel_inputs.reserve(inputs.size());
146   implicitly_broadcast_tensors.reserve(inputs.size());  // not tight
147   for (const auto& input : inputs) {
148     if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
149       if (operation_name == std::string("_EagerConst")) {
150         // Non-parallel tensors from _EagerConst/tf.constant are implicitly
151         // broadcast, i.e. set as the input to each parallel operation. This
152         // allows code like "tf.constant(1.)" or "tf.reduce_sum(..., axis=1)"
153         // (where the value starts on the host), without allowing other implicit
154         // copies/broadcasts. Other implicit copies may be supported eventually,
155         // but need special handling for gradients (gradient of copy-on is not
156         // just copy-off but includes a sum) and consideration of performance.
157         //
158         // TODO(allenl): There may be smarter ways to do this copy in some
159         // cases, i.e. with a collective broadcast. We'll need to be careful
160         // about things that are taken as inputs on the host or on their
161         // existing device (for multi-device functions).
162         std::unique_ptr<ParallelTensor> parallel_tensor(
163             parallel_device.CopyToParallelDevice(
164                 context, absl::get<TFE_TensorHandle*>(input), status));
165         if (TF_GetCode(status) != TF_OK) return absl::nullopt;
166         parallel_inputs.push_back(parallel_tensor.get());
167         implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
168       } else {
169         TF_SetStatus(
170             status, TF_INVALID_ARGUMENT,
171             absl::StrCat(
172                 "Got a non-parallel tensor ",
173                 tensorflow::unwrap(absl::get<TFE_TensorHandle*>(input))
174                     ->DebugString(),
175                 " as input to a parallel operation. First pack non-parallel "
176                 "tensors for each device into a parallel tensor explicitly.")
177                 .c_str());
178         return absl::nullopt;
179       }
180     } else {
181       parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
182     }
183   }
184   absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
185       maybe_parallel_results(
186           parallel_device.Execute(context, parallel_inputs, operation_name,
187                                   attributes, expected_max_outputs, status));
188   if (!maybe_parallel_results.has_value()) return result;
189   std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
190       std::move(maybe_parallel_results.value()));
191   std::vector<MaybeParallelTensorOwned> result_content;
192   result_content.reserve(parallel_results.size());
193   for (std::unique_ptr<ParallelTensor>& parallel_result : parallel_results) {
194     result_content.push_back(
195         MaybeParallelTensorOwned(std::move(parallel_result)));
196   }
197   result.emplace(std::move(result_content));
198   return result;
199 }
200 
201 // Used as an argument to TFE_NewCustomDeviceTensorHandle, indicating how
202 // ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
203 // reference counts drop to zero.
ParallelTensorDeallocator(void * data)204 void ParallelTensorDeallocator(void* data) {
205   delete reinterpret_cast<ParallelTensor*>(data);
206 }
207 
208 // Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing the
209 // number of dimensions of a parallel tensor.
ParallelTensorNumDims(void * data,TF_Status * status)210 int ParallelTensorNumDims(void* data, TF_Status* status) {
211   const std::vector<int64_t>* shape;
212   Status s = reinterpret_cast<ParallelTensor*>(data)->Shape(&shape);
213   if (!s.ok()) {
214     Set_TF_Status_from_Status(status, s);
215     return -1;
216   }
217   return shape->size();
218 }
219 
220 // Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing a
221 // dimension of a parallel tensor.
ParallelTensorDim(void * data,int dim_index,TF_Status * status)222 int64_t ParallelTensorDim(void* data, int dim_index, TF_Status* status) {
223   const std::vector<int64_t>* shape;
224   Status s = reinterpret_cast<ParallelTensor*>(data)->Shape(&shape);
225   if (!s.ok()) {
226     Set_TF_Status_from_Status(status, s);
227     return -1;
228   }
229   return (*shape)[dim_index];
230 }
231 
ParallelTensorSummarize(void * data,TF_Status * status)232 TF_Buffer* ParallelTensorSummarize(void* data, TF_Status* status) {
233   ParallelTensor* parallel_tensor = reinterpret_cast<ParallelTensor*>(data);
234   std::string summary;
235   Status cpp_status = parallel_tensor->SummarizeValue(summary);
236   if (!cpp_status.ok()) {
237     Set_TF_Status_from_Status(status, cpp_status);
238     return nullptr;
239   }
240   return TF_NewBufferFromString(summary.data(), summary.size());
241 }
242 
ParallelTensorToTensorHandle(const std::string & parallel_device_name,TFE_Context * context,std::unique_ptr<ParallelTensor> t,TF_Status * status)243 TensorHandlePtr ParallelTensorToTensorHandle(
244     const std::string& parallel_device_name, TFE_Context* context,
245     std::unique_ptr<ParallelTensor> t, TF_Status* status) {
246   // The resulting TensorHandle owns an opaque pointer to "device memory", which
247   // for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
248   // deleted, it will call ParallelTensorDeallocator to free the struct.
249   ParallelTensor* t_released = t.release();
250   TFE_CustomDeviceTensorHandleMethods handle_methods;
251   handle_methods.num_dims = &ParallelTensorNumDims;
252   handle_methods.dim = &ParallelTensorDim;
253   handle_methods.deallocator = &ParallelTensorDeallocator;
254   handle_methods.summarize = &ParallelTensorSummarize;
255   return TensorHandlePtr(TFE_NewCustomDeviceTensorHandle(
256       context, parallel_device_name.c_str(), t_released->dtype(), t_released,
257       handle_methods, status));
258 }
259 
260 // For TFE_CustomDevice::copy_tensor_to_device in the parallel device
261 // registration.
262 //
263 // Since this function is used to satisfy the TFE_CustomDevice C API,
264 // device_info is passed in using a C-style generic. It must always be a
265 // ParallelDevice.
CopyToParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status,void * device_info)266 TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
267                                        TFE_TensorHandle* tensor,
268                                        TF_Status* status, void* device_info) {
269   TF_SetStatus(
270       status, TF_UNIMPLEMENTED,
271       absl::StrCat("Trying to copy a tensor ",
272                    tensorflow::unwrap(tensor)->DebugString(),
273                    " on to a parallel device. Pack non-parallel "
274                    "tensors for each device into a parallel tensor explicitly.")
275           .c_str());
276   return nullptr;
277 }
278 
279 // For TFE_CustomDevice::copy_tensor_from_device in the parallel device
280 // registration.
281 //
282 // Currently this is an error, and un-packing ParallelTensors must be performed
283 // explicitly by running a TPUReplicatedOutput operation on the parallel device.
284 //
285 // TODO(allenl): There are some use-cases that are only supported by copying to
286 // host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
287 // need to return something here or address these use-cases one by one.
CopyTensorFromParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,const char * target_device_name,TF_Status * status,void * device_info)288 TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
289                                                TFE_TensorHandle* tensor,
290                                                const char* target_device_name,
291                                                TF_Status* status,
292                                                void* device_info) {
293   ParallelTensor* parallel_tensor = reinterpret_cast<ParallelTensor*>(
294       TFE_TensorHandleDevicePointer(tensor, status));
295   if (TF_GetCode(status) != TF_OK) return nullptr;
296   if (parallel_tensor->num_tensors() == 1) {
297     // Copy-off for single-device tensors is allowed to make debugging dynamic
298     // control flow easier.
299     return TFE_TensorHandleCopySharingTensor(parallel_tensor->tensor(0),
300                                              status);
301   } else {
302     TF_SetStatus(
303         status, TF_UNIMPLEMENTED,
304         absl::StrCat(
305             "Trying to copy a tensor out of a parallel device. Since there "
306             "are multiple components to parallel tensors, they must be "
307             "unpacked explicitly.\n",
308             tensorflow::unwrap(tensor)->DebugString())
309             .c_str());
310     return nullptr;
311   }
312 }
313 
314 // For TFE_CustomDevice::execute in the parallel device registration.
315 //
316 // Since this function is used to satisfy the TFE_CustomDevice C API,
317 // device_info is passed in using a C-style generic. It must always be a
318 // ParallelDevice.
ParallelDeviceExecute(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * status,void * device_info)319 void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
320                            TFE_TensorHandle** outputs, TF_Status* status,
321                            void* device_info) {
322   const char* requested_placement = TFE_OpGetDevice(original_op, status);
323   if (*requested_placement == '\0') {
324     TF_SetStatus(
325         status, TF_INTERNAL,
326         "Ops must be placed on the parallel device explicitly, or their inputs "
327         "first un-packed. Got an un-placed op with an input placed on the "
328         "parallel device.");
329     return;
330   }
331   TFE_Context* context = TFE_OpGetContext(original_op, status);
332   if (TF_GetCode(status) != TF_OK) return;
333   const char* operation_name = TFE_OpGetName(original_op, status);
334   if (TF_GetCode(status) != TF_OK) return;
335   const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
336 
337   NamedParallelDevice* named_device =
338       reinterpret_cast<NamedParallelDevice*>(device_info);
339   std::vector<MaybeParallelTensorUnowned> typed_inputs;
340   int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
341   if (TF_GetCode(status) != TF_OK) return;
342   typed_inputs.reserve(num_inputs);
343   for (int i = 0; i < num_inputs; ++i) {
344     TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
345     if (TF_GetCode(status) != TF_OK) return;
346     const char* tensor_handle_device =
347         TFE_TensorHandleDeviceName(input, status);
348     if (TF_GetCode(status) != TF_OK) return;
349     if (named_device->name() == tensor_handle_device) {
350       // We assume that any tensors already placed on this device are
351       // ParallelTensors.
352       typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
353           TFE_TensorHandleDevicePointer(input, status)));
354       if (TF_GetCode(status) != TF_OK) return;
355     } else {
356       typed_inputs.emplace_back(input);
357     }
358   }
359 
360   absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
361       ExecuteWithSpecialOps(named_device->device(), named_device->name(),
362                             context, std::move(typed_inputs), operation_name,
363                             attributes, *num_outputs, status));
364   if (TF_GetCode(status) != TF_OK) return;
365   if (!maybe_typed_outputs.has_value()) {
366     TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
367     return;
368   }
369 
370   std::vector<MaybeParallelTensorOwned> typed_outputs(
371       std::move(maybe_typed_outputs.value()));
372 
373   if (typed_outputs.size() > *num_outputs) {
374     TF_SetStatus(status, TF_INTERNAL,
375                  "The allocated output buffer was too small.");
376     return;
377   }
378 
379   for (int i = 0; i < typed_outputs.size(); ++i) {
380     MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
381     if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
382       outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
383     } else {
384       outputs[i] = ParallelTensorToTensorHandle(
385                        named_device->name(), context,
386                        std::move(absl::get<std::unique_ptr<ParallelTensor>>(
387                            typed_output)),
388                        status)
389                        .release();
390       if (TF_GetCode(status) != TF_OK) return;
391     }
392   }
393   *num_outputs = typed_outputs.size();
394 }
395 
396 // For TFE_CustomDevice::delete_device in the parallel device registration.
397 //
398 // Since this function is used to satisfy the TFE_CustomDevice C API,
399 // device_info is passed in using a C-style generic. It must always be a
400 // ParallelDevice.
DeleteParallelDevice(void * device_info)401 void DeleteParallelDevice(void* device_info) {
402   delete reinterpret_cast<NamedParallelDevice*>(device_info);
403 }
404 
405 }  // namespace
406 
AllocateParallelDevice(const char * device_name,const char * const * underlying_devices,int num_underlying_devices,TFE_CustomDevice * device,void ** device_info)407 void AllocateParallelDevice(const char* device_name,
408                             const char* const* underlying_devices,
409                             int num_underlying_devices,
410                             TFE_CustomDevice* device, void** device_info) {
411   device->copy_tensor_to_device = &CopyToParallelDevice;
412   device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
413   device->delete_device = &DeleteParallelDevice;
414   device->execute = &ParallelDeviceExecute;
415   std::vector<std::string> underlying_devices_vector;
416   underlying_devices_vector.reserve(num_underlying_devices);
417   for (int device_index = 0; device_index < num_underlying_devices;
418        ++device_index) {
419     underlying_devices_vector.push_back(underlying_devices[device_index]);
420   }
421   std::unique_ptr<ParallelDevice> parallel_device(
422       new ParallelDevice(underlying_devices_vector));
423   *device_info =
424       new NamedParallelDevice{device_name, std::move(parallel_device)};
425 }
426 }  // namespace parallel_device
427 }  // namespace tensorflow
428