xref: /aosp_15_r20/external/tensorflow/tensorflow/c/eager/parallel_device/parallel_device_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
17 
18 #include <array>
19 
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_experimental.h"
24 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
25 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
26 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
27 #include "tensorflow/c/tf_status_internal.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/test.h"
30 
31 // NOTE(allenl): These tests currently go through TFE_Execute and so are
32 // integration testing rather than purely testing the parallel device. They
33 // correspond fairly well to the implementation, but testing the C++ directly is
34 // another option.
35 
36 namespace tensorflow {
37 namespace parallel_device {
38 
39 using ::testing::HasSubstr;
40 
TEST(PARALLEL_DEVICE,TestBasicCPU)41 TEST(PARALLEL_DEVICE, TestBasicCPU) {
42   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
43       TF_NewStatus(), TF_DeleteStatus);
44   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
45       TFE_NewContextOptions(), TFE_DeleteContextOptions);
46   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
47       TF_CreateConfig(
48           /*enable_xla_compilation=*/false,
49           /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
50       TF_DeleteBuffer);
51   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
52                               status.get());
53   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
54       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
55   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
56   BasicTestsForTwoDevices(context.get(),
57                           "/job:localhost/replica:0/task:0/device:CPU:0",
58                           "/job:localhost/replica:0/task:0/device:CPU:1");
59 }
60 
TEST(PARALLEL_DEVICE,TestBasicCPUAliased)61 TEST(PARALLEL_DEVICE, TestBasicCPUAliased) {
62   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
63       TF_NewStatus(), TF_DeleteStatus);
64   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
65       TFE_NewContextOptions(), TFE_DeleteContextOptions);
66   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
67       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
68   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
69   BasicTestsForTwoDevices(context.get(),
70                           "/job:localhost/replica:0/task:0/device:CPU:0",
71                           "/job:localhost/replica:0/task:0/device:CPU:0");
72 }
73 
TEST(PARALLEL_DEVICE,TestBasicTPUAliased)74 TEST(PARALLEL_DEVICE, TestBasicTPUAliased) {
75   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
76       TF_NewStatus(), TF_DeleteStatus);
77   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
78       TFE_NewContextOptions(), TFE_DeleteContextOptions);
79   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
80       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
81   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
82 
83   // Skip the test if no TPU is available.
84   std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> devices(
85       TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList);
86   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
87   bool has_tpu = false;
88   for (int device_index = 0; device_index < TF_DeviceListCount(devices.get());
89        ++device_index) {
90     std::string device_type =
91         TF_DeviceListType(devices.get(), device_index, status.get());
92     ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
93     if (device_type == "TPU") {
94       has_tpu = true;
95       break;
96     }
97   }
98   if (has_tpu) {
99     BasicTestsForTwoDevices(context.get(),
100                             "/job:localhost/replica:0/task:0/device:TPU:0",
101                             "/job:localhost/replica:0/task:0/device:TPU:0");
102   }
103 }
104 
TEST(PARALLEL_DEVICE,TestExplicitCopies)105 TEST(PARALLEL_DEVICE, TestExplicitCopies) {
106   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
107       TF_NewStatus(), TF_DeleteStatus);
108   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
109       TFE_NewContextOptions(), TFE_DeleteContextOptions);
110   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
111       TF_CreateConfig(
112           /*enable_xla_compilation=*/false,
113           /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
114       TF_DeleteBuffer);
115   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
116                               status.get());
117   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
118       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
119   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
120 
121   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
122   const char* first_device_name =
123       "/job:localhost/replica:0/task:0/device:CPU:0";
124   const char* second_device_name =
125       "/job:localhost/replica:0/task:0/device:CPU:1";
126   std::array<const char*, 2> underlying_devices{first_device_name,
127                                                 second_device_name};
128   RegisterParallelDevice(context.get(), device_name, underlying_devices,
129                          status.get());
130   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
131 
132   TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
133   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
134 
135   // Copying on to a parallel device must be explicit.
136   TensorHandlePtr failed_copy_on_result(TFE_TensorHandleCopyToDevice(
137       cpu_value.get(), context.get(), device_name, status.get()));
138   EXPECT_EQ(TF_GetCode(status.get()), TF_UNIMPLEMENTED);
139 
140   std::array<TFE_TensorHandle*, 2> components{cpu_value.get(), cpu_value.get()};
141   TensorHandlePtr device_value = CreatePerDeviceValues(
142       context.get(), components, device_name, status.get());
143   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
144   // Copies off of parallel devices must be explicit.
145   TensorHandlePtr copy_off(TFE_TensorHandleCopyToDevice(
146       device_value.get(), context.get(), first_device_name, status.get()));
147   EXPECT_EQ(TF_GetCode(status.get()), TF_UNIMPLEMENTED);
148 }
149 
TEST(PARALLEL_DEVICE,TestDifferentShapes)150 TEST(PARALLEL_DEVICE, TestDifferentShapes) {
151   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
152       TF_NewStatus(), TF_DeleteStatus);
153   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
154       TFE_NewContextOptions(), TFE_DeleteContextOptions);
155   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
156       TF_CreateConfig(
157           /*enable_xla_compilation=*/false,
158           /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
159       TF_DeleteBuffer);
160   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
161                               status.get());
162   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
163       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
164   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
165 
166   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
167   std::array<const char*, 2> underlying_devices{
168       "/job:localhost/replica:0/task:0/device:CPU:0",
169       "/job:localhost/replica:0/task:0/device:CPU:1"};
170   RegisterParallelDevice(context.get(), device_name, underlying_devices,
171                          status.get());
172   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
173 
174   // Create two vectors with different lengths
175   std::vector<float> size_two_value{1., 2.};
176   std::vector<float> size_three_value{1., 2., 3.};
177   TensorHandlePtr size_two(
178       VectorFloatTensorHandle(size_two_value, status.get()));
179   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
180   TensorHandlePtr size_three(
181       VectorFloatTensorHandle(size_three_value, status.get()));
182   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
183 
184   // Try to combine these values into a single parallel tensor.
185   std::array<TFE_TensorHandle*, 2> components{size_two.get(), size_three.get()};
186   TensorHandlePtr combined_value = CreatePerDeviceValues(
187       context.get(), components, device_name, status.get());
188   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
189   int num_axes = TFE_TensorHandleNumDims(combined_value.get(), status.get());
190   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
191   EXPECT_EQ(num_axes, 1);
192 }
193 
TEST(PARALLEL_DEVICE,TestNestedParallelDevices)194 TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
195   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
196       TF_NewStatus(), TF_DeleteStatus);
197   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
198       TFE_NewContextOptions(), TFE_DeleteContextOptions);
199   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
200       TF_CreateConfig(
201           /*enable_xla_compilation=*/false,
202           /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/3),
203       TF_DeleteBuffer);
204   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
205                               status.get());
206   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
207       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
208   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
209 
210   // Create a parallel device with two CPUs
211   const char* first_device_name =
212       "/job:localhost/replica:0/task:0/device:CUSTOM:0";
213   std::array<const char*, 2> first_underlying_devices{
214       "/job:localhost/replica:0/task:0/device:CPU:0",
215       "/job:localhost/replica:0/task:0/device:CPU:1"};
216   RegisterParallelDevice(context.get(), first_device_name,
217                          first_underlying_devices, status.get());
218   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
219 
220   // Create a second parallel device with the first parallel device and one
221   // additional CPU.
222   const char* second_device_name =
223       "/job:localhost/replica:0/task:0/device:CUSTOM:1";
224   std::array<const char*, 2> second_underlying_devices{
225       "/job:localhost/replica:0/task:0/device:CUSTOM:0",
226       "/job:localhost/replica:0/task:0/device:CPU:2"};
227   RegisterParallelDevice(context.get(), second_device_name,
228                          second_underlying_devices, status.get());
229   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
230 
231   // Create a tensor on the first parallel device
232   TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
233   TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
234   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
235   std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
236   TensorHandlePtr first_combined_value = CreatePerDeviceValues(
237       context.get(), components, first_device_name, status.get());
238   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
239 
240   // Nest the first parallel tensor into a second
241   TensorHandlePtr value_three(FloatTensorHandle(3., status.get()));
242   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
243   components[0] = first_combined_value.get();
244   components[1] = value_three.get();
245   TensorHandlePtr second_combined_value = CreatePerDeviceValues(
246       context.get(), components, second_device_name, status.get());
247   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
248 
249   TensorHandlePtr negative_one_cpu(FloatTensorHandle(3., status.get()));
250   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
251   components[0] = negative_one_cpu.get();
252   components[1] = negative_one_cpu.get();
253   TensorHandlePtr first_negative_one = CreatePerDeviceValues(
254       context.get(), components, first_device_name, status.get());
255   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
256   components[0] = first_negative_one.get();
257   components[1] = negative_one_cpu.get();
258   TensorHandlePtr second_negative_one = CreatePerDeviceValues(
259       context.get(), components, second_device_name, status.get());
260   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
261 
262   TensorHandlePtr multiply_result(
263       Multiply(context.get(), second_combined_value.get(),
264                second_negative_one.get(), status.get()));
265   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
266 
267   // Un-pack the parallel tensor to verify that the operation was
268   // successful. The resulting structure should be:
269   //   second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}.
270   std::array<TensorHandlePtr, 2> second_components;
271   ExtractPerDeviceValues(context.get(), multiply_result.get(),
272                          &second_components, status.get());
273   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
274 
275   ExpectScalarEq<float>(second_components[1].get(), 9.);
276 
277   // Verify that the mirrors are placed on the component devices.
278   std::string first_device = TFE_TensorHandleBackingDeviceName(
279       second_components[0].get(), status.get());
280   ASSERT_EQ(second_underlying_devices[0], first_device);
281   std::string second_device = TFE_TensorHandleBackingDeviceName(
282       second_components[1].get(), status.get());
283   ASSERT_EQ(second_underlying_devices[1], second_device);
284 
285   // Un-pack the first parallel device's tensor too
286   std::array<TensorHandlePtr, 2> first_components;
287   ExtractPerDeviceValues(context.get(), second_components[0].get(),
288                          &first_components, status.get());
289   ExpectScalarEq<float>(first_components[0].get(), 3.);
290   ExpectScalarEq<float>(first_components[1].get(), 6.);
291 
292   first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
293                                                    status.get());
294   ASSERT_EQ(first_underlying_devices[0], first_device);
295   second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(),
296                                                     status.get());
297   ASSERT_EQ(first_underlying_devices[1], second_device);
298 }
299 
TEST(PARALLEL_DEVICE,TestInvalidPacking)300 TEST(PARALLEL_DEVICE, TestInvalidPacking) {
301   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
302       TF_NewStatus(), TF_DeleteStatus);
303   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
304       TFE_NewContextOptions(), TFE_DeleteContextOptions);
305   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
306       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
307   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
308   std::array<const char*, 1> underlying_devices{
309       "/job:localhost/replica:0/task:0/device:CPU:0"};
310   RegisterParallelDevice(context.get(), device_name, underlying_devices,
311                          status.get());
312   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
313 
314   TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
315   TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
316   {
317     // Try to pack two TensorHandles onto a parallel device with a single
318     // component.
319     ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
320     std::array<TFE_TensorHandle*, 2> components{value_one.get(),
321                                                 value_two.get()};
322     TensorHandlePtr combined_value = CreatePerDeviceValues(
323         context.get(), components, device_name, status.get());
324     ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
325         << TF_Message(status.get());
326   }
327 
328   {
329     // Try to extract the wrong number of components from a parallel tensor
330     std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
331     TensorHandlePtr combined_value = CreatePerDeviceValues(
332         context.get(), correct_components, device_name, status.get());
333     ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
334 
335     std::array<TensorHandlePtr, 2> incorrect_components;
336     ExtractPerDeviceValues(context.get(), combined_value.get(),
337                            &incorrect_components, status.get());
338     ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
339         << TF_Message(status.get());
340   }
341 
342   {
343     // Try to pass a ParallelTensor to TPUReplicatedInput
344     std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
345     TensorHandlePtr combined_value = CreatePerDeviceValues(
346         context.get(), correct_components, device_name, status.get());
347     ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
348 
349     std::array<TFE_TensorHandle*, 1> incorrect_components{combined_value.get()};
350     TensorHandlePtr recombined_value = CreatePerDeviceValues(
351         context.get(), incorrect_components, device_name, status.get());
352     ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
353         << TF_Message(status.get());
354   }
355 
356   {
357     // Try to pass a non-parallel tensor to TPUReplicatedOutput
358     std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
359         TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()),
360         TFE_DeleteOp);
361     if (TF_GetCode(status.get()) != TF_OK) return;
362     TFE_OpSetAttrInt(op.get(), "num_replicas", 1);
363     TFE_OpAddInput(op.get(), value_one.get(), status.get());
364     if (TF_GetCode(status.get()) != TF_OK) return;
365     TFE_OpSetDevice(op.get(), device_name, status.get());
366     if (TF_GetCode(status.get()) != TF_OK) return;
367 
368     TFE_TensorHandle* result_handles;
369     int num_retvals = 1;
370     TFE_Execute(op.get(), &result_handles, &num_retvals, status.get());
371     ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
372         << TF_Message(status.get());
373   }
374 }
375 
CollectiveSum(TFE_Context * context,TFE_TensorHandle * input,int group_size,TF_Status * status)376 TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
377                               int group_size, TF_Status* status) {
378   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
379       TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp);
380   if (TF_GetCode(status) != TF_OK) return nullptr;
381 
382   const char* device = TFE_TensorHandleDeviceName(input, status);
383   if (TF_GetCode(status) != TF_OK) return nullptr;
384   TFE_OpSetDevice(op.get(), device, status);
385   if (TF_GetCode(status) != TF_OK) return nullptr;
386   TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input));
387   TFE_OpSetAttrInt(op.get(), "group_size", group_size);
388   TFE_OpSetAttrInt(op.get(), "group_key", 0);
389   TFE_OpSetAttrInt(op.get(), "instance_key", 0);
390   const std::string merge_op("Add");
391   TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(),
392                       merge_op.length());
393   const std::string final_op("Id");
394   TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(),
395                       final_op.length());
396   TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
397 
398   TFE_OpAddInput(op.get(), input, status);
399   if (TF_GetCode(status) != TF_OK) return nullptr;
400 
401   TFE_TensorHandle* result_handle;
402   int num_retvals = 1;
403   TFE_Execute(op.get(), &result_handle, &num_retvals, status);
404   if (TF_GetCode(status) != TF_OK) return nullptr;
405   return TensorHandlePtr(result_handle);
406 }
407 
TestCollective(bool async)408 void TestCollective(bool async) {
409   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
410       TF_NewStatus(), TF_DeleteStatus);
411   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
412       TFE_NewContextOptions(), TFE_DeleteContextOptions);
413   TFE_ContextOptionsSetAsync(opts.get(), async);
414   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
415       TF_CreateConfig(
416           /*enable_xla_compilation=*/false,
417           /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
418       TF_DeleteBuffer);
419   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
420                               status.get());
421   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
422       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
423   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
424 
425   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
426   std::array<const char*, 2> underlying_devices{
427       "/job:localhost/replica:0/task:0/device:CPU:0",
428       "/job:localhost/replica:0/task:0/device:CPU:1"};
429   RegisterParallelDevice(context.get(), device_name, underlying_devices,
430                          status.get());
431   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
432 
433   // Create a tensor on the parallel device
434   TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
435   TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
436   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
437   std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
438   TensorHandlePtr parallel_value = CreatePerDeviceValues(
439       context.get(), components, device_name, status.get());
440   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
441 
442   // Run a collective sum, so each component should now be the same.
443   TensorHandlePtr reduced(
444       CollectiveSum(context.get(), parallel_value.get(), 2, status.get()));
445   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
446 
447   std::array<TensorHandlePtr, 2> result_components;
448   ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
449                          status.get());
450   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
451   ExpectScalarEq<float>(result_components[0].get(), 3.);
452   ExpectScalarEq<float>(result_components[1].get(), 3.);
453 }
454 
TEST(PARALLEL_DEVICE,TestCollectiveSync)455 TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
456 
457 // Note that ops on the parallel device currently don't execute
458 // asynchronously. The test is just that we don't get deadlocks.
TEST(PARALLEL_DEVICE,TestCollectiveAsync)459 TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
460 
RegisterCollectiveMulFunction(TFE_Context * context,const char * function_name,int group_size,TF_Status * status)461 void RegisterCollectiveMulFunction(TFE_Context* context,
462                                    const char* function_name, int group_size,
463                                    TF_Status* status) {
464   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> body(TF_NewGraph(),
465                                                             TF_DeleteGraph);
466   TF_OperationDescription* placeholder_desc =
467       TF_NewOperation(body.get(), "Placeholder", "Placeholder");
468   TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT);
469   TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status);
470   if (TF_GetCode(status) != TF_OK) return;
471   TF_Output x{placeholder_op, 0};
472 
473   TF_OperationDescription* reduce_desc =
474       TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce");
475   TF_SetAttrType(reduce_desc, "T", TF_FLOAT);
476   TF_SetAttrInt(reduce_desc, "group_size", group_size);
477   TF_SetAttrInt(reduce_desc, "group_key", 0);
478   TF_SetAttrInt(reduce_desc, "instance_key", 0);
479 
480   const std::string merge_op("Mul");
481   TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(),
482                    merge_op.length());
483   const std::string final_op("Id");
484   TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(),
485                    final_op.length());
486   TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
487   TF_AddInput(reduce_desc, x);
488   TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
489   if (TF_GetCode(status) != TF_OK) return;
490   TF_Operation* operations[]{placeholder_op, reduce_op};
491   TF_Output y{reduce_op, 0};
492   const char* output_name = "y";
493   std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> function(
494       TF_GraphToFunction(
495           /* fn_body */ body.get(), /* fn_name */ function_name,
496           /* append_hash_to_fn_name */ 0, /* num_opers */ 2,
497           /* opers */ operations, /* ninputs */ 1, /* inputs */ &x,
498           /* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name,
499           /* opts */ nullptr, /* description */ "", /* status */ status),
500       TF_DeleteFunction);
501   if (TF_GetCode(status) != TF_OK) return;
502   TFE_ContextAddFunction(context, function.get(), status);
503 }
504 
TEST(PARALLEL_DEVICE,TestFunction)505 TEST(PARALLEL_DEVICE, TestFunction) {
506   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
507       TF_NewStatus(), TF_DeleteStatus);
508   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
509       TFE_NewContextOptions(), TFE_DeleteContextOptions);
510   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
511       TF_CreateConfig(
512           /*enable_xla_compilation=*/false,
513           /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
514       TF_DeleteBuffer);
515   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
516                               status.get());
517   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
518       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
519   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
520 
521   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
522   std::array<const char*, 2> underlying_devices{
523       "/job:localhost/replica:0/task:0/device:CPU:0",
524       "/job:localhost/replica:0/task:0/device:CPU:1"};
525   RegisterParallelDevice(context.get(), device_name, underlying_devices,
526                          status.get());
527   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
528 
529   const char* function_name = "test_reduce_mul";
530   RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get());
531   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
532 
533   TensorHandlePtr value_one(FloatTensorHandle(7., status.get()));
534   TensorHandlePtr value_two(FloatTensorHandle(9., status.get()));
535   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
536   std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
537   TensorHandlePtr parallel_value = CreatePerDeviceValues(
538       context.get(), components, device_name, status.get());
539   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
540 
541   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
542       TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp);
543   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
544   TFE_OpSetDevice(op.get(), device_name, status.get());
545   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
546   TFE_OpAddInput(op.get(), parallel_value.get(), status.get());
547   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
548 
549   TFE_TensorHandle* raw_result_handle;
550   int num_retvals = 1;
551   TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get());
552   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
553   TensorHandlePtr reduced(raw_result_handle);
554 
555   std::array<TensorHandlePtr, 2> result_components;
556   ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
557                          status.get());
558   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
559   ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
560   ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
561 
562   std::string first_device = TFE_TensorHandleBackingDeviceName(
563       result_components[0].get(), status.get());
564   ASSERT_EQ(underlying_devices[0], first_device);
565   std::string second_device = TFE_TensorHandleBackingDeviceName(
566       result_components[1].get(), status.get());
567   ASSERT_EQ(underlying_devices[1], second_device);
568 }
569 
TEST(PARALLEL_DEVICE,TestSummaryString)570 TEST(PARALLEL_DEVICE, TestSummaryString) {
571   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
572       TF_NewStatus(), TF_DeleteStatus);
573   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
574       TFE_NewContextOptions(), TFE_DeleteContextOptions);
575   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
576       TF_CreateConfig(
577           /*enable_xla_compilation=*/false,
578           /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
579       TF_DeleteBuffer);
580   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
581                               status.get());
582   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
583       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
584   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
585 
586   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
587   std::array<const char*, 2> underlying_devices{
588       "/job:localhost/replica:0/task:0/device:CPU:0",
589       "/job:localhost/replica:0/task:0/device:CPU:1"};
590   RegisterParallelDevice(context.get(), device_name, underlying_devices,
591                          status.get());
592   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
593   TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
594   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
595   std::array<TFE_TensorHandle*, 2> components{cpu_value.get(), cpu_value.get()};
596   TensorHandlePtr device_value = CreatePerDeviceValues(
597       context.get(), components, device_name, status.get());
598   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
599   ImmediateExecutionTensorHandle* unwrapped_handle =
600       tensorflow::unwrap(device_value.get());
601   std::string summarized;
602   TF_ASSERT_OK(unwrapped_handle->SummarizeValue(summarized));
603   EXPECT_THAT(summarized, HasSubstr("\"CPU:0\": 3"));
604 }
605 
606 }  // namespace parallel_device
607 }  // namespace tensorflow
608