xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
17 
18 #include "tensorflow/core/platform/errors.h"
19 
20 namespace tensorflow {
21 
Clear()22 void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); }
23 
RegisterCustomDevice(const string & device_name,std::unique_ptr<CustomDevice> device)24 Status CustomDeviceOpHandler::RegisterCustomDevice(
25     const string& device_name, std::unique_ptr<CustomDevice> device) {
26   DeviceNameUtils::ParsedName parsed;
27   if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
28       !parsed.has_job || !parsed.has_replica || !parsed.has_task ||
29       !parsed.has_type || !parsed.has_id) {
30     return errors::InvalidArgument(
31         device_name,
32         " could not be parsed as a device name. Use the full "
33         "/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
34         "format.");
35   }
36 
37   if (!custom_devices_.emplace(device_name, std::move(device)).second) {
38     return errors::AlreadyExists(device_name,
39                                  " already registered as a custom device.");
40   }
41   return OkStatus();
42 }
43 
FindCustomDeviceFromName(const string & name,CustomDevice ** device) const44 bool CustomDeviceOpHandler::FindCustomDeviceFromName(
45     const string& name, CustomDevice** device) const {
46   auto dev_it = custom_devices_.find(name);
47   if (dev_it == custom_devices_.end()) {
48     return false;
49   }
50   *device = dev_it->second.get();
51   return true;
52 }
53 
Execute(ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)54 Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op,
55                                       ImmediateExecutionTensorHandle** retvals,
56                                       int* num_retvals) {
57   tensorflow::CustomDevice* custom_device = nullptr;
58 
59   TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op));
60 
61   if (custom_device != nullptr) {
62     return custom_device->Execute(op, retvals, num_retvals);
63   }
64 
65   // The op will be placed on physical device. However, it contains custom
66   // device tensor handles. The tensor handles will be copy to physical device
67   // first.
68   if (op->HasCustomDeviceInput()) {
69     auto inputs = op->GetInputs();
70     for (int i = 0; i < inputs.size(); ++i) {
71       auto target_device = op->DeviceName();
72       if (target_device.empty()) {
73         target_device = op->GetContext()->HostCPUName();
74       }
75       // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
76       // here.
77       if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) {
78         tensorflow::CustomDeviceTensorHandle* previous =
79             tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
80                 inputs[i]);
81         tensorflow::ImmediateExecutionTensorHandle* new_tensor;
82         TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
83             previous, target_device, &new_tensor));
84         Status s = op->SetInput(i, new_tensor);
85         new_tensor->Unref();
86         TF_RETURN_IF_ERROR(s);
87       }
88     }
89   }
90 
91   return op->Execute(
92       absl::MakeSpan(
93           reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals),
94           *num_retvals),
95       num_retvals);
96 }
97 
CopyTensorHandleToDevice(ImmediateExecutionContext * context,ImmediateExecutionTensorHandle * handle,const char * device_name,Status * status)98 ImmediateExecutionTensorHandle* CustomDeviceOpHandler::CopyTensorHandleToDevice(
99     ImmediateExecutionContext* context, ImmediateExecutionTensorHandle* handle,
100     const char* device_name, Status* status) {
101   *status = OkStatus();
102   ImmediateExecutionTensorHandle* result = nullptr;
103   tensorflow::CustomDevice* dev;
104 
105   if (FindCustomDeviceFromName(device_name, &dev)) {
106     *status = dev->CopyTensorToDevice(handle, &result);
107     if (status->ok()) {
108       return result;
109     }
110     return nullptr;
111   }
112 
113   // Target device is regular device. Check if the input is on custom
114   // device
115   const char* handle_device_name = handle->DeviceName(status);
116   if (!status->ok()) {
117     return nullptr;
118   }
119   if (FindCustomDeviceFromName(handle_device_name, &dev)) {
120     *status = dev->CopyTensorFromDevice(handle, device_name, &result);
121     if (status->ok()) {
122       return result;
123     }
124     return nullptr;
125   }
126 
127   // Both source and target device are regular device.
128   return context->CopyTensorHandleToDevice(handle, device_name, status);
129 }
130 
MaybePinToCustomDevice(CustomDevice ** device,const ImmediateExecutionOperation & op) const131 Status CustomDeviceOpHandler::MaybePinToCustomDevice(
132     CustomDevice** device, const ImmediateExecutionOperation& op) const {
133   *device = nullptr;
134   if (!FindCustomDeviceFromName(op.DeviceName(), device) &&
135       !op.HasCustomDeviceInput()) {
136     return OkStatus();
137   }
138 
139   // Ops are placed on a custom device if there's no other explicit requested
140   // placement and there is only one custom device in the op
141   // inputs.
142   //
143   // Resource-dtype inputs take precedence over non-resource inputs and explicit
144   // placements; this function pins ops with a resource-dtype custom device
145   // input to that custom device.
146   CustomDevice* first = nullptr;
147   if (!op.GetInputs().empty()) {
148     for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
149       // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
150       // here.
151       if (CustomDeviceTensorHandle::classof(generic_input)) {
152         const CustomDeviceTensorHandle* input =
153             down_cast<const CustomDeviceTensorHandle*>(generic_input);
154         CustomDevice* current = input->device();
155         if (first == nullptr) {
156           first = current;
157         } else if (first != current) {
158           return errors::InvalidArgument(absl::StrCat(
159               "If an operation has one of its inputs in a custom device, then "
160               "all inputs should be on that same custom device or another "
161               "physical device. Operation ",
162               op.Name(),
163               " has one input in custom "
164               "device ",
165               first->name(),
166               " and at least one input in a different custom device ",
167               current->name()));
168         }
169       }
170     }
171     for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
172       if (generic_input->DataType() == DT_RESOURCE) {
173         if (CustomDeviceTensorHandle::classof(generic_input)) {
174           const CustomDeviceTensorHandle* input =
175               down_cast<const CustomDeviceTensorHandle*>(generic_input);
176           // There's only one custom device input, and it's a resource input, so
177           // we'll force-place the op on to that custom device. As with physical
178           // devices, this overrides any explicit placement for the op.
179           *device = input->device();
180           return OkStatus();
181         } else {
182           // Don't set a custom device if there's a physical-device resource
183           // input.
184           return OkStatus();
185         }
186       }
187     }
188   }
189   // Since there are no resource-dtype inputs, we'll respect explicit placements
190   // before considering input-based placement.
191   if (*device == nullptr && op.DeviceName().empty() && first != nullptr) {
192     // If there are non-resource inputs on a custom device we will default the
193     // op to that custom device, but not override an explicit op placement.
194     *device = first;
195     return OkStatus();
196   }
197   return OkStatus();
198 }
199 
200 }  // namespace tensorflow
201