xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
17 
18 #include <chrono>  // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/strings/str_format.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/util.h"
33 
34 #if XLA_ENABLE_XCCL
35 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
36 #endif
37 
38 namespace xla {
39 namespace gpu {
40 
GetNcclAllToAllConfig(mlir::lmhlo::AllToAllOp op)41 /*static*/ NcclAllToAllConfig NcclAllToAllThunk::GetNcclAllToAllConfig(
42     mlir::lmhlo::AllToAllOp op) {
43   NcclAllToAllConfig config;
44   // FIXME(b/180174349): LMHLO AllToAll incorrectly has use_global_device_ids
45   // attribute and it should be removed.
46   config.config = GetNcclCollectiveConfigForMlir(op, std::nullopt);
47   config.has_split_dimension = op.getSplitDimension().has_value();
48   return config;
49 }
50 
CanImplement(mlir::lmhlo::AllToAllOp op)51 /*static*/ bool NcclAllToAllThunk::CanImplement(mlir::lmhlo::AllToAllOp op) {
52   return absl::c_all_of(op.getInputs(), [&op](mlir::Value operand) {
53     Shape shape = GetShape(operand);
54     return LayoutUtil::IsDenseArray(shape) &&
55            IsTypeSupportedByNccl(shape.element_type()) &&
56            (!op.getSplitDimension() ||
57             LayoutUtil::MinorToMajor(shape).back() == *op.getSplitDimension());
58   });
59 }
60 
NcclAllToAllThunk(ThunkInfo thunk_info,mlir::lmhlo::AllToAllOp op,std::vector<NcclAllToAllThunk::Buffer> buffers)61 NcclAllToAllThunk::NcclAllToAllThunk(
62     ThunkInfo thunk_info, mlir::lmhlo::AllToAllOp op,
63     std::vector<NcclAllToAllThunk::Buffer> buffers)
64     : NcclCollectiveThunk(Thunk::kNcclAllToAll, thunk_info),
65       config_(GetNcclAllToAllConfig(op)),
66       buffers_(std::move(buffers)) {
67   CHECK_EQ(config_.config.operand_count, buffers_.size());
68 }
69 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)70 Status NcclAllToAllThunk::RunNcclCollective(const ExecuteParams& params,
71                                             ncclComm_t comm) {
72   TF_ASSIGN_OR_RETURN(
73       std::vector<DeviceBufferPair> device_buffers,
74       ConvertToDeviceBuffers(params, buffers_,
75                              config_.config.operand_element_type));
76   return RunAllToAll(config_.has_split_dimension, device_buffers,
77                      *params.stream, comm);
78 }
79 
RunAllToAll(bool has_split_dimension,std::vector<DeviceBufferPair> & buffers,se::Stream & stream,ncclComm_t comm)80 Status RunAllToAll(bool has_split_dimension,
81                    std::vector<DeviceBufferPair>& buffers, se::Stream& stream,
82                    ncclComm_t comm) {
83 #if XLA_ENABLE_XCCL
84   int device_ordinal = stream.parent()->device_ordinal();
85   VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal;
86 
87   se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);
88 
89   int num_participants;
90   XLA_CUDA_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants));
91 
92   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
93   // AllToAll can operate in two modes. Either it specifies a split dimension,
94   // in which case inputs are split and outputs concatenated in that dimension
95   // (here, we only support dimension 0), or it takes a list of inputs
96   // and produces a tuple of outputs.
97   if (has_split_dimension) {
98     for (size_t i = 0; i < buffers.size(); ++i) {
99       DeviceBufferPair& buffer = buffers[i];
100       const uint8_t* send_buffer =
101           static_cast<uint8_t*>(buffer.source_buffer.opaque());
102       uint8_t* recv_buffer =
103           static_cast<uint8_t*>(buffer.destination_buffer.opaque());
104 
105       TF_ASSIGN_OR_RETURN(
106           auto dtype_and_multiplier,
107           ToNcclDataTypeAndCountMultiplier(buffer.element_type));
108       ncclDataType_t dtype = dtype_and_multiplier.first;
109       int element_count = buffer.element_count * dtype_and_multiplier.second;
110 
111       TF_RET_CHECK(element_count % num_participants == 0)
112           << "Buffer was not an exact multiple of the number of participants.";
113       size_t chunk_elements = element_count / num_participants;
114       size_t chunk_bytes = chunk_elements * ShapeUtil::ByteSizeOfPrimitiveType(
115                                                 buffer.element_type);
116 
117       for (int rank = 0; rank < num_participants; ++rank) {
118         XLA_CUDA_RETURN_IF_ERROR(ncclSend(send_buffer + rank * chunk_bytes,
119                                           chunk_elements, dtype, rank, comm,
120                                           gpu_stream));
121         XLA_CUDA_RETURN_IF_ERROR(ncclRecv(recv_buffer + rank * chunk_bytes,
122                                           chunk_elements, dtype, rank, comm,
123                                           gpu_stream));
124       }
125     }
126   } else {
127     TF_RET_CHECK(buffers.size() == num_participants)
128         << "Number of inputs didn't match the number of participants.";
129 
130     for (size_t i = 0; i < buffers.size(); ++i) {
131       DeviceBufferPair& buffer = buffers[i];
132       const uint8_t* send_buffer =
133           static_cast<uint8_t*>(buffer.source_buffer.opaque());
134       uint8_t* recv_buffer =
135           static_cast<uint8_t*>(buffer.destination_buffer.opaque());
136 
137       TF_ASSIGN_OR_RETURN(
138           auto dtype_and_multiplier,
139           ToNcclDataTypeAndCountMultiplier(buffer.element_type));
140       ncclDataType_t dtype = dtype_and_multiplier.first;
141       int element_count = buffer.element_count * dtype_and_multiplier.second;
142 
143       XLA_CUDA_RETURN_IF_ERROR(ncclSend(send_buffer, element_count, dtype,
144                                         /*rank=*/i, comm, gpu_stream));
145       XLA_CUDA_RETURN_IF_ERROR(ncclRecv(recv_buffer, element_count, dtype,
146                                         /*rank=*/i, comm, gpu_stream));
147     }
148   }
149   XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
150 
151   VLOG(3) << "Done performing all-to-all for ordinal: " << device_ordinal;
152   return OkStatus();
153 #else   // XLA_ENABLE_XCCL
154   return Unimplemented(
155       "NCCL support is not available: this binary was not built with a CUDA "
156       "compiler, which is necessary to build the NCCL source library.");
157 #endif  // XLA_ENABLE_XCCL
158 }
159 
160 }  // namespace gpu
161 }  // namespace xla
162