1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define USE_EIGEN_TENSOR
17 #define EIGEN_USE_THREADS
18
19 #include <utility>
20
21 #include "tensorflow/core/framework/kernel_shape_util.h"
22 #include "tensorflow/core/framework/numeric_op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/tensor_slice.h"
28 #include "tensorflow/core/framework/tensor_util.h"
29 #include "tensorflow/core/kernels/conv_2d.h"
30 #include "tensorflow/core/kernels/conv_3d.h"
31 #include "tensorflow/core/kernels/conv_grad_ops.h"
32 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
33 #include "tensorflow/core/kernels/conv_ops_gpu.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/gtl/inlined_vector.h"
36 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
37 #include "tensorflow/core/util/padding.h"
38 #include "tensorflow/core/util/tensor_format.h"
39 #include "tensorflow/core/util/use_cudnn.h"
40 #include "tensorflow/core/util/work_sharder.h"
41
42 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
43 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
44 #endif
45
46 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
47 #include "tensorflow/core/platform/stream_executor.h"
48 using stream_executor::dnn::DimIndex;
49 #include "tensorflow/core/protobuf/autotuning.pb.h"
50 #include "tensorflow/core/util/autotune_maps/conv_parameters.h"
51 #include "tensorflow/core/util/proto/proto_utils.h"
52 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
53 #if GOOGLE_CUDA
54 #include "third_party/gpus/cudnn/cudnn.h"
55 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
56 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
57 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
58 #endif // GOOGLE_CUDA
59
60 namespace {
61
62 // TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
63 // conv_grad_input_ops_3d.cc.
64
65 // TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
66
67 // "Depth" is already used for the channel dimension, so for the third spatial
68 // dimension in this file we use "plane", although in NDHWC layout it's
69 // indicated with a "D".
70
71 // Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
72 // order (planes, height, width, depth), constructed from patches in 'col_data',
73 // which is required to be in storage order (out_planes * out_height *
74 // out_width, filter_planes, filter_height, filter_width, in_depth).
75 //
76 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
77 template <typename T>
Col2im(const T * col_data,const int depth,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * im_data)78 void Col2im(const T* col_data, const int depth, const int planes,
79 const int height, const int width, const int filter_p,
80 const int filter_h, const int filter_w, const int pad_pt,
81 const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
82 const int pad_r, const int stride_p, const int stride_h,
83 const int stride_w, T* im_data) {
84 const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
85 const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
86 const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
87 int p_pad = -pad_pt;
88 for (int p = 0; p < planes_col; ++p) {
89 int h_pad = -pad_t;
90 for (int h = 0; h < height_col; ++h) {
91 int w_pad = -pad_l;
92 for (int w = 0; w < width_col; ++w) {
93 T* im_patch_data =
94 im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
95 for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
96 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
97 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
98 if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
99 iw < width) {
100 for (int i = 0; i < depth; ++i) {
101 im_patch_data[i] += col_data[i];
102 }
103 }
104 im_patch_data += depth;
105 col_data += depth;
106 }
107 // Jump over remaining number of depth.
108 im_patch_data += depth * (width - filter_w);
109 }
110 // Jump over remaining number of (depth * width).
111 im_patch_data += (depth * width) * (height - filter_h);
112 }
113 w_pad += stride_w;
114 }
115 h_pad += stride_h;
116 }
117 p_pad += stride_p;
118 }
119 }
120
121 // Returns in 'col_data', image patches in storage order (planes, height, width,
122 // depth) extracted from image at 'input_data', which is required to be in
123 // storage order (batch, planes, height, width, depth).
124 //
125 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
126 template <typename T>
Im2col(const T * input_data,const int depth,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * col_data)127 void Im2col(const T* input_data, const int depth, const int planes,
128 const int height, const int width, const int filter_p,
129 const int filter_h, const int filter_w, const int pad_pt,
130 const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
131 const int pad_r, const int stride_p, const int stride_h,
132 const int stride_w, T* col_data) {
133 const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
134 const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
135 const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
136
137 int p_pad = -pad_pt;
138 for (int p = 0; p < planes_col; ++p) {
139 int h_pad = -pad_t;
140 for (int h = 0; h < height_col; ++h) {
141 int w_pad = -pad_l;
142 for (int w = 0; w < width_col; ++w) {
143 for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
144 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
145 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
146 if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
147 iw < width) {
148 memcpy(col_data,
149 input_data +
150 (ip * height * width + ih * width + iw) * depth,
151 sizeof(T) * depth);
152 } else {
153 // This should be simply padded with zero.
154 memset(col_data, 0, sizeof(T) * depth);
155 }
156 col_data += depth;
157 }
158 }
159 }
160 w_pad += stride_w;
161 }
162 h_pad += stride_h;
163 }
164 p_pad += stride_p;
165 }
166 }
167
168 } // namespace
169
170 namespace tensorflow {
171
172 typedef Eigen::ThreadPoolDevice CPUDevice;
173 typedef Eigen::GpuDevice GPUDevice;
174
175 // Backprop for input that offloads computation to
176 // Eigen::CuboidConvolutionBackwardInput.
177 template <typename Device, class T>
178 class Conv3DBackpropInputOp : public OpKernel {
179 public:
Conv3DBackpropInputOp(OpKernelConstruction * context)180 explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
181 : OpKernel(context),
182 data_format_(FORMAT_NHWC),
183 takes_shape_(type_string().find("V2") != std::string::npos) {
184 // data_format is only available in V2.
185 if (takes_shape_) {
186 string data_format;
187 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
188 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
189 errors::InvalidArgument("Invalid data format"));
190 OP_REQUIRES(
191 context, data_format_ == FORMAT_NHWC,
192 errors::InvalidArgument(
193 "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
194 }
195
196 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
197 OP_REQUIRES(context, dilation_.size() == 5,
198 errors::InvalidArgument("Dilation rates field must "
199 "specify 5 dimensions"));
200 OP_REQUIRES(context,
201 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
202 GetTensorDim(dilation_, data_format_, 'N') == 1),
203 errors::InvalidArgument(
204 "Current implementation does not yet support "
205 "dilation rates in the batch and depth dimensions."));
206
207 // TODO(yangzihao): Add CPU version of dilated conv 3D.
208 OP_REQUIRES(context,
209 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
210 GetTensorDim(dilation_, data_format_, '1') == 1 &&
211 GetTensorDim(dilation_, data_format_, '2') == 1),
212 errors::InvalidArgument(
213 "Current CPU implementation does not yet support "
214 "dilation rates larger than 1."));
215
216 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
217 OP_REQUIRES(context, stride_.size() == 5,
218 errors::InvalidArgument("Sliding window strides field must "
219 "specify 5 dimensions"));
220 OP_REQUIRES(
221 context,
222 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
223 GetTensorDim(stride_, data_format_, 'N') == 1),
224 errors::InvalidArgument("Current implementation does not yet support "
225 "strides in the batch and depth dimensions."));
226 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
227 }
228
Compute(OpKernelContext * context)229 void Compute(OpKernelContext* context) override {
230 const Tensor& filter = context->input(1);
231 const TensorShape& filter_shape = filter.shape();
232
233 const Tensor& out_backprop = context->input(2);
234 const TensorShape& out_backprop_shape = out_backprop.shape();
235
236 TensorShape input_shape;
237 if (takes_shape_) {
238 const Tensor& input_sizes = context->input(0);
239 // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
240 // input_sizes.
241 OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
242 } else {
243 input_shape = context->input(0).shape();
244 }
245
246 OP_REQUIRES(context, input_shape.dims() == 5,
247 errors::InvalidArgument("input tensor must have 5 dimensions"));
248 OP_REQUIRES(
249 context, filter_shape.dims() == 5,
250 errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
251 OP_REQUIRES(
252 context, out_backprop_shape.dims() == 5,
253 errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
254 OP_REQUIRES(
255 context, input_shape.dim_size(4) == filter_shape.dim_size(3),
256 errors::InvalidArgument("input and filter_sizes must have the same "
257 "number of channels. Got ",
258 input_shape.dim_size(4), " for input and ",
259 filter_shape.dim_size(3), " for filter_sizes"));
260 OP_REQUIRES(
261 context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
262 errors::InvalidArgument("out_backprop and filter_sizes must have the "
263 "same number of channels. Got ",
264 out_backprop_shape.dim_size(4),
265 " for out_backprop and ",
266 filter_shape.dim_size(4), " for filter_sizes"));
267
268 ConvBackpropDimensions dims;
269 OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
270 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
271 input_shape, filter_shape, out_backprop_shape,
272 stride_, padding_, data_format_, &dims));
273
274 Tensor* in_backprop;
275 OP_REQUIRES_OK(context,
276 context->allocate_output(0, input_shape, &in_backprop));
277
278 functor::CuboidConvolutionBackwardInput<Device, T>()(
279 context->eigen_device<Device>(),
280 in_backprop->tensor<T, 5>(), // input_backward
281 filter.tensor<T, 5>(), // filter
282 out_backprop.tensor<T, 5>(), // output_backward
283 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
284 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
285 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
286 }
287
288 private:
289 std::vector<int32> dilation_;
290 std::vector<int32> stride_;
291 Padding padding_;
292 TensorFormat data_format_;
293 bool takes_shape_;
294
295 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
296 };
297
298 // Custom backprop for input that explicitly does the work sharding and calls
299 // Eigen only to multiply matrices.
300 template <typename Device, class T>
301 class Conv3DCustomBackpropInputOp : public OpKernel {
302 // Limit the maximum size of allocated temporary buffer to
303 // kMaxTempAllocationOverhead times the size of the input tensors (input,
304 // filter, out_backprop). If the size of the temporary buffer exceeds this
305 // limit, fallback on Eigen implementation.
306 static constexpr int kMaxTempAllocationOverhead = 25;
307
308 public:
Conv3DCustomBackpropInputOp(OpKernelConstruction * context)309 explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
310 : OpKernel(context),
311 data_format_(FORMAT_NHWC),
312 takes_shape_(type_string().find("V2") != std::string::npos) {
313 // data_format is only available in V2.
314 if (takes_shape_) {
315 string data_format;
316 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
317 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
318 errors::InvalidArgument("Invalid data format"));
319 OP_REQUIRES(
320 context, data_format_ == FORMAT_NHWC,
321 errors::InvalidArgument(
322 "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
323 }
324
325 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
326 OP_REQUIRES(context, dilation_.size() == 5,
327 errors::InvalidArgument("Dilation rates field must "
328 "specify 5 dimensions"));
329 OP_REQUIRES(context,
330 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
331 GetTensorDim(dilation_, data_format_, 'N') == 1),
332 errors::InvalidArgument(
333 "Current implementation does not yet support "
334 "dilation rates in the batch and depth dimensions."));
335
336 // TODO(yangzihao): Add CPU version of dilated conv 3D.
337 OP_REQUIRES(context,
338 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
339 GetTensorDim(dilation_, data_format_, '1') == 1 &&
340 GetTensorDim(dilation_, data_format_, '2') == 1),
341 errors::InvalidArgument(
342 "Current CPU implementation does not yet support "
343 "dilation rates larger than 1."));
344
345 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
346 OP_REQUIRES(context, stride_.size() == 5,
347 errors::InvalidArgument("Sliding window strides field must "
348 "specify 5 dimensions"));
349 OP_REQUIRES(
350 context,
351 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
352 GetTensorDim(stride_, data_format_, 'N') == 1),
353 errors::InvalidArgument("Current implementation does not yet support "
354 "strides in the batch and depth dimensions."));
355 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
356 }
357
Compute(OpKernelContext * context)358 void Compute(OpKernelContext* context) override {
359 const Tensor& filter = context->input(1);
360 const TensorShape& filter_shape = filter.shape();
361
362 const Tensor& out_backprop = context->input(2);
363 const TensorShape& out_backprop_shape = out_backprop.shape();
364
365 TensorShape input_shape;
366 if (takes_shape_) {
367 const Tensor& input_sizes = context->input(0);
368 // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
369 // input_sizes.
370 OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
371 } else {
372 input_shape = context->input(0).shape();
373 }
374
375 OP_REQUIRES(context, input_shape.dims() == 5,
376 errors::InvalidArgument("input tensor must have 5 dimensions"));
377 OP_REQUIRES(
378 context, filter_shape.dims() == 5,
379 errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
380 OP_REQUIRES(
381 context, out_backprop_shape.dims() == 5,
382 errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
383 OP_REQUIRES(
384 context, input_shape.dim_size(4) == filter_shape.dim_size(3),
385 errors::InvalidArgument("input and filter_sizes must have the same "
386 "number of channels. Got ",
387 input_shape.dim_size(4), " for input and ",
388 filter_shape.dim_size(3), " for filter_sizes"));
389 OP_REQUIRES(
390 context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
391 errors::InvalidArgument("out_backprop and filter_sizes must have the "
392 "same number of channels. Got ",
393 out_backprop_shape.dim_size(4),
394 " for out_backprop and ",
395 filter_shape.dim_size(4), " for filter_sizes"));
396
397 ConvBackpropDimensions dims;
398 OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
399 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
400 input_shape, filter_shape, out_backprop_shape,
401 stride_, padding_, data_format_, &dims));
402
403 Tensor* in_backprop;
404 OP_REQUIRES_OK(context,
405 context->allocate_output(0, input_shape, &in_backprop));
406
407 int64_t top_pad_planes, bottom_pad_planes;
408 int64_t top_pad_rows, bottom_pad_rows;
409 int64_t left_pad_cols, right_pad_cols;
410
411 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
412 dims.spatial_dims[0].input_size,
413 dims.spatial_dims[0].filter_size,
414 dims.spatial_dims[0].stride, padding_,
415 &dims.spatial_dims[0].output_size,
416 &top_pad_planes, &bottom_pad_planes));
417 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
418 dims.spatial_dims[1].input_size,
419 dims.spatial_dims[1].filter_size,
420 dims.spatial_dims[1].stride, padding_,
421 &dims.spatial_dims[1].output_size,
422 &top_pad_rows, &bottom_pad_rows));
423 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
424 dims.spatial_dims[2].input_size,
425 dims.spatial_dims[2].filter_size,
426 dims.spatial_dims[2].stride, padding_,
427 &dims.spatial_dims[2].output_size,
428 &left_pad_cols, &right_pad_cols));
429
430 // TODO(ezhulenev): Extract work size and shard estimation to shared
431 // functions in conv_grad_ops, and update 2d convolution backprop.
432
433 // The total dimension size of each kernel.
434 const int64_t filter_total_size =
435 dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
436 dims.spatial_dims[2].filter_size * dims.in_depth;
437
438 // The output image size is the spatial size of the output.
439 const int64_t output_image_size = dims.spatial_dims[0].output_size *
440 dims.spatial_dims[1].output_size *
441 dims.spatial_dims[2].output_size;
442
443 const auto cache_sizes = Eigen::internal::CacheSizes();
444 const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
445
446 // Use L3 cache size as target working set size.
447 const size_t target_working_set_size = l3_cache_size / sizeof(T);
448
449 // Calculate size of matrices involved in MatMul: C = A x B.
450 const int64_t size_A = output_image_size * dims.out_depth;
451
452 const int64_t size_B = filter_total_size * dims.out_depth;
453
454 const int64_t size_C = output_image_size * filter_total_size;
455
456 const int64_t work_unit_size = size_A + size_B + size_C;
457
458 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
459
460 // Use parallel tensor contractions if there is no batching.
461 //
462 // Compared to Conv2D code, this version is missing work size estimation. In
463 // benchmarks I didn't find a case when it's beneficial to run parallel
464 // contraction compared to sharding and matmuls.
465 const bool use_parallel_contraction = dims.batch_size == 1;
466
467 OP_REQUIRES(
468 context, work_unit_size > 0,
469 errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
470 "must all have at least 1 element"));
471
472 const size_t shard_size =
473 use_parallel_contraction
474 ? 1
475 : (target_working_set_size + work_unit_size - 1) / work_unit_size;
476
477 // Total number of elements in all the tensors used by this kernel.
478 int64_t total_tensor_elements = input_shape.num_elements() +
479 filter_shape.num_elements() +
480 out_backprop_shape.num_elements();
481
482 // Shape of the temporary workspace buffer.
483 TensorShape col_buffer_shape = {static_cast<int64_t>(shard_size),
484 static_cast<int64_t>(output_image_size),
485 static_cast<int64_t>(filter_total_size)};
486 int64_t col_buffer_elements = col_buffer_shape.num_elements();
487
488 // If the temporary allocation overhead is too large, fallback on Eigen
489 // implementation which requires much less memory.
490 int64_t col_buffer_overhead = col_buffer_elements / total_tensor_elements;
491 if (col_buffer_overhead > kMaxTempAllocationOverhead) {
492 VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
493 "col_buffer_overhead="
494 << col_buffer_overhead;
495
496 functor::CuboidConvolutionBackwardInput<Device, T>()(
497 context->eigen_device<Device>(),
498 in_backprop->tensor<T, 5>(), // input_backward
499 filter.tensor<T, 5>(), // filter
500 out_backprop.tensor<T, 5>(), // output_backward
501 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
502 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
503 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
504
505 return;
506 }
507
508 Tensor col_buffer;
509 OP_REQUIRES_OK(context,
510 context->allocate_temp(DataTypeToEnum<T>::value,
511 col_buffer_shape, &col_buffer));
512
513 // The input offset corresponding to a single input image.
514 const int64_t input_offset =
515 dims.spatial_dims[0].input_size * dims.spatial_dims[1].input_size *
516 dims.spatial_dims[2].input_size * dims.in_depth;
517
518 // The output offset corresponding to a single output image.
519 const int64_t output_offset =
520 dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
521 dims.spatial_dims[2].output_size * dims.out_depth;
522
523 const T* filter_data = filter.template flat<T>().data();
524 T* col_buffer_data = col_buffer.template flat<T>().data();
525 const T* out_backprop_data = out_backprop.template flat<T>().data();
526
527 auto in_backprop_flat = in_backprop->template flat<T>();
528 T* input_backprop_data = in_backprop_flat.data();
529 in_backprop_flat.device(context->eigen_device<Device>()) =
530 in_backprop_flat.constant(T(0));
531
532 if (use_parallel_contraction) {
533 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
534 Eigen::Unaligned>
535 TensorMap;
536 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
537 Eigen::Unaligned>
538 ConstTensorMap;
539
540 // Initialize contraction dims (we need to transpose 'B' below).
541 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
542 contract_dims[0].first = 1;
543 contract_dims[0].second = 1;
544
545 for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
546 // Compute gradient into col_buffer.
547 TensorMap C(col_buffer_data, output_image_size, filter_total_size);
548
549 ConstTensorMap A(out_backprop_data + output_offset * image_id,
550 output_image_size, dims.out_depth);
551 ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
552
553 C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
554
555 Col2im<T>(col_buffer_data, dims.in_depth,
556 // Input spatial dimensions.
557 dims.spatial_dims[0].input_size, // input planes
558 dims.spatial_dims[1].input_size, // input rows
559 dims.spatial_dims[2].input_size, // input cols
560 // Filter spatial dimensions.
561 dims.spatial_dims[0].filter_size, // filter planes
562 dims.spatial_dims[1].filter_size, // filter rows
563 dims.spatial_dims[2].filter_size, // filter cols
564 // Spatial padding.
565 top_pad_planes, top_pad_rows, left_pad_cols,
566 bottom_pad_planes, bottom_pad_rows, right_pad_cols,
567 // Spatial striding.
568 dims.spatial_dims[0].stride, // stride planes
569 dims.spatial_dims[1].stride, // stride rows
570 dims.spatial_dims[2].stride, // stride cols
571 input_backprop_data);
572
573 input_backprop_data += input_offset;
574 }
575 } else {
576 typedef Eigen::Map<
577 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
578 MatrixMap;
579 typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
580 Eigen::RowMajor>>
581 ConstMatrixMap;
582
583 for (int image_id = 0; image_id < dims.batch_size;
584 image_id += shard_size) {
585 const int shard_limit =
586 std::min(static_cast<int>(shard_size),
587 static_cast<int>(dims.batch_size) - image_id);
588
589 auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
590 &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
591 &output_image_size, &filter_total_size,
592 &input_backprop_data, &col_buffer_data,
593 &out_backprop_data, &filter_data, &input_offset,
594 &output_offset, &size_C](int64_t start, int64_t limit) {
595 for (int shard_id = start; shard_id < limit; ++shard_id) {
596 T* im2col_buf = col_buffer_data + shard_id * size_C;
597 T* input_data = input_backprop_data + shard_id * input_offset;
598 const T* out_data = out_backprop_data + shard_id * output_offset;
599
600 // Compute gradient into 'im2col_buf'.
601 MatrixMap C(im2col_buf, output_image_size, filter_total_size);
602
603 ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
604 ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
605
606 C.noalias() = A * B.transpose();
607
608 Col2im<T>(im2col_buf, dims.in_depth,
609 // Input spatial dimensions.
610 dims.spatial_dims[0].input_size, // input planes
611 dims.spatial_dims[1].input_size, // input rows
612 dims.spatial_dims[2].input_size, // input cols
613 // Filter spatial dimensions.
614 dims.spatial_dims[0].filter_size, // filter planes
615 dims.spatial_dims[1].filter_size, // filter rows
616 dims.spatial_dims[2].filter_size, // filter cols
617 // Spatial padding.
618 top_pad_planes, top_pad_rows, left_pad_cols,
619 bottom_pad_planes, bottom_pad_rows, right_pad_cols,
620 // Spatial striding.
621 dims.spatial_dims[0].stride, // stride planes
622 dims.spatial_dims[1].stride, // stride rows
623 dims.spatial_dims[2].stride, // stride cols
624 input_data);
625 }
626 };
627 Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
628 work_unit_size, shard);
629
630 input_backprop_data += input_offset * shard_limit;
631 out_backprop_data += output_offset * shard_limit;
632 }
633 }
634 }
635
636 private:
637 std::vector<int32> dilation_;
638 std::vector<int32> stride_;
639 Padding padding_;
640 TensorFormat data_format_;
641 bool takes_shape_;
642
643 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
644 };
645
646 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
647 // default Eigen implementation (at the cost of ~2x-8x peak memory usage).
648
649 #define REGISTER_CPU_KERNEL(T) \
650 REGISTER_KERNEL_BUILDER( \
651 Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
652 Conv3DCustomBackpropInputOp<CPUDevice, T>); \
653 REGISTER_KERNEL_BUILDER( \
654 Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
655 Conv3DCustomBackpropInputOp<CPUDevice, T>); \
656 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
657 .Device(DEVICE_CPU) \
658 .Label("custom") \
659 .TypeConstraint<T>("T"), \
660 Conv3DCustomBackpropInputOp<CPUDevice, T>); \
661 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
662 .Device(DEVICE_CPU) \
663 .Label("custom") \
664 .TypeConstraint<T>("T"), \
665 Conv3DCustomBackpropInputOp<CPUDevice, T>); \
666 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
667 .Device(DEVICE_CPU) \
668 .Label("eigen_tensor") \
669 .TypeConstraint<T>("T"), \
670 Conv3DBackpropInputOp<CPUDevice, T>); \
671 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
672 .Device(DEVICE_CPU) \
673 .Label("eigen_tensor") \
674 .TypeConstraint<T>("T"), \
675 Conv3DBackpropInputOp<CPUDevice, T>);
676
677 TF_CALL_half(REGISTER_CPU_KERNEL);
678 TF_CALL_float(REGISTER_CPU_KERNEL);
679 TF_CALL_double(REGISTER_CPU_KERNEL);
680 #undef REGISTER_CPU_KERNEL
681
682 // Backprop for filter that offloads computation to
683 // Eigen::CuboidConvolutionBackwardFilter.
684 template <typename Device, class T>
685 class Conv3DBackpropFilterOp : public OpKernel {
686 public:
Conv3DBackpropFilterOp(OpKernelConstruction * context)687 explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
688 : OpKernel(context),
689 data_format_(FORMAT_NHWC),
690 takes_shape_(type_string().find("V2") != std::string::npos) {
691 // data_format is only available in V2.
692 if (takes_shape_) {
693 string data_format;
694 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
695 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
696 errors::InvalidArgument("Invalid data format"));
697 OP_REQUIRES(
698 context, data_format_ == FORMAT_NHWC,
699 errors::InvalidArgument(
700 "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
701 }
702
703 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
704 OP_REQUIRES(context, dilation_.size() == 5,
705 errors::InvalidArgument("Dilation rates field must "
706 "specify 5 dimensions"));
707 OP_REQUIRES(context,
708 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
709 GetTensorDim(dilation_, data_format_, 'N') == 1),
710 errors::InvalidArgument(
711 "Current implementation does not yet support "
712 "dilation rates in the batch and depth dimensions."));
713
714 // TODO(yangzihao): Add CPU version of dilated conv 3D.
715 OP_REQUIRES(context,
716 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
717 GetTensorDim(dilation_, data_format_, '1') == 1 &&
718 GetTensorDim(dilation_, data_format_, '2') == 1),
719 errors::InvalidArgument(
720 "Current CPU implementation does not yet support "
721 "dilation rates larger than 1."));
722
723 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
724 OP_REQUIRES(context, stride_.size() == 5,
725 errors::InvalidArgument("Sliding window strides field must "
726 "specify 5 dimensions"));
727 OP_REQUIRES(
728 context,
729 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
730 GetTensorDim(stride_, data_format_, 'N') == 1),
731 errors::InvalidArgument("Current implementation does not yet support "
732 "strides in the batch and depth dimensions."));
733 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
734 }
735
Compute(OpKernelContext * context)736 void Compute(OpKernelContext* context) override {
737 const Tensor& input = context->input(0);
738 const TensorShape& input_shape = input.shape();
739
740 const Tensor& out_backprop = context->input(2);
741 const TensorShape& out_backprop_shape = out_backprop.shape();
742
743 TensorShape filter_shape;
744 if (takes_shape_) {
745 const Tensor& filter_sizes = context->input(1);
746 OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()),
747 errors::InvalidArgument(
748 "filter_sizes shape must be rank 1 but is rank ",
749 filter_sizes.shape().dims()));
750 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
751 filter_sizes.vec<int32>(), &filter_shape));
752 } else {
753 filter_shape = context->input(1).shape();
754 }
755
756 OP_REQUIRES(context, input_shape.dims() == 5,
757 errors::InvalidArgument("input tensor must have 5 dimensions"));
758 OP_REQUIRES(
759 context, filter_shape.dims() == 5,
760 errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
761 OP_REQUIRES(
762 context, out_backprop_shape.dims() == 5,
763 errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
764 OP_REQUIRES(
765 context, input_shape.dim_size(4) == filter_shape.dim_size(3),
766 errors::InvalidArgument("input and filter_sizes must have the same "
767 "number of channels. Got ",
768 input_shape.dim_size(4), " for input and ",
769 filter_shape.dim_size(3), " for filter_sizes"));
770 OP_REQUIRES(
771 context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
772 errors::InvalidArgument("out_backprop and filter_sizes must have the "
773 "same number of channels. Got ",
774 out_backprop_shape.dim_size(4),
775 " for out_backprop and ",
776 filter_shape.dim_size(4), " for filter_sizes"));
777
778 ConvBackpropDimensions dims;
779 OP_REQUIRES_OK(context,
780 ConvBackpropComputeDimensions(
781 "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
782 input_shape, filter_shape, out_backprop_shape, stride_,
783 padding_, data_format_, &dims));
784
785 Tensor* filter_backprop;
786 OP_REQUIRES_OK(context,
787 context->allocate_output(0, filter_shape, &filter_backprop));
788
789 if (input_shape.num_elements() == 0) {
790 filter_backprop->template flat<T>().setZero();
791 return;
792 }
793
794 functor::CuboidConvolutionBackwardFilter<Device, T>()(
795 context->eigen_device<Device>(),
796 filter_backprop->tensor<T, 5>(), // filter_backward
797 input.tensor<T, 5>(), // input
798 out_backprop.tensor<T, 5>(), // output_backward
799 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
800 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
801 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
802 }
803
804 private:
805 std::vector<int32> dilation_;
806 std::vector<int32> stride_;
807 Padding padding_;
808 TensorFormat data_format_;
809 bool takes_shape_;
810
811 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
812 };
813
814 // Custom backprop for filter that explicitly does the work sharding and calls
815 // Eigen only to multiply matrices.
816 template <typename Device, class T>
817 class Conv3DCustomBackpropFilterOp : public OpKernel {
818 // Limit the maximum size of allocated temporary buffer to
819 // kMaxTempAllocationOverhead times the size of the input tensors (input,
820 // filter, out_backprop). If the size of the temporary buffer exceeds this
821 // limit, fallback on Eigen implementation.
822 static constexpr int kMaxTempAllocationOverhead = 25;
823
824 public:
Conv3DCustomBackpropFilterOp(OpKernelConstruction * context)825 explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
826 : OpKernel(context),
827 data_format_(FORMAT_NHWC),
828 takes_shape_(type_string().find("V2") != std::string::npos) {
829 // data_format is only available in V2.
830 if (takes_shape_) {
831 string data_format;
832 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
833 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
834 errors::InvalidArgument("Invalid data format"));
835 OP_REQUIRES(
836 context, data_format_ == FORMAT_NHWC,
837 errors::InvalidArgument(
838 "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
839 }
840
841 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
842 OP_REQUIRES(context, dilation_.size() == 5,
843 errors::InvalidArgument("Dilation rates field must "
844 "specify 5 dimensions"));
845 OP_REQUIRES(context,
846 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
847 GetTensorDim(dilation_, data_format_, 'N') == 1),
848 errors::InvalidArgument(
849 "Current implementation does not yet support "
850 "dilation rates in the batch and depth dimensions."));
851
852 // TODO(yangzihao): Add CPU version of dilated conv 3D.
853 OP_REQUIRES(context,
854 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
855 GetTensorDim(dilation_, data_format_, '1') == 1 &&
856 GetTensorDim(dilation_, data_format_, '2') == 1),
857 errors::InvalidArgument(
858 "Current CPU implementation does not yet support "
859 "dilation rates larger than 1."));
860
861 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
862 OP_REQUIRES(context, stride_.size() == 5,
863 errors::InvalidArgument("Sliding window strides field must "
864 "specify 5 dimensions"));
865 OP_REQUIRES(
866 context,
867 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
868 GetTensorDim(stride_, data_format_, 'N') == 1),
869 errors::InvalidArgument("Current implementation does not yet support "
870 "strides in the batch and depth dimensions."));
871 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
872 }
873
Compute(OpKernelContext * context)874 void Compute(OpKernelContext* context) override {
875 const Tensor& input = context->input(0);
876 const TensorShape& input_shape = input.shape();
877
878 const Tensor& out_backprop = context->input(2);
879 const TensorShape& out_backprop_shape = out_backprop.shape();
880
881 TensorShape filter_shape;
882 if (takes_shape_) {
883 const Tensor& filter_sizes = context->input(1);
884 OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()),
885 errors::InvalidArgument(
886 "filter_sizes shape must be rank 1 but is rank ",
887 filter_sizes.shape().dims()));
888 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
889 filter_sizes.vec<int32>(), &filter_shape));
890 } else {
891 filter_shape = context->input(1).shape();
892 }
893
894 OP_REQUIRES(context, input_shape.dims() == 5,
895 errors::InvalidArgument("input tensor must have 5 dimensions"));
896 OP_REQUIRES(
897 context, filter_shape.dims() == 5,
898 errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
899 OP_REQUIRES(
900 context, out_backprop_shape.dims() == 5,
901 errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
902 OP_REQUIRES(
903 context, input_shape.dim_size(4) == filter_shape.dim_size(3),
904 errors::InvalidArgument("input and filter_sizes must have the same "
905 "number of channels. Got ",
906 input_shape.dim_size(4), " for input and ",
907 filter_shape.dim_size(3), " for filter_sizes"));
908 OP_REQUIRES(
909 context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
910 errors::InvalidArgument("out_backprop and filter_sizes must have the "
911 "same number of channels. Got ",
912 out_backprop_shape.dim_size(4),
913 " for out_backprop and ",
914 filter_shape.dim_size(4), " for filter_sizes"));
915
916 ConvBackpropDimensions dims;
917 OP_REQUIRES_OK(context,
918 ConvBackpropComputeDimensions(
919 "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
920 input_shape, filter_shape, out_backprop_shape, stride_,
921 padding_, data_format_, &dims));
922
923 Tensor* filter_backprop;
924 OP_REQUIRES_OK(context,
925 context->allocate_output(0, filter_shape, &filter_backprop));
926
927 if (input_shape.num_elements() == 0) {
928 filter_backprop->template flat<T>().setZero();
929 return;
930 }
931
932 int64_t top_pad_planes, bottom_pad_planes;
933 int64_t top_pad_rows, bottom_pad_rows;
934 int64_t left_pad_cols, right_pad_cols;
935
936 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
937 dims.spatial_dims[0].input_size,
938 dims.spatial_dims[0].filter_size,
939 dims.spatial_dims[0].stride, padding_,
940 &dims.spatial_dims[0].output_size,
941 &top_pad_planes, &bottom_pad_planes));
942 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
943 dims.spatial_dims[1].input_size,
944 dims.spatial_dims[1].filter_size,
945 dims.spatial_dims[1].stride, padding_,
946 &dims.spatial_dims[1].output_size,
947 &top_pad_rows, &bottom_pad_rows));
948 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
949 dims.spatial_dims[2].input_size,
950 dims.spatial_dims[2].filter_size,
951 dims.spatial_dims[2].stride, padding_,
952 &dims.spatial_dims[2].output_size,
953 &left_pad_cols, &right_pad_cols));
954
955 // TODO(ezhulenev): Extract work size and shard estimation to shared
956 // functions in conv_grad_ops, and update 2d convolution backprop.
957
958 // The total dimension size of each kernel.
959 const int64_t filter_total_size =
960 dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
961 dims.spatial_dims[2].filter_size * dims.in_depth;
962 // The output image size is the spatial size of the output.
963 const int64_t output_image_size = dims.spatial_dims[0].output_size *
964 dims.spatial_dims[1].output_size *
965 dims.spatial_dims[2].output_size;
966
967 // Shard 'batch' images (volumes) into 'shard_size' groups of images
968 // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
969 // dividing the L3 cache size ('target_working_set_size') by the matmul size
970 // of an individual image ('work_unit_size').
971
972 const auto cache_sizes = Eigen::internal::CacheSizes();
973 const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
974
975 // TODO(andydavis)
976 // *) Consider reducing 'target_working_set_size' if L3 is shared by
977 // other concurrently running tensorflow ops.
978 const size_t target_working_set_size = l3_cache_size / sizeof(T);
979
980 const int64_t size_A = output_image_size * filter_total_size;
981
982 const int64_t size_B = output_image_size * dims.out_depth;
983
984 const int64_t size_C = filter_total_size * dims.out_depth;
985
986 const int64_t work_unit_size = size_A + size_B + size_C;
987
988 OP_REQUIRES(
989 context, work_unit_size > 0,
990 errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
991 "must all have at least 1 element"));
992
993 const size_t shard_size =
994 (target_working_set_size + work_unit_size - 1) / work_unit_size;
995
996 // Total number of elements in all the tensors used by this kernel.
997 int64_t total_tensor_elements = input_shape.num_elements() +
998 filter_shape.num_elements() +
999 out_backprop_shape.num_elements();
1000
1001 // Shape of the temporary workspace buffer.
1002 TensorShape col_buffer_shape = {static_cast<int64_t>(shard_size),
1003 static_cast<int64_t>(output_image_size),
1004 static_cast<int64_t>(filter_total_size)};
1005 int64_t col_buffer_elements = col_buffer_shape.num_elements();
1006
1007 // If the temporary allocation overhead is too large, fallback on Eigen
1008 // implementation which requires much less memory.
1009 int64_t col_buffer_overhead = col_buffer_elements / total_tensor_elements;
1010 if (col_buffer_overhead > kMaxTempAllocationOverhead) {
1011 VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
1012 "col_buffer_overhead="
1013 << col_buffer_overhead;
1014
1015 functor::CuboidConvolutionBackwardFilter<Device, T>()(
1016 context->eigen_device<Device>(),
1017 filter_backprop->tensor<T, 5>(), // filter_backward
1018 input.tensor<T, 5>(), // input
1019 out_backprop.tensor<T, 5>(), // output_backward
1020 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
1021 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
1022 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
1023
1024 return;
1025 }
1026
1027 Tensor col_buffer;
1028 OP_REQUIRES_OK(context,
1029 context->allocate_temp(DataTypeToEnum<T>::value,
1030 col_buffer_shape, &col_buffer));
1031
1032 // The input offset corresponding to a single input image.
1033 const int64_t input_offset =
1034 dims.spatial_dims[0].input_size * dims.spatial_dims[1].input_size *
1035 dims.spatial_dims[2].input_size * dims.in_depth;
1036 // The output offset corresponding to a single output image.
1037 const int64_t output_offset =
1038 dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
1039 dims.spatial_dims[2].output_size * dims.out_depth;
1040
1041 const T* input_data = input.template flat<T>().data();
1042 T* col_buffer_data = col_buffer.template flat<T>().data();
1043 const T* out_backprop_data = out_backprop.template flat<T>().data();
1044 T* filter_backprop_data = filter_backprop->template flat<T>().data();
1045
1046 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
1047 Eigen::Unaligned>
1048 TensorMap;
1049 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
1050 Eigen::Unaligned>
1051 ConstTensorMap;
1052
1053 TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
1054 C.setZero();
1055
1056 // Initialize contraction dims (we need to transpose 'A' below).
1057 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
1058 contract_dims[0].first = 0;
1059 contract_dims[0].second = 0;
1060
1061 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
1062
1063 for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
1064 const int shard_limit =
1065 std::min(static_cast<int>(shard_size),
1066 static_cast<int>(dims.batch_size) - image_id);
1067
1068 auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
1069 &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
1070 &bottom_pad_rows, &right_pad_cols, &input_offset,
1071 &size_A](int64_t start, int64_t limit) {
1072 for (int shard_id = start; shard_id < limit; ++shard_id) {
1073 const T* input_data_shard = input_data + shard_id * input_offset;
1074 T* col_data_shard = col_buffer_data + shard_id * size_A;
1075
1076 // When we compute the gradient with respect to the filters, we need
1077 // to do im2col to allow gemm-type computation.
1078 Im2col<T>(input_data_shard, dims.in_depth,
1079 // Input spatial dimensions.
1080 dims.spatial_dims[0].input_size, // input planes
1081 dims.spatial_dims[1].input_size, // input rows
1082 dims.spatial_dims[2].input_size, // input cols
1083 // Filter spatial dimensions.
1084 dims.spatial_dims[0].filter_size, // filter planes
1085 dims.spatial_dims[1].filter_size, // filter rows
1086 dims.spatial_dims[2].filter_size, // filter cols
1087 // Spatial padding.
1088 top_pad_planes, top_pad_rows, left_pad_cols,
1089 bottom_pad_planes, bottom_pad_rows, right_pad_cols,
1090 // Spatial striding.
1091 dims.spatial_dims[0].stride, // stride planes
1092 dims.spatial_dims[1].stride, // stride rows
1093 dims.spatial_dims[2].stride, // stride cols
1094 col_data_shard);
1095 }
1096 };
1097 Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
1098 size_A, shard);
1099
1100 ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
1101 filter_total_size);
1102 ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
1103 dims.out_depth);
1104
1105 // Gradient with respect to filter.
1106 C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
1107
1108 input_data += input_offset * shard_limit;
1109 out_backprop_data += output_offset * shard_limit;
1110 }
1111 }
1112
1113 private:
1114 std::vector<int32> dilation_;
1115 std::vector<int32> stride_;
1116 Padding padding_;
1117 TensorFormat data_format_;
1118 bool takes_shape_;
1119
1120 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
1121 };
1122
1123 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
1124 // default Eigen implementation (at the cost of ~2x-8x peak memory usage).
1125
1126 #define REGISTER_CPU_KERNEL(T) \
1127 REGISTER_KERNEL_BUILDER( \
1128 Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1129 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
1130 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1131 .Device(DEVICE_CPU) \
1132 .TypeConstraint<T>("T"), \
1133 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
1134 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
1135 .Device(DEVICE_CPU) \
1136 .Label("custom") \
1137 .TypeConstraint<T>("T"), \
1138 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
1139 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1140 .Device(DEVICE_CPU) \
1141 .Label("custom") \
1142 .TypeConstraint<T>("T"), \
1143 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
1144 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
1145 .Device(DEVICE_CPU) \
1146 .Label("eigen_tensor") \
1147 .TypeConstraint<T>("T"), \
1148 Conv3DBackpropFilterOp<CPUDevice, T>); \
1149 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1150 .Device(DEVICE_CPU) \
1151 .Label("eigen_tensor") \
1152 .TypeConstraint<T>("T"), \
1153 Conv3DBackpropFilterOp<CPUDevice, T>);
1154
1155 TF_CALL_float(REGISTER_CPU_KERNEL);
1156 TF_CALL_double(REGISTER_CPU_KERNEL);
1157 #undef REGISTER_CPU_KERNEL
1158
1159 // WARNING: Eigen::half is not trivially copyable and can't be used in
1160 // custom backprop filter kernel because of memcpy and memset in Im2col.
1161 #define REGISTER_CPU_KERNEL(T) \
1162 REGISTER_KERNEL_BUILDER( \
1163 Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1164 Conv3DBackpropFilterOp<CPUDevice, T>); \
1165 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1166 .Device(DEVICE_CPU) \
1167 .TypeConstraint<T>("T"), \
1168 Conv3DBackpropFilterOp<CPUDevice, T>);
1169
1170 TF_CALL_half(REGISTER_CPU_KERNEL);
1171 #undef REGISTER_CPU_KERNEL
1172
1173 // GPU definitions of both ops.
1174 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1175 // Forward declarations of the functor specializations for GPU.
1176 // This ensures that the custom implementation is used instead of the default
1177 // Eigen one (which is used for CPU).
1178 namespace functor {
1179 #define DECLARE_GPU_SPEC(T) \
1180 template <> \
1181 void TransformFilter<GPUDevice, T, int, 5>::operator()( \
1182 const GPUDevice& d, FilterTensorFormat dst_filter_format, \
1183 typename TTypes<T, 5, int>::ConstTensor in, \
1184 typename TTypes<T, 5, int>::Tensor out); \
1185 template <> \
1186 void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
1187 const GPUDevice& d, FilterTensorFormat src_filter_format, \
1188 typename TTypes<T, 5>::ConstTensor in, \
1189 typename TTypes<T, 5>::Tensor out); \
1190 template <> \
1191 void PadInput<GPUDevice, T, int, 5>::operator()( \
1192 const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
1193 const std::array<int, 3>& padding_left, \
1194 const std::array<int, 3>& padding_right, \
1195 typename TTypes<T, 5, int>::Tensor out, TensorFormat format, \
1196 const T& padding_value);
1197
1198 DECLARE_GPU_SPEC(Eigen::half);
1199 DECLARE_GPU_SPEC(float);
1200 DECLARE_GPU_SPEC(double);
1201 #undef DECLARE_GPU_SPEC
1202 } // namespace functor
1203
1204 // A dummy type to group backward data autotune results together.
1205 struct Conv3dBackwardDataAutotuneGroup {
nametensorflow::Conv3dBackwardDataAutotuneGroup1206 static string name() { return "Conv3dBwdData"; }
1207 };
1208
1209 typedef AutotuneSingleton<Conv3dBackwardDataAutotuneGroup, ConvParameters,
1210 AutotuneEntry<se::dnn::ConvOp>>
1211
1212 AutotuneConv3dBwdData;
1213
1214 template <typename T>
1215 class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
1216 public:
Conv3DBackpropInputOp(OpKernelConstruction * context)1217 explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
1218 : OpKernel(context),
1219 data_format_(FORMAT_NHWC),
1220 takes_shape_(type_string().find("V2") != std::string::npos) {
1221 // data_format is only available in V2.
1222 if (takes_shape_) {
1223 string data_format;
1224 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1225 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1226 errors::InvalidArgument("Invalid data format"));
1227 }
1228 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
1229 OP_REQUIRES(context, dilation_.size() == 5,
1230 errors::InvalidArgument("Dilation rates field must "
1231 "specify 5 dimensions"));
1232 OP_REQUIRES(context,
1233 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
1234 GetTensorDim(dilation_, data_format_, 'N') == 1),
1235 errors::InvalidArgument(
1236 "Current implementation does not yet support "
1237 "dilation rates in the batch and depth dimensions."));
1238 OP_REQUIRES(
1239 context,
1240 (GetTensorDim(dilation_, data_format_, '0') > 0 &&
1241 GetTensorDim(dilation_, data_format_, '1') > 0 &&
1242 GetTensorDim(dilation_, data_format_, '2') > 0),
1243 errors::InvalidArgument("Dilated rates should be larger than 0."));
1244 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1245 OP_REQUIRES(context, stride_.size() == 5,
1246 errors::InvalidArgument("Sliding window strides field must "
1247 "specify 5 dimensions"));
1248 OP_REQUIRES(
1249 context,
1250 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
1251 GetTensorDim(stride_, data_format_, 'N') == 1),
1252 errors::InvalidArgument("Current implementation does not yet support "
1253 "strides in the batch and depth dimensions."));
1254 OP_REQUIRES(
1255 context,
1256 (GetTensorDim(stride_, data_format_, '0') > 0 &&
1257 GetTensorDim(stride_, data_format_, '1') > 0 &&
1258 GetTensorDim(stride_, data_format_, '2') > 0),
1259 errors::InvalidArgument("Spatial strides should be larger than 0."));
1260 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1261 cudnn_use_autotune_ = CudnnUseAutotune();
1262 }
Compute(OpKernelContext * context)1263 void Compute(OpKernelContext* context) override {
1264 const Tensor& filter = context->input(1);
1265 const TensorShape& filter_shape = filter.shape();
1266
1267 const Tensor& out_backprop = context->input(2);
1268 const TensorShape& out_backprop_shape = out_backprop.shape();
1269
1270 TensorShape input_shape;
1271 if (takes_shape_) {
1272 const Tensor& input_sizes = context->input(0);
1273 OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
1274 } else {
1275 input_shape = context->input(0).shape();
1276 }
1277
1278 ConvBackpropDimensions dims;
1279 OP_REQUIRES_OK(context, ConvBackpropComputeDimensionsV2(
1280 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
1281 input_shape, filter_shape, out_backprop_shape,
1282 dilation_, stride_, padding_,
1283 /*explicit_paddings=*/{}, data_format_, &dims));
1284
1285 Tensor* in_backprop;
1286 OP_REQUIRES_OK(context,
1287 context->allocate_output(0, input_shape, &in_backprop));
1288
1289 auto* stream = context->op_device_context()->stream();
1290 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
1291
1292 bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
1293 if (!is_grouped_convolution && dims.filter_size(0) == 1 &&
1294 dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
1295 dims.dilation(0) == 1 && dims.dilation(1) == 1 &&
1296 dims.dilation(2) == 1 && dims.stride(0) == 1 && dims.stride(1) == 1 &&
1297 dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) {
1298 const uint64 m = dims.batch_size * dims.input_size(0) *
1299 dims.input_size(1) * dims.input_size(2);
1300 const uint64 k = dims.out_depth;
1301 const uint64 n = dims.in_depth;
1302
1303 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1304 out_backprop.template flat<T>().size());
1305 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
1306 filter.template flat<T>().size());
1307 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
1308 in_backprop->template flat<T>().size());
1309
1310 auto transpose = se::blas::Transpose::kTranspose;
1311 auto no_transpose = se::blas::Transpose::kNoTranspose;
1312
1313 OP_REQUIRES_OK(
1314 context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr,
1315 k, a_ptr, k, &c_ptr, n,
1316 se::blas::kDefaultComputePrecision));
1317 return;
1318 } else if (!is_grouped_convolution &&
1319 dims.filter_size(0) == dims.input_size(0) &&
1320 dims.filter_size(1) == dims.input_size(1) &&
1321 dims.filter_size(2) == dims.input_size(2) &&
1322 padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
1323 const uint64 m = dims.batch_size;
1324 const uint64 k = dims.out_depth;
1325 const uint64 n = dims.input_size(0) * dims.input_size(1) *
1326 dims.input_size(2) * dims.in_depth;
1327
1328 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1329 out_backprop.template flat<T>().size());
1330 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
1331 filter.template flat<T>().size());
1332 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
1333 in_backprop->template flat<T>().size());
1334
1335 auto transpose = se::blas::Transpose::kTranspose;
1336 auto no_transpose = se::blas::Transpose::kNoTranspose;
1337
1338 OP_REQUIRES_OK(
1339 context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr,
1340 k, a_ptr, k, &c_ptr, n,
1341 se::blas::kDefaultComputePrecision));
1342 return;
1343 }
1344
1345 int padding_planes = dims.SpatialPadding(padding_, 0);
1346 int padding_rows = dims.SpatialPadding(padding_, 1);
1347 int padding_cols = dims.SpatialPadding(padding_, 2);
1348 const bool planes_odd = (padding_planes % 2 != 0);
1349 const bool rows_odd = (padding_rows % 2 != 0);
1350 const bool cols_odd = (padding_cols % 2 != 0);
1351
1352 TensorShape compatible_input_shape;
1353 if (rows_odd || cols_odd || planes_odd) {
1354 // cuDNN only supports the same amount of padding on both sides.
1355 compatible_input_shape = {
1356 dims.batch_size,
1357 dims.in_depth,
1358 dims.input_size(0) + planes_odd,
1359 dims.input_size(1) + rows_odd,
1360 dims.input_size(2) + cols_odd,
1361 };
1362 } else {
1363 compatible_input_shape = {dims.batch_size, dims.in_depth,
1364 dims.input_size(0), dims.input_size(1),
1365 dims.input_size(2)};
1366 }
1367
1368 CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
1369 << "Negative paddings: (" << padding_rows << ", " << padding_cols
1370 << ", " << padding_planes << ")";
1371
1372 #if GOOGLE_CUDA
1373 const bool compute_in_nhwc =
1374 CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
1375 #else
1376 // fast NDHWC implementation is a CUDA only feature
1377 const bool compute_in_nhwc = false;
1378 #endif
1379 const TensorFormat compute_data_format =
1380 (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
1381 : FORMAT_NCHW;
1382
1383 VLOG(3) << "Compute Conv3DBackpropInput with cuDNN:"
1384 << " data_format=" << ToString(data_format_)
1385 << " compute_data_format=" << ToString(compute_data_format);
1386
1387 constexpr auto kComputeInNHWC =
1388 std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1389 se::dnn::FilterLayout::kOutputYXInput);
1390 constexpr auto kComputeInNCHW =
1391 std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1392 se::dnn::FilterLayout::kOutputInputYX);
1393
1394 se::dnn::DataLayout compute_data_layout;
1395 se::dnn::FilterLayout filter_layout;
1396
1397 std::tie(compute_data_layout, filter_layout) =
1398 compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1399
1400 se::dnn::BatchDescriptor input_desc(3);
1401 input_desc.set_count(dims.batch_size)
1402 .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
1403 .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
1404 .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
1405 .set_feature_map_count(dims.in_depth)
1406 .set_layout(compute_data_layout);
1407 se::dnn::BatchDescriptor output_desc(3);
1408 output_desc.set_count(dims.batch_size)
1409 .set_spatial_dim(DimIndex::X, dims.output_size(2))
1410 .set_spatial_dim(DimIndex::Y, dims.output_size(1))
1411 .set_spatial_dim(DimIndex::Z, dims.output_size(0))
1412 .set_feature_map_count(dims.out_depth)
1413 .set_layout(compute_data_layout);
1414 se::dnn::FilterDescriptor filter_desc(3);
1415 filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
1416 .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
1417 .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
1418 .set_input_feature_map_count(filter_shape.dim_size(3))
1419 .set_output_feature_map_count(filter_shape.dim_size(4))
1420 .set_layout(filter_layout);
1421 se::dnn::ConvolutionDescriptor conv_desc(3);
1422 conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
1423 .set_dilation_rate(DimIndex::Y, dims.dilation(1))
1424 .set_dilation_rate(DimIndex::Z, dims.dilation(0))
1425 .set_filter_stride(DimIndex::X, dims.stride(2))
1426 .set_filter_stride(DimIndex::Y, dims.stride(1))
1427 .set_filter_stride(DimIndex::Z, dims.stride(0))
1428 .set_zero_padding(DimIndex::X, padding_cols / 2)
1429 .set_zero_padding(DimIndex::Y, padding_rows / 2)
1430 .set_zero_padding(DimIndex::Z, padding_planes / 2)
1431 .set_group_count(dims.in_depth / filter_shape.dim_size(3));
1432
1433 // Shape: out, in, z, y, x.
1434 Tensor transformed_filter;
1435 auto dst_format =
1436 compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
1437 TensorShape dst_shape =
1438 dst_format == FORMAT_OIHW
1439 ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
1440 dims.filter_size(0), dims.filter_size(1),
1441 dims.filter_size(2)})
1442 : TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
1443 dims.filter_size(1), dims.filter_size(2),
1444 filter_shape.dim_size(3)});
1445 OP_REQUIRES_OK(context,
1446 context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1447 &transformed_filter));
1448
1449 functor::TransformFilter<GPUDevice, T, int, 5>()(
1450 context->eigen_device<GPUDevice>(), dst_format,
1451 To32Bit(filter.tensor<T, 5>()),
1452 To32Bit(transformed_filter.tensor<T, 5>()));
1453
1454 // Shape: batch, filters, z, y, x.
1455 Tensor transformed_out_backprop;
1456 if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1457 TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
1458 dims.output_size(0), dims.output_size(1),
1459 dims.output_size(2)};
1460 if (dims.out_depth > 1) {
1461 OP_REQUIRES_OK(context, context->allocate_temp(
1462 DataTypeToEnum<T>::value, nchw_shape,
1463 &transformed_out_backprop));
1464 functor::NHWCToNCHW<GPUDevice, T, 5>()(
1465 context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
1466 transformed_out_backprop.tensor<T, 5>());
1467 } else {
1468 CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
1469 }
1470 } else {
1471 transformed_out_backprop = out_backprop;
1472 }
1473 // Shape: batch, filters, z, y, x.
1474 Tensor pre_transformed_in_backprop;
1475 OP_REQUIRES_OK(context,
1476 context->allocate_temp(
1477 DataTypeToEnum<T>::value,
1478 ShapeFromFormat(compute_data_format,
1479 compatible_input_shape.dim_size(0),
1480 {{compatible_input_shape.dim_size(2),
1481 compatible_input_shape.dim_size(3),
1482 compatible_input_shape.dim_size(4)}},
1483 compatible_input_shape.dim_size(1)),
1484 &pre_transformed_in_backprop));
1485
1486 auto out_backprop_ptr =
1487 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
1488 transformed_out_backprop.template flat<T>().size());
1489 auto filter_ptr =
1490 AsDeviceMemory(transformed_filter.template flat<T>().data(),
1491 transformed_filter.template flat<T>().size());
1492 auto in_backprop_ptr =
1493 AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
1494 pre_transformed_in_backprop.template flat<T>().size());
1495
1496 static int64_t ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
1497 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 33); // 8GB by default
1498
1499 const int device_id = stream->parent()->device_ordinal();
1500 // To make sure the Conv3DBackpropInputV2 get the correct dtype, we infer
1501 // the dtype from 2nd input, i.e., out_backprop.
1502 DataType dtype = context->input(2).dtype();
1503 const ConvParameters conv_parameters = {
1504 dims.batch_size,
1505 dims.in_depth,
1506 {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
1507 compute_data_format,
1508 dims.out_depth,
1509 {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
1510 {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
1511 {{dims.stride(0), dims.stride(1), dims.stride(2)}},
1512 {{padding_planes, padding_rows, padding_cols}},
1513 dtype,
1514 device_id,
1515 conv_desc.group_count()};
1516
1517 using se::dnn::AlgorithmConfig;
1518 using se::dnn::AlgorithmDesc;
1519 using se::dnn::ProfileResult;
1520
1521 auto entry_or = AutotuneUnfusedConv(
1522 cudnn_use_autotune_, AutotuneConv3dBwdData::GetInstance(),
1523 conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_DATA,
1524 input_desc, in_backprop_ptr, filter_desc, filter_ptr, conv_desc,
1525 output_desc, out_backprop_ptr, ConvolveBackwardDataScratchSize);
1526 OP_REQUIRES_OK(context, entry_or.status());
1527 auto autotune_entry = std::move(entry_or).value();
1528
1529 DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
1530 context);
1531 Status cudnn_launch_status = LaunchAutotunedConv(
1532 autotune_entry, &scratch_allocator,
1533 se::dnn::ConvolutionKind::BACKWARD_DATA, stream, input_desc,
1534 in_backprop_ptr, filter_desc, filter_ptr, conv_desc, output_desc,
1535 out_backprop_ptr);
1536 if (!cudnn_launch_status.ok()) {
1537 context->SetStatus(cudnn_launch_status);
1538 return;
1539 }
1540
1541 if (rows_odd || cols_odd || planes_odd) {
1542 Tensor in_backprop_remove_padding;
1543 OP_REQUIRES_OK(
1544 context, context->allocate_temp(
1545 DataTypeToEnum<T>::value,
1546 ShapeFromFormat(compute_data_format, dims.batch_size,
1547 {{dims.input_size(0), dims.input_size(1),
1548 dims.input_size(2)}},
1549 dims.in_depth),
1550 &in_backprop_remove_padding));
1551
1552 // Remove the padding for odd spatial dimensions.
1553 functor::PadInput<GPUDevice, T, int, 5>()(
1554 context->eigen_device<GPUDevice>(),
1555 To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
1556 .tensor<T, 5>()),
1557 {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
1558 To32Bit(in_backprop_remove_padding.tensor<T, 5>()),
1559 compute_data_format, T{});
1560
1561 pre_transformed_in_backprop = in_backprop_remove_padding;
1562 }
1563
1564 if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1565 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1566 functor::NCHWToNHWC<GPUDevice, T, 5>()(
1567 context->eigen_device<GPUDevice>(),
1568 toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(),
1569 in_backprop->tensor<T, 5>());
1570 } else {
1571 *in_backprop = pre_transformed_in_backprop;
1572 }
1573 }
1574
1575 private:
1576 std::vector<int32> dilation_;
1577 std::vector<int32> stride_;
1578 Padding padding_;
1579 TensorFormat data_format_;
1580 bool takes_shape_;
1581 bool cudnn_use_autotune_;
1582 };
1583
1584 // A dummy type to group backward filter autotune results together.
1585 struct Conv3dBackwardFilterAutotuneGroup {
nametensorflow::Conv3dBackwardFilterAutotuneGroup1586 static string name() { return "Conv3dBwdFilter"; }
1587 };
1588
1589 typedef AutotuneSingleton<Conv3dBackwardFilterAutotuneGroup, ConvParameters,
1590 AutotuneEntry<se::dnn::ConvOp>>
1591 AutotuneConv3dBwdFilter;
1592
1593 template <typename T>
1594 class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
1595 public:
Conv3DBackpropFilterOp(OpKernelConstruction * context)1596 explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
1597 : OpKernel(context),
1598 data_format_(FORMAT_NHWC),
1599 takes_shape_(type_string().find("V2") != std::string::npos) {
1600 // data_format is only available in V2.
1601 if (takes_shape_) {
1602 string data_format;
1603 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1604 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1605 errors::InvalidArgument("Invalid data format"));
1606 }
1607 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
1608 OP_REQUIRES(context, dilation_.size() == 5,
1609 errors::InvalidArgument("Dilation rates field must "
1610 "specify 5 dimensions"));
1611 OP_REQUIRES(context,
1612 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
1613 GetTensorDim(dilation_, data_format_, 'N') == 1),
1614 errors::InvalidArgument(
1615 "Current implementation does not yet support "
1616 "dilation rates in the batch and depth dimensions."));
1617 OP_REQUIRES(
1618 context,
1619 (GetTensorDim(dilation_, data_format_, '0') > 0 &&
1620 GetTensorDim(dilation_, data_format_, '1') > 0 &&
1621 GetTensorDim(dilation_, data_format_, '2') > 0),
1622 errors::InvalidArgument("Dilated rates should be larger than 0."));
1623 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1624 OP_REQUIRES(context, stride_.size() == 5,
1625 errors::InvalidArgument("Sliding window strides field must "
1626 "specify 5 dimensions"));
1627 OP_REQUIRES(
1628 context,
1629 (GetTensorDim(stride_, data_format_, 'C') == 1 &&
1630 GetTensorDim(stride_, data_format_, 'N') == 1),
1631 errors::InvalidArgument("Current implementation does not yet support "
1632 "strides in the batch and depth dimensions."));
1633 OP_REQUIRES(
1634 context,
1635 (GetTensorDim(stride_, data_format_, '0') > 0 &&
1636 GetTensorDim(stride_, data_format_, '1') > 0 &&
1637 GetTensorDim(stride_, data_format_, '2') > 0),
1638 errors::InvalidArgument("Spatial strides should be larger than 0."));
1639 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1640 cudnn_use_autotune_ = CudnnUseAutotune();
1641 }
1642
Compute(OpKernelContext * context)1643 void Compute(OpKernelContext* context) override {
1644 const Tensor& input = context->input(0);
1645 const TensorShape& input_shape = input.shape();
1646
1647 const Tensor& out_backprop = context->input(2);
1648 const TensorShape& out_backprop_shape = out_backprop.shape();
1649
1650 TensorShape filter_shape;
1651 if (takes_shape_) {
1652 const Tensor& filter_sizes = context->input(1);
1653 OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()),
1654 errors::InvalidArgument(
1655 "filter_sizes shape must be rank 1 but is rank ",
1656 filter_sizes.shape().dims()));
1657 OP_REQUIRES_OK(context, tensor::MakeShape(filter_sizes, &filter_shape));
1658 } else {
1659 filter_shape = context->input(1).shape();
1660 }
1661
1662 ConvBackpropDimensions dims;
1663 OP_REQUIRES_OK(
1664 context,
1665 ConvBackpropComputeDimensionsV2(
1666 "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, input_shape,
1667 filter_shape, out_backprop_shape, dilation_, stride_, padding_,
1668 /*explicit_paddings=*/{}, data_format_, &dims));
1669
1670 Tensor* filter_backprop;
1671 OP_REQUIRES_OK(context,
1672 context->allocate_output(0, filter_shape, &filter_backprop));
1673
1674 auto* stream = context->op_device_context()->stream();
1675 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
1676
1677 bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
1678 if (!is_grouped_convolution && dims.filter_size(1) == 1 &&
1679 dims.filter_size(2) == 1 && dims.filter_size(0) == 1 &&
1680 dims.dilation(2) == 1 && dims.dilation(1) == 1 &&
1681 dims.dilation(0) == 1 && dims.stride(2) == 1 && dims.stride(1) == 1 &&
1682 dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) {
1683 const uint64 m = dims.in_depth;
1684 const uint64 k = dims.batch_size * dims.input_size(1) *
1685 dims.input_size(2) * dims.input_size(0);
1686 const uint64 n = dims.out_depth;
1687
1688 // The shape of output backprop is
1689 // [batch, out_z, out_y, out_x, out_depth]
1690 // From cublas's perspective, it is: n x k
1691 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1692 out_backprop.template flat<T>().size());
1693
1694 // The shape of input is:
1695 // [batch, in_z, in_y, in_x, in_depth],
1696 // From cublas's perspective, it is: m x k
1697 auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
1698 input.template flat<T>().size());
1699
1700 // The shape of the filter backprop is:
1701 // [1, 1, 1, in_depth, out_depth]
1702 // From cublas's perspective, it is: n x m
1703 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
1704 filter_backprop->template flat<T>().size());
1705
1706 OP_REQUIRES_OK(context,
1707 stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
1708 se::blas::Transpose::kTranspose, n, m,
1709 k, a_ptr, n, b_ptr, m, &c_ptr, n,
1710 se::blas::kDefaultComputePrecision));
1711 return;
1712 } else if (!is_grouped_convolution &&
1713 dims.filter_size(0) == dims.input_size(0) &&
1714 dims.filter_size(1) == dims.input_size(1) &&
1715 dims.filter_size(2) == dims.input_size(2) &&
1716 padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
1717 const uint64 m = dims.input_size(0) * dims.input_size(1) *
1718 dims.input_size(2) * dims.in_depth;
1719 const uint64 k = dims.batch_size;
1720 const uint64 n = dims.out_depth;
1721
1722 auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
1723 input.template flat<T>().size());
1724 auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1725 out_backprop.template flat<T>().size());
1726 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
1727 filter_backprop->template flat<T>().size());
1728
1729 OP_REQUIRES_OK(context,
1730 stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
1731 se::blas::Transpose::kTranspose, n, m,
1732 k, b_ptr, n, a_ptr, m, &c_ptr, n,
1733 se::blas::kDefaultComputePrecision));
1734 return;
1735 }
1736
1737 int padding_planes = dims.SpatialPadding(padding_, 0);
1738 int padding_rows = dims.SpatialPadding(padding_, 1);
1739 int padding_cols = dims.SpatialPadding(padding_, 2);
1740 const bool planes_odd = (padding_planes % 2 != 0);
1741 const bool rows_odd = (padding_rows % 2 != 0);
1742 const bool cols_odd = (padding_cols % 2 != 0);
1743
1744 Tensor compatible_input;
1745 if (rows_odd || cols_odd || planes_odd) {
1746 OP_REQUIRES_OK(context,
1747 context->allocate_temp(
1748 DataTypeToEnum<T>::value,
1749 ShapeFromFormat(data_format_, dims.batch_size,
1750 {{dims.input_size(0) + planes_odd,
1751 dims.input_size(1) + rows_odd,
1752 dims.input_size(2) + cols_odd}},
1753 dims.in_depth),
1754 &compatible_input));
1755 functor::PadInput<GPUDevice, T, int, 5>()(
1756 context->template eigen_device<GPUDevice>(),
1757 To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
1758 {{planes_odd, rows_odd, cols_odd}},
1759 To32Bit(compatible_input.tensor<T, 5>()), data_format_, T{});
1760 } else {
1761 compatible_input = input;
1762 }
1763
1764 CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
1765 << "Negative paddings: (" << padding_rows << ", " << padding_cols
1766 << ", " << padding_planes << ")";
1767
1768 #if GOOGLE_CUDA
1769 const bool compute_in_nhwc =
1770 CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
1771 #else
1772 // fast NDHWC implementation is a CUDA only feature
1773 const bool compute_in_nhwc = false;
1774 #endif
1775 const TensorFormat compute_data_format =
1776 (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
1777 : FORMAT_NCHW;
1778
1779 VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:"
1780 << " data_format=" << ToString(data_format_)
1781 << " compute_data_format=" << ToString(compute_data_format);
1782
1783 constexpr auto kComputeInNHWC =
1784 std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1785 se::dnn::FilterLayout::kOutputYXInput);
1786 constexpr auto kComputeInNCHW =
1787 std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1788 se::dnn::FilterLayout::kOutputInputYX);
1789
1790 se::dnn::DataLayout compute_data_layout;
1791 se::dnn::FilterLayout filter_layout;
1792
1793 std::tie(compute_data_layout, filter_layout) =
1794 compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1795
1796 se::dnn::BatchDescriptor input_desc(3);
1797 input_desc.set_count(dims.batch_size)
1798 .set_spatial_dim(DimIndex::X,
1799 GetTensorDim(compatible_input, data_format_, '2'))
1800 .set_spatial_dim(DimIndex::Y,
1801 GetTensorDim(compatible_input, data_format_, '1'))
1802 .set_spatial_dim(DimIndex::Z,
1803 GetTensorDim(compatible_input, data_format_, '0'))
1804 .set_feature_map_count(dims.in_depth)
1805 .set_layout(compute_data_layout);
1806 se::dnn::BatchDescriptor output_desc(3);
1807 output_desc.set_count(dims.batch_size)
1808 .set_spatial_dim(DimIndex::X, dims.output_size(2))
1809 .set_spatial_dim(DimIndex::Y, dims.output_size(1))
1810 .set_spatial_dim(DimIndex::Z, dims.output_size(0))
1811 .set_feature_map_count(dims.out_depth)
1812 .set_layout(compute_data_layout);
1813 se::dnn::FilterDescriptor filter_desc(3);
1814 filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
1815 .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
1816 .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
1817 .set_input_feature_map_count(filter_shape.dim_size(3))
1818 .set_output_feature_map_count(filter_shape.dim_size(4))
1819 .set_layout(filter_layout);
1820 se::dnn::ConvolutionDescriptor conv_desc(3);
1821 conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
1822 .set_dilation_rate(DimIndex::Y, dims.dilation(1))
1823 .set_dilation_rate(DimIndex::Z, dims.dilation(0))
1824 .set_filter_stride(DimIndex::X, dims.stride(2))
1825 .set_filter_stride(DimIndex::Y, dims.stride(1))
1826 .set_filter_stride(DimIndex::Z, dims.stride(0))
1827 .set_zero_padding(DimIndex::X, padding_cols / 2)
1828 .set_zero_padding(DimIndex::Y, padding_rows / 2)
1829 .set_zero_padding(DimIndex::Z, padding_planes / 2)
1830 .set_group_count(dims.in_depth / filter_shape.dim_size(3));
1831
1832 Tensor pre_transformed_filter_backprop;
1833 auto dst_format =
1834 compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
1835 TensorShape dst_shape =
1836 dst_format == FORMAT_OIHW
1837 ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
1838 dims.filter_size(0), dims.filter_size(1),
1839 dims.filter_size(2)})
1840 : TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
1841 dims.filter_size(1), dims.filter_size(2),
1842 filter_shape.dim_size(3)});
1843 OP_REQUIRES_OK(context,
1844 context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1845 &pre_transformed_filter_backprop));
1846
1847 Tensor transformed_out_backprop;
1848 if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1849 VLOG(4) << "Convert the `out_backprop` tensor from NDHWC to NCDHW.";
1850 TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
1851 dims.output_size(0), dims.output_size(1),
1852 dims.output_size(2)};
1853 OP_REQUIRES_OK(
1854 context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
1855 &transformed_out_backprop));
1856 if (dims.out_depth > 1) {
1857 functor::NHWCToNCHW<GPUDevice, T, 5>()(
1858 context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
1859 transformed_out_backprop.tensor<T, 5>());
1860 } else {
1861 CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
1862 }
1863 } else {
1864 transformed_out_backprop = out_backprop;
1865 }
1866 Tensor transformed_input;
1867 if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1868 VLOG(4) << "Convert the `input` tensor from NDHWC to NCDHW.";
1869 TensorShape nchw_shape = {
1870 dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
1871 compatible_input.dim_size(2), compatible_input.dim_size(3)};
1872 if (dims.in_depth > 1) {
1873 OP_REQUIRES_OK(context,
1874 context->allocate_temp(DataTypeToEnum<T>::value,
1875 nchw_shape, &transformed_input));
1876 functor::NHWCToNCHW<GPUDevice, T, 5>()(
1877 context->eigen_device<GPUDevice>(),
1878 const_cast<const Tensor&>(compatible_input).tensor<T, 5>(),
1879 transformed_input.tensor<T, 5>());
1880 } else {
1881 CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
1882 }
1883 } else {
1884 transformed_input = compatible_input;
1885 }
1886
1887 auto out_backprop_ptr =
1888 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
1889 transformed_out_backprop.template flat<T>().size());
1890 auto filter_backprop_ptr = AsDeviceMemory(
1891 pre_transformed_filter_backprop.template flat<T>().data(),
1892 pre_transformed_filter_backprop.template flat<T>().size());
1893 auto input_ptr =
1894 AsDeviceMemory(transformed_input.template flat<T>().data(),
1895 transformed_input.template flat<T>().size());
1896
1897 static int64_t ConvolveBackwardFilterScratchSize =
1898 GetDnnWorkspaceLimitOrDefault();
1899
1900 const int device_id = stream->parent()->device_ordinal();
1901 DataType dtype = input.dtype();
1902 const ConvParameters conv_parameters = {
1903 dims.batch_size,
1904 dims.in_depth,
1905 {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
1906 compute_data_format,
1907 dims.out_depth,
1908 {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
1909 {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
1910 {{dims.stride(0), dims.stride(1), dims.stride(2)}},
1911 {{padding_planes, padding_rows, padding_cols}},
1912 dtype,
1913 device_id,
1914 conv_desc.group_count()};
1915
1916 using se::dnn::AlgorithmConfig;
1917 using se::dnn::AlgorithmDesc;
1918 using se::dnn::ProfileResult;
1919
1920 auto entry_or = AutotuneUnfusedConv(
1921 cudnn_use_autotune_, AutotuneConv3dBwdFilter::GetInstance(),
1922 conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_FILTER,
1923 input_desc, input_ptr, filter_desc, filter_backprop_ptr, conv_desc,
1924 output_desc, out_backprop_ptr, ConvolveBackwardFilterScratchSize);
1925 OP_REQUIRES_OK(context, entry_or.status());
1926 auto autotune_entry = std::move(entry_or).value();
1927
1928 DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
1929 context);
1930 Status cudnn_launch_status = LaunchAutotunedConv(
1931 autotune_entry, &scratch_allocator,
1932 se::dnn::ConvolutionKind::BACKWARD_FILTER, stream, input_desc,
1933 input_ptr, filter_desc, filter_backprop_ptr, conv_desc, output_desc,
1934 out_backprop_ptr);
1935 if (!cudnn_launch_status.ok()) {
1936 context->SetStatus(cudnn_launch_status);
1937 return;
1938 }
1939
1940 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1941 functor::ReverseTransformFilter<GPUDevice, T, 5>()(
1942 context->eigen_device<GPUDevice>(), /*src_filter_format=*/dst_format,
1943 toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
1944 filter_backprop->tensor<T, 5>());
1945 }
1946
1947 private:
1948 std::vector<int32> dilation_;
1949 std::vector<int32> stride_;
1950 Padding padding_;
1951 TensorFormat data_format_;
1952 bool takes_shape_;
1953 bool cudnn_use_autotune_;
1954 };
1955
1956 #define REGISTER_GPU_KERNEL(T) \
1957 REGISTER_KERNEL_BUILDER( \
1958 Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1959 Conv3DBackpropInputOp<GPUDevice, T>); \
1960 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
1961 .Device(DEVICE_GPU) \
1962 .TypeConstraint<T>("T") \
1963 .HostMemory("input_sizes"), \
1964 Conv3DBackpropInputOp<GPUDevice, T>); \
1965 REGISTER_KERNEL_BUILDER( \
1966 Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1967 Conv3DBackpropFilterOp<GPUDevice, T>); \
1968 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
1969 .Device(DEVICE_GPU) \
1970 .TypeConstraint<T>("T") \
1971 .HostMemory("filter_sizes"), \
1972 Conv3DBackpropFilterOp<GPUDevice, T>);
1973 TF_CALL_half(REGISTER_GPU_KERNEL);
1974 TF_CALL_float(REGISTER_GPU_KERNEL);
1975 TF_CALL_double(REGISTER_GPU_KERNEL);
1976 #undef REGISTER_GPU_KERNEL
1977
1978 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1979
1980 } // namespace tensorflow
1981