1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api_experimental.h"
17 #include "tensorflow/c/eager/c_api.h"
18 #include "tensorflow/c/eager/c_api_experimental.h"
19 #include "tensorflow/c/eager/c_api_internal.h"
20 #include "tensorflow/c/eager/c_api_test_util.h"
21 #include "tensorflow/c/eager/tfe_context_internal.h"
22 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
23 #include "tensorflow/core/common_runtime/eager/context.h"
24 #include "tensorflow/core/framework/device_attributes.pb.h"
25 #include "tensorflow/core/platform/strcat.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/protobuf/cluster.pb.h"
28 #include "tensorflow/core/protobuf/coordination_config.pb.h"
29 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
30 
31 namespace {
32 
StartWorkers(int cluster_size,std::function<void (TFE_Context * ctx,TF_Status * status,int worker_id,int cluster_size)> fn)33 void StartWorkers(int cluster_size,
34                   std::function<void(TFE_Context* ctx, TF_Status* status,
35                                      int worker_id, int cluster_size)>
36                       fn) {
37   tensorflow::ServerDef server_def =
38       GetMultiClientServerDef("worker", cluster_size, /*num_virtual_gpus=*/2);
39   // Enable coordination service for propagating remote device attributess
40   auto* config = server_def.mutable_default_session_config()
41                      ->mutable_experimental()
42                      ->mutable_coordination_config();
43   config->set_service_type("standalone");
44   config->set_service_leader("/job:worker/replica:0/task:0");
45   // Use shutdown barrier to make sure that worker/0 thread (leader that starts
46   // the coordination service instance) does not exit early while other workers
47   // are still interacting with the coordination service.
48   config->set_shutdown_barrier_timeout_in_ms(3 * 1000);  // 3 seconds
49 
50   auto worker_thread_fn = [&](int worker_id) {
51     tensorflow::ServerDef server_def_copy = server_def;
52     // By default, server_def has task index set to 0.
53     server_def_copy.set_task_index(worker_id);
54     std::string serialized = server_def_copy.SerializeAsString();
55 
56     TF_Status* status = TF_NewStatus();
57     TFE_ContextOptions* opts = TFE_NewContextOptions();
58     TFE_ContextOptionsSetAsync(opts,
59                                static_cast<unsigned char>(/*enable=*/true));
60     TFE_ContextOptionsSetDevicePlacementPolicy(opts,
61                                                TFE_DEVICE_PLACEMENT_SILENT);
62 
63     tensorflow::SessionOptions options;
64     options.config = server_def_copy.default_session_config();
65     opts->session_options.options = options;
66     TFE_Context* ctx = TFE_NewContext(opts, status);
67     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
68     TFE_DeleteContextOptions(opts);
69 
70     TFE_EnableCollectiveOps(ctx, serialized.data(), serialized.size(), status);
71     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
72 
73     fn(ctx, status, worker_id, cluster_size);
74 
75     // Since we created an async EagerContext, wait for all pending operations
76     // to finish before deleting the context.
77     TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
78     TFE_ExecutorWaitForAllPendingNodes(executor, status);
79     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
80     TFE_DeleteExecutor(executor);
81 
82     TFE_DeleteContext(ctx);
83     TF_DeleteStatus(status);
84   };
85 
86   std::vector<std::thread> worker_threads;
87   for (int i = 0; i < cluster_size; ++i) {
88     worker_threads.emplace_back([i, worker_thread_fn] { worker_thread_fn(i); });
89   }
90   for (auto i = 0; i < cluster_size; ++i) {
91     worker_threads[i].join();
92   }
93 }
94 
TEST(CAPI,MultiClientCollectiveOps)95 TEST(CAPI, MultiClientCollectiveOps) {
96   auto fn = [](TFE_Context* ctx, TF_Status* status, int worker_id,
97                int cluster_size) {
98     TFE_TensorHandle* in = TestMatrixTensorHandle(ctx);
99     TFE_Op* allreduce = AllReduceOp(ctx, in, cluster_size);
100     TFE_TensorHandle* retvals[1];
101     int num_retvals = 1;
102     TFE_Execute(allreduce, &retvals[0], &num_retvals, status);
103     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
104 
105     TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
106     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
107     float result[4] = {0};
108     EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
109     memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
110     TF_DeleteTensor(t);
111     EXPECT_EQ(2.0, result[0]);
112     EXPECT_EQ(4.0, result[1]);
113     EXPECT_EQ(6.0, result[2]);
114     EXPECT_EQ(8.0, result[3]);
115 
116     TFE_DeleteTensorHandle(in);
117     TFE_DeleteTensorHandle(retvals[0]);
118     TFE_DeleteOp(allreduce);
119   };
120   StartWorkers(2, fn);
121 }
122 
TEST(CAPI,MultiClientRemoteDevices)123 TEST(CAPI, MultiClientRemoteDevices) {
124   auto fn = [](TFE_Context* ctx, TF_Status* status, int worker_id,
125                int cluster_size) {
126     std::vector<tensorflow::DeviceAttributes> device_attrs;
127     tensorflow::EagerContext* context =
128         tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
129     context->ListDevices(&device_attrs);
130     std::vector<std::string> device_names;
131     for (const auto& device_attr : device_attrs) {
132       device_names.push_back(device_attr.name());
133     }
134 
135     bool has_gpu_devices = false;
136     std::string unused_gpu_device_name;
137     if (GetDeviceName(ctx, &unused_gpu_device_name, "GPU")) {
138       has_gpu_devices = true;
139     }
140 
141     std::vector<std::string> expected_device_names;
142     for (int i = 0; i < cluster_size; ++i) {
143       expected_device_names.push_back(tensorflow::strings::StrCat(
144           "/job:worker/replica:0/task:", i, "/device:CPU:0"));
145       if (has_gpu_devices) {
146         expected_device_names.push_back(tensorflow::strings::StrCat(
147             "/job:worker/replica:0/task:", i, "/device:GPU:0"));
148         expected_device_names.push_back(tensorflow::strings::StrCat(
149             "/job:worker/replica:0/task:", i, "/device:GPU:1"));
150       }
151     }
152 
153     EXPECT_THAT(device_names,
154                 testing::UnorderedElementsAreArray(expected_device_names));
155   };
156   StartWorkers(3, fn);
157 }
158 
TEST(CAPI,MultiClientSendRecv)159 TEST(CAPI, MultiClientSendRecv) {
160   auto fn = [](TFE_Context* ctx, TF_Status* status, int worker_id,
161                int cluster_size) {
162     // Test with GPUs if present (based on test configuration) and CPUs
163     // otherwise.
164     auto send_device = "/job:worker/replica:0/task:0/device:CPU:0";
165     auto recv_device = "/job:worker/replica:0/task:1/device:CPU:0";
166     std::string unused_gpu_device_name;
167     if (GetDeviceName(ctx, &unused_gpu_device_name, "GPU")) {
168       send_device = "/job:worker/replica:0/task:0/device:GPU:0";
169       recv_device = "/job:worker/replica:0/task:1/device:GPU:0";
170     }
171 
172     std::vector<tensorflow::DeviceAttributes> device_attrs;
173     tensorflow::EagerContext* context =
174         tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
175     context->ListDevices(&device_attrs);
176 
177     tensorflow::uint64 send_device_incarnation = 0;
178     for (const auto& device_attr : device_attrs) {
179       if (device_attr.name() == send_device) {
180         send_device_incarnation = device_attr.incarnation();
181         break;
182       }
183     }
184 
185     if (worker_id == 0) {
186       TFE_TensorHandle* in = TestMatrixTensorHandle(ctx);
187       const std::string& op_name =
188           tensorflow::str_util::StrContains(send_device, "GPU") ? "Send"
189                                                                 : "_HostSend";
190       TFE_Op* sendop = SendOp(ctx, in, op_name, send_device, recv_device,
191                               send_device_incarnation);
192       TFE_TensorHandle* retvals[1];
193       int num_retvals = 1;
194       TFE_Execute(sendop, &retvals[0], &num_retvals, status);
195       EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
196       TFE_DeleteOp(sendop);
197       TFE_DeleteTensorHandle(in);
198     } else {
199       const std::string& op_name =
200           tensorflow::str_util::StrContains(send_device, "GPU") ? "Recv"
201                                                                 : "_HostRecv";
202       TFE_Op* recvop = RecvOp(ctx, op_name, send_device, recv_device,
203                               send_device_incarnation);
204       TFE_TensorHandle* retvals[1];
205       int num_retvals = 1;
206       TFE_Execute(recvop, &retvals[0], &num_retvals, status);
207       TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
208       TF_DeleteTensor(t);
209       EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
210       TFE_DeleteTensorHandle(retvals[0]);
211       TFE_DeleteOp(recvop);
212     }
213   };
214   StartWorkers(2, fn);
215 }
216 
217 }  // namespace
218