1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/device_attributes.pb.h"
17 #ifdef GOOGLE_CUDA
18
19 #include "tensorflow/core/kernels/collective_nccl.h"
20
21 #include <algorithm>
22
23 #include "absl/memory/memory.h"
24 #include "tensorflow/core/common_runtime/base_collective_executor.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/device_resolver_local.h"
29 #include "tensorflow/core/common_runtime/dma_helper.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/common_runtime/test_collective_executor_mgr.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/fake_input.h"
34 #include "tensorflow/core/framework/node_def_builder.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/kernels/collective_nccl_broadcaster.h"
38 #include "tensorflow/core/kernels/collective_nccl_gatherer.h"
39 #include "tensorflow/core/kernels/collective_nccl_reducer.h"
40 #include "tensorflow/core/lib/core/notification.h"
41 #include "tensorflow/core/lib/core/status_test_util.h"
42 #include "tensorflow/core/lib/strings/strcat.h"
43 #include "tensorflow/core/nccl/collective_communicator.h"
44 #include "tensorflow/core/platform/env.h"
45 #include "tensorflow/core/platform/test.h"
46 #include "tensorflow/core/platform/unbounded_work_queue.h"
47 #include "tensorflow/core/public/session_options.h"
48 #include "tensorflow/core/public/version.h"
49
50 namespace tensorflow {
51 static constexpr int kStepId = 10;
52
GetKernel(const NodeDef & node,DeviceBase * device)53 std::unique_ptr<OpKernel> GetKernel(const NodeDef& node, DeviceBase* device) {
54 Status status;
55 std::unique_ptr<OpKernel> k = CreateOpKernel(
56 DEVICE_GPU, device, device->GetAllocator(AllocatorAttributes()), node,
57 TF_GRAPH_DEF_VERSION, &status);
58 if (!status.ok()) LOG(FATAL) << status;
59 return k;
60 }
61
GetAdd(DeviceBase * device)62 std::unique_ptr<OpKernel> GetAdd(DeviceBase* device) {
63 NodeDef node_def;
64 NodeDefBuilder builder("add_node", "Add");
65 TF_CHECK_OK(builder.Attr("T", DT_FLOAT)
66 .Input(FakeInput(DT_FLOAT))
67 .Input(FakeInput(DT_FLOAT))
68 .Finalize(&node_def));
69 return GetKernel(node_def, device);
70 }
71
GetDiv(DeviceBase * device)72 std::unique_ptr<OpKernel> GetDiv(DeviceBase* device) {
73 NodeDef node_def;
74 NodeDefBuilder builder("add_node", "Div");
75 TF_CHECK_OK(builder.Attr("T", DT_FLOAT)
76 .Input(FakeInput(DT_FLOAT))
77 .Input(FakeInput(DT_FLOAT))
78 .Finalize(&node_def));
79 return GetKernel(node_def, device);
80 }
81
82 class NcclTestBase : public ::testing::Test {
83 protected:
84 class DeviceInstance;
85
NcclTestBase(CollectiveType collective_type,const string & collective_name)86 NcclTestBase(CollectiveType collective_type, const string& collective_name)
87 : collective_type_(collective_type),
88 collective_name_(collective_name),
89 nccl_communicator_(MaybeCreateNcclCommunicator(config_proto_)),
90 work_queue_(std::make_shared<UnboundedWorkQueue>(
91 Env::Default(), "collective_executor")),
92 col_exec_(nullptr),
93 col_params_(nullptr) {}
94
~NcclTestBase()95 ~NcclTestBase() override {
96 if (col_exec_) col_exec_->Unref();
97 if (col_params_) col_params_->Unref();
98 }
99
SetUp()100 void SetUp() {
101 std::vector<std::unique_ptr<Device>> all_devices;
102 TF_CHECK_OK(DeviceFactory::GetFactory(DEVICE_GPU)
103 ->AddDevices(SessionOptions(), "", &all_devices));
104 for (std::unique_ptr<Device>& d : all_devices) {
105 if (d->device_type() == "GPU") {
106 gpus_.emplace_back(std::move(d));
107 }
108 }
109 }
110
Init(const int num_ranks,const int instance_key)111 void Init(const int num_ranks, const int instance_key) {
112 setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
113 setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */);
114 std::vector<std::unique_ptr<Device>> local_devices;
115 std::vector<string> device_names;
116 CHECK_LE(num_ranks, gpus_.size());
117 for (int rank = 0; rank < num_ranks; ++rank) {
118 local_devices.emplace_back(std::move(gpus_[rank]));
119 }
120 int num_gpus = local_devices.size();
121 for (const auto& device : local_devices) {
122 device_names.push_back(device->name());
123 VLOG(2) << device->name();
124 }
125 if (!dev_mgr_)
126 dev_mgr_ = std::make_unique<StaticDeviceMgr>(std::move(local_devices));
127 col_exec_ =
128 new BaseCollectiveExecutor(&col_exec_mgr_, /*remote_access=*/nullptr,
129 kStepId, dev_mgr_.get(), work_queue_);
130
131 // Initialize collective params.
132 col_params_ = new CollectiveParams();
133 col_params_->name = "test_nccl_collective_op";
134 const int group_key = num_ranks;
135 col_params_->group.group_key = group_key;
136 col_params_->group.device_type = DEVICE_GPU;
137 col_params_->group.group_size = num_ranks;
138 col_params_->instance.instance_key = instance_key;
139 col_params_->instance.type = collective_type_;
140 col_params_->instance.data_type = DT_FLOAT;
141 col_params_->instance.impl_details.collective_name = collective_name_;
142 const string task_name = "/job:worker/replica:0/task:0";
143 col_params_->group.num_devices_per_task[task_name] = num_ranks;
144 for (int rank = 0; rank < num_ranks; ++rank) {
145 CollGroupMember member;
146 member.device.set_name(device_names[rank % num_gpus]);
147 col_params_->group.members.push_back(member);
148 }
149 for (int rank = 0; rank < num_ranks; ++rank) {
150 instances_.push_back(std::make_unique<DeviceInstance>(
151 rank, col_params_->group.members[rank].device.name(), this));
152 }
153 }
154
155 // Initialize `input` tensor at rank `rank`.
156 virtual void InitInput(Tensor* input, const int rank) = 0;
157
158 // Initialize `expected` output at all `num_ranks` ranks.
159 virtual void InitExpected(std::vector<float>* expected,
160 const int tensor_length, const int num_ranks) = 0;
161
162 // Initialize device `di` specific to the collective op.
163 virtual void InitDevice(DeviceInstance* di) = 0;
164
165 // Run collective op on device `di`.
166 virtual void RunCollectiveOnDevice(DeviceInstance* di) = 0;
167
RunCollective()168 void RunCollective() {
169 int done = 0;
170 mutex done_mu;
171 condition_variable done_cv;
172 for (const auto& instance : instances_) {
173 DeviceInstance* di = instance.get();
174 InitDevice(di);
175 SchedClosure([this, di, &done, &done_mu, &done_cv] {
176 RunCollectiveOnDevice(di);
177 mutex_lock l(done_mu);
178 ++done;
179 done_cv.notify_all();
180 });
181 }
182
183 mutex_lock l(done_mu);
184 while (done < instances_.size()) done_cv.wait(l);
185 }
186
RunTest(int num_ranks,int input_length,int instance_key)187 void RunTest(int num_ranks, int input_length, int instance_key) {
188 if (num_ranks > gpus_.size()) {
189 LOG(WARNING) << "Skipping test because required " << num_ranks
190 << " GPUs but found " << gpus_.size();
191 return;
192 }
193 Init(num_ranks, instance_key);
194 std::vector<float> expected;
195 InitExpected(&expected, input_length, num_ranks);
196 if (VLOG_IS_ON(3)) {
197 string str_buf;
198 for (const auto& x : expected) {
199 strings::StrAppend(&str_buf, " ", x);
200 }
201 VLOG(3) << "Expected output " << str_buf;
202 }
203 for (int rank = 0; rank < num_ranks; ++rank) {
204 DeviceInstance* instance = instances_[rank].get();
205 instance->InitTensor(DT_FLOAT, TensorShape({input_length}),
206 [this, rank](Tensor* t) { InitInput(t, rank); });
207 }
208 RunCollective();
209 // Confirm that every rank computed the same correct value.
210 for (int rank = 0; rank < instances_.size(); ++rank) {
211 TF_ASSERT_OK(instances_[rank]->status_);
212 Tensor* output = &instances_[rank]->output_;
213 const int output_length = output->NumElements();
214 VLOG(2) << "rank " << rank << " output " << output << " buf "
215 << DMAHelper::base(output);
216 Tensor actual(DT_FLOAT, TensorShape({output_length}));
217 Device* dev = instances_[rank]->device_;
218 auto* dev_info = dev->tensorflow_accelerator_device_info();
219 TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
220 output, /*tensor_name=*/"", dev, &actual));
221 VLOG(3) << "rank " << rank << " got output tensor "
222 << actual.DebugString(output_length);
223 for (int i = 0; i < output_length; ++i) {
224 EXPECT_FLOAT_EQ(expected[i], actual.template flat<float>()(i))
225 << "Mismatch at rank " << rank << " index " << i;
226 }
227 }
228 }
229
GetCollectiveReduceOpKernel(const CollectiveParams & params,Tensor * input,DeviceBase * device)230 std::unique_ptr<OpKernel> GetCollectiveReduceOpKernel(
231 const CollectiveParams& params, Tensor* input, DeviceBase* device) {
232 mutex_lock l(mu_);
233 NodeDef node_def;
234 NodeDefBuilder builder(strings::StrCat("collective_reduce_", op_counter_++),
235 "CollectiveReduce");
236 TF_CHECK_OK(
237 builder.Attr("T", params.instance.data_type)
238 .Attr("merge_op", "Add")
239 .Attr("final_op", "Div")
240 .Attr("group_size", params.group.group_size)
241 .Attr("group_key", params.group.group_key)
242 .Attr("instance_key", params.instance.instance_key)
243 .Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
244 .Input(FakeInput(params.instance.data_type))
245 .Finalize(&node_def));
246 return GetKernel(node_def, device);
247 }
248
249 class DeviceInstance {
250 public:
DeviceInstance(int rank,const string & device_name,NcclTestBase * parent)251 DeviceInstance(int rank, const string& device_name, NcclTestBase* parent)
252 : parent_(parent),
253 device_name_(device_name),
254 rank_(rank),
255 col_params_(new CollectiveParams()) {
256 TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(device_name_, &device_))
257 << "Could not find device " << device_name_ << " existing devices "
258 << parent_->dev_mgr_->DebugString();
259 merge_op_ = GetAdd(device_);
260 final_op_ = GetDiv(device_);
261 col_params_->name = parent_->col_params_->name;
262 col_params_->default_rank = rank;
263 col_params_->group = parent_->col_params_->group;
264 col_params_->instance = parent->col_params_->instance;
265 }
266
~DeviceInstance()267 ~DeviceInstance() { col_params_->Unref(); }
268
InitTensor(DataType dtype,const TensorShape & shape,const std::function<void (Tensor *)> & init_f)269 void InitTensor(DataType dtype, const TensorShape& shape,
270 const std::function<void(Tensor*)>& init_f) {
271 input_ =
272 Tensor(device_->GetAllocator(AllocatorAttributes()), dtype, shape);
273 Tensor cpu_tensor(dtype, shape);
274 init_f(&cpu_tensor);
275 if (VLOG_IS_ON(3)) {
276 VLOG(3) << "input tensor "
277 << cpu_tensor.DebugString(shape.num_elements());
278 } else {
279 VLOG(2) << "input tensor " << cpu_tensor.DebugString();
280 }
281 auto* dev_info = device_->tensorflow_accelerator_device_info();
282 TF_CHECK_OK(dev_info->default_context->CopyCPUTensorToDeviceSync(
283 &cpu_tensor, device_, &input_));
284 }
285
PrepareDeviceContext(OpKernelContext::Params * params)286 void PrepareDeviceContext(OpKernelContext::Params* params) {
287 params->step_id = kStepId;
288 params->device = device_;
289 DeviceContext* dev_ctx = nullptr;
290 auto* dev_info = device_->tensorflow_accelerator_device_info();
291 if (dev_info) {
292 dev_ctx = dev_info->default_context;
293 dev_ctx->Ref();
294 } else {
295 dev_ctx = new DeviceContext;
296 }
297 params->op_device_context = dev_ctx;
298 }
299
RunReduce()300 void RunReduce() {
301 // Prepare an OpKernelContext.
302 OpKernelContext::Params op_params;
303 PrepareDeviceContext(&op_params);
304
305 // Prepare inputs and outputs to OpKernel.
306 gtl::InlinedVector<TensorValue, 4> inputs;
307 inputs.push_back(TensorValue(&input_));
308 op_params.inputs = inputs;
309 gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
310 {AllocatorAttributes()});
311 op_params.input_alloc_attrs = input_aa;
312 int forward_from = 0;
313 op_params.forward_from_array = &forward_from;
314 AllocatorAttributes generic_alloc_attr;
315 op_params.output_attr_array = &generic_alloc_attr;
316 std::unique_ptr<OpKernel> op =
317 parent_->GetCollectiveReduceOpKernel(*col_params_, &input_, device_);
318 op_params.op_kernel = op.get();
319 OpKernelContext ctx(&op_params, 1);
320 // We never actually execute the kernel, so we need to do the output
321 // allocation it would do, ourselves.
322 Tensor* output_tensor_ptr = nullptr;
323 TF_CHECK_OK(ctx.forward_input_or_allocate_output({0}, 0, input_.shape(),
324 &output_tensor_ptr));
325 CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
326
327 // Run the all-reduce.
328 string exec_key =
329 strings::StrCat(col_params_->instance.instance_key, ":0:0");
330 auto* reducer = new NcclReducer();
331 auto col_ctx = std::make_shared<CollectiveContext>(
332 parent_->col_exec_, parent_->nccl_communicator_.get(),
333 parent_->dev_mgr_.get(),
334 /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
335 /*input=*/&input_, /*output=*/&input_);
336 TF_CHECK_OK(reducer->InitializeCollectiveContext(col_ctx));
337 Notification note;
338 reducer->Run([this, ¬e](Status s) {
339 status_ = s;
340 note.Notify();
341 });
342 note.WaitForNotification();
343 if (status_.ok()) {
344 CHECK(output_.CopyFrom(*ctx.mutable_output(0), input_.shape()));
345 }
346
347 reducer->Unref();
348 op_params.op_device_context->Unref();
349 }
350
RunBroadcast()351 void RunBroadcast() {
352 VLOG(2) << "RunBroadcast name " << parent_->collective_name_ << " rank "
353 << col_params_->default_rank;
354 // Prepare an OpKernelContext.
355 OpKernelContext::Params op_params;
356 PrepareDeviceContext(&op_params);
357 OpKernelContext ctx(&op_params, 1);
358
359 // Run broadcast.
360 string exec_key =
361 strings::StrCat(col_params_->instance.instance_key, ":0:0");
362 auto* broadcaster = new NcclBroadcaster();
363 auto col_ctx = std::make_shared<CollectiveContext>(
364 parent_->col_exec_, parent_->nccl_communicator_.get(),
365 parent_->dev_mgr_.get(),
366 /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
367 /*input=*/col_params_->is_source ? &input_ : nullptr,
368 /*output=*/&input_);
369 TF_CHECK_OK(broadcaster->InitializeCollectiveContext(col_ctx));
370 Notification note;
371 broadcaster->Run([this, ¬e](Status s) {
372 status_ = s;
373 note.Notify();
374 });
375 note.WaitForNotification();
376 if (status_.ok()) {
377 CHECK(output_.CopyFrom(input_, input_.shape()));
378 }
379
380 broadcaster->Unref();
381 op_params.op_device_context->Unref();
382 }
383
RunGather()384 void RunGather() {
385 VLOG(2) << "RunGather name " << parent_->collective_name_ << " rank "
386 << col_params_->default_rank;
387 // Prepare an OpKernelContext.
388 OpKernelContext::Params op_params;
389 PrepareDeviceContext(&op_params);
390 OpKernelContext ctx(&op_params, 1);
391
392 // Allocate output. We can't reuse the input because output has a
393 // different shape.
394 auto output_shape = input_.shape();
395 output_shape.set_dim(
396 0, output_shape.dim_size(0) * col_params_->group.group_size);
397 output_ = Tensor(device_->GetAllocator(AllocatorAttributes()), DT_FLOAT,
398 output_shape);
399
400 // Run gather.
401 string exec_key =
402 strings::StrCat(col_params_->instance.instance_key, ":0:0");
403 auto* gatherer = new NcclGatherer();
404 auto col_ctx = std::make_shared<CollectiveContext>(
405 parent_->col_exec_, parent_->nccl_communicator_.get(),
406 parent_->dev_mgr_.get(),
407 /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
408 /*input=*/&input_,
409 /*output=*/&output_);
410 TF_CHECK_OK(gatherer->InitializeCollectiveContext(col_ctx));
411 Notification note;
412 gatherer->Run([this, ¬e](Status s) {
413 status_ = s;
414 note.Notify();
415 });
416 note.WaitForNotification();
417
418 gatherer->Unref();
419 op_params.op_device_context->Unref();
420 }
421
422 NcclTestBase* parent_;
423 string device_name_;
424 int rank_;
425 Tensor input_;
426 Tensor output_;
427 Device* device_;
428 CollectiveParams* col_params_;
429 std::unique_ptr<OpKernel> merge_op_;
430 std::unique_ptr<OpKernel> final_op_;
431 Status status_;
432 };
433
434 CollectiveType collective_type_;
435 const string collective_name_;
436 std::vector<std::unique_ptr<tensorflow::Device>> gpus_;
437 TestCollectiveExecutorMgr col_exec_mgr_;
438 ConfigProto config_proto_;
439 std::unique_ptr<NcclCommunicatorInterface> nccl_communicator_;
440 std::shared_ptr<UnboundedWorkQueue> work_queue_;
441 CollectiveExecutor* col_exec_;
442 std::unique_ptr<DeviceMgr> dev_mgr_;
443 std::vector<std::unique_ptr<DeviceInstance>> instances_;
444 CollectiveParams* col_params_;
445 mutex mu_;
446 int32 op_counter_ TF_GUARDED_BY(mu_) = 0;
447 };
448
449 class NcclReducerTest : public NcclTestBase {
450 protected:
NcclReducerTest()451 NcclReducerTest()
452 : NcclTestBase(/*collective_type=*/REDUCTION_COLLECTIVE,
453 /*collective_name=*/"NcclReduce") {}
454 ~NcclReducerTest() override = default;
455
InitInput(Tensor * input,const int rank)456 void InitInput(Tensor* input, const int rank) override {
457 for (size_t i = 0; i < input->NumElements(); ++i) {
458 float value = pow(10, rank) * i;
459 input->flat<float>()(i) = value;
460 }
461 }
462
InitExpected(std::vector<float> * expected,const int tensor_length,const int num_ranks)463 void InitExpected(std::vector<float>* expected, const int tensor_length,
464 const int num_ranks) override {
465 expected->resize(tensor_length);
466 for (int i = 0; i < tensor_length; ++i) {
467 float expected_sum = 0.0;
468 for (int rank = 0; rank < num_ranks; ++rank) {
469 float value = pow(10, rank) * i;
470 expected_sum += value;
471 }
472 (*expected)[i] = expected_sum / num_ranks;
473 }
474 }
475
InitDevice(DeviceInstance * di)476 void InitDevice(DeviceInstance* di) override {
477 di->col_params_->merge_op = di->merge_op_.get();
478 di->col_params_->final_op = di->final_op_.get();
479 }
480
RunCollectiveOnDevice(DeviceInstance * di)481 void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunReduce(); }
482 };
483
484 class NcclBroadcasterTest : public NcclTestBase {
485 protected:
NcclBroadcasterTest()486 NcclBroadcasterTest()
487 : NcclTestBase(/*collective_type=*/BROADCAST_COLLECTIVE,
488 /*collective_name=*/"NcclBroadcast") {}
489 ~NcclBroadcasterTest() override = default;
490
InitInput(Tensor * input,const int rank)491 void InitInput(Tensor* input, const int rank) override {
492 bool source = rank == source_rank_;
493 for (size_t i = 0; i < input->NumElements(); ++i) {
494 input->flat<float>()(i) = source ? static_cast<float>(i) : -1.0;
495 }
496 }
497
InitExpected(std::vector<float> * expected,const int tensor_length,const int num_ranks)498 void InitExpected(std::vector<float>* expected, const int tensor_length,
499 const int num_ranks) override {
500 expected->resize(tensor_length);
501 for (int i = 0; i < tensor_length; ++i) {
502 (*expected)[i] = i;
503 }
504 }
505
InitDevice(DeviceInstance * di)506 void InitDevice(DeviceInstance* di) override {
507 di->col_params_->source_rank = source_rank_;
508 di->col_params_->is_source = di->col_params_->default_rank == source_rank_;
509 }
510
RunCollectiveOnDevice(DeviceInstance * di)511 void RunCollectiveOnDevice(DeviceInstance* di) override {
512 di->RunBroadcast();
513 }
514
515 int source_rank_ = 0;
516 };
517
518 class NcclGathererTest : public NcclTestBase {
519 protected:
NcclGathererTest()520 NcclGathererTest()
521 : NcclTestBase(/*collective_type=*/GATHER_COLLECTIVE,
522 /*collective_name=*/"NcclGather") {}
523 ~NcclGathererTest() override = default;
524
InitInput(Tensor * input,const int rank)525 void InitInput(Tensor* input, const int rank) override {
526 for (size_t i = 0; i < input->NumElements(); ++i) {
527 float value = pow(10, rank) * i;
528 input->flat<float>()(i) = value;
529 }
530 }
531
InitExpected(std::vector<float> * expected,const int tensor_length,const int num_ranks)532 void InitExpected(std::vector<float>* expected, const int tensor_length,
533 const int num_ranks) override {
534 expected->resize(tensor_length * num_ranks, -1);
535 for (int rank = 0, i = 0; rank < num_ranks; ++rank) {
536 for (int j = 0; j < tensor_length; ++j, ++i) {
537 (*expected)[i] = pow(10, rank) * j;
538 }
539 }
540 }
541
InitDevice(DeviceInstance * di)542 void InitDevice(DeviceInstance* di) override {}
543
RunCollectiveOnDevice(DeviceInstance * di)544 void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunGather(); }
545
546 int source_rank_ = 0;
547 };
548
TEST_F(NcclReducerTest,Test2Dev16Len)549 TEST_F(NcclReducerTest, Test2Dev16Len) {
550 RunTest(/*num_ranks=*/2, /*tensor_length=*/16, /*instance_key=*/23);
551 }
TEST_F(NcclReducerTest,Test4Dev16Len)552 TEST_F(NcclReducerTest, Test4Dev16Len) {
553 RunTest(/*num_ranks=*/4, /*tensor_length=*/16, /*instance_key=*/23);
554 }
TEST_F(NcclReducerTest,Test8Dev16Len)555 TEST_F(NcclReducerTest, Test8Dev16Len) {
556 RunTest(/*num_ranks=*/8, /*tensor_length=*/16, /*instance_key=*/23);
557 }
TEST_F(NcclReducerTest,Test8Dev128Len)558 TEST_F(NcclReducerTest, Test8Dev128Len) {
559 RunTest(/*num_ranks=*/8, /*tensor_length=*/128, /*instance_key=*/23);
560 }
TEST_F(NcclReducerTest,Test8Dev1045991Len)561 TEST_F(NcclReducerTest, Test8Dev1045991Len) {
562 RunTest(/*num_ranks=*/8, /*tensor_length=*/1048576, /*instance_key=*/23);
563 }
564
TEST_F(NcclBroadcasterTest,Test2Dev16LenSrc0)565 TEST_F(NcclBroadcasterTest, Test2Dev16LenSrc0) {
566 RunTest(/*num_ranks=*/2, /*tensor_length=*/16, /*instance_key=*/23);
567 }
TEST_F(NcclBroadcasterTest,Test4Dev16LenSrc1)568 TEST_F(NcclBroadcasterTest, Test4Dev16LenSrc1) {
569 source_rank_ = 1;
570 RunTest(/*num_ranks=*/4, /*tensor_length=*/16, /*instance_key=*/23);
571 }
TEST_F(NcclBroadcasterTest,Test8Dev16LenSrc7)572 TEST_F(NcclBroadcasterTest, Test8Dev16LenSrc7) {
573 source_rank_ = 7;
574 RunTest(/*num_ranks=*/8, /*tensor_length=*/16, /*instance_key=*/23);
575 }
TEST_F(NcclBroadcasterTest,Test8Dev128LenSrc0)576 TEST_F(NcclBroadcasterTest, Test8Dev128LenSrc0) {
577 RunTest(/*num_ranks=*/8, /*tensor_length=*/128, /*instance_key=*/24);
578 }
TEST_F(NcclBroadcasterTest,Test8Dev1045991LenSrc0)579 TEST_F(NcclBroadcasterTest, Test8Dev1045991LenSrc0) {
580 RunTest(/*num_ranks=*/8, /*tensor_length=*/1048576, /*instance_key=*/23);
581 }
582
TEST_F(NcclGathererTest,Test2Dev16Len)583 TEST_F(NcclGathererTest, Test2Dev16Len) {
584 RunTest(/*num_ranks=*/2, /*tensor_length=*/16, /*instance_key=*/23);
585 }
TEST_F(NcclGathererTest,Test4Dev16Len)586 TEST_F(NcclGathererTest, Test4Dev16Len) {
587 RunTest(/*num_ranks=*/4, /*tensor_length=*/16, /*instance_key=*/23);
588 }
TEST_F(NcclGathererTest,Test8Dev16Len)589 TEST_F(NcclGathererTest, Test8Dev16Len) {
590 RunTest(/*num_ranks=*/8, /*tensor_length=*/16, /*instance_key=*/23);
591 }
TEST_F(NcclGathererTest,Test8Dev128Len)592 TEST_F(NcclGathererTest, Test8Dev128Len) {
593 RunTest(/*num_ranks=*/8, /*tensor_length=*/128, /*instance_key=*/24);
594 }
TEST_F(NcclGathererTest,Test8Dev1045991Len)595 TEST_F(NcclGathererTest, Test8Dev1045991Len) {
596 RunTest(/*num_ranks=*/8, /*tensor_length=*/1048576, /*instance_key=*/23);
597 }
598
599 } // namespace tensorflow
600
601 #endif
602