1 // Copyright 2021 The TensorFlow Runtime Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //===- ccl_pattern.cc
16 //-------------------------------------------------------------------------===//
17 //
18 // Pattern to lower lmhlo collective ops to tfrt_gpu/xlir dialect.
19 //
20 //===----------------------------------------------------------------------===//
21 #include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h"
22 
23 #include <functional>
24 #include <string>
25 
26 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
27 #include "mlir/IR/BlockAndValueMapping.h"
28 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
29 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
30 #include "tensorflow/compiler/xla/service/gpu/xlir_ops.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tfrt/gpu/kernels/gpu_ops.h"  // from @tf_runtime
33 #include "tfrt/gpu/pass/pass.h"  // from @tf_runtime
34 
35 namespace tensorflow {
36 namespace {
37 
ToNcclReduction(xla::ReductionKind kind)38 ncclRedOp_t ToNcclReduction(xla::ReductionKind kind) {
39   switch (kind) {
40     case xla::ReductionKind::SUM:
41       return ncclSum;
42     case xla::ReductionKind::PRODUCT:
43       return ncclProd;
44     case xla::ReductionKind::MIN:
45       return ncclMin;
46     case xla::ReductionKind::MAX:
47       return ncclMax;
48   }
49 }
50 
ToNcclDataType(xla::PrimitiveType element_type)51 FailureOr<ncclDataType_t> ToNcclDataType(xla::PrimitiveType element_type) {
52   switch (element_type) {
53     case xla::S8:
54       return ncclInt8;
55     case xla::PRED:
56     case xla::U8:
57       return ncclUint8;
58     case xla::S32:
59       return ncclInt32;
60     case xla::U32:
61       return ncclUint32;
62     case xla::S64:
63       return ncclInt64;
64     case xla::U64:
65       return ncclUint64;
66     case xla::F16:
67       return ncclFloat16;
68     case xla::F32:
69     case xla::C64:
70       return ncclFloat32;
71     case xla::F64:
72     case xla::C128:
73       return ncclFloat64;
74 #if defined(__CUDA_BF16_TYPES_EXIST__)
75     case xla::BF16:
76       return ncclBfloat16;
77 #endif
78     default:
79       return mlir::failure();
80   }
81 }
82 
CclOpConversionRewrite(lmhlo::AllReduceOp srcOp,Value chain,Value stream,mlir::BlockAndValueMapping & mapping,ConversionPatternRewriter & rewriter)83 FailureOr<Value> CclOpConversionRewrite(lmhlo::AllReduceOp srcOp, Value chain,
84                                         Value stream,
85                                         mlir::BlockAndValueMapping& mapping,
86                                         ConversionPatternRewriter& rewriter) {
87   const auto& operands = srcOp.operands();
88   const auto& results = srcOp.results();
89   if (operands.size() != results.size()) {
90     return rewriter.notifyMatchFailure(
91         srcOp, "Number of operands and results do not match.");
92   }
93 
94   auto reduction_kind =
95       xla::gpu::NcclAllReduceThunkBase::MatchAllReduceComputation(
96           srcOp.computation());
97   if (!reduction_kind.has_value()) {
98     return rewriter.notifyMatchFailure(
99         srcOp,
100         "Failed to match the reduction computation to a reduction kind.");
101   }
102   ncclRedOp_t reduction_op = ToNcclReduction(*reduction_kind);
103 
104   auto context =
105       rewriter.create<tfrt::gpu::StreamGetContextOp>(srcOp.getLoc(), stream)
106           .getResult();
107 
108   auto handle = rewriter.create<xla::gpu::CclCreateOp>(srcOp.getLoc(), context)
109                     .getResult();
110 
111   for (int i = 0; i < operands.size(); i++) {
112     xla::Shape shape = xla::TypeToShape(operands[i].getType());
113     auto nccl_data_type_or = ToNcclDataType(shape.element_type());
114     if (mlir::failed(nccl_data_type_or)) {
115       return rewriter.notifyMatchFailure(
116           srcOp, "Failed to convert operand data type to ncclDataType_t.");
117     }
118     ncclDataType_t nccl_data_type = nccl_data_type_or.getValue();
119 
120     Value input = mapping.lookup(operands[i]);
121     Value output = mapping.lookup(results[i]);
122 
123     chain = rewriter
124                 .create<tfrt::gpu::CclAllReduceOp>(
125                     srcOp.getLoc(), handle, input, output, nccl_data_type,
126                     reduction_op, chain)
127                 .getResult();
128   }
129 
130   return rewriter
131       .create<tfrt::gpu::CclExecuteOp>(srcOp.getLoc(), stream, handle, chain)
132       .getResult();
133 }
134 
135 // TODO(hanbinyoon): Support additional lmhlo collective ops (in addition to
136 // lmhlo::AllReduceOp).
137 struct CclRewritePattern
138     : tfrt::gpu::GpuAsyncOpConversionPattern<lmhlo::AllReduceOp> {
139   using tfrt::gpu::GpuAsyncOpConversionPattern<
140       lmhlo::AllReduceOp>::GpuAsyncOpConversionPattern;
matchAndRewriteOptensorflow::__anond6ddd7dd0111::CclRewritePattern141   FailureOr<Value> matchAndRewriteOp(
142       lmhlo::AllReduceOp op, Value chain, Value stream,
143       ArrayRef<Value> operands,
144       ConversionPatternRewriter& rewriter) const override {
145     if (!all_of(operands, [](Value operand) {
146           return operand.getType().isa<tfrt::gpu::BufferType>();
147         }))
148       return rewriter.notifyMatchFailure(op, "expected buffer operands");
149 
150     BlockAndValueMapping mapping;
151     for (auto pair : llvm::zip_first(op->getOperands(), operands))
152       mapping.map(std::get<0>(pair), std::get<1>(pair));
153 
154     auto out_chain_or =
155         CclOpConversionRewrite(op, chain, stream, mapping, rewriter);
156     if (mlir::succeeded(out_chain_or)) {
157       rewriter.eraseOp(op);
158     }
159     return out_chain_or;
160   }
161 };
162 
163 }  // namespace
164 
populateCclConversionPattern(RewritePatternSet & patterns)165 void populateCclConversionPattern(RewritePatternSet& patterns) {
166   patterns.add<CclRewritePattern>(patterns.getContext());
167 }
168 
169 }  // namespace tensorflow
170