xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_conv_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_
18 
19 #ifdef INTEL_MKL
20 #include <limits>
21 #include <memory>
22 #include <vector>
23 
24 #include "dnnl.hpp"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/kernel_shape_util.h"
27 #include "tensorflow/core/framework/numeric_op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_slice.h"
33 #include "tensorflow/core/kernels/conv_grad_ops.h"
34 #include "tensorflow/core/kernels/ops_util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/array_slice.h"
37 #include "tensorflow/core/lib/strings/numbers.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/platform/macros.h"
40 #include "tensorflow/core/util/mkl_util.h"
41 #include "tensorflow/core/util/onednn_env_vars.h"
42 #include "tensorflow/core/util/padding.h"
43 #include "tensorflow/core/util/tensor_format.h"
44 
45 using dnnl::convolution_forward;
46 using dnnl::prop_kind;
47 using dnnl::stream;
48 
49 namespace tensorflow {
50 
51 using ConvFwdDesc = dnnl::convolution_forward::desc;
52 using ConvFwdPd = dnnl::convolution_forward::primitive_desc;
53 
54 class MklDnnConvUtil {
55  protected:
56   OpKernelContext* context_;  // We don't own this.
57   std::vector<int32> strides_;
58   std::vector<int32> dilations_;
59   Padding padding_;
60   TensorFormat data_format_;
61 
62  public:
63   MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides,
64                  Padding pad, TensorFormat fm,
65                  const std::vector<int32>& dilations, bool is_depthwise = false)
context_(context)66       : context_(context),
67         strides_(strides),
68         dilations_(dilations),
69         padding_(pad),
70         data_format_(fm) {}
71 
~MklDnnConvUtil()72   virtual ~MklDnnConvUtil() { context_ = nullptr; }
73 
74   // Calculate Convolution strides
GetStridesInMklOrder(memory::dims * strides)75   virtual inline void GetStridesInMklOrder(memory::dims* strides) {
76     // For now we take the stride from the second and third dimensions only
77     // (we do not support striding on the batch or depth dimension).
78     DCHECK(strides);
79     if (strides_.size() == 4) {
80       int stride_rows = GetTensorDim(strides_, data_format_, 'H');
81       int stride_cols = GetTensorDim(strides_, data_format_, 'W');
82       *strides = {stride_rows, stride_cols};
83     } else if (strides_.size() == 5) {
84       int stride_planes = GetTensorDim(strides_, data_format_, '0');
85       int stride_rows = GetTensorDim(strides_, data_format_, '1');
86       int stride_cols = GetTensorDim(strides_, data_format_, '2');
87       *strides = {stride_planes, stride_rows, stride_cols};
88     }
89   }
90 
91   // Calculate Convolution dilations
GetDilationsInMklOrder(memory::dims * dilations)92   virtual inline void GetDilationsInMklOrder(memory::dims* dilations) {
93     // For now we take the dilation from the second and third dimensions only
94     // (we do not support dilation on the batch or depth dimension).
95     DCHECK(dilations);
96     if (dilations_.size() == 4) {
97       int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
98       int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
99       *dilations = {dilations_rows, dilations_cols};
100     } else if (dilations_.size() == 5) {
101       int dilations_planes = GetTensorDim(dilations_, data_format_, '0');
102       int dilations_rows = GetTensorDim(dilations_, data_format_, '1');
103       int dilations_cols = GetTensorDim(dilations_, data_format_, '2');
104       *dilations = {dilations_planes, dilations_rows, dilations_cols};
105     }
106   }
107 
108   // Calculate Convolution input size in oneDNN order. oneDNN
109   // requires input in NCHW/NCDHW format. Function does not return anything.
110   // But errors arising from sanity checks are returned in context's
111   // status.
GetInputSizeInMklOrder(const TensorShape & input_shape,memory::dims * input_dims)112   virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape,
113                                              memory::dims* input_dims) {
114 #define CHECK_BOUNDS(val, err_msg)                                     \
115   do {                                                                 \
116     OP_REQUIRES(context_,                                              \
117                 FastBoundsCheck(val, std::numeric_limits<int>::max()), \
118                 errors::InvalidArgument(err_msg));                     \
119   } while (0)
120 
121     DCHECK(input_dims);
122 
123     // Input channel
124     int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
125     int input_depth = static_cast<int>(input_depth_raw);
126 
127     // Input batch
128     int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
129     CHECK_BOUNDS(input_batch_raw, "Input batch too large");
130     int input_batch = static_cast<int>(input_batch_raw);
131 
132     if (strides_.size() == 4) {  // NCHW format for Conv2D
133       // Input rows/height
134       int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
135       CHECK_BOUNDS(input_rows_raw, "Input rows too large");
136       int input_rows = static_cast<int>(input_rows_raw);
137 
138       // Input columns/width
139       int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
140       CHECK_BOUNDS(input_cols_raw, "Input cols too large");
141       int input_cols = static_cast<int>(input_cols_raw);
142 
143       // oneDNN always requires input in NCHW format Conv2D.
144       std::vector<memory::dim> input_sizes(4, -1);
145       input_sizes[MklDnnDims::Dim_N] = input_batch;
146       input_sizes[MklDnnDims::Dim_C] = input_depth;
147       input_sizes[MklDnnDims::Dim_H] = input_rows;
148       input_sizes[MklDnnDims::Dim_W] = input_cols;
149       *input_dims = input_sizes;
150     } else if (strides_.size() == 5) {  // NCDHW format for Conv3D
151       // Input planes/third-dimension
152       int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0');
153       CHECK_BOUNDS(input_planes_raw, "Input depth too large");
154       int input_planes = static_cast<int>(input_planes_raw);
155 
156       // Input rows/height
157       int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1');
158       CHECK_BOUNDS(input_rows_raw, "Input rows too large");
159       int input_rows = static_cast<int>(input_rows_raw);
160 
161       // Input columns/width
162       int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2');
163       CHECK_BOUNDS(input_cols_raw, "Input cols too large");
164       int input_cols = static_cast<int>(input_cols_raw);
165 
166       // oneDNN always requires input in NCDHW format for Conv3D.
167       std::vector<memory::dim> input_sizes(5, -1);
168       input_sizes[MklDnnDims3D::Dim3d_N] = input_batch;
169       input_sizes[MklDnnDims3D::Dim3d_C] = input_depth;
170       input_sizes[MklDnnDims3D::Dim3d_D] = input_planes;
171       input_sizes[MklDnnDims3D::Dim3d_H] = input_rows;
172       input_sizes[MklDnnDims3D::Dim3d_W] = input_cols;
173       *input_dims = input_sizes;
174     }
175 #undef CHECK_BOUNDS
176   }
177 
178   // Calculate Convolution filter size in oneDNN order.
179   // oneDNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format.
180   // Function does not return anything.
181   // But errors arising from sanity checks are returned in context's
182   // status. This function differs from GetConvFilterSizeInMklOrder in
183   // parameter for input - it accepts src_shape since Convolution Backward
184   // Input gets shape of input tensor rather than actual tensor (Convolution
185   // forward gets actual tensor as input).
186   //
187   // TODO(intel-tf): Add similar function for input and filter in MklShape.
GetFilterSizeInMklOrder(const TensorShape & input_shape,const TensorShape & filter_shape,memory::dims * filter_dims,bool * is_grouped_convolution,bool is_depthwise)188   virtual inline void GetFilterSizeInMklOrder(const TensorShape& input_shape,
189                                               const TensorShape& filter_shape,
190                                               memory::dims* filter_dims,
191                                               bool* is_grouped_convolution,
192                                               bool is_depthwise) {
193     DCHECK(filter_dims);
194 
195     OP_REQUIRES(context_, filter_shape.dims() == strides_.size(),
196                 errors::InvalidArgument((strides_.size() == 4)
197                                             ? "filter must be 4-dimensional: "
198                                             : "filter must be 5-dimensional: ",
199                                         filter_shape.DebugString()));
200 
201     for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) {
202       OP_REQUIRES(context_,
203                   FastBoundsCheck(filter_shape.dim_size(i),
204                                   std::numeric_limits<int>::max()),
205                   errors::InvalidArgument("filter too large"));
206     }
207 
208     int input_depth = GetTensorDim(input_shape, data_format_, 'C');
209 
210     if (strides_.size() == 4) {  // Conv2D
211       // TF filter is always in (rows, cols, in_depth, out_depth) order.
212       int filter_rows =
213           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_H));
214       int filter_cols =
215           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_W));
216       int filter_in_depth =
217           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_I));
218       int filter_out_depth =
219           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_O));
220       OP_REQUIRES(context_, input_depth % filter_in_depth == 0,
221                   errors::InvalidArgument(
222                       "input depth must be evenly divisible by filter depth: ",
223                       input_depth, " vs ", filter_in_depth));
224       *is_grouped_convolution = filter_in_depth != input_depth;
225       int group_count = input_depth / filter_in_depth;
226       // oneDNN always needs filter in OIHW format for regular convolutions
227       // and GOIHW for grouped/depthwise convolutions,
228       // OIHW = (out_depth, in_depth, rows, cols)
229       // GOIHW = (group, out_depth, in_depth, rows, cols)
230       // Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1
231       if (is_depthwise) {
232         std::vector<memory::dim> filter_sizes(5, -1);
233         filter_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth;
234         filter_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth;
235         filter_sizes[MKL_GROUP_FILTER_DIM_I] = 1;
236         filter_sizes[MKL_GROUP_FILTER_DIM_H] = filter_rows;
237         filter_sizes[MKL_GROUP_FILTER_DIM_W] = filter_cols;
238         *filter_dims = filter_sizes;
239       } else if (*is_grouped_convolution) {
240         // TODO(intel-tf): Directly set filter_dims. Same for other places.
241         std::vector<memory::dim> filter_sizes(5, -1);
242         filter_sizes[MKL_GROUP_FILTER_DIM_G] = group_count;
243         filter_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth / group_count;
244         filter_sizes[MKL_GROUP_FILTER_DIM_I] = filter_in_depth;
245         filter_sizes[MKL_GROUP_FILTER_DIM_H] = filter_rows;
246         filter_sizes[MKL_GROUP_FILTER_DIM_W] = filter_cols;
247         *filter_dims = filter_sizes;
248       } else {
249         std::vector<memory::dim> filter_sizes(4, -1);
250         filter_sizes[MklDnnDims::Dim_O] = filter_out_depth;
251         filter_sizes[MklDnnDims::Dim_I] = filter_in_depth;
252         filter_sizes[MklDnnDims::Dim_H] = filter_rows;
253         filter_sizes[MklDnnDims::Dim_W] = filter_cols;
254         *filter_dims = filter_sizes;
255       }
256     } else {  // Conv3D
257       OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3),
258                   errors::InvalidArgument(
259                       "input and filter must have the same depth: ",
260                       input_depth, " vs ", filter_shape.dim_size(3)));
261 
262       // TF filter is always in (planes, rows, cols, in_depth, out_depth) order.
263       int filter_planes =
264           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_P));
265       int filter_rows =
266           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_H));
267       int filter_cols =
268           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_W));
269       int filter_in_depth =
270           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_I));
271       int filter_out_depth =
272           static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_O));
273 
274       // oneDNN always needs filter in OIDHW format.
275       // OIDHW = (out_depth, in_depth, planes, rows, cols)
276       std::vector<memory::dim> filter_sizes(5, -1);
277       filter_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth;
278       filter_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth;
279       filter_sizes[MklDnnDims3D::Dim3d_D] = filter_planes;
280       filter_sizes[MklDnnDims3D::Dim3d_H] = filter_rows;
281       filter_sizes[MklDnnDims3D::Dim3d_W] = filter_cols;
282       *filter_dims = filter_sizes;
283     }
284   }
285 
286   // Calculate Convolution filter size in oneDNN order.
287   // oneDNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format.
288   // Function does not return anything. But errors arising from sanity
289   // checks are returned in context's status.
GetFilterSizeInMklOrder(size_t src_index,size_t filter_index,memory::dims * filter_dims,bool * is_grouped_convolution,bool is_depthwise)290   virtual inline void GetFilterSizeInMklOrder(size_t src_index,
291                                               size_t filter_index,
292                                               memory::dims* filter_dims,
293                                               bool* is_grouped_convolution,
294                                               bool is_depthwise) {
295     DCHECK(filter_dims);
296     GetFilterSizeInMklOrder(GetTfShape(context_, src_index),
297                             GetTfShape(context_, filter_index), filter_dims,
298                             is_grouped_convolution, is_depthwise);
299   }
300 
301   // Calculate Bias size for 2D or 3D Convolution. Function does not
302   // return anything, but may set an error in context status.
GetBiasSizeInMklOrder(size_t bias_index,memory::dims * bias_dims)303   virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
304                                             memory::dims* bias_dims) {
305     const Tensor& bias = MklGetInput(context_, bias_index);
306     if (bias.dims() > 1) {
307       if (strides_.size() == 4) {
308         OP_REQUIRES(
309             context_, bias.dims() <= 4,
310             errors::InvalidArgument("For NHWC format, bias should have  "
311                                     "4 or less dimensions",
312                                     bias.shape().DebugString()));
313       } else if (strides_.size() == 5) {
314         OP_REQUIRES(
315             context_, bias.dims() <= 5,
316             errors::InvalidArgument("For NDHWC format, bias should have  "
317                                     "5 or less dimensions",
318                                     bias.shape().DebugString()));
319       }
320       // Make sure all the dims except channel(last) is 1
321       for (int i = 0; i < bias.dims() - 1; i++) {
322         OP_REQUIRES(
323             context_, bias.dim_size(i) == 1,
324             errors::InvalidArgument("For bias_dims > 1, all except the last "
325                                     "dimension (channel) must be 1: ",
326                                     bias.shape().DebugString()));
327       }
328       *bias_dims = {static_cast<int>(bias.dim_size(bias.dims() - 1))};
329     } else {
330       *bias_dims = {static_cast<int>(bias.dim_size(0))};
331     }
332   }
333 
334   // Function to calculate output and padding size for 2D/3D convolution.
335   //
336   // Calculate output shape of Convolution in oneDNN and TensorFlow order.
337   // oneDNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order.
338   // But TensorFlow output will be in NHWC||NCHW(Conv2D) or
339   // NDHWC||NCDHW(Conv3D) format depending on data format.
340   // Function also calculates left, right, top and bottom pads.
341   // Function does not return any status which is set with context status.
342   //
343   // TODO(intel-tf): Add similar function for input and filter in MklShape.
344   virtual inline void GetOutputAndPadSizeInMklOrder(
345       const TensorShape& input_shape, const TensorShape& filter_shape,
346       const memory::dims& strides, const memory::dims& dilations,
347       memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
348       memory::dims* pad_l, memory::dims* pad_r, bool is_grouped_convolution,
349       bool pad_enabled = false, bool is_depthwise = false) {
350     DCHECK(output_dims_tf_order);
351     DCHECK(output_dims_mkl_order);
352     DCHECK(pad_l);
353     DCHECK(pad_r);
354 
355     bool is_conv2d = (strides_.size() == 4);
356     int input_planes, input_rows, input_cols;
357     if (is_conv2d) {
358       input_rows = GetTensorDim(input_shape, data_format_, 'H');
359       input_cols = GetTensorDim(input_shape, data_format_, 'W');
360     } else {
361       input_planes = GetTensorDim(input_shape, data_format_, '0');
362       input_rows = GetTensorDim(input_shape, data_format_, '1');
363       input_cols = GetTensorDim(input_shape, data_format_, '2');
364     }
365 
366     // Filter dimension
367     // Conv2D:
368     //    First dimension: rows/height.
369     //    Second dimension: cols/width.
370     // Conv3D:
371     //    First dimension: planes/depth.
372     //    Second dimension: rows/height.
373     //    Third dimension: cols/width.
374 
375     int filter_planes, filter_rows, filter_cols;
376     if (is_conv2d) {
377       filter_rows = filter_shape.dim_size(TF_2DFILTER_DIM_H);
378       filter_cols = filter_shape.dim_size(TF_2DFILTER_DIM_W);
379     } else {
380       filter_planes = filter_shape.dim_size(TF_3DFILTER_DIM_P);
381       filter_rows = filter_shape.dim_size(TF_3DFILTER_DIM_H);
382       filter_cols = filter_shape.dim_size(TF_3DFILTER_DIM_W);
383     }
384 
385     int stride_planes, stride_rows, stride_cols;
386     int dilation_planes, dilation_rows, dilation_cols;
387     if (is_conv2d) {
388       // Conv2D stride is a vector of 2 elements: {s_r, s_c}
389       stride_rows = strides[0];
390       stride_cols = strides[1];
391       dilation_rows = dilations[0];
392       dilation_cols = dilations[1];
393     } else {
394       // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c}
395       stride_planes = strides[0];
396       stride_rows = strides[1];
397       stride_cols = strides[2];
398       dilation_planes = dilations[0];
399       dilation_rows = dilations[1];
400       dilation_cols = dilations[2];
401     }
402 
403     // Output batch is same as input batch.
404     int out_batch = GetTensorDim(input_shape, data_format_, 'N');
405     int out_depth;
406 
407     // TODO(intel-tf) add support for 3-D Depthwise
408 
409     // Output depth is same as last dimension for filters for regular
410     // convolutions and group convolutions. For depthwise it is in_depth *
411     // channel_multiplier. The channel_multiplier is the last dimension of
412     // TF filter for depthwise convolutions.
413     if (is_depthwise) {
414       out_depth = (filter_shape.dim_size(TF_2DFILTER_DIM_I) *
415                    filter_shape.dim_size(TF_2DFILTER_DIM_O));
416     } else if (is_grouped_convolution) {
417       out_depth = filter_shape.dim_size(TF_2DFILTER_DIM_O);
418     } else {
419       out_depth = filter_shape.dim_size(
420           is_conv2d ? static_cast<int>(TF_2DFILTER_DIM_O)
421                     : static_cast<int>(TF_3DFILTER_DIM_O));
422     }
423 
424     int64 out_rows = 0, out_cols = 0, out_planes = 0;
425     int64 pad_top = 0, pad_bottom = 0, pad_left = 0, pad_right = 0;
426     int64 pad_front, pad_back;
427 
428     if (is_conv2d) {
429       Padding padding_type;
430       if (pad_enabled) {
431         padding_type = Padding::EXPLICIT;
432         pad_top = static_cast<int64_t>((*pad_l)[0]);
433         pad_left = static_cast<int64_t>((*pad_l)[1]);
434         pad_bottom = static_cast<int64_t>((*pad_r)[0]);
435         pad_right = static_cast<int64_t>((*pad_r)[1]);
436       } else {
437         padding_type = padding_;
438       }
439       OP_REQUIRES_OK(context_,
440                      GetWindowedOutputSizeVerboseV2(
441                          input_rows, filter_rows, dilation_rows, stride_rows,
442                          padding_type, &out_rows, &pad_top, &pad_bottom));
443       OP_REQUIRES_OK(context_,
444                      GetWindowedOutputSizeVerboseV2(
445                          input_cols, filter_cols, dilation_cols, stride_cols,
446                          padding_type, &out_cols, &pad_left, &pad_right));
447     } else {
448       Padding padding_type;
449       if (pad_enabled) {
450         padding_type = Padding::EXPLICIT;
451         pad_front = static_cast<int64>((*pad_l)[0]);
452         pad_top = static_cast<int64>((*pad_l)[1]);
453         pad_left = static_cast<int64>((*pad_l)[2]);
454         pad_back = static_cast<int64>((*pad_r)[0]);
455         pad_bottom = static_cast<int64>((*pad_r)[1]);
456         pad_right = static_cast<int64>((*pad_r)[2]);
457       } else {
458         padding_type = padding_;
459       }
460       OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerboseV2(
461                                    input_planes, filter_planes, dilation_planes,
462                                    stride_planes, padding_type, &out_planes,
463                                    &pad_front, &pad_back));
464       OP_REQUIRES_OK(context_,
465                      GetWindowedOutputSizeVerboseV2(
466                          input_rows, filter_rows, dilation_rows, stride_rows,
467                          padding_type, &out_rows, &pad_top, &pad_bottom));
468       OP_REQUIRES_OK(context_,
469                      GetWindowedOutputSizeVerboseV2(
470                          input_cols, filter_cols, dilation_cols, stride_cols,
471                          padding_type, &out_cols, &pad_left, &pad_right));
472     }
473 
474     if (is_conv2d) {
475       // If pad_enabled, i.e., pad and conv op are fused, then
476       // all pads are already passed from pad op through
477       // *pad_l and *pad_r and they don't need to be set here.
478       if (!pad_enabled) {
479         *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
480         *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
481       }
482     } else {
483       // If pad_enabled, i.e., pad and conv op are fused, then
484       // all pads are already passed from pad op through
485       // *pad_l and *pad_r and they don't need to be set here.
486       if (!pad_enabled) {
487         *pad_l = {static_cast<int>(pad_front), static_cast<int>(pad_top),
488                   static_cast<int>(pad_left)};
489         *pad_r = {static_cast<int>(pad_back), static_cast<int>(pad_bottom),
490                   static_cast<int>(pad_right)};
491       }
492     }
493     // Tensorflow output is in data_format order.
494     //     Conv2D: NHWC or NCHW
495     //     Conv3D: NDHWC or NCDHW
496     // oneDNN uses asymmetric padding.
497     TensorShape out_shape =
498         is_conv2d
499             ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols,
500                               out_depth)
501             : ShapeFromFormat(data_format_, out_batch,
502                               {{out_planes, out_rows, out_cols}}, out_depth);
503     *output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
504     if (is_grouped_convolution) {
505       int out_depth = GetTensorDim(out_shape, data_format_, 'C');
506       int input_depth = GetTensorDim(input_shape, data_format_, 'C');
507       int filter_in_depth =
508           static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_I));
509       int num_groups = input_depth / filter_in_depth;
510       OP_REQUIRES(
511           context_, out_depth % num_groups == 0 && out_depth >= num_groups,
512           errors::InvalidArgument(
513               "output depth must be evenly divisible by number of groups: ",
514               out_depth, " vs ", num_groups));
515     }
516     if (is_conv2d) {
517       // For Conv2D, oneDNN always needs output in NCHW format.
518       std::vector<memory::dim> output_sizes(4, -1);
519       output_sizes[MklDnnDims::Dim_N] = out_batch;
520       output_sizes[MklDnnDims::Dim_C] = out_depth;
521       output_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
522       output_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
523       *output_dims_mkl_order = output_sizes;
524     } else {
525       std::vector<memory::dim> output_sizes(5, -1);
526       output_sizes[MklDnnDims3D::Dim3d_N] = out_batch;
527       output_sizes[MklDnnDims3D::Dim3d_C] = out_depth;
528       output_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes);
529       output_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows);
530       output_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols);
531       *output_dims_mkl_order = output_sizes;
532     }
533   }
534 
535   // Calculate output and pad size of forward Convolution operator.
536   // See comment on GetConvOutputAndPadSizeInMklOrder for parameters.
537   //
538   // Function does not return anything, but sets error in context status.
GetOutputAndPadSizeInMklOrder(size_t src_index,size_t filter_index,const memory::dims & strides,const memory::dims & dilations,memory::dims * output_dims_tf_order,memory::dims * output_dims_mkl_order,memory::dims * pad_l,memory::dims * pad_r,bool is_grouped_convolution,bool is_depthwise)539   inline void GetOutputAndPadSizeInMklOrder(
540       size_t src_index, size_t filter_index, const memory::dims& strides,
541       const memory::dims& dilations, memory::dims* output_dims_tf_order,
542       memory::dims* output_dims_mkl_order, memory::dims* pad_l,
543       memory::dims* pad_r, bool is_grouped_convolution, bool is_depthwise) {
544     DCHECK(output_dims_tf_order);
545     DCHECK(output_dims_mkl_order);
546     DCHECK(pad_l);
547     DCHECK(pad_r);
548 
549     auto input_tf_shape = GetTfShape(context_, src_index);
550     auto filter_tf_shape = GetTfShape(context_, filter_index);
551 
552     if (strides_.size() == 4) {
553       // Conv2D
554       OP_REQUIRES(context_, input_tf_shape.dims() == 4,
555                   errors::InvalidArgument("input must be 4-dimensional",
556                                           input_tf_shape.DebugString()));
557     } else {
558       // Conv3D
559       OP_REQUIRES(context_, input_tf_shape.dims() == 5,
560                   errors::InvalidArgument("input must be 5-dimensional",
561                                           input_tf_shape.DebugString()));
562     }
563 
564     GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides,
565                                   dilations, output_dims_tf_order,
566                                   output_dims_mkl_order, pad_l, pad_r,
567                                   is_grouped_convolution, is_depthwise);
568   }
569 
570   // Wrapper function to calculate input, filter, and output sizes of
571   // Conv2D/Conv3D in MKL order:
572   //     Conv2D: NCHW for input and output; OIHW for filter.
573   //     Conv3D: NCDHW for input and output; OIDHW for filter.
574   // Function also calculates output shape in Tensorflow order.
575   // Additionally, it also calculates strides and paddings.
576   //
577   // Function does not return anything, but sets error in context status.
578   inline void GetConvFwdSizesInMklOrder(
579       const TensorShape& input_shape, const TensorShape& filter_shape,
580       memory::dims* input_dims, memory::dims* filter_dims,
581       memory::dims* strides, memory::dims* dilations,
582       memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
583       memory::dims* pad_l, memory::dims* pad_r, bool* is_grouped_convolution,
584       bool pad_enabled = false, bool is_depthwise = false) {
585     DCHECK(input_dims);
586     DCHECK(filter_dims);
587     DCHECK(strides);
588     DCHECK(dilations);
589     DCHECK(output_dims_tf_order);
590     DCHECK(output_dims_mkl_order);
591     DCHECK(pad_l);
592     DCHECK(pad_r);
593 
594     GetInputSizeInMklOrder(input_shape, input_dims);
595     if (!context_->status().ok()) return;
596     GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims,
597                             is_grouped_convolution, is_depthwise);
598     if (!context_->status().ok()) return;
599     GetStridesInMklOrder(strides);
600     GetDilationsInMklOrder(dilations);
601     GetOutputAndPadSizeInMklOrder(
602         input_shape, filter_shape, *strides, *dilations, output_dims_tf_order,
603         output_dims_mkl_order, pad_l, pad_r, *is_grouped_convolution,
604         pad_enabled, is_depthwise);
605     if (!context_->status().ok()) return;
606   }
607 };
608 
609 /////////////////////////////////////////////////////////////////////
610 ///  Common class that implements ConvBackpropFilter and Input
611 /////////////////////////////////////////////////////////////////////
612 
613 template <typename Device, class T, bool is_depthwise>
614 class MklConvBackpropCommonOp : public OpKernel {
615  public:
~MklConvBackpropCommonOp()616   ~MklConvBackpropCommonOp() {}
MklConvBackpropCommonOp(OpKernelConstruction * context)617   explicit MklConvBackpropCommonOp(OpKernelConstruction* context)
618       : OpKernel(context) {
619     string data_format_str;
620     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
621     OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
622                 errors::InvalidArgument("Invalid data format"));
623     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
624     int stride_n = GetTensorDim(strides_, data_format_, 'N');
625     int stride_c = GetTensorDim(strides_, data_format_, 'C');
626     OP_REQUIRES(
627         context, (stride_n == 1 && stride_c == 1),
628         errors::InvalidArgument("Current implementation does not yet support "
629                                 "strides in the batch and depth dimensions."));
630 
631     // Depthwise Convolution doesn't have dilation parameter
632     if (!is_depthwise) {
633       OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
634       if (strides_.size() == 4) {
635         // Check Conv2D dilations
636         OP_REQUIRES(
637             context, dilations_.size() == 4,
638             errors::InvalidArgument("Sliding window dilations field must "
639                                     "specify 4 dimensions"));
640         int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
641         int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
642         int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
643         int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
644         OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
645                     errors::InvalidArgument(
646                         "Current implementation does not yet support "
647                         "dilations in the batch and depth dimensions."));
648         OP_REQUIRES(
649             context, dilation_h > 0 && dilation_w > 0,
650             errors::InvalidArgument("Dilated rates should be larger than 0."));
651       }
652     } else {
653       // Set dilations as 1 for depthwise conv
654       // for future support to align with Tensorflow
655       dilations_ = {1, 1, 1, 1};
656     }
657 
658     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
659   }
660 
661  protected:
662   // data members accessible to derived classes.
663   std::vector<int32> dilations_;
664   std::vector<int32> strides_;
665   Padding padding_;
666   TensorFormat data_format_;  // NCHW or NHWC
667 };
668 
669 /////////////////////////////////////////////////////////////////////
670 ///  Dummy Mkl op that is just used for operators that are intermediate
671 ///  output of node fusion in the graph
672 /////////////////////////////////////////////////////////////////////
673 
674 template <typename Device, typename T>
675 class MklDummyOp : public OpKernel {
676  public:
~MklDummyOp()677   ~MklDummyOp() {}
678 
MklDummyOp(OpKernelConstruction * context)679   explicit MklDummyOp(OpKernelConstruction* context) : OpKernel(context) {}
680 
Compute(OpKernelContext * context)681   void Compute(OpKernelContext* context) override {
682     TF_CHECK_OK(
683         errors::Unimplemented("This is a dummy op."
684                               "It should not have been invoked."));
685   }
686 };
687 
688 }  // namespace tensorflow
689 
690 #endif  // INTEL_MKL
691 #endif  // TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_
692