1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
17
18 #include <array>
19
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_experimental.h"
24 #include "tensorflow/core/platform/test.h"
25
26 // NOTE(allenl): These tests currently go through TFE_Execute and so are
27 // integration testing rather than purely testing the parallel device. They
28 // correspond fairly well to the implementation, but testing the C++ directly is
29 // another option.
30
31 namespace tensorflow {
32 namespace parallel_device {
33
Create(TFE_Context * context,TF_DataType type,const int64_t * dims,const int num_dims,const char * device,TF_Status * status)34 Variable* Variable::Create(TFE_Context* context, TF_DataType type,
35 const int64_t* dims, const int num_dims,
36 const char* device, TF_Status* status) {
37 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
38 TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
39 if (TF_GetCode(status) != TF_OK) return nullptr;
40 TFE_OpSetAttrType(op.get(), "dtype", type);
41 TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
42 TFE_OpSetAttrString(op.get(), "container", "", 0);
43 // Use the special GUID for no buffer sharing
44 //
45 // TODO(allenl): Should we provide a better API for this? AFAIK this is the
46 // only reasonable way to make variables with no aliasing using the eager C
47 // API.
48 std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
49 TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
50 no_sharing.length());
51 TFE_OpSetDevice(op.get(), device, status);
52 if (TF_GetCode(status) != TF_OK) return nullptr;
53 TFE_TensorHandle* var_handle = nullptr;
54 int num_retvals = 1;
55 TFE_Execute(op.get(), &var_handle, &num_retvals, status);
56 if (TF_GetCode(status) != TF_OK) return nullptr;
57 return new Variable(var_handle, type);
58 }
59
Destroy(TFE_Context * context,TF_Status * status)60 void Variable::Destroy(TFE_Context* context, TF_Status* status) {
61 // Free the backing buffer for the variable.
62 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
63 TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
64 if (TF_GetCode(status) != TF_OK) return;
65 TFE_OpAddInput(op.get(), handle_, status);
66 if (TF_GetCode(status) != TF_OK) return;
67 const char* device = TFE_TensorHandleDeviceName(handle_, status);
68 if (TF_GetCode(status) != TF_OK) return;
69 TFE_OpSetDevice(op.get(), device, status);
70 if (TF_GetCode(status) != TF_OK) return;
71 int num_retvals = 0;
72 TFE_Execute(op.get(), nullptr, &num_retvals, status);
73 if (TF_GetCode(status) != TF_OK) return;
74 // Delete the variable handle itself.
75 TFE_DeleteTensorHandle(handle_);
76 }
77
Read(TFE_Context * context,TF_Status * status)78 TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
79 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
80 TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
81 if (TF_GetCode(status) != TF_OK) return nullptr;
82 TFE_OpAddInput(op.get(), handle_, status);
83 if (TF_GetCode(status) != TF_OK) return nullptr;
84 const char* device = TFE_TensorHandleDeviceName(handle_, status);
85 if (TF_GetCode(status) != TF_OK) return nullptr;
86 TFE_OpSetDevice(op.get(), device, status);
87 if (TF_GetCode(status) != TF_OK) return nullptr;
88 TFE_OpSetAttrType(op.get(), "dtype", type_);
89 int num_retvals = 1;
90 TFE_TensorHandle* var_value = nullptr;
91 TFE_Execute(op.get(), &var_value, &num_retvals, status);
92 if (TF_GetCode(status) != TF_OK) return nullptr;
93 return TensorHandlePtr(var_value);
94 }
95
GeneralAssignment(const char * op_name,TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)96 void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
97 TFE_TensorHandle* value, TF_Status* status) {
98 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
99 TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
100 if (TF_GetCode(status) != TF_OK) return;
101 TFE_OpSetAttrType(op.get(), "dtype", type_);
102 TFE_OpAddInput(op.get(), handle_, status);
103 if (TF_GetCode(status) != TF_OK) return;
104 TFE_OpAddInput(op.get(), value, status);
105 if (TF_GetCode(status) != TF_OK) return;
106 const char* device = TFE_TensorHandleDeviceName(handle_, status);
107 if (TF_GetCode(status) != TF_OK) return;
108 TFE_OpSetDevice(op.get(), device, status);
109
110 int num_retvals = 0;
111 TFE_Execute(op.get(), nullptr, &num_retvals, status);
112 if (TF_GetCode(status) != TF_OK) return;
113 }
114
AssignAdd(TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)115 void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
116 TF_Status* status) {
117 GeneralAssignment("AssignAddVariableOp", context, value, status);
118 }
119
Assign(TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)120 void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
121 TF_Status* status) {
122 GeneralAssignment("AssignVariableOp", context, value, status);
123 }
124
125 // Passed to `TF_NewTensor` to indicate how an array of floats should be
126 // deleted.
FloatDeallocator(void * data,size_t,void * arg)127 static void FloatDeallocator(void* data, size_t, void* arg) {
128 delete[] static_cast<float*>(data);
129 }
130
131 // Creates a TFE_TensorHandle with value `v`.
FloatTensorHandle(float v,TF_Status * status)132 TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
133 const int num_bytes = sizeof(float);
134 float* values = new float[1];
135 values[0] = v;
136 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
137 TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
138 nullptr),
139 TF_DeleteTensor);
140 return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
141 }
142
143 // Creates a rank-one TFE_TensorHandle with value `v`.
VectorFloatTensorHandle(const std::vector<float> & v,TF_Status * status)144 TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
145 TF_Status* status) {
146 const int num_bytes = v.size() * sizeof(float);
147 float* values = new float[v.size()];
148 memcpy(values, v.data(), num_bytes);
149 int64_t dims = v.size();
150 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
151 TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
152 &FloatDeallocator, nullptr),
153 TF_DeleteTensor);
154 return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
155 }
156
157 // Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
158 template <std::size_t num_replicas>
ExtractPerDeviceValues(TFE_Context * context,TFE_TensorHandle * input,std::array<TensorHandlePtr,num_replicas> * components,TF_Status * status)159 void ExtractPerDeviceValues(
160 TFE_Context* context, TFE_TensorHandle* input,
161 std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
162 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
163 TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
164 if (TF_GetCode(status) != TF_OK) return;
165 TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
166 TFE_OpAddInput(op.get(), input, status);
167 if (TF_GetCode(status) != TF_OK) return;
168 const char* device = TFE_TensorHandleDeviceName(input, status);
169 if (TF_GetCode(status) != TF_OK) return;
170 TFE_OpSetDevice(op.get(), device, status);
171 if (TF_GetCode(status) != TF_OK) return;
172
173 TFE_TensorHandle* result_handles[num_replicas];
174 int num_retvals = num_replicas;
175 TFE_Execute(op.get(), result_handles, &num_retvals, status);
176 if (TF_GetCode(status) != TF_OK) return;
177 for (int i = 0; i < num_replicas; ++i) {
178 (*components)[i].reset(result_handles[i]);
179 }
180 }
181
Multiply(TFE_Context * context,TFE_TensorHandle * first,TFE_TensorHandle * second,TF_Status * status)182 TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
183 TFE_TensorHandle* second, TF_Status* status) {
184 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
185 TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
186 if (TF_GetCode(status) != TF_OK) return nullptr;
187 TFE_OpAddInput(op.get(), first, status);
188 if (TF_GetCode(status) != TF_OK) return nullptr;
189 TFE_OpAddInput(op.get(), second, status);
190 if (TF_GetCode(status) != TF_OK) return nullptr;
191 const char* first_device = TFE_TensorHandleDeviceName(first, status);
192 if (TF_GetCode(status) != TF_OK) return nullptr;
193 TFE_OpSetDevice(op.get(), first_device, status);
194
195 TFE_TensorHandle* result_handle;
196 int num_retvals = 1;
197 TFE_Execute(op.get(), &result_handle, &num_retvals, status);
198 if (TF_GetCode(status) != TF_OK) return nullptr;
199 return TensorHandlePtr(result_handle);
200 }
201
202 // Create and modify a variable placed on a parallel device which composes
203 // `first_device` and `second_device`.
BasicTestsForTwoDevices(TFE_Context * context,const char * first_device,const char * second_device)204 void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
205 const char* second_device) {
206 // Register the custom device
207 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
208 TF_NewStatus(), TF_DeleteStatus);
209 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
210 std::array<const char*, 2> underlying_devices{first_device, second_device};
211 RegisterParallelDevice(context, device_name, underlying_devices,
212 status.get());
213 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
214
215 // Create a variable handle (uninitialized to start) placed on the parallel
216 // device.
217 std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
218 to_delete->Destroy(context, status.get());
219 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
220 delete to_delete;
221 };
222 std::unique_ptr<Variable, decltype(variable_deleter)> variable(
223 Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
224 status.get()),
225 variable_deleter);
226 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
227
228 // Assign an initial value to the variable, mirroring it to each component
229 // device.
230 {
231 TensorHandlePtr initial_value_cpu = FloatTensorHandle(20., status.get());
232 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
233 std::array<TFE_TensorHandle*, 2> components{initial_value_cpu.get(),
234 initial_value_cpu.get()};
235 TensorHandlePtr initial_value =
236 CreatePerDeviceValues(context, components, device_name, status.get());
237 variable->Assign(context, initial_value.get(), status.get());
238 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
239 }
240
241 // Read from the variable and verify that we have a parallel tensor.
242 {
243 TensorHandlePtr read = variable->Read(context, status.get());
244 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
245 std::array<TensorHandlePtr, 2> components;
246 ExtractPerDeviceValues(context, read.get(), &components, status.get());
247 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
248
249 ExpectScalarEq<float>(components[0].get(), 20.);
250 ExpectScalarEq<float>(components[1].get(), 20.);
251
252 std::string first_device =
253 TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
254 ASSERT_EQ(underlying_devices[0], first_device);
255 std::string second_device =
256 TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
257 ASSERT_EQ(underlying_devices[1], second_device);
258 }
259
260 // Add a parallel tensor with different values on each device to the variable.
261 {
262 TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
263 TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
264 std::array<TFE_TensorHandle*, 2> components{value_one.get(),
265 value_two.get()};
266 TensorHandlePtr combined_value =
267 CreatePerDeviceValues(context, components, device_name, status.get());
268 variable->AssignAdd(context, combined_value.get(), status.get());
269 }
270
271 // Read the variable and verify that each component has the right modified
272 // value.
273 {
274 TensorHandlePtr read = variable->Read(context, status.get());
275 std::array<TensorHandlePtr, 2> components;
276 ExtractPerDeviceValues(context, read.get(), &components, status.get());
277 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
278
279 ExpectScalarEq<float>(components[0].get(), 23.);
280 ExpectScalarEq<float>(components[1].get(), 18.);
281
282 std::string first_device =
283 TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
284 ASSERT_EQ(underlying_devices[0], first_device);
285 std::string second_device =
286 TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
287 ASSERT_EQ(underlying_devices[1], second_device);
288 }
289 }
290
291 } // namespace parallel_device
292 } // namespace tensorflow
293