xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/conv_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define USE_EIGEN_TENSOR
19 #define EIGEN_USE_THREADS
20 
21 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 #define EIGEN_USE_GPU
23 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
24 
25 #include "tensorflow/core/kernels/conv_ops.h"
26 
27 #include <string.h>
28 
29 #include <atomic>
30 #include <map>
31 #include <utility>
32 #include <vector>
33 
34 #include "absl/synchronization/blocking_counter.h"
35 #include "tensorflow/core/framework/allocator.h"
36 #include "tensorflow/core/framework/bounds_check.h"
37 #include "tensorflow/core/framework/kernel_shape_util.h"
38 #include "tensorflow/core/framework/numeric_op.h"
39 #include "tensorflow/core/framework/op_kernel.h"
40 #include "tensorflow/core/framework/register_types.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor_shape.h"
43 #include "tensorflow/core/framework/tensor_slice.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/kernels/conv_2d.h"
46 #include "tensorflow/core/kernels/deep_conv2d.h"
47 #include "tensorflow/core/kernels/fill_functor.h"
48 #include "tensorflow/core/kernels/ops_util.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/lib/gtl/array_slice.h"
51 #include "tensorflow/core/lib/strings/numbers.h"
52 #include "tensorflow/core/lib/strings/str_util.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/macros.h"
55 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
56 #include "tensorflow/core/util/padding.h"
57 #include "tensorflow/core/util/tensor_format.h"
58 #include "tensorflow/core/util/use_cudnn.h"
59 
60 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
61 #include "tensorflow/core/kernels/xsmm_conv2d.h"
62 #endif
63 
64 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
65 #include "tensorflow/core/kernels/conv_ops_gpu.h"
66 #include "tensorflow/core/platform/stream_executor.h"
67 #include "tensorflow/core/protobuf/autotuning.pb.h"
68 #include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h"
69 #include "tensorflow/core/util/autotune_maps/conv_parameters.h"
70 #include "tensorflow/core/util/proto/proto_utils.h"
71 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
72 #if GOOGLE_CUDA
73 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
74 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
75 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
76 #endif  // GOOGLE_CUDA
77 
78 namespace tensorflow {
79 
80 typedef Eigen::ThreadPoolDevice CPUDevice;
81 typedef Eigen::GpuDevice GPUDevice;
82 
83 namespace {
84 template <typename Device, typename T>
85 struct LaunchGeneric {
operator ()tensorflow::__anonf1f4e3a10111::LaunchGeneric86   void operator()(OpKernelContext* ctx, const Tensor& input,
87                   const Tensor& filter, int row_stride, int col_stride,
88                   int row_dilation, int col_dilation, const Padding& padding,
89                   const std::vector<int64_t>& explicit_paddings, Tensor* output,
90                   TensorFormat data_format) {
91     CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
92                                          "supports NHWC tensor format for now.";
93     if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
94         col_stride == 1 && (padding == SAME || padding == VALID)) {
95       // For 1x1 kernel, the 2D convolution is reduced to matrix
96       // multiplication.
97       //
98       // TODO(vrv): We should be able to call SpatialConvolution
99       // and it will produce the same result, but doing so
100       // led to NaNs during training.  Using matmul instead for now.
101       int conv_width = 1;  // Width for the convolution step.
102       for (int i = 0; i < 3; ++i) {
103         conv_width *= output->dim_size(i);
104       }
105 
106       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
107       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
108       functor::MatMulConvFunctor<Device, T>()(
109           ctx->eigen_device<Device>(),
110           output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
111           input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
112           filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
113           dim_pair);
114     } else if (filter.dim_size(0) == input.dim_size(1) &&
115                filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
116                col_dilation == 1 && padding == VALID) {
117       // If the input data and filter have the same height/width,
118       // the 2D convolution is reduced to matrix multiplication.
119       const int k =  // Length of reduction dimension.
120           filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
121 
122       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
123       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
124       functor::MatMulConvFunctor<Device, T>()(
125           ctx->eigen_device<Device>(),
126           output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
127           input.shaped<T, 2>({input.dim_size(0), k}),
128           filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair);
129     } else {
130       if (padding == EXPLICIT) {
131         functor::SpatialConvolution<Device, T>()(
132             ctx->eigen_device<Device>(), output->tensor<T, 4>(),
133             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
134             row_dilation, col_dilation, static_cast<int>(explicit_paddings[2]),
135             static_cast<int>(explicit_paddings[3]),
136             static_cast<int>(explicit_paddings[4]),
137             static_cast<int>(explicit_paddings[5]));
138       } else {
139         functor::SpatialConvolution<Device, T>()(
140             ctx->eigen_device<Device>(), output->tensor<T, 4>(),
141             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
142             row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
143       }
144     }
145   }
146 };
147 
148 // Compute grouped 2D convolutions on CPU. Unlike grouped convolution
149 // implementation in cuDNN this is faaaaaar from optimal and needs more work
150 // to deliver competitive performance. Currently it exists to close the feature
151 // parity gap between convolution operations on different devices.
152 template <typename T>
153 struct LaunchGrouped {
operator ()tensorflow::__anonf1f4e3a10111::LaunchGrouped154   void operator()(OpKernelContext* ctx, const Tensor& input,
155                   const Tensor& filter, int row_stride, int col_stride,
156                   int row_dilation, int col_dilation, const Padding& padding,
157                   const std::vector<int64_t>& explicit_paddings, Tensor* output,
158                   TensorFormat data_format) {
159     DCHECK(data_format == FORMAT_NHWC)
160         << "Grouped conv implementation only "
161            "supports NHWC tensor format for now.";
162 
163     const int64_t in_depth = input.dim_size(3);
164     const int64_t patch_depth = filter.dim_size(2);
165     const int64_t num_groups = in_depth / patch_depth;
166 
167     // Shuffle input/filter tensors to have group as a leading dimension.
168     std::array<int64_t, 5> shuffle({3, 0, 1, 2, 4});
169 
170     // Compute pre shuffle dimemnsions.
171     auto pre_shuffle = [&](const Tensor& tensor) -> std::array<int64, 5> {
172       return {tensor.dim_size(0), tensor.dim_size(1), tensor.dim_size(2),
173               num_groups, tensor.dim_size(3) / num_groups};
174     };
175 
176     // Compute post shuffle dimemnsions.
177     auto post_shuffle = [&](const Tensor& tensor) -> std::array<int64, 5> {
178       return {num_groups, tensor.dim_size(0), tensor.dim_size(1),
179               tensor.dim_size(2), tensor.dim_size(3) / num_groups};
180     };
181 
182     auto& device = ctx->eigen_device<CPUDevice>();
183 
184     absl::BlockingCounter shuffles_completed(2);
185     auto on_shuffled = [&]() { shuffles_completed.DecrementCount(); };
186 
187     // Shuffle input into temporary tensor.
188     Tensor input_shuffled;
189     OP_REQUIRES_OK(
190         ctx, ctx->allocate_temp(input.dtype(), TensorShape(post_shuffle(input)),
191                                 &input_shuffled));
192     input_shuffled.tensor<T, 5>().device(device, on_shuffled) =
193         input.shaped<T, 5>(pre_shuffle(input)).shuffle(shuffle);
194 
195     // Shuffle filter into temporary tensor.
196     Tensor filter_shuffled;
197     OP_REQUIRES_OK(ctx, ctx->allocate_temp(filter.dtype(),
198                                            TensorShape(post_shuffle(filter)),
199                                            &filter_shuffled));
200     filter_shuffled.tensor<T, 5>().device(device, on_shuffled) =
201         filter.shaped<T, 5>(pre_shuffle(filter)).shuffle(shuffle);
202 
203     // Wait for the completion of input/filter shuffles.
204     shuffles_completed.Wait();
205 
206     // Write group convolution results into temporary output tensor.
207     Tensor output_shuffled;
208     OP_REQUIRES_OK(ctx, ctx->allocate_temp(output->dtype(),
209                                            TensorShape(post_shuffle(*output)),
210                                            &output_shuffled));
211 
212     for (int64_t i = 0; i < num_groups; ++i) {
213       // TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor
214       // will lead to deadlock, SpatialConvolution has to use async Eigen
215       // assignment). This requires small changes to Eigen to support async
216       // exeuction for tensor chipping operation.
217 
218       // TODO(ezhulenev): Grouped convolution should also support 1x1 filter
219       // optimization.
220 
221       auto input_slice = input_shuffled.tensor<T, 5>().template chip<0>(i);
222       auto filter_slice = filter_shuffled.tensor<T, 5>().template chip<0>(i);
223       auto output_slice = output_shuffled.tensor<T, 5>().template chip<0>(i);
224 
225       if (padding == EXPLICIT) {
226         functor::SpatialConvolution<CPUDevice, T>()(
227             ctx->eigen_device<CPUDevice>(), output_slice, input_slice,
228             filter_slice, row_stride, col_stride, row_dilation, col_dilation,
229             static_cast<int>(explicit_paddings[2]),
230             static_cast<int>(explicit_paddings[3]),
231             static_cast<int>(explicit_paddings[4]),
232             static_cast<int>(explicit_paddings[5]));
233       } else {
234         functor::SpatialConvolution<CPUDevice, T>()(
235             ctx->eigen_device<CPUDevice>(), output_slice, input_slice,
236             filter_slice, row_stride, col_stride, row_dilation, col_dilation,
237             BrainPadding2EigenPadding(padding));
238       }
239     }
240 
241     // Shuffle temporary output back into pre-shuffled shape.
242     std::array<int64_t, 5> rev_shuffle({1, 2, 3, 0, 4});
243     output->shaped<T, 5>(pre_shuffle(*output)).device(device) =
244         output_shuffled.tensor<T, 5>().shuffle(rev_shuffle);
245   }
246 };
247 
248 }  // namespace
249 
250 template <typename T>
251 struct LaunchConv2DOp<CPUDevice, T> {
operator ()tensorflow::LaunchConv2DOp252   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
253                   const Tensor& input, const Tensor& filter, int row_dilation,
254                   int col_dilation, int row_stride, int col_stride,
255                   const Padding& padding,
256                   const std::vector<int64_t>& explicit_paddings, Tensor* output,
257                   TensorFormat data_format) {
258     if (data_format != FORMAT_NHWC) {
259       ctx->SetStatus(errors::Unimplemented(
260           "The Conv2D op currently only supports the NHWC tensor format on the "
261           "CPU. The op was given the format: ",
262           ToString(data_format)));
263       return;
264     }
265 
266     for (int64_t explicit_padding : explicit_paddings) {
267       if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
268         ctx->SetStatus(errors::InvalidArgument("filter too large"));
269         return;
270       }
271     }
272 
273     const int64_t in_depth = input.dim_size(3);
274     const int64_t out_depth = output->dim_size(3);
275     const int64_t patch_depth = filter.dim_size(2);
276 
277     if (patch_depth <= 0) {
278       ctx->SetStatus(errors::InvalidArgument(
279           "filter depth must be stricly positive, got ", patch_depth));
280       return;
281     }
282     if (in_depth % patch_depth != 0) {
283       ctx->SetStatus(errors::InvalidArgument(
284           "input depth must be evenly divisible by filter depth: ", in_depth,
285           " vs ", patch_depth));
286       return;
287     }
288     if (filter.NumElements() <= 0) {
289       ctx->SetStatus(
290           errors::InvalidArgument("filter must not have zero elements "
291                                   "(i.e. all dimensions must be non-zero)"));
292       return;
293     }
294 
295     const int64_t num_groups = in_depth / patch_depth;
296     if (num_groups <= 0) {
297       ctx->SetStatus(errors::InvalidArgument(
298           "number of groups must be stricly positive, got ", num_groups));
299       return;
300     }
301     if (out_depth % num_groups != 0 || out_depth < num_groups) {
302       ctx->SetStatus(errors::InvalidArgument(
303           "output depth must be evenly divisible by number of groups: ",
304           out_depth, " vs ", num_groups));
305       return;
306     }
307 
308     if (in_depth != patch_depth) {
309       LaunchGrouped<T>()(ctx, input, filter, row_stride, col_stride,
310                          row_dilation, col_dilation, padding, explicit_paddings,
311                          output, data_format);
312     } else {
313       LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
314                                     row_dilation, col_dilation, padding,
315                                     explicit_paddings, output, data_format);
316     }
317   }
318 };
319 
320 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
321 template <>
322 struct LaunchConv2DOp<GPUDevice, int32> {
operator ()tensorflow::LaunchConv2DOp323   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
324                   const Tensor& input, const Tensor& filter, int row_dilation,
325                   int col_dilation, int row_stride, int col_stride,
326                   const Padding& padding,
327                   const std::vector<int64_t>& explicit_paddings, Tensor* output,
328                   TensorFormat data_format) {
329     if (data_format != FORMAT_NHWC) {
330       ctx->SetStatus(
331           errors::Unimplemented("The Conv2D op currently only supports the "
332                                 "NHWC tensor format for integer types. "
333                                 "The op was given the format: ",
334                                 ToString(data_format)));
335       return;
336     }
337     const int64_t in_depth = GetTensorDim(input, data_format, 'C');
338     OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
339                 errors::Unimplemented(
340                     "The Conv2D op currently does not support grouped "
341                     "convolutions for integer types. A grouped convolution was "
342                     "attempted to be run because the input depth of ",
343                     in_depth, " does not match the filter input depth of ",
344                     filter.dim_size(2)));
345     OP_REQUIRES(
346         ctx, filter.NumElements() > 0,
347         errors::InvalidArgument("filter must not have zero elements "
348                                 "(i.e. all dimensions must be non-zero)"));
349 
350     for (int64_t explicit_padding : explicit_paddings) {
351       if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
352         ctx->SetStatus(errors::InvalidArgument("filter too large"));
353         return;
354       }
355     }
356     LaunchGeneric<GPUDevice, int32>()(
357         ctx, input, filter, row_stride, col_stride, row_dilation, col_dilation,
358         padding, explicit_paddings, output, data_format);
359   }
360 };
361 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
362 
363 template <typename Device, typename T>
364 class LaunchDeepConvOp {
365  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int,int,int,int,int,int,Tensor *,TensorFormat)366   static bool Run(OpKernelContext* ctx, const Tensor& input,
367                   const Tensor& filter, int batch, int input_rows,
368                   int input_cols, int in_depth, int filter_rows,
369                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
370                   int /*out_cols*/, int /*out_depth*/, int /*dilation_rows*/,
371                   int /*dilation_cols*/, int /*stride_rows*/,
372                   int /*stride_cols*/, Tensor* /*output*/,
373                   TensorFormat /*data_format*/) {
374     return false;
375   }
376 };
377 
378 // Conditionally launches DeepConv operation based on convolution parameters.
379 template <>
380 class LaunchDeepConvOp<CPUDevice, float> {
381  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int dilation_rows,int dilation_cols,int stride_rows,int stride_cols,Tensor * output,TensorFormat data_format)382   static bool Run(OpKernelContext* ctx, const Tensor& input,
383                   const Tensor& filter, int batch, int input_rows,
384                   int input_cols, int in_depth, int filter_rows,
385                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
386                   int out_cols, int out_depth, int dilation_rows,
387                   int dilation_cols, int stride_rows, int stride_cols,
388                   Tensor* output, TensorFormat data_format) {
389     if (data_format != FORMAT_NHWC || dilation_rows != 1 ||
390         dilation_cols != 1 ||
391         !CanUseDeepConv2D(stride_rows, stride_cols, filter_rows, filter_cols,
392                           in_depth, out_depth, out_rows, out_cols)) {
393       return false;
394     }
395 
396     Conv2DArgs args;
397     args.batch = batch;
398     args.in_rows = input_rows;
399     args.in_cols = input_cols;
400     args.in_depth = in_depth;
401     args.filter_rows = filter_rows;
402     args.filter_cols = filter_cols;
403     args.pad_rows = pad_rows;
404     args.pad_cols = pad_cols;
405     args.out_rows = out_rows;
406     args.out_cols = out_cols;
407     args.out_depth = out_depth;
408 
409     auto input_ptr = input.template flat<float>().data();
410     auto filter_ptr = filter.template flat<float>().data();
411     auto output_ptr = output->template flat<float>().data();
412 
413     functor::DeepConv2D<CPUDevice, float>()(ctx, args, input_ptr, filter_ptr,
414                                             output_ptr);
415     return true;
416   }
417 };
418 
419 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
420 template <typename Device, typename T>
421 class LaunchXsmmConvOp {
422  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int stride_rows,int stride_cols,int dilation_rows,int dilation_cols,Tensor * output,TensorFormat data_format)423   static bool Run(OpKernelContext* ctx, const Tensor& input,
424                   const Tensor& filter, int batch, int input_rows,
425                   int input_cols, int in_depth, int filter_rows,
426                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
427                   int out_cols, int out_depth, int stride_rows, int stride_cols,
428                   int dilation_rows, int dilation_cols, Tensor* output,
429                   TensorFormat data_format) {
430     return false;
431   }
432 };
433 
434 template <>
435 class LaunchXsmmConvOp<CPUDevice, float> {
436  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int dilation_rows,int dilation_cols,int stride_rows,int stride_cols,Tensor * output,TensorFormat data_format)437   static bool Run(OpKernelContext* ctx, const Tensor& input,
438                   const Tensor& filter, int batch, int input_rows,
439                   int input_cols, int in_depth, int filter_rows,
440                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
441                   int out_cols, int out_depth, int dilation_rows,
442                   int dilation_cols, int stride_rows, int stride_cols,
443                   Tensor* output, TensorFormat data_format) {
444     auto num_threads =
445         ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
446     // See libxsmm_dnn.h for this struct definition.
447     libxsmm_dnn_conv_desc desc;
448     desc.N = batch;
449     desc.C = in_depth;
450     desc.H = input_rows;
451     desc.W = input_cols;
452     desc.K = out_depth;
453     desc.R = filter_rows;
454     desc.S = filter_cols;
455     desc.u = stride_rows;
456     desc.v = stride_cols;
457     desc.pad_h = pad_rows;
458     desc.pad_w = pad_cols;
459     desc.pad_h_in = 0;
460     desc.pad_w_in = 0;
461     desc.pad_h_out = 0;
462     desc.pad_w_out = 0;
463     desc.threads = num_threads;
464     desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
465     desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
466     desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
467     desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
468     desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
469     desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
470     desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
471     if (dilation_rows != 1 || dilation_cols != 1 ||
472         !CanUseXsmmConv2D(desc, data_format)) {
473       return false;
474     }
475 
476     auto input_ptr = input.template flat<float>().data();
477     auto filter_ptr = filter.template flat<float>().data();
478     auto output_ptr = output->template flat<float>().data();
479 
480     bool success = functor::XsmmFwdConv2D<CPUDevice, float>()(
481         ctx, desc, input_ptr, filter_ptr, output_ptr);
482     return success;
483   }
484 };
485 #endif
486 
487 #define TF_REQUIRES(EXP, STATUS)                \
488   do {                                          \
489     if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \
490   } while (false)
491 
InitConv2DParameters(const OpKernelConstruction * context,Conv2DParameters * params)492 Status InitConv2DParameters(const OpKernelConstruction* context,
493                             Conv2DParameters* params) {
494   TF_RETURN_IF_ERROR(context->GetAttr("dilations", &params->dilations));
495   TF_RETURN_IF_ERROR(context->GetAttr("strides", &params->strides));
496   TF_RETURN_IF_ERROR(context->GetAttr("padding", &params->padding));
497   if (context->HasAttr("explicit_paddings")) {
498     TF_RETURN_IF_ERROR(
499         context->GetAttr("explicit_paddings", &params->explicit_paddings));
500   }
501   string data_format_string;
502   TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string));
503   TF_REQUIRES(FormatFromString(data_format_string, &params->data_format),
504               errors::InvalidArgument("Invalid data format"));
505 
506   const auto& strides = params->strides;
507   const auto& dilations = params->dilations;
508   const auto& data_format = params->data_format;
509 
510   TF_REQUIRES(dilations.size() == 4,
511               errors::InvalidArgument("Sliding window dilations field must "
512                                       "specify 4 dimensions"));
513   TF_REQUIRES(strides.size() == 4,
514               errors::InvalidArgument("Sliding window strides field must "
515                                       "specify 4 dimensions"));
516   const int64_t stride_n = GetTensorDim(strides, data_format, 'N');
517   const int64_t stride_c = GetTensorDim(strides, data_format, 'C');
518   const int64_t stride_h = GetTensorDim(strides, data_format, 'H');
519   const int64_t stride_w = GetTensorDim(strides, data_format, 'W');
520   TF_REQUIRES(
521       stride_n == 1 && stride_c == 1,
522       errors::Unimplemented("Current implementation does not yet support "
523                             "strides in the batch and depth dimensions."));
524   TF_REQUIRES(stride_h > 0 && stride_w > 0,
525               errors::InvalidArgument(
526                   "Row and column strides should be larger than 0."));
527 
528   const int64_t dilation_n = GetTensorDim(dilations, data_format, 'N');
529   const int64_t dilation_c = GetTensorDim(dilations, data_format, 'C');
530   const int64_t dilation_h = GetTensorDim(dilations, data_format, 'H');
531   const int64_t dilation_w = GetTensorDim(dilations, data_format, 'W');
532   TF_REQUIRES(
533       dilation_n == 1 && dilation_c == 1,
534       errors::Unimplemented("Current implementation does not yet support "
535                             "dilations in the batch and depth dimensions."));
536   TF_REQUIRES(
537       dilation_h > 0 && dilation_w > 0,
538       errors::InvalidArgument("Dilated rates should be larger than 0."));
539 
540   TF_RETURN_IF_ERROR(CheckValidPadding(params->padding,
541                                        params->explicit_paddings,
542                                        /*num_dims=*/4, data_format));
543 
544   return OkStatus();
545 }
546 
ComputeConv2DDimension(const Conv2DParameters & params,const Tensor & input,const Tensor & filter,Conv2DDimensions * dimensions)547 Status ComputeConv2DDimension(const Conv2DParameters& params,
548                               const Tensor& input, const Tensor& filter,
549                               Conv2DDimensions* dimensions) {
550   // Check that 2D convolution input and filter have exactly 4 dimensions.
551   TF_REQUIRES(input.dims() == 4,
552               errors::InvalidArgument("input must be 4-dimensional",
553                                       input.shape().DebugString()));
554   TF_REQUIRES(filter.dims() == 4,
555               errors::InvalidArgument("filter must be 4-dimensional: ",
556                                       filter.shape().DebugString()));
557   for (int i = 0; i < 3; i++) {
558     TF_REQUIRES(
559         FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
560         errors::InvalidArgument("filter too large"));
561   }
562 
563   // The last dimension for input is in_depth. Check that it is the same as the
564   // filter's in_depth or it is evenly divisible by filter's in_depth.
565   const int64_t in_depth_raw = GetTensorDim(input, params.data_format, 'C');
566   const int64_t patch_depth_raw = filter.dim_size(2);
567   TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()),
568               errors::InvalidArgument("Input depth too large"));
569   TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()),
570               errors::InvalidArgument("Patch depth too large"));
571   const int in_depth = static_cast<int>(in_depth_raw);
572   const int patch_depth = static_cast<int>(patch_depth_raw);
573   TF_REQUIRES(patch_depth > 0,
574               errors::InvalidArgument(
575                   "filter depth must be stricly positive, got ", patch_depth));
576   TF_REQUIRES(in_depth % patch_depth == 0,
577               errors::InvalidArgument(
578                   "input depth must be evenly divisible by filter depth: ",
579                   in_depth, " vs ", patch_depth));
580 
581   // The last dimension for filter is out_depth.
582   const int out_depth = static_cast<int>(filter.dim_size(3));
583 
584   // The second dimension for input is rows/height.
585   // The first dimension for filter is rows/height.
586   const int64_t input_rows_raw = GetTensorDim(input, params.data_format, 'H');
587   TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
588               errors::InvalidArgument("Input rows too large"));
589   const int input_rows = static_cast<int>(input_rows_raw);
590   const int filter_rows = static_cast<int>(filter.dim_size(0));
591 
592   // The third dimension for input is columns/width.
593   // The second dimension for filter is columns/width.
594   const int64_t input_cols_raw = GetTensorDim(input, params.data_format, 'W');
595   TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
596               errors::InvalidArgument("Input cols too large"));
597   const int input_cols = static_cast<int>(input_cols_raw);
598   const int filter_cols = static_cast<int>(filter.dim_size(1));
599 
600   // The first dimension for input is batch.
601   const int64_t batch_raw = GetTensorDim(input, params.data_format, 'N');
602   TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
603               errors::InvalidArgument("batch is too large"));
604   const int batch = static_cast<int>(batch_raw);
605 
606   // Take the stride and dilation from the second and third dimensions only (we
607   // do not support striding or dilation on the batch or depth dimension).
608   const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H');
609   const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W');
610   const int dilation_rows =
611       GetTensorDim(params.dilations, params.data_format, 'H');
612   const int dilation_cols =
613       GetTensorDim(params.dilations, params.data_format, 'W');
614 
615   int64_t pad_rows_before, pad_rows_after, pad_cols_before, pad_cols_after;
616   if (params.padding == Padding::EXPLICIT) {
617     GetExplicitPaddingForDim(params.explicit_paddings, params.data_format, 'H',
618                              &pad_rows_before, &pad_rows_after);
619     GetExplicitPaddingForDim(params.explicit_paddings, params.data_format, 'W',
620                              &pad_cols_before, &pad_cols_after);
621   }
622 
623   // Compute windowed output sizes for rows and columns.
624   int64_t out_rows = 0, out_cols = 0;
625   TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
626       input_rows, filter_rows, dilation_rows, stride_rows, params.padding,
627       &out_rows, &pad_rows_before, &pad_rows_after));
628   TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
629       input_cols, filter_cols, dilation_cols, stride_cols, params.padding,
630       &out_cols, &pad_cols_before, &pad_cols_after));
631 
632   dimensions->batch = batch;
633   dimensions->input_rows = input_rows;
634   dimensions->input_cols = input_cols;
635   dimensions->in_depth = in_depth;
636   dimensions->filter_rows = filter_rows;
637   dimensions->filter_cols = filter_cols;
638   dimensions->patch_depth = patch_depth;
639   dimensions->out_depth = out_depth;
640   dimensions->stride_rows = stride_rows;
641   dimensions->stride_cols = stride_cols;
642   dimensions->dilation_rows = dilation_rows;
643   dimensions->dilation_cols = dilation_cols;
644   dimensions->out_rows = out_rows;
645   dimensions->out_cols = out_cols;
646   dimensions->pad_rows_before = pad_rows_before;
647   dimensions->pad_rows_after = pad_rows_after;
648   dimensions->pad_cols_before = pad_cols_before;
649   dimensions->pad_cols_after = pad_cols_after;
650 
651   return OkStatus();
652 }
653 
654 #undef TF_REQUIRES
655 
656 template <typename Device, typename T>
657 class Conv2DOp : public BinaryOp<T> {
658  public:
Conv2DOp(OpKernelConstruction * context)659   explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
660     OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
661 
662     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
663     cudnn_use_autotune_ = CudnnUseAutotune();
664   }
665 
Compute(OpKernelContext * context)666   void Compute(OpKernelContext* context) override {
667     // Input tensor is of the following dimensions:
668     // [ batch, in_rows, in_cols, in_depth ]
669     const Tensor& input = context->input(0);
670 
671     // Input filter is of the following dimensions:
672     // [ filter_rows, filter_cols, in_depth, out_depth]
673     const Tensor& filter = context->input(1);
674 
675     Conv2DDimensions dimensions;
676     OP_REQUIRES_OK(context,
677                    ComputeConv2DDimension(params_, input, filter, &dimensions));
678 
679     TensorShape out_shape = ShapeFromFormat(
680         params_.data_format, dimensions.batch, dimensions.out_rows,
681         dimensions.out_cols, dimensions.out_depth);
682 
683     // Output tensor is of the following dimensions:
684     // [ in_batch, out_rows, out_cols, out_depth ]
685     Tensor* output = nullptr;
686     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
687 
688     VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth
689             << ", patch_depth = " << dimensions.patch_depth
690             << ", input_cols = " << dimensions.input_cols
691             << ", filter_cols = " << dimensions.filter_cols
692             << ", input_rows = " << dimensions.input_rows
693             << ", filter_rows = " << dimensions.filter_rows
694             << ", stride_rows = " << dimensions.stride_rows
695             << ", stride_cols = " << dimensions.stride_cols
696             << ", dilation_rows = " << dimensions.dilation_rows
697             << ", dilation_cols = " << dimensions.dilation_cols
698             << ", out_depth = " << dimensions.out_depth;
699 
700     // If there is nothing to compute, return.
701     if (out_shape.num_elements() == 0) {
702       return;
703     }
704 
705     // If the input is empty, result can only be due to padding.
706     if (input.NumElements() == 0) {
707       // Zero-out output and return.
708       functor::SetZeroFunctor<Device, T>()(context->eigen_device<Device>(),
709                                            output->template flat<T>());
710 
711       return;
712     }
713 
714 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
715     if (params_.padding != EXPLICIT &&
716         LaunchXsmmConvOp<Device, T>::Run(
717             context, input, filter, dimensions.batch, dimensions.input_rows,
718             dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
719             dimensions.filter_cols, dimensions.pad_rows_before,
720             dimensions.pad_cols_before, dimensions.out_rows,
721             dimensions.out_cols, dimensions.out_depth, dimensions.dilation_rows,
722             dimensions.dilation_cols, dimensions.stride_rows,
723             dimensions.stride_cols, output, params_.data_format)) {
724       return;
725     }
726 #endif
727 
728     if (params_.padding != EXPLICIT &&
729         LaunchDeepConvOp<Device, T>::Run(
730             context, input, filter, dimensions.batch, dimensions.input_rows,
731             dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
732             dimensions.filter_cols, dimensions.pad_rows_before,
733             dimensions.pad_cols_before, dimensions.out_rows,
734             dimensions.out_cols, dimensions.out_depth, dimensions.dilation_rows,
735             dimensions.dilation_cols, dimensions.stride_rows,
736             dimensions.stride_cols, output, params_.data_format)) {
737       return;
738     }
739 
740     launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
741               dimensions.dilation_rows, dimensions.dilation_cols,
742               dimensions.stride_rows, dimensions.stride_cols, params_.padding,
743               params_.explicit_paddings, output, params_.data_format);
744   }
745 
746  private:
747   Conv2DParameters params_;
748   bool use_cudnn_;
749   bool cudnn_use_autotune_;
750 
751   LaunchConv2DOp<Device, T> launcher_;
752 
753   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
754 };
755 
756 #define REGISTER_CPU(T)                                         \
757   REGISTER_KERNEL_BUILDER(                                      \
758       Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
759       Conv2DOp<CPUDevice, T>);
760 
761 // If we're using the alternative GEMM-based implementation of Conv2D for the
762 // CPU implementation, don't register this EigenTensor-based version.
763 #if !defined(USE_GEMM_FOR_CONV)
764 TF_CALL_half(REGISTER_CPU);
765 TF_CALL_float(REGISTER_CPU);
766 TF_CALL_double(REGISTER_CPU);
767 TF_CALL_int32(REGISTER_CPU);
768 #endif  // USE_GEMM_FOR_CONV
769 
770 // To be used inside depthwise_conv_op.cc.
771 template struct LaunchConv2DOp<CPUDevice, Eigen::bfloat16>;
772 template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
773 template struct LaunchConv2DOp<CPUDevice, float>;
774 template struct LaunchConv2DOp<CPUDevice, double>;
775 
776 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
777 
GetDnnWorkspaceLimit(const string & envvar_in_mb,int64_t default_value_in_bytes)778 int64_t GetDnnWorkspaceLimit(const string& envvar_in_mb,
779                              int64_t default_value_in_bytes) {
780   const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
781   if (workspace_limit_in_mb_str != nullptr &&
782       strcmp(workspace_limit_in_mb_str, "") != 0) {
783     int64_t scratch_limit_in_mb = -1;
784     if (strings::safe_strto64(workspace_limit_in_mb_str,
785                               &scratch_limit_in_mb)) {
786       return scratch_limit_in_mb * (1 << 20);
787     } else {
788       LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
789                    << workspace_limit_in_mb_str;
790     }
791   }
792   return default_value_in_bytes;
793 }
794 
GetDnnWorkspaceLimitOrDefault()795 int64_t GetDnnWorkspaceLimitOrDefault() {
796   return GetDnnWorkspaceLimit("TF_CUDNN_WORKSPACE_LIMIT_IN_MB",
797                               1LL << 33);  // 8GB by default
798 }
799 
800 template <typename T>
operator ()(OpKernelContext * ctx,bool use_cudnn,bool cudnn_use_autotune,const Tensor & input_param,const Tensor & filter,int row_dilation,int col_dilation,int row_stride,int col_stride,const Padding & padding,const std::vector<int64_t> & explicit_paddings,Tensor * output,TensorFormat data_format)801 void LaunchConv2DOp<GPUDevice, T>::operator()(
802     OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
803     const Tensor& input_param, const Tensor& filter, int row_dilation,
804     int col_dilation, int row_stride, int col_stride, const Padding& padding,
805     const std::vector<int64_t>& explicit_paddings, Tensor* output,
806     TensorFormat data_format) {
807   using se::dnn::AlgorithmConfig;
808   using se::dnn::AlgorithmDesc;
809   using se::dnn::ProfileResult;
810   auto* stream = ctx->op_device_context()->stream();
811   OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
812 
813   if (!use_cudnn) {
814     ctx->SetStatus(
815         errors::Unimplemented("Conv2D for GPU is not currently supported "
816                               "without cudnn"));
817     return;
818   }
819 
820   Tensor input = input_param;
821   const int64_t in_batch = GetTensorDim(input, data_format, 'N');
822   int64_t in_rows = GetTensorDim(input, data_format, 'H');
823   int64_t in_cols = GetTensorDim(input, data_format, 'W');
824   const int64_t in_depths = GetTensorDim(input, data_format, 'C');
825   const int64_t patch_rows = filter.dim_size(0);
826   const int64_t patch_cols = filter.dim_size(1);
827   const int64_t patch_depths = filter.dim_size(2);
828 
829   OP_REQUIRES(
830       ctx, filter.NumElements() > 0,
831       errors::InvalidArgument("filter must not have zero elements "
832                               "(i.e. all dimensions must be non-zero)"));
833 
834   // If the filter in-depth (patch_depths) is 1 and smaller than the input
835   // depth, it's a depthwise convolution. More generally, if the filter in-depth
836   // divides but is smaller than the input depth, it is a grouped convolution.
837   bool is_grouped_convolution = patch_depths != in_depths;
838   if (patch_rows == 1 && patch_cols == 1 && !is_grouped_convolution &&
839       row_dilation == 1 && col_dilation == 1 && row_stride == 1 &&
840       col_stride == 1 && data_format == FORMAT_NHWC &&
841       (padding == VALID || padding == SAME)) {
842     // 1x1 filter, so call cublas directly.
843     const uint64 m = in_batch * in_rows * in_cols;
844     const uint64 k = patch_depths;
845     const uint64 n = filter.dim_size(3);
846 
847     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
848                                 input.template flat<T>().size());
849     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
850                                 filter.template flat<T>().size());
851     auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
852                                 output->template flat<T>().size());
853 
854     auto no_transpose = se::blas::Transpose::kNoTranspose;
855     OP_REQUIRES_OK(
856         ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n,
857                                   a_ptr, k, &c_ptr, n,
858                                   se::blas::kDefaultComputePrecision));
859     return;
860   } else if (patch_rows == in_rows && patch_cols == in_cols &&
861              !is_grouped_convolution && row_dilation == 1 &&
862              col_dilation == 1 && padding == VALID &&
863              data_format == FORMAT_NHWC) {
864     // The input data and filter have the same height/width, so call cublas
865     // directly.
866     const uint64 m = in_batch;
867     const uint64 k = patch_rows * patch_cols * patch_depths;
868     const uint64 n = filter.dim_size(3);
869 
870     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
871                                 input.template flat<T>().size());
872     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
873                                 filter.template flat<T>().size());
874     auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
875                                 output->template flat<T>().size());
876 
877     auto no_transpose = se::blas::Transpose::kNoTranspose;
878     OP_REQUIRES_OK(
879         ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n,
880                                   a_ptr, k, &c_ptr, n,
881                                   se::blas::kDefaultComputePrecision));
882     return;
883   }
884 
885 #if GOOGLE_CUDA
886   // Tensor Core (NVIDIA Volta+ GPUs) supports efficient convolution with fp16
887   // in NHWC data layout. In all other configurations it's more efficient to
888   // run computation in NCHW data format.
889   const bool compute_in_nhwc = DataTypeToEnum<T>::value == DT_HALF &&
890                                stream->GetCudaComputeCapability().IsAtLeast(
891                                    se::CudaComputeCapability::VOLTA);
892 #else
893   // fast NHWC implementation is a CUDA only feature
894   const bool compute_in_nhwc = false;
895 #endif
896 
897   // We only do one directional conversion: NHWC->NCHW. We never convert in the
898   // other direction. Grappler layout optimizer selects preferred layout and
899   // adds necessary annotations to the graph.
900   // TODO(ezhulenev): Convert in other direction for fp16?
901   const TensorFormat compute_data_format =
902       (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
903                                                       : FORMAT_NCHW;
904 
905   VLOG(3) << "Compute Conv2D with cuDNN:"
906           << " data_format=" << ToString(data_format)
907           << " compute_data_format=" << ToString(compute_data_format);
908 
909   const int64_t out_batch = GetTensorDim(*output, data_format, 'N');
910   const int64_t out_rows = GetTensorDim(*output, data_format, 'H');
911   const int64_t out_cols = GetTensorDim(*output, data_format, 'W');
912   const int64_t out_depths = GetTensorDim(*output, data_format, 'C');
913   int64_t padding_top = -1, padding_bottom = -1;
914   int64_t padding_left = -1, padding_right = -1;
915   if (padding == EXPLICIT) {
916     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top,
917                              &padding_bottom);
918     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left,
919                              &padding_right);
920   }
921   int64_t out_rows_check, out_cols_check;
922   Status status = GetWindowedOutputSizeVerboseV2(
923       in_rows, patch_rows, row_dilation, row_stride, padding, &out_rows_check,
924       &padding_top, &padding_bottom);
925   // The status is guaranteed to be OK because we checked the output and padding
926   // was valid earlier.
927   TF_CHECK_OK(status);
928   DCHECK_EQ(out_rows, out_rows_check);
929   status = GetWindowedOutputSizeVerboseV2(in_cols, patch_cols, col_dilation,
930                                           col_stride, padding, &out_cols_check,
931                                           &padding_left, &padding_right);
932   TF_CHECK_OK(status);
933   DCHECK_EQ(out_cols, out_cols_check);
934 
935   const int64_t common_padding_rows = std::min(padding_top, padding_bottom);
936   const int64_t common_padding_cols = std::min(padding_left, padding_right);
937   if (padding_top != padding_bottom || padding_left != padding_right) {
938     // cuDNN only supports padding the same amount on the left and right sides,
939     // and on the top and bottom sides. So we manually create a new padded
940     // input tensor such that we can pass it to cuDNN.
941     VLOG(4) << "Pad input tensor:"
942             << " padding_top=" << padding_top
943             << " padding_bottom=" << padding_bottom
944             << " padding_left=" << padding_left
945             << " padding_right=" << padding_right;
946 
947     // TODO(reedwm): In some cases, we can avoid an allocation even if the two
948     // padding sides are different. For example, if the input is 2x2, the filter
949     // is 1x1, the stride is 2, and the padding is (1, 0, 1, 0), the result is
950     // equivalent to as if the padding is (1, 1, 1, 1). Changing the padding in
951     // such a way would allow us to avoid the allocation.
952     Tensor transformed_input;
953     const int64_t padding_rows_diff = std::abs(padding_bottom - padding_top);
954     const int64_t padding_cols_diff = std::abs(padding_right - padding_left);
955     const int64_t new_in_rows = in_rows + padding_rows_diff;
956     const int64_t new_in_cols = in_cols + padding_cols_diff;
957     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
958                             DataTypeToEnum<T>::value,
959                             ShapeFromFormat(data_format, in_batch, new_in_rows,
960                                             new_in_cols, in_depths),
961                             &transformed_input));
962 
963     const int64_t input_pad_top = padding_top - common_padding_rows;
964     const int64_t input_pad_bottom = padding_bottom - common_padding_rows;
965     const int64_t input_pad_left = padding_left - common_padding_cols;
966     const int64_t input_pad_right = padding_right - common_padding_cols;
967     bool in_bounds =
968         FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
969         FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
970         FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
971         FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
972     if (!in_bounds) {
973       ctx->SetStatus(errors::InvalidArgument("Padding is too large."));
974       return;
975     }
976     functor::PadInput<GPUDevice, T, int, 4>()(
977         ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
978         {{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}},
979         {{static_cast<int>(input_pad_bottom),
980           static_cast<int>(input_pad_right)}},
981         To32Bit(transformed_input.tensor<T, 4>()), data_format, T{});
982 
983     input = transformed_input;
984     in_rows = new_in_rows;
985     in_cols = new_in_cols;
986   }
987 
988   if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
989     VLOG(4) << "Convert the input tensor from NHWC to NCHW.";
990 
991     TensorShape nchw_shape =
992         ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
993     if (in_depths > 1) {
994       Tensor transformed_input;
995       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
996                                              nchw_shape, &transformed_input));
997       functor::NHWCToNCHW<GPUDevice, T, 4>()(
998           ctx->eigen_device<GPUDevice>(),
999           const_cast<const Tensor&>(input).tensor<T, 4>(),
1000           transformed_input.tensor<T, 4>());
1001       input = transformed_input;
1002     } else {
1003       // If depth <= 1, then just reshape.
1004       CHECK(input.CopyFrom(input, nchw_shape));
1005     }
1006   } else {
1007     CHECK(data_format == compute_data_format)  // Crash OK
1008         << "Illegal data and compute format pair:"
1009         << " data_format=" << ToString(data_format)
1010         << " compute_data_format=" << ToString(compute_data_format);
1011   }
1012 
1013   CHECK(common_padding_rows >= 0 && common_padding_cols >= 0)  // Crash OK
1014       << "Negative row or col paddings: (" << common_padding_rows << ", "
1015       << common_padding_cols << ")";
1016 
1017   constexpr auto kComputeInNHWC =
1018       std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1019                       se::dnn::FilterLayout::kOutputYXInput);
1020   constexpr auto kComputeInNCHW =
1021       std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1022                       se::dnn::FilterLayout::kOutputInputYX);
1023 
1024   se::dnn::DataLayout compute_data_layout;
1025   se::dnn::FilterLayout filter_layout;
1026 
1027   std::tie(compute_data_layout, filter_layout) =
1028       compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1029 
1030   se::dnn::BatchDescriptor input_desc;
1031   input_desc.set_count(in_batch)
1032       .set_feature_map_count(in_depths)
1033       .set_height(in_rows)
1034       .set_width(in_cols)
1035       .set_layout(compute_data_layout);
1036   se::dnn::BatchDescriptor output_desc;
1037   output_desc.set_count(out_batch)
1038       .set_height(out_rows)
1039       .set_width(out_cols)
1040       .set_feature_map_count(out_depths)
1041       .set_layout(compute_data_layout);
1042   se::dnn::FilterDescriptor filter_desc;
1043   filter_desc.set_input_filter_height(patch_rows)
1044       .set_input_filter_width(patch_cols)
1045       .set_input_feature_map_count(patch_depths)
1046       .set_output_feature_map_count(filter.dim_size(3))
1047       .set_layout(filter_layout);
1048   se::dnn::ConvolutionDescriptor conv_desc;
1049   conv_desc.set_vertical_dilation_rate(row_dilation)
1050       .set_horizontal_dilation_rate(col_dilation)
1051       .set_vertical_filter_stride(row_stride)
1052       .set_horizontal_filter_stride(col_stride)
1053       .set_zero_padding_height(common_padding_rows)
1054       .set_zero_padding_width(common_padding_cols)
1055       .set_group_count(in_depths / patch_depths);
1056 
1057   Tensor transformed_filter;
1058 
1059   const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
1060     VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
1061             << " to " << ToString(dst_format);
1062 
1063     TensorShape dst_shape =
1064         dst_format == FORMAT_OIHW
1065             ? TensorShape({filter.dim_size(3), filter.dim_size(2),
1066                            filter.dim_size(0), filter.dim_size(1)})
1067             : TensorShape({filter.dim_size(3), filter.dim_size(0),
1068                            filter.dim_size(1), filter.dim_size(2)});
1069 
1070     TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1071                                           &transformed_filter));
1072     functor::TransformFilter<GPUDevice, T, int, 4>()(
1073         ctx->eigen_device<GPUDevice>(), dst_format,
1074         To32Bit(filter.tensor<T, 4>()),
1075         To32Bit(transformed_filter.tensor<T, 4>()));
1076 
1077     return OkStatus();
1078   };
1079 
1080   if (compute_data_format == FORMAT_NCHW) {
1081     OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OIHW));
1082   } else if (compute_data_format == FORMAT_NHWC) {
1083     OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OHWI));
1084   } else {
1085     ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
1086                                            ToString(compute_data_format)));
1087     return;
1088   }
1089 
1090   Tensor transformed_output;
1091   if (data_format != compute_data_format) {
1092     VLOG(4) << "Allocate temporary memory for output in compute data format";
1093     OP_REQUIRES_OK(
1094         ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1095                                 ShapeFromFormat(compute_data_format, out_batch,
1096                                                 out_rows, out_cols, out_depths),
1097                                 &transformed_output));
1098   } else {
1099     transformed_output = *output;
1100   }
1101 
1102   auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
1103                                   input.template flat<T>().size());
1104   auto filter_ptr =
1105       AsDeviceMemory(transformed_filter.template flat<T>().data(),
1106                      transformed_filter.template flat<T>().size());
1107   auto output_ptr =
1108       AsDeviceMemory(transformed_output.template flat<T>().data(),
1109                      transformed_output.template flat<T>().size());
1110 
1111   static int64_t ConvolveScratchSize = GetDnnWorkspaceLimitOrDefault();
1112 
1113   int device_id = stream->parent()->device_ordinal();
1114   DataType dtype = input.dtype();
1115   ConvParameters conv_parameters = {in_batch,             // batch
1116                                     in_depths,            // in_depths
1117                                     {{in_rows,            // in_rows
1118                                       in_cols}},          // in_cols
1119                                     compute_data_format,  // compute_data_format
1120                                     out_depths,           // out_depths
1121                                     {{patch_rows,         // filter_rows
1122                                       patch_cols,         // filter_cols
1123                                       patch_depths}},     // filter_depths
1124                                     {{row_dilation,       // dilation_rows
1125                                       col_dilation}},     // dilation_cols
1126                                     {{row_stride,         // stride_rows
1127                                       col_stride}},       // stride_cols
1128                                     {{common_padding_rows,    // padding_rows
1129                                       common_padding_cols}},  // padding_cols
1130                                     dtype,                    // tensor datatype
1131                                     device_id,                // device_id
1132                                     conv_desc.group_count()};
1133 
1134   auto entry_or = AutotuneUnfusedConv(
1135       cudnn_use_autotune, ConvAutotuneMap::GetInstance(), conv_parameters, ctx,
1136       se::dnn::ConvolutionKind::FORWARD, input_desc, input_ptr, filter_desc,
1137       filter_ptr, conv_desc, output_desc, output_ptr, ConvolveScratchSize);
1138   OP_REQUIRES_OK(ctx, entry_or.status());
1139   auto autotune_entry = std::move(entry_or).value();
1140 
1141   DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
1142   Status cudnn_launch_status = LaunchAutotunedConv(
1143       autotune_entry, &scratch_allocator, se::dnn::ConvolutionKind::FORWARD,
1144       stream, input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
1145       output_desc, output_ptr);
1146   if (!cudnn_launch_status.ok()) {
1147     ctx->SetStatus(cudnn_launch_status);
1148     return;
1149   }
1150 
1151   if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1152     VLOG(4) << "Convert the output tensor back from NCHW to NHWC.";
1153     functor::NCHWToNHWC<GPUDevice, T, 4>()(
1154         ctx->eigen_device<GPUDevice>(),
1155         const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
1156         output->tensor<T, 4>());
1157   }
1158 }
1159 
1160 // Forward declarations of the functor specializations for GPU.
1161 namespace functor {
1162 #define DECLARE_GPU_SPEC(T)                                                 \
1163   template <>                                                               \
1164   void SpatialConvolution<GPUDevice, T>::operator()(                        \
1165       const GPUDevice& d, typename TTypes<T, 4>::Tensor output,             \
1166       typename TTypes<T, 4>::ConstTensor input,                             \
1167       typename TTypes<T, 4>::ConstTensor filter, int row_stride,            \
1168       int col_stride, int row_dilation, int col_dilation,                   \
1169       const Eigen::PaddingType& padding,                                    \
1170       const Eigen::NoOpOutputKernel& output_kernel);                        \
1171   template <>                                                               \
1172   void SpatialConvolution<GPUDevice, T>::operator()(                        \
1173       const GPUDevice& d, typename TTypes<T, 4>::Tensor output,             \
1174       typename TTypes<T, 4>::ConstTensor input,                             \
1175       typename TTypes<T, 4>::ConstTensor filter, int row_stride,            \
1176       int col_stride, int row_dilation, int col_dilation, int padding_top,  \
1177       int padding_bottom, int padding_left, int padding_right,              \
1178       const Eigen::NoOpOutputKernel& output_kernel);                        \
1179   extern template struct SpatialConvolution<GPUDevice, T>;                  \
1180   template <>                                                               \
1181   void MatMulConvFunctor<GPUDevice, T>::operator()(                         \
1182       const GPUDevice& d, typename TTypes<T, 2>::Tensor out,                \
1183       typename TTypes<T, 2>::ConstTensor in0,                               \
1184       typename TTypes<T, 2>::ConstTensor in1,                               \
1185       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, \
1186       const Eigen::NoOpOutputKernel& output_kernel);                        \
1187   extern template struct MatMulConvFunctor<GPUDevice, T>;                   \
1188   template <>                                                               \
1189   void TransformFilter<GPUDevice, T, int, 4>::operator()(                   \
1190       const GPUDevice& d, FilterTensorFormat dst_filter_format,             \
1191       typename TTypes<T, 4, int>::ConstTensor in,                           \
1192       typename TTypes<T, 4, int>::Tensor out);                              \
1193   extern template struct TransformFilter<GPUDevice, T, int, 4>;             \
1194   template <>                                                               \
1195   void PadInput<GPUDevice, T, int, 4>::operator()(                          \
1196       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,       \
1197       const std::array<int, 2>& padding_left,                               \
1198       const std::array<int, 2>& padding_right,                              \
1199       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format,     \
1200       const T& padding_value);                                              \
1201   extern template struct PadInput<GPUDevice, T, int, 4>
1202 
1203 DECLARE_GPU_SPEC(float);
1204 DECLARE_GPU_SPEC(Eigen::half);
1205 DECLARE_GPU_SPEC(double);
1206 DECLARE_GPU_SPEC(int32);
1207 #undef DECLARE_GPU_SPEC
1208 
1209 }  // namespace functor
1210 
1211 // Registration of the GPU implementations.
1212 REGISTER_KERNEL_BUILDER(
1213     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
1214     Conv2DOp<GPUDevice, Eigen::half>);
1215 REGISTER_KERNEL_BUILDER(
1216     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1217     Conv2DOp<GPUDevice, float>);
1218 REGISTER_KERNEL_BUILDER(
1219     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<double>("T"),
1220     Conv2DOp<GPUDevice, double>);
1221 REGISTER_KERNEL_BUILDER(
1222     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<int32>("T"),
1223     Conv2DOp<GPUDevice, int32>);
1224 
1225 // To be used inside depthwise_conv_op.cc.
1226 template struct LaunchConv2DOp<GPUDevice, float>;
1227 template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
1228 template struct LaunchConv2DOp<GPUDevice, double>;
1229 
1230 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1231 
1232 }  // namespace tensorflow
1233