xref: /aosp_15_r20/external/tensorflow/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
17 
18 #include "tensorflow/c/c_api.h"
19 #include "tensorflow/c/c_api_experimental.h"
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
23 #include "tensorflow/c/eager/tfe_context_internal.h"
24 #include "tensorflow/core/common_runtime/eager/context.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 namespace parallel_device {
31 
32 using ::testing::ElementsAre;
33 using ::testing::HasSubstr;
34 
TEST(PARALLEL_DEVICE_LIB,TestOpWithError)35 TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
36   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
37       TF_NewStatus(), TF_DeleteStatus);
38   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
39       TFE_NewContextOptions(), TFE_DeleteContextOptions);
40   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
41       TF_CreateConfig(
42           /*xla*/ false,
43           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
44           2),
45       TF_DeleteBuffer);
46   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
47                               status.get());
48   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
49       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
50   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
51 
52   std::vector<std::string> devices{
53       "/job:localhost/replica:0/task:0/device:CPU:0",
54       "/job:localhost/replica:0/task:0/device:CPU:1"};
55   ParallelDevice parallel_device(std::move(devices));
56   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
57       TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
58   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
59   TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
60   TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
61                      status.get());
62   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
63   auto outputs =
64       parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
65                               "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
66                               /*expected_max_outputs=*/1, status.get());
67   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
68   const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
69   std::vector<ParallelTensor*> handle_inputs;
70   handle_inputs.reserve(handles.size());
71   for (auto& handle : handles) {
72     handle_inputs.push_back(handle.get());
73   }
74   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> read_op(
75       TFE_NewOp(context.get(), "ReadVariableOp", status.get()), TFE_DeleteOp);
76   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
77   TFE_OpSetAttrType(read_op.get(), "dtype", TF_FLOAT);
78   parallel_device.Execute(context.get(), handle_inputs, "ReadVariableOp",
79                           TFE_OpGetAttrs(read_op.get()),
80                           /*expected_max_outputs=*/1, status.get());
81   ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
82   TF_SetStatus(status.get(), TF_OK, "");
83 
84   // Check that ops still run successfully on the device.
85   parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
86                           "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
87                           /*expected_max_outputs=*/1, status.get());
88   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
89 }
90 
TEST(PARALLEL_DEVICE_LIB,TestExplicitOutputShape)91 TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) {
92   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
93       TF_NewStatus(), TF_DeleteStatus);
94   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
95       TFE_NewContextOptions(), TFE_DeleteContextOptions);
96   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
97       TF_CreateConfig(
98           /*xla*/ false,
99           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
100           2),
101       TF_DeleteBuffer);
102   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
103                               status.get());
104   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
105       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
106   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
107 
108   std::vector<std::string> devices{
109       "/job:localhost/replica:0/task:0/device:CPU:0",
110       "/job:localhost/replica:0/task:0/device:CPU:1"};
111   ParallelDevice parallel_device(std::move(devices));
112   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
113       TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
114   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
115   TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
116   TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
117                      status.get());
118   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
119   CancellationManager cancellation_manager;
120   parallel_device.StartExecute(context.get(), std::vector<ParallelTensor*>(),
121                                "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
122                                /*expected_max_outputs=*/1,
123                                cancellation_manager);
124   auto outputs = parallel_device.Join(
125       /*expected_output_shapes=*/{PartialTensorShape({})}, status.get());
126   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
127   const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
128   const std::vector<int64_t>* shape;
129   Status s = handles[0]->Shape(&shape);
130   ASSERT_TRUE(s.ok());
131   EXPECT_EQ(0, shape->size());
132 }
133 
TEST(PARALLEL_DEVICE_LIB,TestCancelOnError)134 TEST(PARALLEL_DEVICE_LIB, TestCancelOnError) {
135   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
136       TF_NewStatus(), TF_DeleteStatus);
137   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
138       TFE_NewContextOptions(), TFE_DeleteContextOptions);
139   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
140       TF_CreateConfig(
141           /*enable_xla_compilation=*/false,
142           /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
143       TF_DeleteBuffer);
144   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
145                               status.get());
146   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
147       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
148   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
149 
150   std::vector<std::string> devices{
151       "/job:localhost/replica:0/task:0/device:CPU:0",
152       "/job:localhost/replica:0/task:0/device:CPU:1"};
153   ParallelDevice parallel_device(devices);
154   const FunctionDef assert_and_collective = FunctionDefHelper::Define(
155       // Name
156       "AssertAndCollective",
157       // Args
158       {"x: float", "condition: bool"},
159       // Return values
160       {"y: float"},
161       // Attr def
162       {},
163       // Nodes
164       {
165           {{"assert"},
166            "Assert",
167            {"condition", "x"},
168            {{"T", std::vector<DataType>{DT_FLOAT}}}},
169           {{"y"},
170            "CollectiveReduce",
171            {"x"},
172            {{"T", DT_FLOAT},
173             {"group_size", static_cast<int>(devices.size())},
174             {"group_key", 0},
175             {"instance_key", 0},
176             {"merge_op", "Add"},
177             {"final_op", "Id"},
178             {"subdiv_offsets", std::vector<int>()}},
179            /*dep=*/{"assert"}},
180       });
181   TF_ASSERT_OK(ContextFromInterface(unwrap(context.get()))
182                    ->AddFunctionDef(assert_and_collective));
183 
184   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> call_op(
185       TFE_NewOp(context.get(), "AssertAndCollective", status.get()),
186       TFE_DeleteOp);
187   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
188   std::unique_ptr<ParallelTensor> reduced_values =
189       parallel_device.ScalarsFromSequence<float>({1.0, 2.0}, context.get(),
190                                                  status.get());
191   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
192   std::unique_ptr<ParallelTensor> run_collective =
193       parallel_device.ScalarsFromSequence<bool>({true, true}, context.get(),
194                                                 status.get());
195   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
196   auto outputs = parallel_device.Execute(
197       context.get(), {reduced_values.get(), run_collective.get()},
198       "AssertAndCollective", TFE_OpGetAttrs(call_op.get()),
199       /*expected_max_outputs=*/1, status.get());
200   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
201   ASSERT_EQ(outputs->size(), 1);
202   ParallelTensor* parallel_result = (*outputs)[0].get();
203   ExpectScalarEq<float>(parallel_result->tensor(0), 3.);
204   ExpectScalarEq<float>(parallel_result->tensor(1), 3.);
205 
206   run_collective = parallel_device.ScalarsFromSequence<bool>(
207       {true, false}, context.get(), status.get());
208   parallel_device.Execute(context.get(),
209                           {reduced_values.get(), run_collective.get()},
210                           "AssertAndCollective", TFE_OpGetAttrs(call_op.get()),
211                           /*expected_max_outputs=*/1, status.get());
212   EXPECT_NE(TF_GetCode(status.get()), TF_CANCELLED);
213   EXPECT_EQ(TF_GetCode(status.get()), TF_INVALID_ARGUMENT);
214   EXPECT_THAT(TF_Message(status.get()), HasSubstr("assertion failed"));
215 
216   // Note that future collectives with the same context do not work at the
217   // moment; once canceled, the collective executor requires the program to be
218   // restarted / context to be reset.
219 }
220 
TEST(PARALLEL_DEVICE_LIB,TestDifferentShapes)221 TEST(PARALLEL_DEVICE_LIB, TestDifferentShapes) {
222   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
223       TF_NewStatus(), TF_DeleteStatus);
224   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
225       TFE_NewContextOptions(), TFE_DeleteContextOptions);
226   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
227       TF_CreateConfig(
228           /*xla*/ false,
229           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
230           2),
231       TF_DeleteBuffer);
232   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
233                               status.get());
234   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
235       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
236   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
237 
238   std::vector<std::string> devices{
239       "/job:localhost/replica:0/task:0/device:CPU:0",
240       "/job:localhost/replica:0/task:0/device:CPU:1"};
241   ParallelDevice parallel_device(std::move(devices));
242   TensorHandlePtr two_vector = VectorFloatTensorHandle({3., 4.}, status.get());
243   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
244   TensorHandlePtr three_vector =
245       VectorFloatTensorHandle({5., 6., 7.}, status.get());
246   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
247 
248   std::vector<TensorHandlePtr> vector_handles;
249   vector_handles.reserve(2);
250   vector_handles.push_back(std::move(two_vector));
251   vector_handles.push_back(std::move(three_vector));
252   std::unique_ptr<ParallelTensor> unknown_length_vector =
253       ParallelTensor::FromTensorHandles(
254           parallel_device, std::move(vector_handles), status.get());
255   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
256   const std::vector<int64_t>* shape;
257   TF_ASSERT_OK(unknown_length_vector->Shape(&shape));
258   EXPECT_THAT(*shape, ElementsAre(-1));
259 
260   TensorHandlePtr scalar = FloatTensorHandle(2., status.get());
261   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
262   two_vector = VectorFloatTensorHandle({3., 4.}, status.get());
263   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
264   std::vector<TensorHandlePtr> mixed_handles;
265   mixed_handles.reserve(2);
266   mixed_handles.push_back(std::move(scalar));
267   mixed_handles.push_back(std::move(two_vector));
268   std::unique_ptr<ParallelTensor> unknown_dims_vector =
269       ParallelTensor::FromTensorHandles(parallel_device,
270                                         std::move(mixed_handles), status.get());
271   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
272   // Can't take the shape of a parallel tensor with varying numbers of axes, but
273   // running operations on them is OK.
274   TF_ASSERT_OK(unknown_length_vector->Shape(&shape));
275   EXPECT_THAT(*shape, ElementsAre(-1));
276   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> size_op(
277       TFE_NewOp(context.get(), "Size", status.get()), TFE_DeleteOp);
278   auto result = parallel_device.Execute(
279       context.get(), {unknown_dims_vector.get()}, "Size",
280       TFE_OpGetAttrs(size_op.get()), 1, status.get());
281   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
282   TF_ASSERT_OK((*result)[0]->Shape(&shape));
283   EXPECT_EQ(0, shape->size());
284 }
285 
TEST(PARALLEL_DEVICE_LIB,TestScalarsFromSequence)286 TEST(PARALLEL_DEVICE_LIB, TestScalarsFromSequence) {
287   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
288       TF_NewStatus(), TF_DeleteStatus);
289   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
290       TFE_NewContextOptions(), TFE_DeleteContextOptions);
291   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
292       TF_CreateConfig(
293           /*enable_xla_compilation=*/false,
294           /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
295       TF_DeleteBuffer);
296   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
297                               status.get());
298   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
299       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
300   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
301 
302   std::vector<std::string> devices{
303       "/job:localhost/replica:0/task:0/device:CPU:0",
304       "/job:localhost/replica:0/task:0/device:CPU:1"};
305   ParallelDevice parallel_device(std::move(devices));
306   {
307     std::unique_ptr<ParallelTensor> float_tensors =
308         parallel_device.ScalarsFromSequence<float>({10.0, 11.0}, context.get(),
309                                                    status.get());
310     ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
311     ExpectScalarEq<float>(float_tensors->tensor(0), 10.0);
312     ExpectScalarEq<float>(float_tensors->tensor(1), 11.0);
313   }
314 
315   {
316     std::unique_ptr<ParallelTensor> int_tensors =
317         parallel_device.ScalarsFromSequence<int>({5, 6}, context.get(),
318                                                  status.get());
319     ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
320     ExpectScalarEq<int>(int_tensors->tensor(0), 5);
321     ExpectScalarEq<int>(int_tensors->tensor(1), 6);
322   }
323 }
324 
325 }  // namespace parallel_device
326 }  // namespace tensorflow
327