1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
18
19 #include <string>
20
21 #include "absl/synchronization/mutex.h"
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
25 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
26 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
27 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
28 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31
32 #if XLA_ENABLE_XCCL
33 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
34 #endif // XLA_ENABLE_XCCL
35
36 struct ncclComm;
37 using ncclComm_t = ncclComm*;
38
39 namespace xla {
40 namespace gpu {
41
42 class NcclClique;
43
44 struct NcclCollectiveConfig {
45 NcclCollectiveConfig();
46 NcclCollectiveConfig(NcclCollectiveConfig&&);
47 ~NcclCollectiveConfig();
48
49 NcclCollectiveConfig& operator=(NcclCollectiveConfig&&);
50
51 int64_t operand_count;
52 std::vector<PrimitiveType> operand_element_type;
53 std::vector<ReplicaGroup> replica_groups;
54 RendezvousKey::CollectiveOpKind collective_op_kind;
55 int64_t op_id;
56 CollectiveOpGroupMode group_mode;
57
58 template <typename OpT>
59 void SetCollectiveOpKindAndID(OpT op);
60 bool IsDegenerate(int64_t replica_count, int64_t partition_count) const;
61 };
62
63 template <typename OpT>
SetCollectiveOpKindAndID(OpT op)64 void NcclCollectiveConfig::SetCollectiveOpKindAndID(OpT op) {
65 if (op.getChannelId()) {
66 collective_op_kind = RendezvousKey::kCrossModule;
67 op_id = static_cast<int64_t>(op.getChannelId()->getHandle());
68 } else {
69 collective_op_kind = RendezvousKey::kCrossReplica;
70 mlir::ModuleOp parent = op->template getParentOfType<mlir::ModuleOp>();
71 mlir::IntegerAttr unique_id =
72 parent->getAttrOfType<mlir::IntegerAttr>("hlo.unique_id");
73 op_id = static_cast<int64_t>(unique_id.getInt());
74 }
75 }
76
77 template <typename OpT>
GetNcclCollectiveConfigForMlir(OpT op,std::optional<bool> use_global_device_ids)78 NcclCollectiveConfig GetNcclCollectiveConfigForMlir(
79 OpT op, std::optional<bool> use_global_device_ids) {
80 NcclCollectiveConfig config;
81 config.operand_count = op.getInputs().size();
82 config.operand_element_type.reserve(config.operand_count);
83 for (int i = 0; i < config.operand_count; i++) {
84 const Shape shape = GetShape(op.getInputs()[i]);
85 config.operand_element_type.push_back(shape.element_type());
86 }
87 config.replica_groups =
88 ConvertReplicaGroups(op.getReplicaGroups()).ValueOrDie();
89 config.SetCollectiveOpKindAndID(op);
90 config.group_mode = GetCollectiveOpGroupMode(op.getChannelId().has_value(),
91 use_global_device_ids)
92 .ValueOrDie();
93 return config;
94 }
95
96 // Thunk base class for NCCL collective operations.
97 class NcclCollectiveThunk : public Thunk {
98 public:
99 using Thunk::Thunk;
100
101 struct Buffer {
102 int64_t element_count;
103 BufferAllocation::Slice source_buffer;
104 BufferAllocation::Slice destination_buffer;
105 };
106
107 // Returns whether NCCL operations appear possible to perform; e.g. if we
108 // haven't done a build with the CUDA compiler enabled, we can't compile the
109 // NCCL header, and thus this will be false.
110 //
111 // When this is false, the ExecuteOnStream() call will simply return a status
112 // error.
113 static bool NcclIsEnabled();
114
115 // Logging support.
116 static std::string GetDeviceString(const NcclExecuteParams& params);
117
118 Status ExecuteOnStream(const ExecuteParams& params) override;
119
120 protected:
121 virtual Status RunNcclCollective(const ExecuteParams& params,
122 ncclComm_t comm) = 0;
123 virtual const NcclCollectiveConfig& config() const = 0;
124
125 private:
126 #if XLA_ENABLE_XCCL
127 bool first_call_to_execute_ = true;
128 #endif // XLA_ENABLE_XCCL
129 };
130
131 // Returns if the given data type is supported by NCCL.
132 // Note: Keep this in sync with ToNcclDataType().
133 bool IsTypeSupportedByNccl(PrimitiveType element_type);
134
135 #if XLA_ENABLE_XCCL
136 // TODO(hanbinyoon): Consider moving to nccl_utils.h when deprecating Thunks.
137 StatusOr<NcclComm::Lock> LockNcclComm(
138 const NcclExecuteParams& params,
139 const std::vector<ReplicaGroup>& replica_groups,
140 CollectiveOpGroupMode group_mode, int64_t op_id);
141 #endif // XLA_ENABLE_XCCL
142
143 struct DeviceBufferPair {
144 PrimitiveType element_type;
145 int64_t element_count;
146 se::DeviceMemoryBase source_buffer;
147 se::DeviceMemoryBase destination_buffer;
148 };
149 StatusOr<std::vector<DeviceBufferPair>> ConvertToDeviceBuffers(
150 const Thunk::ExecuteParams& params,
151 const std::vector<NcclCollectiveThunk::Buffer>& buffers,
152 const std::vector<PrimitiveType>& element_types);
153
154 } // namespace gpu
155 } // namespace xla
156
157 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
158