1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The CUDA-specific DNN library support, implementing the general DnnSupport 17 // interface. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_ 20 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_ 21 22 #include "absl/base/thread_annotations.h" 23 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_activation.h" 24 #include "tensorflow/compiler/xla/stream_executor/dnn.h" 25 #include "tensorflow/compiler/xla/stream_executor/lib/status.h" 26 #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h" 27 #include "tensorflow/compiler/xla/stream_executor/temporary_device_memory.h" 28 29 namespace stream_executor { 30 namespace gpu { 31 32 class GpuExecutor; 33 class CudnnRnnDescriptor; 34 class CudnnRnnSequenceTensorDescriptor; 35 class CudnnRnnStateTensorDescriptor; 36 class CudnnCtcLossDescriptor; 37 38 // Opaque and unique identifier for the cuDNN plugin. 39 extern const PluginId kCuDnnPlugin; 40 41 using BatchDescriptorSlice = 42 port::ArraySlice<dnn::BatchDescriptor>; // non-absl ok 43 44 template <typename T> 45 using DeviceMemorySlice = 46 port::ArraySlice<const DeviceMemory<T>*>; // non-absl ok 47 48 // cudnn-library based DNN support. For details on overridden interface 49 // functions, see dnn.h. 50 class CudnnSupport : public dnn::DnnSupport { 51 public: 52 explicit CudnnSupport(GpuExecutor* parent); 53 54 port::Status Init() override; 55 port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override; 56 57 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( 58 int num_layers, int hidden_size, int input_size, int cell_size, 59 int batch_size, dnn::RnnInputMode input_mode, 60 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, 61 dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, 62 float dropout, uint64_t seed, ScratchAllocator* state_allocator, 63 bool use_padded_io) override; 64 65 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> 66 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, 67 int data_size, 68 dnn::DataType data_type) override; 69 70 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> 71 createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, 72 int data_size, 73 const absl::Span<const int>& seq_lengths, 74 bool time_major, 75 dnn::DataType data_type) override; 76 77 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>> 78 createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, 79 dnn::DataType data_type) override; 80 81 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, 82 const dnn::RnnSequenceTensorDescriptor& input_desc, 83 const DeviceMemory<Eigen::half>& input_data, 84 const DeviceMemory<int>& seq_lengths_data, 85 const dnn::RnnStateTensorDescriptor& input_h_desc, 86 const DeviceMemory<Eigen::half>& input_h_data, 87 const dnn::RnnStateTensorDescriptor& input_c_desc, 88 const DeviceMemory<Eigen::half>& input_c_data, 89 const DeviceMemory<Eigen::half>& params, 90 const dnn::RnnSequenceTensorDescriptor& output_desc, 91 DeviceMemory<Eigen::half>* output_data, 92 const dnn::RnnStateTensorDescriptor& output_h_desc, 93 DeviceMemory<Eigen::half>* output_h_data, 94 const dnn::RnnStateTensorDescriptor& output_c_desc, 95 DeviceMemory<Eigen::half>* output_c_data, bool is_training, 96 ScratchAllocator* reserve_space_allocator, 97 ScratchAllocator* workspace_allocator, 98 dnn::ProfileResult* output_profile_result) override; 99 100 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, 101 const dnn::RnnSequenceTensorDescriptor& input_desc, 102 const DeviceMemory<float>& input_data, 103 const DeviceMemory<int>& seq_lengths_data, 104 const dnn::RnnStateTensorDescriptor& input_h_desc, 105 const DeviceMemory<float>& input_h_data, 106 const dnn::RnnStateTensorDescriptor& input_c_desc, 107 const DeviceMemory<float>& input_c_data, 108 const DeviceMemory<float>& params, 109 const dnn::RnnSequenceTensorDescriptor& output_desc, 110 DeviceMemory<float>* output_data, 111 const dnn::RnnStateTensorDescriptor& output_h_desc, 112 DeviceMemory<float>* output_h_data, 113 const dnn::RnnStateTensorDescriptor& output_c_desc, 114 DeviceMemory<float>* output_c_data, bool is_training, 115 ScratchAllocator* reserve_space_allocator, 116 ScratchAllocator* workspace_allocator, 117 dnn::ProfileResult* output_profile_result) override; 118 119 bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, 120 const dnn::RnnSequenceTensorDescriptor& input_desc, 121 const DeviceMemory<double>& input_data, 122 const DeviceMemory<int>& seq_lengths_data, 123 const dnn::RnnStateTensorDescriptor& input_h_desc, 124 const DeviceMemory<double>& input_h_data, 125 const dnn::RnnStateTensorDescriptor& input_c_desc, 126 const DeviceMemory<double>& input_c_data, 127 const DeviceMemory<double>& params, 128 const dnn::RnnSequenceTensorDescriptor& output_desc, 129 DeviceMemory<double>* output_data, 130 const dnn::RnnStateTensorDescriptor& output_h_desc, 131 DeviceMemory<double>* output_h_data, 132 const dnn::RnnStateTensorDescriptor& output_c_desc, 133 DeviceMemory<double>* output_c_data, bool is_training, 134 ScratchAllocator* reserve_space_allocator, 135 ScratchAllocator* workspace_allocator, 136 dnn::ProfileResult* output_profile_result) override; 137 138 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, 139 const dnn::RnnSequenceTensorDescriptor& input_desc, 140 const DeviceMemory<Eigen::half>& input_data, 141 const DeviceMemory<int>& seq_lengths_data, 142 const dnn::RnnStateTensorDescriptor& input_h_desc, 143 const DeviceMemory<Eigen::half>& input_h_data, 144 const dnn::RnnStateTensorDescriptor& input_c_desc, 145 const DeviceMemory<Eigen::half>& input_c_data, 146 const DeviceMemory<Eigen::half>& params, 147 const dnn::RnnSequenceTensorDescriptor& output_desc, 148 const DeviceMemory<Eigen::half>& output_data, 149 const dnn::RnnStateTensorDescriptor& output_h_desc, 150 const DeviceMemory<Eigen::half>& output_h_data, 151 const dnn::RnnStateTensorDescriptor& output_c_desc, 152 const DeviceMemory<Eigen::half>& output_c_data, 153 const DeviceMemory<Eigen::half>& output_backprop_data, 154 const DeviceMemory<Eigen::half>& output_h_backprop_data, 155 const DeviceMemory<Eigen::half>& output_c_backprop_data, 156 DeviceMemory<Eigen::half>* input_backprop_data, 157 DeviceMemory<Eigen::half>* input_h_backprop_data, 158 DeviceMemory<Eigen::half>* input_c_backprop_data, 159 DeviceMemory<Eigen::half>* params_backprop_data, 160 DeviceMemory<uint8>* reserve_space_data, 161 ScratchAllocator* workspace_allocator, 162 dnn::ProfileResult* output_profile_result) override; 163 164 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, 165 const dnn::RnnSequenceTensorDescriptor& input_desc, 166 const DeviceMemory<float>& input_data, 167 const DeviceMemory<int>& seq_lengths_data, 168 const dnn::RnnStateTensorDescriptor& input_h_desc, 169 const DeviceMemory<float>& input_h_data, 170 const dnn::RnnStateTensorDescriptor& input_c_desc, 171 const DeviceMemory<float>& input_c_data, 172 const DeviceMemory<float>& params, 173 const dnn::RnnSequenceTensorDescriptor& output_desc, 174 const DeviceMemory<float>& output_data, 175 const dnn::RnnStateTensorDescriptor& output_h_desc, 176 const DeviceMemory<float>& output_h_data, 177 const dnn::RnnStateTensorDescriptor& output_c_desc, 178 const DeviceMemory<float>& output_c_data, 179 const DeviceMemory<float>& output_backprop_data, 180 const DeviceMemory<float>& output_h_backprop_data, 181 const DeviceMemory<float>& output_c_backprop_data, 182 DeviceMemory<float>* input_backprop_data, 183 DeviceMemory<float>* input_h_backprop_data, 184 DeviceMemory<float>* input_c_backprop_data, 185 DeviceMemory<float>* params_backprop_data, 186 DeviceMemory<uint8>* reserve_space_data, 187 ScratchAllocator* workspace_allocator, 188 dnn::ProfileResult* output_profile_result) override; 189 190 bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, 191 const dnn::RnnSequenceTensorDescriptor& input_desc, 192 const DeviceMemory<double>& input_data, 193 const DeviceMemory<int>& seq_lengths_data, 194 const dnn::RnnStateTensorDescriptor& input_h_desc, 195 const DeviceMemory<double>& input_h_data, 196 const dnn::RnnStateTensorDescriptor& input_c_desc, 197 const DeviceMemory<double>& input_c_data, 198 const DeviceMemory<double>& params, 199 const dnn::RnnSequenceTensorDescriptor& output_desc, 200 const DeviceMemory<double>& output_data, 201 const dnn::RnnStateTensorDescriptor& output_h_desc, 202 const DeviceMemory<double>& output_h_data, 203 const dnn::RnnStateTensorDescriptor& output_c_desc, 204 const DeviceMemory<double>& output_c_data, 205 const DeviceMemory<double>& output_backprop_data, 206 const DeviceMemory<double>& output_h_backprop_data, 207 const DeviceMemory<double>& output_c_backprop_data, 208 DeviceMemory<double>* input_backprop_data, 209 DeviceMemory<double>* input_h_backprop_data, 210 DeviceMemory<double>* input_c_backprop_data, 211 DeviceMemory<double>* params_backprop_data, 212 DeviceMemory<uint8>* reserve_space_data, 213 ScratchAllocator* workspace_allocator, 214 dnn::ProfileResult* output_profile_result) override; 215 216 bool GetConvolveAlgorithms( 217 CudaComputeCapability cuda_compute_capability, 218 std::vector<dnn::AlgorithmDesc>* out_algorithms) override; 219 220 port::Status GetConvolveRunners( 221 bool use_cudnn_frontend, dnn::ConvolutionKind kind, 222 dnn::DataType input_type, dnn::DataType output_type, Stream* stream, 223 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, 224 const dnn::FilterDescriptor& filter_descriptor, 225 DeviceMemoryBase filter_data, 226 const dnn::BatchDescriptor& output_descriptor, 227 DeviceMemoryBase output_data, 228 const dnn::ConvolutionDescriptor& convolution_descriptor, 229 bool use_fallback, ScratchAllocator* scratch_allocator, 230 std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_exec_plans) 231 override; 232 233 port::StatusOr<std::unique_ptr<const dnn::ConvRunner>> ConvolveRunnerFromDesc( 234 Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, 235 dnn::ConvolutionKind kind, dnn::DataType input_type, 236 dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, 237 const dnn::FilterDescriptor& filter_descriptor, 238 const dnn::BatchDescriptor& output_descriptor, 239 const dnn::ConvolutionDescriptor& convolution_descriptor) override; 240 241 port::Status GetFusedConvolveRunners( 242 bool use_cudnn_frontend, dnn::ConvolutionKind kind, 243 dnn::DataType input_type, dnn::DataType bias_type, 244 dnn::DataType output_type, double conv_scale, double side_input_scale, 245 double leakyrelu_alpha, Stream* stream, 246 const dnn::BatchDescriptor& input_descriptor, 247 const dnn::FilterDescriptor& filter_descriptor, 248 const dnn::BatchDescriptor& bias_descriptor, 249 const dnn::BatchDescriptor& output_descriptor, 250 const dnn::ConvolutionDescriptor& convolution_descriptor, 251 bool use_fallback, dnn::ActivationMode activation_mode, 252 std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans) 253 override; 254 255 port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>> 256 FusedConvolveRunnerFromDesc( 257 Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, 258 dnn::ConvolutionKind kind, dnn::DataType input_type, 259 dnn::DataType bias_type, dnn::DataType output_type, double conv_scale, 260 double side_input_scale, double leakyrelu_alpha, 261 const dnn::BatchDescriptor& input_descriptor, 262 const dnn::FilterDescriptor& filter_descriptor, 263 const dnn::BatchDescriptor& bias_descriptor, 264 const dnn::BatchDescriptor& output_descriptor, 265 const dnn::ConvolutionDescriptor& convolution_descriptor, 266 dnn::ActivationMode activation_mode) override; 267 268 bool GetRnnAlgorithms( 269 std::vector<dnn::AlgorithmDesc>* out_algorithms) override; 270 271 bool GetConvolveBackwardDataAlgorithms( 272 CudaComputeCapability cuda_compute_capability, 273 std::vector<dnn::AlgorithmDesc>* out_algorithms) override; 274 275 bool GetConvolveBackwardFilterAlgorithms( 276 CudaComputeCapability cuda_compute_capability, 277 std::vector<dnn::AlgorithmDesc>* out_algorithms) override; 278 279 bool DoBatchNormalizationForward( 280 Stream* stream, const DeviceMemory<float>& x, 281 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset, 282 const DeviceMemory<float>& estimated_mean, 283 const DeviceMemory<float>& estimated_var_iance, 284 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc, 285 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, 286 const double exponential_average_factor, 287 dnn::ActivationMode activation_mode, DeviceMemory<float>* y, 288 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var, 289 DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var, 290 bool is_training, ScratchAllocator* reserve_space_allocator, 291 ScratchAllocator* workspace_allocator) override; 292 293 bool DoBatchNormalizationForward( 294 Stream* stream, const DeviceMemory<Eigen::half>& x, 295 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset, 296 const DeviceMemory<float>& estimated_mean, 297 const DeviceMemory<float>& estimated_variance, 298 const DeviceMemory<Eigen::half>& side_input, 299 const dnn::BatchDescriptor& x_desc, 300 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, 301 const double exponential_average_factor, 302 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y, 303 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var, 304 DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var, 305 bool is_training, ScratchAllocator* reserve_space_allocator, 306 ScratchAllocator* workspace_allocator) override; 307 308 bool DoBatchNormalizationBackward( 309 Stream* stream, const DeviceMemory<float>& y_backprop, 310 const DeviceMemory<float>& x, const DeviceMemory<float>& scale, 311 const DeviceMemory<float>& offset, const DeviceMemory<float>& mean, 312 const DeviceMemory<float>& inv_var, const DeviceMemory<float>& y, 313 const dnn::BatchDescriptor& x_desc, 314 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, 315 dnn::ActivationMode activation_mode, DeviceMemory<float>* x_backprop, 316 DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop, 317 DeviceMemory<float>* side_input_backprop, 318 DeviceMemory<uint8>* reserve_space_data, 319 ScratchAllocator* workspace_allocator) override; 320 321 bool DoBatchNormalizationBackward( 322 Stream* stream, const DeviceMemory<Eigen::half>& y_backprop, 323 const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale, 324 const DeviceMemory<float>& offset, const DeviceMemory<float>& mean, 325 const DeviceMemory<float>& inv_var, const DeviceMemory<Eigen::half>& y, 326 const dnn::BatchDescriptor& x_desc, 327 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, 328 dnn::ActivationMode activation_mode, 329 DeviceMemory<Eigen::half>* x_backprop, 330 DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop, 331 DeviceMemory<Eigen::half>* side_input_backprop, 332 DeviceMemory<uint8>* reserve_space_data, 333 ScratchAllocator* workspace_allocator) override; 334 335 port::Status DoConvolve( 336 dnn::ConvolutionKind kind, dnn::DataType element_type, 337 dnn::DataType output_type, Stream* stream, 338 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, 339 const dnn::FilterDescriptor& filter_descriptor, 340 DeviceMemoryBase filter_data, 341 const dnn::BatchDescriptor& output_descriptor, 342 DeviceMemoryBase output_data, 343 const dnn::ConvolutionDescriptor& convolution_descriptor, 344 dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory, 345 dnn::ProfileResult* output_profile_result) override; 346 347 port::Status DoFusedConvolve( 348 Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type, 349 dnn::DataType bias_type, dnn::DataType output_type, 350 const dnn::BatchDescriptor& conv_input_descriptor, 351 DeviceMemoryBase conv_input_data, double conv_input_scale, 352 const dnn::FilterDescriptor& filter_descriptor, 353 DeviceMemoryBase filter_data, 354 const dnn::ConvolutionDescriptor& convolution_descriptor, 355 DeviceMemoryBase side_input_data, double side_input_scale, 356 const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases, 357 dnn::ActivationMode activation_mode, 358 const dnn::BatchDescriptor& output_descriptor, 359 DeviceMemoryBase output_data, ScratchAllocator* scratch_allocator, 360 const dnn::AlgorithmConfig& algorithm_config, 361 dnn::ProfileResult* output_profile_result) override; 362 DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)363 bool DoConvolveQuantized( 364 Stream* stream, const dnn::BatchDescriptor& input_descriptor, 365 const DeviceMemory<float>& input_data, 366 const dnn::FilterDescriptor& filter_descriptor, 367 const DeviceMemory<int8>& filter_coefficients, 368 const DeviceMemory<float>& coefficient_scales, 369 const dnn::ConvolutionDescriptor& convolution_descriptor, 370 const dnn::BatchDescriptor& output_descriptor, 371 DeviceMemory<float>* output_data) override { 372 LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN"; 373 return false; 374 } 375 DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)376 bool DoConvolveQuantized( 377 Stream* stream, const dnn::BatchDescriptor& input_descriptor, 378 const DeviceMemory<float>& input_data, 379 const dnn::FilterDescriptor& filter_descriptor, 380 const DeviceMemory<int16>& filter_coefficients, 381 const DeviceMemory<float>& coefficient_scales, 382 const dnn::ConvolutionDescriptor& convolution_descriptor, 383 const dnn::BatchDescriptor& output_descriptor, 384 DeviceMemory<float>* output_data) override { 385 LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN"; 386 return false; 387 } 388 DoSeparableConvolve(Stream * stream,const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)389 bool DoSeparableConvolve( 390 Stream* stream, const dnn::BatchDescriptor& batch_descriptor, 391 const DeviceMemory<float>& input_data, 392 const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier, 393 const DeviceMemory<float>& first_weights, 394 const DeviceMemory<float>& second_weights, 395 const dnn::ConvolutionDescriptor& convolution_descriptor, 396 const dnn::BatchDescriptor& output_descriptor, 397 DeviceMemory<float>* output_data) override { 398 LOG(ERROR) << "separable convolution not supported by CUDNN"; 399 return false; 400 } 401 402 bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data, 403 const DeviceMemory<float>& weights, 404 const dnn::BatchDescriptor& input_dimensions, 405 const dnn::BatchDescriptor& output_dimensions, 406 DeviceMemory<float>* output_data) override; 407 DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int8> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)408 bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data, 409 const DeviceMemory<int8>& quantized_weights, 410 const DeviceMemory<float>& weight_scales, 411 const dnn::BatchDescriptor& input_dimensions, 412 const dnn::BatchDescriptor& output_dimensions, 413 DeviceMemory<float>* output_data) override { 414 LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN"; 415 return false; 416 } 417 DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int16> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)418 bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data, 419 const DeviceMemory<int16>& quantized_weights, 420 const DeviceMemory<float>& weight_scales, 421 const dnn::BatchDescriptor& input_dimensions, 422 const dnn::BatchDescriptor& output_dimensions, 423 DeviceMemory<float>* output_data) override { 424 LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN"; 425 return false; 426 } 427 428 bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data, 429 const DeviceMemory<float>& biases, 430 const dnn::BatchDescriptor& dimensions, 431 DeviceMemory<float>* output_data) override; 432 433 bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode, 434 const dnn::BatchDescriptor& dimensions, 435 const DeviceMemory<float>& input_data, 436 DeviceMemory<float>* output_data, uint64_t options) override; 437 438 port::Status DoPoolForward(dnn::DataType element_type, Stream* stream, 439 const dnn::PoolingDescriptor& pooling_dimensions, 440 const dnn::BatchDescriptor& input_dimensions, 441 DeviceMemoryBase input_data, 442 const dnn::BatchDescriptor& output_dimensions, 443 DeviceMemoryBase output_data, 444 ScratchAllocator* workspace_allocator) override; 445 446 port::Status DoPoolBackward(dnn::DataType element_type, Stream* stream, 447 const dnn::PoolingDescriptor& pooling_dimensions, 448 const dnn::BatchDescriptor& input_dimensions, 449 DeviceMemoryBase input_data, 450 const dnn::BatchDescriptor& output_dimensions, 451 DeviceMemoryBase output_data, 452 DeviceMemoryBase input_diff_data, 453 DeviceMemoryBase output_diff_data, 454 ScratchAllocator* workspace_allocator) override; 455 456 bool DoNormalizeWithDimensions( 457 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, 458 const dnn::BatchDescriptor& dimensions, 459 const DeviceMemory<float>& input_data, 460 DeviceMemory<float>* output_data) override; 461 462 bool DoNormalizeBackwardWithDimensions( 463 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, 464 const dnn::BatchDescriptor& dimensions, 465 const DeviceMemory<float>& raw_data, 466 const DeviceMemory<float>& normalized_data, 467 const DeviceMemory<float>& normalized_variable_gradient, 468 DeviceMemory<float>* raw_variable_gradient, 469 ScratchAllocator* workspace_allocator) override; 470 471 bool DoDepthConcatenate(Stream* stream, BatchDescriptorSlice input_dimensions, 472 DeviceMemorySlice<float> input_data, 473 DeviceMemory<float>* output_data) override; 474 475 bool DoElementwiseOperate(Stream* stream, dnn::ElementwiseOperation operation, 476 BatchDescriptorSlice input_dimensions, 477 DeviceMemorySlice<float> input_data, 478 const dnn::BatchDescriptor& output_dimensions, 479 DeviceMemory<float>* output_data) override; 480 481 bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions, 482 const DeviceMemory<float>& input_data, int64_t left_pad, 483 int64_t right_pad, int64_t top_pad, int64_t bottom_pad, 484 DeviceMemory<float>* output_data) override; 485 486 bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions, 487 const DeviceMemory<float>& input_data, int64_t left_trim, 488 int64_t right_trim, int64_t top_trim, int64_t bottom_trim, 489 DeviceMemory<float>* output_data) override; 490 491 bool DoMemcpyD2HQuantized(Stream* stream, 492 const DeviceMemory<float>& device_unquantized_src, 493 dnn::QuantizedActivationMode mode, void* host_dst, 494 int64_t size) override; 495 496 bool DoMemcpyH2DQuantized( 497 Stream* stream, const void* host_src, int64_t size, 498 dnn::QuantizedActivationMode mode, 499 DeviceMemory<float>* device_unquantized_dst) override; 500 501 // Derives an output batch descriptor from an input batch and convolution 502 // descriptors. 503 bool DeriveOutputBatchDescriptor( 504 const dnn::BatchDescriptor& batch_descriptor, 505 const dnn::FilterDescriptor& filter_descriptor, 506 const dnn::ConvolutionDescriptor& convolution_descriptor, 507 dnn::BatchDescriptor* output_batch_descriptor); 508 509 port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type, 510 const dnn::RnnStateTensorDescriptor& probs_desc, 511 const DeviceMemoryBase probs_data, 512 absl::Span<const int> labels_data, 513 absl::Span<const int> labels_lengths_data, 514 absl::Span<const int> input_lengths_data, 515 DeviceMemoryBase costs_data, 516 const dnn::RnnStateTensorDescriptor& grads_desc, 517 DeviceMemoryBase grads_data, 518 DeviceMemory<uint8> scratch_memory, 519 int ctc_loss_algo_id) override; 520 521 bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc, 522 dnn::DataType input_type, 523 const DeviceMemoryBase& input_data, 524 const dnn::BatchDescriptor& output_desc, 525 dnn::DataType output_type, float scale, 526 DeviceMemoryBase* output_data) override; 527 528 void NotifyStreamDestroyed(Stream* stream) override; 529 530 private: 531 GpuExecutor* parent_; // Parent executor object. Not owned. 532 533 // Provides access to the cuDNN handle. 534 std::unique_ptr<class CudnnAccess> cudnn_; 535 536 template <class T, class U> 537 port::Status DoBatchNormalizationForwardImpl( 538 Stream* stream, dnn::DataType input_data_type, 539 dnn::DataType scale_data_type, const DeviceMemory<T>& x, 540 const DeviceMemory<U>& scale, const DeviceMemory<U>& offset, 541 const DeviceMemory<U>& estimated_mean, 542 const DeviceMemory<U>& estimated_variance, 543 const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc, 544 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, 545 const double exponential_average_factor, 546 dnn::ActivationMode activation_mode, DeviceMemory<T>* y, 547 DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var, 548 DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var, 549 bool is_training, ScratchAllocator* reserve_space_allocator, 550 ScratchAllocator* workspace_allocator); 551 552 template <class T, class U> 553 port::Status DoBatchNormalizationBackwardImpl( 554 Stream* stream, int cudnn_input_type, int cudnn_scale_type, 555 const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x, 556 const DeviceMemory<U>& scale, const DeviceMemory<U>& offset, 557 const DeviceMemory<U>& mean, const DeviceMemory<U>& inv_var, 558 const DeviceMemory<T>& y, const dnn::BatchDescriptor& x_desc, 559 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, 560 dnn::ActivationMode activation_mode, DeviceMemory<T>* x_backprop, 561 DeviceMemory<U>* scale_backprop, DeviceMemory<U>* offset_backprop, 562 DeviceMemory<T>* side_input_backprop, 563 DeviceMemory<uint8>* reserve_space_data, 564 ScratchAllocator* workspace_allocator); 565 566 template <class T> 567 port::Status DoRnnForwardImpl( 568 Stream* stream, const CudnnRnnDescriptor& rnn_desc, 569 const CudnnRnnSequenceTensorDescriptor& input_desc, 570 const DeviceMemory<T>& input_data, 571 const DeviceMemory<int>& seq_lengths_data, 572 const CudnnRnnStateTensorDescriptor& input_h_desc, 573 const DeviceMemory<T>& input_h_data, 574 const CudnnRnnStateTensorDescriptor& input_c_desc, 575 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params, 576 const CudnnRnnSequenceTensorDescriptor& output_desc, 577 DeviceMemory<T>* output_data, 578 const CudnnRnnStateTensorDescriptor& output_h_desc, 579 DeviceMemory<T>* output_h_data, 580 const CudnnRnnStateTensorDescriptor& output_c_desc, 581 DeviceMemory<T>* output_c_data, bool is_training, 582 ScratchAllocator* reserve_space_allocator, 583 ScratchAllocator* workspace_allocator, 584 dnn::ProfileResult* output_profile_result); 585 586 template <class T> 587 port::Status DoRnnBackwardImpl( 588 Stream* stream, const CudnnRnnDescriptor& rnn_desc, 589 const CudnnRnnSequenceTensorDescriptor& input_desc, 590 const DeviceMemory<T>& input_data, 591 const DeviceMemory<int>& seq_lengths_data, 592 const CudnnRnnStateTensorDescriptor& input_h_desc, 593 const DeviceMemory<T>& input_h_data, 594 const CudnnRnnStateTensorDescriptor& input_c_desc, 595 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params, 596 const CudnnRnnSequenceTensorDescriptor& output_desc, 597 const DeviceMemory<T>& output_data, 598 const CudnnRnnStateTensorDescriptor& output_h_desc, 599 const DeviceMemory<T>& output_h_data, 600 const CudnnRnnStateTensorDescriptor& output_c_desc, 601 const DeviceMemory<T>& output_c_data, 602 const DeviceMemory<T>& output_backprop_data, 603 const DeviceMemory<T>& output_h_backprop_data, 604 const DeviceMemory<T>& output_c_backprop_data, 605 DeviceMemory<T>* input_backprop_data, 606 DeviceMemory<T>* input_h_backprop_data, 607 DeviceMemory<T>* input_c_backprop_data, 608 DeviceMemory<T>* params_backprop_data, 609 DeviceMemory<uint8>* reserve_space_data, 610 ScratchAllocator* workspace_allocator, 611 dnn::ProfileResult* output_profile_result); 612 613 port::Status DoCtcLossImpl( 614 Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc, 615 const DeviceMemoryBase probs_data, absl::Span<const int> labels_data, 616 absl::Span<const int> labels_lengths_data, 617 absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data, 618 const CudnnRnnStateTensorDescriptor& grads_desc, 619 DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc, 620 DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id); 621 622 private: 623 port::Status DoPrepareForConvolution( 624 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, 625 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, 626 const dnn::FilterDescriptor& filter_descriptor, 627 DeviceMemoryBase filter_data, 628 const dnn::BatchDescriptor& output_descriptor, 629 DeviceMemoryBase output_data, 630 const dnn::ConvolutionDescriptor& convolution_descriptor, 631 const dnn::AlgorithmConfig& algorithm_config, 632 ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc, 633 DeviceMemory<uint8>* scratch_memory) override; 634 635 port::Status DoPrepareForCtcLoss( 636 Stream* stream, dnn::DataType element_type, 637 const dnn::RnnStateTensorDescriptor& probs_desc, 638 const dnn::RnnStateTensorDescriptor& grads_desc, 639 absl::Span<const int> labels_data, 640 absl::Span<const int> labels_lengths_data, 641 absl::Span<const int> input_lengths_data, 642 ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory, 643 int* ctc_loss_algo_id) override; 644 645 SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport); 646 }; 647 648 } // namespace gpu 649 } // namespace stream_executor 650 651 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_ 652