1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
17
18 #include <chrono> // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34
35 #if XLA_ENABLE_XCCL
36 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
37 #endif
38
39 namespace xla {
40 namespace gpu {
41
RunAllReduce(ReductionKind reduction_kind,std::vector<DeviceBufferPair> & buffers,se::Stream & stream,ncclComm_t comm)42 Status RunAllReduce(ReductionKind reduction_kind,
43 std::vector<DeviceBufferPair>& buffers, se::Stream& stream,
44 ncclComm_t comm) {
45 #if XLA_ENABLE_XCCL
46 int device_ordinal = stream.parent()->device_ordinal();
47 VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
48
49 ncclRedOp_t reduce_op = ToNcclReduction(reduction_kind);
50
51 se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);
52
53 XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
54 for (size_t i = 0; i < buffers.size(); ++i) {
55 DeviceBufferPair& buffer = buffers[i];
56 const void* send_buffer = buffer.source_buffer.opaque();
57 void* recv_buffer = buffer.destination_buffer.opaque();
58
59 TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
60 ToNcclDataTypeAndCountMultiplier(buffer.element_type));
61 ncclDataType_t dtype = dtype_and_multiplier.first;
62 int element_count = buffer.element_count * dtype_and_multiplier.second;
63
64 VLOG(3) << absl::StreamFormat(
65 "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
66 "comm=%p, stream=%p)",
67 send_buffer, recv_buffer, element_count, static_cast<const void*>(comm),
68 gpu_stream);
69
70 XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
71 element_count, dtype, reduce_op,
72 comm, gpu_stream));
73 }
74 return XLA_CUDA_STATUS(ncclGroupEnd());
75 #else // XLA_ENABLE_XCCL
76 return Unimplemented(
77 "NCCL support is not available: this binary was not built with a CUDA "
78 "compiler, which is necessary to build the NCCL source library.");
79 #endif // XLA_ENABLE_XCCL
80 }
81
82 namespace {
83
IsValidOperand(mlir::Value operand)84 bool IsValidOperand(mlir::Value operand) {
85 Shape shape = TypeToShape(operand.getType());
86 return LayoutUtil::IsDenseArray(shape) &&
87 IsTypeSupportedByNccl(shape.element_type());
88 }
89
90 // Generally, the reduction op should be the only operation in the block, except
91 // the terminator. However, if the type is bf16, the `BFloat16Normalization`
92 // pass will have converted the op to float32 and added type conversions.
93 // TODO(cjfj): Can we prevent the bf16 conversion for this computation?
FindReductionOp(mlir::Block & block)94 StatusOr<mlir::Operation*> FindReductionOp(mlir::Block& block) {
95 TF_RET_CHECK(block.getNumArguments() == 2);
96 mlir::Operation* terminator = block.getTerminator();
97 TF_RET_CHECK(terminator);
98 TF_RET_CHECK(terminator->getNumOperands() == 1);
99 mlir::Value result = terminator->getOperand(0);
100 TF_RET_CHECK(block.getArgument(0).getType() == result.getType());
101 TF_RET_CHECK(block.getArgument(1).getType() == result.getType());
102
103 mlir::Operation* result_op = result.getDefiningOp();
104 TF_RET_CHECK(result_op);
105
106 // In the bf16 case, the type conversions and op might be fused.
107 if (mlir::isa<mlir::mhlo::FusionOp>(result_op)) {
108 return FindReductionOp(result_op->getRegion(0).front());
109 }
110
111 // Standard case.
112 if (absl::c_is_permutation(result_op->getOperands(), block.getArguments())) {
113 return result_op;
114 }
115
116 // bf16 case.
117 TF_RET_CHECK(mlir::isa<mlir::mhlo::ConvertOp>(result_op));
118 TF_RET_CHECK(result_op->getNumOperands() == 1);
119 mlir::Operation* reduction_op = result_op->getOperand(0).getDefiningOp();
120 TF_RET_CHECK(reduction_op);
121 TF_RET_CHECK(reduction_op->getNumOperands() == 2);
122 mlir::Value operand0 = reduction_op->getOperand(0);
123 mlir::Value operand1 = reduction_op->getOperand(1);
124 auto operand0_op = operand0.getDefiningOp<mlir::mhlo::ConvertOp>();
125 auto operand1_op = operand1.getDefiningOp<mlir::mhlo::ConvertOp>();
126 TF_RET_CHECK(operand0_op);
127 TF_RET_CHECK(operand1_op);
128 TF_RET_CHECK(operand0_op->getNumOperands() == 1);
129 TF_RET_CHECK(operand1_op->getNumOperands() == 1);
130 std::array<mlir::Value, 2> operands{operand0_op->getOperand(0),
131 operand1_op->getOperand(0)};
132 TF_RET_CHECK(absl::c_is_permutation(operands, block.getArguments()));
133 return reduction_op;
134 }
135
136 } // namespace
137
138 namespace impl {
139
140 template <typename OpT>
CanImplement(OpT op)141 bool CanImplement(OpT op) {
142 return absl::c_all_of(op.getInputs(), IsValidOperand) &&
143 NcclAllReduceThunkBase::MatchAllReduceComputation(op.getComputation())
144 .has_value();
145 }
146
147 template <typename OpT>
GetNcclAllReduceConfig(OpT op)148 NcclAllReduceConfig GetNcclAllReduceConfig(OpT op) {
149 std::optional<ReductionKind> reduction_kind =
150 NcclAllReduceThunkBase::MatchAllReduceComputation(op.getComputation());
151 CHECK(reduction_kind.has_value());
152
153 NcclAllReduceConfig config;
154 config.config =
155 GetNcclCollectiveConfigForMlir(op, op.getUseGlobalDeviceIds());
156 config.reduction_kind = *reduction_kind;
157 return config;
158 }
159
160 template <typename OpT>
IsDegenerate(OpT op,int64_t replica_count,int64_t partition_count)161 bool IsDegenerate(OpT op, int64_t replica_count, int64_t partition_count) {
162 return GetNcclCollectiveConfigForMlir(op, op.getUseGlobalDeviceIds())
163 .IsDegenerate(replica_count, partition_count);
164 }
165
166 template <typename OpT>
GetGroupMode(OpT op)167 CollectiveOpGroupMode GetGroupMode(OpT op) {
168 return GetNcclAllReduceConfig(op).config.group_mode;
169 }
170
171 } // namespace impl
172
MatchAllReduceComputation(mlir::Region & computation)173 std::optional<ReductionKind> NcclAllReduceThunkBase::MatchAllReduceComputation(
174 mlir::Region& computation) {
175 mlir::Block& block = computation.front();
176 StatusOr<mlir::Operation*> reduction_op = FindReductionOp(block);
177 if (!reduction_op.ok()) return std::nullopt;
178 StatusOr<HloOpcode> opcode = MhloToHloOpcode(*reduction_op);
179 if (!opcode.ok()) return std::nullopt;
180 // Match the operation to a reduction kind. We can represent and/or of pred as
181 // min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
182 PrimitiveType type =
183 TypeToShape(block.getArgument(0).getType()).element_type();
184 if (type == PRED) {
185 switch (opcode.ValueOrDie()) {
186 case HloOpcode::kAnd:
187 return ReductionKind::MIN;
188 case HloOpcode::kOr:
189 return ReductionKind::MAX;
190 default:
191 return std::nullopt;
192 }
193 } else if (primitive_util::IsComplexType(type)) {
194 // Only addition is supported for complex types.
195 if (*opcode == HloOpcode::kAdd) {
196 return ReductionKind::SUM;
197 } else {
198 return std::nullopt;
199 }
200 } else {
201 switch (*opcode) {
202 case HloOpcode::kAdd:
203 return ReductionKind::SUM;
204 case HloOpcode::kMultiply:
205 return ReductionKind::PRODUCT;
206 case HloOpcode::kMaximum:
207 return ReductionKind::MAX;
208 case HloOpcode::kMinimum:
209 return ReductionKind::MIN;
210 default:
211 return std::nullopt;
212 }
213 }
214 }
215
NcclAllReduceThunkBase(Thunk::Kind kind,ThunkInfo thunk_info,NcclAllReduceConfig config,std::vector<Buffer> buffers)216 NcclAllReduceThunkBase::NcclAllReduceThunkBase(Thunk::Kind kind,
217 ThunkInfo thunk_info,
218 NcclAllReduceConfig config,
219 std::vector<Buffer> buffers)
220 : NcclCollectiveThunk(kind, thunk_info),
221 config_(std::move(config)),
222 buffers_(std::move(buffers)) {
223 CHECK_EQ(config_.config.operand_count, buffers_.size());
224 }
225
NcclAllReduceThunk(ThunkInfo thunk_info,mlir::lmhlo::AllReduceOp op,std::vector<Buffer> buffers)226 NcclAllReduceThunk::NcclAllReduceThunk(ThunkInfo thunk_info,
227 mlir::lmhlo::AllReduceOp op,
228 std::vector<Buffer> buffers)
229 : NcclAllReduceThunkBase(Thunk::kNcclAllReduce, thunk_info,
230 impl::GetNcclAllReduceConfig(op), buffers) {}
231
CanImplement(mlir::lmhlo::AllReduceOp op)232 bool NcclAllReduceThunk::CanImplement(mlir::lmhlo::AllReduceOp op) {
233 return impl::CanImplement(op);
234 }
235
IsDegenerate(mlir::lmhlo::AllReduceOp op,int64_t replica_count,int64_t partition_count)236 bool NcclAllReduceThunk::IsDegenerate(mlir::lmhlo::AllReduceOp op,
237 int64_t replica_count,
238 int64_t partition_count) {
239 return impl::IsDegenerate(op, replica_count, partition_count);
240 }
241
GetGroupMode(mlir::lmhlo::AllReduceOp op)242 CollectiveOpGroupMode NcclAllReduceThunk::GetGroupMode(
243 mlir::lmhlo::AllReduceOp op) {
244 return impl::GetGroupMode(op);
245 }
246
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)247 Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams& params,
248 ncclComm_t comm) {
249 se::Stream& stream = *params.stream;
250 TF_ASSIGN_OR_RETURN(
251 std::vector<DeviceBufferPair> device_buffers,
252 ConvertToDeviceBuffers(params, buffers_,
253 config_.config.operand_element_type));
254 TF_RETURN_IF_ERROR(
255 RunAllReduce(config_.reduction_kind, device_buffers, stream, comm));
256
257 int device_ordinal = stream.parent()->device_ordinal();
258 VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal;
259 return OkStatus();
260 }
261
NcclAllReduceStartThunk(ThunkInfo thunk_info,mlir::lmhlo_gpu::AllReduceStartOp op,std::vector<Buffer> buffers)262 NcclAllReduceStartThunk::NcclAllReduceStartThunk(
263 ThunkInfo thunk_info, mlir::lmhlo_gpu::AllReduceStartOp op,
264 std::vector<Buffer> buffers)
265 : NcclAllReduceThunkBase(Thunk::kNcclAllReduceStart, thunk_info,
266 impl::GetNcclAllReduceConfig(op), buffers) {}
267
CanImplement(mlir::lmhlo_gpu::AllReduceStartOp op)268 bool NcclAllReduceStartThunk::CanImplement(
269 mlir::lmhlo_gpu::AllReduceStartOp op) {
270 return impl::CanImplement(op);
271 }
272
IsDegenerate(mlir::lmhlo_gpu::AllReduceStartOp op,int64_t replica_count,int64_t partition_count)273 bool NcclAllReduceStartThunk::IsDegenerate(mlir::lmhlo_gpu::AllReduceStartOp op,
274 int64_t replica_count,
275 int64_t partition_count) {
276 return impl::IsDegenerate(op, replica_count, partition_count);
277 }
278
GetGroupMode(mlir::lmhlo_gpu::AllReduceStartOp op)279 CollectiveOpGroupMode NcclAllReduceStartThunk::GetGroupMode(
280 mlir::lmhlo_gpu::AllReduceStartOp op) {
281 return impl::GetGroupMode(op);
282 }
283
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)284 Status NcclAllReduceStartThunk::RunNcclCollective(const ExecuteParams& params,
285 ncclComm_t comm) {
286 se::Stream& async_comms_stream = *params.async_comms_stream;
287 // Wait until compute inputs are ready.
288 async_comms_stream.ThenWaitFor(params.stream);
289
290 TF_ASSIGN_OR_RETURN(
291 std::vector<DeviceBufferPair> device_buffers,
292 ConvertToDeviceBuffers(params, buffers_,
293 config_.config.operand_element_type));
294 TF_RETURN_IF_ERROR(RunAllReduce(config_.reduction_kind, device_buffers,
295 async_comms_stream, comm));
296
297 // Create an event on the async stream for the completion of the all-reduce.
298 se::Event done_event(async_comms_stream.parent());
299 TF_RET_CHECK(done_event.Init());
300 async_comms_stream.ThenRecordEvent(&done_event);
301
302 int device_ordinal = async_comms_stream.parent()->device_ordinal();
303
304 {
305 absl::MutexLock lock(&mu_);
306 auto result = done_events_.emplace(device_ordinal, std::move(done_event));
307 TF_RET_CHECK(result.second) << "done event has not been consumed";
308 }
309
310 VLOG(3) << "Done performing all-reduce-start for ordinal: " << device_ordinal;
311 return OkStatus();
312 }
313
TakeDoneEvent(int device_ordinal)314 StatusOr<se::Event> NcclAllReduceStartThunk::TakeDoneEvent(int device_ordinal) {
315 absl::MutexLock lock(&mu_);
316 auto it = done_events_.find(device_ordinal);
317 TF_RET_CHECK(it != done_events_.end()) << "done event not found";
318 // Take ownership of the event.
319 se::Event done_event = std::move(it->second);
320 done_events_.erase(it);
321 return done_event;
322 }
323
NcclAllReduceDoneThunk(ThunkInfo thunk_info,NcclAllReduceStartThunk & start_thunk)324 NcclAllReduceDoneThunk::NcclAllReduceDoneThunk(
325 ThunkInfo thunk_info, NcclAllReduceStartThunk& start_thunk)
326 : Thunk(Thunk::kNcclAllReduceDone, thunk_info), start_thunk_(start_thunk) {}
327
ExecuteOnStream(const ExecuteParams & params)328 Status NcclAllReduceDoneThunk::ExecuteOnStream(const ExecuteParams& params) {
329 int device_ordinal = params.stream->parent()->device_ordinal();
330 TF_ASSIGN_OR_RETURN(se::Event done_event,
331 start_thunk_.TakeDoneEvent(device_ordinal));
332 params.stream->ThenWaitFor(&done_event);
333 return OkStatus();
334 }
335
NcclReduceScatterThunk(ThunkInfo thunk_info,mlir::lmhlo::ReduceScatterOp op,std::vector<NcclAllReduceThunk::Buffer> buffers)336 NcclReduceScatterThunk::NcclReduceScatterThunk(
337 ThunkInfo thunk_info, mlir::lmhlo::ReduceScatterOp op,
338 std::vector<NcclAllReduceThunk::Buffer> buffers)
339 : NcclAllReduceThunkBase(Thunk::kNcclReduceScatter, thunk_info,
340 impl::GetNcclAllReduceConfig(op),
341 std::move(buffers)) {}
342
CanImplement(mlir::lmhlo::ReduceScatterOp op)343 /*static*/ bool NcclReduceScatterThunk::CanImplement(
344 mlir::lmhlo::ReduceScatterOp op) {
345 return impl::CanImplement(op);
346 }
347
IsDegenerate(mlir::lmhlo::ReduceScatterOp op,int64_t replica_count,int64_t partition_count)348 /*static*/ bool NcclReduceScatterThunk::IsDegenerate(
349 mlir::lmhlo::ReduceScatterOp op, int64_t replica_count,
350 int64_t partition_count) {
351 return impl::IsDegenerate(op, replica_count, partition_count);
352 }
353
GetGroupMode(mlir::lmhlo::ReduceScatterOp op)354 /*static*/ CollectiveOpGroupMode NcclReduceScatterThunk::GetGroupMode(
355 mlir::lmhlo::ReduceScatterOp op) {
356 return impl::GetGroupMode(op);
357 }
358
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)359 Status NcclReduceScatterThunk::RunNcclCollective(const ExecuteParams& params,
360 ncclComm_t comm) {
361 TF_ASSIGN_OR_RETURN(
362 std::vector<DeviceBufferPair> device_buffers,
363 ConvertToDeviceBuffers(params, buffers_,
364 config_.config.operand_element_type));
365 return RunReduceScatter(config_.reduction_kind, device_buffers,
366 *params.stream, comm);
367 }
368
RunReduceScatter(ReductionKind reduction_kind,std::vector<DeviceBufferPair> & buffers,se::Stream & stream,ncclComm_t comm)369 Status RunReduceScatter(ReductionKind reduction_kind,
370 std::vector<DeviceBufferPair>& buffers,
371 se::Stream& stream, ncclComm_t comm) {
372 #if XLA_ENABLE_XCCL
373 int device_ordinal = stream.parent()->device_ordinal();
374 VLOG(3) << "Performing reduce-scatter from device ordinal: "
375 << device_ordinal;
376
377 ncclRedOp_t reduce_op = ToNcclReduction(reduction_kind);
378
379 se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);
380
381 int num_participants = 0;
382 XLA_CUDA_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants));
383
384 XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
385 for (size_t i = 0; i < buffers.size(); ++i) {
386 DeviceBufferPair& buffer = buffers[i];
387 const void* send_buffer = buffer.source_buffer.opaque();
388 void* recv_buffer = buffer.destination_buffer.opaque();
389
390 TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
391 ToNcclDataTypeAndCountMultiplier(buffer.element_type));
392 ncclDataType_t dtype = dtype_and_multiplier.first;
393 int element_count = buffer.element_count * dtype_and_multiplier.second;
394
395 // buffer.element_count is the source buffers element count. For
396 // ncclReduceScatter, we need the destination buffers element count.
397 TF_RET_CHECK(element_count % num_participants == 0)
398 << "Source buffer was not an exact multiple of the number of "
399 "participants.";
400
401 int64_t recv_count = element_count / num_participants;
402 VLOG(3) << absl::StreamFormat(
403 "Calling ncclReduceScatter(send_buffer=%p, recv_buffer=%p, "
404 "recvcount=%d, "
405 "comm=%p, stream=%p)",
406 send_buffer, recv_buffer, recv_count, static_cast<const void*>(comm),
407 gpu_stream);
408 XLA_CUDA_RETURN_IF_ERROR(ncclReduceScatter(send_buffer, recv_buffer,
409 recv_count, dtype, reduce_op,
410 comm, gpu_stream));
411 }
412 XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
413
414 VLOG(3) << "Done performing reduce-scatter for ordinal: " << device_ordinal;
415 return OkStatus();
416 #else // XLA_ENABLE_XCCL
417 return Unimplemented(
418 "NCCL support is not available: this binary was not built with a CUDA "
419 "compiler, which is necessary to build the NCCL source library.");
420 #endif // XLA_ENABLE_XCCL
421 }
422
423 } // namespace gpu
424 } // namespace xla
425