1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
17
18 #include <array>
19
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_experimental.h"
24 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
25 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
26 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
27 #include "tensorflow/c/tf_status_internal.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/test.h"
30
31 // NOTE(allenl): These tests currently go through TFE_Execute and so are
32 // integration testing rather than purely testing the parallel device. They
33 // correspond fairly well to the implementation, but testing the C++ directly is
34 // another option.
35
36 namespace tensorflow {
37 namespace parallel_device {
38
39 using ::testing::HasSubstr;
40
TEST(PARALLEL_DEVICE,TestBasicCPU)41 TEST(PARALLEL_DEVICE, TestBasicCPU) {
42 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
43 TF_NewStatus(), TF_DeleteStatus);
44 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
45 TFE_NewContextOptions(), TFE_DeleteContextOptions);
46 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
47 TF_CreateConfig(
48 /*enable_xla_compilation=*/false,
49 /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
50 TF_DeleteBuffer);
51 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
52 status.get());
53 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
54 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
55 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
56 BasicTestsForTwoDevices(context.get(),
57 "/job:localhost/replica:0/task:0/device:CPU:0",
58 "/job:localhost/replica:0/task:0/device:CPU:1");
59 }
60
TEST(PARALLEL_DEVICE,TestBasicCPUAliased)61 TEST(PARALLEL_DEVICE, TestBasicCPUAliased) {
62 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
63 TF_NewStatus(), TF_DeleteStatus);
64 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
65 TFE_NewContextOptions(), TFE_DeleteContextOptions);
66 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
67 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
68 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
69 BasicTestsForTwoDevices(context.get(),
70 "/job:localhost/replica:0/task:0/device:CPU:0",
71 "/job:localhost/replica:0/task:0/device:CPU:0");
72 }
73
TEST(PARALLEL_DEVICE,TestBasicTPUAliased)74 TEST(PARALLEL_DEVICE, TestBasicTPUAliased) {
75 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
76 TF_NewStatus(), TF_DeleteStatus);
77 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
78 TFE_NewContextOptions(), TFE_DeleteContextOptions);
79 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
80 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
81 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
82
83 // Skip the test if no TPU is available.
84 std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> devices(
85 TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList);
86 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
87 bool has_tpu = false;
88 for (int device_index = 0; device_index < TF_DeviceListCount(devices.get());
89 ++device_index) {
90 std::string device_type =
91 TF_DeviceListType(devices.get(), device_index, status.get());
92 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
93 if (device_type == "TPU") {
94 has_tpu = true;
95 break;
96 }
97 }
98 if (has_tpu) {
99 BasicTestsForTwoDevices(context.get(),
100 "/job:localhost/replica:0/task:0/device:TPU:0",
101 "/job:localhost/replica:0/task:0/device:TPU:0");
102 }
103 }
104
TEST(PARALLEL_DEVICE,TestExplicitCopies)105 TEST(PARALLEL_DEVICE, TestExplicitCopies) {
106 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
107 TF_NewStatus(), TF_DeleteStatus);
108 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
109 TFE_NewContextOptions(), TFE_DeleteContextOptions);
110 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
111 TF_CreateConfig(
112 /*enable_xla_compilation=*/false,
113 /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
114 TF_DeleteBuffer);
115 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
116 status.get());
117 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
118 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
119 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
120
121 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
122 const char* first_device_name =
123 "/job:localhost/replica:0/task:0/device:CPU:0";
124 const char* second_device_name =
125 "/job:localhost/replica:0/task:0/device:CPU:1";
126 std::array<const char*, 2> underlying_devices{first_device_name,
127 second_device_name};
128 RegisterParallelDevice(context.get(), device_name, underlying_devices,
129 status.get());
130 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
131
132 TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
133 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
134
135 // Copying on to a parallel device must be explicit.
136 TensorHandlePtr failed_copy_on_result(TFE_TensorHandleCopyToDevice(
137 cpu_value.get(), context.get(), device_name, status.get()));
138 EXPECT_EQ(TF_GetCode(status.get()), TF_UNIMPLEMENTED);
139
140 std::array<TFE_TensorHandle*, 2> components{cpu_value.get(), cpu_value.get()};
141 TensorHandlePtr device_value = CreatePerDeviceValues(
142 context.get(), components, device_name, status.get());
143 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
144 // Copies off of parallel devices must be explicit.
145 TensorHandlePtr copy_off(TFE_TensorHandleCopyToDevice(
146 device_value.get(), context.get(), first_device_name, status.get()));
147 EXPECT_EQ(TF_GetCode(status.get()), TF_UNIMPLEMENTED);
148 }
149
TEST(PARALLEL_DEVICE,TestDifferentShapes)150 TEST(PARALLEL_DEVICE, TestDifferentShapes) {
151 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
152 TF_NewStatus(), TF_DeleteStatus);
153 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
154 TFE_NewContextOptions(), TFE_DeleteContextOptions);
155 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
156 TF_CreateConfig(
157 /*enable_xla_compilation=*/false,
158 /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
159 TF_DeleteBuffer);
160 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
161 status.get());
162 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
163 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
164 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
165
166 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
167 std::array<const char*, 2> underlying_devices{
168 "/job:localhost/replica:0/task:0/device:CPU:0",
169 "/job:localhost/replica:0/task:0/device:CPU:1"};
170 RegisterParallelDevice(context.get(), device_name, underlying_devices,
171 status.get());
172 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
173
174 // Create two vectors with different lengths
175 std::vector<float> size_two_value{1., 2.};
176 std::vector<float> size_three_value{1., 2., 3.};
177 TensorHandlePtr size_two(
178 VectorFloatTensorHandle(size_two_value, status.get()));
179 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
180 TensorHandlePtr size_three(
181 VectorFloatTensorHandle(size_three_value, status.get()));
182 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
183
184 // Try to combine these values into a single parallel tensor.
185 std::array<TFE_TensorHandle*, 2> components{size_two.get(), size_three.get()};
186 TensorHandlePtr combined_value = CreatePerDeviceValues(
187 context.get(), components, device_name, status.get());
188 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
189 int num_axes = TFE_TensorHandleNumDims(combined_value.get(), status.get());
190 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
191 EXPECT_EQ(num_axes, 1);
192 }
193
TEST(PARALLEL_DEVICE,TestNestedParallelDevices)194 TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
195 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
196 TF_NewStatus(), TF_DeleteStatus);
197 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
198 TFE_NewContextOptions(), TFE_DeleteContextOptions);
199 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
200 TF_CreateConfig(
201 /*enable_xla_compilation=*/false,
202 /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/3),
203 TF_DeleteBuffer);
204 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
205 status.get());
206 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
207 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
208 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
209
210 // Create a parallel device with two CPUs
211 const char* first_device_name =
212 "/job:localhost/replica:0/task:0/device:CUSTOM:0";
213 std::array<const char*, 2> first_underlying_devices{
214 "/job:localhost/replica:0/task:0/device:CPU:0",
215 "/job:localhost/replica:0/task:0/device:CPU:1"};
216 RegisterParallelDevice(context.get(), first_device_name,
217 first_underlying_devices, status.get());
218 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
219
220 // Create a second parallel device with the first parallel device and one
221 // additional CPU.
222 const char* second_device_name =
223 "/job:localhost/replica:0/task:0/device:CUSTOM:1";
224 std::array<const char*, 2> second_underlying_devices{
225 "/job:localhost/replica:0/task:0/device:CUSTOM:0",
226 "/job:localhost/replica:0/task:0/device:CPU:2"};
227 RegisterParallelDevice(context.get(), second_device_name,
228 second_underlying_devices, status.get());
229 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
230
231 // Create a tensor on the first parallel device
232 TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
233 TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
234 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
235 std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
236 TensorHandlePtr first_combined_value = CreatePerDeviceValues(
237 context.get(), components, first_device_name, status.get());
238 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
239
240 // Nest the first parallel tensor into a second
241 TensorHandlePtr value_three(FloatTensorHandle(3., status.get()));
242 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
243 components[0] = first_combined_value.get();
244 components[1] = value_three.get();
245 TensorHandlePtr second_combined_value = CreatePerDeviceValues(
246 context.get(), components, second_device_name, status.get());
247 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
248
249 TensorHandlePtr negative_one_cpu(FloatTensorHandle(3., status.get()));
250 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
251 components[0] = negative_one_cpu.get();
252 components[1] = negative_one_cpu.get();
253 TensorHandlePtr first_negative_one = CreatePerDeviceValues(
254 context.get(), components, first_device_name, status.get());
255 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
256 components[0] = first_negative_one.get();
257 components[1] = negative_one_cpu.get();
258 TensorHandlePtr second_negative_one = CreatePerDeviceValues(
259 context.get(), components, second_device_name, status.get());
260 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
261
262 TensorHandlePtr multiply_result(
263 Multiply(context.get(), second_combined_value.get(),
264 second_negative_one.get(), status.get()));
265 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
266
267 // Un-pack the parallel tensor to verify that the operation was
268 // successful. The resulting structure should be:
269 // second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}.
270 std::array<TensorHandlePtr, 2> second_components;
271 ExtractPerDeviceValues(context.get(), multiply_result.get(),
272 &second_components, status.get());
273 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
274
275 ExpectScalarEq<float>(second_components[1].get(), 9.);
276
277 // Verify that the mirrors are placed on the component devices.
278 std::string first_device = TFE_TensorHandleBackingDeviceName(
279 second_components[0].get(), status.get());
280 ASSERT_EQ(second_underlying_devices[0], first_device);
281 std::string second_device = TFE_TensorHandleBackingDeviceName(
282 second_components[1].get(), status.get());
283 ASSERT_EQ(second_underlying_devices[1], second_device);
284
285 // Un-pack the first parallel device's tensor too
286 std::array<TensorHandlePtr, 2> first_components;
287 ExtractPerDeviceValues(context.get(), second_components[0].get(),
288 &first_components, status.get());
289 ExpectScalarEq<float>(first_components[0].get(), 3.);
290 ExpectScalarEq<float>(first_components[1].get(), 6.);
291
292 first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
293 status.get());
294 ASSERT_EQ(first_underlying_devices[0], first_device);
295 second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(),
296 status.get());
297 ASSERT_EQ(first_underlying_devices[1], second_device);
298 }
299
TEST(PARALLEL_DEVICE,TestInvalidPacking)300 TEST(PARALLEL_DEVICE, TestInvalidPacking) {
301 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
302 TF_NewStatus(), TF_DeleteStatus);
303 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
304 TFE_NewContextOptions(), TFE_DeleteContextOptions);
305 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
306 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
307 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
308 std::array<const char*, 1> underlying_devices{
309 "/job:localhost/replica:0/task:0/device:CPU:0"};
310 RegisterParallelDevice(context.get(), device_name, underlying_devices,
311 status.get());
312 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
313
314 TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
315 TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
316 {
317 // Try to pack two TensorHandles onto a parallel device with a single
318 // component.
319 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
320 std::array<TFE_TensorHandle*, 2> components{value_one.get(),
321 value_two.get()};
322 TensorHandlePtr combined_value = CreatePerDeviceValues(
323 context.get(), components, device_name, status.get());
324 ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
325 << TF_Message(status.get());
326 }
327
328 {
329 // Try to extract the wrong number of components from a parallel tensor
330 std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
331 TensorHandlePtr combined_value = CreatePerDeviceValues(
332 context.get(), correct_components, device_name, status.get());
333 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
334
335 std::array<TensorHandlePtr, 2> incorrect_components;
336 ExtractPerDeviceValues(context.get(), combined_value.get(),
337 &incorrect_components, status.get());
338 ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
339 << TF_Message(status.get());
340 }
341
342 {
343 // Try to pass a ParallelTensor to TPUReplicatedInput
344 std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
345 TensorHandlePtr combined_value = CreatePerDeviceValues(
346 context.get(), correct_components, device_name, status.get());
347 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
348
349 std::array<TFE_TensorHandle*, 1> incorrect_components{combined_value.get()};
350 TensorHandlePtr recombined_value = CreatePerDeviceValues(
351 context.get(), incorrect_components, device_name, status.get());
352 ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
353 << TF_Message(status.get());
354 }
355
356 {
357 // Try to pass a non-parallel tensor to TPUReplicatedOutput
358 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
359 TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()),
360 TFE_DeleteOp);
361 if (TF_GetCode(status.get()) != TF_OK) return;
362 TFE_OpSetAttrInt(op.get(), "num_replicas", 1);
363 TFE_OpAddInput(op.get(), value_one.get(), status.get());
364 if (TF_GetCode(status.get()) != TF_OK) return;
365 TFE_OpSetDevice(op.get(), device_name, status.get());
366 if (TF_GetCode(status.get()) != TF_OK) return;
367
368 TFE_TensorHandle* result_handles;
369 int num_retvals = 1;
370 TFE_Execute(op.get(), &result_handles, &num_retvals, status.get());
371 ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
372 << TF_Message(status.get());
373 }
374 }
375
CollectiveSum(TFE_Context * context,TFE_TensorHandle * input,int group_size,TF_Status * status)376 TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
377 int group_size, TF_Status* status) {
378 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
379 TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp);
380 if (TF_GetCode(status) != TF_OK) return nullptr;
381
382 const char* device = TFE_TensorHandleDeviceName(input, status);
383 if (TF_GetCode(status) != TF_OK) return nullptr;
384 TFE_OpSetDevice(op.get(), device, status);
385 if (TF_GetCode(status) != TF_OK) return nullptr;
386 TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input));
387 TFE_OpSetAttrInt(op.get(), "group_size", group_size);
388 TFE_OpSetAttrInt(op.get(), "group_key", 0);
389 TFE_OpSetAttrInt(op.get(), "instance_key", 0);
390 const std::string merge_op("Add");
391 TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(),
392 merge_op.length());
393 const std::string final_op("Id");
394 TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(),
395 final_op.length());
396 TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
397
398 TFE_OpAddInput(op.get(), input, status);
399 if (TF_GetCode(status) != TF_OK) return nullptr;
400
401 TFE_TensorHandle* result_handle;
402 int num_retvals = 1;
403 TFE_Execute(op.get(), &result_handle, &num_retvals, status);
404 if (TF_GetCode(status) != TF_OK) return nullptr;
405 return TensorHandlePtr(result_handle);
406 }
407
TestCollective(bool async)408 void TestCollective(bool async) {
409 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
410 TF_NewStatus(), TF_DeleteStatus);
411 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
412 TFE_NewContextOptions(), TFE_DeleteContextOptions);
413 TFE_ContextOptionsSetAsync(opts.get(), async);
414 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
415 TF_CreateConfig(
416 /*enable_xla_compilation=*/false,
417 /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
418 TF_DeleteBuffer);
419 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
420 status.get());
421 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
422 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
423 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
424
425 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
426 std::array<const char*, 2> underlying_devices{
427 "/job:localhost/replica:0/task:0/device:CPU:0",
428 "/job:localhost/replica:0/task:0/device:CPU:1"};
429 RegisterParallelDevice(context.get(), device_name, underlying_devices,
430 status.get());
431 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
432
433 // Create a tensor on the parallel device
434 TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
435 TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
436 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
437 std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
438 TensorHandlePtr parallel_value = CreatePerDeviceValues(
439 context.get(), components, device_name, status.get());
440 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
441
442 // Run a collective sum, so each component should now be the same.
443 TensorHandlePtr reduced(
444 CollectiveSum(context.get(), parallel_value.get(), 2, status.get()));
445 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
446
447 std::array<TensorHandlePtr, 2> result_components;
448 ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
449 status.get());
450 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
451 ExpectScalarEq<float>(result_components[0].get(), 3.);
452 ExpectScalarEq<float>(result_components[1].get(), 3.);
453 }
454
TEST(PARALLEL_DEVICE,TestCollectiveSync)455 TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
456
457 // Note that ops on the parallel device currently don't execute
458 // asynchronously. The test is just that we don't get deadlocks.
TEST(PARALLEL_DEVICE,TestCollectiveAsync)459 TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
460
RegisterCollectiveMulFunction(TFE_Context * context,const char * function_name,int group_size,TF_Status * status)461 void RegisterCollectiveMulFunction(TFE_Context* context,
462 const char* function_name, int group_size,
463 TF_Status* status) {
464 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> body(TF_NewGraph(),
465 TF_DeleteGraph);
466 TF_OperationDescription* placeholder_desc =
467 TF_NewOperation(body.get(), "Placeholder", "Placeholder");
468 TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT);
469 TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status);
470 if (TF_GetCode(status) != TF_OK) return;
471 TF_Output x{placeholder_op, 0};
472
473 TF_OperationDescription* reduce_desc =
474 TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce");
475 TF_SetAttrType(reduce_desc, "T", TF_FLOAT);
476 TF_SetAttrInt(reduce_desc, "group_size", group_size);
477 TF_SetAttrInt(reduce_desc, "group_key", 0);
478 TF_SetAttrInt(reduce_desc, "instance_key", 0);
479
480 const std::string merge_op("Mul");
481 TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(),
482 merge_op.length());
483 const std::string final_op("Id");
484 TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(),
485 final_op.length());
486 TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
487 TF_AddInput(reduce_desc, x);
488 TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
489 if (TF_GetCode(status) != TF_OK) return;
490 TF_Operation* operations[]{placeholder_op, reduce_op};
491 TF_Output y{reduce_op, 0};
492 const char* output_name = "y";
493 std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> function(
494 TF_GraphToFunction(
495 /* fn_body */ body.get(), /* fn_name */ function_name,
496 /* append_hash_to_fn_name */ 0, /* num_opers */ 2,
497 /* opers */ operations, /* ninputs */ 1, /* inputs */ &x,
498 /* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name,
499 /* opts */ nullptr, /* description */ "", /* status */ status),
500 TF_DeleteFunction);
501 if (TF_GetCode(status) != TF_OK) return;
502 TFE_ContextAddFunction(context, function.get(), status);
503 }
504
TEST(PARALLEL_DEVICE,TestFunction)505 TEST(PARALLEL_DEVICE, TestFunction) {
506 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
507 TF_NewStatus(), TF_DeleteStatus);
508 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
509 TFE_NewContextOptions(), TFE_DeleteContextOptions);
510 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
511 TF_CreateConfig(
512 /*enable_xla_compilation=*/false,
513 /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
514 TF_DeleteBuffer);
515 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
516 status.get());
517 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
518 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
519 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
520
521 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
522 std::array<const char*, 2> underlying_devices{
523 "/job:localhost/replica:0/task:0/device:CPU:0",
524 "/job:localhost/replica:0/task:0/device:CPU:1"};
525 RegisterParallelDevice(context.get(), device_name, underlying_devices,
526 status.get());
527 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
528
529 const char* function_name = "test_reduce_mul";
530 RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get());
531 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
532
533 TensorHandlePtr value_one(FloatTensorHandle(7., status.get()));
534 TensorHandlePtr value_two(FloatTensorHandle(9., status.get()));
535 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
536 std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
537 TensorHandlePtr parallel_value = CreatePerDeviceValues(
538 context.get(), components, device_name, status.get());
539 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
540
541 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
542 TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp);
543 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
544 TFE_OpSetDevice(op.get(), device_name, status.get());
545 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
546 TFE_OpAddInput(op.get(), parallel_value.get(), status.get());
547 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
548
549 TFE_TensorHandle* raw_result_handle;
550 int num_retvals = 1;
551 TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get());
552 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
553 TensorHandlePtr reduced(raw_result_handle);
554
555 std::array<TensorHandlePtr, 2> result_components;
556 ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
557 status.get());
558 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
559 ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
560 ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
561
562 std::string first_device = TFE_TensorHandleBackingDeviceName(
563 result_components[0].get(), status.get());
564 ASSERT_EQ(underlying_devices[0], first_device);
565 std::string second_device = TFE_TensorHandleBackingDeviceName(
566 result_components[1].get(), status.get());
567 ASSERT_EQ(underlying_devices[1], second_device);
568 }
569
TEST(PARALLEL_DEVICE,TestSummaryString)570 TEST(PARALLEL_DEVICE, TestSummaryString) {
571 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
572 TF_NewStatus(), TF_DeleteStatus);
573 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
574 TFE_NewContextOptions(), TFE_DeleteContextOptions);
575 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
576 TF_CreateConfig(
577 /*enable_xla_compilation=*/false,
578 /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
579 TF_DeleteBuffer);
580 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
581 status.get());
582 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
583 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
584 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
585
586 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
587 std::array<const char*, 2> underlying_devices{
588 "/job:localhost/replica:0/task:0/device:CPU:0",
589 "/job:localhost/replica:0/task:0/device:CPU:1"};
590 RegisterParallelDevice(context.get(), device_name, underlying_devices,
591 status.get());
592 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
593 TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
594 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
595 std::array<TFE_TensorHandle*, 2> components{cpu_value.get(), cpu_value.get()};
596 TensorHandlePtr device_value = CreatePerDeviceValues(
597 context.get(), components, device_name, status.get());
598 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
599 ImmediateExecutionTensorHandle* unwrapped_handle =
600 tensorflow::unwrap(device_value.get());
601 std::string summarized;
602 TF_ASSERT_OK(unwrapped_handle->SummarizeValue(summarized));
603 EXPECT_THAT(summarized, HasSubstr("\"CPU:0\": 3"));
604 }
605
606 } // namespace parallel_device
607 } // namespace tensorflow
608