xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/collective_nccl_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/device_attributes.pb.h"
17 #ifdef GOOGLE_CUDA
18 
19 #include "tensorflow/core/kernels/collective_nccl.h"
20 
21 #include <algorithm>
22 
23 #include "absl/memory/memory.h"
24 #include "tensorflow/core/common_runtime/base_collective_executor.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/device_resolver_local.h"
29 #include "tensorflow/core/common_runtime/dma_helper.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/common_runtime/test_collective_executor_mgr.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/fake_input.h"
34 #include "tensorflow/core/framework/node_def_builder.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/kernels/collective_nccl_broadcaster.h"
38 #include "tensorflow/core/kernels/collective_nccl_gatherer.h"
39 #include "tensorflow/core/kernels/collective_nccl_reducer.h"
40 #include "tensorflow/core/lib/core/notification.h"
41 #include "tensorflow/core/lib/core/status_test_util.h"
42 #include "tensorflow/core/lib/strings/strcat.h"
43 #include "tensorflow/core/nccl/collective_communicator.h"
44 #include "tensorflow/core/platform/env.h"
45 #include "tensorflow/core/platform/test.h"
46 #include "tensorflow/core/platform/unbounded_work_queue.h"
47 #include "tensorflow/core/public/session_options.h"
48 #include "tensorflow/core/public/version.h"
49 
50 namespace tensorflow {
51 static constexpr int kStepId = 10;
52 
GetKernel(const NodeDef & node,DeviceBase * device)53 std::unique_ptr<OpKernel> GetKernel(const NodeDef& node, DeviceBase* device) {
54   Status status;
55   std::unique_ptr<OpKernel> k = CreateOpKernel(
56       DEVICE_GPU, device, device->GetAllocator(AllocatorAttributes()), node,
57       TF_GRAPH_DEF_VERSION, &status);
58   if (!status.ok()) LOG(FATAL) << status;
59   return k;
60 }
61 
GetAdd(DeviceBase * device)62 std::unique_ptr<OpKernel> GetAdd(DeviceBase* device) {
63   NodeDef node_def;
64   NodeDefBuilder builder("add_node", "Add");
65   TF_CHECK_OK(builder.Attr("T", DT_FLOAT)
66                   .Input(FakeInput(DT_FLOAT))
67                   .Input(FakeInput(DT_FLOAT))
68                   .Finalize(&node_def));
69   return GetKernel(node_def, device);
70 }
71 
GetDiv(DeviceBase * device)72 std::unique_ptr<OpKernel> GetDiv(DeviceBase* device) {
73   NodeDef node_def;
74   NodeDefBuilder builder("add_node", "Div");
75   TF_CHECK_OK(builder.Attr("T", DT_FLOAT)
76                   .Input(FakeInput(DT_FLOAT))
77                   .Input(FakeInput(DT_FLOAT))
78                   .Finalize(&node_def));
79   return GetKernel(node_def, device);
80 }
81 
82 class NcclTestBase : public ::testing::Test {
83  protected:
84   class DeviceInstance;
85 
NcclTestBase(CollectiveType collective_type,const string & collective_name)86   NcclTestBase(CollectiveType collective_type, const string& collective_name)
87       : collective_type_(collective_type),
88         collective_name_(collective_name),
89         nccl_communicator_(MaybeCreateNcclCommunicator(config_proto_)),
90         work_queue_(std::make_shared<UnboundedWorkQueue>(
91             Env::Default(), "collective_executor")),
92         col_exec_(nullptr),
93         col_params_(nullptr) {}
94 
~NcclTestBase()95   ~NcclTestBase() override {
96     if (col_exec_) col_exec_->Unref();
97     if (col_params_) col_params_->Unref();
98   }
99 
SetUp()100   void SetUp() {
101     std::vector<std::unique_ptr<Device>> all_devices;
102     TF_CHECK_OK(DeviceFactory::GetFactory(DEVICE_GPU)
103                     ->AddDevices(SessionOptions(), "", &all_devices));
104     for (std::unique_ptr<Device>& d : all_devices) {
105       if (d->device_type() == "GPU") {
106         gpus_.emplace_back(std::move(d));
107       }
108     }
109   }
110 
Init(const int num_ranks,const int instance_key)111   void Init(const int num_ranks, const int instance_key) {
112     setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
113     setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */);
114     std::vector<std::unique_ptr<Device>> local_devices;
115     std::vector<string> device_names;
116     CHECK_LE(num_ranks, gpus_.size());
117     for (int rank = 0; rank < num_ranks; ++rank) {
118       local_devices.emplace_back(std::move(gpus_[rank]));
119     }
120     int num_gpus = local_devices.size();
121     for (const auto& device : local_devices) {
122       device_names.push_back(device->name());
123       VLOG(2) << device->name();
124     }
125     if (!dev_mgr_)
126       dev_mgr_ = std::make_unique<StaticDeviceMgr>(std::move(local_devices));
127     col_exec_ =
128         new BaseCollectiveExecutor(&col_exec_mgr_, /*remote_access=*/nullptr,
129                                    kStepId, dev_mgr_.get(), work_queue_);
130 
131     // Initialize collective params.
132     col_params_ = new CollectiveParams();
133     col_params_->name = "test_nccl_collective_op";
134     const int group_key = num_ranks;
135     col_params_->group.group_key = group_key;
136     col_params_->group.device_type = DEVICE_GPU;
137     col_params_->group.group_size = num_ranks;
138     col_params_->instance.instance_key = instance_key;
139     col_params_->instance.type = collective_type_;
140     col_params_->instance.data_type = DT_FLOAT;
141     col_params_->instance.impl_details.collective_name = collective_name_;
142     const string task_name = "/job:worker/replica:0/task:0";
143     col_params_->group.num_devices_per_task[task_name] = num_ranks;
144     for (int rank = 0; rank < num_ranks; ++rank) {
145       CollGroupMember member;
146       member.device.set_name(device_names[rank % num_gpus]);
147       col_params_->group.members.push_back(member);
148     }
149     for (int rank = 0; rank < num_ranks; ++rank) {
150       instances_.push_back(std::make_unique<DeviceInstance>(
151           rank, col_params_->group.members[rank].device.name(), this));
152     }
153   }
154 
155   // Initialize `input` tensor at rank `rank`.
156   virtual void InitInput(Tensor* input, const int rank) = 0;
157 
158   // Initialize `expected` output at all `num_ranks` ranks.
159   virtual void InitExpected(std::vector<float>* expected,
160                             const int tensor_length, const int num_ranks) = 0;
161 
162   // Initialize device `di` specific to the collective op.
163   virtual void InitDevice(DeviceInstance* di) = 0;
164 
165   // Run collective op on device `di`.
166   virtual void RunCollectiveOnDevice(DeviceInstance* di) = 0;
167 
RunCollective()168   void RunCollective() {
169     int done = 0;
170     mutex done_mu;
171     condition_variable done_cv;
172     for (const auto& instance : instances_) {
173       DeviceInstance* di = instance.get();
174       InitDevice(di);
175       SchedClosure([this, di, &done, &done_mu, &done_cv] {
176         RunCollectiveOnDevice(di);
177         mutex_lock l(done_mu);
178         ++done;
179         done_cv.notify_all();
180       });
181     }
182 
183     mutex_lock l(done_mu);
184     while (done < instances_.size()) done_cv.wait(l);
185   }
186 
RunTest(int num_ranks,int input_length,int instance_key)187   void RunTest(int num_ranks, int input_length, int instance_key) {
188     if (num_ranks > gpus_.size()) {
189       LOG(WARNING) << "Skipping test because required " << num_ranks
190                    << " GPUs but found " << gpus_.size();
191       return;
192     }
193     Init(num_ranks, instance_key);
194     std::vector<float> expected;
195     InitExpected(&expected, input_length, num_ranks);
196     if (VLOG_IS_ON(3)) {
197       string str_buf;
198       for (const auto& x : expected) {
199         strings::StrAppend(&str_buf, " ", x);
200       }
201       VLOG(3) << "Expected output " << str_buf;
202     }
203     for (int rank = 0; rank < num_ranks; ++rank) {
204       DeviceInstance* instance = instances_[rank].get();
205       instance->InitTensor(DT_FLOAT, TensorShape({input_length}),
206                            [this, rank](Tensor* t) { InitInput(t, rank); });
207     }
208     RunCollective();
209     // Confirm that every rank computed the same correct value.
210     for (int rank = 0; rank < instances_.size(); ++rank) {
211       TF_ASSERT_OK(instances_[rank]->status_);
212       Tensor* output = &instances_[rank]->output_;
213       const int output_length = output->NumElements();
214       VLOG(2) << "rank " << rank << " output " << output << " buf "
215               << DMAHelper::base(output);
216       Tensor actual(DT_FLOAT, TensorShape({output_length}));
217       Device* dev = instances_[rank]->device_;
218       auto* dev_info = dev->tensorflow_accelerator_device_info();
219       TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
220           output, /*tensor_name=*/"", dev, &actual));
221       VLOG(3) << "rank " << rank << " got output tensor "
222               << actual.DebugString(output_length);
223       for (int i = 0; i < output_length; ++i) {
224         EXPECT_FLOAT_EQ(expected[i], actual.template flat<float>()(i))
225             << "Mismatch at rank " << rank << " index " << i;
226       }
227     }
228   }
229 
GetCollectiveReduceOpKernel(const CollectiveParams & params,Tensor * input,DeviceBase * device)230   std::unique_ptr<OpKernel> GetCollectiveReduceOpKernel(
231       const CollectiveParams& params, Tensor* input, DeviceBase* device) {
232     mutex_lock l(mu_);
233     NodeDef node_def;
234     NodeDefBuilder builder(strings::StrCat("collective_reduce_", op_counter_++),
235                            "CollectiveReduce");
236     TF_CHECK_OK(
237         builder.Attr("T", params.instance.data_type)
238             .Attr("merge_op", "Add")
239             .Attr("final_op", "Div")
240             .Attr("group_size", params.group.group_size)
241             .Attr("group_key", params.group.group_key)
242             .Attr("instance_key", params.instance.instance_key)
243             .Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
244             .Input(FakeInput(params.instance.data_type))
245             .Finalize(&node_def));
246     return GetKernel(node_def, device);
247   }
248 
249   class DeviceInstance {
250    public:
DeviceInstance(int rank,const string & device_name,NcclTestBase * parent)251     DeviceInstance(int rank, const string& device_name, NcclTestBase* parent)
252         : parent_(parent),
253           device_name_(device_name),
254           rank_(rank),
255           col_params_(new CollectiveParams()) {
256       TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(device_name_, &device_))
257           << "Could not find device " << device_name_ << " existing devices "
258           << parent_->dev_mgr_->DebugString();
259       merge_op_ = GetAdd(device_);
260       final_op_ = GetDiv(device_);
261       col_params_->name = parent_->col_params_->name;
262       col_params_->default_rank = rank;
263       col_params_->group = parent_->col_params_->group;
264       col_params_->instance = parent->col_params_->instance;
265     }
266 
~DeviceInstance()267     ~DeviceInstance() { col_params_->Unref(); }
268 
InitTensor(DataType dtype,const TensorShape & shape,const std::function<void (Tensor *)> & init_f)269     void InitTensor(DataType dtype, const TensorShape& shape,
270                     const std::function<void(Tensor*)>& init_f) {
271       input_ =
272           Tensor(device_->GetAllocator(AllocatorAttributes()), dtype, shape);
273       Tensor cpu_tensor(dtype, shape);
274       init_f(&cpu_tensor);
275       if (VLOG_IS_ON(3)) {
276         VLOG(3) << "input tensor "
277                 << cpu_tensor.DebugString(shape.num_elements());
278       } else {
279         VLOG(2) << "input tensor " << cpu_tensor.DebugString();
280       }
281       auto* dev_info = device_->tensorflow_accelerator_device_info();
282       TF_CHECK_OK(dev_info->default_context->CopyCPUTensorToDeviceSync(
283           &cpu_tensor, device_, &input_));
284     }
285 
PrepareDeviceContext(OpKernelContext::Params * params)286     void PrepareDeviceContext(OpKernelContext::Params* params) {
287       params->step_id = kStepId;
288       params->device = device_;
289       DeviceContext* dev_ctx = nullptr;
290       auto* dev_info = device_->tensorflow_accelerator_device_info();
291       if (dev_info) {
292         dev_ctx = dev_info->default_context;
293         dev_ctx->Ref();
294       } else {
295         dev_ctx = new DeviceContext;
296       }
297       params->op_device_context = dev_ctx;
298     }
299 
RunReduce()300     void RunReduce() {
301       // Prepare an OpKernelContext.
302       OpKernelContext::Params op_params;
303       PrepareDeviceContext(&op_params);
304 
305       // Prepare inputs and outputs to OpKernel.
306       gtl::InlinedVector<TensorValue, 4> inputs;
307       inputs.push_back(TensorValue(&input_));
308       op_params.inputs = inputs;
309       gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
310           {AllocatorAttributes()});
311       op_params.input_alloc_attrs = input_aa;
312       int forward_from = 0;
313       op_params.forward_from_array = &forward_from;
314       AllocatorAttributes generic_alloc_attr;
315       op_params.output_attr_array = &generic_alloc_attr;
316       std::unique_ptr<OpKernel> op =
317           parent_->GetCollectiveReduceOpKernel(*col_params_, &input_, device_);
318       op_params.op_kernel = op.get();
319       OpKernelContext ctx(&op_params, 1);
320       // We never actually execute the kernel, so we need to do the output
321       // allocation it would do, ourselves.
322       Tensor* output_tensor_ptr = nullptr;
323       TF_CHECK_OK(ctx.forward_input_or_allocate_output({0}, 0, input_.shape(),
324                                                        &output_tensor_ptr));
325       CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
326 
327       // Run the all-reduce.
328       string exec_key =
329           strings::StrCat(col_params_->instance.instance_key, ":0:0");
330       auto* reducer = new NcclReducer();
331       auto col_ctx = std::make_shared<CollectiveContext>(
332           parent_->col_exec_, parent_->nccl_communicator_.get(),
333           parent_->dev_mgr_.get(),
334           /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
335           /*input=*/&input_, /*output=*/&input_);
336       TF_CHECK_OK(reducer->InitializeCollectiveContext(col_ctx));
337       Notification note;
338       reducer->Run([this, &note](Status s) {
339         status_ = s;
340         note.Notify();
341       });
342       note.WaitForNotification();
343       if (status_.ok()) {
344         CHECK(output_.CopyFrom(*ctx.mutable_output(0), input_.shape()));
345       }
346 
347       reducer->Unref();
348       op_params.op_device_context->Unref();
349     }
350 
RunBroadcast()351     void RunBroadcast() {
352       VLOG(2) << "RunBroadcast name " << parent_->collective_name_ << " rank "
353               << col_params_->default_rank;
354       // Prepare an OpKernelContext.
355       OpKernelContext::Params op_params;
356       PrepareDeviceContext(&op_params);
357       OpKernelContext ctx(&op_params, 1);
358 
359       // Run broadcast.
360       string exec_key =
361           strings::StrCat(col_params_->instance.instance_key, ":0:0");
362       auto* broadcaster = new NcclBroadcaster();
363       auto col_ctx = std::make_shared<CollectiveContext>(
364           parent_->col_exec_, parent_->nccl_communicator_.get(),
365           parent_->dev_mgr_.get(),
366           /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
367           /*input=*/col_params_->is_source ? &input_ : nullptr,
368           /*output=*/&input_);
369       TF_CHECK_OK(broadcaster->InitializeCollectiveContext(col_ctx));
370       Notification note;
371       broadcaster->Run([this, &note](Status s) {
372         status_ = s;
373         note.Notify();
374       });
375       note.WaitForNotification();
376       if (status_.ok()) {
377         CHECK(output_.CopyFrom(input_, input_.shape()));
378       }
379 
380       broadcaster->Unref();
381       op_params.op_device_context->Unref();
382     }
383 
RunGather()384     void RunGather() {
385       VLOG(2) << "RunGather name " << parent_->collective_name_ << " rank "
386               << col_params_->default_rank;
387       // Prepare an OpKernelContext.
388       OpKernelContext::Params op_params;
389       PrepareDeviceContext(&op_params);
390       OpKernelContext ctx(&op_params, 1);
391 
392       // Allocate output.  We can't reuse the input because output has a
393       // different shape.
394       auto output_shape = input_.shape();
395       output_shape.set_dim(
396           0, output_shape.dim_size(0) * col_params_->group.group_size);
397       output_ = Tensor(device_->GetAllocator(AllocatorAttributes()), DT_FLOAT,
398                        output_shape);
399 
400       // Run gather.
401       string exec_key =
402           strings::StrCat(col_params_->instance.instance_key, ":0:0");
403       auto* gatherer = new NcclGatherer();
404       auto col_ctx = std::make_shared<CollectiveContext>(
405           parent_->col_exec_, parent_->nccl_communicator_.get(),
406           parent_->dev_mgr_.get(),
407           /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
408           /*input=*/&input_,
409           /*output=*/&output_);
410       TF_CHECK_OK(gatherer->InitializeCollectiveContext(col_ctx));
411       Notification note;
412       gatherer->Run([this, &note](Status s) {
413         status_ = s;
414         note.Notify();
415       });
416       note.WaitForNotification();
417 
418       gatherer->Unref();
419       op_params.op_device_context->Unref();
420     }
421 
422     NcclTestBase* parent_;
423     string device_name_;
424     int rank_;
425     Tensor input_;
426     Tensor output_;
427     Device* device_;
428     CollectiveParams* col_params_;
429     std::unique_ptr<OpKernel> merge_op_;
430     std::unique_ptr<OpKernel> final_op_;
431     Status status_;
432   };
433 
434   CollectiveType collective_type_;
435   const string collective_name_;
436   std::vector<std::unique_ptr<tensorflow::Device>> gpus_;
437   TestCollectiveExecutorMgr col_exec_mgr_;
438   ConfigProto config_proto_;
439   std::unique_ptr<NcclCommunicatorInterface> nccl_communicator_;
440   std::shared_ptr<UnboundedWorkQueue> work_queue_;
441   CollectiveExecutor* col_exec_;
442   std::unique_ptr<DeviceMgr> dev_mgr_;
443   std::vector<std::unique_ptr<DeviceInstance>> instances_;
444   CollectiveParams* col_params_;
445   mutex mu_;
446   int32 op_counter_ TF_GUARDED_BY(mu_) = 0;
447 };
448 
449 class NcclReducerTest : public NcclTestBase {
450  protected:
NcclReducerTest()451   NcclReducerTest()
452       : NcclTestBase(/*collective_type=*/REDUCTION_COLLECTIVE,
453                      /*collective_name=*/"NcclReduce") {}
454   ~NcclReducerTest() override = default;
455 
InitInput(Tensor * input,const int rank)456   void InitInput(Tensor* input, const int rank) override {
457     for (size_t i = 0; i < input->NumElements(); ++i) {
458       float value = pow(10, rank) * i;
459       input->flat<float>()(i) = value;
460     }
461   }
462 
InitExpected(std::vector<float> * expected,const int tensor_length,const int num_ranks)463   void InitExpected(std::vector<float>* expected, const int tensor_length,
464                     const int num_ranks) override {
465     expected->resize(tensor_length);
466     for (int i = 0; i < tensor_length; ++i) {
467       float expected_sum = 0.0;
468       for (int rank = 0; rank < num_ranks; ++rank) {
469         float value = pow(10, rank) * i;
470         expected_sum += value;
471       }
472       (*expected)[i] = expected_sum / num_ranks;
473     }
474   }
475 
InitDevice(DeviceInstance * di)476   void InitDevice(DeviceInstance* di) override {
477     di->col_params_->merge_op = di->merge_op_.get();
478     di->col_params_->final_op = di->final_op_.get();
479   }
480 
RunCollectiveOnDevice(DeviceInstance * di)481   void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunReduce(); }
482 };
483 
484 class NcclBroadcasterTest : public NcclTestBase {
485  protected:
NcclBroadcasterTest()486   NcclBroadcasterTest()
487       : NcclTestBase(/*collective_type=*/BROADCAST_COLLECTIVE,
488                      /*collective_name=*/"NcclBroadcast") {}
489   ~NcclBroadcasterTest() override = default;
490 
InitInput(Tensor * input,const int rank)491   void InitInput(Tensor* input, const int rank) override {
492     bool source = rank == source_rank_;
493     for (size_t i = 0; i < input->NumElements(); ++i) {
494       input->flat<float>()(i) = source ? static_cast<float>(i) : -1.0;
495     }
496   }
497 
InitExpected(std::vector<float> * expected,const int tensor_length,const int num_ranks)498   void InitExpected(std::vector<float>* expected, const int tensor_length,
499                     const int num_ranks) override {
500     expected->resize(tensor_length);
501     for (int i = 0; i < tensor_length; ++i) {
502       (*expected)[i] = i;
503     }
504   }
505 
InitDevice(DeviceInstance * di)506   void InitDevice(DeviceInstance* di) override {
507     di->col_params_->source_rank = source_rank_;
508     di->col_params_->is_source = di->col_params_->default_rank == source_rank_;
509   }
510 
RunCollectiveOnDevice(DeviceInstance * di)511   void RunCollectiveOnDevice(DeviceInstance* di) override {
512     di->RunBroadcast();
513   }
514 
515   int source_rank_ = 0;
516 };
517 
518 class NcclGathererTest : public NcclTestBase {
519  protected:
NcclGathererTest()520   NcclGathererTest()
521       : NcclTestBase(/*collective_type=*/GATHER_COLLECTIVE,
522                      /*collective_name=*/"NcclGather") {}
523   ~NcclGathererTest() override = default;
524 
InitInput(Tensor * input,const int rank)525   void InitInput(Tensor* input, const int rank) override {
526     for (size_t i = 0; i < input->NumElements(); ++i) {
527       float value = pow(10, rank) * i;
528       input->flat<float>()(i) = value;
529     }
530   }
531 
InitExpected(std::vector<float> * expected,const int tensor_length,const int num_ranks)532   void InitExpected(std::vector<float>* expected, const int tensor_length,
533                     const int num_ranks) override {
534     expected->resize(tensor_length * num_ranks, -1);
535     for (int rank = 0, i = 0; rank < num_ranks; ++rank) {
536       for (int j = 0; j < tensor_length; ++j, ++i) {
537         (*expected)[i] = pow(10, rank) * j;
538       }
539     }
540   }
541 
InitDevice(DeviceInstance * di)542   void InitDevice(DeviceInstance* di) override {}
543 
RunCollectiveOnDevice(DeviceInstance * di)544   void RunCollectiveOnDevice(DeviceInstance* di) override { di->RunGather(); }
545 
546   int source_rank_ = 0;
547 };
548 
TEST_F(NcclReducerTest,Test2Dev16Len)549 TEST_F(NcclReducerTest, Test2Dev16Len) {
550   RunTest(/*num_ranks=*/2, /*tensor_length=*/16, /*instance_key=*/23);
551 }
TEST_F(NcclReducerTest,Test4Dev16Len)552 TEST_F(NcclReducerTest, Test4Dev16Len) {
553   RunTest(/*num_ranks=*/4, /*tensor_length=*/16, /*instance_key=*/23);
554 }
TEST_F(NcclReducerTest,Test8Dev16Len)555 TEST_F(NcclReducerTest, Test8Dev16Len) {
556   RunTest(/*num_ranks=*/8, /*tensor_length=*/16, /*instance_key=*/23);
557 }
TEST_F(NcclReducerTest,Test8Dev128Len)558 TEST_F(NcclReducerTest, Test8Dev128Len) {
559   RunTest(/*num_ranks=*/8, /*tensor_length=*/128, /*instance_key=*/23);
560 }
TEST_F(NcclReducerTest,Test8Dev1045991Len)561 TEST_F(NcclReducerTest, Test8Dev1045991Len) {
562   RunTest(/*num_ranks=*/8, /*tensor_length=*/1048576, /*instance_key=*/23);
563 }
564 
TEST_F(NcclBroadcasterTest,Test2Dev16LenSrc0)565 TEST_F(NcclBroadcasterTest, Test2Dev16LenSrc0) {
566   RunTest(/*num_ranks=*/2, /*tensor_length=*/16, /*instance_key=*/23);
567 }
TEST_F(NcclBroadcasterTest,Test4Dev16LenSrc1)568 TEST_F(NcclBroadcasterTest, Test4Dev16LenSrc1) {
569   source_rank_ = 1;
570   RunTest(/*num_ranks=*/4, /*tensor_length=*/16, /*instance_key=*/23);
571 }
TEST_F(NcclBroadcasterTest,Test8Dev16LenSrc7)572 TEST_F(NcclBroadcasterTest, Test8Dev16LenSrc7) {
573   source_rank_ = 7;
574   RunTest(/*num_ranks=*/8, /*tensor_length=*/16, /*instance_key=*/23);
575 }
TEST_F(NcclBroadcasterTest,Test8Dev128LenSrc0)576 TEST_F(NcclBroadcasterTest, Test8Dev128LenSrc0) {
577   RunTest(/*num_ranks=*/8, /*tensor_length=*/128, /*instance_key=*/24);
578 }
TEST_F(NcclBroadcasterTest,Test8Dev1045991LenSrc0)579 TEST_F(NcclBroadcasterTest, Test8Dev1045991LenSrc0) {
580   RunTest(/*num_ranks=*/8, /*tensor_length=*/1048576, /*instance_key=*/23);
581 }
582 
TEST_F(NcclGathererTest,Test2Dev16Len)583 TEST_F(NcclGathererTest, Test2Dev16Len) {
584   RunTest(/*num_ranks=*/2, /*tensor_length=*/16, /*instance_key=*/23);
585 }
TEST_F(NcclGathererTest,Test4Dev16Len)586 TEST_F(NcclGathererTest, Test4Dev16Len) {
587   RunTest(/*num_ranks=*/4, /*tensor_length=*/16, /*instance_key=*/23);
588 }
TEST_F(NcclGathererTest,Test8Dev16Len)589 TEST_F(NcclGathererTest, Test8Dev16Len) {
590   RunTest(/*num_ranks=*/8, /*tensor_length=*/16, /*instance_key=*/23);
591 }
TEST_F(NcclGathererTest,Test8Dev128Len)592 TEST_F(NcclGathererTest, Test8Dev128Len) {
593   RunTest(/*num_ranks=*/8, /*tensor_length=*/128, /*instance_key=*/24);
594 }
TEST_F(NcclGathererTest,Test8Dev1045991Len)595 TEST_F(NcclGathererTest, Test8Dev1045991Len) {
596   RunTest(/*num_ranks=*/8, /*tensor_length=*/1048576, /*instance_key=*/23);
597 }
598 
599 }  // namespace tensorflow
600 
601 #endif
602