xref: /aosp_15_r20/external/tensorflow/tensorflow/stream_executor/rocm/rocm_dnn.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The ROCM-specific DNN library support, implementing the general DnnSupport
17 // interface.
18 
19 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
20 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
21 
22 #include "absl/synchronization/mutex.h"
23 #include "absl/types/span.h"
24 #include "rocm/include/miopen/miopen.h"
25 #include "tensorflow/stream_executor/dnn.h"
26 #include "tensorflow/stream_executor/lib/status.h"
27 #include "tensorflow/stream_executor/plugin_registry.h"
28 #include "tensorflow/stream_executor/temporary_device_memory.h"
29 
30 namespace stream_executor {
31 namespace gpu {
32 
33 class GpuExecutor;
34 class MIOpenRnnDescriptor;
35 class MIOpenRnnSequenceTensorDescriptor;
36 class MIOpenRnnStateTensorDescriptor;
37 class MIOpenCTCLossDescriptor;
38 
39 // Opaque and unique identifier for the MIOpen plugin.
40 extern const PluginId kMIOpenPlugin;
41 
42 struct PoolingWorkspaceDescriptor {
43   std::vector<int64_t> input_dims;
44   std::vector<int64_t> output_dims;
45   dnn::PoolingDescriptor op;
46   int dtype;
47   uint64_t timestamp;
48   std::unique_ptr<TemporaryDeviceMemory<uint8>> workspace;
49   size_t workspace_size;
50   bool IsSame(const dnn::BatchDescriptor& input_dimensions,
51               const dnn::BatchDescriptor& output_dimensions,
52               const dnn::PoolingDescriptor& pooling_dimensions, int _type);
53 };
54 
55 struct PoolingWorkspaceCache {
56   std::map<const void*, PoolingWorkspaceDescriptor> cache;
57   const int trim_size = 1000;
58   const uint64_t memory_budget = 2e7;
59   uint64_t timestamp = 0;
60   uint64_t memory_used = 0;
61   bool find(const void* p, const dnn::BatchDescriptor& input_dimensions,
62             const dnn::BatchDescriptor& output_dimensions,
63             const dnn::PoolingDescriptor& pooling_dimensions, int _type,
64             PoolingWorkspaceDescriptor*& pdesc);
65   void insert(const void* p, const dnn::BatchDescriptor& input_dimensions,
66               const dnn::BatchDescriptor& output_dimensions,
67               const dnn::PoolingDescriptor& pooling_dimensions, int _type,
68               std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace,
69               size_t wsp_size, hipStream_t hip_stream);
70 
71  private:
72   void trim(hipStream_t hip_stream);
73 };
74 
75 // miopen-library based DNN support. For details on overridden interface
76 // functions, see dnn.h.
77 class MIOpenSupport : public dnn::DnnSupport {
78  public:
79   explicit MIOpenSupport(GpuExecutor* parent);
80 
81   port::Status Init() override;
82   port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
83 
84   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
85       int num_layers, int hidden_size, int input_size, int cell_size,
86       int batch_size, dnn::RnnInputMode input_mode,
87       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
88       dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
89       float dropout, uint64_t seed, ScratchAllocator* state_allocator,
90       bool use_padded_io) override;
91 
92   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
93   createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
94                                     int data_size,
95                                     dnn::DataType data_type) override;
96 
97   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
98   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
99                                  dnn::DataType data_type) override;
100 
101   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
102                     const dnn::RnnSequenceTensorDescriptor& input_desc,
103                     const DeviceMemory<Eigen::half>& input_data,
104                     const DeviceMemory<int>& seq_lengths_data,
105                     const dnn::RnnStateTensorDescriptor& input_h_desc,
106                     const DeviceMemory<Eigen::half>& input_h_data,
107                     const dnn::RnnStateTensorDescriptor& input_c_desc,
108                     const DeviceMemory<Eigen::half>& input_c_data,
109                     const DeviceMemory<Eigen::half>& params,
110                     const dnn::RnnSequenceTensorDescriptor& output_desc,
111                     DeviceMemory<Eigen::half>* output_data,
112                     const dnn::RnnStateTensorDescriptor& output_h_desc,
113                     DeviceMemory<Eigen::half>* output_h_data,
114                     const dnn::RnnStateTensorDescriptor& output_c_desc,
115                     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
116                     ScratchAllocator* reserve_space_allocator,
117                     ScratchAllocator* workspace_allocator,
118                     dnn::ProfileResult* output_profile_result) override;
119 
120   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
121                     const dnn::RnnSequenceTensorDescriptor& input_desc,
122                     const DeviceMemory<float>& input_data,
123                     const DeviceMemory<int>& seq_lengths_data,
124                     const dnn::RnnStateTensorDescriptor& input_h_desc,
125                     const DeviceMemory<float>& input_h_data,
126                     const dnn::RnnStateTensorDescriptor& input_c_desc,
127                     const DeviceMemory<float>& input_c_data,
128                     const DeviceMemory<float>& params,
129                     const dnn::RnnSequenceTensorDescriptor& output_desc,
130                     DeviceMemory<float>* output_data,
131                     const dnn::RnnStateTensorDescriptor& output_h_desc,
132                     DeviceMemory<float>* output_h_data,
133                     const dnn::RnnStateTensorDescriptor& output_c_desc,
134                     DeviceMemory<float>* output_c_data, bool is_training,
135                     ScratchAllocator* reserve_space_allocator,
136                     ScratchAllocator* workspace_allocator,
137                     dnn::ProfileResult* output_profile_result) override;
138 
139   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
140                     const dnn::RnnSequenceTensorDescriptor& input_desc,
141                     const DeviceMemory<double>& input_data,
142                     const DeviceMemory<int>& seq_lengths_data,
143                     const dnn::RnnStateTensorDescriptor& input_h_desc,
144                     const DeviceMemory<double>& input_h_data,
145                     const dnn::RnnStateTensorDescriptor& input_c_desc,
146                     const DeviceMemory<double>& input_c_data,
147                     const DeviceMemory<double>& params,
148                     const dnn::RnnSequenceTensorDescriptor& output_desc,
149                     DeviceMemory<double>* output_data,
150                     const dnn::RnnStateTensorDescriptor& output_h_desc,
151                     DeviceMemory<double>* output_h_data,
152                     const dnn::RnnStateTensorDescriptor& output_c_desc,
153                     DeviceMemory<double>* output_c_data, bool is_training,
154                     ScratchAllocator* reserve_space_allocator,
155                     ScratchAllocator* workspace_allocator,
156                     dnn::ProfileResult* output_profile_result) override;
157 
158   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
159                      const dnn::RnnSequenceTensorDescriptor& input_desc,
160                      const DeviceMemory<Eigen::half>& input_data,
161                      const DeviceMemory<int>& seq_lengths_data,
162                      const dnn::RnnStateTensorDescriptor& input_h_desc,
163                      const DeviceMemory<Eigen::half>& input_h_data,
164                      const dnn::RnnStateTensorDescriptor& input_c_desc,
165                      const DeviceMemory<Eigen::half>& input_c_data,
166                      const DeviceMemory<Eigen::half>& params,
167                      const dnn::RnnSequenceTensorDescriptor& output_desc,
168                      const DeviceMemory<Eigen::half>& output_data,
169                      const dnn::RnnStateTensorDescriptor& output_h_desc,
170                      const DeviceMemory<Eigen::half>& output_h_data,
171                      const dnn::RnnStateTensorDescriptor& output_c_desc,
172                      const DeviceMemory<Eigen::half>& output_c_data,
173                      const DeviceMemory<Eigen::half>& output_backprop_data,
174                      const DeviceMemory<Eigen::half>& output_h_backprop_data,
175                      const DeviceMemory<Eigen::half>& output_c_backprop_data,
176                      DeviceMemory<Eigen::half>* input_backprop_data,
177                      DeviceMemory<Eigen::half>* input_h_backprop_data,
178                      DeviceMemory<Eigen::half>* input_c_backprop_data,
179                      DeviceMemory<Eigen::half>* params_backprop_data,
180                      DeviceMemory<uint8>* reserve_space_data,
181                      ScratchAllocator* workspace_allocator,
182                      dnn::ProfileResult* output_profile_result) override;
183 
184   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
185                      const dnn::RnnSequenceTensorDescriptor& input_desc,
186                      const DeviceMemory<float>& input_data,
187                      const DeviceMemory<int>& seq_lengths_data,
188                      const dnn::RnnStateTensorDescriptor& input_h_desc,
189                      const DeviceMemory<float>& input_h_data,
190                      const dnn::RnnStateTensorDescriptor& input_c_desc,
191                      const DeviceMemory<float>& input_c_data,
192                      const DeviceMemory<float>& params,
193                      const dnn::RnnSequenceTensorDescriptor& output_desc,
194                      const DeviceMemory<float>& output_data,
195                      const dnn::RnnStateTensorDescriptor& output_h_desc,
196                      const DeviceMemory<float>& output_h_data,
197                      const dnn::RnnStateTensorDescriptor& output_c_desc,
198                      const DeviceMemory<float>& output_c_data,
199                      const DeviceMemory<float>& output_backprop_data,
200                      const DeviceMemory<float>& output_h_backprop_data,
201                      const DeviceMemory<float>& output_c_backprop_data,
202                      DeviceMemory<float>* input_backprop_data,
203                      DeviceMemory<float>* input_h_backprop_data,
204                      DeviceMemory<float>* input_c_backprop_data,
205                      DeviceMemory<float>* params_backprop_data,
206                      DeviceMemory<uint8>* reserve_space_data,
207                      ScratchAllocator* workspace_allocator,
208                      dnn::ProfileResult* output_profile_result) override;
209 
210   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
211                      const dnn::RnnSequenceTensorDescriptor& input_desc,
212                      const DeviceMemory<double>& input_data,
213                      const DeviceMemory<int>& seq_lengths_data,
214                      const dnn::RnnStateTensorDescriptor& input_h_desc,
215                      const DeviceMemory<double>& input_h_data,
216                      const dnn::RnnStateTensorDescriptor& input_c_desc,
217                      const DeviceMemory<double>& input_c_data,
218                      const DeviceMemory<double>& params,
219                      const dnn::RnnSequenceTensorDescriptor& output_desc,
220                      const DeviceMemory<double>& output_data,
221                      const dnn::RnnStateTensorDescriptor& output_h_desc,
222                      const DeviceMemory<double>& output_h_data,
223                      const dnn::RnnStateTensorDescriptor& output_c_desc,
224                      const DeviceMemory<double>& output_c_data,
225                      const DeviceMemory<double>& output_backprop_data,
226                      const DeviceMemory<double>& output_h_backprop_data,
227                      const DeviceMemory<double>& output_c_backprop_data,
228                      DeviceMemory<double>* input_backprop_data,
229                      DeviceMemory<double>* input_h_backprop_data,
230                      DeviceMemory<double>* input_c_backprop_data,
231                      DeviceMemory<double>* params_backprop_data,
232                      DeviceMemory<uint8>* reserve_space_data,
233                      ScratchAllocator* workspace_allocator,
234                      dnn::ProfileResult* output_profile_result) override;
235 
236   bool GetConvolveAlgorithms(
237       CudaComputeCapability cuda_compute_capability,
238       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
239 
240   port::Status GetConvolveRunners(
241       bool use_cudnn_frontend, dnn::ConvolutionKind kind,
242       dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
243       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
244       const dnn::FilterDescriptor& filter_descriptor,
245       DeviceMemoryBase filter_data,
246       const dnn::BatchDescriptor& output_descriptor,
247       DeviceMemoryBase output_data,
248       const dnn::ConvolutionDescriptor& convolution_descriptor,
249       bool use_fallback, ScratchAllocator* scratch_allocator,
250       std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_runners)
251       override;
252 
253   port::StatusOr<std::unique_ptr<const dnn::ConvRunner>> ConvolveRunnerFromDesc(
254       Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
255       dnn::ConvolutionKind kind, dnn::DataType input_type,
256       dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor,
257       const dnn::FilterDescriptor& filter_descriptor,
258       const dnn::BatchDescriptor& output_descriptor,
259       const dnn::ConvolutionDescriptor& convolution_descriptor) override;
260 
261   bool GetMIOpenConvolveAlgorithms(
262       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
263       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
264       const dnn::FilterDescriptor& filter_descriptor,
265       DeviceMemoryBase filter_data,
266       const dnn::BatchDescriptor& output_descriptor,
267       DeviceMemoryBase output_data,
268       const dnn::ConvolutionDescriptor& convolution_descriptor,
269       ScratchAllocator* scratch_allocator,
270       std::vector<dnn::ProfileResult>* out_algorithms) override;
271 
272   bool GetRnnAlgorithms(
273       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
274 
275   bool GetConvolveBackwardDataAlgorithms(
276       CudaComputeCapability cuda_compute_capability,
277       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
278 
279   bool GetConvolveBackwardFilterAlgorithms(
280       CudaComputeCapability cuda_compute_capability,
281       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
282 
283   bool DoBatchNormalizationForward(
284       Stream* stream, const DeviceMemory<float>& x,
285       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
286       const DeviceMemory<float>& estimated_mean,
287       const DeviceMemory<float>& estimated_variance,
288       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
289       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
290       const double exponential_average_factor,
291       dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
292       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
293       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
294       bool is_training, ScratchAllocator* reserve_space_allocator,
295       ScratchAllocator* workspace_allocator) override;
296 
297   bool DoBatchNormalizationForward(
298       Stream* stream, const DeviceMemory<Eigen::half>& x,
299       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
300       const DeviceMemory<float>& estimated_mean,
301       const DeviceMemory<float>& estimated_variance,
302       const DeviceMemory<Eigen::half>& side_input,
303       const dnn::BatchDescriptor& x_desc,
304       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
305       const double exponential_average_factor,
306       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
307       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
308       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
309       bool is_training, ScratchAllocator* reserve_space_allocator,
310       ScratchAllocator* workspace_allocator) override;
311 
312   bool DoBatchNormalizationBackward(
313       Stream* stream, const DeviceMemory<float>& y_backprop,
314       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
315       const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
316       const DeviceMemory<float>& variance, const DeviceMemory<float>& y,
317       const dnn::BatchDescriptor& x_desc,
318       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
319       dnn::ActivationMode activation_mode, DeviceMemory<float>* x_backprop,
320       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
321       DeviceMemory<float>* side_input_backprop,
322       DeviceMemory<uint8>* reserve_space_data,
323       ScratchAllocator* workspace_allocator) override;
324 
325   bool DoBatchNormalizationBackward(
326       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
327       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
328       const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
329       const DeviceMemory<float>& inv_var, const DeviceMemory<Eigen::half>& y,
330       const dnn::BatchDescriptor& x_desc,
331       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
332       dnn::ActivationMode activation_mode,
333       DeviceMemory<Eigen::half>* x_backprop,
334       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
335       DeviceMemory<Eigen::half>* side_input_backprop,
336       DeviceMemory<uint8>* reserve_space_data,
337       ScratchAllocator* workspace_allocator) override;
338 
339   port::Status DoConvolve(
340       dnn::ConvolutionKind kind, dnn::DataType element_type,
341       dnn::DataType output_type, Stream* stream,
342       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
343       const dnn::FilterDescriptor& filter_descriptor,
344       DeviceMemoryBase filter_data,
345       const dnn::BatchDescriptor& output_descriptor,
346       DeviceMemoryBase output_data,
347       const dnn::ConvolutionDescriptor& convolution_descriptor,
348       dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
349       dnn::ProfileResult* output_profile_result) override;
350 
351   port::Status DoFusedConvolve(
352       Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type,
353       dnn::DataType bias_type, dnn::DataType output_type,
354       const dnn::BatchDescriptor& conv_input_descriptor,
355       DeviceMemoryBase conv_input_data, double conv_input_scale,
356       const dnn::FilterDescriptor& filter_descriptor,
357       DeviceMemoryBase filter_data,
358       const dnn::ConvolutionDescriptor& convolution_descriptor,
359       DeviceMemoryBase side_input_data, double side_input_scale,
360       const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
361       dnn::ActivationMode activation_mode,
362       const dnn::BatchDescriptor& output_descriptor,
363       DeviceMemoryBase output_data, ScratchAllocator* scratch_allocator,
364       const dnn::AlgorithmConfig& algorithm_config,
365       dnn::ProfileResult* output_profile_result) override;
366 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)367   bool DoConvolveQuantized(
368       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
369       const DeviceMemory<float>& input_data,
370       const dnn::FilterDescriptor& filter_descriptor,
371       const DeviceMemory<int8>& filter_coefficients,
372       const DeviceMemory<float>& coefficient_scales,
373       const dnn::ConvolutionDescriptor& convolution_descriptor,
374       const dnn::BatchDescriptor& output_descriptor,
375       DeviceMemory<float>* output_data) override {
376     LOG(ERROR) << "DoConvolveQuantized not supported by MIOpen";
377     return false;
378   }
379 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)380   bool DoConvolveQuantized(
381       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
382       const DeviceMemory<float>& input_data,
383       const dnn::FilterDescriptor& filter_descriptor,
384       const DeviceMemory<int16>& filter_coefficients,
385       const DeviceMemory<float>& coefficient_scales,
386       const dnn::ConvolutionDescriptor& convolution_descriptor,
387       const dnn::BatchDescriptor& output_descriptor,
388       DeviceMemory<float>* output_data) override {
389     LOG(ERROR) << "DoConvolveQuantized not supported by MIOpen";
390     return false;
391   }
392 
DoSeparableConvolve(Stream * stream,const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)393   bool DoSeparableConvolve(
394       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
395       const DeviceMemory<float>& input_data,
396       const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier,
397       const DeviceMemory<float>& first_weights,
398       const DeviceMemory<float>& second_weights,
399       const dnn::ConvolutionDescriptor& convolution_descriptor,
400       const dnn::BatchDescriptor& output_descriptor,
401       DeviceMemory<float>* output_data) override {
402     LOG(ERROR) << "separable convolution not supported by MIOpen";
403     return false;
404   }
405 
406   bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
407                 const DeviceMemory<float>& weights,
408                 const dnn::BatchDescriptor& input_dimensions,
409                 const dnn::BatchDescriptor& output_dimensions,
410                 DeviceMemory<float>* output_data) override;
411 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int8> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)412   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
413                          const DeviceMemory<int8>& quantized_weights,
414                          const DeviceMemory<float>& weight_scales,
415                          const dnn::BatchDescriptor& input_dimensions,
416                          const dnn::BatchDescriptor& output_dimensions,
417                          DeviceMemory<float>* output_data) override {
418     LOG(ERROR) << "DNN MatMulQuantized not supported by MIOpen";
419     return false;
420   }
421 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int16> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)422   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
423                          const DeviceMemory<int16>& quantized_weights,
424                          const DeviceMemory<float>& weight_scales,
425                          const dnn::BatchDescriptor& input_dimensions,
426                          const dnn::BatchDescriptor& output_dimensions,
427                          DeviceMemory<float>* output_data) override {
428     LOG(ERROR) << "DNN MatMulQuantized not supported by MIOpen";
429     return false;
430   }
431 
432   bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
433                  const DeviceMemory<float>& biases,
434                  const dnn::BatchDescriptor& dimensions,
435                  DeviceMemory<float>* output_data) override;
436 
437   bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode,
438                   const dnn::BatchDescriptor& dimensions,
439                   const DeviceMemory<float>& input_data,
440                   DeviceMemory<float>* output_data, uint64_t options) override;
441 
442   port::Status DoPoolForward(dnn::DataType element_type, Stream* stream,
443                              const dnn::PoolingDescriptor& pooling_dimensions,
444                              const dnn::BatchDescriptor& input_dimensions,
445                              DeviceMemoryBase input_data,
446                              const dnn::BatchDescriptor& output_dimensions,
447                              DeviceMemoryBase output_data,
448                              ScratchAllocator* workspace_allocator) override;
449 
450   port::Status DoPoolBackward(dnn::DataType element_type, Stream* stream,
451                               const dnn::PoolingDescriptor& pooling_dimensions,
452                               const dnn::BatchDescriptor& input_dimensions,
453                               DeviceMemoryBase input_data,
454                               const dnn::BatchDescriptor& output_dimensions,
455                               DeviceMemoryBase output_data,
456                               DeviceMemoryBase input_diff_data,
457                               DeviceMemoryBase output_diff_data,
458                               ScratchAllocator* workspace_allocator) override;
459 
460   bool DoNormalizeWithDimensions(
461       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
462       const dnn::BatchDescriptor& dimensions,
463       const DeviceMemory<float>& input_data,
464       DeviceMemory<float>* output_data) override;
465 
466   bool DoNormalizeBackwardWithDimensions(
467       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
468       const dnn::BatchDescriptor& dimensions,
469       const DeviceMemory<float>& raw_data,
470       const DeviceMemory<float>& normalized_data,
471       const DeviceMemory<float>& normalized_variable_gradient,
472       DeviceMemory<float>* raw_variable_gradient,
473       ScratchAllocator* workspace_allocator = nullptr) override;
474 
475   bool DoDepthConcatenate(
476       Stream* stream, absl::Span<const dnn::BatchDescriptor> input_dimensions,
477       absl::Span<const DeviceMemory<float>* const> input_data,
478       DeviceMemory<float>* output_data) override;
479 
480   bool DoElementwiseOperate(
481       Stream* stream, dnn::ElementwiseOperation operation,
482       absl::Span<const dnn::BatchDescriptor> input_dimensions,
483       absl::Span<const DeviceMemory<float>* const> input_data,
484       const dnn::BatchDescriptor& output_dimensions,
485       DeviceMemory<float>* output_data) override;
486 
487   bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions,
488                const DeviceMemory<float>& input_data, int64_t left_pad,
489                int64_t right_pad, int64_t top_pad, int64_t bottom_pad,
490                DeviceMemory<float>* output_data) override;
491 
492   bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions,
493                  const DeviceMemory<float>& input_data, int64_t left_trim,
494                  int64_t right_trim, int64_t top_trim, int64_t bottom_trim,
495                  DeviceMemory<float>* output_data) override;
496 
497   bool DoMemcpyD2HQuantized(Stream* stream,
498                             const DeviceMemory<float>& device_unquantized_src,
499                             dnn::QuantizedActivationMode mode, void* host_dst,
500                             int64_t size) override;
501 
502   bool DoMemcpyH2DQuantized(
503       Stream* stream, const void* host_src, int64_t size,
504       dnn::QuantizedActivationMode mode,
505       DeviceMemory<float>* device_unquantized_dst) override;
506 
507   // Derives an output batch descriptor from an input batch and convolution
508   // descriptors.
509   bool DeriveOutputBatchDescriptor(
510       const dnn::BatchDescriptor& batch_descriptor,
511       const dnn::FilterDescriptor& filter_descriptor,
512       const dnn::ConvolutionDescriptor& convolution_descriptor,
513       dnn::BatchDescriptor* output_batch_descriptor);
514 
515   bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
516                          dnn::DataType input_type,
517                          const DeviceMemoryBase& input_data,
518                          const dnn::BatchDescriptor& output_desc,
519                          dnn::DataType output_type, float scale,
520                          DeviceMemoryBase* output_data) override;
521 
522   bool DoFusedConvolutionBiasActivation(
523       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
524       const DeviceMemory<float>& conv_input_data,
525       const dnn::FilterDescriptor& filter_descriptor,
526       const DeviceMemory<float>& filter_data,
527       const dnn::ConvolutionDescriptor& convolution_descriptor,
528       const dnn::BatchDescriptor& bias_descriptor,
529       const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
530       const dnn::BatchDescriptor& output_descriptor,
531       DeviceMemory<float>* output_data,
532       dnn::ProfileResult* output_profile_result) override;
533 
534   bool DoFusedBatchNormActivationInference(
535       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
536       const DeviceMemory<float>& x_data,
537       const dnn::BatchDescriptor& scale_mean_variance_descriptor,
538       const DeviceMemory<float>& scale_data,
539       const DeviceMemory<float>& offset_data,
540       const DeviceMemory<float>& mean_data,
541       const DeviceMemory<float>& variance_data, double epsilon,
542       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
543       dnn::ProfileResult* output_profile_result) override;
544 
545   bool DoFusedBatchNormActivationInference(
546       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
547       const DeviceMemory<Eigen::half>& x_data,
548       const dnn::BatchDescriptor& scale_mean_variance_descriptor,
549       const DeviceMemory<float>& scale_data,
550       const DeviceMemory<float>& offset_data,
551       const DeviceMemory<float>& mean_data,
552       const DeviceMemory<float>& variance_data, double epsilon,
553       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
554       dnn::ProfileResult* output_profile_result) override;
555 
556   bool DoFusedBatchNormActivationForward(
557       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
558       const DeviceMemory<float>& x_data,
559       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
560       const DeviceMemory<float>& scale_data,
561       const DeviceMemory<float>& offset_data, double epsilon,
562       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
563       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
564       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
565       dnn::ProfileResult* output_profile_result) override;
566 
567   bool DoFusedBatchNormActivationForward(
568       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
569       const DeviceMemory<Eigen::half>& x_data,
570       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
571       const DeviceMemory<float>& scale_data,
572       const DeviceMemory<float>& offset_data, double epsilon,
573       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
574       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
575       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
576       dnn::ProfileResult* output_profile_result) override;
577 
578   bool DoFusedBatchNormActivationBackward(
579       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
580       const DeviceMemory<float>& y_act_backprop_data,
581       const DeviceMemory<float>& y_act_data,
582       dnn::ActivationMode activation_mode, const DeviceMemory<float>& x_bn_data,
583       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
584       const DeviceMemory<float>& scale_data,
585       const DeviceMemory<float>& offset_data,
586       const DeviceMemory<float>& saved_mean_data,
587       const DeviceMemory<float>& saved_var_data,
588       DeviceMemory<float>* x_bn_backprop_data,
589       DeviceMemory<float>* scale_backprop_data,
590       DeviceMemory<float>* offset_backprop_data,
591       dnn::ProfileResult* output_profile_result) override;
592 
593   bool DoFusedBatchNormActivationBackward(
594       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
595       const DeviceMemory<Eigen::half>& y_act_backprop_data,
596       const DeviceMemory<Eigen::half>& y_act_data,
597       dnn::ActivationMode activation_mode,
598       const DeviceMemory<Eigen::half>& x_bn_data,
599       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
600       const DeviceMemory<float>& scale_data,
601       const DeviceMemory<float>& offset_data,
602       const DeviceMemory<float>& saved_mean_data,
603       const DeviceMemory<float>& saved_var_data,
604       DeviceMemory<Eigen::half>* x_bn_backprop_data,
605       DeviceMemory<float>* scale_backprop_data,
606       DeviceMemory<float>* offset_backprop_data,
607       dnn::ProfileResult* output_profile_result) override;
608 
GetParentExecutor()609   GpuExecutor* GetParentExecutor() { return parent_; }
610 
611   port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
612                          const dnn::RnnStateTensorDescriptor& probs_desc,
613                          const DeviceMemoryBase probs_data,
614                          absl::Span<const int> labels_data,
615                          absl::Span<const int> labels_lengths_data,
616                          absl::Span<const int> input_lengths_data,
617                          DeviceMemoryBase costs_data,
618                          const dnn::RnnStateTensorDescriptor& grads_desc,
619                          DeviceMemoryBase grads_data,
620                          DeviceMemory<uint8> scratch_memory,
621                          int ctc_loss_algo_id) override;
622 
623  private:
624   GpuExecutor* parent_;  // Parent executor object. Not owned.
625 
626   // Flag to indicate whether Get*Algorithm routines should only return
627   // the best algorithm (as opposed to a list of all applicable ones)
628   bool return_best_algo_only_;
629 
630   // Flag to indicate whether to use Immediate (or Find) mode for Convolutions
631   bool use_immediate_mode_;
632 
633   // Provide access to the MIOpen handle.
634   std::unique_ptr<class MIOpenAccess> miopen_;
635 
636   PoolingWorkspaceCache m_pooling_cache;
637   bool m_pooling_cache_allowed = false;
638   bool m_pooling_cache_enabled = false;
639 
640   template <class T, class U>
641   bool DoBatchNormalizationForwardImpl(
642       Stream* stream, dnn::DataType input_data_type,
643       dnn::DataType scale_data_type, const DeviceMemory<T>& x,
644       const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
645       const DeviceMemory<U>& estimated_mean,
646       const DeviceMemory<U>& estimated_variance,
647       const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
648       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
649       const double exponential_average_factor,
650       dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
651       DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
652       DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
653       bool is_training);
654 
655   template <class T, class U>
656   bool DoBatchNormalizationBackwardImpl(
657       Stream* stream, int miopen_input_type, int miopen_scale_type,
658       const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
659       const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
660       const DeviceMemory<U>& variance, const dnn::BatchDescriptor& x_desc,
661       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
662       DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
663       DeviceMemory<U>* offset_backprop);
664 
665   template <class T>
666   bool DoRnnForwardImpl(Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
667                         const MIOpenRnnSequenceTensorDescriptor& input_desc,
668                         const DeviceMemory<T>& input_data,
669                         const MIOpenRnnStateTensorDescriptor& input_h_desc,
670                         const DeviceMemory<T>& input_h_data,
671                         const MIOpenRnnStateTensorDescriptor& input_c_desc,
672                         const DeviceMemory<T>& input_c_data,
673                         const DeviceMemory<T>& params,
674                         const MIOpenRnnSequenceTensorDescriptor& output_desc,
675                         DeviceMemory<T>* output_data,
676                         const MIOpenRnnStateTensorDescriptor& output_h_desc,
677                         DeviceMemory<T>* output_h_data,
678                         const MIOpenRnnStateTensorDescriptor& output_c_desc,
679                         DeviceMemory<T>* output_c_data, bool is_training,
680                         ScratchAllocator* reserve_space_allocator,
681                         ScratchAllocator* workspace_allocator);
682   template <class T>
683   bool DoRnnBackwardImpl(Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
684                          const MIOpenRnnSequenceTensorDescriptor& input_desc,
685                          const DeviceMemory<T>& input_data,
686                          const MIOpenRnnStateTensorDescriptor& input_h_desc,
687                          const DeviceMemory<T>& input_h_data,
688                          const MIOpenRnnStateTensorDescriptor& input_c_desc,
689                          const DeviceMemory<T>& input_c_data,
690                          const DeviceMemory<T>& params,
691                          const MIOpenRnnSequenceTensorDescriptor& output_desc,
692                          const DeviceMemory<T>& output_data,
693                          const MIOpenRnnStateTensorDescriptor& output_h_desc,
694                          const DeviceMemory<T>& output_h_data,
695                          const MIOpenRnnStateTensorDescriptor& output_c_desc,
696                          const DeviceMemory<T>& output_c_data,
697                          const DeviceMemory<T>& output_backprop_data,
698                          const DeviceMemory<T>& output_h_backprop_data,
699                          const DeviceMemory<T>& output_c_backprop_data,
700                          DeviceMemory<T>* input_backprop_data,
701                          DeviceMemory<T>* input_h_backprop_data,
702                          DeviceMemory<T>* input_c_backprop_data,
703                          DeviceMemory<T>* params_backprop_data,
704                          DeviceMemory<uint8>* reserve_space_data,
705                          ScratchAllocator* workspace_allocator);
706 
707   template <typename T>
708   bool DoFusedConvolutionBiasActivationImpl(
709       Stream* stream,
710       int miopen_type,  // Actually miopenDataType_t.
711       const dnn::BatchDescriptor& conv_input_descriptor,
712       const DeviceMemory<T>& conv_input_data,
713       const dnn::FilterDescriptor& filter_descriptor,
714       const DeviceMemory<T>& filter_data,
715       const dnn::ConvolutionDescriptor& convolution_descriptor,
716       const dnn::BatchDescriptor& bias_descriptor,
717       const DeviceMemory<T>& bias_data, dnn::ActivationMode activation_mode,
718       const dnn::BatchDescriptor& output_descriptor,
719       DeviceMemory<T>* output_data, dnn::ProfileResult* output_profile_result);
720 
721   template <typename T, typename U>
722   bool DoFusedBatchNormActivationInferenceImpl(
723       Stream* stream,
724       int miopen_type,  // Actually miopenDataType_t.
725       const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
726       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
727       const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
728       const DeviceMemory<U>& mean_data, const DeviceMemory<U>& variance_data,
729       double epsilon, dnn::ActivationMode activation_mode,
730       DeviceMemory<T>* y_data, dnn::ProfileResult* output_profile_result);
731 
732   template <typename T, typename U>
733   bool DoFusedBatchNormActivationForwardImpl(
734       Stream* stream,
735       int miopen_type,  // Actually miopenDataType_t.
736       const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
737       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
738       const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
739       double epsilon, dnn::ActivationMode activation_mode,
740       DeviceMemory<T>* y_data, DeviceMemory<U>* batch_mean_data,
741       DeviceMemory<U>* batch_var_data, DeviceMemory<U>* saved_mean_data,
742       DeviceMemory<U>* saved_var_data,
743       dnn::ProfileResult* output_profile_result);
744 
745   template <typename T, typename U>
746   bool DoFusedBatchNormActivationBackwardImpl(
747       Stream* stream,
748       int miopen_type,  // Actually miopenDataType_t.
749       const dnn::BatchDescriptor& y_act_backprop_descriptor,
750       const DeviceMemory<T>& y_act_backprop_data,
751       const DeviceMemory<T>& y_act_data, dnn::ActivationMode activation_mode,
752       const DeviceMemory<T>& x_bn_data,
753       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
754       const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
755       const DeviceMemory<U>& saved_mean_data,
756       const DeviceMemory<U>& saved_var_data,
757       DeviceMemory<T>* x_bn_backprop_data, DeviceMemory<U>* scale_backprop_data,
758       DeviceMemory<U>* offset_backprop_data,
759       dnn::ProfileResult* output_profile_result);
760 
761   port::Status DoPrepareForConvolution(
762       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
763       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
764       const dnn::FilterDescriptor& filter_descriptor,
765       DeviceMemoryBase filter_data,
766       const dnn::BatchDescriptor& output_descriptor,
767       DeviceMemoryBase output_data,
768       const dnn::ConvolutionDescriptor& convolution_descriptor,
769       const dnn::AlgorithmConfig& algorithm_config,
770       ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
771       DeviceMemory<uint8>* scratch_memory) override;
772 
773   port::Status DoCtcLossImpl(
774       Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
775       const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
776       absl::Span<const int> labels_lengths_data,
777       absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
778       const MIOpenRnnStateTensorDescriptor& grads_desc,
779       DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
780       DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
781 
782   port::Status DoPrepareForCtcLoss(
783       Stream* stream, dnn::DataType element_type,
784       const dnn::RnnStateTensorDescriptor& probs_desc,
785       const dnn::RnnStateTensorDescriptor& grads_desc,
786       absl::Span<const int> labels_data,
787       absl::Span<const int> labels_lengths_data,
788       absl::Span<const int> input_lengths_data,
789       ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
790       int* ctc_loss_algo_id) override;
791 
792   bool GetMIOpenConvolveAlgorithmsImmediateMode(
793       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
794       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
795       const dnn::FilterDescriptor& filter_descriptor,
796       DeviceMemoryBase filter_data,
797       const dnn::BatchDescriptor& output_descriptor,
798       DeviceMemoryBase output_data,
799       const dnn::ConvolutionDescriptor& convolution_descriptor,
800       ScratchAllocator* scratch_allocator,
801       std::vector<dnn::ProfileResult>* out_algorithms);
802 
803   bool GetMIOpenConvolveAlgorithmsFindMode(
804       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
805       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
806       const dnn::FilterDescriptor& filter_descriptor,
807       DeviceMemoryBase filter_data,
808       const dnn::BatchDescriptor& output_descriptor,
809       DeviceMemoryBase output_data,
810       const dnn::ConvolutionDescriptor& convolution_descriptor,
811       ScratchAllocator* scratch_allocator,
812       std::vector<dnn::ProfileResult>* out_algorithms);
813 
814   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
815 };
816 
817 }  // namespace gpu
818 }  // namespace stream_executor
819 
820 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
821