xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/depthwise_conv_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/kernels/depthwise_conv_op.h"
19 
20 #include <algorithm>
21 #include <cmath>
22 #include <type_traits>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/kernel_shape_util.h"
27 #include "tensorflow/core/framework/numeric_op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/kernels/conv_ops.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/core/util/padding.h"
40 #include "tensorflow/core/util/tensor_format.h"
41 #include "tensorflow/core/util/use_cudnn.h"
42 #include "tensorflow/core/util/work_sharder.h"
43 
44 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
45 
46 #if GOOGLE_CUDA
47 #include "third_party/gpus/cudnn/cudnn.h"
48 #endif
49 
50 #include "tensorflow/core/platform/stream_executor.h"
51 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
52 
53 namespace tensorflow {
54 
55 // In depthwise convolution, one input is convolved into depth_multipler
56 // outputs and the outputs don't need to be reduced again like what regular
57 // convolution does.
58 //  However, the way to apply filters to inputs is exactly the same as the
59 // regular convolution. Please refer to the regular convolution kernels for
60 // more details.
61 
62 typedef Eigen::ThreadPoolDevice CPUDevice;
63 typedef Eigen::GpuDevice GPUDevice;
64 
65 // Computes the vectorized product of 'input_buffer' and 'filter' and stores
66 // result in 'output' at location specified by 'out_r' and 'out_c'.
67 //
68 // EX:
69 //   in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4
70 //   Both 'input_buffer' and 'filter' are padded to register-width boundaries.
71 //
72 //   input_buffer [rows, cols, in_depth, depth_multiplier]
73 //     [a0, a0, a1, a1] [a2, a2, 0, 0] [b0, b0, b1, b1] [b2, b2, 0, 0]
74 //     [e0, e0, e1, e1] [e2, e2, 0, 0] [f0, f0, f1, f1] [f2, f2, 0, 0]
75 //
76 //   filter [rows, cols, in_depth, depth_multiplier]
77 //     [u0, v0, w0, x0] [y0, z0, 0, 0] [u1, v1, w1, x1] [y1, z1, 0, 0]
78 //     [u2, v2, w2, x2] [y2, z2, 0, 0] [u3, v3, w3, x3] [y3, z3, 0, 0]
79 //
80 //   First output register [in_depth, depth_multiplier]
81 //     [q0, q1, q2, q3] = ([a0, a0, a1, a1] x [u0, v0, w0, x0]) +
82 //                        ([b0, b0, b1, b1] x [u1, v1, w1, x1]) +
83 //                        ([e0, e0, e1, e1] x [u2, v2, w2, x2]) +
84 //                        ([f0, f0, f1, f1] x [u3, v3, w3, x3])
85 //
86 // TODO(andydavis) Experiment with processing multiple inputs per input buffer.
87 template <typename T>
88 struct DepthwiseConv2DKernel {
Runtensorflow::DepthwiseConv2DKernel89   static void Run(const DepthwiseArgs& args,
90                   const int64_t padded_filter_inner_dim_size,
91                   const int64_t out_r, const int64_t out_c, const T* filter,
92                   const T* input_buffer, T* output, TensorFormat data_format) {
93     typedef typename Eigen::internal::packet_traits<T>::type Packet;
94     static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
95 
96     const int64_t out_depth = args.out_depth;
97     const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
98     const int64_t output_scalar_size = out_depth % kPacketSize;
99     const int64_t output_vectorized_size =
100         (out_depth / kPacketSize) * kPacketSize;
101     const int64_t base_output_index =
102         (out_r * args.out_cols + out_c) * out_depth;
103 
104     for (int i = 0; i < output_vectorized_size; i += kPacketSize) {
105       // Reset accumulator.
106       auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
107       for (int j = 0; j < filter_spatial_size; ++j) {
108         // Calculate index.
109         const int64_t index = i + j * padded_filter_inner_dim_size;
110         // Load filter.
111         // TODO(andydavis) Unroll 'out_c' loop in caller so we can load
112         // multiple inputs here to amortize the cost of each filter block load.
113         const auto filter_block =
114             Eigen::internal::ploadu<Packet>(filter + index);
115         // Load input.
116         const auto data_block =
117             Eigen::internal::ploadu<Packet>(input_buffer + index);
118         // Vector multiply-add.
119         vaccum =
120             Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum);
121       }
122       // Store vector accumulator to output.
123       Eigen::internal::pstoreu<T>(output + base_output_index + i, vaccum);
124     }
125 
126     if (output_scalar_size > 0) {
127       auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
128       for (int j = 0; j < filter_spatial_size; ++j) {
129         const int64_t index =
130             output_vectorized_size + j * padded_filter_inner_dim_size;
131         const auto filter_block =
132             Eigen::internal::ploadu<Packet>(filter + index);
133         const auto data_block =
134             Eigen::internal::ploadu<Packet>(input_buffer + index);
135         vaccum =
136             Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum);
137       }
138       // Load accumulator into an array and loop through output.
139       T out_buf[kPacketSize];
140       Eigen::internal::pstoreu<T>(out_buf, vaccum);
141       const int64_t last_output_index =
142           base_output_index + output_vectorized_size;
143       for (int j = 0; j < output_scalar_size; ++j) {
144         output[last_output_index + j] = out_buf[j];
145       }
146     }
147   }
148 };
149 
150 // Computes the depthwise conv2d of 'input' by 'depthwise_filter' and stores
151 // the result in 'output'. This implementation trades off copying small patches
152 // of the input to achieve better data alignment, which enables vectorized
153 // load/store and multiply-add operations (see comments at InputBufferCopyOp and
154 // DepthwiseConv2DKernel for details).
155 //
156 // TODO(andydavis) Evaluate the performance of processing multiple input
157 // patches in the inner loop.
158 // TODO(andydavis) Consider a zero-copy implementation for the case when
159 // 'in_depth' is a multiple of register width, and 'depth_multipler' is one.
160 // TODO(andydavis) Evaluate the performance of alternative implementations.
161 template <typename T>
162 struct LaunchDepthwiseConvOp<CPUDevice, T> {
163   typedef typename Eigen::internal::packet_traits<T>::type Packet;
164 
operator ()tensorflow::LaunchDepthwiseConvOp165   void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
166                   const T* input, const T* depthwise_filter, T* output,
167                   TensorFormat data_format) {
168     OP_REQUIRES(
169         ctx, data_format == FORMAT_NHWC,
170         errors::Unimplemented(
171             "Depthwise convolution on CPU is only supported for NHWC format"));
172     static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
173 
174     // Pad 'depthwise_filter' to vector register width (if needed).
175     const bool pad_filter = (args.out_depth % kPacketSize) == 0 ? false : true;
176     Tensor padded_filter;
177     if (pad_filter) {
178       // Allocate space for padded filter.
179       const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
180       const int64_t padded_filter_inner_dim_size =
181           ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
182       OP_REQUIRES_OK(
183           ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
184                                   TensorShape({filter_spatial_size,
185                                                padded_filter_inner_dim_size}),
186                                   &padded_filter));
187       // Write out padded filter.
188       functor::DepthwiseFilterPadOp<T>()(
189           args, depthwise_filter, padded_filter.template flat<T>().data());
190     }
191     const T* filter_data =
192         pad_filter ? padded_filter.template flat<T>().data() : depthwise_filter;
193 
194     // Computes one shard of depthwise conv2d output.
195     auto shard = [&ctx, &args, &input, &filter_data, &output, data_format](
196                      int64_t start, int64_t limit) {
197       static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
198       const int64_t input_image_size =
199           args.in_rows * args.in_cols * args.in_depth;
200       const int64_t output_image_size =
201           args.out_rows * args.out_cols * args.out_depth;
202       const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
203       const int64_t padded_filter_inner_dim_size =
204           ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
205 
206       // Allocate buffer for local input regions.
207       Tensor input_buffer;
208       OP_REQUIRES_OK(
209           ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
210                                   TensorShape({filter_spatial_size,
211                                                padded_filter_inner_dim_size}),
212                                   &input_buffer));
213       T* input_buffer_data = input_buffer.template flat<T>().data();
214 
215       for (int64_t i = start; i < limit; ++i) {
216         const int64_t b = i / args.out_rows;
217         const int64_t in_base = b * input_image_size;
218         const int64_t out_base = b * output_image_size;
219 
220         const int64_t out_r = i % args.out_rows;
221 
222         for (int64_t out_c = 0; out_c < args.out_cols; ++out_c) {
223           // Populate 'input_buffer_data' with data from local input region.
224           functor::DepthwiseInputCopyOp<T>()(args, padded_filter_inner_dim_size,
225                                              out_r, out_c, input + in_base,
226                                              input_buffer_data);
227 
228           // Process buffered input across all filters and store to output.
229           DepthwiseConv2DKernel<T>::Run(
230               args, padded_filter_inner_dim_size, out_r, out_c, filter_data,
231               input_buffer_data, output + out_base, data_format);
232         }
233       }
234     };
235 
236     const int64_t total_shards = args.batch * args.out_rows;
237 
238     // Empirically tested to give reasonable performance boosts at batch size 1
239     // without reducing throughput at batch size 32.
240     const float kCostMultiplier = 2.5f;
241 
242     // TODO(andydavis): Estimate shard cost (in cycles) based on the number of
243     // flops/loads/stores required to compute one shard.
244     const int64_t shard_cost = kCostMultiplier * args.out_cols * args.out_depth;
245 
246     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
247     Shard(worker_threads.num_threads, worker_threads.workers, total_shards,
248           shard_cost, shard);
249   }
250 };
251 
252 // Extern template instantiated in conv_ops.cc.
253 extern template struct LaunchConv2DOp<CPUDevice, bfloat16>;
254 extern template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
255 extern template struct LaunchConv2DOp<CPUDevice, float>;
256 extern template struct LaunchConv2DOp<CPUDevice, double>;
257 
258 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
259 
260 // Extern template instantiated in conv_ops.cc.
261 extern template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
262 extern template struct LaunchConv2DOp<GPUDevice, float>;
263 extern template struct LaunchConv2DOp<GPUDevice, double>;
264 
265 // Extern template instantiated in depthwise_conv_op_gpu.cc.
266 extern template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>;
267 extern template struct LaunchDepthwiseConvOp<GPUDevice, float>;
268 extern template struct LaunchDepthwiseConvOp<GPUDevice, double>;
269 
270 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
271 
272 template <typename Device, typename T>
273 class DepthwiseConv2dNativeOp : public BinaryOp<T> {
274  public:
DepthwiseConv2dNativeOp(OpKernelConstruction * context)275   explicit DepthwiseConv2dNativeOp(OpKernelConstruction* context)
276       : BinaryOp<T>(context) {
277     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
278     string data_format;
279     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
280     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
281                 errors::InvalidArgument("Invalid data format"));
282 
283     OP_REQUIRES(context, strides_.size() == 4,
284                 errors::InvalidArgument("Sliding window strides field must "
285                                         "specify 4 dimensions"));
286     stride_ = GetTensorDim(strides_, data_format_, 'H');
287     const int64_t stride_w = GetTensorDim(strides_, data_format_, 'W');
288     const int64_t stride_n = GetTensorDim(strides_, data_format_, 'N');
289     const int64_t stride_c = GetTensorDim(strides_, data_format_, 'C');
290 
291     OP_REQUIRES(context, stride_ == stride_w,
292                 errors::InvalidArgument(
293                     "Current implementation only supports equal length "
294                     "strides in the row and column dimensions."));
295     OP_REQUIRES(
296         context, (stride_n == 1 && stride_c == 1),
297         errors::InvalidArgument("Current implementation does not yet support "
298                                 "strides in the batch and depth dimensions."));
299     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
300     OP_REQUIRES_OK(context,
301                    context->GetAttr("explicit_paddings", &explicit_paddings_));
302     OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
303                                               /*num_dims=*/4, data_format_));
304 
305     // CPU/GPU kernel currently ignores dilations, so all must be 1.
306     std::vector<int32_t> dilations;
307     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations));
308     bool unit_dilations = true;
309     for (int32_t dilation : dilations) {
310       if (dilation != 1) {
311         unit_dilations = false;
312       }
313     }
314     OP_REQUIRES(context, unit_dilations,
315                 errors::Unimplemented(
316                     "Current kernel implementation does not support "
317                     "dilations, received [",
318                     Eigen::Map<Eigen::Matrix<int32_t, 1, Eigen::Dynamic>>(
319                         dilations.data(), dilations.size()),
320                     "]"));
321 
322     cudnn_use_autotune_ = CudnnUseAutotune();
323     dtype_ = DataTypeToEnum<T>::value;
324 #if CUDNN_VERSION >= 8000
325     // From the cuDNN release note 8.0: We’ve extended the fprop and dgrad
326     // NHWC depthwise kernels to support more combinations (filter
327     // sizes/strides) such as 5x5/1x1, 5x5/2x2, 7x7/1x1, 7x7/2x2 (in addition
328     // to what we already have, 1x1/1x1, 3x3/1x1, 3x3/2x2), which provides
329     // good performance. (https://docs.nvidia.com/deeplearning/sdk/cudnn-
330     // release-notes/rel_8.html#rel_8)
331     use_cudnn_grouped_conv_ =
332         dtype_ == DT_HALF &&
333         (data_format_ == FORMAT_NCHW ||
334          (data_format_ == FORMAT_NHWC && stride_ == stride_w &&
335           (stride_ == 1 || stride_ == 2)));
336 #elif CUDNN_VERSION >= 7603
337     // Use CuDNN grouped conv only when input/output is NCHW and float16(half).
338     // See cudnn release note 7.6.3. (https://docs.nvidia.com/deeplearning/sdk/c
339     // udnn-release-notes/rel_763.html#rel_763)
340     use_cudnn_grouped_conv_ = dtype_ == DT_HALF && data_format_ == FORMAT_NCHW;
341 #else
342     use_cudnn_grouped_conv_ = false;
343 #endif
344   }
345 
Compute(OpKernelContext * context)346   void Compute(OpKernelContext* context) override {
347     // Input tensor is of the following dimensions:
348     // [ batch, in_rows, in_cols, in_depth ]
349     const Tensor& input = context->input(0);
350 
351     // Input filter is of the following dimensions:
352     // [ filter_rows, filter_cols, in_depth, depth_multiplier]
353     const Tensor& filter = context->input(1);
354 
355     // For 2D convolution, there should be 4 dimensions.
356     OP_REQUIRES(context, input.dims() == 4,
357                 errors::InvalidArgument("input must be 4-dimensional",
358                                         input.shape().DebugString()));
359     OP_REQUIRES(context, filter.dims() == 4,
360                 errors::InvalidArgument("filter must be 4-dimensional: ",
361                                         filter.shape().DebugString()));
362 
363     // in_depth for input and filter must match.
364     const int64_t in_depth = GetTensorDim(input, data_format_, 'C');
365     OP_REQUIRES(context, in_depth == filter.dim_size(2),
366                 errors::InvalidArgument(
367                     "input and filter must have the same depth: ", in_depth,
368                     " vs ", filter.dim_size(2)));
369 
370     // The last dimension for filter is depth multiplier.
371     const int32_t depth_multiplier = filter.dim_size(3);
372 
373     // The output depth is input depth x depth multiplier
374     const int32_t out_depth = in_depth * depth_multiplier;
375 
376     const int64_t input_rows_raw = GetTensorDim(input, data_format_, 'H');
377     OP_REQUIRES(
378         context,
379         FastBoundsCheck(input_rows_raw, std::numeric_limits<int32>::max()),
380         errors::InvalidArgument("Input rows too large"));
381     const int32_t input_rows = static_cast<int32>(input_rows_raw);
382     const int32_t filter_rows = filter.dim_size(0);
383 
384     const int64_t input_cols_raw = GetTensorDim(input, data_format_, 'W');
385     OP_REQUIRES(
386         context,
387         FastBoundsCheck(input_cols_raw, std::numeric_limits<int32>::max()),
388         errors::InvalidArgument("Input cols too large"));
389     const int32_t input_cols = static_cast<int32>(input_cols_raw);
390     const int32_t filter_cols = filter.dim_size(1);
391 
392     // The first dimension for input is batch.
393     const int32_t batch = input.dim_size(0);
394 
395     int64_t out_rows = 0, out_cols = 0, pad_top = 0, pad_bottom = 0,
396             pad_left = 0, pad_right = 0;
397     if (padding_ == Padding::EXPLICIT) {
398       GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'H', &pad_top,
399                                &pad_bottom);
400       GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', &pad_left,
401                                &pad_right);
402     }
403     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
404                                 input_rows, filter_rows, stride_, padding_,
405                                 &out_rows, &pad_top, &pad_bottom));
406     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
407                                 input_cols, filter_cols, stride_, padding_,
408                                 &out_cols, &pad_left, &pad_right));
409     TensorShape out_shape =
410         ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
411     OP_REQUIRES(
412         context,
413         (!std::is_same<Device, GPUDevice>::value ||
414          FastBoundsCheck(out_shape.num_elements(),
415                          std::numeric_limits<int32>::max())),
416         errors::InvalidArgument("Output elements too large for GPU kernel"));
417 
418     Tensor* output = nullptr;
419     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
420 
421     // If there is nothing to compute, return.
422     if (out_shape.num_elements() == 0) {
423       return;
424     }
425 
426     // TODO(csigg): Have autotune decide if native is faster than cuDNN.
427     // If in_depth==1, this operation is just a standard convolution.
428     // Depthwise convolution is a special case of cuDNN's grouped convolution.
429     bool use_cudnn =
430         std::is_same<Device, GPUDevice>::value &&
431         (in_depth == 1 || (use_cudnn_grouped_conv_ &&
432                            ShouldCudnnGroupedConvolutionBeUsed(
433                                filter_rows, filter_cols, in_depth, out_depth)));
434 
435     VLOG(2) << "DepthwiseConv2dNative: "
436             << " Input: [" << batch << ", " << input_rows << ", " << input_cols
437             << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
438             << filter_cols << ", " << in_depth << ", " << depth_multiplier
439             << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
440             << ", " << out_depth << "], stride = " << stride_
441             << ", pad_top = " << pad_top << ", pad_left = " << pad_left
442             << ", Use cuDNN: " << use_cudnn;
443 
444     if (use_cudnn) {
445       // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
446       //
447       //                  | TensorFlow       | cuDNN
448       // --------------------------------------------------------------------
449       // filter_out_depth | depth_multiplier | depth_multiplier * group_count
450       // filter_in_depth  | in_depth         | in_depth / group_count
451       //
452       // For depthwise convolution, we have group_count == in_depth.
453       int32_t filter_in_depth = 1;
454       TensorShape shape =
455           TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
456       Tensor reshaped_filter(/*type=*/dtype_);
457       OP_REQUIRES(
458           context, reshaped_filter.CopyFrom(filter, shape),
459           errors::Internal(
460               "Failed to reshape filter tensor for grouped convolution."));
461       // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
462       // conv is supported.
463       launcher_(context, /*use_cudnn=*/true, cudnn_use_autotune_, input,
464                 reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
465                 stride_, stride_, padding_, explicit_paddings_, output,
466                 data_format_);
467       return;
468     }
469 
470     DepthwiseArgs args;
471     args.batch = batch;
472     args.in_rows = input_rows;
473     args.in_cols = input_cols;
474     args.in_depth = in_depth;
475     args.filter_rows = filter_rows;
476     args.filter_cols = filter_cols;
477     args.depth_multiplier = depth_multiplier;
478     args.stride = stride_;
479     args.pad_rows = pad_top;
480     args.pad_cols = pad_left;
481     args.out_rows = out_rows;
482     args.out_cols = out_cols;
483     args.out_depth = out_depth;
484 
485     auto input_ptr = input.template flat<T>().data();
486     auto filter_ptr = filter.template flat<T>().data();
487     auto output_ptr = output->template flat<T>().data();
488     LaunchDepthwiseConvOp<Device, T>()(context, args, input_ptr, filter_ptr,
489                                        output_ptr, data_format_);
490   }
491 
492  protected:
493   bool use_cudnn_grouped_conv_;
494 
495  private:
496   std::vector<int32_t> strides_;
497   Padding padding_;
498   std::vector<int64_t> explicit_paddings_;
499   TensorFormat data_format_;
500 
501   int64_t stride_;  // in height/width dimension.
502 
503   // For in_depth == 1 and grouped convolutions.
504   LaunchConv2DOp<Device, T> launcher_;
505   bool cudnn_use_autotune_;
506   DataType dtype_;
507 
508   TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp);
509 };
510 
511 #define REGISTER_CPU_KERNEL(T)                                                 \
512   REGISTER_KERNEL_BUILDER(                                                     \
513       Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
514       DepthwiseConv2dNativeOp<CPUDevice, T>)
515 
516 TF_CALL_bfloat16(REGISTER_CPU_KERNEL);
517 TF_CALL_half(REGISTER_CPU_KERNEL);
518 TF_CALL_float(REGISTER_CPU_KERNEL);
519 #if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
520 TF_CALL_double(REGISTER_CPU_KERNEL);
521 #endif
522 
523 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
524 
525 #define REGISTER_GPU_KERNEL(T)                                                 \
526   REGISTER_KERNEL_BUILDER(                                                     \
527       Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
528       DepthwiseConv2dNativeOp<GPUDevice, T>)
529 
530 TF_CALL_half(REGISTER_GPU_KERNEL);
531 TF_CALL_float(REGISTER_GPU_KERNEL);
532 TF_CALL_double(REGISTER_GPU_KERNEL);
533 
534 #if CUDNN_VERSION >= 7000
535 template <typename T>
536 class DepthwiseConv2dGroupedConvOp
537     : public DepthwiseConv2dNativeOp<GPUDevice, T> {
538  public:
DepthwiseConv2dGroupedConvOp(OpKernelConstruction * context)539   DepthwiseConv2dGroupedConvOp(OpKernelConstruction* context)
540       : DepthwiseConv2dNativeOp<GPUDevice, T>(context) {
541     this->use_cudnn_grouped_conv_ = true;
542   }
543 };
544 
545 #define REGISTER_GROUPED_CONV_KERNEL(T)                            \
546   REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")            \
547                               .Device(DEVICE_GPU)                  \
548                               .TypeConstraint<T>("T")              \
549                               .Label("cudnn_grouped_convolution"), \
550                           DepthwiseConv2dGroupedConvOp<T>)
551 
552 TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
553 TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
554 TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
555 #endif  // CUDNN_VERSION
556 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
557 
558 }  // namespace tensorflow
559