xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
17 
18 #include <chrono>  // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 
35 #if XLA_ENABLE_XCCL
36 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
37 #endif
38 
39 namespace xla {
40 namespace gpu {
41 
RunAllReduce(ReductionKind reduction_kind,std::vector<DeviceBufferPair> & buffers,se::Stream & stream,ncclComm_t comm)42 Status RunAllReduce(ReductionKind reduction_kind,
43                     std::vector<DeviceBufferPair>& buffers, se::Stream& stream,
44                     ncclComm_t comm) {
45 #if XLA_ENABLE_XCCL
46   int device_ordinal = stream.parent()->device_ordinal();
47   VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
48 
49   ncclRedOp_t reduce_op = ToNcclReduction(reduction_kind);
50 
51   se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);
52 
53   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
54   for (size_t i = 0; i < buffers.size(); ++i) {
55     DeviceBufferPair& buffer = buffers[i];
56     const void* send_buffer = buffer.source_buffer.opaque();
57     void* recv_buffer = buffer.destination_buffer.opaque();
58 
59     TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
60                         ToNcclDataTypeAndCountMultiplier(buffer.element_type));
61     ncclDataType_t dtype = dtype_and_multiplier.first;
62     int element_count = buffer.element_count * dtype_and_multiplier.second;
63 
64     VLOG(3) << absl::StreamFormat(
65         "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
66         "comm=%p, stream=%p)",
67         send_buffer, recv_buffer, element_count, static_cast<const void*>(comm),
68         gpu_stream);
69 
70     XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
71                                            element_count, dtype, reduce_op,
72                                            comm, gpu_stream));
73   }
74   return XLA_CUDA_STATUS(ncclGroupEnd());
75 #else   // XLA_ENABLE_XCCL
76   return Unimplemented(
77       "NCCL support is not available: this binary was not built with a CUDA "
78       "compiler, which is necessary to build the NCCL source library.");
79 #endif  // XLA_ENABLE_XCCL
80 }
81 
82 namespace {
83 
IsValidOperand(mlir::Value operand)84 bool IsValidOperand(mlir::Value operand) {
85   Shape shape = TypeToShape(operand.getType());
86   return LayoutUtil::IsDenseArray(shape) &&
87          IsTypeSupportedByNccl(shape.element_type());
88 }
89 
90 // Generally, the reduction op should be the only operation in the block, except
91 // the terminator. However, if the type is bf16, the `BFloat16Normalization`
92 // pass will have converted the op to float32 and added type conversions.
93 // TODO(cjfj): Can we prevent the bf16 conversion for this computation?
FindReductionOp(mlir::Block & block)94 StatusOr<mlir::Operation*> FindReductionOp(mlir::Block& block) {
95   TF_RET_CHECK(block.getNumArguments() == 2);
96   mlir::Operation* terminator = block.getTerminator();
97   TF_RET_CHECK(terminator);
98   TF_RET_CHECK(terminator->getNumOperands() == 1);
99   mlir::Value result = terminator->getOperand(0);
100   TF_RET_CHECK(block.getArgument(0).getType() == result.getType());
101   TF_RET_CHECK(block.getArgument(1).getType() == result.getType());
102 
103   mlir::Operation* result_op = result.getDefiningOp();
104   TF_RET_CHECK(result_op);
105 
106   // In the bf16 case, the type conversions and op might be fused.
107   if (mlir::isa<mlir::mhlo::FusionOp>(result_op)) {
108     return FindReductionOp(result_op->getRegion(0).front());
109   }
110 
111   // Standard case.
112   if (absl::c_is_permutation(result_op->getOperands(), block.getArguments())) {
113     return result_op;
114   }
115 
116   // bf16 case.
117   TF_RET_CHECK(mlir::isa<mlir::mhlo::ConvertOp>(result_op));
118   TF_RET_CHECK(result_op->getNumOperands() == 1);
119   mlir::Operation* reduction_op = result_op->getOperand(0).getDefiningOp();
120   TF_RET_CHECK(reduction_op);
121   TF_RET_CHECK(reduction_op->getNumOperands() == 2);
122   mlir::Value operand0 = reduction_op->getOperand(0);
123   mlir::Value operand1 = reduction_op->getOperand(1);
124   auto operand0_op = operand0.getDefiningOp<mlir::mhlo::ConvertOp>();
125   auto operand1_op = operand1.getDefiningOp<mlir::mhlo::ConvertOp>();
126   TF_RET_CHECK(operand0_op);
127   TF_RET_CHECK(operand1_op);
128   TF_RET_CHECK(operand0_op->getNumOperands() == 1);
129   TF_RET_CHECK(operand1_op->getNumOperands() == 1);
130   std::array<mlir::Value, 2> operands{operand0_op->getOperand(0),
131                                       operand1_op->getOperand(0)};
132   TF_RET_CHECK(absl::c_is_permutation(operands, block.getArguments()));
133   return reduction_op;
134 }
135 
136 }  // namespace
137 
138 namespace impl {
139 
140 template <typename OpT>
CanImplement(OpT op)141 bool CanImplement(OpT op) {
142   return absl::c_all_of(op.getInputs(), IsValidOperand) &&
143          NcclAllReduceThunkBase::MatchAllReduceComputation(op.getComputation())
144              .has_value();
145 }
146 
147 template <typename OpT>
GetNcclAllReduceConfig(OpT op)148 NcclAllReduceConfig GetNcclAllReduceConfig(OpT op) {
149   std::optional<ReductionKind> reduction_kind =
150       NcclAllReduceThunkBase::MatchAllReduceComputation(op.getComputation());
151   CHECK(reduction_kind.has_value());
152 
153   NcclAllReduceConfig config;
154   config.config =
155       GetNcclCollectiveConfigForMlir(op, op.getUseGlobalDeviceIds());
156   config.reduction_kind = *reduction_kind;
157   return config;
158 }
159 
160 template <typename OpT>
IsDegenerate(OpT op,int64_t replica_count,int64_t partition_count)161 bool IsDegenerate(OpT op, int64_t replica_count, int64_t partition_count) {
162   return GetNcclCollectiveConfigForMlir(op, op.getUseGlobalDeviceIds())
163       .IsDegenerate(replica_count, partition_count);
164 }
165 
166 template <typename OpT>
GetGroupMode(OpT op)167 CollectiveOpGroupMode GetGroupMode(OpT op) {
168   return GetNcclAllReduceConfig(op).config.group_mode;
169 }
170 
171 }  // namespace impl
172 
MatchAllReduceComputation(mlir::Region & computation)173 std::optional<ReductionKind> NcclAllReduceThunkBase::MatchAllReduceComputation(
174     mlir::Region& computation) {
175   mlir::Block& block = computation.front();
176   StatusOr<mlir::Operation*> reduction_op = FindReductionOp(block);
177   if (!reduction_op.ok()) return std::nullopt;
178   StatusOr<HloOpcode> opcode = MhloToHloOpcode(*reduction_op);
179   if (!opcode.ok()) return std::nullopt;
180   // Match the operation to a reduction kind. We can represent and/or of pred as
181   // min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
182   PrimitiveType type =
183       TypeToShape(block.getArgument(0).getType()).element_type();
184   if (type == PRED) {
185     switch (opcode.ValueOrDie()) {
186       case HloOpcode::kAnd:
187         return ReductionKind::MIN;
188       case HloOpcode::kOr:
189         return ReductionKind::MAX;
190       default:
191         return std::nullopt;
192     }
193   } else if (primitive_util::IsComplexType(type)) {
194     // Only addition is supported for complex types.
195     if (*opcode == HloOpcode::kAdd) {
196       return ReductionKind::SUM;
197     } else {
198       return std::nullopt;
199     }
200   } else {
201     switch (*opcode) {
202       case HloOpcode::kAdd:
203         return ReductionKind::SUM;
204       case HloOpcode::kMultiply:
205         return ReductionKind::PRODUCT;
206       case HloOpcode::kMaximum:
207         return ReductionKind::MAX;
208       case HloOpcode::kMinimum:
209         return ReductionKind::MIN;
210       default:
211         return std::nullopt;
212     }
213   }
214 }
215 
NcclAllReduceThunkBase(Thunk::Kind kind,ThunkInfo thunk_info,NcclAllReduceConfig config,std::vector<Buffer> buffers)216 NcclAllReduceThunkBase::NcclAllReduceThunkBase(Thunk::Kind kind,
217                                                ThunkInfo thunk_info,
218                                                NcclAllReduceConfig config,
219                                                std::vector<Buffer> buffers)
220     : NcclCollectiveThunk(kind, thunk_info),
221       config_(std::move(config)),
222       buffers_(std::move(buffers)) {
223   CHECK_EQ(config_.config.operand_count, buffers_.size());
224 }
225 
NcclAllReduceThunk(ThunkInfo thunk_info,mlir::lmhlo::AllReduceOp op,std::vector<Buffer> buffers)226 NcclAllReduceThunk::NcclAllReduceThunk(ThunkInfo thunk_info,
227                                        mlir::lmhlo::AllReduceOp op,
228                                        std::vector<Buffer> buffers)
229     : NcclAllReduceThunkBase(Thunk::kNcclAllReduce, thunk_info,
230                              impl::GetNcclAllReduceConfig(op), buffers) {}
231 
CanImplement(mlir::lmhlo::AllReduceOp op)232 bool NcclAllReduceThunk::CanImplement(mlir::lmhlo::AllReduceOp op) {
233   return impl::CanImplement(op);
234 }
235 
IsDegenerate(mlir::lmhlo::AllReduceOp op,int64_t replica_count,int64_t partition_count)236 bool NcclAllReduceThunk::IsDegenerate(mlir::lmhlo::AllReduceOp op,
237                                       int64_t replica_count,
238                                       int64_t partition_count) {
239   return impl::IsDegenerate(op, replica_count, partition_count);
240 }
241 
GetGroupMode(mlir::lmhlo::AllReduceOp op)242 CollectiveOpGroupMode NcclAllReduceThunk::GetGroupMode(
243     mlir::lmhlo::AllReduceOp op) {
244   return impl::GetGroupMode(op);
245 }
246 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)247 Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams& params,
248                                              ncclComm_t comm) {
249   se::Stream& stream = *params.stream;
250   TF_ASSIGN_OR_RETURN(
251       std::vector<DeviceBufferPair> device_buffers,
252       ConvertToDeviceBuffers(params, buffers_,
253                              config_.config.operand_element_type));
254   TF_RETURN_IF_ERROR(
255       RunAllReduce(config_.reduction_kind, device_buffers, stream, comm));
256 
257   int device_ordinal = stream.parent()->device_ordinal();
258   VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal;
259   return OkStatus();
260 }
261 
NcclAllReduceStartThunk(ThunkInfo thunk_info,mlir::lmhlo_gpu::AllReduceStartOp op,std::vector<Buffer> buffers)262 NcclAllReduceStartThunk::NcclAllReduceStartThunk(
263     ThunkInfo thunk_info, mlir::lmhlo_gpu::AllReduceStartOp op,
264     std::vector<Buffer> buffers)
265     : NcclAllReduceThunkBase(Thunk::kNcclAllReduceStart, thunk_info,
266                              impl::GetNcclAllReduceConfig(op), buffers) {}
267 
CanImplement(mlir::lmhlo_gpu::AllReduceStartOp op)268 bool NcclAllReduceStartThunk::CanImplement(
269     mlir::lmhlo_gpu::AllReduceStartOp op) {
270   return impl::CanImplement(op);
271 }
272 
IsDegenerate(mlir::lmhlo_gpu::AllReduceStartOp op,int64_t replica_count,int64_t partition_count)273 bool NcclAllReduceStartThunk::IsDegenerate(mlir::lmhlo_gpu::AllReduceStartOp op,
274                                            int64_t replica_count,
275                                            int64_t partition_count) {
276   return impl::IsDegenerate(op, replica_count, partition_count);
277 }
278 
GetGroupMode(mlir::lmhlo_gpu::AllReduceStartOp op)279 CollectiveOpGroupMode NcclAllReduceStartThunk::GetGroupMode(
280     mlir::lmhlo_gpu::AllReduceStartOp op) {
281   return impl::GetGroupMode(op);
282 }
283 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)284 Status NcclAllReduceStartThunk::RunNcclCollective(const ExecuteParams& params,
285                                                   ncclComm_t comm) {
286   se::Stream& async_comms_stream = *params.async_comms_stream;
287   // Wait until compute inputs are ready.
288   async_comms_stream.ThenWaitFor(params.stream);
289 
290   TF_ASSIGN_OR_RETURN(
291       std::vector<DeviceBufferPair> device_buffers,
292       ConvertToDeviceBuffers(params, buffers_,
293                              config_.config.operand_element_type));
294   TF_RETURN_IF_ERROR(RunAllReduce(config_.reduction_kind, device_buffers,
295                                   async_comms_stream, comm));
296 
297   // Create an event on the async stream for the completion of the all-reduce.
298   se::Event done_event(async_comms_stream.parent());
299   TF_RET_CHECK(done_event.Init());
300   async_comms_stream.ThenRecordEvent(&done_event);
301 
302   int device_ordinal = async_comms_stream.parent()->device_ordinal();
303 
304   {
305     absl::MutexLock lock(&mu_);
306     auto result = done_events_.emplace(device_ordinal, std::move(done_event));
307     TF_RET_CHECK(result.second) << "done event has not been consumed";
308   }
309 
310   VLOG(3) << "Done performing all-reduce-start for ordinal: " << device_ordinal;
311   return OkStatus();
312 }
313 
TakeDoneEvent(int device_ordinal)314 StatusOr<se::Event> NcclAllReduceStartThunk::TakeDoneEvent(int device_ordinal) {
315   absl::MutexLock lock(&mu_);
316   auto it = done_events_.find(device_ordinal);
317   TF_RET_CHECK(it != done_events_.end()) << "done event not found";
318   // Take ownership of the event.
319   se::Event done_event = std::move(it->second);
320   done_events_.erase(it);
321   return done_event;
322 }
323 
NcclAllReduceDoneThunk(ThunkInfo thunk_info,NcclAllReduceStartThunk & start_thunk)324 NcclAllReduceDoneThunk::NcclAllReduceDoneThunk(
325     ThunkInfo thunk_info, NcclAllReduceStartThunk& start_thunk)
326     : Thunk(Thunk::kNcclAllReduceDone, thunk_info), start_thunk_(start_thunk) {}
327 
ExecuteOnStream(const ExecuteParams & params)328 Status NcclAllReduceDoneThunk::ExecuteOnStream(const ExecuteParams& params) {
329   int device_ordinal = params.stream->parent()->device_ordinal();
330   TF_ASSIGN_OR_RETURN(se::Event done_event,
331                       start_thunk_.TakeDoneEvent(device_ordinal));
332   params.stream->ThenWaitFor(&done_event);
333   return OkStatus();
334 }
335 
NcclReduceScatterThunk(ThunkInfo thunk_info,mlir::lmhlo::ReduceScatterOp op,std::vector<NcclAllReduceThunk::Buffer> buffers)336 NcclReduceScatterThunk::NcclReduceScatterThunk(
337     ThunkInfo thunk_info, mlir::lmhlo::ReduceScatterOp op,
338     std::vector<NcclAllReduceThunk::Buffer> buffers)
339     : NcclAllReduceThunkBase(Thunk::kNcclReduceScatter, thunk_info,
340                              impl::GetNcclAllReduceConfig(op),
341                              std::move(buffers)) {}
342 
CanImplement(mlir::lmhlo::ReduceScatterOp op)343 /*static*/ bool NcclReduceScatterThunk::CanImplement(
344     mlir::lmhlo::ReduceScatterOp op) {
345   return impl::CanImplement(op);
346 }
347 
IsDegenerate(mlir::lmhlo::ReduceScatterOp op,int64_t replica_count,int64_t partition_count)348 /*static*/ bool NcclReduceScatterThunk::IsDegenerate(
349     mlir::lmhlo::ReduceScatterOp op, int64_t replica_count,
350     int64_t partition_count) {
351   return impl::IsDegenerate(op, replica_count, partition_count);
352 }
353 
GetGroupMode(mlir::lmhlo::ReduceScatterOp op)354 /*static*/ CollectiveOpGroupMode NcclReduceScatterThunk::GetGroupMode(
355     mlir::lmhlo::ReduceScatterOp op) {
356   return impl::GetGroupMode(op);
357 }
358 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)359 Status NcclReduceScatterThunk::RunNcclCollective(const ExecuteParams& params,
360                                                  ncclComm_t comm) {
361   TF_ASSIGN_OR_RETURN(
362       std::vector<DeviceBufferPair> device_buffers,
363       ConvertToDeviceBuffers(params, buffers_,
364                              config_.config.operand_element_type));
365   return RunReduceScatter(config_.reduction_kind, device_buffers,
366                           *params.stream, comm);
367 }
368 
RunReduceScatter(ReductionKind reduction_kind,std::vector<DeviceBufferPair> & buffers,se::Stream & stream,ncclComm_t comm)369 Status RunReduceScatter(ReductionKind reduction_kind,
370                         std::vector<DeviceBufferPair>& buffers,
371                         se::Stream& stream, ncclComm_t comm) {
372 #if XLA_ENABLE_XCCL
373   int device_ordinal = stream.parent()->device_ordinal();
374   VLOG(3) << "Performing reduce-scatter from device ordinal: "
375           << device_ordinal;
376 
377   ncclRedOp_t reduce_op = ToNcclReduction(reduction_kind);
378 
379   se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);
380 
381   int num_participants = 0;
382   XLA_CUDA_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants));
383 
384   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
385   for (size_t i = 0; i < buffers.size(); ++i) {
386     DeviceBufferPair& buffer = buffers[i];
387     const void* send_buffer = buffer.source_buffer.opaque();
388     void* recv_buffer = buffer.destination_buffer.opaque();
389 
390     TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
391                         ToNcclDataTypeAndCountMultiplier(buffer.element_type));
392     ncclDataType_t dtype = dtype_and_multiplier.first;
393     int element_count = buffer.element_count * dtype_and_multiplier.second;
394 
395     // buffer.element_count is the source buffers element count. For
396     // ncclReduceScatter, we need the destination buffers element count.
397     TF_RET_CHECK(element_count % num_participants == 0)
398         << "Source buffer was not an exact multiple of the number of "
399            "participants.";
400 
401     int64_t recv_count = element_count / num_participants;
402     VLOG(3) << absl::StreamFormat(
403         "Calling ncclReduceScatter(send_buffer=%p, recv_buffer=%p, "
404         "recvcount=%d, "
405         "comm=%p, stream=%p)",
406         send_buffer, recv_buffer, recv_count, static_cast<const void*>(comm),
407         gpu_stream);
408     XLA_CUDA_RETURN_IF_ERROR(ncclReduceScatter(send_buffer, recv_buffer,
409                                                recv_count, dtype, reduce_op,
410                                                comm, gpu_stream));
411   }
412   XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
413 
414   VLOG(3) << "Done performing reduce-scatter for ordinal: " << device_ordinal;
415   return OkStatus();
416 #else   // XLA_ENABLE_XCCL
417   return Unimplemented(
418       "NCCL support is not available: this binary was not built with a CUDA "
419       "compiler, which is necessary to build the NCCL source library.");
420 #endif  // XLA_ENABLE_XCCL
421 }
422 
423 }  // namespace gpu
424 }  // namespace xla
425