1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api_experimental.h"
17 #include "tensorflow/c/eager/c_api.h"
18 #include "tensorflow/c/eager/c_api_experimental.h"
19 #include "tensorflow/c/eager/c_api_internal.h"
20 #include "tensorflow/c/eager/c_api_test_util.h"
21 #include "tensorflow/c/eager/tfe_context_internal.h"
22 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
23 #include "tensorflow/core/common_runtime/eager/context.h"
24 #include "tensorflow/core/framework/device_attributes.pb.h"
25 #include "tensorflow/core/platform/strcat.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/protobuf/cluster.pb.h"
28 #include "tensorflow/core/protobuf/coordination_config.pb.h"
29 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
30
31 namespace {
32
StartWorkers(int cluster_size,std::function<void (TFE_Context * ctx,TF_Status * status,int worker_id,int cluster_size)> fn)33 void StartWorkers(int cluster_size,
34 std::function<void(TFE_Context* ctx, TF_Status* status,
35 int worker_id, int cluster_size)>
36 fn) {
37 tensorflow::ServerDef server_def =
38 GetMultiClientServerDef("worker", cluster_size, /*num_virtual_gpus=*/2);
39 // Enable coordination service for propagating remote device attributess
40 auto* config = server_def.mutable_default_session_config()
41 ->mutable_experimental()
42 ->mutable_coordination_config();
43 config->set_service_type("standalone");
44 config->set_service_leader("/job:worker/replica:0/task:0");
45 // Use shutdown barrier to make sure that worker/0 thread (leader that starts
46 // the coordination service instance) does not exit early while other workers
47 // are still interacting with the coordination service.
48 config->set_shutdown_barrier_timeout_in_ms(3 * 1000); // 3 seconds
49
50 auto worker_thread_fn = [&](int worker_id) {
51 tensorflow::ServerDef server_def_copy = server_def;
52 // By default, server_def has task index set to 0.
53 server_def_copy.set_task_index(worker_id);
54 std::string serialized = server_def_copy.SerializeAsString();
55
56 TF_Status* status = TF_NewStatus();
57 TFE_ContextOptions* opts = TFE_NewContextOptions();
58 TFE_ContextOptionsSetAsync(opts,
59 static_cast<unsigned char>(/*enable=*/true));
60 TFE_ContextOptionsSetDevicePlacementPolicy(opts,
61 TFE_DEVICE_PLACEMENT_SILENT);
62
63 tensorflow::SessionOptions options;
64 options.config = server_def_copy.default_session_config();
65 opts->session_options.options = options;
66 TFE_Context* ctx = TFE_NewContext(opts, status);
67 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
68 TFE_DeleteContextOptions(opts);
69
70 TFE_EnableCollectiveOps(ctx, serialized.data(), serialized.size(), status);
71 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
72
73 fn(ctx, status, worker_id, cluster_size);
74
75 // Since we created an async EagerContext, wait for all pending operations
76 // to finish before deleting the context.
77 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
78 TFE_ExecutorWaitForAllPendingNodes(executor, status);
79 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
80 TFE_DeleteExecutor(executor);
81
82 TFE_DeleteContext(ctx);
83 TF_DeleteStatus(status);
84 };
85
86 std::vector<std::thread> worker_threads;
87 for (int i = 0; i < cluster_size; ++i) {
88 worker_threads.emplace_back([i, worker_thread_fn] { worker_thread_fn(i); });
89 }
90 for (auto i = 0; i < cluster_size; ++i) {
91 worker_threads[i].join();
92 }
93 }
94
TEST(CAPI,MultiClientCollectiveOps)95 TEST(CAPI, MultiClientCollectiveOps) {
96 auto fn = [](TFE_Context* ctx, TF_Status* status, int worker_id,
97 int cluster_size) {
98 TFE_TensorHandle* in = TestMatrixTensorHandle(ctx);
99 TFE_Op* allreduce = AllReduceOp(ctx, in, cluster_size);
100 TFE_TensorHandle* retvals[1];
101 int num_retvals = 1;
102 TFE_Execute(allreduce, &retvals[0], &num_retvals, status);
103 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
104
105 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
106 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
107 float result[4] = {0};
108 EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
109 memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
110 TF_DeleteTensor(t);
111 EXPECT_EQ(2.0, result[0]);
112 EXPECT_EQ(4.0, result[1]);
113 EXPECT_EQ(6.0, result[2]);
114 EXPECT_EQ(8.0, result[3]);
115
116 TFE_DeleteTensorHandle(in);
117 TFE_DeleteTensorHandle(retvals[0]);
118 TFE_DeleteOp(allreduce);
119 };
120 StartWorkers(2, fn);
121 }
122
TEST(CAPI,MultiClientRemoteDevices)123 TEST(CAPI, MultiClientRemoteDevices) {
124 auto fn = [](TFE_Context* ctx, TF_Status* status, int worker_id,
125 int cluster_size) {
126 std::vector<tensorflow::DeviceAttributes> device_attrs;
127 tensorflow::EagerContext* context =
128 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
129 context->ListDevices(&device_attrs);
130 std::vector<std::string> device_names;
131 for (const auto& device_attr : device_attrs) {
132 device_names.push_back(device_attr.name());
133 }
134
135 bool has_gpu_devices = false;
136 std::string unused_gpu_device_name;
137 if (GetDeviceName(ctx, &unused_gpu_device_name, "GPU")) {
138 has_gpu_devices = true;
139 }
140
141 std::vector<std::string> expected_device_names;
142 for (int i = 0; i < cluster_size; ++i) {
143 expected_device_names.push_back(tensorflow::strings::StrCat(
144 "/job:worker/replica:0/task:", i, "/device:CPU:0"));
145 if (has_gpu_devices) {
146 expected_device_names.push_back(tensorflow::strings::StrCat(
147 "/job:worker/replica:0/task:", i, "/device:GPU:0"));
148 expected_device_names.push_back(tensorflow::strings::StrCat(
149 "/job:worker/replica:0/task:", i, "/device:GPU:1"));
150 }
151 }
152
153 EXPECT_THAT(device_names,
154 testing::UnorderedElementsAreArray(expected_device_names));
155 };
156 StartWorkers(3, fn);
157 }
158
TEST(CAPI,MultiClientSendRecv)159 TEST(CAPI, MultiClientSendRecv) {
160 auto fn = [](TFE_Context* ctx, TF_Status* status, int worker_id,
161 int cluster_size) {
162 // Test with GPUs if present (based on test configuration) and CPUs
163 // otherwise.
164 auto send_device = "/job:worker/replica:0/task:0/device:CPU:0";
165 auto recv_device = "/job:worker/replica:0/task:1/device:CPU:0";
166 std::string unused_gpu_device_name;
167 if (GetDeviceName(ctx, &unused_gpu_device_name, "GPU")) {
168 send_device = "/job:worker/replica:0/task:0/device:GPU:0";
169 recv_device = "/job:worker/replica:0/task:1/device:GPU:0";
170 }
171
172 std::vector<tensorflow::DeviceAttributes> device_attrs;
173 tensorflow::EagerContext* context =
174 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
175 context->ListDevices(&device_attrs);
176
177 tensorflow::uint64 send_device_incarnation = 0;
178 for (const auto& device_attr : device_attrs) {
179 if (device_attr.name() == send_device) {
180 send_device_incarnation = device_attr.incarnation();
181 break;
182 }
183 }
184
185 if (worker_id == 0) {
186 TFE_TensorHandle* in = TestMatrixTensorHandle(ctx);
187 const std::string& op_name =
188 tensorflow::str_util::StrContains(send_device, "GPU") ? "Send"
189 : "_HostSend";
190 TFE_Op* sendop = SendOp(ctx, in, op_name, send_device, recv_device,
191 send_device_incarnation);
192 TFE_TensorHandle* retvals[1];
193 int num_retvals = 1;
194 TFE_Execute(sendop, &retvals[0], &num_retvals, status);
195 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
196 TFE_DeleteOp(sendop);
197 TFE_DeleteTensorHandle(in);
198 } else {
199 const std::string& op_name =
200 tensorflow::str_util::StrContains(send_device, "GPU") ? "Recv"
201 : "_HostRecv";
202 TFE_Op* recvop = RecvOp(ctx, op_name, send_device, recv_device,
203 send_device_incarnation);
204 TFE_TensorHandle* retvals[1];
205 int num_retvals = 1;
206 TFE_Execute(recvop, &retvals[0], &num_retvals, status);
207 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
208 TF_DeleteTensor(t);
209 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
210 TFE_DeleteTensorHandle(retvals[0]);
211 TFE_DeleteOp(recvop);
212 }
213 };
214 StartWorkers(2, fn);
215 }
216
217 } // namespace
218