1Intermittent patch to TFRT to submit a TF/TFRT cross-cutting change. 2This patch will be applied only until TF's TFRT commit is automatically bumped. 3 4--- 5 6diff --git a/backends/gpu/include/tfrt/gpu/gpu_types.h b/backends/gpu/include/tfrt/gpu/gpu_types.h 7index 3d311c3..a216716 100644 8--- a/backends/gpu/include/tfrt/gpu/gpu_types.h 9+++ b/backends/gpu/include/tfrt/gpu/gpu_types.h 10@@ -295,11 +295,7 @@ 11 wrapper::CurrentContext current, wrapper::Stream stream, 12 wrapper::CclComm comm)>; 13 14- explicit GpuCclHandle(AsyncValueRef<GpuContext> context, 15- wrapper::OwningCclComm comm, int num_ranks); 16- // TODO(hanbinyoon): Remove after transitioning to the above constructor. 17- explicit GpuCclHandle(AsyncValueRef<GpuContext> context, 18- wrapper::OwningCclComm comm); 19+ GpuCclHandle(AsyncValueRef<GpuContext> context, wrapper::OwningCclComm comm); 20 ~GpuCclHandle(); 21 22 GpuCclHandle(GpuCclHandle&&) = default; 23@@ -311,8 +307,6 @@ 24 llvm::Error ExecuteCallbacks(wrapper::CurrentContext current, 25 wrapper::Stream stream); 26 27- int num_ranks() const { return num_ranks_; } 28- 29 const wrapper::OwningCclComm& operator->() const { return comm_; } 30 wrapper::CclComm get() const { return comm_.get(); } 31 wrapper::CclComm release(); 32@@ -322,7 +316,6 @@ 33 private: 34 AsyncValueRef<GpuContext> context_; 35 wrapper::OwningCclComm comm_; 36- int num_ranks_; 37 std::vector<Callback> callbacks_; 38 }; 39 40diff --git a/backends/gpu/lib/gpu_types.cc b/backends/gpu/lib/gpu_types.cc 41index 38529bc..01e3dba 100644 42--- a/backends/gpu/lib/gpu_types.cc 43+++ b/backends/gpu/lib/gpu_types.cc 44@@ -214,15 +214,8 @@ 45 GpuBlasHandle::~GpuBlasHandle() = default; 46 47 GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context, 48- wrapper::OwningCclComm comm, int num_ranks) 49- : context_(std::move(context)), 50- comm_(std::move(comm)), 51- num_ranks_(num_ranks) {} 52- 53-// TODO(hanbinyoon): Remove after transitioning to the above constructor. 54-GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context, 55 wrapper::OwningCclComm comm) 56- : context_(std::move(context)), comm_(std::move(comm)), num_ranks_(0) {} 57+ : context_(std::move(context)), comm_(std::move(comm)) {} 58 59 GpuCclHandle::~GpuCclHandle() = default; 60 61diff --git a/backends/gpu/lib/kernels/ccl_kernels.cc b/backends/gpu/lib/kernels/ccl_kernels.cc 62index 52ce820..9cfc1de 100644 63--- a/backends/gpu/lib/kernels/ccl_kernels.cc 64+++ b/backends/gpu/lib/kernels/ccl_kernels.cc 65@@ -107,8 +107,6 @@ 66 auto width = ToWidthInBytes(type); 67 if (!width) return width.takeError(); 68 assert(*width != 0); 69- if (input->size() != output->size() * handle->num_ranks()) 70- return MakeStringError("Input size must be output size times ranks."); 71 72 handle->AddCallback([input = input.ValueRef(), output = output.ValueRef(), 73 recvcount = output->size() / *width, type, 74@@ -116,6 +114,10 @@ 75 wrapper::CurrentContext current, 76 wrapper::Stream stream, 77 wrapper::CclComm comm) -> llvm::Error { 78+ auto count = wrapper::CclCommCount(comm); 79+ if (!count) return count.takeError(); 80+ if (input->size() != output->size() * *count) 81+ return MakeStringError("Input size must be output size times ranks."); 82 return wrapper::CclReduceScatter(current, input->pointer(), 83 output->pointer(), recvcount, type, op, 84 comm, stream); 85