xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/conv_ops_fused_impl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements convolution operations with other kernels baked into the
17 // processing, to optimize latency and memory usage:
18 //  - Conv2D + BiasAdd + <Activation>
19 //  - Conv2D + FusedBatchNorm + <Activation>
20 //
21 // Activation: Relu, Relu6, Elu, etc...
22 //
23 // Kernels for convolutions fused with image transformations (resize and mirror
24 // padding) defined in `conv_ops_fused_image_transform.cc`.
25 //
26 // For the CPU device we implement fusion with an Eigen tensor contraction
27 // output kernel. For the GPU device we rely on CuDNN primitives.
28 //
29 // NOTE: GPU only supports fusion of Conv2D + BiasAdd + <optional Relu>.
30 
31 #ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
32 #define TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
33 
34 #define USE_EIGEN_TENSOR
35 #define EIGEN_USE_THREADS
36 
37 #if GOOGLE_CUDA
38 #define EIGEN_USE_GPU
39 #endif  // GOOGLE_CUDA
40 
41 #include <string>
42 #include <utility>
43 #include <vector>
44 
45 #include "absl/strings/str_cat.h"
46 #include "absl/strings/str_join.h"
47 #include "absl/strings/substitute.h"
48 #include "tensorflow/core/framework/bounds_check.h"
49 #include "tensorflow/core/framework/op_kernel.h"
50 #include "tensorflow/core/framework/register_types.h"
51 #include "tensorflow/core/framework/tensor.h"
52 #include "tensorflow/core/framework/tensor_shape.h"
53 #include "tensorflow/core/kernels/conv_2d.h"
54 #include "tensorflow/core/kernels/conv_ops.h"
55 #include "tensorflow/core/kernels/fused_eigen_output_kernels.h"
56 #include "tensorflow/core/kernels/ops_util.h"
57 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
58 #include "tensorflow/core/util/tensor_format.h"
59 #include "tensorflow/core/util/use_cudnn.h"
60 
61 #if GOOGLE_CUDA
62 #include "third_party/gpus/cudnn/cudnn.h"
63 #include "tensorflow/core/kernels/conv_ops_gpu.h"
64 #include "tensorflow/core/platform/stream_executor.h"
65 #include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h"
66 #include "tensorflow/core/util/autotune_maps/conv_parameters.h"
67 #include "tensorflow/core/util/proto/proto_utils.h"
68 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
69 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
70 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
71 #endif  // GOOGLE_CUDA
72 
73 namespace tensorflow {
74 
75 typedef Eigen::ThreadPoolDevice CPUDevice;
76 typedef Eigen::GpuDevice GPUDevice;
77 
78 template <typename Device, typename T>
79 struct LaunchFusedConv2DOp {
80   void operator()(OpKernelContext* context, bool use_cudnn,
81                   bool cudnn_use_autotune, const Tensor& input,
82                   const Tensor& filter, FusedComputationType fusion,
83                   const FusedComputationArgs& fusion_args,
84                   const Conv2DParameters& params,
85                   const Conv2DDimensions& dimensions, Tensor* output);
86 };
87 
88 // This is CPU-only implementation that uses Eigen contraction output kernels.
89 //
90 // Dispatch 2D convolution to the appropriate primitive operation:
91 //   (1) MatMul for the case of 1x1 convolution.
92 //   (2) MatMul for the case when filter size equals to the input size.
93 //   (3) General spatial 2D convolution for all other cases.
94 template <typename T>
95 class LaunchFusedConv2DWithOutputKernel {
96  public:
LaunchFusedConv2DWithOutputKernel(int row_stride,int col_stride,int row_dilation,int col_dilation,Padding padding,const std::vector<int64_t> & explicit_paddings)97   LaunchFusedConv2DWithOutputKernel(
98       int row_stride, int col_stride,      //
99       int row_dilation, int col_dilation,  //
100       Padding padding, const std::vector<int64_t>& explicit_paddings)
101       : row_stride_(row_stride),
102         col_stride_(col_stride),
103         row_dilation_(row_dilation),
104         col_dilation_(col_dilation),
105         padding_(padding),
106         explicit_paddings_(explicit_paddings) {}
107 
108   template <typename OutputKernel>
operator()109   void operator()(const OutputKernel& output_kernel, OpKernelContext* ctx,
110                   const Tensor& input, const Tensor& filter, Tensor* output) {
111     // Wrap output_kernel into type erased wrapper to reduce the number of
112     // unique template instantiations for Eigen Tensor contraction expressions.
113     OutputKernelWrapper output_kernel_wrapper(
114         [&output_kernel](
115             const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
116             const Eigen::TensorContractionParams& params, Eigen::Index i,
117             Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) {
118           output_kernel(output_mapper, params, i, j, num_rows, num_cols);
119         });
120 
121     if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 &&
122         row_stride_ == 1 && col_stride_ == 1 && padding_ != EXPLICIT) {
123       int conv_width = 1;  // Width for the convolution step.
124       for (int i = 0; i < 3; ++i) {
125         conv_width *= output->dim_size(i);
126       }
127 
128       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
129       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
130       functor::MatMulConvFunctor<CPUDevice, T, OutputKernelWrapper>()(
131           ctx->eigen_device<CPUDevice>(),
132           output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
133           input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
134           filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
135           dim_pair, std::move(output_kernel_wrapper));
136 
137     } else if (filter.dim_size(0) == input.dim_size(1) &&
138                filter.dim_size(1) == input.dim_size(2) && row_dilation_ == 1 &&
139                col_dilation_ == 1 && padding_ == VALID) {
140       // If the input data and filter have the same height/width,
141       // reduce the 2D convolution to matrix multiplication.
142       const auto k =  // Length of reduction dimension.
143           filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
144 
145       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
146       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
147       functor::MatMulConvFunctor<CPUDevice, T, OutputKernelWrapper>()(
148           ctx->eigen_device<CPUDevice>(),
149           output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
150           input.shaped<T, 2>({input.dim_size(0), k}),
151           filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair,
152           std::move(output_kernel_wrapper));
153 
154     } else {
155       if (padding_ == EXPLICIT) {
156         functor::SpatialConvolution<CPUDevice, T, OutputKernelWrapper>()(
157             ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
158             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
159             col_stride_, row_dilation_, col_dilation_,
160             static_cast<int>(explicit_paddings_[2]),
161             static_cast<int>(explicit_paddings_[3]),
162             static_cast<int>(explicit_paddings_[4]),
163             static_cast<int>(explicit_paddings_[5]),
164             std::move(output_kernel_wrapper));
165       } else {
166         functor::SpatialConvolution<CPUDevice, T, OutputKernelWrapper>()(
167             ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
168             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
169             col_stride_, row_dilation_, col_dilation_,
170             BrainPadding2EigenPadding(padding_),
171             std::move(output_kernel_wrapper));
172       }
173     }
174   }
175 
176  private:
177   // Wrap output_kernel into type erased struct to reduce the number of unique
178   // template instantiations for Eigen Tensor contraction expressions.
179   //
180   // We do not pass std::function directly as an output kernel because it blows
181   // up the binary size in debug mode with super long symbol names.
182   struct OutputKernelWrapper {
183     using OutputKernelFn =
184         std::function<void(const ContractionOutputMapper<T, Eigen::Index>&,
185                            const Eigen::TensorContractionParams&, Eigen::Index,
186                            Eigen::Index, Eigen::Index, Eigen::Index)>;
187 
OutputKernelWrapperOutputKernelWrapper188     explicit OutputKernelWrapper(OutputKernelFn fn)
189         : output_kernel_fn(std::move(fn)) {}
190 
operatorOutputKernelWrapper191     void operator()(
192         const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
193         const Eigen::TensorContractionParams& params, Eigen::Index i,
194         Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const {
195       output_kernel_fn(output_mapper, params, i, j, num_rows, num_cols);
196     }
197 
198     OutputKernelFn output_kernel_fn;
199   };
200 
201   int row_stride_;
202   int col_stride_;
203   int row_dilation_;
204   int col_dilation_;
205   const Padding padding_;
206   const std::vector<int64_t>& explicit_paddings_;
207 };
208 
209 template <typename T>
210 struct LaunchFusedConv2DOp<CPUDevice, T> {
211   void operator()(OpKernelContext* context, bool use_cudnn,
212                   bool cudnn_use_autotune, const Tensor& input,
213                   const Tensor& filter, const FusedComputationType fusion,
214                   const FusedComputationArgs& fusion_args,
215                   const Conv2DParameters& params,
216                   const Conv2DDimensions& dimensions, Tensor* output) {
217     OP_REQUIRES(context, dimensions.in_depth == filter.dim_size(2),
218                 errors::Unimplemented("Fused conv implementation does not "
219                                       "support grouped convolutions for now."));
220     OP_REQUIRES(context, params.data_format == FORMAT_NHWC,
221                 errors::Unimplemented("Fused conv implementation only supports "
222                                       "NHWC tensor format for now."));
223     OP_REQUIRES(context, DataTypeToEnum<T>::value != DT_HALF,
224                 errors::Unimplemented("Fused conv implementation with half "
225                                       "precision is not supported on CPU."));
226 
227     BiasAddArgs<T> bias_add_args;
228     if (BiasAddArgs<T>::IsSupported(fusion)) {
229       if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) {
230         OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args,
231                                                 &fusion_args.leakyrelu_alpha));
232       } else {
233         OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
234       }
235     }
236 
237     FusedBatchNormArgs<T> fused_batch_norm_args;
238     if (FusedBatchNormArgs<T>::IsSupported(fusion)) {
239       if (fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu) {
240         OP_REQUIRES_OK(context,
241                        InitFusedBatchNormArgs(context, fusion_args.epsilon,
242                                               &fused_batch_norm_args,
243                                               &fusion_args.leakyrelu_alpha));
244       } else {
245         OP_REQUIRES_OK(context,
246                        InitFusedBatchNormArgs(context, fusion_args.epsilon,
247                                               &fused_batch_norm_args));
248       }
249     }
250 
251     LaunchFusedConv2DWithOutputKernel<T> conv2d(
252         dimensions.stride_rows, dimensions.stride_cols,
253         dimensions.dilation_rows, dimensions.dilation_cols, params.padding,
254         params.explicit_paddings);
255 
256     switch (fusion) {
257       case FusedComputationType::kUndefined:
258         OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
259         break;
260       case FusedComputationType::kBiasAddWithGeluApproximate:
261         OP_REQUIRES_OK(context, errors::Internal("Fusion type is unsupported"));
262         break;
263       case FusedComputationType::kBiasAdd:
264         conv2d(WithBiasAdd<T>(bias_add_args), context, input, filter, output);
265         break;
266       case FusedComputationType::kBiasAddWithRelu:
267         conv2d(WithBiasAddAndRelu<T>(bias_add_args), context, input, filter,
268                output);
269         break;
270       case FusedComputationType::kBiasAddWithRelu6:
271         conv2d(WithBiasAddAndRelu6<T>(bias_add_args), context, input, filter,
272                output);
273         break;
274       case FusedComputationType::kBiasAddWithLeakyRelu:
275         conv2d(WithBiasAddAndLeakyRelu<T>(bias_add_args), context, input,
276                filter, output);
277         break;
278       case FusedComputationType::kBiasAddWithElu:
279         conv2d(WithBiasAddAndElu<T>(bias_add_args), context, input, filter,
280                output);
281         break;
282       case FusedComputationType::kFusedBatchNorm:
283         conv2d(
284             WithFusedBatchNorm<T>(fusion_args.epsilon, fused_batch_norm_args),
285             context, input, filter, output);
286         break;
287       case FusedComputationType::kFusedBatchNormWithRelu:
288         conv2d(WithFusedBatchNormAndRelu<T>(fusion_args.epsilon,
289                                             fused_batch_norm_args),
290                context, input, filter, output);
291         break;
292       case FusedComputationType::kFusedBatchNormWithRelu6:
293         conv2d(WithFusedBatchNormAndRelu6<T>(fusion_args.epsilon,
294                                              fused_batch_norm_args),
295                context, input, filter, output);
296         break;
297       case FusedComputationType::kFusedBatchNormWithLeakyRelu:
298         conv2d(WithFusedBatchNormAndLeakyRelu<T>(fusion_args.epsilon,
299                                                  fused_batch_norm_args),
300                context, input, filter, output);
301         break;
302       case FusedComputationType::kFusedBatchNormWithElu:
303         conv2d(WithFusedBatchNormAndElu<T>(fusion_args.epsilon,
304                                            fused_batch_norm_args),
305                context, input, filter, output);
306         break;
307     }
308   }
309 };
310 
311 #if GOOGLE_CUDA
312 
313 inline int64_t ConvolveScratchSize() {
314   static int64_t convolve_scratch_size = GetDnnWorkspaceLimit(
315       // default value is in bytes despite the name of the environment variable
316       "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
317   );
318   return convolve_scratch_size;
319 }
320 
321 template <typename T>
322 struct LaunchFusedConv2DOp<GPUDevice, T> {
323   void operator()(OpKernelContext* context, bool use_cudnn,
324                   bool cudnn_use_autotune, const Tensor& input_param,
325                   const Tensor& filter, FusedComputationType fusion,
326                   const FusedComputationArgs& fusion_args,
327                   const Conv2DParameters& params,
328                   const Conv2DDimensions& dimensions, Tensor* output) {
329     OP_REQUIRES(
330         context,
331         params.data_format == FORMAT_NHWC || params.data_format == FORMAT_NCHW,
332         errors::Unimplemented("Fused conv implementation only supports "
333                               "NHWC and HCHW tensor formats for now."));
334 
335     auto* stream = context->op_device_context()->stream();
336     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
337     OP_REQUIRES(
338         context, use_cudnn,
339         errors::Unimplemented("FusedConv2D for GPU is not currently supported "
340                               "without cudnn"));
341 
342     bool is_supported_activation =
343         fusion == FusedComputationType::kBiasAddWithRelu ||
344         fusion == FusedComputationType::kBiasAddWithRelu6 ||
345         fusion == FusedComputationType::kBiasAddWithElu ||
346         fusion == FusedComputationType::kBiasAddWithLeakyRelu;
347     OP_REQUIRES(
348         context, is_supported_activation,
349         errors::Unimplemented("FusedConv2D implementation only supports "
350                               "fusing with `BiasAdd + Relu|Relu6|Elu|LeakyRlue`"
351                               " for now."));
352 
353     Tensor input = input_param;
354 
355     const int64_t in_batch = GetTensorDim(input, params.data_format, 'N');
356     int64_t in_rows = GetTensorDim(input, params.data_format, 'H');
357     int64_t in_cols = GetTensorDim(input, params.data_format, 'W');
358     const int64_t in_depths = GetTensorDim(input, params.data_format, 'C');
359 
360     const int64_t patch_rows = filter.dim_size(0);
361     const int64_t patch_cols = filter.dim_size(1);
362     const int64_t patch_depths = filter.dim_size(2);
363 
364     const int64_t out_batch = GetTensorDim(*output, params.data_format, 'N');
365     const int64_t out_rows = GetTensorDim(*output, params.data_format, 'H');
366     const int64_t out_cols = GetTensorDim(*output, params.data_format, 'W');
367     const int64_t out_depths = GetTensorDim(*output, params.data_format, 'C');
368 
369     // Bias of the following dimensions: [ output_depth ]
370     const Tensor& bias = context->input(2);
371     OP_REQUIRES(context, bias.dims() == 1,
372                 errors::InvalidArgument("bias must be 1-dimensional",
373                                         bias.shape().DebugString()));
374     OP_REQUIRES(context, bias.dim_size(0) == out_depths,
375                 errors::InvalidArgument("bias depth must be equal to out depth",
376                                         bias.shape().DebugString()));
377 
378     const int64_t common_padding_rows =
379         std::min(dimensions.pad_rows_before, dimensions.pad_rows_after);
380     const int64_t common_padding_cols =
381         std::min(dimensions.pad_cols_before, dimensions.pad_cols_after);
382     if (dimensions.pad_rows_before != dimensions.pad_rows_after ||
383         dimensions.pad_cols_before != dimensions.pad_cols_after) {
384       // cuDNN only supports padding the same amount on the left and right
385       // sides, and on the top and bottom sides. So we manually create a new
386       // padded input tensor such that we can pass it to cuDNN.
387 
388       // TODO(reedwm): In some cases, we can avoid an allocation even if the two
389       // padding sides are different. For example, if the input is 2x2, the
390       // filter is 1x1, the stride is 2, and the padding is (1, 0, 1, 0), the
391       // result is equivalent to as if the padding is (1, 1, 1, 1). Changing the
392       // padding in such a way would allow us to avoid the allocation.
393       Tensor transformed_input;
394       const int64_t padding_rows_diff =
395           std::abs(dimensions.pad_rows_after - dimensions.pad_rows_before);
396       const int64_t padding_cols_diff =
397           std::abs(dimensions.pad_cols_after - dimensions.pad_cols_before);
398       const int64_t new_in_rows = in_rows + padding_rows_diff;
399       const int64_t new_in_cols = in_cols + padding_cols_diff;
400       OP_REQUIRES_OK(context,
401                      context->allocate_temp(
402                          DataTypeToEnum<T>::value,
403                          ShapeFromFormat(params.data_format, in_batch,
404                                          new_in_rows, new_in_cols, in_depths),
405                          &transformed_input));
406       const int64_t input_pad_top =
407           dimensions.pad_rows_before - common_padding_rows;
408       const int64_t input_pad_bottom =
409           dimensions.pad_rows_after - common_padding_rows;
410       const int64_t input_pad_left =
411           dimensions.pad_cols_before - common_padding_cols;
412       const int64_t input_pad_right =
413           dimensions.pad_cols_after - common_padding_cols;
414       bool in_bounds =
415           FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
416           FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
417           FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
418           FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
419       if (!in_bounds) {
420         context->SetStatus(errors::InvalidArgument("Padding is too large."));
421         return;
422       }
423       functor::PadInput<GPUDevice, T, int, 4>()(
424           context->eigen_device<GPUDevice>(),
425           To32Bit(input_param.tensor<T, 4>()),
426           {{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}},
427           {{static_cast<int>(input_pad_bottom),
428             static_cast<int>(input_pad_right)}},
429           To32Bit(transformed_input.tensor<T, 4>()), params.data_format, T{});
430       input = transformed_input;
431       in_rows = new_in_rows;
432       in_cols = new_in_cols;
433     }
434 
435     const bool compute_in_nhwc = DataTypeToEnum<T>::value == DT_HALF &&
436                                  stream->GetCudaComputeCapability().IsAtLeast(
437                                      se::CudaComputeCapability::VOLTA);
438     if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) {
439       // Convert the input tensor from NHWC to NCHW.
440       TensorShape nchw_shape =
441           ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
442       if (in_depths > 1) {
443         Tensor transformed_input;
444         OP_REQUIRES_OK(context,
445                        context->allocate_temp(DataTypeToEnum<T>::value,
446                                               nchw_shape, &transformed_input));
447         functor::NHWCToNCHW<GPUDevice, T, 4>()(
448             context->eigen_device<GPUDevice>(),
449             const_cast<const Tensor&>(input).tensor<T, 4>(),
450             transformed_input.tensor<T, 4>());
451         input = transformed_input;
452       } else {
453         // If depth <= 1, then just reshape.
454         CHECK(input.CopyFrom(input, nchw_shape));  // Crash OK
455       }
456     }
457 
458     CHECK(common_padding_rows >= 0) << "Negative padding rows";  // Crash OK
459     CHECK(common_padding_rows >= 0) << "Negative padding cols";  // Crash OK
460 
461     se::dnn::ActivationMode dnn_activation_mode;
462     switch (fusion) {
463       case FusedComputationType::kBiasAddWithRelu:
464         dnn_activation_mode = se::dnn::ActivationMode::kRelu;
465         break;
466       case FusedComputationType::kBiasAddWithRelu6:
467         dnn_activation_mode = se::dnn::ActivationMode::kRelu6;
468         break;
469       case FusedComputationType::kBiasAddWithElu:
470         dnn_activation_mode = se::dnn::ActivationMode::kElu;
471         break;
472       case FusedComputationType::kBiasAddWithLeakyRelu:
473         dnn_activation_mode = se::dnn::ActivationMode::kLeakyRelu;
474         break;
475       default:
476         LOG(FATAL) << "Unsupported fusion type";  // Crash OK
477     }
478 
479     const TensorFormat compute_data_format =
480         compute_in_nhwc ? FORMAT_NHWC : FORMAT_NCHW;
481     constexpr auto kComputeInNHWC =
482         std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
483                         se::dnn::FilterLayout::kOutputYXInput);
484     constexpr auto kComputeInNCHW =
485         std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
486                         se::dnn::FilterLayout::kOutputInputYX);
487     se::dnn::DataLayout compute_data_layout;
488     se::dnn::FilterLayout filter_layout;
489     std::tie(compute_data_layout, filter_layout) =
490         compute_in_nhwc ? kComputeInNHWC : kComputeInNCHW;
491 
492     se::dnn::BatchDescriptor input_desc;
493     input_desc.set_count(in_batch)
494         .set_feature_map_count(in_depths)
495         .set_height(in_rows)
496         .set_width(in_cols)
497         .set_layout(compute_data_layout);
498     se::dnn::FilterDescriptor filter_desc;
499     filter_desc.set_input_filter_height(patch_rows)
500         .set_input_filter_width(patch_cols)
501         .set_input_feature_map_count(patch_depths)
502         .set_output_feature_map_count(filter.dim_size(3))
503         .set_layout(filter_layout);
504     se::dnn::BatchDescriptor bias_desc;
505     bias_desc.set_count(1)
506         .set_height(1)
507         .set_width(1)
508         .set_feature_map_count(out_depths)
509         .set_layout(compute_data_layout);
510     se::dnn::ConvolutionDescriptor conv_desc;
511     conv_desc.set_vertical_dilation_rate(dimensions.dilation_rows)
512         .set_horizontal_dilation_rate(dimensions.dilation_cols)
513         .set_vertical_filter_stride(dimensions.stride_rows)
514         .set_horizontal_filter_stride(dimensions.stride_cols)
515         .set_zero_padding_height(common_padding_rows)
516         .set_zero_padding_width(common_padding_cols)
517         .set_group_count(in_depths / patch_depths);
518     se::dnn::BatchDescriptor output_desc;
519     output_desc.set_count(out_batch)
520         .set_height(out_rows)
521         .set_width(out_cols)
522         .set_feature_map_count(out_depths)
523         .set_layout(compute_data_layout);
524 
525     Tensor transformed_filter;
526     const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
527       VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
528               << " to " << ToString(dst_format);
529 
530       TensorShape dst_shape =
531           dst_format == FORMAT_OIHW
532               ? TensorShape({filter.dim_size(3), filter.dim_size(2),
533                              filter.dim_size(0), filter.dim_size(1)})
534               : TensorShape({filter.dim_size(3), filter.dim_size(0),
535                              filter.dim_size(1), filter.dim_size(2)});
536 
537       TF_RETURN_IF_ERROR(context->allocate_temp(
538           DataTypeToEnum<T>::value, dst_shape, &transformed_filter));
539       functor::TransformFilter<GPUDevice, T, int, 4>()(
540           context->eigen_device<GPUDevice>(), dst_format,
541           To32Bit(filter.tensor<T, 4>()),
542           To32Bit(transformed_filter.tensor<T, 4>()));
543 
544       return OkStatus();
545     };
546 
547     if (compute_in_nhwc) {
548       OP_REQUIRES_OK(context, transform_filter(FORMAT_OHWI));
549     } else {
550       OP_REQUIRES_OK(context, transform_filter(FORMAT_OIHW));
551     }
552 
553     Tensor transformed_output;
554     if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) {
555       // Only allocate temporary memory when a layout transformation is needed.
556       OP_REQUIRES_OK(context,
557                      context->allocate_temp(
558                          DataTypeToEnum<T>::value,
559                          ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows,
560                                          out_cols, out_depths),
561                          &transformed_output));
562     } else {
563       transformed_output = *output;
564     }
565 
566     const auto tensor_on_device = [](const Tensor& t) -> se::DeviceMemory<T> {
567       return AsDeviceMemory(t.template flat<T>().data(),
568                             t.template flat<T>().size());
569     };
570 
571     se::DeviceMemory<T> input_ptr = tensor_on_device(input);
572     se::DeviceMemory<T> filter_ptr = tensor_on_device(transformed_filter);
573     se::DeviceMemory<T> bias_ptr = tensor_on_device(bias);
574     se::DeviceMemory<T> output_ptr = tensor_on_device(transformed_output);
575 
576     // We do not use side inputs, so we can safely pass nullptr.
577     se::DeviceMemory<T> side_input_ptr =
578         AsDeviceMemory(static_cast<T*>(nullptr), 0);
579 
580     constexpr double kConvScale = 1.0;
581     constexpr double kSideInputScale = 0.0;
582     double leakyrelu_alpha = fusion_args.leakyrelu_alpha;
583 
584     int device_id = stream->parent()->device_ordinal();
585     DataType dtype = input.dtype();
586     ConvParameters conv_parameters = {
587         in_batch,                      // batch
588         in_depths,                     // in_depths
589         {{in_rows,                     // in_rows
590           in_cols}},                   // in_cols
591         compute_data_format,           // compute_data_format
592         out_depths,                    // out_depths
593         {{patch_rows,                  // filter_rows
594           patch_cols,                  // filter_cols
595           patch_depths}},              // filter_depths
596         {{dimensions.dilation_rows,    // dilation_rows
597           dimensions.dilation_cols}},  // dilation_cols
598         {{dimensions.stride_rows,      // stride_rows
599           dimensions.stride_cols}},    // stride_cols
600         {{common_padding_rows,         // padding_rows
601           common_padding_cols}},       // padding_cols
602         dtype,                         // tensor datatype
603         device_id,                     // device_id
604         conv_desc.group_count(),
605         ConvParameters::FusionInfo{kConvScale, kSideInputScale, leakyrelu_alpha,
606                                    dnn_activation_mode,  // activation_mode
607                                    /*is_contrib=*/false}};
608 
609     se::dnn::DataType element_type = se::dnn::ToDataType<T>::value;
610 
611     auto entry_or = AutotuneFusedConv<T>(
612         cudnn_use_autotune, FusedConvAutotuneMap::GetInstance(),
613         conv_parameters, context, input_desc, filter_desc, bias_desc,
614         output_desc, conv_desc, dnn_activation_mode, kConvScale,
615         kSideInputScale, leakyrelu_alpha, input_ptr, filter_ptr, output_ptr,
616         bias_ptr, side_input_ptr, ConvolveScratchSize());
617     OP_REQUIRES_OK(context, entry_or.status());
618     auto autotune_entry = std::move(entry_or).value();
619 
620     DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
621     Status cudnn_launch_status;
622     if (!autotune_entry.is_algorithm_config()) {
623       auto& runners = autotune_entry.GetOpRunners();
624       se::dnn::FusedConvOp::Config config{se::dnn::ConvolutionKind::FORWARD,
625                                           element_type,
626                                           element_type,
627                                           element_type,
628                                           kConvScale,
629                                           kSideInputScale,
630                                           leakyrelu_alpha,
631                                           input_desc,
632                                           filter_desc,
633                                           bias_desc,
634                                           output_desc,
635                                           conv_desc,
636                                           dnn_activation_mode};
637       auto primary_or = runners.primary->GetOrCreateRunner(config, stream);
638       OP_REQUIRES_OK(context, primary_or.status());
639       auto* primary = primary_or.ValueOrDie();
640 
641       const se::dnn::FusedConvRunner* no_scratch_fallback = nullptr;
642       if (runners.no_scratch_fallback) {
643         auto no_scratch_fallback_or =
644             runners.no_scratch_fallback->GetOrCreateRunner(config, stream);
645         OP_REQUIRES_OK(context, no_scratch_fallback_or.status());
646         no_scratch_fallback = no_scratch_fallback_or.ValueOrDie();
647       }
648 
649       auto runner_and_scratch_or =
650           AllocateScratchOrFallback<se::dnn::FusedConvOp::Signature>(
651               &scratch_allocator, primary, no_scratch_fallback);
652       OP_REQUIRES_OK(context, runner_and_scratch_or.status());
653       auto runner_and_scratch = std::move(runner_and_scratch_or).value();
654       auto& runner =
655           *std::get<const se::dnn::FusedConvRunner*>(runner_and_scratch);
656       cudnn_launch_status = runner(
657           stream, nullptr, std::get<se::DeviceMemoryBase>(runner_and_scratch),
658           input_ptr, filter_ptr, side_input_ptr, bias_ptr, output_ptr);
659     } else {
660       cudnn_launch_status = stream->FusedConvolveWithAlgorithm(
661           input_desc, input_ptr,            // input
662           kConvScale,                       // input_scale
663           filter_desc, filter_ptr,          // filter
664           conv_desc,                        // conv
665           side_input_ptr, kSideInputScale,  // side_input
666           bias_desc, bias_ptr,              // bias
667           dnn_activation_mode,              // activation
668           output_desc, &output_ptr,         // output
669           &scratch_allocator, autotune_entry.GetAlgorithmConfig(), nullptr);
670     }
671 
672     OP_REQUIRES_OK(context, cudnn_launch_status);
673 
674     // Convert the output tensor back from NCHW to NHWC.
675     if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) {
676       functor::NCHWToNHWC<GPUDevice, T, 4>()(
677           context->eigen_device<GPUDevice>(),
678           const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
679           output->tensor<T, 4>());
680     }
681   }
682 };
683 
684 #endif  // GOOGLE_CUDA
685 
686 template <typename Device, typename T>
687 class FusedConv2DOp : public OpKernel {
688  public:
689   explicit FusedConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
690     OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
691 
692     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
693     cudnn_use_autotune_ = CudnnUseAutotune();
694 
695     using FCT = FusedComputationType;
696 
697     std::vector<FusedComputationPattern> patterns;
698     if (std::is_same<Device, CPUDevice>::value) {
699       patterns = {
700           {FCT::kBiasAdd, {"BiasAdd"}},
701           {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
702           {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
703           {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
704           {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
705           {FCT::kFusedBatchNorm, {"FusedBatchNorm"}},
706           {FCT::kFusedBatchNormWithRelu, {"FusedBatchNorm", "Relu"}},
707           {FCT::kFusedBatchNormWithRelu6, {"FusedBatchNorm", "Relu6"}},
708           {FCT::kFusedBatchNormWithElu, {"FusedBatchNorm", "Elu"}},
709           {FCT::kFusedBatchNormWithLeakyRelu, {"FusedBatchNorm", "LeakyRelu"}},
710       };
711     }
712 
713     // NOTE(ezhulenev): CuDNN `cudnnConvolutionBiasActivationForward` supports
714     // identity activation function, it in theory should allow to fuse
715     // convolution with BiasAdd, but in practice it doesn't work, cuDNN ignores
716     // this parameter and always does Relu activation.
717     if (std::is_same<Device, GPUDevice>::value) {
718       patterns = {
719           {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
720           {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
721           {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
722           {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
723       };
724     }
725 
726     OP_REQUIRES_OK(context, InitializeFusedComputation(
727                                 context, "Conv2D", patterns,
728                                 &fused_computation_, &fused_computation_args_));
729   }
730 
731   void Compute(OpKernelContext* context) override {
732     // Input tensor is of the following dimensions:
733     // [ batch, in_rows, in_cols, in_depth ]
734     const Tensor& input = context->input(0);
735 
736     // Input filter is of the following dimensions:
737     // [ filter_rows, filter_cols, in_depth, out_depth]
738     const Tensor& filter = context->input(1);
739 
740     Conv2DDimensions dimensions;
741     OP_REQUIRES_OK(context,
742                    ComputeConv2DDimension(params_, input, filter, &dimensions));
743 
744     TensorShape out_shape = ShapeFromFormat(
745         params_.data_format, dimensions.batch, dimensions.out_rows,
746         dimensions.out_cols, dimensions.out_depth);
747 
748     // Output tensor is of the following dimensions:
749     // [ in_batch, out_rows, out_cols, out_depth ]
750     Tensor* output = nullptr;
751     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
752 
753     VLOG(2) << "FusedConv2D: in_depth = " << dimensions.in_depth
754             << ", patch_depth = " << dimensions.patch_depth
755             << ", input_cols = " << dimensions.input_cols
756             << ", filter_cols = " << dimensions.filter_cols
757             << ", input_rows = " << dimensions.input_rows
758             << ", filter_rows = " << dimensions.filter_rows
759             << ", stride_rows = " << dimensions.stride_rows
760             << ", stride_cols = " << dimensions.stride_cols
761             << ", dilation_rows = " << dimensions.dilation_rows
762             << ", dilation_cols = " << dimensions.dilation_cols
763             << ", out_depth = " << dimensions.out_depth;
764 
765     // If there is nothing to compute, return.
766     if (out_shape.num_elements() == 0) {
767       return;
768     }
769 
770     LaunchFusedConv2DOp<Device, T>()(context, use_cudnn_, cudnn_use_autotune_,
771                                      input, filter, fused_computation_,
772                                      fused_computation_args_, params_,
773                                      dimensions, output);
774   }
775 
776  private:
777   Conv2DParameters params_;
778   bool use_cudnn_;
779   bool cudnn_use_autotune_;
780 
781   FusedComputationType fused_computation_ = FusedComputationType::kUndefined;
782   FusedComputationArgs fused_computation_args_;
783 
784   TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DOp);
785 };
786 
787 // Registration of the CPU implementations.
788 #define REGISTER_FUSED_CPU_CONV2D(T)                                  \
789   REGISTER_KERNEL_BUILDER(                                            \
790       Name("_FusedConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
791       FusedConv2DOp<CPUDevice, T>);
792 
793 #if GOOGLE_CUDA
794 
795 #define DECLARE_FUNCTOR_GPU_SPEC(T)                                     \
796   template <>                                                           \
797   void TransformFilter<GPUDevice, T, int, 4>::operator()(               \
798       const GPUDevice& d, FilterTensorFormat dst_filter_format,         \
799       typename TTypes<T, 4, int>::ConstTensor in,                       \
800       typename TTypes<T, 4, int>::Tensor out);                          \
801   extern template struct TransformFilter<GPUDevice, T, int, 4>;         \
802   template <>                                                           \
803   void PadInput<GPUDevice, T, int, 4>::operator()(                      \
804       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,   \
805       const std::array<int, 2>& padding_left,                           \
806       const std::array<int, 2>& padding_right,                          \
807       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
808       const T& padding_value);                                          \
809   extern template struct PadInput<GPUDevice, T, int, 4>
810 
811 // Registration of the GPU implementations.
812 #define REGISTER_FUSED_GPU_CONV2D(T)                                  \
813   REGISTER_KERNEL_BUILDER(                                            \
814       Name("_FusedConv2D").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
815       FusedConv2DOp<GPUDevice, T>);
816 
817 #endif  // GOOGLE_CUDA
818 
819 }  // namespace tensorflow
820 
821 #endif  // TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
822