1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" 21 #include "tensorflow/compiler/xla/service/collective_ops_utils.h" 22 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" 23 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/xla_data.pb.h" 26 27 namespace xla { 28 namespace gpu { 29 30 struct NcclCollectivePermuteConfig { 31 // During a collective permute, every node optionally sends its data to 32 // another node (including possibly itself) and received data from another 33 // node. For each node, remember who it receives data from (source) and who 34 // it send data to (target). Either are optional. 35 struct SourceTargetMapEntry { 36 std::optional<int64_t> source; 37 std::optional<int64_t> target; 38 }; 39 40 using IdToSourceTargetMap = 41 absl::flat_hash_map<int64_t, SourceTargetMapEntry>; 42 43 // Returns the source and target ID corresponding to the given ID (these IDs 44 // are replica_ids for cross replica permute or partition_ids for cross 45 // partition permute). The source ID is the id which will send data to this 46 // ID and the target ID is the id to which this ID will send its data. Either 47 // can be optional. GetSourceTargetNcclCollectivePermuteConfig48 static SourceTargetMapEntry GetSourceTarget( 49 const IdToSourceTargetMap& id_to_source_target, int64_t id) { 50 auto it = id_to_source_target.find(id); 51 if (it != id_to_source_target.end()) return it->second; 52 return SourceTargetMapEntry{}; 53 } 54 55 NcclCollectiveConfig config; 56 IdToSourceTargetMap id_to_source_target; 57 }; 58 59 // Thunk that performs a NCCL-based collective permute. 60 class NcclCollectivePermuteThunk : public NcclCollectiveThunk { 61 public: 62 static NcclCollectivePermuteConfig GetNcclCollectivePermuteConfig( 63 mlir::lmhlo::CollectivePermuteOp op, int64_t replica_count, 64 int64_t partition_count); 65 66 NcclCollectivePermuteThunk(ThunkInfo thunk_info, 67 mlir::lmhlo::CollectivePermuteOp op, 68 int64_t replica_count, int64_t partition_count, 69 const Buffer& buffer); 70 71 // Returns whether the given instruction can be lowered to a nccl collective 72 // permute thunk. 73 static bool CanImplement(mlir::lmhlo::CollectivePermuteOp op); 74 GetName()75 static const char* GetName() { return "CollectivePermute"; } 76 static bool IsDegenerate(mlir::lmhlo::CollectivePermuteOp op, 77 int64_t replica_count, int64_t partition_count); GetGroupMode(mlir::lmhlo::CollectivePermuteOp op)78 static CollectiveOpGroupMode GetGroupMode( 79 mlir::lmhlo::CollectivePermuteOp op) { 80 return GetCollectiveOpGroupMode(op.getChannelId().has_value(), std::nullopt) 81 .ValueOrDie(); 82 } 83 84 protected: 85 Status RunNcclCollective(const ExecuteParams& params, 86 ncclComm_t comm) override; 87 config()88 const NcclCollectiveConfig& config() const override { return config_.config; } 89 90 private: 91 const NcclCollectivePermuteConfig config_; 92 const Buffer buffer_; 93 }; 94 95 Status RunCollectivePermute( 96 NcclCollectivePermuteConfig::SourceTargetMapEntry source_target, 97 DeviceBufferPair& buffer, se::Stream& stream, ncclComm_t comm, 98 absl::string_view device_string, int64_t current_id); 99 100 } // namespace gpu 101 } // namespace xla 102 103 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ 104