1 // Copyright 2022 The TensorFlow Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XLA_RUNTIME_CUSTOM_CALLS_H_ 16 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XLA_RUNTIME_CUSTOM_CALLS_H_ 17 18 #include <cstdint> 19 #include <memory> 20 #include <tuple> 21 22 #include "llvm/ExecutionEngine/Orc/Core.h" 23 #include "llvm/ExecutionEngine/Orc/Mangling.h" 24 #include "tensorflow/compiler/xla/mlir/transforms/runtime/custom_call_encoding.h" 25 #include "tensorflow/compiler/xla/runtime/custom_call.h" 26 #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" 27 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" 28 #include "tensorflow/compiler/xla/service/service_executable_run_options.h" 29 #include "tfrt/support/type_id.h" // from @tf_runtime 30 31 namespace xla { 32 namespace gpu { 33 class JitRtKernelsCache; 34 class JitRtGemmConfigCache; 35 class JitRtCollectiveSupport; 36 class JitRtAsyncCollectiveSupport; 37 38 struct DotDimensionNumbers { 39 llvm::ArrayRef<int64_t> lhs_batch; 40 llvm::ArrayRef<int64_t> lhs_contract; 41 llvm::ArrayRef<int64_t> rhs_batch; 42 llvm::ArrayRef<int64_t> rhs_contract; 43 }; 44 45 struct ConvDimensionNumbers { 46 int64_t input_batch_dim; 47 int64_t input_feature_dim; 48 llvm::ArrayRef<int64_t> input_spatial_dims; 49 50 int64_t kernel_in_feature_dim; 51 int64_t kernel_out_feature_dim; 52 llvm::ArrayRef<int64_t> kernel_spatial_dims; 53 54 int64_t output_batch_dim; 55 int64_t output_feature_dim; 56 llvm::ArrayRef<int64_t> output_spatial_dims; 57 }; 58 59 struct ConvBackendConfig { 60 int64_t algorithm; 61 bool tensor_ops_enabled; 62 bool is_cudnn_frontend; 63 llvm::ArrayRef<int64_t> knob_ids; 64 llvm::ArrayRef<int64_t> knob_values; 65 llvm::ArrayRef<int64_t> operand_0_layout; 66 llvm::ArrayRef<int64_t> operand_1_layout; 67 llvm::ArrayRef<int64_t> result_layout; 68 int64_t workspace_size; 69 }; 70 } // namespace gpu 71 } // namespace xla 72 73 namespace xla { 74 namespace runtime { 75 76 using llvm::ArrayRef; 77 78 XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(stream_executor::dnn::ActivationMode); 79 XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(stream_executor::fft::Type); 80 XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING( 81 stream_executor::cuda::BlasLt::Epilogue); 82 83 XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( 84 xla::gpu::DotDimensionNumbers, 85 XLA_RUNTIME_AGGREGATE_FIELDS("lhs_batch", "lhs_contract", "rhs_batch", 86 "rhs_contract"), 87 ArrayRef<int64_t>, ArrayRef<int64_t>, ArrayRef<int64_t>, ArrayRef<int64_t>); 88 89 XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( 90 xla::gpu::ConvDimensionNumbers, 91 XLA_RUNTIME_AGGREGATE_FIELDS("input_batch_dim", "input_feature_dim", 92 "input_spatial_dims", "kernel_in_feature_dim", 93 "kernel_out_feature_dim", 94 "kernel_spatial_dims", "output_batch_dim", 95 "output_feature_dim", "output_spatial_dims"), 96 int64_t, int64_t, ArrayRef<int64_t>, int64_t, int64_t, ArrayRef<int64_t>, 97 int64_t, int64_t, ArrayRef<int64_t>); 98 99 XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( 100 xla::gpu::ConvBackendConfig, 101 XLA_RUNTIME_AGGREGATE_FIELDS("algorithm", "tensor_ops_enabled", 102 "is_cudnn_frontend", "knob_ids", "knob_values", 103 "operand_0_layout", "operand_1_layout", 104 "result_layout", "workspace_size"), 105 int64_t, bool, bool, ArrayRef<int64_t>, ArrayRef<int64_t>, 106 ArrayRef<int64_t>, ArrayRef<int64_t>, ArrayRef<int64_t>, int64_t); 107 108 } // namespace runtime 109 } // namespace xla 110 111 // Declare explicit dense type ids for all types passed to the custom calls 112 // as a user data to generate template specializations for fast id lookup. 113 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall, 114 xla::gpu::JitRtKernelsCache); 115 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall, 116 xla::gpu::JitRtGemmConfigCache); 117 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall, 118 xla::gpu::JitRtCollectiveSupport); 119 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall, 120 xla::gpu::JitRtAsyncCollectiveSupport); 121 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall, 122 const xla::ServiceExecutableRunOptions); 123 TFRT_DECLARE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall, 124 const xla::DebugOptions); 125 126 namespace xla { 127 namespace gpu { 128 129 // Populate mapping from XLA (SE) enums/structs type id to symbol names. 130 void PopulateXlaTypeIdNames(runtime::TypeIDNameRegistry& registry); 131 132 // Populate encoding from LMHLO attributes to XLA(SE) enums and structs. 133 void PopulateLmhloToXlaAttrEncoding( 134 runtime::CustomCallAttrEncodingSet& encoding); 135 136 class JitRtKernelsCache { 137 public: 138 JitRtKernelsCache() = default; 139 140 ::stream_executor::KernelBase* Get( 141 ::stream_executor::StreamExecutor* executor, const char* data, 142 llvm::StringRef name); 143 144 ::stream_executor::KernelBase* Set( 145 ::stream_executor::StreamExecutor* executor, const char* data, 146 llvm::StringRef name, 147 std::unique_ptr<::stream_executor::KernelBase> kernel); 148 149 private: 150 mutable absl::Mutex mutex_; 151 152 using Key = std::tuple<::stream_executor::StreamExecutor*, const char*, 153 llvm::StringRef>; 154 llvm::SmallDenseMap<Key, std::unique_ptr<::stream_executor::KernelBase>> 155 kernels_cache_ ABSL_GUARDED_BY(mutex_); 156 }; 157 158 class JitRtGemmConfigCache { 159 public: 160 const GemmConfig* Get(int64_t uid); 161 const GemmConfig* Set(int64_t uid, GemmConfig config); 162 163 private: 164 mutable absl::Mutex mutex_; 165 166 llvm::SmallDenseMap<int64_t, GemmConfig> configs_ ABSL_GUARDED_BY(mutex_); 167 }; 168 169 class JitRtCollectiveSupport { 170 public: 171 // Maybe block host after the first call to the collective operation with the 172 // given uid, to ensure that all devices have allocated the required buffers 173 // for their communicators before allowing any device to continue enqueuing 174 // operations. Otherwise, the allocations can cause deadlock in the CUDA 175 // driver. 176 // 177 // This basically ports workaround form cr/435058849 to JitRt (see details in 178 // the b/215649390). 179 Status MaybeBlockAfterFirstRun(int32_t uid, int32_t device_ordinal, 180 se::Stream* stream); 181 182 private: Key(int32_t uid,int32_t device_ordinal)183 static int64_t Key(int32_t uid, int32_t device_ordinal) { 184 return static_cast<int64_t>(uid) << 32 | device_ordinal; 185 } 186 187 mutable absl::Mutex mutex_; 188 189 // Store if a particular collective operation was executed at least once. We 190 // rely on unique `uid` assigned to each collective operation by the lowering 191 // pass. 192 llvm::SmallDenseMap<int64_t, bool> executed_ ABSL_GUARDED_BY(mutex_); 193 }; 194 195 // Support for running async collective operations communicating via events. 196 class JitRtAsyncCollectiveSupport { 197 public: 198 explicit JitRtAsyncCollectiveSupport(se::Stream* async_comm_stream); 199 200 mlir::FailureOr<se::Event> PopEvent(int32_t uid, int32_t device_ordinal); 201 mlir::LogicalResult PushEvent(int32_t uid, int32_t device_ordinal, 202 se::Event done_event); 203 async_comm_stream()204 ::stream_executor::Stream* async_comm_stream() const { 205 return async_comm_stream_; 206 } 207 208 private: EventKey(int32_t uid,int32_t device_ordinal)209 static int64_t EventKey(int32_t uid, int32_t device_ordinal) { 210 return static_cast<int64_t>(uid) << 32 | device_ordinal; 211 } 212 213 mutable absl::Mutex mutex_; 214 215 ::stream_executor::Stream* async_comm_stream_; 216 217 // Store done events for the AllReduceDone to wait on. 218 llvm::SmallDenseMap<int64_t, se::Event> done_events_ ABSL_GUARDED_BY(mutex_); 219 }; 220 221 xla::runtime::DirectCustomCallLibrary JitRtGpuCustomCalls(); 222 223 } // namespace gpu 224 } // namespace xla 225 226 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XLA_RUNTIME_CUSTOM_CALLS_H_ 227