1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
17
18 #include "tensorflow/core/platform/errors.h"
19
20 namespace tensorflow {
21
Clear()22 void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); }
23
RegisterCustomDevice(const string & device_name,std::unique_ptr<CustomDevice> device)24 Status CustomDeviceOpHandler::RegisterCustomDevice(
25 const string& device_name, std::unique_ptr<CustomDevice> device) {
26 DeviceNameUtils::ParsedName parsed;
27 if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
28 !parsed.has_job || !parsed.has_replica || !parsed.has_task ||
29 !parsed.has_type || !parsed.has_id) {
30 return errors::InvalidArgument(
31 device_name,
32 " could not be parsed as a device name. Use the full "
33 "/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
34 "format.");
35 }
36
37 if (!custom_devices_.emplace(device_name, std::move(device)).second) {
38 return errors::AlreadyExists(device_name,
39 " already registered as a custom device.");
40 }
41 return OkStatus();
42 }
43
FindCustomDeviceFromName(const string & name,CustomDevice ** device) const44 bool CustomDeviceOpHandler::FindCustomDeviceFromName(
45 const string& name, CustomDevice** device) const {
46 auto dev_it = custom_devices_.find(name);
47 if (dev_it == custom_devices_.end()) {
48 return false;
49 }
50 *device = dev_it->second.get();
51 return true;
52 }
53
Execute(ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)54 Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op,
55 ImmediateExecutionTensorHandle** retvals,
56 int* num_retvals) {
57 tensorflow::CustomDevice* custom_device = nullptr;
58
59 TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op));
60
61 if (custom_device != nullptr) {
62 return custom_device->Execute(op, retvals, num_retvals);
63 }
64
65 // The op will be placed on physical device. However, it contains custom
66 // device tensor handles. The tensor handles will be copy to physical device
67 // first.
68 if (op->HasCustomDeviceInput()) {
69 auto inputs = op->GetInputs();
70 for (int i = 0; i < inputs.size(); ++i) {
71 auto target_device = op->DeviceName();
72 if (target_device.empty()) {
73 target_device = op->GetContext()->HostCPUName();
74 }
75 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
76 // here.
77 if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) {
78 tensorflow::CustomDeviceTensorHandle* previous =
79 tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
80 inputs[i]);
81 tensorflow::ImmediateExecutionTensorHandle* new_tensor;
82 TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
83 previous, target_device, &new_tensor));
84 Status s = op->SetInput(i, new_tensor);
85 new_tensor->Unref();
86 TF_RETURN_IF_ERROR(s);
87 }
88 }
89 }
90
91 return op->Execute(
92 absl::MakeSpan(
93 reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals),
94 *num_retvals),
95 num_retvals);
96 }
97
CopyTensorHandleToDevice(ImmediateExecutionContext * context,ImmediateExecutionTensorHandle * handle,const char * device_name,Status * status)98 ImmediateExecutionTensorHandle* CustomDeviceOpHandler::CopyTensorHandleToDevice(
99 ImmediateExecutionContext* context, ImmediateExecutionTensorHandle* handle,
100 const char* device_name, Status* status) {
101 *status = OkStatus();
102 ImmediateExecutionTensorHandle* result = nullptr;
103 tensorflow::CustomDevice* dev;
104
105 if (FindCustomDeviceFromName(device_name, &dev)) {
106 *status = dev->CopyTensorToDevice(handle, &result);
107 if (status->ok()) {
108 return result;
109 }
110 return nullptr;
111 }
112
113 // Target device is regular device. Check if the input is on custom
114 // device
115 const char* handle_device_name = handle->DeviceName(status);
116 if (!status->ok()) {
117 return nullptr;
118 }
119 if (FindCustomDeviceFromName(handle_device_name, &dev)) {
120 *status = dev->CopyTensorFromDevice(handle, device_name, &result);
121 if (status->ok()) {
122 return result;
123 }
124 return nullptr;
125 }
126
127 // Both source and target device are regular device.
128 return context->CopyTensorHandleToDevice(handle, device_name, status);
129 }
130
MaybePinToCustomDevice(CustomDevice ** device,const ImmediateExecutionOperation & op) const131 Status CustomDeviceOpHandler::MaybePinToCustomDevice(
132 CustomDevice** device, const ImmediateExecutionOperation& op) const {
133 *device = nullptr;
134 if (!FindCustomDeviceFromName(op.DeviceName(), device) &&
135 !op.HasCustomDeviceInput()) {
136 return OkStatus();
137 }
138
139 // Ops are placed on a custom device if there's no other explicit requested
140 // placement and there is only one custom device in the op
141 // inputs.
142 //
143 // Resource-dtype inputs take precedence over non-resource inputs and explicit
144 // placements; this function pins ops with a resource-dtype custom device
145 // input to that custom device.
146 CustomDevice* first = nullptr;
147 if (!op.GetInputs().empty()) {
148 for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
149 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
150 // here.
151 if (CustomDeviceTensorHandle::classof(generic_input)) {
152 const CustomDeviceTensorHandle* input =
153 down_cast<const CustomDeviceTensorHandle*>(generic_input);
154 CustomDevice* current = input->device();
155 if (first == nullptr) {
156 first = current;
157 } else if (first != current) {
158 return errors::InvalidArgument(absl::StrCat(
159 "If an operation has one of its inputs in a custom device, then "
160 "all inputs should be on that same custom device or another "
161 "physical device. Operation ",
162 op.Name(),
163 " has one input in custom "
164 "device ",
165 first->name(),
166 " and at least one input in a different custom device ",
167 current->name()));
168 }
169 }
170 }
171 for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
172 if (generic_input->DataType() == DT_RESOURCE) {
173 if (CustomDeviceTensorHandle::classof(generic_input)) {
174 const CustomDeviceTensorHandle* input =
175 down_cast<const CustomDeviceTensorHandle*>(generic_input);
176 // There's only one custom device input, and it's a resource input, so
177 // we'll force-place the op on to that custom device. As with physical
178 // devices, this overrides any explicit placement for the op.
179 *device = input->device();
180 return OkStatus();
181 } else {
182 // Don't set a custom device if there's a physical-device resource
183 // input.
184 return OkStatus();
185 }
186 }
187 }
188 }
189 // Since there are no resource-dtype inputs, we'll respect explicit placements
190 // before considering input-based placement.
191 if (*device == nullptr && op.DeviceName().empty() && first != nullptr) {
192 // If there are non-resource inputs on a custom device we will default the
193 // op to that custom device, but not override an explicit op placement.
194 *device = first;
195 return OkStatus();
196 }
197 return OkStatus();
198 }
199
200 } // namespace tensorflow
201