xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
22 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 namespace gpu {
29 
GetBiasDescriptor(const GpuConvConfig & config)30 se::dnn::BatchDescriptor GetBiasDescriptor(const GpuConvConfig& config) {
31   se::dnn::BatchDescriptor result(config.output_descriptor.ndims());
32   result.set_count(1)
33       .set_height(1)
34       .set_width(1)
35       .set_feature_map_count(config.output_descriptor.feature_map_count())
36       .set_layout([&] {
37         // Normalize NCHW_VECT_C to NCHW for layout of `bias`, even though it's
38         // actually the same (because `bias` only has one dimension):  cudnn
39         // does not accept NCHW_VECT_C for `bias`.
40         se::dnn::DataLayout layout = config.output_descriptor.layout();
41         switch (layout) {
42           case se::dnn::DataLayout::kBatchDepthYX4:
43           case se::dnn::DataLayout::kBatchDepthYX32:
44             return se::dnn::DataLayout::kBatchDepthYX;
45           default:
46             return layout;
47         }
48       }());
49   if (result.ndims() == 3) {
50     result.set_spatial_dim(se::dnn::DimIndex::Z, 1);
51   }
52   return result;
53 }
54 
55 namespace {
56 
57 using se::DeviceMemory;
58 using se::DeviceMemoryBase;
59 using se::Stream;
60 using se::dnn::AlgorithmConfig;
61 using se::dnn::BatchDescriptor;
62 using se::dnn::ConvolutionDescriptor;
63 using se::dnn::DataLayout;
64 using se::dnn::DimIndex;
65 using se::dnn::FilterDescriptor;
66 using se::dnn::FilterLayout;
67 using se::dnn::ProfileResult;
68 
69 template <typename ElementType, typename OutputType>
RunGpuConvUnfused(GpuConvParams params,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,DeviceMemoryBase scratch_memory)70 Status RunGpuConvUnfused(GpuConvParams params, se::Stream* stream,
71                          RunConvOptions options,
72                          DeviceMemory<ElementType> input_buf,
73                          DeviceMemory<ElementType> filter_buf,
74                          DeviceMemory<OutputType> output_buf,
75                          DeviceMemoryBase scratch_memory) {
76   if (params.config->conv_result_scale != 1) {
77     return InternalError(
78         "StreamExecutor doesn't support scaled convolution: %lf.",
79         params.config->conv_result_scale);
80   }
81 
82   TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
83                       GetDNNConvKindFromCudnnConvKind(params.config->kind));
84 
85   TF_ASSIGN_OR_RETURN(
86       se::dnn::DataType input_type,
87       GetDNNDataTypeFromPrimitiveType(params.config->input_type));
88 
89   TF_ASSIGN_OR_RETURN(
90       se::dnn::DataType output_type,
91       GetDNNDataTypeFromPrimitiveType(params.config->output_type));
92 
93   se::dnn::LazyOpRunner<se::dnn::ConvOp>* lazy_runner =
94       options.runner_cache->AsConvRunner();
95   std::optional<se::dnn::LazyOpRunner<se::dnn::ConvOp>> local_runner;
96   if (!lazy_runner) {
97     local_runner.emplace(params.config->algorithm);
98     lazy_runner = &*local_runner;
99   }
100 
101   se::dnn::ConvOp::Config config{kind,
102                                  input_type,
103                                  output_type,
104                                  params.config->input_descriptor,
105                                  params.config->filter_descriptor,
106                                  params.config->output_descriptor,
107                                  params.config->conv_desc};
108   TF_ASSIGN_OR_RETURN(auto* runner,
109                       lazy_runner->GetOrCreateRunner(config, stream));
110 
111   return (*runner)(stream, options.profile_result, scratch_memory, input_buf,
112                    filter_buf, output_buf);
113 }
114 
115 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvForwardActivation(const GpuConvParams & params,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,DeviceMemoryBase scratch_memory)116 Status RunGpuConvForwardActivation(const GpuConvParams& params,
117                                    se::Stream* stream, RunConvOptions options,
118                                    DeviceMemory<ElementType> input_buf,
119                                    DeviceMemory<ElementType> filter_buf,
120                                    DeviceMemory<OutputType> output_buf,
121                                    DeviceMemoryBase scratch_memory) {
122   BatchDescriptor bias_desc = GetBiasDescriptor(*params.config);
123 
124   se::DeviceMemory<OutputType> side_input(params.fusion->side_input_buf);
125   // If there is no side input, use output as the side input.
126   if (side_input.is_null()) {
127     if (params.config->fusion->side_input_scale != 0) {
128       return InternalError(
129           "Side input scale is not 0, yet no side input buffer is "
130           "provided");
131     }
132     // Since side-input scale is 0, the values in the side input don't
133     // matter.  The simplest thing to do would be to pass in a null buffer
134     // for the side input, but cudnn doesn't allow this.  cudnn does promise
135     // that if side-input-scale is 0 the side input won't be read, so we
136     // just pass in the output buffer, since it's handy and has the correct
137     // size.
138     side_input = output_buf;
139   }
140 
141   se::dnn::LazyOpRunner<se::dnn::FusedConvOp>* lazy_runner =
142       options.runner_cache->AsFusedConvRunner();
143   std::optional<se::dnn::LazyOpRunner<se::dnn::FusedConvOp>> local_runner;
144   if (!lazy_runner) {
145     local_runner.emplace(params.config->algorithm);
146     lazy_runner = &*local_runner;
147   }
148 
149   TF_ASSIGN_OR_RETURN(
150       se::dnn::DataType input_type,
151       GetDNNDataTypeFromPrimitiveType(params.config->input_type));
152 
153   TF_ASSIGN_OR_RETURN(
154       se::dnn::DataType output_type,
155       GetDNNDataTypeFromPrimitiveType(params.config->output_type));
156 
157   se::dnn::FusedConvOp::Config config{se::dnn::ConvolutionKind::FORWARD,
158                                       input_type,
159                                       BiasTypeForInputType(input_type),
160                                       output_type,
161                                       params.config->conv_result_scale,
162                                       params.config->fusion->side_input_scale,
163                                       /* leakyrelu_alpha = */ 0.0,
164                                       params.config->input_descriptor,
165                                       params.config->filter_descriptor,
166                                       bias_desc,
167                                       params.config->output_descriptor,
168                                       params.config->conv_desc,
169                                       params.config->fusion->mode};
170   TF_ASSIGN_OR_RETURN(auto* runner,
171                       lazy_runner->GetOrCreateRunner(config, stream));
172 
173   return (*runner)(stream, options.profile_result, scratch_memory, input_buf,
174                    filter_buf, side_input, params.fusion->bias_buf, output_buf);
175 }
176 
177 // StreamExecutor supports various data types via overloading, and the support
178 // is maintained on-demand. To avoid calling into non-exist overloads, we have
179 // to carefully not call into them by using enable_if.
180 // TODO(timshen): Ideally, to avoid such complication in the runner, we can turn
181 // StreamExecutor overloadings to template functions, and for unsupported data
182 // types return runtime errors.
183 // This is the specialization for double, float, and half types.  All kinds of
184 // convolutions are supported here.
185 template <typename ElementType, typename BiasType, typename OutputType,
186           typename std::enable_if<
187               !std::is_integral<ElementType>::value>::type* = nullptr>
RunGpuConvInternalImpl(const GpuConvParams & params,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,DeviceMemoryBase scratch_memory)188 Status RunGpuConvInternalImpl(const GpuConvParams& params, se::Stream* stream,
189                               RunConvOptions options,
190                               DeviceMemory<ElementType> input_buf,
191                               DeviceMemory<ElementType> filter_buf,
192                               DeviceMemory<OutputType> output_buf,
193                               DeviceMemoryBase scratch_memory) {
194   switch (params.config->kind) {
195     case CudnnConvKind::kForward:
196     case CudnnConvKind::kBackwardInput:
197     case CudnnConvKind::kBackwardFilter:
198       return RunGpuConvUnfused(params, stream, options, input_buf, filter_buf,
199                                output_buf, scratch_memory);
200     case CudnnConvKind::kForwardActivation: {
201       return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
202           params, stream, options, input_buf, filter_buf, output_buf,
203           scratch_memory);
204     }
205   }
206   return OkStatus();
207 }
208 
209 // Specialization for integer types.  Only two forward convolutions are allowed.
210 template <typename ElementType, typename BiasType, typename OutputType,
211           typename std::enable_if<std::is_integral<ElementType>::value>::type* =
212               nullptr>
RunGpuConvInternalImpl(const GpuConvParams & params,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,DeviceMemoryBase scratch_memory)213 Status RunGpuConvInternalImpl(const GpuConvParams& params, se::Stream* stream,
214                               RunConvOptions options,
215                               DeviceMemory<ElementType> input_buf,
216                               DeviceMemory<ElementType> filter_buf,
217                               DeviceMemory<OutputType> output_buf,
218                               DeviceMemoryBase scratch_memory) {
219   switch (params.config->kind) {
220     case CudnnConvKind::kForward:
221       return RunGpuConvUnfused(params, stream, options, input_buf, filter_buf,
222                                output_buf, scratch_memory);
223     case CudnnConvKind::kForwardActivation:
224       return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
225           params, stream, options, input_buf, filter_buf, output_buf,
226           scratch_memory);
227     default:
228       return InternalError(
229           "Only convolution kinds kForward and kForwardActivation are "
230           "supported for integer types");
231   }
232   return OkStatus();
233 }
234 
235 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvImpl(const GpuConvParams & params,se::Stream * stream,se::DeviceMemoryBase scratch_memory,RunConvOptions options)236 Status RunGpuConvImpl(const GpuConvParams& params, se::Stream* stream,
237                       se::DeviceMemoryBase scratch_memory,
238                       RunConvOptions options) {
239   auto input_buf = se::DeviceMemory<ElementType>(params.input_buf);
240   auto filter_buf = se::DeviceMemory<ElementType>(params.filter_buf);
241   auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
242 
243   se::dnn::AlgorithmDesc algorithm = params.config->algorithm;
244   if (options.runner_cache) {
245     algorithm = options.runner_cache->ToAlgorithmDesc();
246   }
247 
248   Status run_status = RunGpuConvInternalImpl<ElementType, BiasType, OutputType>(
249       params, stream, options, input_buf, filter_buf, output_buf,
250       scratch_memory);
251 
252   if (run_status != OkStatus()) {
253     return run_status;
254   }
255 
256   if (!stream->ok()) {
257     return InternalError(
258         "Unable to launch convolution with type %s and algorithm %s",
259         CudnnConvKindToString(params.config->kind), algorithm.ToString());
260   }
261   return OkStatus();
262 }
263 
GetVectCSize(DataLayout layout)264 int64_t GetVectCSize(DataLayout layout) {
265   switch (layout) {
266     case DataLayout::kBatchDepthYX4:
267       return 4;
268     case DataLayout::kBatchDepthYX32:
269       return 32;
270     default:
271       return 1;
272   }
273 }
274 
GetVectCSize(FilterLayout layout)275 int64_t GetVectCSize(FilterLayout layout) {
276   switch (layout) {
277     case FilterLayout::kOutputInputYX4:
278       return 4;
279     case FilterLayout::kOutputInputYX32:
280       return 32;
281     default:
282       return 1;
283   }
284 }
285 
286 }  // anonymous namespace
287 
GetGpuConvConfig(const GpuConvDescriptor & desc,const absl::string_view inst_as_string)288 StatusOr<GpuConvConfig> GetGpuConvConfig(
289     const GpuConvDescriptor& desc, const absl::string_view inst_as_string) {
290   GpuConvConfig config;
291 
292   const Shape& operand0_shape = desc.operand0_shape;
293   const Shape& operand1_shape = desc.operand1_shape;
294   const Shape& result_shape = desc.result_shape;
295   const CudnnConvBackendConfig& backend_config = desc.backend_config;
296 
297   config.input_type = operand0_shape.element_type();
298   config.output_type = result_shape.element_type();
299   config.kind = desc.kind;
300   config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm());
301   config.conv_result_scale = backend_config.conv_result_scale();
302 
303   switch (config.kind) {
304     case CudnnConvKind::kForward:
305     case CudnnConvKind::kForwardActivation:
306       config.input_shape = operand0_shape;
307       config.filter_shape = operand1_shape;
308       config.output_shape = result_shape;
309       break;
310     case CudnnConvKind::kBackwardInput:
311       config.input_shape = result_shape;
312       config.filter_shape = operand1_shape;
313       config.output_shape = operand0_shape;
314       break;
315     case CudnnConvKind::kBackwardFilter:
316       config.input_shape = operand0_shape;
317       config.filter_shape = result_shape;
318       config.output_shape = operand1_shape;
319       break;
320     default:
321       return InternalError("Unknown convolution kind");
322   }
323 
324   if (config.kind == CudnnConvKind::kForwardActivation) {
325     config.fusion.emplace();
326     GpuConvConfig::FusionConfig& fusion = *config.fusion;
327     if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) {
328       return InternalError("Bad activation mode: %s",
329                            backend_config.ShortDebugString());
330     }
331     fusion.mode =
332         static_cast<se::dnn::ActivationMode>(backend_config.activation_mode());
333     fusion.side_input_scale = backend_config.side_input_scale();
334   }
335 
336   const Window& window = desc.window;
337   const ConvolutionDimensionNumbers& dnums = desc.dnums;
338 
339   VLOG(3) << "Convolution Algorithm: " << config.algorithm.ToString();
340   VLOG(3) << "Convolution kind: " << CudnnConvKindToString(config.kind);
341   VLOG(3) << "input shape: "
342           << ShapeUtil::HumanStringWithLayout(config.input_shape);
343   VLOG(3) << "filter shape: "
344           << ShapeUtil::HumanStringWithLayout(config.filter_shape);
345   VLOG(3) << "Output shape: "
346           << ShapeUtil::HumanStringWithLayout(config.output_shape);
347   VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
348   VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
349 
350   const int num_dimensions = window.dimensions_size();
351   CHECK_LE(num_dimensions, 3) << inst_as_string;
352 
353   // cuDNN does not support 1D convolutions. We therefore express 1D
354   // convolutions as 2D convolutions where the first spatial dimension is 1.
355   // This matches the behavior of TF (see definition of conv1d in
356   // tensorflow/python/ops/nn_ops.py).
357   const int effective_num_dimensions = std::max(2, num_dimensions);
358 
359   // If one dimension is reversed, we need to have all dimensions reversed (so
360   // we're doing convolution not cross correlation).
361   const bool dims_reversed =
362       window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
363 
364   CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
365       << inst_as_string;
366   CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
367       << inst_as_string;
368   CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
369       << inst_as_string;
370   for (const WindowDimension& dim : window.dimensions()) {
371     CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string;
372     CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string;
373     CHECK_EQ(dim.base_dilation(), 1)
374         << "cudnn does not support base dilation; it "
375            "must be made explicit with a kPad: "
376         << inst_as_string;
377   }
378 
379   // cuDNN's convolution APIs support the BDYX layout for activations/output and
380   // the OIYX layout for weights.
381   DataLayout input_dl;
382   FilterLayout filter_dl;
383   DataLayout output_dl;
384 
385   const Shape& input_shape = config.input_shape;
386   const Shape& filter_shape = config.filter_shape;
387   const Shape& output_shape = config.output_shape;
388 
389   TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
390                       XlaConvShapesToStreamExecutorLayouts(
391                           dnums, input_shape, filter_shape, output_shape));
392 
393   BatchDescriptor& input_descriptor = config.input_descriptor;
394   input_descriptor = BatchDescriptor(effective_num_dimensions);
395   input_descriptor.set_layout(input_dl)
396       .set_feature_map_count(
397           GetVectCSize(input_dl) *
398           input_shape.dimensions(dnums.input_feature_dimension()))
399       .set_count(input_shape.dimensions(dnums.input_batch_dimension()));
400   for (int dim = 0; dim < num_dimensions; ++dim) {
401     // Note that the dimensions are reversed. The same holds below.
402     input_descriptor.set_spatial_dim(
403         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
404         input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
405   }
406 
407   FilterDescriptor& filter_descriptor = config.filter_descriptor;
408   filter_descriptor = FilterDescriptor(effective_num_dimensions);
409   filter_descriptor.set_layout(filter_dl)
410       .set_input_feature_map_count(
411           GetVectCSize(filter_dl) *
412           filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
413       .set_output_feature_map_count(
414           filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
415   for (int dim = 0; dim < num_dimensions; ++dim) {
416     filter_descriptor.set_spatial_dim(
417         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
418         filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
419   }
420 
421   config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
422   config.conv_desc.set_group_count(desc.feature_group_count);
423   config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
424   for (int dim = 0; dim < num_dimensions; ++dim) {
425     config.conv_desc
426         .set_zero_padding(
427             static_cast<DimIndex>(effective_num_dimensions - dim - 1),
428             window.dimensions(dim).padding_low())
429         .set_filter_stride(
430             static_cast<DimIndex>(effective_num_dimensions - dim - 1),
431             window.dimensions(dim).stride())
432         .set_dilation_rate(
433             static_cast<DimIndex>(effective_num_dimensions - dim - 1),
434             window.dimensions(dim).window_dilation());
435   }
436 
437   BatchDescriptor& output_descriptor = config.output_descriptor;
438   output_descriptor = BatchDescriptor(effective_num_dimensions);
439   output_descriptor.set_layout(output_dl)
440       .set_feature_map_count(
441           GetVectCSize(output_dl) *
442           output_shape.dimensions(dnums.output_feature_dimension()))
443       .set_count(output_shape.dimensions(dnums.output_batch_dimension()));
444   for (int dim = 0; dim < num_dimensions; ++dim) {
445     output_descriptor.set_spatial_dim(
446         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
447         output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
448   }
449 
450   // Add a singleton dimension in the 1D convolution case.
451   for (int dim = 0; dim < effective_num_dimensions - num_dimensions; dim++) {
452     input_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
453     output_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
454     filter_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
455     config.conv_desc.set_zero_padding(static_cast<DimIndex>(dim), 0)
456         .set_filter_stride(static_cast<DimIndex>(dim), 1);
457   }
458 
459   return config;
460 }
461 
GetGpuConvConfig(const HloCustomCallInstruction * cudnn_call)462 StatusOr<GpuConvConfig> GetGpuConvConfig(
463     const HloCustomCallInstruction* cudnn_call) {
464   GpuConvDescriptor descriptor;
465 
466   TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call));
467   TF_ASSIGN_OR_RETURN(descriptor.backend_config,
468                       cudnn_call->backend_config<CudnnConvBackendConfig>());
469   descriptor.operand0_shape = cudnn_call->operand(0)->shape();
470   descriptor.operand1_shape = cudnn_call->operand(1)->shape();
471   descriptor.result_shape = cudnn_call->shape().tuple_shapes(0);
472   descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0);
473   descriptor.window = cudnn_call->window();
474   descriptor.dnums = cudnn_call->convolution_dimension_numbers();
475   descriptor.feature_group_count = cudnn_call->feature_group_count();
476   return GetGpuConvConfig(descriptor, cudnn_call->ToString());
477 }
478 
GetGpuConvParams(const GpuConvConfig & config,absl::Span<const se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer)479 StatusOr<GpuConvParams> GetGpuConvParams(
480     const GpuConvConfig& config,
481     absl::Span<const se::DeviceMemoryBase> operand_buffers,
482     se::DeviceMemoryBase result_buffer) {
483   GpuConvParams params;
484   params.config = &config;
485 
486   switch (config.kind) {
487     case CudnnConvKind::kForward:
488     case CudnnConvKind::kForwardActivation:
489       params.input_buf = operand_buffers[0];
490       params.filter_buf = operand_buffers[1];
491       params.output_buf = result_buffer;
492       break;
493     case CudnnConvKind::kBackwardInput:
494       params.input_buf = result_buffer;
495       params.filter_buf = operand_buffers[1];
496       params.output_buf = operand_buffers[0];
497       break;
498     case CudnnConvKind::kBackwardFilter:
499       params.input_buf = operand_buffers[0];
500       params.filter_buf = result_buffer;
501       params.output_buf = operand_buffers[1];
502       break;
503   }
504 
505   if (config.kind == CudnnConvKind::kForwardActivation) {
506     params.fusion.emplace();
507     GpuConvParams::FusionParams& fusion = *params.fusion;
508     fusion.bias_buf = operand_buffers[2];
509     if (operand_buffers.size() >= 4) {
510       fusion.side_input_buf = operand_buffers[3];
511     }
512   }
513 
514   return params;
515 }
516 
RunGpuConv(const gpu::GpuConvConfig & config,absl::Span<const se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::DeviceMemoryBase scratch_memory,se::Stream * stream,RunConvOptions options)517 Status RunGpuConv(const gpu::GpuConvConfig& config,
518                   absl::Span<const se::DeviceMemoryBase> operand_buffers,
519                   se::DeviceMemoryBase result_buffer,
520                   se::DeviceMemoryBase scratch_memory, se::Stream* stream,
521                   RunConvOptions options) {
522   TF_ASSIGN_OR_RETURN(GpuConvParams params,
523                       GetGpuConvParams(config, operand_buffers, result_buffer));
524 
525   PrimitiveType input_primitive_type = config.input_type;
526   switch (input_primitive_type) {
527     case F16:
528       return RunGpuConvImpl<Eigen::half, Eigen::half, Eigen::half>(
529           params, stream, scratch_memory, options);
530     case BF16:
531       return RunGpuConvImpl<Eigen::bfloat16, Eigen::bfloat16, Eigen::bfloat16>(
532           params, stream, scratch_memory, options);
533     case F32:
534       return RunGpuConvImpl<float, float, float>(params, stream, scratch_memory,
535                                                  options);
536     case F64:
537       return RunGpuConvImpl<double, double, double>(params, stream,
538                                                     scratch_memory, options);
539     case S8: {
540       PrimitiveType output_primitive_type = config.output_type;
541       switch (output_primitive_type) {
542         case F32:
543           return RunGpuConvImpl<int8_t, float, float>(params, stream,
544                                                       scratch_memory, options);
545         case S8:
546           return RunGpuConvImpl<int8_t, float, int8_t>(params, stream,
547                                                        scratch_memory, options);
548         default:
549           return Unimplemented("Unimplemented convolution");
550       }
551     }
552     default:
553       return Unimplemented("Unimplemented convolution");
554   }
555 }
556 
557 }  // namespace gpu
558 }  // namespace xla
559