xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
18 
19 #include <string>
20 
21 #include "absl/synchronization/mutex.h"
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
25 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
26 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
27 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
28 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 #if XLA_ENABLE_XCCL
33 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
34 #endif  // XLA_ENABLE_XCCL
35 
36 struct ncclComm;
37 using ncclComm_t = ncclComm*;
38 
39 namespace xla {
40 namespace gpu {
41 
42 class NcclClique;
43 
44 struct NcclCollectiveConfig {
45   NcclCollectiveConfig();
46   NcclCollectiveConfig(NcclCollectiveConfig&&);
47   ~NcclCollectiveConfig();
48 
49   NcclCollectiveConfig& operator=(NcclCollectiveConfig&&);
50 
51   int64_t operand_count;
52   std::vector<PrimitiveType> operand_element_type;
53   std::vector<ReplicaGroup> replica_groups;
54   RendezvousKey::CollectiveOpKind collective_op_kind;
55   int64_t op_id;
56   CollectiveOpGroupMode group_mode;
57 
58   template <typename OpT>
59   void SetCollectiveOpKindAndID(OpT op);
60   bool IsDegenerate(int64_t replica_count, int64_t partition_count) const;
61 };
62 
63 template <typename OpT>
SetCollectiveOpKindAndID(OpT op)64 void NcclCollectiveConfig::SetCollectiveOpKindAndID(OpT op) {
65   if (op.getChannelId()) {
66     collective_op_kind = RendezvousKey::kCrossModule;
67     op_id = static_cast<int64_t>(op.getChannelId()->getHandle());
68   } else {
69     collective_op_kind = RendezvousKey::kCrossReplica;
70     mlir::ModuleOp parent = op->template getParentOfType<mlir::ModuleOp>();
71     mlir::IntegerAttr unique_id =
72         parent->getAttrOfType<mlir::IntegerAttr>("hlo.unique_id");
73     op_id = static_cast<int64_t>(unique_id.getInt());
74   }
75 }
76 
77 template <typename OpT>
GetNcclCollectiveConfigForMlir(OpT op,std::optional<bool> use_global_device_ids)78 NcclCollectiveConfig GetNcclCollectiveConfigForMlir(
79     OpT op, std::optional<bool> use_global_device_ids) {
80   NcclCollectiveConfig config;
81   config.operand_count = op.getInputs().size();
82   config.operand_element_type.reserve(config.operand_count);
83   for (int i = 0; i < config.operand_count; i++) {
84     const Shape shape = GetShape(op.getInputs()[i]);
85     config.operand_element_type.push_back(shape.element_type());
86   }
87   config.replica_groups =
88       ConvertReplicaGroups(op.getReplicaGroups()).ValueOrDie();
89   config.SetCollectiveOpKindAndID(op);
90   config.group_mode = GetCollectiveOpGroupMode(op.getChannelId().has_value(),
91                                                use_global_device_ids)
92                           .ValueOrDie();
93   return config;
94 }
95 
96 // Thunk base class for NCCL collective operations.
97 class NcclCollectiveThunk : public Thunk {
98  public:
99   using Thunk::Thunk;
100 
101   struct Buffer {
102     int64_t element_count;
103     BufferAllocation::Slice source_buffer;
104     BufferAllocation::Slice destination_buffer;
105   };
106 
107   // Returns whether NCCL operations appear possible to perform; e.g. if we
108   // haven't done a build with the CUDA compiler enabled, we can't compile the
109   // NCCL header, and thus this will be false.
110   //
111   // When this is false, the ExecuteOnStream() call will simply return a status
112   // error.
113   static bool NcclIsEnabled();
114 
115   // Logging support.
116   static std::string GetDeviceString(const NcclExecuteParams& params);
117 
118   Status ExecuteOnStream(const ExecuteParams& params) override;
119 
120  protected:
121   virtual Status RunNcclCollective(const ExecuteParams& params,
122                                    ncclComm_t comm) = 0;
123   virtual const NcclCollectiveConfig& config() const = 0;
124 
125  private:
126 #if XLA_ENABLE_XCCL
127   bool first_call_to_execute_ = true;
128 #endif  // XLA_ENABLE_XCCL
129 };
130 
131 // Returns if the given data type is supported by NCCL.
132 // Note: Keep this in sync with ToNcclDataType().
133 bool IsTypeSupportedByNccl(PrimitiveType element_type);
134 
135 #if XLA_ENABLE_XCCL
136 // TODO(hanbinyoon): Consider moving to nccl_utils.h when deprecating Thunks.
137 StatusOr<NcclComm::Lock> LockNcclComm(
138     const NcclExecuteParams& params,
139     const std::vector<ReplicaGroup>& replica_groups,
140     CollectiveOpGroupMode group_mode, int64_t op_id);
141 #endif  // XLA_ENABLE_XCCL
142 
143 struct DeviceBufferPair {
144   PrimitiveType element_type;
145   int64_t element_count;
146   se::DeviceMemoryBase source_buffer;
147   se::DeviceMemoryBase destination_buffer;
148 };
149 StatusOr<std::vector<DeviceBufferPair>> ConvertToDeviceBuffers(
150     const Thunk::ExecuteParams& params,
151     const std::vector<NcclCollectiveThunk::Buffer>& buffers,
152     const std::vector<PrimitiveType>& element_types);
153 
154 }  // namespace gpu
155 }  // namespace xla
156 
157 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
158