1 // Copyright 2021 The TensorFlow Runtime Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 //===- ccl_pattern.cc
16 //-------------------------------------------------------------------------===//
17 //
18 // Pattern to lower lmhlo collective ops to tfrt_gpu/xlir dialect.
19 //
20 //===----------------------------------------------------------------------===//
21 #include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h"
22
23 #include <functional>
24 #include <string>
25
26 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
27 #include "mlir/IR/BlockAndValueMapping.h"
28 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
29 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
30 #include "tensorflow/compiler/xla/service/gpu/xlir_ops.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime
33 #include "tfrt/gpu/pass/pass.h" // from @tf_runtime
34
35 namespace tensorflow {
36 namespace {
37
ToNcclReduction(xla::ReductionKind kind)38 ncclRedOp_t ToNcclReduction(xla::ReductionKind kind) {
39 switch (kind) {
40 case xla::ReductionKind::SUM:
41 return ncclSum;
42 case xla::ReductionKind::PRODUCT:
43 return ncclProd;
44 case xla::ReductionKind::MIN:
45 return ncclMin;
46 case xla::ReductionKind::MAX:
47 return ncclMax;
48 }
49 }
50
ToNcclDataType(xla::PrimitiveType element_type)51 FailureOr<ncclDataType_t> ToNcclDataType(xla::PrimitiveType element_type) {
52 switch (element_type) {
53 case xla::S8:
54 return ncclInt8;
55 case xla::PRED:
56 case xla::U8:
57 return ncclUint8;
58 case xla::S32:
59 return ncclInt32;
60 case xla::U32:
61 return ncclUint32;
62 case xla::S64:
63 return ncclInt64;
64 case xla::U64:
65 return ncclUint64;
66 case xla::F16:
67 return ncclFloat16;
68 case xla::F32:
69 case xla::C64:
70 return ncclFloat32;
71 case xla::F64:
72 case xla::C128:
73 return ncclFloat64;
74 #if defined(__CUDA_BF16_TYPES_EXIST__)
75 case xla::BF16:
76 return ncclBfloat16;
77 #endif
78 default:
79 return mlir::failure();
80 }
81 }
82
CclOpConversionRewrite(lmhlo::AllReduceOp srcOp,Value chain,Value stream,mlir::BlockAndValueMapping & mapping,ConversionPatternRewriter & rewriter)83 FailureOr<Value> CclOpConversionRewrite(lmhlo::AllReduceOp srcOp, Value chain,
84 Value stream,
85 mlir::BlockAndValueMapping& mapping,
86 ConversionPatternRewriter& rewriter) {
87 const auto& operands = srcOp.operands();
88 const auto& results = srcOp.results();
89 if (operands.size() != results.size()) {
90 return rewriter.notifyMatchFailure(
91 srcOp, "Number of operands and results do not match.");
92 }
93
94 auto reduction_kind =
95 xla::gpu::NcclAllReduceThunkBase::MatchAllReduceComputation(
96 srcOp.computation());
97 if (!reduction_kind.has_value()) {
98 return rewriter.notifyMatchFailure(
99 srcOp,
100 "Failed to match the reduction computation to a reduction kind.");
101 }
102 ncclRedOp_t reduction_op = ToNcclReduction(*reduction_kind);
103
104 auto context =
105 rewriter.create<tfrt::gpu::StreamGetContextOp>(srcOp.getLoc(), stream)
106 .getResult();
107
108 auto handle = rewriter.create<xla::gpu::CclCreateOp>(srcOp.getLoc(), context)
109 .getResult();
110
111 for (int i = 0; i < operands.size(); i++) {
112 xla::Shape shape = xla::TypeToShape(operands[i].getType());
113 auto nccl_data_type_or = ToNcclDataType(shape.element_type());
114 if (mlir::failed(nccl_data_type_or)) {
115 return rewriter.notifyMatchFailure(
116 srcOp, "Failed to convert operand data type to ncclDataType_t.");
117 }
118 ncclDataType_t nccl_data_type = nccl_data_type_or.getValue();
119
120 Value input = mapping.lookup(operands[i]);
121 Value output = mapping.lookup(results[i]);
122
123 chain = rewriter
124 .create<tfrt::gpu::CclAllReduceOp>(
125 srcOp.getLoc(), handle, input, output, nccl_data_type,
126 reduction_op, chain)
127 .getResult();
128 }
129
130 return rewriter
131 .create<tfrt::gpu::CclExecuteOp>(srcOp.getLoc(), stream, handle, chain)
132 .getResult();
133 }
134
135 // TODO(hanbinyoon): Support additional lmhlo collective ops (in addition to
136 // lmhlo::AllReduceOp).
137 struct CclRewritePattern
138 : tfrt::gpu::GpuAsyncOpConversionPattern<lmhlo::AllReduceOp> {
139 using tfrt::gpu::GpuAsyncOpConversionPattern<
140 lmhlo::AllReduceOp>::GpuAsyncOpConversionPattern;
matchAndRewriteOptensorflow::__anond6ddd7dd0111::CclRewritePattern141 FailureOr<Value> matchAndRewriteOp(
142 lmhlo::AllReduceOp op, Value chain, Value stream,
143 ArrayRef<Value> operands,
144 ConversionPatternRewriter& rewriter) const override {
145 if (!all_of(operands, [](Value operand) {
146 return operand.getType().isa<tfrt::gpu::BufferType>();
147 }))
148 return rewriter.notifyMatchFailure(op, "expected buffer operands");
149
150 BlockAndValueMapping mapping;
151 for (auto pair : llvm::zip_first(op->getOperands(), operands))
152 mapping.map(std::get<0>(pair), std::get<1>(pair));
153
154 auto out_chain_or =
155 CclOpConversionRewrite(op, chain, stream, mapping, rewriter);
156 if (mlir::succeeded(out_chain_or)) {
157 rewriter.eraseOp(op);
158 }
159 return out_chain_or;
160 }
161 };
162
163 } // namespace
164
populateCclConversionPattern(RewritePatternSet & patterns)165 void populateCclConversionPattern(RewritePatternSet& patterns) {
166 patterns.add<CclRewritePattern>(patterns.getContext());
167 }
168
169 } // namespace tensorflow
170