1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/custom_device.h"
17
18 #include "tensorflow/core/common_runtime/device_mgr.h"
19 #include "tensorflow/core/common_runtime/eager/context.h"
20 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
21 #include "tensorflow/core/common_runtime/eager/placement_utils.h"
22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
23 #include "tensorflow/core/framework/device_factory.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace eager {
29 namespace {
30
31 using ::testing::ContainsRegex;
32 using ::testing::HasSubstr;
33
34 class TestCustomDevice : public CustomDevice {
35 public:
TestCustomDevice(std::string name)36 explicit TestCustomDevice(std::string name) : name_(name) {}
name()37 const std::string& name() override { return name_; }
CopyTensorToDevice(ImmediateExecutionTensorHandle * tensor,ImmediateExecutionTensorHandle ** result)38 Status CopyTensorToDevice(ImmediateExecutionTensorHandle* tensor,
39 ImmediateExecutionTensorHandle** result) override {
40 tensor->Ref();
41 *result = tensor;
42 return OkStatus();
43 }
CopyTensorFromDevice(ImmediateExecutionTensorHandle * tensor,const std::string & target_device_name,ImmediateExecutionTensorHandle ** result)44 Status CopyTensorFromDevice(
45 ImmediateExecutionTensorHandle* tensor,
46 const std::string& target_device_name,
47 ImmediateExecutionTensorHandle** result) override {
48 tensor->Ref();
49 *result = tensor;
50 return OkStatus();
51 }
Execute(const ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)52 Status Execute(const ImmediateExecutionOperation* op,
53 ImmediateExecutionTensorHandle** retvals,
54 int* num_retvals) override {
55 return errors::Unimplemented("Not implemented");
56 }
57
Pack(absl::Span<ImmediateExecutionTensorHandle * > handles,ImmediateExecutionTensorHandle ** result)58 Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
59 ImmediateExecutionTensorHandle** result) override {
60 return errors::Unimplemented("Packing is not implemented");
61 }
62
63 private:
64 std::string name_;
65 };
66
67 class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle {
68 public:
TestCustomDeviceTensorHandle(ImmediateExecutionContext * context,TestCustomDevice * device,tensorflow::DataType dtype,int64_t length)69 TestCustomDeviceTensorHandle(ImmediateExecutionContext* context,
70 TestCustomDevice* device,
71 tensorflow::DataType dtype, int64_t length)
72 : CustomDeviceTensorHandle(context, device, dtype), length_(length) {}
73
DevicePointer() const74 void* DevicePointer() const override { return nullptr; }
NumDims(int * num_dims) const75 Status NumDims(int* num_dims) const override {
76 *num_dims = 1;
77 return OkStatus();
78 }
Dim(int dim_index,int64_t * dim) const79 Status Dim(int dim_index, int64_t* dim) const override {
80 if (dim_index == 0) {
81 *dim = length_;
82 return OkStatus();
83 } else {
84 return errors::Internal("Dim out of bounds");
85 }
86 }
87
SummarizeValue(std::string & summary) const88 Status SummarizeValue(std::string& summary) const override {
89 summary = std::string("TestValue");
90 return OkStatus();
91 }
92
93 private:
94 const int64_t length_;
95 };
96
TEST(CustomDevice,TestTensorHandle)97 TEST(CustomDevice, TestTensorHandle) {
98 StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
99 "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
100 core::RefCountPtr<EagerContext> ctx(new EagerContext(
101 SessionOptions(),
102 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
103 &device_mgr, false, nullptr, nullptr));
104 std::string device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:15";
105 TestCustomDevice device(device_name);
106 core::RefCountPtr<TestCustomDeviceTensorHandle> tensor(
107 new TestCustomDeviceTensorHandle(ctx.get(), &device, DT_FLOAT,
108 /*length=*/3));
109 Status s;
110 std::string device_type = tensor->DeviceType(&s);
111 ASSERT_TRUE(s.ok()) << s.error_message();
112 EXPECT_EQ("CUSTOM", device_type);
113 int device_index = tensor->DeviceId(&s);
114 ASSERT_TRUE(s.ok()) << s.error_message();
115 EXPECT_EQ(15, device_index);
116 int64_t num_elements = 0;
117 s = tensor->NumElements(&num_elements);
118 ASSERT_TRUE(s.ok()) << s.error_message();
119 EXPECT_EQ(3, num_elements);
120 EXPECT_THAT(
121 tensor->DebugString(),
122 ContainsRegex(
123 R"re(TensorHandle\(TestValue, shape=\[3\], dtype=DT_FLOAT, device=.*\))re"));
124 }
125
TEST(CustomDevice,TestTensorHandleUnknownDimNumElements)126 TEST(CustomDevice, TestTensorHandleUnknownDimNumElements) {
127 StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
128 "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
129 core::RefCountPtr<EagerContext> ctx(new EagerContext(
130 SessionOptions(),
131 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
132 &device_mgr, false, nullptr, nullptr));
133 std::string device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:15";
134 TestCustomDevice device(device_name);
135 core::RefCountPtr<TestCustomDeviceTensorHandle> tensor(
136 new TestCustomDeviceTensorHandle(ctx.get(), &device, DT_FLOAT,
137 /*length=*/-1));
138 int64_t num_elements;
139 Status s = tensor->NumElements(&num_elements);
140 EXPECT_FALSE(s.ok());
141 EXPECT_THAT(s.error_message(), HasSubstr("representing varying shapes"));
142 }
143
TEST(CustomDevice,TestResourcePlacement)144 TEST(CustomDevice, TestResourcePlacement) {
145 StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
146 "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
147 core::RefCountPtr<EagerContext> ctx(new EagerContext(
148 SessionOptions(),
149 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
150 &device_mgr, false, nullptr, nullptr));
151 std::string custom_device_name =
152 "/job:localhost/replica:0/task:0/device:CUSTOM:15";
153 TestCustomDevice custom_device(custom_device_name);
154 core::RefCountPtr<TestCustomDeviceTensorHandle> custom_float_tensor(
155 new TestCustomDeviceTensorHandle(ctx.get(), &custom_device, DT_FLOAT,
156 /*length=*/3));
157 core::RefCountPtr<TestCustomDeviceTensorHandle> custom_resource_tensor(
158 new TestCustomDeviceTensorHandle(ctx.get(), &custom_device, DT_RESOURCE,
159 /*length=*/3));
160
161 Tensor resource_tensor(DT_RESOURCE, {});
162 Device* physical_device = device_mgr.ListDevices().at(0);
163 core::RefCountPtr<TensorHandle> physical_resource_tensor(
164 TensorHandle::CreateLocalHandle(std::move(resource_tensor),
165 physical_device, physical_device,
166 physical_device, ctx.get()));
167 Tensor float_tensor(DT_FLOAT, {});
168 core::RefCountPtr<TensorHandle> physical_float_tensor(
169 TensorHandle::CreateLocalHandle(std::move(float_tensor), physical_device,
170 physical_device, physical_device,
171 ctx.get()));
172 EagerOperation op(ctx.get());
173 TF_ASSERT_OK(op.Reset("AssignVariableOp", ""));
174 TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
175 TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
176 CustomDevice* placed_device = nullptr;
177 TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
178 &placed_device, op));
179 // MaybePinToCustomDevice has no opinion about ops which have physical
180 // resource-dtype inputs. They'll get placed on physical devices.
181 EXPECT_EQ(nullptr, placed_device);
182
183 op.Clear();
184 TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str()));
185 TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
186 TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
187 placed_device = nullptr;
188 TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
189 &placed_device, op));
190 // Explicit placement onto a custom device also doesn't trigger custom device
191 // placement if there's a physical device resource input.
192 EXPECT_EQ(nullptr, placed_device);
193
194 op.Clear();
195 TF_ASSERT_OK(
196 op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0"));
197 TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
198 placed_device = nullptr;
199 TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
200 &placed_device, op));
201 // Explicit placements typically override input-based placement onto a custom
202 // device.
203 EXPECT_EQ(nullptr, placed_device);
204
205 op.Clear();
206 TF_ASSERT_OK(op.Reset("AssignVariableOp",
207 "/job:localhost/replica:0/task:0/device:CPU:0"));
208 TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get()));
209 TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
210 placed_device = nullptr;
211 TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
212 &placed_device, op));
213 // Even with an explicit physical device placement, custom device resource
214 // inputs place the op on the custom device.
215 ASSERT_NE(placed_device, nullptr);
216 EXPECT_EQ(&custom_device, placed_device);
217 }
218
219 } // namespace
220 } // namespace eager
221 } // namespace tensorflow
222