xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/conv_grad_ops_3d.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define USE_EIGEN_TENSOR
17 #define EIGEN_USE_THREADS
18 
19 #include <utility>
20 
21 #include "tensorflow/core/framework/kernel_shape_util.h"
22 #include "tensorflow/core/framework/numeric_op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/tensor_slice.h"
28 #include "tensorflow/core/framework/tensor_util.h"
29 #include "tensorflow/core/kernels/conv_2d.h"
30 #include "tensorflow/core/kernels/conv_3d.h"
31 #include "tensorflow/core/kernels/conv_grad_ops.h"
32 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
33 #include "tensorflow/core/kernels/conv_ops_gpu.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/inlined_vector.h"
36 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
37 #include "tensorflow/core/util/padding.h"
38 #include "tensorflow/core/util/tensor_format.h"
39 #include "tensorflow/core/util/use_cudnn.h"
40 #include "tensorflow/core/util/work_sharder.h"
41 
42 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
43 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
44 #endif
45 
46 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
47 #include "tensorflow/core/platform/stream_executor.h"
48 using stream_executor::dnn::DimIndex;
49 #include "tensorflow/core/protobuf/autotuning.pb.h"
50 #include "tensorflow/core/util/autotune_maps/conv_parameters.h"
51 #include "tensorflow/core/util/proto/proto_utils.h"
52 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
53 #if GOOGLE_CUDA
54 #include "third_party/gpus/cudnn/cudnn.h"
55 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
56 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
57 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
58 #endif  // GOOGLE_CUDA
59 
60 namespace {
61 
62 // TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
63 // conv_grad_input_ops_3d.cc.
64 
65 // TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
66 
67 // "Depth" is already used for the channel dimension, so for the third spatial
68 // dimension in this file we use "plane", although in NDHWC layout it's
69 // indicated with a "D".
70 
71 // Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
72 // order (planes, height, width, depth), constructed from patches in 'col_data',
73 // which is required to be in storage order (out_planes * out_height *
74 // out_width, filter_planes, filter_height, filter_width, in_depth).
75 //
76 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
77 template <typename T>
Col2im(const T * col_data,const int depth,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * im_data)78 void Col2im(const T* col_data, const int depth, const int planes,
79             const int height, const int width, const int filter_p,
80             const int filter_h, const int filter_w, const int pad_pt,
81             const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
82             const int pad_r, const int stride_p, const int stride_h,
83             const int stride_w, T* im_data) {
84   const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
85   const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
86   const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
87   int p_pad = -pad_pt;
88   for (int p = 0; p < planes_col; ++p) {
89     int h_pad = -pad_t;
90     for (int h = 0; h < height_col; ++h) {
91       int w_pad = -pad_l;
92       for (int w = 0; w < width_col; ++w) {
93         T* im_patch_data =
94             im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
95         for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
96           for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
97             for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
98               if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
99                   iw < width) {
100                 for (int i = 0; i < depth; ++i) {
101                   im_patch_data[i] += col_data[i];
102                 }
103               }
104               im_patch_data += depth;
105               col_data += depth;
106             }
107             // Jump over remaining number of depth.
108             im_patch_data += depth * (width - filter_w);
109           }
110           // Jump over remaining number of (depth * width).
111           im_patch_data += (depth * width) * (height - filter_h);
112         }
113         w_pad += stride_w;
114       }
115       h_pad += stride_h;
116     }
117     p_pad += stride_p;
118   }
119 }
120 
121 // Returns in 'col_data', image patches in storage order (planes, height, width,
122 // depth) extracted from image at 'input_data', which is required to be in
123 // storage order (batch, planes, height, width, depth).
124 //
125 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
126 template <typename T>
Im2col(const T * input_data,const int depth,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * col_data)127 void Im2col(const T* input_data, const int depth, const int planes,
128             const int height, const int width, const int filter_p,
129             const int filter_h, const int filter_w, const int pad_pt,
130             const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
131             const int pad_r, const int stride_p, const int stride_h,
132             const int stride_w, T* col_data) {
133   const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
134   const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
135   const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
136 
137   int p_pad = -pad_pt;
138   for (int p = 0; p < planes_col; ++p) {
139     int h_pad = -pad_t;
140     for (int h = 0; h < height_col; ++h) {
141       int w_pad = -pad_l;
142       for (int w = 0; w < width_col; ++w) {
143         for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
144           for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
145             for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
146               if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
147                   iw < width) {
148                 memcpy(col_data,
149                        input_data +
150                            (ip * height * width + ih * width + iw) * depth,
151                        sizeof(T) * depth);
152               } else {
153                 // This should be simply padded with zero.
154                 memset(col_data, 0, sizeof(T) * depth);
155               }
156               col_data += depth;
157             }
158           }
159         }
160         w_pad += stride_w;
161       }
162       h_pad += stride_h;
163     }
164     p_pad += stride_p;
165   }
166 }
167 
168 }  // namespace
169 
170 namespace tensorflow {
171 
172 typedef Eigen::ThreadPoolDevice CPUDevice;
173 typedef Eigen::GpuDevice GPUDevice;
174 
175 // Backprop for input that offloads computation to
176 // Eigen::CuboidConvolutionBackwardInput.
177 template <typename Device, class T>
178 class Conv3DBackpropInputOp : public OpKernel {
179  public:
Conv3DBackpropInputOp(OpKernelConstruction * context)180   explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
181       : OpKernel(context),
182         data_format_(FORMAT_NHWC),
183         takes_shape_(type_string().find("V2") != std::string::npos) {
184     // data_format is only available in V2.
185     if (takes_shape_) {
186       string data_format;
187       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
188       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
189                   errors::InvalidArgument("Invalid data format"));
190       OP_REQUIRES(
191           context, data_format_ == FORMAT_NHWC,
192           errors::InvalidArgument(
193               "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
194     }
195 
196     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
197     OP_REQUIRES(context, dilation_.size() == 5,
198                 errors::InvalidArgument("Dilation rates field must "
199                                         "specify 5 dimensions"));
200     OP_REQUIRES(context,
201                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
202                  GetTensorDim(dilation_, data_format_, 'N') == 1),
203                 errors::InvalidArgument(
204                     "Current implementation does not yet support "
205                     "dilation rates in the batch and depth dimensions."));
206 
207     // TODO(yangzihao): Add CPU version of dilated conv 3D.
208     OP_REQUIRES(context,
209                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
210                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
211                  GetTensorDim(dilation_, data_format_, '2') == 1),
212                 errors::InvalidArgument(
213                     "Current CPU implementation does not yet support "
214                     "dilation rates larger than 1."));
215 
216     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
217     OP_REQUIRES(context, stride_.size() == 5,
218                 errors::InvalidArgument("Sliding window strides field must "
219                                         "specify 5 dimensions"));
220     OP_REQUIRES(
221         context,
222         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
223          GetTensorDim(stride_, data_format_, 'N') == 1),
224         errors::InvalidArgument("Current implementation does not yet support "
225                                 "strides in the batch and depth dimensions."));
226     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
227   }
228 
Compute(OpKernelContext * context)229   void Compute(OpKernelContext* context) override {
230     const Tensor& filter = context->input(1);
231     const TensorShape& filter_shape = filter.shape();
232 
233     const Tensor& out_backprop = context->input(2);
234     const TensorShape& out_backprop_shape = out_backprop.shape();
235 
236     TensorShape input_shape;
237     if (takes_shape_) {
238       const Tensor& input_sizes = context->input(0);
239       // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
240       // input_sizes.
241       OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
242     } else {
243       input_shape = context->input(0).shape();
244     }
245 
246     OP_REQUIRES(context, input_shape.dims() == 5,
247                 errors::InvalidArgument("input tensor must have 5 dimensions"));
248     OP_REQUIRES(
249         context, filter_shape.dims() == 5,
250         errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
251     OP_REQUIRES(
252         context, out_backprop_shape.dims() == 5,
253         errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
254     OP_REQUIRES(
255         context, input_shape.dim_size(4) == filter_shape.dim_size(3),
256         errors::InvalidArgument("input and filter_sizes must have the same "
257                                 "number of channels. Got ",
258                                 input_shape.dim_size(4), " for input and ",
259                                 filter_shape.dim_size(3), " for filter_sizes"));
260     OP_REQUIRES(
261         context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
262         errors::InvalidArgument("out_backprop and filter_sizes must have the "
263                                 "same number of channels. Got ",
264                                 out_backprop_shape.dim_size(4),
265                                 " for out_backprop and ",
266                                 filter_shape.dim_size(4), " for filter_sizes"));
267 
268     ConvBackpropDimensions dims;
269     OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
270                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
271                                 input_shape, filter_shape, out_backprop_shape,
272                                 stride_, padding_, data_format_, &dims));
273 
274     Tensor* in_backprop;
275     OP_REQUIRES_OK(context,
276                    context->allocate_output(0, input_shape, &in_backprop));
277 
278     functor::CuboidConvolutionBackwardInput<Device, T>()(
279         context->eigen_device<Device>(),
280         in_backprop->tensor<T, 5>(),                     // input_backward
281         filter.tensor<T, 5>(),                           // filter
282         out_backprop.tensor<T, 5>(),                     // output_backward
283         static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
284         static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
285         static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
286   }
287 
288  private:
289   std::vector<int32> dilation_;
290   std::vector<int32> stride_;
291   Padding padding_;
292   TensorFormat data_format_;
293   bool takes_shape_;
294 
295   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
296 };
297 
298 // Custom backprop for input that explicitly does the work sharding and calls
299 // Eigen only to multiply matrices.
300 template <typename Device, class T>
301 class Conv3DCustomBackpropInputOp : public OpKernel {
302   // Limit the maximum size of allocated temporary buffer to
303   // kMaxTempAllocationOverhead times the size of the input tensors (input,
304   // filter, out_backprop). If the size of the temporary buffer exceeds this
305   // limit, fallback on Eigen implementation.
306   static constexpr int kMaxTempAllocationOverhead = 25;
307 
308  public:
Conv3DCustomBackpropInputOp(OpKernelConstruction * context)309   explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
310       : OpKernel(context),
311         data_format_(FORMAT_NHWC),
312         takes_shape_(type_string().find("V2") != std::string::npos) {
313     // data_format is only available in V2.
314     if (takes_shape_) {
315       string data_format;
316       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
317       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
318                   errors::InvalidArgument("Invalid data format"));
319       OP_REQUIRES(
320           context, data_format_ == FORMAT_NHWC,
321           errors::InvalidArgument(
322               "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
323     }
324 
325     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
326     OP_REQUIRES(context, dilation_.size() == 5,
327                 errors::InvalidArgument("Dilation rates field must "
328                                         "specify 5 dimensions"));
329     OP_REQUIRES(context,
330                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
331                  GetTensorDim(dilation_, data_format_, 'N') == 1),
332                 errors::InvalidArgument(
333                     "Current implementation does not yet support "
334                     "dilation rates in the batch and depth dimensions."));
335 
336     // TODO(yangzihao): Add CPU version of dilated conv 3D.
337     OP_REQUIRES(context,
338                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
339                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
340                  GetTensorDim(dilation_, data_format_, '2') == 1),
341                 errors::InvalidArgument(
342                     "Current CPU implementation does not yet support "
343                     "dilation rates larger than 1."));
344 
345     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
346     OP_REQUIRES(context, stride_.size() == 5,
347                 errors::InvalidArgument("Sliding window strides field must "
348                                         "specify 5 dimensions"));
349     OP_REQUIRES(
350         context,
351         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
352          GetTensorDim(stride_, data_format_, 'N') == 1),
353         errors::InvalidArgument("Current implementation does not yet support "
354                                 "strides in the batch and depth dimensions."));
355     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
356   }
357 
Compute(OpKernelContext * context)358   void Compute(OpKernelContext* context) override {
359     const Tensor& filter = context->input(1);
360     const TensorShape& filter_shape = filter.shape();
361 
362     const Tensor& out_backprop = context->input(2);
363     const TensorShape& out_backprop_shape = out_backprop.shape();
364 
365     TensorShape input_shape;
366     if (takes_shape_) {
367       const Tensor& input_sizes = context->input(0);
368       // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
369       // input_sizes.
370       OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
371     } else {
372       input_shape = context->input(0).shape();
373     }
374 
375     OP_REQUIRES(context, input_shape.dims() == 5,
376                 errors::InvalidArgument("input tensor must have 5 dimensions"));
377     OP_REQUIRES(
378         context, filter_shape.dims() == 5,
379         errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
380     OP_REQUIRES(
381         context, out_backprop_shape.dims() == 5,
382         errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
383     OP_REQUIRES(
384         context, input_shape.dim_size(4) == filter_shape.dim_size(3),
385         errors::InvalidArgument("input and filter_sizes must have the same "
386                                 "number of channels. Got ",
387                                 input_shape.dim_size(4), " for input and ",
388                                 filter_shape.dim_size(3), " for filter_sizes"));
389     OP_REQUIRES(
390         context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
391         errors::InvalidArgument("out_backprop and filter_sizes must have the "
392                                 "same number of channels. Got ",
393                                 out_backprop_shape.dim_size(4),
394                                 " for out_backprop and ",
395                                 filter_shape.dim_size(4), " for filter_sizes"));
396 
397     ConvBackpropDimensions dims;
398     OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
399                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
400                                 input_shape, filter_shape, out_backprop_shape,
401                                 stride_, padding_, data_format_, &dims));
402 
403     Tensor* in_backprop;
404     OP_REQUIRES_OK(context,
405                    context->allocate_output(0, input_shape, &in_backprop));
406 
407     int64_t top_pad_planes, bottom_pad_planes;
408     int64_t top_pad_rows, bottom_pad_rows;
409     int64_t left_pad_cols, right_pad_cols;
410 
411     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
412                                 dims.spatial_dims[0].input_size,
413                                 dims.spatial_dims[0].filter_size,
414                                 dims.spatial_dims[0].stride, padding_,
415                                 &dims.spatial_dims[0].output_size,
416                                 &top_pad_planes, &bottom_pad_planes));
417     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
418                                 dims.spatial_dims[1].input_size,
419                                 dims.spatial_dims[1].filter_size,
420                                 dims.spatial_dims[1].stride, padding_,
421                                 &dims.spatial_dims[1].output_size,
422                                 &top_pad_rows, &bottom_pad_rows));
423     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
424                                 dims.spatial_dims[2].input_size,
425                                 dims.spatial_dims[2].filter_size,
426                                 dims.spatial_dims[2].stride, padding_,
427                                 &dims.spatial_dims[2].output_size,
428                                 &left_pad_cols, &right_pad_cols));
429 
430     // TODO(ezhulenev): Extract work size and shard estimation to shared
431     // functions in conv_grad_ops, and update 2d convolution backprop.
432 
433     // The total dimension size of each kernel.
434     const int64_t filter_total_size =
435         dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
436         dims.spatial_dims[2].filter_size * dims.in_depth;
437 
438     // The output image size is the spatial size of the output.
439     const int64_t output_image_size = dims.spatial_dims[0].output_size *
440                                       dims.spatial_dims[1].output_size *
441                                       dims.spatial_dims[2].output_size;
442 
443     const auto cache_sizes = Eigen::internal::CacheSizes();
444     const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
445 
446     // Use L3 cache size as target working set size.
447     const size_t target_working_set_size = l3_cache_size / sizeof(T);
448 
449     // Calculate size of matrices involved in MatMul: C = A x B.
450     const int64_t size_A = output_image_size * dims.out_depth;
451 
452     const int64_t size_B = filter_total_size * dims.out_depth;
453 
454     const int64_t size_C = output_image_size * filter_total_size;
455 
456     const int64_t work_unit_size = size_A + size_B + size_C;
457 
458     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
459 
460     // Use parallel tensor contractions if there is no batching.
461     //
462     // Compared to Conv2D code, this version is missing work size estimation. In
463     // benchmarks I didn't find a case when it's beneficial to run parallel
464     // contraction compared to sharding and matmuls.
465     const bool use_parallel_contraction = dims.batch_size == 1;
466 
467     OP_REQUIRES(
468         context, work_unit_size > 0,
469         errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
470                                 "must all have at least 1 element"));
471 
472     const size_t shard_size =
473         use_parallel_contraction
474             ? 1
475             : (target_working_set_size + work_unit_size - 1) / work_unit_size;
476 
477     // Total number of elements in all the tensors used by this kernel.
478     int64_t total_tensor_elements = input_shape.num_elements() +
479                                     filter_shape.num_elements() +
480                                     out_backprop_shape.num_elements();
481 
482     // Shape of the temporary workspace buffer.
483     TensorShape col_buffer_shape = {static_cast<int64_t>(shard_size),
484                                     static_cast<int64_t>(output_image_size),
485                                     static_cast<int64_t>(filter_total_size)};
486     int64_t col_buffer_elements = col_buffer_shape.num_elements();
487 
488     // If the temporary allocation overhead is too large, fallback on Eigen
489     // implementation which requires much less memory.
490     int64_t col_buffer_overhead = col_buffer_elements / total_tensor_elements;
491     if (col_buffer_overhead > kMaxTempAllocationOverhead) {
492       VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
493                  "col_buffer_overhead="
494               << col_buffer_overhead;
495 
496       functor::CuboidConvolutionBackwardInput<Device, T>()(
497           context->eigen_device<Device>(),
498           in_backprop->tensor<T, 5>(),                     // input_backward
499           filter.tensor<T, 5>(),                           // filter
500           out_backprop.tensor<T, 5>(),                     // output_backward
501           static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
502           static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
503           static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
504 
505       return;
506     }
507 
508     Tensor col_buffer;
509     OP_REQUIRES_OK(context,
510                    context->allocate_temp(DataTypeToEnum<T>::value,
511                                           col_buffer_shape, &col_buffer));
512 
513     // The input offset corresponding to a single input image.
514     const int64_t input_offset =
515         dims.spatial_dims[0].input_size * dims.spatial_dims[1].input_size *
516         dims.spatial_dims[2].input_size * dims.in_depth;
517 
518     // The output offset corresponding to a single output image.
519     const int64_t output_offset =
520         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
521         dims.spatial_dims[2].output_size * dims.out_depth;
522 
523     const T* filter_data = filter.template flat<T>().data();
524     T* col_buffer_data = col_buffer.template flat<T>().data();
525     const T* out_backprop_data = out_backprop.template flat<T>().data();
526 
527     auto in_backprop_flat = in_backprop->template flat<T>();
528     T* input_backprop_data = in_backprop_flat.data();
529     in_backprop_flat.device(context->eigen_device<Device>()) =
530         in_backprop_flat.constant(T(0));
531 
532     if (use_parallel_contraction) {
533       typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
534                                Eigen::Unaligned>
535           TensorMap;
536       typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
537                                Eigen::Unaligned>
538           ConstTensorMap;
539 
540       // Initialize contraction dims (we need to transpose 'B' below).
541       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
542       contract_dims[0].first = 1;
543       contract_dims[0].second = 1;
544 
545       for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
546         // Compute gradient into col_buffer.
547         TensorMap C(col_buffer_data, output_image_size, filter_total_size);
548 
549         ConstTensorMap A(out_backprop_data + output_offset * image_id,
550                          output_image_size, dims.out_depth);
551         ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
552 
553         C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
554 
555         Col2im<T>(col_buffer_data, dims.in_depth,
556                   // Input spatial dimensions.
557                   dims.spatial_dims[0].input_size,  // input planes
558                   dims.spatial_dims[1].input_size,  // input rows
559                   dims.spatial_dims[2].input_size,  // input cols
560                   // Filter spatial dimensions.
561                   dims.spatial_dims[0].filter_size,  // filter planes
562                   dims.spatial_dims[1].filter_size,  // filter rows
563                   dims.spatial_dims[2].filter_size,  // filter cols
564                   // Spatial padding.
565                   top_pad_planes, top_pad_rows, left_pad_cols,
566                   bottom_pad_planes, bottom_pad_rows, right_pad_cols,
567                   // Spatial striding.
568                   dims.spatial_dims[0].stride,  // stride planes
569                   dims.spatial_dims[1].stride,  // stride rows
570                   dims.spatial_dims[2].stride,  // stride cols
571                   input_backprop_data);
572 
573         input_backprop_data += input_offset;
574       }
575     } else {
576       typedef Eigen::Map<
577           Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
578           MatrixMap;
579       typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
580                                              Eigen::RowMajor>>
581           ConstMatrixMap;
582 
583       for (int image_id = 0; image_id < dims.batch_size;
584            image_id += shard_size) {
585         const int shard_limit =
586             std::min(static_cast<int>(shard_size),
587                      static_cast<int>(dims.batch_size) - image_id);
588 
589         auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
590                       &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
591                       &output_image_size, &filter_total_size,
592                       &input_backprop_data, &col_buffer_data,
593                       &out_backprop_data, &filter_data, &input_offset,
594                       &output_offset, &size_C](int64_t start, int64_t limit) {
595           for (int shard_id = start; shard_id < limit; ++shard_id) {
596             T* im2col_buf = col_buffer_data + shard_id * size_C;
597             T* input_data = input_backprop_data + shard_id * input_offset;
598             const T* out_data = out_backprop_data + shard_id * output_offset;
599 
600             // Compute gradient into 'im2col_buf'.
601             MatrixMap C(im2col_buf, output_image_size, filter_total_size);
602 
603             ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
604             ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
605 
606             C.noalias() = A * B.transpose();
607 
608             Col2im<T>(im2col_buf, dims.in_depth,
609                       // Input spatial dimensions.
610                       dims.spatial_dims[0].input_size,  // input planes
611                       dims.spatial_dims[1].input_size,  // input rows
612                       dims.spatial_dims[2].input_size,  // input cols
613                       // Filter spatial dimensions.
614                       dims.spatial_dims[0].filter_size,  // filter planes
615                       dims.spatial_dims[1].filter_size,  // filter rows
616                       dims.spatial_dims[2].filter_size,  // filter cols
617                       // Spatial padding.
618                       top_pad_planes, top_pad_rows, left_pad_cols,
619                       bottom_pad_planes, bottom_pad_rows, right_pad_cols,
620                       // Spatial striding.
621                       dims.spatial_dims[0].stride,  // stride planes
622                       dims.spatial_dims[1].stride,  // stride rows
623                       dims.spatial_dims[2].stride,  // stride cols
624                       input_data);
625           }
626         };
627         Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
628               work_unit_size, shard);
629 
630         input_backprop_data += input_offset * shard_limit;
631         out_backprop_data += output_offset * shard_limit;
632       }
633     }
634   }
635 
636  private:
637   std::vector<int32> dilation_;
638   std::vector<int32> stride_;
639   Padding padding_;
640   TensorFormat data_format_;
641   bool takes_shape_;
642 
643   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
644 };
645 
646 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
647 // default Eigen implementation (at the cost of ~2x-8x peak memory usage).
648 
649 #define REGISTER_CPU_KERNEL(T)                                                 \
650   REGISTER_KERNEL_BUILDER(                                                     \
651       Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
652       Conv3DCustomBackpropInputOp<CPUDevice, T>);                              \
653   REGISTER_KERNEL_BUILDER(                                                     \
654       Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
655       Conv3DCustomBackpropInputOp<CPUDevice, T>);                              \
656   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput")                          \
657                               .Device(DEVICE_CPU)                              \
658                               .Label("custom")                                 \
659                               .TypeConstraint<T>("T"),                         \
660                           Conv3DCustomBackpropInputOp<CPUDevice, T>);          \
661   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                        \
662                               .Device(DEVICE_CPU)                              \
663                               .Label("custom")                                 \
664                               .TypeConstraint<T>("T"),                         \
665                           Conv3DCustomBackpropInputOp<CPUDevice, T>);          \
666   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput")                          \
667                               .Device(DEVICE_CPU)                              \
668                               .Label("eigen_tensor")                           \
669                               .TypeConstraint<T>("T"),                         \
670                           Conv3DBackpropInputOp<CPUDevice, T>);                \
671   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                        \
672                               .Device(DEVICE_CPU)                              \
673                               .Label("eigen_tensor")                           \
674                               .TypeConstraint<T>("T"),                         \
675                           Conv3DBackpropInputOp<CPUDevice, T>);
676 
677 TF_CALL_half(REGISTER_CPU_KERNEL);
678 TF_CALL_float(REGISTER_CPU_KERNEL);
679 TF_CALL_double(REGISTER_CPU_KERNEL);
680 #undef REGISTER_CPU_KERNEL
681 
682 // Backprop for filter that offloads computation to
683 // Eigen::CuboidConvolutionBackwardFilter.
684 template <typename Device, class T>
685 class Conv3DBackpropFilterOp : public OpKernel {
686  public:
Conv3DBackpropFilterOp(OpKernelConstruction * context)687   explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
688       : OpKernel(context),
689         data_format_(FORMAT_NHWC),
690         takes_shape_(type_string().find("V2") != std::string::npos) {
691     // data_format is only available in V2.
692     if (takes_shape_) {
693       string data_format;
694       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
695       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
696                   errors::InvalidArgument("Invalid data format"));
697       OP_REQUIRES(
698           context, data_format_ == FORMAT_NHWC,
699           errors::InvalidArgument(
700               "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
701     }
702 
703     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
704     OP_REQUIRES(context, dilation_.size() == 5,
705                 errors::InvalidArgument("Dilation rates field must "
706                                         "specify 5 dimensions"));
707     OP_REQUIRES(context,
708                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
709                  GetTensorDim(dilation_, data_format_, 'N') == 1),
710                 errors::InvalidArgument(
711                     "Current implementation does not yet support "
712                     "dilation rates in the batch and depth dimensions."));
713 
714     // TODO(yangzihao): Add CPU version of dilated conv 3D.
715     OP_REQUIRES(context,
716                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
717                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
718                  GetTensorDim(dilation_, data_format_, '2') == 1),
719                 errors::InvalidArgument(
720                     "Current CPU implementation does not yet support "
721                     "dilation rates larger than 1."));
722 
723     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
724     OP_REQUIRES(context, stride_.size() == 5,
725                 errors::InvalidArgument("Sliding window strides field must "
726                                         "specify 5 dimensions"));
727     OP_REQUIRES(
728         context,
729         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
730          GetTensorDim(stride_, data_format_, 'N') == 1),
731         errors::InvalidArgument("Current implementation does not yet support "
732                                 "strides in the batch and depth dimensions."));
733     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
734   }
735 
Compute(OpKernelContext * context)736   void Compute(OpKernelContext* context) override {
737     const Tensor& input = context->input(0);
738     const TensorShape& input_shape = input.shape();
739 
740     const Tensor& out_backprop = context->input(2);
741     const TensorShape& out_backprop_shape = out_backprop.shape();
742 
743     TensorShape filter_shape;
744     if (takes_shape_) {
745       const Tensor& filter_sizes = context->input(1);
746       OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()),
747                   errors::InvalidArgument(
748                       "filter_sizes shape must be rank 1 but is rank ",
749                       filter_sizes.shape().dims()));
750       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
751                                   filter_sizes.vec<int32>(), &filter_shape));
752     } else {
753       filter_shape = context->input(1).shape();
754     }
755 
756     OP_REQUIRES(context, input_shape.dims() == 5,
757                 errors::InvalidArgument("input tensor must have 5 dimensions"));
758     OP_REQUIRES(
759         context, filter_shape.dims() == 5,
760         errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
761     OP_REQUIRES(
762         context, out_backprop_shape.dims() == 5,
763         errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
764     OP_REQUIRES(
765         context, input_shape.dim_size(4) == filter_shape.dim_size(3),
766         errors::InvalidArgument("input and filter_sizes must have the same "
767                                 "number of channels. Got ",
768                                 input_shape.dim_size(4), " for input and ",
769                                 filter_shape.dim_size(3), " for filter_sizes"));
770     OP_REQUIRES(
771         context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
772         errors::InvalidArgument("out_backprop and filter_sizes must have the "
773                                 "same number of channels. Got ",
774                                 out_backprop_shape.dim_size(4),
775                                 " for out_backprop and ",
776                                 filter_shape.dim_size(4), " for filter_sizes"));
777 
778     ConvBackpropDimensions dims;
779     OP_REQUIRES_OK(context,
780                    ConvBackpropComputeDimensions(
781                        "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
782                        input_shape, filter_shape, out_backprop_shape, stride_,
783                        padding_, data_format_, &dims));
784 
785     Tensor* filter_backprop;
786     OP_REQUIRES_OK(context,
787                    context->allocate_output(0, filter_shape, &filter_backprop));
788 
789     if (input_shape.num_elements() == 0) {
790       filter_backprop->template flat<T>().setZero();
791       return;
792     }
793 
794     functor::CuboidConvolutionBackwardFilter<Device, T>()(
795         context->eigen_device<Device>(),
796         filter_backprop->tensor<T, 5>(),                 // filter_backward
797         input.tensor<T, 5>(),                            // input
798         out_backprop.tensor<T, 5>(),                     // output_backward
799         static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
800         static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
801         static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
802   }
803 
804  private:
805   std::vector<int32> dilation_;
806   std::vector<int32> stride_;
807   Padding padding_;
808   TensorFormat data_format_;
809   bool takes_shape_;
810 
811   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
812 };
813 
814 // Custom backprop for filter that explicitly does the work sharding and calls
815 // Eigen only to multiply matrices.
816 template <typename Device, class T>
817 class Conv3DCustomBackpropFilterOp : public OpKernel {
818   // Limit the maximum size of allocated temporary buffer to
819   // kMaxTempAllocationOverhead times the size of the input tensors (input,
820   // filter, out_backprop). If the size of the temporary buffer exceeds this
821   // limit, fallback on Eigen implementation.
822   static constexpr int kMaxTempAllocationOverhead = 25;
823 
824  public:
Conv3DCustomBackpropFilterOp(OpKernelConstruction * context)825   explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
826       : OpKernel(context),
827         data_format_(FORMAT_NHWC),
828         takes_shape_(type_string().find("V2") != std::string::npos) {
829     // data_format is only available in V2.
830     if (takes_shape_) {
831       string data_format;
832       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
833       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
834                   errors::InvalidArgument("Invalid data format"));
835       OP_REQUIRES(
836           context, data_format_ == FORMAT_NHWC,
837           errors::InvalidArgument(
838               "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
839     }
840 
841     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
842     OP_REQUIRES(context, dilation_.size() == 5,
843                 errors::InvalidArgument("Dilation rates field must "
844                                         "specify 5 dimensions"));
845     OP_REQUIRES(context,
846                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
847                  GetTensorDim(dilation_, data_format_, 'N') == 1),
848                 errors::InvalidArgument(
849                     "Current implementation does not yet support "
850                     "dilation rates in the batch and depth dimensions."));
851 
852     // TODO(yangzihao): Add CPU version of dilated conv 3D.
853     OP_REQUIRES(context,
854                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
855                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
856                  GetTensorDim(dilation_, data_format_, '2') == 1),
857                 errors::InvalidArgument(
858                     "Current CPU implementation does not yet support "
859                     "dilation rates larger than 1."));
860 
861     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
862     OP_REQUIRES(context, stride_.size() == 5,
863                 errors::InvalidArgument("Sliding window strides field must "
864                                         "specify 5 dimensions"));
865     OP_REQUIRES(
866         context,
867         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
868          GetTensorDim(stride_, data_format_, 'N') == 1),
869         errors::InvalidArgument("Current implementation does not yet support "
870                                 "strides in the batch and depth dimensions."));
871     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
872   }
873 
Compute(OpKernelContext * context)874   void Compute(OpKernelContext* context) override {
875     const Tensor& input = context->input(0);
876     const TensorShape& input_shape = input.shape();
877 
878     const Tensor& out_backprop = context->input(2);
879     const TensorShape& out_backprop_shape = out_backprop.shape();
880 
881     TensorShape filter_shape;
882     if (takes_shape_) {
883       const Tensor& filter_sizes = context->input(1);
884       OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()),
885                   errors::InvalidArgument(
886                       "filter_sizes shape must be rank 1 but is rank ",
887                       filter_sizes.shape().dims()));
888       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
889                                   filter_sizes.vec<int32>(), &filter_shape));
890     } else {
891       filter_shape = context->input(1).shape();
892     }
893 
894     OP_REQUIRES(context, input_shape.dims() == 5,
895                 errors::InvalidArgument("input tensor must have 5 dimensions"));
896     OP_REQUIRES(
897         context, filter_shape.dims() == 5,
898         errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
899     OP_REQUIRES(
900         context, out_backprop_shape.dims() == 5,
901         errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
902     OP_REQUIRES(
903         context, input_shape.dim_size(4) == filter_shape.dim_size(3),
904         errors::InvalidArgument("input and filter_sizes must have the same "
905                                 "number of channels. Got ",
906                                 input_shape.dim_size(4), " for input and ",
907                                 filter_shape.dim_size(3), " for filter_sizes"));
908     OP_REQUIRES(
909         context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
910         errors::InvalidArgument("out_backprop and filter_sizes must have the "
911                                 "same number of channels. Got ",
912                                 out_backprop_shape.dim_size(4),
913                                 " for out_backprop and ",
914                                 filter_shape.dim_size(4), " for filter_sizes"));
915 
916     ConvBackpropDimensions dims;
917     OP_REQUIRES_OK(context,
918                    ConvBackpropComputeDimensions(
919                        "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
920                        input_shape, filter_shape, out_backprop_shape, stride_,
921                        padding_, data_format_, &dims));
922 
923     Tensor* filter_backprop;
924     OP_REQUIRES_OK(context,
925                    context->allocate_output(0, filter_shape, &filter_backprop));
926 
927     if (input_shape.num_elements() == 0) {
928       filter_backprop->template flat<T>().setZero();
929       return;
930     }
931 
932     int64_t top_pad_planes, bottom_pad_planes;
933     int64_t top_pad_rows, bottom_pad_rows;
934     int64_t left_pad_cols, right_pad_cols;
935 
936     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
937                                 dims.spatial_dims[0].input_size,
938                                 dims.spatial_dims[0].filter_size,
939                                 dims.spatial_dims[0].stride, padding_,
940                                 &dims.spatial_dims[0].output_size,
941                                 &top_pad_planes, &bottom_pad_planes));
942     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
943                                 dims.spatial_dims[1].input_size,
944                                 dims.spatial_dims[1].filter_size,
945                                 dims.spatial_dims[1].stride, padding_,
946                                 &dims.spatial_dims[1].output_size,
947                                 &top_pad_rows, &bottom_pad_rows));
948     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
949                                 dims.spatial_dims[2].input_size,
950                                 dims.spatial_dims[2].filter_size,
951                                 dims.spatial_dims[2].stride, padding_,
952                                 &dims.spatial_dims[2].output_size,
953                                 &left_pad_cols, &right_pad_cols));
954 
955     // TODO(ezhulenev): Extract work size and shard estimation to shared
956     // functions in conv_grad_ops, and update 2d convolution backprop.
957 
958     // The total dimension size of each kernel.
959     const int64_t filter_total_size =
960         dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
961         dims.spatial_dims[2].filter_size * dims.in_depth;
962     // The output image size is the spatial size of the output.
963     const int64_t output_image_size = dims.spatial_dims[0].output_size *
964                                       dims.spatial_dims[1].output_size *
965                                       dims.spatial_dims[2].output_size;
966 
967     // Shard 'batch' images (volumes) into 'shard_size' groups of images
968     // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
969     // dividing the L3 cache size ('target_working_set_size') by the matmul size
970     // of an individual image ('work_unit_size').
971 
972     const auto cache_sizes = Eigen::internal::CacheSizes();
973     const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
974 
975     // TODO(andydavis)
976     // *) Consider reducing 'target_working_set_size' if L3 is shared by
977     //    other concurrently running tensorflow ops.
978     const size_t target_working_set_size = l3_cache_size / sizeof(T);
979 
980     const int64_t size_A = output_image_size * filter_total_size;
981 
982     const int64_t size_B = output_image_size * dims.out_depth;
983 
984     const int64_t size_C = filter_total_size * dims.out_depth;
985 
986     const int64_t work_unit_size = size_A + size_B + size_C;
987 
988     OP_REQUIRES(
989         context, work_unit_size > 0,
990         errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
991                                 "must all have at least 1 element"));
992 
993     const size_t shard_size =
994         (target_working_set_size + work_unit_size - 1) / work_unit_size;
995 
996     // Total number of elements in all the tensors used by this kernel.
997     int64_t total_tensor_elements = input_shape.num_elements() +
998                                     filter_shape.num_elements() +
999                                     out_backprop_shape.num_elements();
1000 
1001     // Shape of the temporary workspace buffer.
1002     TensorShape col_buffer_shape = {static_cast<int64_t>(shard_size),
1003                                     static_cast<int64_t>(output_image_size),
1004                                     static_cast<int64_t>(filter_total_size)};
1005     int64_t col_buffer_elements = col_buffer_shape.num_elements();
1006 
1007     // If the temporary allocation overhead is too large, fallback on Eigen
1008     // implementation which requires much less memory.
1009     int64_t col_buffer_overhead = col_buffer_elements / total_tensor_elements;
1010     if (col_buffer_overhead > kMaxTempAllocationOverhead) {
1011       VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
1012                  "col_buffer_overhead="
1013               << col_buffer_overhead;
1014 
1015       functor::CuboidConvolutionBackwardFilter<Device, T>()(
1016           context->eigen_device<Device>(),
1017           filter_backprop->tensor<T, 5>(),                 // filter_backward
1018           input.tensor<T, 5>(),                            // input
1019           out_backprop.tensor<T, 5>(),                     // output_backward
1020           static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
1021           static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
1022           static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
1023 
1024       return;
1025     }
1026 
1027     Tensor col_buffer;
1028     OP_REQUIRES_OK(context,
1029                    context->allocate_temp(DataTypeToEnum<T>::value,
1030                                           col_buffer_shape, &col_buffer));
1031 
1032     // The input offset corresponding to a single input image.
1033     const int64_t input_offset =
1034         dims.spatial_dims[0].input_size * dims.spatial_dims[1].input_size *
1035         dims.spatial_dims[2].input_size * dims.in_depth;
1036     // The output offset corresponding to a single output image.
1037     const int64_t output_offset =
1038         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
1039         dims.spatial_dims[2].output_size * dims.out_depth;
1040 
1041     const T* input_data = input.template flat<T>().data();
1042     T* col_buffer_data = col_buffer.template flat<T>().data();
1043     const T* out_backprop_data = out_backprop.template flat<T>().data();
1044     T* filter_backprop_data = filter_backprop->template flat<T>().data();
1045 
1046     typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
1047                              Eigen::Unaligned>
1048         TensorMap;
1049     typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
1050                              Eigen::Unaligned>
1051         ConstTensorMap;
1052 
1053     TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
1054     C.setZero();
1055 
1056     // Initialize contraction dims (we need to transpose 'A' below).
1057     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
1058     contract_dims[0].first = 0;
1059     contract_dims[0].second = 0;
1060 
1061     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
1062 
1063     for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
1064       const int shard_limit =
1065           std::min(static_cast<int>(shard_size),
1066                    static_cast<int>(dims.batch_size) - image_id);
1067 
1068       auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
1069                     &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
1070                     &bottom_pad_rows, &right_pad_cols, &input_offset,
1071                     &size_A](int64_t start, int64_t limit) {
1072         for (int shard_id = start; shard_id < limit; ++shard_id) {
1073           const T* input_data_shard = input_data + shard_id * input_offset;
1074           T* col_data_shard = col_buffer_data + shard_id * size_A;
1075 
1076           // When we compute the gradient with respect to the filters, we need
1077           // to do im2col to allow gemm-type computation.
1078           Im2col<T>(input_data_shard, dims.in_depth,
1079                     // Input spatial dimensions.
1080                     dims.spatial_dims[0].input_size,  // input planes
1081                     dims.spatial_dims[1].input_size,  // input rows
1082                     dims.spatial_dims[2].input_size,  // input cols
1083                     // Filter spatial dimensions.
1084                     dims.spatial_dims[0].filter_size,  // filter planes
1085                     dims.spatial_dims[1].filter_size,  // filter rows
1086                     dims.spatial_dims[2].filter_size,  // filter cols
1087                     // Spatial padding.
1088                     top_pad_planes, top_pad_rows, left_pad_cols,
1089                     bottom_pad_planes, bottom_pad_rows, right_pad_cols,
1090                     // Spatial striding.
1091                     dims.spatial_dims[0].stride,  // stride planes
1092                     dims.spatial_dims[1].stride,  // stride rows
1093                     dims.spatial_dims[2].stride,  // stride cols
1094                     col_data_shard);
1095         }
1096       };
1097       Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
1098             size_A, shard);
1099 
1100       ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
1101                        filter_total_size);
1102       ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
1103                        dims.out_depth);
1104 
1105       // Gradient with respect to filter.
1106       C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
1107 
1108       input_data += input_offset * shard_limit;
1109       out_backprop_data += output_offset * shard_limit;
1110     }
1111   }
1112 
1113  private:
1114   std::vector<int32> dilation_;
1115   std::vector<int32> stride_;
1116   Padding padding_;
1117   TensorFormat data_format_;
1118   bool takes_shape_;
1119 
1120   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
1121 };
1122 
1123 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
1124 // default Eigen implementation (at the cost of ~2x-8x peak memory usage).
1125 
1126 #define REGISTER_CPU_KERNEL(T)                                                \
1127   REGISTER_KERNEL_BUILDER(                                                    \
1128       Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1129       Conv3DCustomBackpropFilterOp<CPUDevice, T>);                            \
1130   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1131                               .Device(DEVICE_CPU)                             \
1132                               .TypeConstraint<T>("T"),                        \
1133                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
1134   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter")                        \
1135                               .Device(DEVICE_CPU)                             \
1136                               .Label("custom")                                \
1137                               .TypeConstraint<T>("T"),                        \
1138                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
1139   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1140                               .Device(DEVICE_CPU)                             \
1141                               .Label("custom")                                \
1142                               .TypeConstraint<T>("T"),                        \
1143                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
1144   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter")                        \
1145                               .Device(DEVICE_CPU)                             \
1146                               .Label("eigen_tensor")                          \
1147                               .TypeConstraint<T>("T"),                        \
1148                           Conv3DBackpropFilterOp<CPUDevice, T>);              \
1149   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1150                               .Device(DEVICE_CPU)                             \
1151                               .Label("eigen_tensor")                          \
1152                               .TypeConstraint<T>("T"),                        \
1153                           Conv3DBackpropFilterOp<CPUDevice, T>);
1154 
1155 TF_CALL_float(REGISTER_CPU_KERNEL);
1156 TF_CALL_double(REGISTER_CPU_KERNEL);
1157 #undef REGISTER_CPU_KERNEL
1158 
1159 // WARNING: Eigen::half is not trivially copyable and can't be used in
1160 // custom backprop filter kernel because of memcpy and memset in Im2col.
1161 #define REGISTER_CPU_KERNEL(T)                                                \
1162   REGISTER_KERNEL_BUILDER(                                                    \
1163       Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1164       Conv3DBackpropFilterOp<CPUDevice, T>);                                  \
1165   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1166                               .Device(DEVICE_CPU)                             \
1167                               .TypeConstraint<T>("T"),                        \
1168                           Conv3DBackpropFilterOp<CPUDevice, T>);
1169 
1170 TF_CALL_half(REGISTER_CPU_KERNEL);
1171 #undef REGISTER_CPU_KERNEL
1172 
1173 // GPU definitions of both ops.
1174 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1175 // Forward declarations of the functor specializations for GPU.
1176 // This ensures that the custom implementation is used instead of the default
1177 // Eigen one (which is used for CPU).
1178 namespace functor {
1179 #define DECLARE_GPU_SPEC(T)                                           \
1180   template <>                                                         \
1181   void TransformFilter<GPUDevice, T, int, 5>::operator()(             \
1182       const GPUDevice& d, FilterTensorFormat dst_filter_format,       \
1183       typename TTypes<T, 5, int>::ConstTensor in,                     \
1184       typename TTypes<T, 5, int>::Tensor out);                        \
1185   template <>                                                         \
1186   void ReverseTransformFilter<GPUDevice, T, 5>::operator()(           \
1187       const GPUDevice& d, FilterTensorFormat src_filter_format,       \
1188       typename TTypes<T, 5>::ConstTensor in,                          \
1189       typename TTypes<T, 5>::Tensor out);                             \
1190   template <>                                                         \
1191   void PadInput<GPUDevice, T, int, 5>::operator()(                    \
1192       const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
1193       const std::array<int, 3>& padding_left,                         \
1194       const std::array<int, 3>& padding_right,                        \
1195       typename TTypes<T, 5, int>::Tensor out, TensorFormat format,    \
1196       const T& padding_value);
1197 
1198 DECLARE_GPU_SPEC(Eigen::half);
1199 DECLARE_GPU_SPEC(float);
1200 DECLARE_GPU_SPEC(double);
1201 #undef DECLARE_GPU_SPEC
1202 }  // namespace functor
1203 
1204 // A dummy type to group backward data autotune results together.
1205 struct Conv3dBackwardDataAutotuneGroup {
nametensorflow::Conv3dBackwardDataAutotuneGroup1206   static string name() { return "Conv3dBwdData"; }
1207 };
1208 
1209 typedef AutotuneSingleton<Conv3dBackwardDataAutotuneGroup, ConvParameters,
1210                           AutotuneEntry<se::dnn::ConvOp>>
1211 
1212     AutotuneConv3dBwdData;
1213 
1214 template <typename T>
1215 class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
1216  public:
Conv3DBackpropInputOp(OpKernelConstruction * context)1217   explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
1218       : OpKernel(context),
1219         data_format_(FORMAT_NHWC),
1220         takes_shape_(type_string().find("V2") != std::string::npos) {
1221     // data_format is only available in V2.
1222     if (takes_shape_) {
1223       string data_format;
1224       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1225       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1226                   errors::InvalidArgument("Invalid data format"));
1227     }
1228     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
1229     OP_REQUIRES(context, dilation_.size() == 5,
1230                 errors::InvalidArgument("Dilation rates field must "
1231                                         "specify 5 dimensions"));
1232     OP_REQUIRES(context,
1233                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
1234                  GetTensorDim(dilation_, data_format_, 'N') == 1),
1235                 errors::InvalidArgument(
1236                     "Current implementation does not yet support "
1237                     "dilation rates in the batch and depth dimensions."));
1238     OP_REQUIRES(
1239         context,
1240         (GetTensorDim(dilation_, data_format_, '0') > 0 &&
1241          GetTensorDim(dilation_, data_format_, '1') > 0 &&
1242          GetTensorDim(dilation_, data_format_, '2') > 0),
1243         errors::InvalidArgument("Dilated rates should be larger than 0."));
1244     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1245     OP_REQUIRES(context, stride_.size() == 5,
1246                 errors::InvalidArgument("Sliding window strides field must "
1247                                         "specify 5 dimensions"));
1248     OP_REQUIRES(
1249         context,
1250         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
1251          GetTensorDim(stride_, data_format_, 'N') == 1),
1252         errors::InvalidArgument("Current implementation does not yet support "
1253                                 "strides in the batch and depth dimensions."));
1254     OP_REQUIRES(
1255         context,
1256         (GetTensorDim(stride_, data_format_, '0') > 0 &&
1257          GetTensorDim(stride_, data_format_, '1') > 0 &&
1258          GetTensorDim(stride_, data_format_, '2') > 0),
1259         errors::InvalidArgument("Spatial strides should be larger than 0."));
1260     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1261     cudnn_use_autotune_ = CudnnUseAutotune();
1262   }
Compute(OpKernelContext * context)1263   void Compute(OpKernelContext* context) override {
1264     const Tensor& filter = context->input(1);
1265     const TensorShape& filter_shape = filter.shape();
1266 
1267     const Tensor& out_backprop = context->input(2);
1268     const TensorShape& out_backprop_shape = out_backprop.shape();
1269 
1270     TensorShape input_shape;
1271     if (takes_shape_) {
1272       const Tensor& input_sizes = context->input(0);
1273       OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
1274     } else {
1275       input_shape = context->input(0).shape();
1276     }
1277 
1278     ConvBackpropDimensions dims;
1279     OP_REQUIRES_OK(context, ConvBackpropComputeDimensionsV2(
1280                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
1281                                 input_shape, filter_shape, out_backprop_shape,
1282                                 dilation_, stride_, padding_,
1283                                 /*explicit_paddings=*/{}, data_format_, &dims));
1284 
1285     Tensor* in_backprop;
1286     OP_REQUIRES_OK(context,
1287                    context->allocate_output(0, input_shape, &in_backprop));
1288 
1289     auto* stream = context->op_device_context()->stream();
1290     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
1291 
1292     bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
1293     if (!is_grouped_convolution && dims.filter_size(0) == 1 &&
1294         dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
1295         dims.dilation(0) == 1 && dims.dilation(1) == 1 &&
1296         dims.dilation(2) == 1 && dims.stride(0) == 1 && dims.stride(1) == 1 &&
1297         dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) {
1298       const uint64 m = dims.batch_size * dims.input_size(0) *
1299                        dims.input_size(1) * dims.input_size(2);
1300       const uint64 k = dims.out_depth;
1301       const uint64 n = dims.in_depth;
1302 
1303       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1304                                   out_backprop.template flat<T>().size());
1305       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
1306                                   filter.template flat<T>().size());
1307       auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
1308                                   in_backprop->template flat<T>().size());
1309 
1310       auto transpose = se::blas::Transpose::kTranspose;
1311       auto no_transpose = se::blas::Transpose::kNoTranspose;
1312 
1313       OP_REQUIRES_OK(
1314           context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr,
1315                                         k, a_ptr, k, &c_ptr, n,
1316                                         se::blas::kDefaultComputePrecision));
1317       return;
1318     } else if (!is_grouped_convolution &&
1319                dims.filter_size(0) == dims.input_size(0) &&
1320                dims.filter_size(1) == dims.input_size(1) &&
1321                dims.filter_size(2) == dims.input_size(2) &&
1322                padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
1323       const uint64 m = dims.batch_size;
1324       const uint64 k = dims.out_depth;
1325       const uint64 n = dims.input_size(0) * dims.input_size(1) *
1326                        dims.input_size(2) * dims.in_depth;
1327 
1328       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1329                                   out_backprop.template flat<T>().size());
1330       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
1331                                   filter.template flat<T>().size());
1332       auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
1333                                   in_backprop->template flat<T>().size());
1334 
1335       auto transpose = se::blas::Transpose::kTranspose;
1336       auto no_transpose = se::blas::Transpose::kNoTranspose;
1337 
1338       OP_REQUIRES_OK(
1339           context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr,
1340                                         k, a_ptr, k, &c_ptr, n,
1341                                         se::blas::kDefaultComputePrecision));
1342       return;
1343     }
1344 
1345     int padding_planes = dims.SpatialPadding(padding_, 0);
1346     int padding_rows = dims.SpatialPadding(padding_, 1);
1347     int padding_cols = dims.SpatialPadding(padding_, 2);
1348     const bool planes_odd = (padding_planes % 2 != 0);
1349     const bool rows_odd = (padding_rows % 2 != 0);
1350     const bool cols_odd = (padding_cols % 2 != 0);
1351 
1352     TensorShape compatible_input_shape;
1353     if (rows_odd || cols_odd || planes_odd) {
1354       // cuDNN only supports the same amount of padding on both sides.
1355       compatible_input_shape = {
1356           dims.batch_size,
1357           dims.in_depth,
1358           dims.input_size(0) + planes_odd,
1359           dims.input_size(1) + rows_odd,
1360           dims.input_size(2) + cols_odd,
1361       };
1362     } else {
1363       compatible_input_shape = {dims.batch_size, dims.in_depth,
1364                                 dims.input_size(0), dims.input_size(1),
1365                                 dims.input_size(2)};
1366     }
1367 
1368     CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
1369         << "Negative paddings: (" << padding_rows << ", " << padding_cols
1370         << ", " << padding_planes << ")";
1371 
1372 #if GOOGLE_CUDA
1373     const bool compute_in_nhwc =
1374         CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
1375 #else
1376     // fast NDHWC implementation is a CUDA only feature
1377     const bool compute_in_nhwc = false;
1378 #endif
1379     const TensorFormat compute_data_format =
1380         (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
1381                                                          : FORMAT_NCHW;
1382 
1383     VLOG(3) << "Compute Conv3DBackpropInput with cuDNN:"
1384             << " data_format=" << ToString(data_format_)
1385             << " compute_data_format=" << ToString(compute_data_format);
1386 
1387     constexpr auto kComputeInNHWC =
1388         std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1389                         se::dnn::FilterLayout::kOutputYXInput);
1390     constexpr auto kComputeInNCHW =
1391         std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1392                         se::dnn::FilterLayout::kOutputInputYX);
1393 
1394     se::dnn::DataLayout compute_data_layout;
1395     se::dnn::FilterLayout filter_layout;
1396 
1397     std::tie(compute_data_layout, filter_layout) =
1398         compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1399 
1400     se::dnn::BatchDescriptor input_desc(3);
1401     input_desc.set_count(dims.batch_size)
1402         .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
1403         .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
1404         .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
1405         .set_feature_map_count(dims.in_depth)
1406         .set_layout(compute_data_layout);
1407     se::dnn::BatchDescriptor output_desc(3);
1408     output_desc.set_count(dims.batch_size)
1409         .set_spatial_dim(DimIndex::X, dims.output_size(2))
1410         .set_spatial_dim(DimIndex::Y, dims.output_size(1))
1411         .set_spatial_dim(DimIndex::Z, dims.output_size(0))
1412         .set_feature_map_count(dims.out_depth)
1413         .set_layout(compute_data_layout);
1414     se::dnn::FilterDescriptor filter_desc(3);
1415     filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
1416         .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
1417         .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
1418         .set_input_feature_map_count(filter_shape.dim_size(3))
1419         .set_output_feature_map_count(filter_shape.dim_size(4))
1420         .set_layout(filter_layout);
1421     se::dnn::ConvolutionDescriptor conv_desc(3);
1422     conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
1423         .set_dilation_rate(DimIndex::Y, dims.dilation(1))
1424         .set_dilation_rate(DimIndex::Z, dims.dilation(0))
1425         .set_filter_stride(DimIndex::X, dims.stride(2))
1426         .set_filter_stride(DimIndex::Y, dims.stride(1))
1427         .set_filter_stride(DimIndex::Z, dims.stride(0))
1428         .set_zero_padding(DimIndex::X, padding_cols / 2)
1429         .set_zero_padding(DimIndex::Y, padding_rows / 2)
1430         .set_zero_padding(DimIndex::Z, padding_planes / 2)
1431         .set_group_count(dims.in_depth / filter_shape.dim_size(3));
1432 
1433     // Shape: out, in, z, y, x.
1434     Tensor transformed_filter;
1435     auto dst_format =
1436         compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
1437     TensorShape dst_shape =
1438         dst_format == FORMAT_OIHW
1439             ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
1440                            dims.filter_size(0), dims.filter_size(1),
1441                            dims.filter_size(2)})
1442             : TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
1443                            dims.filter_size(1), dims.filter_size(2),
1444                            filter_shape.dim_size(3)});
1445     OP_REQUIRES_OK(context,
1446                    context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1447                                           &transformed_filter));
1448 
1449     functor::TransformFilter<GPUDevice, T, int, 5>()(
1450         context->eigen_device<GPUDevice>(), dst_format,
1451         To32Bit(filter.tensor<T, 5>()),
1452         To32Bit(transformed_filter.tensor<T, 5>()));
1453 
1454     // Shape: batch, filters, z, y, x.
1455     Tensor transformed_out_backprop;
1456     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1457       TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
1458                                 dims.output_size(0), dims.output_size(1),
1459                                 dims.output_size(2)};
1460       if (dims.out_depth > 1) {
1461         OP_REQUIRES_OK(context, context->allocate_temp(
1462                                     DataTypeToEnum<T>::value, nchw_shape,
1463                                     &transformed_out_backprop));
1464         functor::NHWCToNCHW<GPUDevice, T, 5>()(
1465             context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
1466             transformed_out_backprop.tensor<T, 5>());
1467       } else {
1468         CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
1469       }
1470     } else {
1471       transformed_out_backprop = out_backprop;
1472     }
1473     // Shape: batch, filters, z, y, x.
1474     Tensor pre_transformed_in_backprop;
1475     OP_REQUIRES_OK(context,
1476                    context->allocate_temp(
1477                        DataTypeToEnum<T>::value,
1478                        ShapeFromFormat(compute_data_format,
1479                                        compatible_input_shape.dim_size(0),
1480                                        {{compatible_input_shape.dim_size(2),
1481                                          compatible_input_shape.dim_size(3),
1482                                          compatible_input_shape.dim_size(4)}},
1483                                        compatible_input_shape.dim_size(1)),
1484                        &pre_transformed_in_backprop));
1485 
1486     auto out_backprop_ptr =
1487         AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
1488                        transformed_out_backprop.template flat<T>().size());
1489     auto filter_ptr =
1490         AsDeviceMemory(transformed_filter.template flat<T>().data(),
1491                        transformed_filter.template flat<T>().size());
1492     auto in_backprop_ptr =
1493         AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
1494                        pre_transformed_in_backprop.template flat<T>().size());
1495 
1496     static int64_t ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
1497         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 33);  // 8GB by default
1498 
1499     const int device_id = stream->parent()->device_ordinal();
1500     // To make sure the Conv3DBackpropInputV2 get the correct dtype, we infer
1501     // the dtype from 2nd input, i.e., out_backprop.
1502     DataType dtype = context->input(2).dtype();
1503     const ConvParameters conv_parameters = {
1504         dims.batch_size,
1505         dims.in_depth,
1506         {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
1507         compute_data_format,
1508         dims.out_depth,
1509         {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
1510         {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
1511         {{dims.stride(0), dims.stride(1), dims.stride(2)}},
1512         {{padding_planes, padding_rows, padding_cols}},
1513         dtype,
1514         device_id,
1515         conv_desc.group_count()};
1516 
1517     using se::dnn::AlgorithmConfig;
1518     using se::dnn::AlgorithmDesc;
1519     using se::dnn::ProfileResult;
1520 
1521     auto entry_or = AutotuneUnfusedConv(
1522         cudnn_use_autotune_, AutotuneConv3dBwdData::GetInstance(),
1523         conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_DATA,
1524         input_desc, in_backprop_ptr, filter_desc, filter_ptr, conv_desc,
1525         output_desc, out_backprop_ptr, ConvolveBackwardDataScratchSize);
1526     OP_REQUIRES_OK(context, entry_or.status());
1527     auto autotune_entry = std::move(entry_or).value();
1528 
1529     DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
1530                                           context);
1531     Status cudnn_launch_status = LaunchAutotunedConv(
1532         autotune_entry, &scratch_allocator,
1533         se::dnn::ConvolutionKind::BACKWARD_DATA, stream, input_desc,
1534         in_backprop_ptr, filter_desc, filter_ptr, conv_desc, output_desc,
1535         out_backprop_ptr);
1536     if (!cudnn_launch_status.ok()) {
1537       context->SetStatus(cudnn_launch_status);
1538       return;
1539     }
1540 
1541     if (rows_odd || cols_odd || planes_odd) {
1542       Tensor in_backprop_remove_padding;
1543       OP_REQUIRES_OK(
1544           context, context->allocate_temp(
1545                        DataTypeToEnum<T>::value,
1546                        ShapeFromFormat(compute_data_format, dims.batch_size,
1547                                        {{dims.input_size(0), dims.input_size(1),
1548                                          dims.input_size(2)}},
1549                                        dims.in_depth),
1550                        &in_backprop_remove_padding));
1551 
1552       // Remove the padding for odd spatial dimensions.
1553       functor::PadInput<GPUDevice, T, int, 5>()(
1554           context->eigen_device<GPUDevice>(),
1555           To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
1556                       .tensor<T, 5>()),
1557           {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
1558           To32Bit(in_backprop_remove_padding.tensor<T, 5>()),
1559           compute_data_format, T{});
1560 
1561       pre_transformed_in_backprop = in_backprop_remove_padding;
1562     }
1563 
1564     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1565       auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1566       functor::NCHWToNHWC<GPUDevice, T, 5>()(
1567           context->eigen_device<GPUDevice>(),
1568           toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(),
1569           in_backprop->tensor<T, 5>());
1570     } else {
1571       *in_backprop = pre_transformed_in_backprop;
1572     }
1573   }
1574 
1575  private:
1576   std::vector<int32> dilation_;
1577   std::vector<int32> stride_;
1578   Padding padding_;
1579   TensorFormat data_format_;
1580   bool takes_shape_;
1581   bool cudnn_use_autotune_;
1582 };
1583 
1584 // A dummy type to group backward filter autotune results together.
1585 struct Conv3dBackwardFilterAutotuneGroup {
nametensorflow::Conv3dBackwardFilterAutotuneGroup1586   static string name() { return "Conv3dBwdFilter"; }
1587 };
1588 
1589 typedef AutotuneSingleton<Conv3dBackwardFilterAutotuneGroup, ConvParameters,
1590                           AutotuneEntry<se::dnn::ConvOp>>
1591     AutotuneConv3dBwdFilter;
1592 
1593 template <typename T>
1594 class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
1595  public:
Conv3DBackpropFilterOp(OpKernelConstruction * context)1596   explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
1597       : OpKernel(context),
1598         data_format_(FORMAT_NHWC),
1599         takes_shape_(type_string().find("V2") != std::string::npos) {
1600     // data_format is only available in V2.
1601     if (takes_shape_) {
1602       string data_format;
1603       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1604       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1605                   errors::InvalidArgument("Invalid data format"));
1606     }
1607     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
1608     OP_REQUIRES(context, dilation_.size() == 5,
1609                 errors::InvalidArgument("Dilation rates field must "
1610                                         "specify 5 dimensions"));
1611     OP_REQUIRES(context,
1612                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
1613                  GetTensorDim(dilation_, data_format_, 'N') == 1),
1614                 errors::InvalidArgument(
1615                     "Current implementation does not yet support "
1616                     "dilation rates in the batch and depth dimensions."));
1617     OP_REQUIRES(
1618         context,
1619         (GetTensorDim(dilation_, data_format_, '0') > 0 &&
1620          GetTensorDim(dilation_, data_format_, '1') > 0 &&
1621          GetTensorDim(dilation_, data_format_, '2') > 0),
1622         errors::InvalidArgument("Dilated rates should be larger than 0."));
1623     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1624     OP_REQUIRES(context, stride_.size() == 5,
1625                 errors::InvalidArgument("Sliding window strides field must "
1626                                         "specify 5 dimensions"));
1627     OP_REQUIRES(
1628         context,
1629         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
1630          GetTensorDim(stride_, data_format_, 'N') == 1),
1631         errors::InvalidArgument("Current implementation does not yet support "
1632                                 "strides in the batch and depth dimensions."));
1633     OP_REQUIRES(
1634         context,
1635         (GetTensorDim(stride_, data_format_, '0') > 0 &&
1636          GetTensorDim(stride_, data_format_, '1') > 0 &&
1637          GetTensorDim(stride_, data_format_, '2') > 0),
1638         errors::InvalidArgument("Spatial strides should be larger than 0."));
1639     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1640     cudnn_use_autotune_ = CudnnUseAutotune();
1641   }
1642 
Compute(OpKernelContext * context)1643   void Compute(OpKernelContext* context) override {
1644     const Tensor& input = context->input(0);
1645     const TensorShape& input_shape = input.shape();
1646 
1647     const Tensor& out_backprop = context->input(2);
1648     const TensorShape& out_backprop_shape = out_backprop.shape();
1649 
1650     TensorShape filter_shape;
1651     if (takes_shape_) {
1652       const Tensor& filter_sizes = context->input(1);
1653       OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()),
1654                   errors::InvalidArgument(
1655                       "filter_sizes shape must be rank 1 but is rank ",
1656                       filter_sizes.shape().dims()));
1657       OP_REQUIRES_OK(context, tensor::MakeShape(filter_sizes, &filter_shape));
1658     } else {
1659       filter_shape = context->input(1).shape();
1660     }
1661 
1662     ConvBackpropDimensions dims;
1663     OP_REQUIRES_OK(
1664         context,
1665         ConvBackpropComputeDimensionsV2(
1666             "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, input_shape,
1667             filter_shape, out_backprop_shape, dilation_, stride_, padding_,
1668             /*explicit_paddings=*/{}, data_format_, &dims));
1669 
1670     Tensor* filter_backprop;
1671     OP_REQUIRES_OK(context,
1672                    context->allocate_output(0, filter_shape, &filter_backprop));
1673 
1674     auto* stream = context->op_device_context()->stream();
1675     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
1676 
1677     bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
1678     if (!is_grouped_convolution && dims.filter_size(1) == 1 &&
1679         dims.filter_size(2) == 1 && dims.filter_size(0) == 1 &&
1680         dims.dilation(2) == 1 && dims.dilation(1) == 1 &&
1681         dims.dilation(0) == 1 && dims.stride(2) == 1 && dims.stride(1) == 1 &&
1682         dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) {
1683       const uint64 m = dims.in_depth;
1684       const uint64 k = dims.batch_size * dims.input_size(1) *
1685                        dims.input_size(2) * dims.input_size(0);
1686       const uint64 n = dims.out_depth;
1687 
1688       // The shape of output backprop is
1689       //   [batch, out_z, out_y, out_x, out_depth]
1690       // From cublas's perspective, it is: n x k
1691       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1692                                   out_backprop.template flat<T>().size());
1693 
1694       // The shape of input is:
1695       //   [batch, in_z, in_y, in_x, in_depth],
1696       // From cublas's perspective, it is: m x k
1697       auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
1698                                   input.template flat<T>().size());
1699 
1700       // The shape of the filter backprop is:
1701       //   [1, 1, 1, in_depth, out_depth]
1702       // From cublas's perspective, it is: n x m
1703       auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
1704                                   filter_backprop->template flat<T>().size());
1705 
1706       OP_REQUIRES_OK(context,
1707                      stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
1708                                           se::blas::Transpose::kTranspose, n, m,
1709                                           k, a_ptr, n, b_ptr, m, &c_ptr, n,
1710                                           se::blas::kDefaultComputePrecision));
1711       return;
1712     } else if (!is_grouped_convolution &&
1713                dims.filter_size(0) == dims.input_size(0) &&
1714                dims.filter_size(1) == dims.input_size(1) &&
1715                dims.filter_size(2) == dims.input_size(2) &&
1716                padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
1717       const uint64 m = dims.input_size(0) * dims.input_size(1) *
1718                        dims.input_size(2) * dims.in_depth;
1719       const uint64 k = dims.batch_size;
1720       const uint64 n = dims.out_depth;
1721 
1722       auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
1723                                   input.template flat<T>().size());
1724       auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1725                                   out_backprop.template flat<T>().size());
1726       auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
1727                                   filter_backprop->template flat<T>().size());
1728 
1729       OP_REQUIRES_OK(context,
1730                      stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
1731                                           se::blas::Transpose::kTranspose, n, m,
1732                                           k, b_ptr, n, a_ptr, m, &c_ptr, n,
1733                                           se::blas::kDefaultComputePrecision));
1734       return;
1735     }
1736 
1737     int padding_planes = dims.SpatialPadding(padding_, 0);
1738     int padding_rows = dims.SpatialPadding(padding_, 1);
1739     int padding_cols = dims.SpatialPadding(padding_, 2);
1740     const bool planes_odd = (padding_planes % 2 != 0);
1741     const bool rows_odd = (padding_rows % 2 != 0);
1742     const bool cols_odd = (padding_cols % 2 != 0);
1743 
1744     Tensor compatible_input;
1745     if (rows_odd || cols_odd || planes_odd) {
1746       OP_REQUIRES_OK(context,
1747                      context->allocate_temp(
1748                          DataTypeToEnum<T>::value,
1749                          ShapeFromFormat(data_format_, dims.batch_size,
1750                                          {{dims.input_size(0) + planes_odd,
1751                                            dims.input_size(1) + rows_odd,
1752                                            dims.input_size(2) + cols_odd}},
1753                                          dims.in_depth),
1754                          &compatible_input));
1755       functor::PadInput<GPUDevice, T, int, 5>()(
1756           context->template eigen_device<GPUDevice>(),
1757           To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
1758           {{planes_odd, rows_odd, cols_odd}},
1759           To32Bit(compatible_input.tensor<T, 5>()), data_format_, T{});
1760     } else {
1761       compatible_input = input;
1762     }
1763 
1764     CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
1765         << "Negative paddings: (" << padding_rows << ", " << padding_cols
1766         << ", " << padding_planes << ")";
1767 
1768 #if GOOGLE_CUDA
1769     const bool compute_in_nhwc =
1770         CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
1771 #else
1772     // fast NDHWC implementation is a CUDA only feature
1773     const bool compute_in_nhwc = false;
1774 #endif
1775     const TensorFormat compute_data_format =
1776         (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
1777                                                          : FORMAT_NCHW;
1778 
1779     VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:"
1780             << " data_format=" << ToString(data_format_)
1781             << " compute_data_format=" << ToString(compute_data_format);
1782 
1783     constexpr auto kComputeInNHWC =
1784         std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1785                         se::dnn::FilterLayout::kOutputYXInput);
1786     constexpr auto kComputeInNCHW =
1787         std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1788                         se::dnn::FilterLayout::kOutputInputYX);
1789 
1790     se::dnn::DataLayout compute_data_layout;
1791     se::dnn::FilterLayout filter_layout;
1792 
1793     std::tie(compute_data_layout, filter_layout) =
1794         compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1795 
1796     se::dnn::BatchDescriptor input_desc(3);
1797     input_desc.set_count(dims.batch_size)
1798         .set_spatial_dim(DimIndex::X,
1799                          GetTensorDim(compatible_input, data_format_, '2'))
1800         .set_spatial_dim(DimIndex::Y,
1801                          GetTensorDim(compatible_input, data_format_, '1'))
1802         .set_spatial_dim(DimIndex::Z,
1803                          GetTensorDim(compatible_input, data_format_, '0'))
1804         .set_feature_map_count(dims.in_depth)
1805         .set_layout(compute_data_layout);
1806     se::dnn::BatchDescriptor output_desc(3);
1807     output_desc.set_count(dims.batch_size)
1808         .set_spatial_dim(DimIndex::X, dims.output_size(2))
1809         .set_spatial_dim(DimIndex::Y, dims.output_size(1))
1810         .set_spatial_dim(DimIndex::Z, dims.output_size(0))
1811         .set_feature_map_count(dims.out_depth)
1812         .set_layout(compute_data_layout);
1813     se::dnn::FilterDescriptor filter_desc(3);
1814     filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
1815         .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
1816         .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
1817         .set_input_feature_map_count(filter_shape.dim_size(3))
1818         .set_output_feature_map_count(filter_shape.dim_size(4))
1819         .set_layout(filter_layout);
1820     se::dnn::ConvolutionDescriptor conv_desc(3);
1821     conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
1822         .set_dilation_rate(DimIndex::Y, dims.dilation(1))
1823         .set_dilation_rate(DimIndex::Z, dims.dilation(0))
1824         .set_filter_stride(DimIndex::X, dims.stride(2))
1825         .set_filter_stride(DimIndex::Y, dims.stride(1))
1826         .set_filter_stride(DimIndex::Z, dims.stride(0))
1827         .set_zero_padding(DimIndex::X, padding_cols / 2)
1828         .set_zero_padding(DimIndex::Y, padding_rows / 2)
1829         .set_zero_padding(DimIndex::Z, padding_planes / 2)
1830         .set_group_count(dims.in_depth / filter_shape.dim_size(3));
1831 
1832     Tensor pre_transformed_filter_backprop;
1833     auto dst_format =
1834         compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
1835     TensorShape dst_shape =
1836         dst_format == FORMAT_OIHW
1837             ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
1838                            dims.filter_size(0), dims.filter_size(1),
1839                            dims.filter_size(2)})
1840             : TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
1841                            dims.filter_size(1), dims.filter_size(2),
1842                            filter_shape.dim_size(3)});
1843     OP_REQUIRES_OK(context,
1844                    context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1845                                           &pre_transformed_filter_backprop));
1846 
1847     Tensor transformed_out_backprop;
1848     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1849       VLOG(4) << "Convert the `out_backprop` tensor from NDHWC to NCDHW.";
1850       TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
1851                                 dims.output_size(0), dims.output_size(1),
1852                                 dims.output_size(2)};
1853       OP_REQUIRES_OK(
1854           context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
1855                                           &transformed_out_backprop));
1856       if (dims.out_depth > 1) {
1857         functor::NHWCToNCHW<GPUDevice, T, 5>()(
1858             context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
1859             transformed_out_backprop.tensor<T, 5>());
1860       } else {
1861         CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
1862       }
1863     } else {
1864       transformed_out_backprop = out_backprop;
1865     }
1866     Tensor transformed_input;
1867     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1868       VLOG(4) << "Convert the `input` tensor from NDHWC to NCDHW.";
1869       TensorShape nchw_shape = {
1870           dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
1871           compatible_input.dim_size(2), compatible_input.dim_size(3)};
1872       if (dims.in_depth > 1) {
1873         OP_REQUIRES_OK(context,
1874                        context->allocate_temp(DataTypeToEnum<T>::value,
1875                                               nchw_shape, &transformed_input));
1876         functor::NHWCToNCHW<GPUDevice, T, 5>()(
1877             context->eigen_device<GPUDevice>(),
1878             const_cast<const Tensor&>(compatible_input).tensor<T, 5>(),
1879             transformed_input.tensor<T, 5>());
1880       } else {
1881         CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
1882       }
1883     } else {
1884       transformed_input = compatible_input;
1885     }
1886 
1887     auto out_backprop_ptr =
1888         AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
1889                        transformed_out_backprop.template flat<T>().size());
1890     auto filter_backprop_ptr = AsDeviceMemory(
1891         pre_transformed_filter_backprop.template flat<T>().data(),
1892         pre_transformed_filter_backprop.template flat<T>().size());
1893     auto input_ptr =
1894         AsDeviceMemory(transformed_input.template flat<T>().data(),
1895                        transformed_input.template flat<T>().size());
1896 
1897     static int64_t ConvolveBackwardFilterScratchSize =
1898         GetDnnWorkspaceLimitOrDefault();
1899 
1900     const int device_id = stream->parent()->device_ordinal();
1901     DataType dtype = input.dtype();
1902     const ConvParameters conv_parameters = {
1903         dims.batch_size,
1904         dims.in_depth,
1905         {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
1906         compute_data_format,
1907         dims.out_depth,
1908         {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
1909         {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
1910         {{dims.stride(0), dims.stride(1), dims.stride(2)}},
1911         {{padding_planes, padding_rows, padding_cols}},
1912         dtype,
1913         device_id,
1914         conv_desc.group_count()};
1915 
1916     using se::dnn::AlgorithmConfig;
1917     using se::dnn::AlgorithmDesc;
1918     using se::dnn::ProfileResult;
1919 
1920     auto entry_or = AutotuneUnfusedConv(
1921         cudnn_use_autotune_, AutotuneConv3dBwdFilter::GetInstance(),
1922         conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_FILTER,
1923         input_desc, input_ptr, filter_desc, filter_backprop_ptr, conv_desc,
1924         output_desc, out_backprop_ptr, ConvolveBackwardFilterScratchSize);
1925     OP_REQUIRES_OK(context, entry_or.status());
1926     auto autotune_entry = std::move(entry_or).value();
1927 
1928     DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
1929                                           context);
1930     Status cudnn_launch_status = LaunchAutotunedConv(
1931         autotune_entry, &scratch_allocator,
1932         se::dnn::ConvolutionKind::BACKWARD_FILTER, stream, input_desc,
1933         input_ptr, filter_desc, filter_backprop_ptr, conv_desc, output_desc,
1934         out_backprop_ptr);
1935     if (!cudnn_launch_status.ok()) {
1936       context->SetStatus(cudnn_launch_status);
1937       return;
1938     }
1939 
1940     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1941     functor::ReverseTransformFilter<GPUDevice, T, 5>()(
1942         context->eigen_device<GPUDevice>(), /*src_filter_format=*/dst_format,
1943         toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
1944         filter_backprop->tensor<T, 5>());
1945   }
1946 
1947  private:
1948   std::vector<int32> dilation_;
1949   std::vector<int32> stride_;
1950   Padding padding_;
1951   TensorFormat data_format_;
1952   bool takes_shape_;
1953   bool cudnn_use_autotune_;
1954 };
1955 
1956 #define REGISTER_GPU_KERNEL(T)                                                \
1957   REGISTER_KERNEL_BUILDER(                                                    \
1958       Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"),  \
1959       Conv3DBackpropInputOp<GPUDevice, T>);                                   \
1960   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                       \
1961                               .Device(DEVICE_GPU)                             \
1962                               .TypeConstraint<T>("T")                         \
1963                               .HostMemory("input_sizes"),                     \
1964                           Conv3DBackpropInputOp<GPUDevice, T>);               \
1965   REGISTER_KERNEL_BUILDER(                                                    \
1966       Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1967       Conv3DBackpropFilterOp<GPUDevice, T>);                                  \
1968   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1969                               .Device(DEVICE_GPU)                             \
1970                               .TypeConstraint<T>("T")                         \
1971                               .HostMemory("filter_sizes"),                    \
1972                           Conv3DBackpropFilterOp<GPUDevice, T>);
1973 TF_CALL_half(REGISTER_GPU_KERNEL);
1974 TF_CALL_float(REGISTER_GPU_KERNEL);
1975 TF_CALL_double(REGISTER_GPU_KERNEL);
1976 #undef REGISTER_GPU_KERNEL
1977 
1978 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1979 
1980 }  // namespace tensorflow
1981