xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 // Copyright 2022 The TensorFlow Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XLA_RUNTIME_CUSTOM_CALLS_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XLA_RUNTIME_CUSTOM_CALLS_H_
17 
18 #include <cstdint>
19 #include <memory>
20 #include <tuple>
21 
22 #include "llvm/ExecutionEngine/Orc/Core.h"
23 #include "llvm/ExecutionEngine/Orc/Mangling.h"
24 #include "tensorflow/compiler/xla/mlir/transforms/runtime/custom_call_encoding.h"
25 #include "tensorflow/compiler/xla/runtime/custom_call.h"
26 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h"
27 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
28 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
29 #include "tfrt/support/type_id.h"  // from @tf_runtime
30 
31 namespace xla {
32 namespace gpu {
33 class JitRtKernelsCache;
34 class JitRtGemmConfigCache;
35 class JitRtCollectiveSupport;
36 class JitRtAsyncCollectiveSupport;
37 
38 struct DotDimensionNumbers {
39   llvm::ArrayRef<int64_t> lhs_batch;
40   llvm::ArrayRef<int64_t> lhs_contract;
41   llvm::ArrayRef<int64_t> rhs_batch;
42   llvm::ArrayRef<int64_t> rhs_contract;
43 };
44 
45 struct ConvDimensionNumbers {
46   int64_t input_batch_dim;
47   int64_t input_feature_dim;
48   llvm::ArrayRef<int64_t> input_spatial_dims;
49 
50   int64_t kernel_in_feature_dim;
51   int64_t kernel_out_feature_dim;
52   llvm::ArrayRef<int64_t> kernel_spatial_dims;
53 
54   int64_t output_batch_dim;
55   int64_t output_feature_dim;
56   llvm::ArrayRef<int64_t> output_spatial_dims;
57 };
58 
59 struct ConvBackendConfig {
60   int64_t algorithm;
61   bool tensor_ops_enabled;
62   bool is_cudnn_frontend;
63   llvm::ArrayRef<int64_t> knob_ids;
64   llvm::ArrayRef<int64_t> knob_values;
65   llvm::ArrayRef<int64_t> operand_0_layout;
66   llvm::ArrayRef<int64_t> operand_1_layout;
67   llvm::ArrayRef<int64_t> result_layout;
68   int64_t workspace_size;
69 };
70 }  // namespace gpu
71 }  // namespace xla
72 
73 namespace xla {
74 namespace runtime {
75 
76 using llvm::ArrayRef;
77 
78 XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(stream_executor::dnn::ActivationMode);
79 XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(stream_executor::fft::Type);
80 XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(
81     stream_executor::cuda::BlasLt::Epilogue);
82 
83 XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING(
84     xla::gpu::DotDimensionNumbers,
85     XLA_RUNTIME_AGGREGATE_FIELDS("lhs_batch", "lhs_contract", "rhs_batch",
86                                  "rhs_contract"),
87     ArrayRef<int64_t>, ArrayRef<int64_t>, ArrayRef<int64_t>, ArrayRef<int64_t>);
88 
89 XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING(
90     xla::gpu::ConvDimensionNumbers,
91     XLA_RUNTIME_AGGREGATE_FIELDS("input_batch_dim", "input_feature_dim",
92                                  "input_spatial_dims", "kernel_in_feature_dim",
93                                  "kernel_out_feature_dim",
94                                  "kernel_spatial_dims", "output_batch_dim",
95                                  "output_feature_dim", "output_spatial_dims"),
96     int64_t, int64_t, ArrayRef<int64_t>, int64_t, int64_t, ArrayRef<int64_t>,
97     int64_t, int64_t, ArrayRef<int64_t>);
98 
99 XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING(
100     xla::gpu::ConvBackendConfig,
101     XLA_RUNTIME_AGGREGATE_FIELDS("algorithm", "tensor_ops_enabled",
102                                  "is_cudnn_frontend", "knob_ids", "knob_values",
103                                  "operand_0_layout", "operand_1_layout",
104                                  "result_layout", "workspace_size"),
105     int64_t, bool, bool, ArrayRef<int64_t>, ArrayRef<int64_t>,
106     ArrayRef<int64_t>, ArrayRef<int64_t>, ArrayRef<int64_t>, int64_t);
107 
108 }  // namespace runtime
109 }  // namespace xla
110 
111 // Declare explicit dense type ids for all types passed to the custom calls
112 // as a user data to generate template specializations for fast id lookup.
113 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
114                                     xla::gpu::JitRtKernelsCache);
115 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
116                                     xla::gpu::JitRtGemmConfigCache);
117 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
118                                     xla::gpu::JitRtCollectiveSupport);
119 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
120                                     xla::gpu::JitRtAsyncCollectiveSupport);
121 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
122                                     const xla::ServiceExecutableRunOptions);
123 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
124                                     const xla::DebugOptions);
125 
126 namespace xla {
127 namespace gpu {
128 
129 // Populate mapping from XLA (SE) enums/structs type id to symbol names.
130 void PopulateXlaTypeIdNames(runtime::TypeIDNameRegistry& registry);
131 
132 // Populate encoding from LMHLO attributes to XLA(SE) enums and structs.
133 void PopulateLmhloToXlaAttrEncoding(
134     runtime::CustomCallAttrEncodingSet& encoding);
135 
136 class JitRtKernelsCache {
137  public:
138   JitRtKernelsCache() = default;
139 
140   ::stream_executor::KernelBase* Get(
141       ::stream_executor::StreamExecutor* executor, const char* data,
142       llvm::StringRef name);
143 
144   ::stream_executor::KernelBase* Set(
145       ::stream_executor::StreamExecutor* executor, const char* data,
146       llvm::StringRef name,
147       std::unique_ptr<::stream_executor::KernelBase> kernel);
148 
149  private:
150   mutable absl::Mutex mutex_;
151 
152   using Key = std::tuple<::stream_executor::StreamExecutor*, const char*,
153                          llvm::StringRef>;
154   llvm::SmallDenseMap<Key, std::unique_ptr<::stream_executor::KernelBase>>
155       kernels_cache_ ABSL_GUARDED_BY(mutex_);
156 };
157 
158 class JitRtGemmConfigCache {
159  public:
160   const GemmConfig* Get(int64_t uid);
161   const GemmConfig* Set(int64_t uid, GemmConfig config);
162 
163  private:
164   mutable absl::Mutex mutex_;
165 
166   llvm::SmallDenseMap<int64_t, GemmConfig> configs_ ABSL_GUARDED_BY(mutex_);
167 };
168 
169 class JitRtCollectiveSupport {
170  public:
171   // Maybe block host after the first call to the collective operation with the
172   // given uid, to ensure that all devices have allocated the required buffers
173   // for their communicators before allowing any device to continue enqueuing
174   // operations. Otherwise, the allocations can cause deadlock in the CUDA
175   // driver.
176   //
177   // This basically ports workaround form cr/435058849 to JitRt (see details in
178   // the b/215649390).
179   Status MaybeBlockAfterFirstRun(int32_t uid, int32_t device_ordinal,
180                                  se::Stream* stream);
181 
182  private:
Key(int32_t uid,int32_t device_ordinal)183   static int64_t Key(int32_t uid, int32_t device_ordinal) {
184     return static_cast<int64_t>(uid) << 32 | device_ordinal;
185   }
186 
187   mutable absl::Mutex mutex_;
188 
189   // Store if a particular collective operation was executed at least once. We
190   // rely on unique `uid` assigned to each collective operation by the lowering
191   // pass.
192   llvm::SmallDenseMap<int64_t, bool> executed_ ABSL_GUARDED_BY(mutex_);
193 };
194 
195 // Support for running async collective operations communicating via events.
196 class JitRtAsyncCollectiveSupport {
197  public:
198   explicit JitRtAsyncCollectiveSupport(se::Stream* async_comm_stream);
199 
200   mlir::FailureOr<se::Event> PopEvent(int32_t uid, int32_t device_ordinal);
201   mlir::LogicalResult PushEvent(int32_t uid, int32_t device_ordinal,
202                                 se::Event done_event);
203 
async_comm_stream()204   ::stream_executor::Stream* async_comm_stream() const {
205     return async_comm_stream_;
206   }
207 
208  private:
EventKey(int32_t uid,int32_t device_ordinal)209   static int64_t EventKey(int32_t uid, int32_t device_ordinal) {
210     return static_cast<int64_t>(uid) << 32 | device_ordinal;
211   }
212 
213   mutable absl::Mutex mutex_;
214 
215   ::stream_executor::Stream* async_comm_stream_;
216 
217   // Store done events for the AllReduceDone to wait on.
218   llvm::SmallDenseMap<int64_t, se::Event> done_events_ ABSL_GUARDED_BY(mutex_);
219 };
220 
221 xla::runtime::DirectCustomCallLibrary JitRtGpuCustomCalls();
222 
223 }  // namespace gpu
224 }  // namespace xla
225 
226 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XLA_RUNTIME_CUSTOM_CALLS_H_
227