xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/conv_ops_gpu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ------------------------------------------------------------------------------*/
15 
16 #include "tensorflow/core/kernels/conv_ops_gpu.h"
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 
20 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
21 #include "tensorflow/core/protobuf/autotuning.pb.h"
22 #include "tensorflow/core/util/proto/proto_utils.h"
23 #include "tensorflow/core/util/use_cudnn.h"
24 
25 #if GOOGLE_CUDA
26 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
27 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
28 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
29 #endif  // GOOGLE_CUDA
30 
31 namespace tensorflow {
32 
33 #if GOOGLE_CUDA
34 namespace {
35 
36 template <typename LaunchFunc, typename Sig>
AutotuneConvImpl(OpKernelContext * ctx,std::vector<std::unique_ptr<const se::dnn::OpRunner<Sig>>> & runners,bool actually_do_autotune,const LaunchFunc & launch_func,size_t scratch_size_limit,const se::RedzoneAllocator & rz_allocator)37 StatusOr<std::vector<tensorflow::AutotuneResult>> AutotuneConvImpl(
38     OpKernelContext* ctx,
39     std::vector<std::unique_ptr<const se::dnn::OpRunner<Sig>>>& runners,
40     bool actually_do_autotune, const LaunchFunc& launch_func,
41     size_t scratch_size_limit, const se::RedzoneAllocator& rz_allocator) {
42   auto* stream = ctx->op_device_context()->stream();
43 
44   se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}),
45                                               stream);
46 
47   std::vector<tensorflow::AutotuneResult> results;
48   // TODO(reedwm): Warn if determinism is enabled after autotune is run
49   for (auto& runner : runners) {
50     // TODO(zhengxq): profile each algorithm multiple times to better
51     // accuracy.
52     se::RedzoneAllocator rz_scratch_allocator(
53         stream, &tf_allocator_adapter, se::GpuAsmOpts(),
54         /*memory_limit=*/scratch_size_limit);
55     DnnScratchAllocator scratch_allocator(scratch_size_limit, ctx);
56     se::ScratchAllocator* allocator_used =
57         !RedzoneCheckDisabled()
58             ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
59             : static_cast<se::ScratchAllocator*>(&scratch_allocator);
60 
61     TF_ASSIGN_OR_RETURN(auto desc, runner->ToAlgorithmDesc());
62     se::dnn::ProfileResult profile_result;
63     Status cudnn_launch_status =
64         actually_do_autotune
65             ? launch_func(allocator_used, runner, &profile_result)
66             : OkStatus();
67     if (!actually_do_autotune) {
68       // Make the result valid according to `is_valid`.
69       profile_result.set_algorithm(desc);
70       profile_result.set_elapsed_time_in_ms(0);
71     }
72 
73     // We need to make sure the profiling results are one-to-one with the
74     // "runners". So, we insert dummy results when the execution fails.
75     results.emplace_back();
76     auto& result = results.back();
77     *result.mutable_algorithm() = desc.ToProto();
78     if (cudnn_launch_status.ok() && profile_result.is_valid()) {
79       result.set_scratch_bytes(
80           !RedzoneCheckDisabled()
81               ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
82               : scratch_allocator.TotalByteSize());
83       *result.mutable_run_time() = proto_utils::ToDurationProto(
84           absl::Milliseconds(profile_result.elapsed_time_in_ms()));
85 
86       CheckRedzones(rz_scratch_allocator, &result);
87       CheckRedzones(rz_allocator, &result);
88     } else {
89       result.mutable_failure()->set_kind(AutotuneResult::UNKNOWN);
90       result.mutable_failure()->set_msg(
91           absl::StrCat("Profiling failure on CUDNN engine ", desc.ToString(),
92                        ": ", cudnn_launch_status.ToString()));
93     }
94   }
95 
96   return results;
97 }
98 }  // namespace
99 #endif  // GOOGLE_CUDA
100 
101 // Finds the best convolution algorithm for the given ConvLaunch (cuda
102 // convolution on the stream) and parameters, by running all possible
103 // algorithms and measuring execution time.
104 template <typename T>
AutotuneFusedConv(bool cudnn_use_autotune,AutotuneMap<ConvParameters,AutotuneEntry<se::dnn::FusedConvOp>> * autotune_map,const ConvParameters & params,OpKernelContext * ctx,const se::dnn::BatchDescriptor & input_desc,const se::dnn::FilterDescriptor & filter_desc,const se::dnn::BatchDescriptor & bias_desc,const se::dnn::BatchDescriptor & output_desc,const se::dnn::ConvolutionDescriptor & conv_desc,const se::dnn::ActivationMode activation_mode,double conv_scale,double side_input_scale,double leakyrelu_alpha,se::DeviceMemory<T> input_ptr,se::DeviceMemory<T> filter_ptr,se::DeviceMemory<T> output_ptr,se::DeviceMemory<T> bias_ptr,se::DeviceMemory<T> side_input_ptr,int64_t scratch_size_limit)105 StatusOr<AutotuneEntry<se::dnn::FusedConvOp>> AutotuneFusedConv(
106     bool cudnn_use_autotune,
107     AutotuneMap<ConvParameters, AutotuneEntry<se::dnn::FusedConvOp>>*
108         autotune_map,
109     const ConvParameters& params, OpKernelContext* ctx,
110     const se::dnn::BatchDescriptor& input_desc,
111     const se::dnn::FilterDescriptor& filter_desc,
112     const se::dnn::BatchDescriptor& bias_desc,
113     const se::dnn::BatchDescriptor& output_desc,
114     const se::dnn::ConvolutionDescriptor& conv_desc,
115     const se::dnn::ActivationMode activation_mode, double conv_scale,
116     double side_input_scale, double leakyrelu_alpha,
117     se::DeviceMemory<T> input_ptr, se::DeviceMemory<T> filter_ptr,
118     se::DeviceMemory<T> output_ptr, se::DeviceMemory<T> bias_ptr,
119     se::DeviceMemory<T> side_input_ptr, int64_t scratch_size_limit) {
120 #if GOOGLE_CUDA
121   AutotuneEntry<se::dnn::FusedConvOp> autotune_entry;
122   auto* stream = ctx->op_device_context()->stream();
123 
124   if (!autotune_map->Find(params, &autotune_entry)) {
125     profiler::ScopedAnnotation trace("cudnn_autotuning");
126 
127     se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}),
128                                                 stream);
129     se::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
130                                       se::GpuAsmOpts());
131     se::DeviceMemory<T> output_ptr_rz(
132         WrapRedzoneBestEffort(&rz_allocator, output_ptr));
133 
134     std::vector<std::unique_ptr<const se::dnn::FusedConvRunner>> runners;
135     auto element_type = se::dnn::ToDataType<T>::value;
136     TF_RETURN_IF_ERROR(stream->parent()->GetFusedConvolveRunners(
137         CudnnUseFrontend(), se::dnn::ConvolutionKind::FORWARD, element_type,
138         element_type, element_type, conv_scale, side_input_scale,
139         leakyrelu_alpha, stream, input_desc, filter_desc, bias_desc,
140         output_desc, conv_desc, /*use_fallback=*/false, activation_mode,
141         &runners));
142 
143     auto launch_func =
144         [&](se::ScratchAllocator* allocator_used,
145             const std::unique_ptr<const se::dnn::FusedConvRunner>& runner,
146             se::dnn::ProfileResult* profile_result) -> Status {
147       TF_ASSIGN_OR_RETURN(auto scratch, allocator_used->AllocateBytes(
148                                             runner->GetWorkspaceSize()));
149       return (*runner)(stream, profile_result, scratch, input_ptr, filter_ptr,
150                        side_input_ptr, bias_ptr, output_ptr_rz);
151     };
152 
153     TF_ASSIGN_OR_RETURN(
154         auto results,
155         AutotuneConvImpl(ctx, runners, cudnn_use_autotune, launch_func,
156                          scratch_size_limit, rz_allocator));
157     // Only log on an AutotuneConv cache miss.
158     LogFusedConvForwardAutotuneResults(
159         se::dnn::ToDataType<T>::value, input_ptr, filter_ptr, output_ptr,
160         bias_ptr, side_input_ptr, input_desc, filter_desc, output_desc,
161         conv_desc, conv_scale, side_input_scale, activation_mode,
162         stream->parent(), results);
163 
164     // Two-level autotuning: Cudnn frontend supports two engine lists:
165     // heuristics and fallback. Heuristics engines are normally faster.
166     // To reduce autotuning time, we evaluate the fallback engines only when
167     // none of the heuristics engines work.
168     bool found_working_engine = false;
169     for (auto& result : results) {
170       if (!result.has_failure()) {
171         found_working_engine = true;
172         break;
173       }
174     }
175 
176     if (!CudnnUseFrontend() || found_working_engine) {
177       TF_ASSIGN_OR_RETURN(autotune_entry,
178                           BestCudnnConvAlgorithm<se::dnn::FusedConvOp>(
179                               results, std::move(runners)));
180     } else {
181       LOG(WARNING)
182           << "None of the algorithms provided by cuDNN frontend heuristics "
183              "worked; trying fallback algorithms.  Conv: "
184           << params.ToString();
185       std::vector<std::unique_ptr<const se::dnn::FusedConvRunner>>
186           fallback_runners;
187       TF_RETURN_IF_ERROR(stream->parent()->GetFusedConvolveRunners(
188           CudnnUseFrontend(), se::dnn::ConvolutionKind::FORWARD, element_type,
189           element_type, element_type, conv_scale, side_input_scale,
190           leakyrelu_alpha, stream, input_desc, filter_desc, bias_desc,
191           output_desc, conv_desc, /*use_fallback=*/true, activation_mode,
192           &fallback_runners));
193 
194       TF_ASSIGN_OR_RETURN(
195           auto fallback_results,
196           AutotuneConvImpl(ctx, fallback_runners, cudnn_use_autotune,
197                            launch_func, scratch_size_limit, rz_allocator));
198 
199       LogFusedConvForwardAutotuneResults(
200           se::dnn::ToDataType<T>::value, input_ptr, filter_ptr, output_ptr,
201           bias_ptr, side_input_ptr, input_desc, filter_desc, output_desc,
202           conv_desc, conv_scale, side_input_scale, activation_mode,
203           stream->parent(), fallback_results);
204 
205       TF_ASSIGN_OR_RETURN(autotune_entry,
206                           BestCudnnConvAlgorithm<se::dnn::FusedConvOp>(
207                               fallback_results, std::move(fallback_runners)));
208     }
209 
210     autotune_map->Insert(params, autotune_entry);
211   }
212   return autotune_entry;
213 #else
214   return errors::Unimplemented(
215       "Fused conv not implemented on non-CUDA platforms.");
216 #endif
217 }
218 
219 template StatusOr<AutotuneEntry<se::dnn::FusedConvOp>>
220 AutotuneFusedConv<double>(
221     bool cudnn_use_autotune,
222     AutotuneMap<ConvParameters, AutotuneEntry<se::dnn::FusedConvOp>>*
223         autotune_map,
224     const ConvParameters& params, OpKernelContext* ctx,
225     const se::dnn::BatchDescriptor& input_desc,
226     const se::dnn::FilterDescriptor& filter_desc,
227     const se::dnn::BatchDescriptor& bias_desc,
228     const se::dnn::BatchDescriptor& output_desc,
229     const se::dnn::ConvolutionDescriptor& conv_desc,
230     const se::dnn::ActivationMode activation_mode, double conv_scale,
231     double side_input_scale, double leakyrelu_alpha,
232     se::DeviceMemory<double> input_ptr, se::DeviceMemory<double> filter_ptr,
233     se::DeviceMemory<double> output_ptr, se::DeviceMemory<double> bias_ptr,
234     se::DeviceMemory<double> side_input_ptr, int64_t scratch_size_limit);
235 
236 template StatusOr<AutotuneEntry<se::dnn::FusedConvOp>> AutotuneFusedConv<float>(
237     bool cudnn_use_autotune,
238     AutotuneMap<ConvParameters, AutotuneEntry<se::dnn::FusedConvOp>>*
239         autotune_map,
240     const ConvParameters& params, OpKernelContext* ctx,
241     const se::dnn::BatchDescriptor& input_desc,
242     const se::dnn::FilterDescriptor& filter_desc,
243     const se::dnn::BatchDescriptor& bias_desc,
244     const se::dnn::BatchDescriptor& output_desc,
245     const se::dnn::ConvolutionDescriptor& conv_desc,
246     const se::dnn::ActivationMode activation_mode, double conv_scale,
247     double side_input_scale, double leakyrelu_alpha,
248     se::DeviceMemory<float> input_ptr, se::DeviceMemory<float> filter_ptr,
249     se::DeviceMemory<float> output_ptr, se::DeviceMemory<float> bias_ptr,
250     se::DeviceMemory<float> side_input_ptr, int64_t scratch_size_limit);
251 
252 template StatusOr<AutotuneEntry<se::dnn::FusedConvOp>>
253 AutotuneFusedConv<Eigen::half>(
254     bool cudnn_use_autotune,
255     AutotuneMap<ConvParameters, AutotuneEntry<se::dnn::FusedConvOp>>*
256         autotune_map,
257     const ConvParameters& params, OpKernelContext* ctx,
258     const se::dnn::BatchDescriptor& input_desc,
259     const se::dnn::FilterDescriptor& filter_desc,
260     const se::dnn::BatchDescriptor& bias_desc,
261     const se::dnn::BatchDescriptor& output_desc,
262     const se::dnn::ConvolutionDescriptor& conv_desc,
263     const se::dnn::ActivationMode activation_mode, double conv_scale,
264     double side_input_scale, double leakyrelu_alpha,
265     se::DeviceMemory<Eigen::half> input_ptr,
266     se::DeviceMemory<Eigen::half> filter_ptr,
267     se::DeviceMemory<Eigen::half> output_ptr,
268     se::DeviceMemory<Eigen::half> bias_ptr,
269     se::DeviceMemory<Eigen::half> side_input_ptr, int64_t scratch_size_limit);
270 
271 template <typename T>
AutotuneUnfusedConv(bool cudnn_use_autotune,AutotuneMap<ConvParameters,AutotuneEntry<se::dnn::ConvOp>> * autotune_map,const ConvParameters & conv_parameters,OpKernelContext * ctx,se::dnn::ConvolutionKind kind,const se::dnn::BatchDescriptor & input_desc,se::DeviceMemory<T> input_ptr,const se::dnn::FilterDescriptor & filter_desc,se::DeviceMemory<T> filter_ptr,const se::dnn::ConvolutionDescriptor & conv_desc,const se::dnn::BatchDescriptor & output_desc,se::DeviceMemory<T> output_ptr,int64_t scratch_size_limit)272 StatusOr<AutotuneEntry<se::dnn::ConvOp>> AutotuneUnfusedConv(
273     bool cudnn_use_autotune,
274     AutotuneMap<ConvParameters, AutotuneEntry<se::dnn::ConvOp>>* autotune_map,
275     const ConvParameters& conv_parameters, OpKernelContext* ctx,
276     se::dnn::ConvolutionKind kind, const se::dnn::BatchDescriptor& input_desc,
277     se::DeviceMemory<T> input_ptr, const se::dnn::FilterDescriptor& filter_desc,
278     se::DeviceMemory<T> filter_ptr,
279     const se::dnn::ConvolutionDescriptor& conv_desc,
280     const se::dnn::BatchDescriptor& output_desc, se::DeviceMemory<T> output_ptr,
281     int64_t scratch_size_limit) {
282   AutotuneEntry<se::dnn::ConvOp> autotune_entry;
283 
284   auto* stream = ctx->op_device_context()->stream();
285 
286   if (!autotune_map->Find(conv_parameters, &autotune_entry)) {
287     profiler::ScopedAnnotation annotation("cudnn_autotuning");
288 
289 #if GOOGLE_CUDA
290     se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}),
291                                                 stream);
292     se::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
293                                       se::GpuAsmOpts());
294 
295     // TODO(awpr): second-guess whether it's okay that this profiles
296     // convolutions on uninitialized memory.
297     switch (kind) {
298       case se::dnn::ConvolutionKind::FORWARD:
299       case se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION:
300         output_ptr = se::DeviceMemory<T>(
301             WrapRedzoneBestEffort(&rz_allocator, output_ptr));
302         break;
303       case se::dnn::ConvolutionKind::BACKWARD_DATA:
304         input_ptr = se::DeviceMemory<T>(
305             WrapRedzoneBestEffort(&rz_allocator, input_ptr));
306         break;
307       case se::dnn::ConvolutionKind::BACKWARD_FILTER:
308         filter_ptr = se::DeviceMemory<T>(
309             WrapRedzoneBestEffort(&rz_allocator, filter_ptr));
310         break;
311       default:
312         return errors::InvalidArgument(
313             absl::StrFormat("Unknown ConvolutionKind %d", kind));
314     }
315 
316     const auto element_type = se::dnn::ToDataType<T>::value;
317     std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners;
318     TF_RETURN_IF_ERROR(stream->parent()->GetConvolveRunners(
319         CudnnUseFrontend(), kind, element_type, element_type, stream,
320         input_desc, input_ptr, filter_desc, filter_ptr, output_desc, output_ptr,
321         conv_desc, /*use_fallback=*/false, &rz_allocator, &runners));
322     auto launch_func =
323         [&](se::ScratchAllocator* allocator_used,
324             const std::unique_ptr<const se::dnn::ConvRunner>& runner,
325             se::dnn::ProfileResult* profile_result) -> Status {
326       TF_ASSIGN_OR_RETURN(auto scratch, allocator_used->AllocateBytes(
327                                             runner->GetWorkspaceSize()));
328       return (*runner)(stream, profile_result, scratch, input_ptr, filter_ptr,
329                        output_ptr);
330     };
331     TF_ASSIGN_OR_RETURN(
332         auto results,
333         AutotuneConvImpl(ctx, runners, cudnn_use_autotune, launch_func,
334                          scratch_size_limit, rz_allocator));
335 
336     LogConvAutotuneResults(kind, se::dnn::ToDataType<T>::value, input_ptr,
337                            filter_ptr, output_ptr, input_desc, filter_desc,
338                            output_desc, conv_desc, stream->parent(), results);
339 
340     // Two-level autotuning: Cudnn frontend supports two engine lists:
341     // heuristics and fallback. Heuristics engines are normally faster.
342     // To reduce autotuning time, we evaluate the fallback engines only when
343     // none of the heuristics engines work.
344     bool found_working_engine = false;
345     for (auto& result : results) {
346       if (!result.has_failure()) {
347         found_working_engine = true;
348         break;
349       }
350     }
351 
352     if (!CudnnUseFrontend() || found_working_engine) {
353       TF_ASSIGN_OR_RETURN(
354           autotune_entry,
355           BestCudnnConvAlgorithm<se::dnn::ConvOp>(results, std::move(runners)));
356     } else {
357       LOG(WARNING)
358           << "None of the algorithms provided by cuDNN frontend heuristics "
359              "worked; trying fallback algorithms.  Conv: "
360           << conv_parameters.ToString();
361       std::vector<std::unique_ptr<const se::dnn::ConvRunner>> fallback_runners;
362       TF_RETURN_IF_ERROR(stream->parent()->GetConvolveRunners(
363           CudnnUseFrontend(), kind, element_type, element_type, stream,
364           input_desc, input_ptr, filter_desc, filter_ptr, output_desc,
365           output_ptr, conv_desc, /*use_fallback=*/true, &rz_allocator,
366           &fallback_runners));
367 
368       TF_ASSIGN_OR_RETURN(
369           auto fallback_results,
370           AutotuneConvImpl(ctx, fallback_runners, cudnn_use_autotune,
371                            launch_func, scratch_size_limit, rz_allocator));
372 
373       LogConvAutotuneResults(kind, se::dnn::ToDataType<T>::value, input_ptr,
374                              filter_ptr, output_ptr, input_desc, filter_desc,
375                              output_desc, conv_desc, stream->parent(),
376                              fallback_results);
377 
378       TF_ASSIGN_OR_RETURN(autotune_entry,
379                           BestCudnnConvAlgorithm<se::dnn::ConvOp>(
380                               fallback_results, std::move(fallback_runners)));
381     }
382 
383 #elif TENSORFLOW_USE_ROCM
384     DnnScratchAllocator scratch_allocator(scratch_size_limit, ctx);
385 
386     std::vector<se::dnn::ProfileResult> algorithms;
387     if (!stream->parent()->GetMIOpenConvolveAlgorithms(
388             kind, se::dnn::ToDataType<T>::value, stream, input_desc, input_ptr,
389             filter_desc, filter_ptr, output_desc, output_ptr, conv_desc,
390             &scratch_allocator, &algorithms)) {
391       return errors::Unknown(
392           "Failed to get convolution algorithm. This is probably "
393           "because MIOpen failed to initialize, so try looking to "
394           "see if a warning log message was printed above.");
395     }
396 
397     std::vector<tensorflow::AutotuneResult> results;
398     if (algorithms.size() == 1) {
399       auto profile_result = algorithms[0];
400       results.emplace_back();
401       auto& result = results.back();
402       *result.mutable_algorithm() = profile_result.algorithm().ToProto();
403 
404       result.set_scratch_bytes(profile_result.scratch_size());
405       *result.mutable_run_time() = proto_utils::ToDurationProto(
406           absl::Milliseconds(profile_result.elapsed_time_in_ms()));
407     } else {
408       for (auto miopen_algorithm : algorithms) {
409         auto profile_algorithm = miopen_algorithm.algorithm();
410         se::dnn::ProfileResult profile_result;
411         auto miopen_launch_status = stream->ConvolveWithAlgorithm(
412             kind, input_desc, input_ptr, filter_desc, filter_ptr, output_desc,
413             output_ptr, conv_desc, &scratch_allocator,
414             se::dnn::AlgorithmConfig(profile_algorithm,
415                                      miopen_algorithm.scratch_size()),
416             &profile_result);
417         if (miopen_launch_status.ok() && profile_result.is_valid()) {
418           results.emplace_back();
419           auto& result = results.back();
420           *result.mutable_algorithm() = profile_algorithm.ToProto();
421 
422           result.set_scratch_bytes(scratch_allocator.TotalByteSize());
423           *result.mutable_run_time() = proto_utils::ToDurationProto(
424               absl::Milliseconds(profile_result.elapsed_time_in_ms()));
425         }
426       }
427     }
428     LogConvAutotuneResults(kind, se::dnn::ToDataType<T>::value, input_ptr,
429                            filter_ptr, output_ptr, input_desc, filter_desc,
430                            output_desc, conv_desc, stream->parent(), results);
431 
432     TF_ASSIGN_OR_RETURN(auto algo_desc, BestCudnnConvAlgorithm(results));
433     autotune_entry = AutotuneEntry<se::dnn::ConvOp>(algo_desc);
434 #endif
435 
436     autotune_map->Insert(conv_parameters, autotune_entry);
437   }
438 
439   return autotune_entry;
440 }
441 
442 template StatusOr<AutotuneEntry<se::dnn::ConvOp>> AutotuneUnfusedConv<double>(
443     bool cudnn_use_autotune,
444     AutotuneMap<ConvParameters, AutotuneEntry<se::dnn::ConvOp>>* autotune_map,
445     const ConvParameters& conv_parameters, OpKernelContext* ctx,
446     se::dnn::ConvolutionKind kind, const se::dnn::BatchDescriptor& input_desc,
447     se::DeviceMemory<double> input_ptr,
448     const se::dnn::FilterDescriptor& filter_desc,
449     se::DeviceMemory<double> filter_ptr,
450     const se::dnn::ConvolutionDescriptor& conv_desc,
451     const se::dnn::BatchDescriptor& output_desc,
452     se::DeviceMemory<double> output_ptr, int64_t scratch_size_limit);
453 
454 template StatusOr<AutotuneEntry<se::dnn::ConvOp>> AutotuneUnfusedConv<float>(
455     bool cudnn_use_autotune,
456     AutotuneMap<ConvParameters, AutotuneEntry<se::dnn::ConvOp>>* autotune_map,
457     const ConvParameters& conv_parameters, OpKernelContext* ctx,
458     se::dnn::ConvolutionKind kind, const se::dnn::BatchDescriptor& input_desc,
459     se::DeviceMemory<float> input_ptr,
460     const se::dnn::FilterDescriptor& filter_desc,
461     se::DeviceMemory<float> filter_ptr,
462     const se::dnn::ConvolutionDescriptor& conv_desc,
463     const se::dnn::BatchDescriptor& output_desc,
464     se::DeviceMemory<float> output_ptr, int64_t scratch_size_limit);
465 
466 template StatusOr<AutotuneEntry<se::dnn::ConvOp>>
467 AutotuneUnfusedConv<Eigen::half>(
468     bool cudnn_use_autotune,
469     AutotuneMap<ConvParameters, AutotuneEntry<se::dnn::ConvOp>>* autotune_map,
470     const ConvParameters& conv_parameters, OpKernelContext* ctx,
471     se::dnn::ConvolutionKind kind, const se::dnn::BatchDescriptor& input_desc,
472     se::DeviceMemory<Eigen::half> input_ptr,
473     const se::dnn::FilterDescriptor& filter_desc,
474     se::DeviceMemory<Eigen::half> filter_ptr,
475     const se::dnn::ConvolutionDescriptor& conv_desc,
476     const se::dnn::BatchDescriptor& output_desc,
477     se::DeviceMemory<Eigen::half> output_ptr, int64_t scratch_size_limit);
478 
479 }  // namespace tensorflow
480 
481 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
482