xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The CUDA-specific DNN library support, implementing the general DnnSupport
17 // interface.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
20 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
21 
22 #include "absl/base/thread_annotations.h"
23 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_activation.h"
24 #include "tensorflow/compiler/xla/stream_executor/dnn.h"
25 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
26 #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h"
27 #include "tensorflow/compiler/xla/stream_executor/temporary_device_memory.h"
28 
29 namespace stream_executor {
30 namespace gpu {
31 
32 class GpuExecutor;
33 class CudnnRnnDescriptor;
34 class CudnnRnnSequenceTensorDescriptor;
35 class CudnnRnnStateTensorDescriptor;
36 class CudnnCtcLossDescriptor;
37 
38 // Opaque and unique identifier for the cuDNN plugin.
39 extern const PluginId kCuDnnPlugin;
40 
41 using BatchDescriptorSlice =
42     port::ArraySlice<dnn::BatchDescriptor>;  // non-absl ok
43 
44 template <typename T>
45 using DeviceMemorySlice =
46     port::ArraySlice<const DeviceMemory<T>*>;  // non-absl ok
47 
48 // cudnn-library based DNN support. For details on overridden interface
49 // functions, see dnn.h.
50 class CudnnSupport : public dnn::DnnSupport {
51  public:
52   explicit CudnnSupport(GpuExecutor* parent);
53 
54   port::Status Init() override;
55   port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
56 
57   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
58       int num_layers, int hidden_size, int input_size, int cell_size,
59       int batch_size, dnn::RnnInputMode input_mode,
60       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
61       dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
62       float dropout, uint64_t seed, ScratchAllocator* state_allocator,
63       bool use_padded_io) override;
64 
65   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
66   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
67                                     int data_size,
68                                     dnn::DataType data_type) override;
69 
70   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
71   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
72                                     int data_size,
73                                     const absl::Span<const int>& seq_lengths,
74                                     bool time_major,
75                                     dnn::DataType data_type) override;
76 
77   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
78   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
79                                  dnn::DataType data_type) override;
80 
81   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
82                     const dnn::RnnSequenceTensorDescriptor& input_desc,
83                     const DeviceMemory<Eigen::half>& input_data,
84                     const DeviceMemory<int>& seq_lengths_data,
85                     const dnn::RnnStateTensorDescriptor& input_h_desc,
86                     const DeviceMemory<Eigen::half>& input_h_data,
87                     const dnn::RnnStateTensorDescriptor& input_c_desc,
88                     const DeviceMemory<Eigen::half>& input_c_data,
89                     const DeviceMemory<Eigen::half>& params,
90                     const dnn::RnnSequenceTensorDescriptor& output_desc,
91                     DeviceMemory<Eigen::half>* output_data,
92                     const dnn::RnnStateTensorDescriptor& output_h_desc,
93                     DeviceMemory<Eigen::half>* output_h_data,
94                     const dnn::RnnStateTensorDescriptor& output_c_desc,
95                     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
96                     ScratchAllocator* reserve_space_allocator,
97                     ScratchAllocator* workspace_allocator,
98                     dnn::ProfileResult* output_profile_result) override;
99 
100   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
101                     const dnn::RnnSequenceTensorDescriptor& input_desc,
102                     const DeviceMemory<float>& input_data,
103                     const DeviceMemory<int>& seq_lengths_data,
104                     const dnn::RnnStateTensorDescriptor& input_h_desc,
105                     const DeviceMemory<float>& input_h_data,
106                     const dnn::RnnStateTensorDescriptor& input_c_desc,
107                     const DeviceMemory<float>& input_c_data,
108                     const DeviceMemory<float>& params,
109                     const dnn::RnnSequenceTensorDescriptor& output_desc,
110                     DeviceMemory<float>* output_data,
111                     const dnn::RnnStateTensorDescriptor& output_h_desc,
112                     DeviceMemory<float>* output_h_data,
113                     const dnn::RnnStateTensorDescriptor& output_c_desc,
114                     DeviceMemory<float>* output_c_data, bool is_training,
115                     ScratchAllocator* reserve_space_allocator,
116                     ScratchAllocator* workspace_allocator,
117                     dnn::ProfileResult* output_profile_result) override;
118 
119   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
120                     const dnn::RnnSequenceTensorDescriptor& input_desc,
121                     const DeviceMemory<double>& input_data,
122                     const DeviceMemory<int>& seq_lengths_data,
123                     const dnn::RnnStateTensorDescriptor& input_h_desc,
124                     const DeviceMemory<double>& input_h_data,
125                     const dnn::RnnStateTensorDescriptor& input_c_desc,
126                     const DeviceMemory<double>& input_c_data,
127                     const DeviceMemory<double>& params,
128                     const dnn::RnnSequenceTensorDescriptor& output_desc,
129                     DeviceMemory<double>* output_data,
130                     const dnn::RnnStateTensorDescriptor& output_h_desc,
131                     DeviceMemory<double>* output_h_data,
132                     const dnn::RnnStateTensorDescriptor& output_c_desc,
133                     DeviceMemory<double>* output_c_data, bool is_training,
134                     ScratchAllocator* reserve_space_allocator,
135                     ScratchAllocator* workspace_allocator,
136                     dnn::ProfileResult* output_profile_result) override;
137 
138   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
139                      const dnn::RnnSequenceTensorDescriptor& input_desc,
140                      const DeviceMemory<Eigen::half>& input_data,
141                      const DeviceMemory<int>& seq_lengths_data,
142                      const dnn::RnnStateTensorDescriptor& input_h_desc,
143                      const DeviceMemory<Eigen::half>& input_h_data,
144                      const dnn::RnnStateTensorDescriptor& input_c_desc,
145                      const DeviceMemory<Eigen::half>& input_c_data,
146                      const DeviceMemory<Eigen::half>& params,
147                      const dnn::RnnSequenceTensorDescriptor& output_desc,
148                      const DeviceMemory<Eigen::half>& output_data,
149                      const dnn::RnnStateTensorDescriptor& output_h_desc,
150                      const DeviceMemory<Eigen::half>& output_h_data,
151                      const dnn::RnnStateTensorDescriptor& output_c_desc,
152                      const DeviceMemory<Eigen::half>& output_c_data,
153                      const DeviceMemory<Eigen::half>& output_backprop_data,
154                      const DeviceMemory<Eigen::half>& output_h_backprop_data,
155                      const DeviceMemory<Eigen::half>& output_c_backprop_data,
156                      DeviceMemory<Eigen::half>* input_backprop_data,
157                      DeviceMemory<Eigen::half>* input_h_backprop_data,
158                      DeviceMemory<Eigen::half>* input_c_backprop_data,
159                      DeviceMemory<Eigen::half>* params_backprop_data,
160                      DeviceMemory<uint8>* reserve_space_data,
161                      ScratchAllocator* workspace_allocator,
162                      dnn::ProfileResult* output_profile_result) override;
163 
164   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
165                      const dnn::RnnSequenceTensorDescriptor& input_desc,
166                      const DeviceMemory<float>& input_data,
167                      const DeviceMemory<int>& seq_lengths_data,
168                      const dnn::RnnStateTensorDescriptor& input_h_desc,
169                      const DeviceMemory<float>& input_h_data,
170                      const dnn::RnnStateTensorDescriptor& input_c_desc,
171                      const DeviceMemory<float>& input_c_data,
172                      const DeviceMemory<float>& params,
173                      const dnn::RnnSequenceTensorDescriptor& output_desc,
174                      const DeviceMemory<float>& output_data,
175                      const dnn::RnnStateTensorDescriptor& output_h_desc,
176                      const DeviceMemory<float>& output_h_data,
177                      const dnn::RnnStateTensorDescriptor& output_c_desc,
178                      const DeviceMemory<float>& output_c_data,
179                      const DeviceMemory<float>& output_backprop_data,
180                      const DeviceMemory<float>& output_h_backprop_data,
181                      const DeviceMemory<float>& output_c_backprop_data,
182                      DeviceMemory<float>* input_backprop_data,
183                      DeviceMemory<float>* input_h_backprop_data,
184                      DeviceMemory<float>* input_c_backprop_data,
185                      DeviceMemory<float>* params_backprop_data,
186                      DeviceMemory<uint8>* reserve_space_data,
187                      ScratchAllocator* workspace_allocator,
188                      dnn::ProfileResult* output_profile_result) override;
189 
190   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
191                      const dnn::RnnSequenceTensorDescriptor& input_desc,
192                      const DeviceMemory<double>& input_data,
193                      const DeviceMemory<int>& seq_lengths_data,
194                      const dnn::RnnStateTensorDescriptor& input_h_desc,
195                      const DeviceMemory<double>& input_h_data,
196                      const dnn::RnnStateTensorDescriptor& input_c_desc,
197                      const DeviceMemory<double>& input_c_data,
198                      const DeviceMemory<double>& params,
199                      const dnn::RnnSequenceTensorDescriptor& output_desc,
200                      const DeviceMemory<double>& output_data,
201                      const dnn::RnnStateTensorDescriptor& output_h_desc,
202                      const DeviceMemory<double>& output_h_data,
203                      const dnn::RnnStateTensorDescriptor& output_c_desc,
204                      const DeviceMemory<double>& output_c_data,
205                      const DeviceMemory<double>& output_backprop_data,
206                      const DeviceMemory<double>& output_h_backprop_data,
207                      const DeviceMemory<double>& output_c_backprop_data,
208                      DeviceMemory<double>* input_backprop_data,
209                      DeviceMemory<double>* input_h_backprop_data,
210                      DeviceMemory<double>* input_c_backprop_data,
211                      DeviceMemory<double>* params_backprop_data,
212                      DeviceMemory<uint8>* reserve_space_data,
213                      ScratchAllocator* workspace_allocator,
214                      dnn::ProfileResult* output_profile_result) override;
215 
216   bool GetConvolveAlgorithms(
217       CudaComputeCapability cuda_compute_capability,
218       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
219 
220   port::Status GetConvolveRunners(
221       bool use_cudnn_frontend, dnn::ConvolutionKind kind,
222       dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
223       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
224       const dnn::FilterDescriptor& filter_descriptor,
225       DeviceMemoryBase filter_data,
226       const dnn::BatchDescriptor& output_descriptor,
227       DeviceMemoryBase output_data,
228       const dnn::ConvolutionDescriptor& convolution_descriptor,
229       bool use_fallback, ScratchAllocator* scratch_allocator,
230       std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_exec_plans)
231       override;
232 
233   port::StatusOr<std::unique_ptr<const dnn::ConvRunner>> ConvolveRunnerFromDesc(
234       Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
235       dnn::ConvolutionKind kind, dnn::DataType input_type,
236       dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor,
237       const dnn::FilterDescriptor& filter_descriptor,
238       const dnn::BatchDescriptor& output_descriptor,
239       const dnn::ConvolutionDescriptor& convolution_descriptor) override;
240 
241   port::Status GetFusedConvolveRunners(
242       bool use_cudnn_frontend, dnn::ConvolutionKind kind,
243       dnn::DataType input_type, dnn::DataType bias_type,
244       dnn::DataType output_type, double conv_scale, double side_input_scale,
245       double leakyrelu_alpha, Stream* stream,
246       const dnn::BatchDescriptor& input_descriptor,
247       const dnn::FilterDescriptor& filter_descriptor,
248       const dnn::BatchDescriptor& bias_descriptor,
249       const dnn::BatchDescriptor& output_descriptor,
250       const dnn::ConvolutionDescriptor& convolution_descriptor,
251       bool use_fallback, dnn::ActivationMode activation_mode,
252       std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans)
253       override;
254 
255   port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>>
256   FusedConvolveRunnerFromDesc(
257       Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
258       dnn::ConvolutionKind kind, dnn::DataType input_type,
259       dnn::DataType bias_type, dnn::DataType output_type, double conv_scale,
260       double side_input_scale, double leakyrelu_alpha,
261       const dnn::BatchDescriptor& input_descriptor,
262       const dnn::FilterDescriptor& filter_descriptor,
263       const dnn::BatchDescriptor& bias_descriptor,
264       const dnn::BatchDescriptor& output_descriptor,
265       const dnn::ConvolutionDescriptor& convolution_descriptor,
266       dnn::ActivationMode activation_mode) override;
267 
268   bool GetRnnAlgorithms(
269       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
270 
271   bool GetConvolveBackwardDataAlgorithms(
272       CudaComputeCapability cuda_compute_capability,
273       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
274 
275   bool GetConvolveBackwardFilterAlgorithms(
276       CudaComputeCapability cuda_compute_capability,
277       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
278 
279   bool DoBatchNormalizationForward(
280       Stream* stream, const DeviceMemory<float>& x,
281       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
282       const DeviceMemory<float>& estimated_mean,
283       const DeviceMemory<float>& estimated_var_iance,
284       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
285       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
286       const double exponential_average_factor,
287       dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
288       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
289       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
290       bool is_training, ScratchAllocator* reserve_space_allocator,
291       ScratchAllocator* workspace_allocator) override;
292 
293   bool DoBatchNormalizationForward(
294       Stream* stream, const DeviceMemory<Eigen::half>& x,
295       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
296       const DeviceMemory<float>& estimated_mean,
297       const DeviceMemory<float>& estimated_variance,
298       const DeviceMemory<Eigen::half>& side_input,
299       const dnn::BatchDescriptor& x_desc,
300       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
301       const double exponential_average_factor,
302       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
303       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
304       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
305       bool is_training, ScratchAllocator* reserve_space_allocator,
306       ScratchAllocator* workspace_allocator) override;
307 
308   bool DoBatchNormalizationBackward(
309       Stream* stream, const DeviceMemory<float>& y_backprop,
310       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
311       const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
312       const DeviceMemory<float>& inv_var, const DeviceMemory<float>& y,
313       const dnn::BatchDescriptor& x_desc,
314       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
315       dnn::ActivationMode activation_mode, DeviceMemory<float>* x_backprop,
316       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
317       DeviceMemory<float>* side_input_backprop,
318       DeviceMemory<uint8>* reserve_space_data,
319       ScratchAllocator* workspace_allocator) override;
320 
321   bool DoBatchNormalizationBackward(
322       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
323       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
324       const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
325       const DeviceMemory<float>& inv_var, const DeviceMemory<Eigen::half>& y,
326       const dnn::BatchDescriptor& x_desc,
327       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
328       dnn::ActivationMode activation_mode,
329       DeviceMemory<Eigen::half>* x_backprop,
330       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
331       DeviceMemory<Eigen::half>* side_input_backprop,
332       DeviceMemory<uint8>* reserve_space_data,
333       ScratchAllocator* workspace_allocator) override;
334 
335   port::Status DoConvolve(
336       dnn::ConvolutionKind kind, dnn::DataType element_type,
337       dnn::DataType output_type, Stream* stream,
338       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
339       const dnn::FilterDescriptor& filter_descriptor,
340       DeviceMemoryBase filter_data,
341       const dnn::BatchDescriptor& output_descriptor,
342       DeviceMemoryBase output_data,
343       const dnn::ConvolutionDescriptor& convolution_descriptor,
344       dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
345       dnn::ProfileResult* output_profile_result) override;
346 
347   port::Status DoFusedConvolve(
348       Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type,
349       dnn::DataType bias_type, dnn::DataType output_type,
350       const dnn::BatchDescriptor& conv_input_descriptor,
351       DeviceMemoryBase conv_input_data, double conv_input_scale,
352       const dnn::FilterDescriptor& filter_descriptor,
353       DeviceMemoryBase filter_data,
354       const dnn::ConvolutionDescriptor& convolution_descriptor,
355       DeviceMemoryBase side_input_data, double side_input_scale,
356       const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
357       dnn::ActivationMode activation_mode,
358       const dnn::BatchDescriptor& output_descriptor,
359       DeviceMemoryBase output_data, ScratchAllocator* scratch_allocator,
360       const dnn::AlgorithmConfig& algorithm_config,
361       dnn::ProfileResult* output_profile_result) override;
362 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)363   bool DoConvolveQuantized(
364       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
365       const DeviceMemory<float>& input_data,
366       const dnn::FilterDescriptor& filter_descriptor,
367       const DeviceMemory<int8>& filter_coefficients,
368       const DeviceMemory<float>& coefficient_scales,
369       const dnn::ConvolutionDescriptor& convolution_descriptor,
370       const dnn::BatchDescriptor& output_descriptor,
371       DeviceMemory<float>* output_data) override {
372     LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
373     return false;
374   }
375 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)376   bool DoConvolveQuantized(
377       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
378       const DeviceMemory<float>& input_data,
379       const dnn::FilterDescriptor& filter_descriptor,
380       const DeviceMemory<int16>& filter_coefficients,
381       const DeviceMemory<float>& coefficient_scales,
382       const dnn::ConvolutionDescriptor& convolution_descriptor,
383       const dnn::BatchDescriptor& output_descriptor,
384       DeviceMemory<float>* output_data) override {
385     LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
386     return false;
387   }
388 
DoSeparableConvolve(Stream * stream,const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)389   bool DoSeparableConvolve(
390       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
391       const DeviceMemory<float>& input_data,
392       const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier,
393       const DeviceMemory<float>& first_weights,
394       const DeviceMemory<float>& second_weights,
395       const dnn::ConvolutionDescriptor& convolution_descriptor,
396       const dnn::BatchDescriptor& output_descriptor,
397       DeviceMemory<float>* output_data) override {
398     LOG(ERROR) << "separable convolution not supported by CUDNN";
399     return false;
400   }
401 
402   bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
403                 const DeviceMemory<float>& weights,
404                 const dnn::BatchDescriptor& input_dimensions,
405                 const dnn::BatchDescriptor& output_dimensions,
406                 DeviceMemory<float>* output_data) override;
407 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int8> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)408   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
409                          const DeviceMemory<int8>& quantized_weights,
410                          const DeviceMemory<float>& weight_scales,
411                          const dnn::BatchDescriptor& input_dimensions,
412                          const dnn::BatchDescriptor& output_dimensions,
413                          DeviceMemory<float>* output_data) override {
414     LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
415     return false;
416   }
417 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int16> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)418   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
419                          const DeviceMemory<int16>& quantized_weights,
420                          const DeviceMemory<float>& weight_scales,
421                          const dnn::BatchDescriptor& input_dimensions,
422                          const dnn::BatchDescriptor& output_dimensions,
423                          DeviceMemory<float>* output_data) override {
424     LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
425     return false;
426   }
427 
428   bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
429                  const DeviceMemory<float>& biases,
430                  const dnn::BatchDescriptor& dimensions,
431                  DeviceMemory<float>* output_data) override;
432 
433   bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode,
434                   const dnn::BatchDescriptor& dimensions,
435                   const DeviceMemory<float>& input_data,
436                   DeviceMemory<float>* output_data, uint64_t options) override;
437 
438   port::Status DoPoolForward(dnn::DataType element_type, Stream* stream,
439                              const dnn::PoolingDescriptor& pooling_dimensions,
440                              const dnn::BatchDescriptor& input_dimensions,
441                              DeviceMemoryBase input_data,
442                              const dnn::BatchDescriptor& output_dimensions,
443                              DeviceMemoryBase output_data,
444                              ScratchAllocator* workspace_allocator) override;
445 
446   port::Status DoPoolBackward(dnn::DataType element_type, Stream* stream,
447                               const dnn::PoolingDescriptor& pooling_dimensions,
448                               const dnn::BatchDescriptor& input_dimensions,
449                               DeviceMemoryBase input_data,
450                               const dnn::BatchDescriptor& output_dimensions,
451                               DeviceMemoryBase output_data,
452                               DeviceMemoryBase input_diff_data,
453                               DeviceMemoryBase output_diff_data,
454                               ScratchAllocator* workspace_allocator) override;
455 
456   bool DoNormalizeWithDimensions(
457       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
458       const dnn::BatchDescriptor& dimensions,
459       const DeviceMemory<float>& input_data,
460       DeviceMemory<float>* output_data) override;
461 
462   bool DoNormalizeBackwardWithDimensions(
463       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
464       const dnn::BatchDescriptor& dimensions,
465       const DeviceMemory<float>& raw_data,
466       const DeviceMemory<float>& normalized_data,
467       const DeviceMemory<float>& normalized_variable_gradient,
468       DeviceMemory<float>* raw_variable_gradient,
469       ScratchAllocator* workspace_allocator) override;
470 
471   bool DoDepthConcatenate(Stream* stream, BatchDescriptorSlice input_dimensions,
472                           DeviceMemorySlice<float> input_data,
473                           DeviceMemory<float>* output_data) override;
474 
475   bool DoElementwiseOperate(Stream* stream, dnn::ElementwiseOperation operation,
476                             BatchDescriptorSlice input_dimensions,
477                             DeviceMemorySlice<float> input_data,
478                             const dnn::BatchDescriptor& output_dimensions,
479                             DeviceMemory<float>* output_data) override;
480 
481   bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions,
482                const DeviceMemory<float>& input_data, int64_t left_pad,
483                int64_t right_pad, int64_t top_pad, int64_t bottom_pad,
484                DeviceMemory<float>* output_data) override;
485 
486   bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions,
487                  const DeviceMemory<float>& input_data, int64_t left_trim,
488                  int64_t right_trim, int64_t top_trim, int64_t bottom_trim,
489                  DeviceMemory<float>* output_data) override;
490 
491   bool DoMemcpyD2HQuantized(Stream* stream,
492                             const DeviceMemory<float>& device_unquantized_src,
493                             dnn::QuantizedActivationMode mode, void* host_dst,
494                             int64_t size) override;
495 
496   bool DoMemcpyH2DQuantized(
497       Stream* stream, const void* host_src, int64_t size,
498       dnn::QuantizedActivationMode mode,
499       DeviceMemory<float>* device_unquantized_dst) override;
500 
501   // Derives an output batch descriptor from an input batch and convolution
502   // descriptors.
503   bool DeriveOutputBatchDescriptor(
504       const dnn::BatchDescriptor& batch_descriptor,
505       const dnn::FilterDescriptor& filter_descriptor,
506       const dnn::ConvolutionDescriptor& convolution_descriptor,
507       dnn::BatchDescriptor* output_batch_descriptor);
508 
509   port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
510                          const dnn::RnnStateTensorDescriptor& probs_desc,
511                          const DeviceMemoryBase probs_data,
512                          absl::Span<const int> labels_data,
513                          absl::Span<const int> labels_lengths_data,
514                          absl::Span<const int> input_lengths_data,
515                          DeviceMemoryBase costs_data,
516                          const dnn::RnnStateTensorDescriptor& grads_desc,
517                          DeviceMemoryBase grads_data,
518                          DeviceMemory<uint8> scratch_memory,
519                          int ctc_loss_algo_id) override;
520 
521   bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
522                          dnn::DataType input_type,
523                          const DeviceMemoryBase& input_data,
524                          const dnn::BatchDescriptor& output_desc,
525                          dnn::DataType output_type, float scale,
526                          DeviceMemoryBase* output_data) override;
527 
528   void NotifyStreamDestroyed(Stream* stream) override;
529 
530  private:
531   GpuExecutor* parent_;  // Parent executor object. Not owned.
532 
533   // Provides access to the cuDNN handle.
534   std::unique_ptr<class CudnnAccess> cudnn_;
535 
536   template <class T, class U>
537   port::Status DoBatchNormalizationForwardImpl(
538       Stream* stream, dnn::DataType input_data_type,
539       dnn::DataType scale_data_type, const DeviceMemory<T>& x,
540       const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
541       const DeviceMemory<U>& estimated_mean,
542       const DeviceMemory<U>& estimated_variance,
543       const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
544       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
545       const double exponential_average_factor,
546       dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
547       DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
548       DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
549       bool is_training, ScratchAllocator* reserve_space_allocator,
550       ScratchAllocator* workspace_allocator);
551 
552   template <class T, class U>
553   port::Status DoBatchNormalizationBackwardImpl(
554       Stream* stream, int cudnn_input_type, int cudnn_scale_type,
555       const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
556       const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
557       const DeviceMemory<U>& mean, const DeviceMemory<U>& inv_var,
558       const DeviceMemory<T>& y, const dnn::BatchDescriptor& x_desc,
559       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
560       dnn::ActivationMode activation_mode, DeviceMemory<T>* x_backprop,
561       DeviceMemory<U>* scale_backprop, DeviceMemory<U>* offset_backprop,
562       DeviceMemory<T>* side_input_backprop,
563       DeviceMemory<uint8>* reserve_space_data,
564       ScratchAllocator* workspace_allocator);
565 
566   template <class T>
567   port::Status DoRnnForwardImpl(
568       Stream* stream, const CudnnRnnDescriptor& rnn_desc,
569       const CudnnRnnSequenceTensorDescriptor& input_desc,
570       const DeviceMemory<T>& input_data,
571       const DeviceMemory<int>& seq_lengths_data,
572       const CudnnRnnStateTensorDescriptor& input_h_desc,
573       const DeviceMemory<T>& input_h_data,
574       const CudnnRnnStateTensorDescriptor& input_c_desc,
575       const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
576       const CudnnRnnSequenceTensorDescriptor& output_desc,
577       DeviceMemory<T>* output_data,
578       const CudnnRnnStateTensorDescriptor& output_h_desc,
579       DeviceMemory<T>* output_h_data,
580       const CudnnRnnStateTensorDescriptor& output_c_desc,
581       DeviceMemory<T>* output_c_data, bool is_training,
582       ScratchAllocator* reserve_space_allocator,
583       ScratchAllocator* workspace_allocator,
584       dnn::ProfileResult* output_profile_result);
585 
586   template <class T>
587   port::Status DoRnnBackwardImpl(
588       Stream* stream, const CudnnRnnDescriptor& rnn_desc,
589       const CudnnRnnSequenceTensorDescriptor& input_desc,
590       const DeviceMemory<T>& input_data,
591       const DeviceMemory<int>& seq_lengths_data,
592       const CudnnRnnStateTensorDescriptor& input_h_desc,
593       const DeviceMemory<T>& input_h_data,
594       const CudnnRnnStateTensorDescriptor& input_c_desc,
595       const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
596       const CudnnRnnSequenceTensorDescriptor& output_desc,
597       const DeviceMemory<T>& output_data,
598       const CudnnRnnStateTensorDescriptor& output_h_desc,
599       const DeviceMemory<T>& output_h_data,
600       const CudnnRnnStateTensorDescriptor& output_c_desc,
601       const DeviceMemory<T>& output_c_data,
602       const DeviceMemory<T>& output_backprop_data,
603       const DeviceMemory<T>& output_h_backprop_data,
604       const DeviceMemory<T>& output_c_backprop_data,
605       DeviceMemory<T>* input_backprop_data,
606       DeviceMemory<T>* input_h_backprop_data,
607       DeviceMemory<T>* input_c_backprop_data,
608       DeviceMemory<T>* params_backprop_data,
609       DeviceMemory<uint8>* reserve_space_data,
610       ScratchAllocator* workspace_allocator,
611       dnn::ProfileResult* output_profile_result);
612 
613   port::Status DoCtcLossImpl(
614       Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
615       const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
616       absl::Span<const int> labels_lengths_data,
617       absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
618       const CudnnRnnStateTensorDescriptor& grads_desc,
619       DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
620       DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
621 
622  private:
623   port::Status DoPrepareForConvolution(
624       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
625       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
626       const dnn::FilterDescriptor& filter_descriptor,
627       DeviceMemoryBase filter_data,
628       const dnn::BatchDescriptor& output_descriptor,
629       DeviceMemoryBase output_data,
630       const dnn::ConvolutionDescriptor& convolution_descriptor,
631       const dnn::AlgorithmConfig& algorithm_config,
632       ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
633       DeviceMemory<uint8>* scratch_memory) override;
634 
635   port::Status DoPrepareForCtcLoss(
636       Stream* stream, dnn::DataType element_type,
637       const dnn::RnnStateTensorDescriptor& probs_desc,
638       const dnn::RnnStateTensorDescriptor& grads_desc,
639       absl::Span<const int> labels_data,
640       absl::Span<const int> labels_lengths_data,
641       absl::Span<const int> input_lengths_data,
642       ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
643       int* ctc_loss_algo_id) override;
644 
645   SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
646 };
647 
648 }  // namespace gpu
649 }  // namespace stream_executor
650 
651 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
652