1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
17
18 #include "tensorflow/c/c_api.h"
19 #include "tensorflow/c/c_api_experimental.h"
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
23 #include "tensorflow/c/eager/tfe_context_internal.h"
24 #include "tensorflow/core/common_runtime/eager/context.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28
29 namespace tensorflow {
30 namespace parallel_device {
31
32 using ::testing::ElementsAre;
33 using ::testing::HasSubstr;
34
TEST(PARALLEL_DEVICE_LIB,TestOpWithError)35 TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
36 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
37 TF_NewStatus(), TF_DeleteStatus);
38 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
39 TFE_NewContextOptions(), TFE_DeleteContextOptions);
40 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
41 TF_CreateConfig(
42 /*xla*/ false,
43 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
44 2),
45 TF_DeleteBuffer);
46 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
47 status.get());
48 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
49 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
50 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
51
52 std::vector<std::string> devices{
53 "/job:localhost/replica:0/task:0/device:CPU:0",
54 "/job:localhost/replica:0/task:0/device:CPU:1"};
55 ParallelDevice parallel_device(std::move(devices));
56 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
57 TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
58 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
59 TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
60 TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
61 status.get());
62 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
63 auto outputs =
64 parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
65 "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
66 /*expected_max_outputs=*/1, status.get());
67 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
68 const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
69 std::vector<ParallelTensor*> handle_inputs;
70 handle_inputs.reserve(handles.size());
71 for (auto& handle : handles) {
72 handle_inputs.push_back(handle.get());
73 }
74 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> read_op(
75 TFE_NewOp(context.get(), "ReadVariableOp", status.get()), TFE_DeleteOp);
76 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
77 TFE_OpSetAttrType(read_op.get(), "dtype", TF_FLOAT);
78 parallel_device.Execute(context.get(), handle_inputs, "ReadVariableOp",
79 TFE_OpGetAttrs(read_op.get()),
80 /*expected_max_outputs=*/1, status.get());
81 ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
82 TF_SetStatus(status.get(), TF_OK, "");
83
84 // Check that ops still run successfully on the device.
85 parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
86 "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
87 /*expected_max_outputs=*/1, status.get());
88 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
89 }
90
TEST(PARALLEL_DEVICE_LIB,TestExplicitOutputShape)91 TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) {
92 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
93 TF_NewStatus(), TF_DeleteStatus);
94 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
95 TFE_NewContextOptions(), TFE_DeleteContextOptions);
96 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
97 TF_CreateConfig(
98 /*xla*/ false,
99 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
100 2),
101 TF_DeleteBuffer);
102 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
103 status.get());
104 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
105 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
106 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
107
108 std::vector<std::string> devices{
109 "/job:localhost/replica:0/task:0/device:CPU:0",
110 "/job:localhost/replica:0/task:0/device:CPU:1"};
111 ParallelDevice parallel_device(std::move(devices));
112 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
113 TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
114 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
115 TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
116 TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
117 status.get());
118 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
119 CancellationManager cancellation_manager;
120 parallel_device.StartExecute(context.get(), std::vector<ParallelTensor*>(),
121 "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
122 /*expected_max_outputs=*/1,
123 cancellation_manager);
124 auto outputs = parallel_device.Join(
125 /*expected_output_shapes=*/{PartialTensorShape({})}, status.get());
126 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
127 const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
128 const std::vector<int64_t>* shape;
129 Status s = handles[0]->Shape(&shape);
130 ASSERT_TRUE(s.ok());
131 EXPECT_EQ(0, shape->size());
132 }
133
TEST(PARALLEL_DEVICE_LIB,TestCancelOnError)134 TEST(PARALLEL_DEVICE_LIB, TestCancelOnError) {
135 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
136 TF_NewStatus(), TF_DeleteStatus);
137 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
138 TFE_NewContextOptions(), TFE_DeleteContextOptions);
139 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
140 TF_CreateConfig(
141 /*enable_xla_compilation=*/false,
142 /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
143 TF_DeleteBuffer);
144 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
145 status.get());
146 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
147 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
148 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
149
150 std::vector<std::string> devices{
151 "/job:localhost/replica:0/task:0/device:CPU:0",
152 "/job:localhost/replica:0/task:0/device:CPU:1"};
153 ParallelDevice parallel_device(devices);
154 const FunctionDef assert_and_collective = FunctionDefHelper::Define(
155 // Name
156 "AssertAndCollective",
157 // Args
158 {"x: float", "condition: bool"},
159 // Return values
160 {"y: float"},
161 // Attr def
162 {},
163 // Nodes
164 {
165 {{"assert"},
166 "Assert",
167 {"condition", "x"},
168 {{"T", std::vector<DataType>{DT_FLOAT}}}},
169 {{"y"},
170 "CollectiveReduce",
171 {"x"},
172 {{"T", DT_FLOAT},
173 {"group_size", static_cast<int>(devices.size())},
174 {"group_key", 0},
175 {"instance_key", 0},
176 {"merge_op", "Add"},
177 {"final_op", "Id"},
178 {"subdiv_offsets", std::vector<int>()}},
179 /*dep=*/{"assert"}},
180 });
181 TF_ASSERT_OK(ContextFromInterface(unwrap(context.get()))
182 ->AddFunctionDef(assert_and_collective));
183
184 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> call_op(
185 TFE_NewOp(context.get(), "AssertAndCollective", status.get()),
186 TFE_DeleteOp);
187 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
188 std::unique_ptr<ParallelTensor> reduced_values =
189 parallel_device.ScalarsFromSequence<float>({1.0, 2.0}, context.get(),
190 status.get());
191 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
192 std::unique_ptr<ParallelTensor> run_collective =
193 parallel_device.ScalarsFromSequence<bool>({true, true}, context.get(),
194 status.get());
195 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
196 auto outputs = parallel_device.Execute(
197 context.get(), {reduced_values.get(), run_collective.get()},
198 "AssertAndCollective", TFE_OpGetAttrs(call_op.get()),
199 /*expected_max_outputs=*/1, status.get());
200 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
201 ASSERT_EQ(outputs->size(), 1);
202 ParallelTensor* parallel_result = (*outputs)[0].get();
203 ExpectScalarEq<float>(parallel_result->tensor(0), 3.);
204 ExpectScalarEq<float>(parallel_result->tensor(1), 3.);
205
206 run_collective = parallel_device.ScalarsFromSequence<bool>(
207 {true, false}, context.get(), status.get());
208 parallel_device.Execute(context.get(),
209 {reduced_values.get(), run_collective.get()},
210 "AssertAndCollective", TFE_OpGetAttrs(call_op.get()),
211 /*expected_max_outputs=*/1, status.get());
212 EXPECT_NE(TF_GetCode(status.get()), TF_CANCELLED);
213 EXPECT_EQ(TF_GetCode(status.get()), TF_INVALID_ARGUMENT);
214 EXPECT_THAT(TF_Message(status.get()), HasSubstr("assertion failed"));
215
216 // Note that future collectives with the same context do not work at the
217 // moment; once canceled, the collective executor requires the program to be
218 // restarted / context to be reset.
219 }
220
TEST(PARALLEL_DEVICE_LIB,TestDifferentShapes)221 TEST(PARALLEL_DEVICE_LIB, TestDifferentShapes) {
222 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
223 TF_NewStatus(), TF_DeleteStatus);
224 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
225 TFE_NewContextOptions(), TFE_DeleteContextOptions);
226 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
227 TF_CreateConfig(
228 /*xla*/ false,
229 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
230 2),
231 TF_DeleteBuffer);
232 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
233 status.get());
234 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
235 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
236 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
237
238 std::vector<std::string> devices{
239 "/job:localhost/replica:0/task:0/device:CPU:0",
240 "/job:localhost/replica:0/task:0/device:CPU:1"};
241 ParallelDevice parallel_device(std::move(devices));
242 TensorHandlePtr two_vector = VectorFloatTensorHandle({3., 4.}, status.get());
243 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
244 TensorHandlePtr three_vector =
245 VectorFloatTensorHandle({5., 6., 7.}, status.get());
246 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
247
248 std::vector<TensorHandlePtr> vector_handles;
249 vector_handles.reserve(2);
250 vector_handles.push_back(std::move(two_vector));
251 vector_handles.push_back(std::move(three_vector));
252 std::unique_ptr<ParallelTensor> unknown_length_vector =
253 ParallelTensor::FromTensorHandles(
254 parallel_device, std::move(vector_handles), status.get());
255 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
256 const std::vector<int64_t>* shape;
257 TF_ASSERT_OK(unknown_length_vector->Shape(&shape));
258 EXPECT_THAT(*shape, ElementsAre(-1));
259
260 TensorHandlePtr scalar = FloatTensorHandle(2., status.get());
261 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
262 two_vector = VectorFloatTensorHandle({3., 4.}, status.get());
263 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
264 std::vector<TensorHandlePtr> mixed_handles;
265 mixed_handles.reserve(2);
266 mixed_handles.push_back(std::move(scalar));
267 mixed_handles.push_back(std::move(two_vector));
268 std::unique_ptr<ParallelTensor> unknown_dims_vector =
269 ParallelTensor::FromTensorHandles(parallel_device,
270 std::move(mixed_handles), status.get());
271 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
272 // Can't take the shape of a parallel tensor with varying numbers of axes, but
273 // running operations on them is OK.
274 TF_ASSERT_OK(unknown_length_vector->Shape(&shape));
275 EXPECT_THAT(*shape, ElementsAre(-1));
276 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> size_op(
277 TFE_NewOp(context.get(), "Size", status.get()), TFE_DeleteOp);
278 auto result = parallel_device.Execute(
279 context.get(), {unknown_dims_vector.get()}, "Size",
280 TFE_OpGetAttrs(size_op.get()), 1, status.get());
281 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
282 TF_ASSERT_OK((*result)[0]->Shape(&shape));
283 EXPECT_EQ(0, shape->size());
284 }
285
TEST(PARALLEL_DEVICE_LIB,TestScalarsFromSequence)286 TEST(PARALLEL_DEVICE_LIB, TestScalarsFromSequence) {
287 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
288 TF_NewStatus(), TF_DeleteStatus);
289 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
290 TFE_NewContextOptions(), TFE_DeleteContextOptions);
291 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
292 TF_CreateConfig(
293 /*enable_xla_compilation=*/false,
294 /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
295 TF_DeleteBuffer);
296 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
297 status.get());
298 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
299 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
300 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
301
302 std::vector<std::string> devices{
303 "/job:localhost/replica:0/task:0/device:CPU:0",
304 "/job:localhost/replica:0/task:0/device:CPU:1"};
305 ParallelDevice parallel_device(std::move(devices));
306 {
307 std::unique_ptr<ParallelTensor> float_tensors =
308 parallel_device.ScalarsFromSequence<float>({10.0, 11.0}, context.get(),
309 status.get());
310 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
311 ExpectScalarEq<float>(float_tensors->tensor(0), 10.0);
312 ExpectScalarEq<float>(float_tensors->tensor(1), 11.0);
313 }
314
315 {
316 std::unique_ptr<ParallelTensor> int_tensors =
317 parallel_device.ScalarsFromSequence<int>({5, 6}, context.get(),
318 status.get());
319 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
320 ExpectScalarEq<int>(int_tensors->tensor(0), 5);
321 ExpectScalarEq<int>(int_tensors->tensor(1), 6);
322 }
323 }
324
325 } // namespace parallel_device
326 } // namespace tensorflow
327