1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
22 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28 namespace gpu {
29
GetBiasDescriptor(const GpuConvConfig & config)30 se::dnn::BatchDescriptor GetBiasDescriptor(const GpuConvConfig& config) {
31 se::dnn::BatchDescriptor result(config.output_descriptor.ndims());
32 result.set_count(1)
33 .set_height(1)
34 .set_width(1)
35 .set_feature_map_count(config.output_descriptor.feature_map_count())
36 .set_layout([&] {
37 // Normalize NCHW_VECT_C to NCHW for layout of `bias`, even though it's
38 // actually the same (because `bias` only has one dimension): cudnn
39 // does not accept NCHW_VECT_C for `bias`.
40 se::dnn::DataLayout layout = config.output_descriptor.layout();
41 switch (layout) {
42 case se::dnn::DataLayout::kBatchDepthYX4:
43 case se::dnn::DataLayout::kBatchDepthYX32:
44 return se::dnn::DataLayout::kBatchDepthYX;
45 default:
46 return layout;
47 }
48 }());
49 if (result.ndims() == 3) {
50 result.set_spatial_dim(se::dnn::DimIndex::Z, 1);
51 }
52 return result;
53 }
54
55 namespace {
56
57 using se::DeviceMemory;
58 using se::DeviceMemoryBase;
59 using se::Stream;
60 using se::dnn::AlgorithmConfig;
61 using se::dnn::BatchDescriptor;
62 using se::dnn::ConvolutionDescriptor;
63 using se::dnn::DataLayout;
64 using se::dnn::DimIndex;
65 using se::dnn::FilterDescriptor;
66 using se::dnn::FilterLayout;
67 using se::dnn::ProfileResult;
68
69 template <typename ElementType, typename OutputType>
RunGpuConvUnfused(GpuConvParams params,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,DeviceMemoryBase scratch_memory)70 Status RunGpuConvUnfused(GpuConvParams params, se::Stream* stream,
71 RunConvOptions options,
72 DeviceMemory<ElementType> input_buf,
73 DeviceMemory<ElementType> filter_buf,
74 DeviceMemory<OutputType> output_buf,
75 DeviceMemoryBase scratch_memory) {
76 if (params.config->conv_result_scale != 1) {
77 return InternalError(
78 "StreamExecutor doesn't support scaled convolution: %lf.",
79 params.config->conv_result_scale);
80 }
81
82 TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
83 GetDNNConvKindFromCudnnConvKind(params.config->kind));
84
85 TF_ASSIGN_OR_RETURN(
86 se::dnn::DataType input_type,
87 GetDNNDataTypeFromPrimitiveType(params.config->input_type));
88
89 TF_ASSIGN_OR_RETURN(
90 se::dnn::DataType output_type,
91 GetDNNDataTypeFromPrimitiveType(params.config->output_type));
92
93 se::dnn::LazyOpRunner<se::dnn::ConvOp>* lazy_runner =
94 options.runner_cache->AsConvRunner();
95 std::optional<se::dnn::LazyOpRunner<se::dnn::ConvOp>> local_runner;
96 if (!lazy_runner) {
97 local_runner.emplace(params.config->algorithm);
98 lazy_runner = &*local_runner;
99 }
100
101 se::dnn::ConvOp::Config config{kind,
102 input_type,
103 output_type,
104 params.config->input_descriptor,
105 params.config->filter_descriptor,
106 params.config->output_descriptor,
107 params.config->conv_desc};
108 TF_ASSIGN_OR_RETURN(auto* runner,
109 lazy_runner->GetOrCreateRunner(config, stream));
110
111 return (*runner)(stream, options.profile_result, scratch_memory, input_buf,
112 filter_buf, output_buf);
113 }
114
115 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvForwardActivation(const GpuConvParams & params,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,DeviceMemoryBase scratch_memory)116 Status RunGpuConvForwardActivation(const GpuConvParams& params,
117 se::Stream* stream, RunConvOptions options,
118 DeviceMemory<ElementType> input_buf,
119 DeviceMemory<ElementType> filter_buf,
120 DeviceMemory<OutputType> output_buf,
121 DeviceMemoryBase scratch_memory) {
122 BatchDescriptor bias_desc = GetBiasDescriptor(*params.config);
123
124 se::DeviceMemory<OutputType> side_input(params.fusion->side_input_buf);
125 // If there is no side input, use output as the side input.
126 if (side_input.is_null()) {
127 if (params.config->fusion->side_input_scale != 0) {
128 return InternalError(
129 "Side input scale is not 0, yet no side input buffer is "
130 "provided");
131 }
132 // Since side-input scale is 0, the values in the side input don't
133 // matter. The simplest thing to do would be to pass in a null buffer
134 // for the side input, but cudnn doesn't allow this. cudnn does promise
135 // that if side-input-scale is 0 the side input won't be read, so we
136 // just pass in the output buffer, since it's handy and has the correct
137 // size.
138 side_input = output_buf;
139 }
140
141 se::dnn::LazyOpRunner<se::dnn::FusedConvOp>* lazy_runner =
142 options.runner_cache->AsFusedConvRunner();
143 std::optional<se::dnn::LazyOpRunner<se::dnn::FusedConvOp>> local_runner;
144 if (!lazy_runner) {
145 local_runner.emplace(params.config->algorithm);
146 lazy_runner = &*local_runner;
147 }
148
149 TF_ASSIGN_OR_RETURN(
150 se::dnn::DataType input_type,
151 GetDNNDataTypeFromPrimitiveType(params.config->input_type));
152
153 TF_ASSIGN_OR_RETURN(
154 se::dnn::DataType output_type,
155 GetDNNDataTypeFromPrimitiveType(params.config->output_type));
156
157 se::dnn::FusedConvOp::Config config{se::dnn::ConvolutionKind::FORWARD,
158 input_type,
159 BiasTypeForInputType(input_type),
160 output_type,
161 params.config->conv_result_scale,
162 params.config->fusion->side_input_scale,
163 /* leakyrelu_alpha = */ 0.0,
164 params.config->input_descriptor,
165 params.config->filter_descriptor,
166 bias_desc,
167 params.config->output_descriptor,
168 params.config->conv_desc,
169 params.config->fusion->mode};
170 TF_ASSIGN_OR_RETURN(auto* runner,
171 lazy_runner->GetOrCreateRunner(config, stream));
172
173 return (*runner)(stream, options.profile_result, scratch_memory, input_buf,
174 filter_buf, side_input, params.fusion->bias_buf, output_buf);
175 }
176
177 // StreamExecutor supports various data types via overloading, and the support
178 // is maintained on-demand. To avoid calling into non-exist overloads, we have
179 // to carefully not call into them by using enable_if.
180 // TODO(timshen): Ideally, to avoid such complication in the runner, we can turn
181 // StreamExecutor overloadings to template functions, and for unsupported data
182 // types return runtime errors.
183 // This is the specialization for double, float, and half types. All kinds of
184 // convolutions are supported here.
185 template <typename ElementType, typename BiasType, typename OutputType,
186 typename std::enable_if<
187 !std::is_integral<ElementType>::value>::type* = nullptr>
RunGpuConvInternalImpl(const GpuConvParams & params,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,DeviceMemoryBase scratch_memory)188 Status RunGpuConvInternalImpl(const GpuConvParams& params, se::Stream* stream,
189 RunConvOptions options,
190 DeviceMemory<ElementType> input_buf,
191 DeviceMemory<ElementType> filter_buf,
192 DeviceMemory<OutputType> output_buf,
193 DeviceMemoryBase scratch_memory) {
194 switch (params.config->kind) {
195 case CudnnConvKind::kForward:
196 case CudnnConvKind::kBackwardInput:
197 case CudnnConvKind::kBackwardFilter:
198 return RunGpuConvUnfused(params, stream, options, input_buf, filter_buf,
199 output_buf, scratch_memory);
200 case CudnnConvKind::kForwardActivation: {
201 return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
202 params, stream, options, input_buf, filter_buf, output_buf,
203 scratch_memory);
204 }
205 }
206 return OkStatus();
207 }
208
209 // Specialization for integer types. Only two forward convolutions are allowed.
210 template <typename ElementType, typename BiasType, typename OutputType,
211 typename std::enable_if<std::is_integral<ElementType>::value>::type* =
212 nullptr>
RunGpuConvInternalImpl(const GpuConvParams & params,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,DeviceMemoryBase scratch_memory)213 Status RunGpuConvInternalImpl(const GpuConvParams& params, se::Stream* stream,
214 RunConvOptions options,
215 DeviceMemory<ElementType> input_buf,
216 DeviceMemory<ElementType> filter_buf,
217 DeviceMemory<OutputType> output_buf,
218 DeviceMemoryBase scratch_memory) {
219 switch (params.config->kind) {
220 case CudnnConvKind::kForward:
221 return RunGpuConvUnfused(params, stream, options, input_buf, filter_buf,
222 output_buf, scratch_memory);
223 case CudnnConvKind::kForwardActivation:
224 return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
225 params, stream, options, input_buf, filter_buf, output_buf,
226 scratch_memory);
227 default:
228 return InternalError(
229 "Only convolution kinds kForward and kForwardActivation are "
230 "supported for integer types");
231 }
232 return OkStatus();
233 }
234
235 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvImpl(const GpuConvParams & params,se::Stream * stream,se::DeviceMemoryBase scratch_memory,RunConvOptions options)236 Status RunGpuConvImpl(const GpuConvParams& params, se::Stream* stream,
237 se::DeviceMemoryBase scratch_memory,
238 RunConvOptions options) {
239 auto input_buf = se::DeviceMemory<ElementType>(params.input_buf);
240 auto filter_buf = se::DeviceMemory<ElementType>(params.filter_buf);
241 auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
242
243 se::dnn::AlgorithmDesc algorithm = params.config->algorithm;
244 if (options.runner_cache) {
245 algorithm = options.runner_cache->ToAlgorithmDesc();
246 }
247
248 Status run_status = RunGpuConvInternalImpl<ElementType, BiasType, OutputType>(
249 params, stream, options, input_buf, filter_buf, output_buf,
250 scratch_memory);
251
252 if (run_status != OkStatus()) {
253 return run_status;
254 }
255
256 if (!stream->ok()) {
257 return InternalError(
258 "Unable to launch convolution with type %s and algorithm %s",
259 CudnnConvKindToString(params.config->kind), algorithm.ToString());
260 }
261 return OkStatus();
262 }
263
GetVectCSize(DataLayout layout)264 int64_t GetVectCSize(DataLayout layout) {
265 switch (layout) {
266 case DataLayout::kBatchDepthYX4:
267 return 4;
268 case DataLayout::kBatchDepthYX32:
269 return 32;
270 default:
271 return 1;
272 }
273 }
274
GetVectCSize(FilterLayout layout)275 int64_t GetVectCSize(FilterLayout layout) {
276 switch (layout) {
277 case FilterLayout::kOutputInputYX4:
278 return 4;
279 case FilterLayout::kOutputInputYX32:
280 return 32;
281 default:
282 return 1;
283 }
284 }
285
286 } // anonymous namespace
287
GetGpuConvConfig(const GpuConvDescriptor & desc,const absl::string_view inst_as_string)288 StatusOr<GpuConvConfig> GetGpuConvConfig(
289 const GpuConvDescriptor& desc, const absl::string_view inst_as_string) {
290 GpuConvConfig config;
291
292 const Shape& operand0_shape = desc.operand0_shape;
293 const Shape& operand1_shape = desc.operand1_shape;
294 const Shape& result_shape = desc.result_shape;
295 const CudnnConvBackendConfig& backend_config = desc.backend_config;
296
297 config.input_type = operand0_shape.element_type();
298 config.output_type = result_shape.element_type();
299 config.kind = desc.kind;
300 config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm());
301 config.conv_result_scale = backend_config.conv_result_scale();
302
303 switch (config.kind) {
304 case CudnnConvKind::kForward:
305 case CudnnConvKind::kForwardActivation:
306 config.input_shape = operand0_shape;
307 config.filter_shape = operand1_shape;
308 config.output_shape = result_shape;
309 break;
310 case CudnnConvKind::kBackwardInput:
311 config.input_shape = result_shape;
312 config.filter_shape = operand1_shape;
313 config.output_shape = operand0_shape;
314 break;
315 case CudnnConvKind::kBackwardFilter:
316 config.input_shape = operand0_shape;
317 config.filter_shape = result_shape;
318 config.output_shape = operand1_shape;
319 break;
320 default:
321 return InternalError("Unknown convolution kind");
322 }
323
324 if (config.kind == CudnnConvKind::kForwardActivation) {
325 config.fusion.emplace();
326 GpuConvConfig::FusionConfig& fusion = *config.fusion;
327 if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) {
328 return InternalError("Bad activation mode: %s",
329 backend_config.ShortDebugString());
330 }
331 fusion.mode =
332 static_cast<se::dnn::ActivationMode>(backend_config.activation_mode());
333 fusion.side_input_scale = backend_config.side_input_scale();
334 }
335
336 const Window& window = desc.window;
337 const ConvolutionDimensionNumbers& dnums = desc.dnums;
338
339 VLOG(3) << "Convolution Algorithm: " << config.algorithm.ToString();
340 VLOG(3) << "Convolution kind: " << CudnnConvKindToString(config.kind);
341 VLOG(3) << "input shape: "
342 << ShapeUtil::HumanStringWithLayout(config.input_shape);
343 VLOG(3) << "filter shape: "
344 << ShapeUtil::HumanStringWithLayout(config.filter_shape);
345 VLOG(3) << "Output shape: "
346 << ShapeUtil::HumanStringWithLayout(config.output_shape);
347 VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
348 VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
349
350 const int num_dimensions = window.dimensions_size();
351 CHECK_LE(num_dimensions, 3) << inst_as_string;
352
353 // cuDNN does not support 1D convolutions. We therefore express 1D
354 // convolutions as 2D convolutions where the first spatial dimension is 1.
355 // This matches the behavior of TF (see definition of conv1d in
356 // tensorflow/python/ops/nn_ops.py).
357 const int effective_num_dimensions = std::max(2, num_dimensions);
358
359 // If one dimension is reversed, we need to have all dimensions reversed (so
360 // we're doing convolution not cross correlation).
361 const bool dims_reversed =
362 window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
363
364 CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
365 << inst_as_string;
366 CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
367 << inst_as_string;
368 CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
369 << inst_as_string;
370 for (const WindowDimension& dim : window.dimensions()) {
371 CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string;
372 CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string;
373 CHECK_EQ(dim.base_dilation(), 1)
374 << "cudnn does not support base dilation; it "
375 "must be made explicit with a kPad: "
376 << inst_as_string;
377 }
378
379 // cuDNN's convolution APIs support the BDYX layout for activations/output and
380 // the OIYX layout for weights.
381 DataLayout input_dl;
382 FilterLayout filter_dl;
383 DataLayout output_dl;
384
385 const Shape& input_shape = config.input_shape;
386 const Shape& filter_shape = config.filter_shape;
387 const Shape& output_shape = config.output_shape;
388
389 TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
390 XlaConvShapesToStreamExecutorLayouts(
391 dnums, input_shape, filter_shape, output_shape));
392
393 BatchDescriptor& input_descriptor = config.input_descriptor;
394 input_descriptor = BatchDescriptor(effective_num_dimensions);
395 input_descriptor.set_layout(input_dl)
396 .set_feature_map_count(
397 GetVectCSize(input_dl) *
398 input_shape.dimensions(dnums.input_feature_dimension()))
399 .set_count(input_shape.dimensions(dnums.input_batch_dimension()));
400 for (int dim = 0; dim < num_dimensions; ++dim) {
401 // Note that the dimensions are reversed. The same holds below.
402 input_descriptor.set_spatial_dim(
403 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
404 input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
405 }
406
407 FilterDescriptor& filter_descriptor = config.filter_descriptor;
408 filter_descriptor = FilterDescriptor(effective_num_dimensions);
409 filter_descriptor.set_layout(filter_dl)
410 .set_input_feature_map_count(
411 GetVectCSize(filter_dl) *
412 filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
413 .set_output_feature_map_count(
414 filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
415 for (int dim = 0; dim < num_dimensions; ++dim) {
416 filter_descriptor.set_spatial_dim(
417 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
418 filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
419 }
420
421 config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
422 config.conv_desc.set_group_count(desc.feature_group_count);
423 config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
424 for (int dim = 0; dim < num_dimensions; ++dim) {
425 config.conv_desc
426 .set_zero_padding(
427 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
428 window.dimensions(dim).padding_low())
429 .set_filter_stride(
430 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
431 window.dimensions(dim).stride())
432 .set_dilation_rate(
433 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
434 window.dimensions(dim).window_dilation());
435 }
436
437 BatchDescriptor& output_descriptor = config.output_descriptor;
438 output_descriptor = BatchDescriptor(effective_num_dimensions);
439 output_descriptor.set_layout(output_dl)
440 .set_feature_map_count(
441 GetVectCSize(output_dl) *
442 output_shape.dimensions(dnums.output_feature_dimension()))
443 .set_count(output_shape.dimensions(dnums.output_batch_dimension()));
444 for (int dim = 0; dim < num_dimensions; ++dim) {
445 output_descriptor.set_spatial_dim(
446 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
447 output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
448 }
449
450 // Add a singleton dimension in the 1D convolution case.
451 for (int dim = 0; dim < effective_num_dimensions - num_dimensions; dim++) {
452 input_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
453 output_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
454 filter_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
455 config.conv_desc.set_zero_padding(static_cast<DimIndex>(dim), 0)
456 .set_filter_stride(static_cast<DimIndex>(dim), 1);
457 }
458
459 return config;
460 }
461
GetGpuConvConfig(const HloCustomCallInstruction * cudnn_call)462 StatusOr<GpuConvConfig> GetGpuConvConfig(
463 const HloCustomCallInstruction* cudnn_call) {
464 GpuConvDescriptor descriptor;
465
466 TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call));
467 TF_ASSIGN_OR_RETURN(descriptor.backend_config,
468 cudnn_call->backend_config<CudnnConvBackendConfig>());
469 descriptor.operand0_shape = cudnn_call->operand(0)->shape();
470 descriptor.operand1_shape = cudnn_call->operand(1)->shape();
471 descriptor.result_shape = cudnn_call->shape().tuple_shapes(0);
472 descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0);
473 descriptor.window = cudnn_call->window();
474 descriptor.dnums = cudnn_call->convolution_dimension_numbers();
475 descriptor.feature_group_count = cudnn_call->feature_group_count();
476 return GetGpuConvConfig(descriptor, cudnn_call->ToString());
477 }
478
GetGpuConvParams(const GpuConvConfig & config,absl::Span<const se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer)479 StatusOr<GpuConvParams> GetGpuConvParams(
480 const GpuConvConfig& config,
481 absl::Span<const se::DeviceMemoryBase> operand_buffers,
482 se::DeviceMemoryBase result_buffer) {
483 GpuConvParams params;
484 params.config = &config;
485
486 switch (config.kind) {
487 case CudnnConvKind::kForward:
488 case CudnnConvKind::kForwardActivation:
489 params.input_buf = operand_buffers[0];
490 params.filter_buf = operand_buffers[1];
491 params.output_buf = result_buffer;
492 break;
493 case CudnnConvKind::kBackwardInput:
494 params.input_buf = result_buffer;
495 params.filter_buf = operand_buffers[1];
496 params.output_buf = operand_buffers[0];
497 break;
498 case CudnnConvKind::kBackwardFilter:
499 params.input_buf = operand_buffers[0];
500 params.filter_buf = result_buffer;
501 params.output_buf = operand_buffers[1];
502 break;
503 }
504
505 if (config.kind == CudnnConvKind::kForwardActivation) {
506 params.fusion.emplace();
507 GpuConvParams::FusionParams& fusion = *params.fusion;
508 fusion.bias_buf = operand_buffers[2];
509 if (operand_buffers.size() >= 4) {
510 fusion.side_input_buf = operand_buffers[3];
511 }
512 }
513
514 return params;
515 }
516
RunGpuConv(const gpu::GpuConvConfig & config,absl::Span<const se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::DeviceMemoryBase scratch_memory,se::Stream * stream,RunConvOptions options)517 Status RunGpuConv(const gpu::GpuConvConfig& config,
518 absl::Span<const se::DeviceMemoryBase> operand_buffers,
519 se::DeviceMemoryBase result_buffer,
520 se::DeviceMemoryBase scratch_memory, se::Stream* stream,
521 RunConvOptions options) {
522 TF_ASSIGN_OR_RETURN(GpuConvParams params,
523 GetGpuConvParams(config, operand_buffers, result_buffer));
524
525 PrimitiveType input_primitive_type = config.input_type;
526 switch (input_primitive_type) {
527 case F16:
528 return RunGpuConvImpl<Eigen::half, Eigen::half, Eigen::half>(
529 params, stream, scratch_memory, options);
530 case BF16:
531 return RunGpuConvImpl<Eigen::bfloat16, Eigen::bfloat16, Eigen::bfloat16>(
532 params, stream, scratch_memory, options);
533 case F32:
534 return RunGpuConvImpl<float, float, float>(params, stream, scratch_memory,
535 options);
536 case F64:
537 return RunGpuConvImpl<double, double, double>(params, stream,
538 scratch_memory, options);
539 case S8: {
540 PrimitiveType output_primitive_type = config.output_type;
541 switch (output_primitive_type) {
542 case F32:
543 return RunGpuConvImpl<int8_t, float, float>(params, stream,
544 scratch_memory, options);
545 case S8:
546 return RunGpuConvImpl<int8_t, float, int8_t>(params, stream,
547 scratch_memory, options);
548 default:
549 return Unimplemented("Unimplemented convolution");
550 }
551 }
552 default:
553 return Unimplemented("Unimplemented convolution");
554 }
555 }
556
557 } // namespace gpu
558 } // namespace xla
559