1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_ 18 19 #ifdef INTEL_MKL 20 #include <limits> 21 #include <memory> 22 #include <vector> 23 24 #include "dnnl.hpp" 25 #include "tensorflow/core/framework/bounds_check.h" 26 #include "tensorflow/core/framework/kernel_shape_util.h" 27 #include "tensorflow/core/framework/numeric_op.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/tensor_slice.h" 33 #include "tensorflow/core/kernels/conv_grad_ops.h" 34 #include "tensorflow/core/kernels/ops_util.h" 35 #include "tensorflow/core/lib/core/errors.h" 36 #include "tensorflow/core/lib/gtl/array_slice.h" 37 #include "tensorflow/core/lib/strings/numbers.h" 38 #include "tensorflow/core/lib/strings/str_util.h" 39 #include "tensorflow/core/platform/macros.h" 40 #include "tensorflow/core/util/mkl_util.h" 41 #include "tensorflow/core/util/onednn_env_vars.h" 42 #include "tensorflow/core/util/padding.h" 43 #include "tensorflow/core/util/tensor_format.h" 44 45 using dnnl::convolution_forward; 46 using dnnl::prop_kind; 47 using dnnl::stream; 48 49 namespace tensorflow { 50 51 using ConvFwdDesc = dnnl::convolution_forward::desc; 52 using ConvFwdPd = dnnl::convolution_forward::primitive_desc; 53 54 class MklDnnConvUtil { 55 protected: 56 OpKernelContext* context_; // We don't own this. 57 std::vector<int32> strides_; 58 std::vector<int32> dilations_; 59 Padding padding_; 60 TensorFormat data_format_; 61 62 public: 63 MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides, 64 Padding pad, TensorFormat fm, 65 const std::vector<int32>& dilations, bool is_depthwise = false) context_(context)66 : context_(context), 67 strides_(strides), 68 dilations_(dilations), 69 padding_(pad), 70 data_format_(fm) {} 71 ~MklDnnConvUtil()72 virtual ~MklDnnConvUtil() { context_ = nullptr; } 73 74 // Calculate Convolution strides GetStridesInMklOrder(memory::dims * strides)75 virtual inline void GetStridesInMklOrder(memory::dims* strides) { 76 // For now we take the stride from the second and third dimensions only 77 // (we do not support striding on the batch or depth dimension). 78 DCHECK(strides); 79 if (strides_.size() == 4) { 80 int stride_rows = GetTensorDim(strides_, data_format_, 'H'); 81 int stride_cols = GetTensorDim(strides_, data_format_, 'W'); 82 *strides = {stride_rows, stride_cols}; 83 } else if (strides_.size() == 5) { 84 int stride_planes = GetTensorDim(strides_, data_format_, '0'); 85 int stride_rows = GetTensorDim(strides_, data_format_, '1'); 86 int stride_cols = GetTensorDim(strides_, data_format_, '2'); 87 *strides = {stride_planes, stride_rows, stride_cols}; 88 } 89 } 90 91 // Calculate Convolution dilations GetDilationsInMklOrder(memory::dims * dilations)92 virtual inline void GetDilationsInMklOrder(memory::dims* dilations) { 93 // For now we take the dilation from the second and third dimensions only 94 // (we do not support dilation on the batch or depth dimension). 95 DCHECK(dilations); 96 if (dilations_.size() == 4) { 97 int dilations_rows = GetTensorDim(dilations_, data_format_, 'H'); 98 int dilations_cols = GetTensorDim(dilations_, data_format_, 'W'); 99 *dilations = {dilations_rows, dilations_cols}; 100 } else if (dilations_.size() == 5) { 101 int dilations_planes = GetTensorDim(dilations_, data_format_, '0'); 102 int dilations_rows = GetTensorDim(dilations_, data_format_, '1'); 103 int dilations_cols = GetTensorDim(dilations_, data_format_, '2'); 104 *dilations = {dilations_planes, dilations_rows, dilations_cols}; 105 } 106 } 107 108 // Calculate Convolution input size in oneDNN order. oneDNN 109 // requires input in NCHW/NCDHW format. Function does not return anything. 110 // But errors arising from sanity checks are returned in context's 111 // status. GetInputSizeInMklOrder(const TensorShape & input_shape,memory::dims * input_dims)112 virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape, 113 memory::dims* input_dims) { 114 #define CHECK_BOUNDS(val, err_msg) \ 115 do { \ 116 OP_REQUIRES(context_, \ 117 FastBoundsCheck(val, std::numeric_limits<int>::max()), \ 118 errors::InvalidArgument(err_msg)); \ 119 } while (0) 120 121 DCHECK(input_dims); 122 123 // Input channel 124 int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C'); 125 int input_depth = static_cast<int>(input_depth_raw); 126 127 // Input batch 128 int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N'); 129 CHECK_BOUNDS(input_batch_raw, "Input batch too large"); 130 int input_batch = static_cast<int>(input_batch_raw); 131 132 if (strides_.size() == 4) { // NCHW format for Conv2D 133 // Input rows/height 134 int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H'); 135 CHECK_BOUNDS(input_rows_raw, "Input rows too large"); 136 int input_rows = static_cast<int>(input_rows_raw); 137 138 // Input columns/width 139 int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W'); 140 CHECK_BOUNDS(input_cols_raw, "Input cols too large"); 141 int input_cols = static_cast<int>(input_cols_raw); 142 143 // oneDNN always requires input in NCHW format Conv2D. 144 std::vector<memory::dim> input_sizes(4, -1); 145 input_sizes[MklDnnDims::Dim_N] = input_batch; 146 input_sizes[MklDnnDims::Dim_C] = input_depth; 147 input_sizes[MklDnnDims::Dim_H] = input_rows; 148 input_sizes[MklDnnDims::Dim_W] = input_cols; 149 *input_dims = input_sizes; 150 } else if (strides_.size() == 5) { // NCDHW format for Conv3D 151 // Input planes/third-dimension 152 int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0'); 153 CHECK_BOUNDS(input_planes_raw, "Input depth too large"); 154 int input_planes = static_cast<int>(input_planes_raw); 155 156 // Input rows/height 157 int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1'); 158 CHECK_BOUNDS(input_rows_raw, "Input rows too large"); 159 int input_rows = static_cast<int>(input_rows_raw); 160 161 // Input columns/width 162 int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2'); 163 CHECK_BOUNDS(input_cols_raw, "Input cols too large"); 164 int input_cols = static_cast<int>(input_cols_raw); 165 166 // oneDNN always requires input in NCDHW format for Conv3D. 167 std::vector<memory::dim> input_sizes(5, -1); 168 input_sizes[MklDnnDims3D::Dim3d_N] = input_batch; 169 input_sizes[MklDnnDims3D::Dim3d_C] = input_depth; 170 input_sizes[MklDnnDims3D::Dim3d_D] = input_planes; 171 input_sizes[MklDnnDims3D::Dim3d_H] = input_rows; 172 input_sizes[MklDnnDims3D::Dim3d_W] = input_cols; 173 *input_dims = input_sizes; 174 } 175 #undef CHECK_BOUNDS 176 } 177 178 // Calculate Convolution filter size in oneDNN order. 179 // oneDNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format. 180 // Function does not return anything. 181 // But errors arising from sanity checks are returned in context's 182 // status. This function differs from GetConvFilterSizeInMklOrder in 183 // parameter for input - it accepts src_shape since Convolution Backward 184 // Input gets shape of input tensor rather than actual tensor (Convolution 185 // forward gets actual tensor as input). 186 // 187 // TODO(intel-tf): Add similar function for input and filter in MklShape. GetFilterSizeInMklOrder(const TensorShape & input_shape,const TensorShape & filter_shape,memory::dims * filter_dims,bool * is_grouped_convolution,bool is_depthwise)188 virtual inline void GetFilterSizeInMklOrder(const TensorShape& input_shape, 189 const TensorShape& filter_shape, 190 memory::dims* filter_dims, 191 bool* is_grouped_convolution, 192 bool is_depthwise) { 193 DCHECK(filter_dims); 194 195 OP_REQUIRES(context_, filter_shape.dims() == strides_.size(), 196 errors::InvalidArgument((strides_.size() == 4) 197 ? "filter must be 4-dimensional: " 198 : "filter must be 5-dimensional: ", 199 filter_shape.DebugString())); 200 201 for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) { 202 OP_REQUIRES(context_, 203 FastBoundsCheck(filter_shape.dim_size(i), 204 std::numeric_limits<int>::max()), 205 errors::InvalidArgument("filter too large")); 206 } 207 208 int input_depth = GetTensorDim(input_shape, data_format_, 'C'); 209 210 if (strides_.size() == 4) { // Conv2D 211 // TF filter is always in (rows, cols, in_depth, out_depth) order. 212 int filter_rows = 213 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_H)); 214 int filter_cols = 215 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_W)); 216 int filter_in_depth = 217 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_I)); 218 int filter_out_depth = 219 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_O)); 220 OP_REQUIRES(context_, input_depth % filter_in_depth == 0, 221 errors::InvalidArgument( 222 "input depth must be evenly divisible by filter depth: ", 223 input_depth, " vs ", filter_in_depth)); 224 *is_grouped_convolution = filter_in_depth != input_depth; 225 int group_count = input_depth / filter_in_depth; 226 // oneDNN always needs filter in OIHW format for regular convolutions 227 // and GOIHW for grouped/depthwise convolutions, 228 // OIHW = (out_depth, in_depth, rows, cols) 229 // GOIHW = (group, out_depth, in_depth, rows, cols) 230 // Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1 231 if (is_depthwise) { 232 std::vector<memory::dim> filter_sizes(5, -1); 233 filter_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth; 234 filter_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth; 235 filter_sizes[MKL_GROUP_FILTER_DIM_I] = 1; 236 filter_sizes[MKL_GROUP_FILTER_DIM_H] = filter_rows; 237 filter_sizes[MKL_GROUP_FILTER_DIM_W] = filter_cols; 238 *filter_dims = filter_sizes; 239 } else if (*is_grouped_convolution) { 240 // TODO(intel-tf): Directly set filter_dims. Same for other places. 241 std::vector<memory::dim> filter_sizes(5, -1); 242 filter_sizes[MKL_GROUP_FILTER_DIM_G] = group_count; 243 filter_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth / group_count; 244 filter_sizes[MKL_GROUP_FILTER_DIM_I] = filter_in_depth; 245 filter_sizes[MKL_GROUP_FILTER_DIM_H] = filter_rows; 246 filter_sizes[MKL_GROUP_FILTER_DIM_W] = filter_cols; 247 *filter_dims = filter_sizes; 248 } else { 249 std::vector<memory::dim> filter_sizes(4, -1); 250 filter_sizes[MklDnnDims::Dim_O] = filter_out_depth; 251 filter_sizes[MklDnnDims::Dim_I] = filter_in_depth; 252 filter_sizes[MklDnnDims::Dim_H] = filter_rows; 253 filter_sizes[MklDnnDims::Dim_W] = filter_cols; 254 *filter_dims = filter_sizes; 255 } 256 } else { // Conv3D 257 OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3), 258 errors::InvalidArgument( 259 "input and filter must have the same depth: ", 260 input_depth, " vs ", filter_shape.dim_size(3))); 261 262 // TF filter is always in (planes, rows, cols, in_depth, out_depth) order. 263 int filter_planes = 264 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_P)); 265 int filter_rows = 266 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_H)); 267 int filter_cols = 268 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_W)); 269 int filter_in_depth = 270 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_I)); 271 int filter_out_depth = 272 static_cast<int>(filter_shape.dim_size(TF_3DFILTER_DIM_O)); 273 274 // oneDNN always needs filter in OIDHW format. 275 // OIDHW = (out_depth, in_depth, planes, rows, cols) 276 std::vector<memory::dim> filter_sizes(5, -1); 277 filter_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth; 278 filter_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth; 279 filter_sizes[MklDnnDims3D::Dim3d_D] = filter_planes; 280 filter_sizes[MklDnnDims3D::Dim3d_H] = filter_rows; 281 filter_sizes[MklDnnDims3D::Dim3d_W] = filter_cols; 282 *filter_dims = filter_sizes; 283 } 284 } 285 286 // Calculate Convolution filter size in oneDNN order. 287 // oneDNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format. 288 // Function does not return anything. But errors arising from sanity 289 // checks are returned in context's status. GetFilterSizeInMklOrder(size_t src_index,size_t filter_index,memory::dims * filter_dims,bool * is_grouped_convolution,bool is_depthwise)290 virtual inline void GetFilterSizeInMklOrder(size_t src_index, 291 size_t filter_index, 292 memory::dims* filter_dims, 293 bool* is_grouped_convolution, 294 bool is_depthwise) { 295 DCHECK(filter_dims); 296 GetFilterSizeInMklOrder(GetTfShape(context_, src_index), 297 GetTfShape(context_, filter_index), filter_dims, 298 is_grouped_convolution, is_depthwise); 299 } 300 301 // Calculate Bias size for 2D or 3D Convolution. Function does not 302 // return anything, but may set an error in context status. GetBiasSizeInMklOrder(size_t bias_index,memory::dims * bias_dims)303 virtual inline void GetBiasSizeInMklOrder(size_t bias_index, 304 memory::dims* bias_dims) { 305 const Tensor& bias = MklGetInput(context_, bias_index); 306 if (bias.dims() > 1) { 307 if (strides_.size() == 4) { 308 OP_REQUIRES( 309 context_, bias.dims() <= 4, 310 errors::InvalidArgument("For NHWC format, bias should have " 311 "4 or less dimensions", 312 bias.shape().DebugString())); 313 } else if (strides_.size() == 5) { 314 OP_REQUIRES( 315 context_, bias.dims() <= 5, 316 errors::InvalidArgument("For NDHWC format, bias should have " 317 "5 or less dimensions", 318 bias.shape().DebugString())); 319 } 320 // Make sure all the dims except channel(last) is 1 321 for (int i = 0; i < bias.dims() - 1; i++) { 322 OP_REQUIRES( 323 context_, bias.dim_size(i) == 1, 324 errors::InvalidArgument("For bias_dims > 1, all except the last " 325 "dimension (channel) must be 1: ", 326 bias.shape().DebugString())); 327 } 328 *bias_dims = {static_cast<int>(bias.dim_size(bias.dims() - 1))}; 329 } else { 330 *bias_dims = {static_cast<int>(bias.dim_size(0))}; 331 } 332 } 333 334 // Function to calculate output and padding size for 2D/3D convolution. 335 // 336 // Calculate output shape of Convolution in oneDNN and TensorFlow order. 337 // oneDNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order. 338 // But TensorFlow output will be in NHWC||NCHW(Conv2D) or 339 // NDHWC||NCDHW(Conv3D) format depending on data format. 340 // Function also calculates left, right, top and bottom pads. 341 // Function does not return any status which is set with context status. 342 // 343 // TODO(intel-tf): Add similar function for input and filter in MklShape. 344 virtual inline void GetOutputAndPadSizeInMklOrder( 345 const TensorShape& input_shape, const TensorShape& filter_shape, 346 const memory::dims& strides, const memory::dims& dilations, 347 memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, 348 memory::dims* pad_l, memory::dims* pad_r, bool is_grouped_convolution, 349 bool pad_enabled = false, bool is_depthwise = false) { 350 DCHECK(output_dims_tf_order); 351 DCHECK(output_dims_mkl_order); 352 DCHECK(pad_l); 353 DCHECK(pad_r); 354 355 bool is_conv2d = (strides_.size() == 4); 356 int input_planes, input_rows, input_cols; 357 if (is_conv2d) { 358 input_rows = GetTensorDim(input_shape, data_format_, 'H'); 359 input_cols = GetTensorDim(input_shape, data_format_, 'W'); 360 } else { 361 input_planes = GetTensorDim(input_shape, data_format_, '0'); 362 input_rows = GetTensorDim(input_shape, data_format_, '1'); 363 input_cols = GetTensorDim(input_shape, data_format_, '2'); 364 } 365 366 // Filter dimension 367 // Conv2D: 368 // First dimension: rows/height. 369 // Second dimension: cols/width. 370 // Conv3D: 371 // First dimension: planes/depth. 372 // Second dimension: rows/height. 373 // Third dimension: cols/width. 374 375 int filter_planes, filter_rows, filter_cols; 376 if (is_conv2d) { 377 filter_rows = filter_shape.dim_size(TF_2DFILTER_DIM_H); 378 filter_cols = filter_shape.dim_size(TF_2DFILTER_DIM_W); 379 } else { 380 filter_planes = filter_shape.dim_size(TF_3DFILTER_DIM_P); 381 filter_rows = filter_shape.dim_size(TF_3DFILTER_DIM_H); 382 filter_cols = filter_shape.dim_size(TF_3DFILTER_DIM_W); 383 } 384 385 int stride_planes, stride_rows, stride_cols; 386 int dilation_planes, dilation_rows, dilation_cols; 387 if (is_conv2d) { 388 // Conv2D stride is a vector of 2 elements: {s_r, s_c} 389 stride_rows = strides[0]; 390 stride_cols = strides[1]; 391 dilation_rows = dilations[0]; 392 dilation_cols = dilations[1]; 393 } else { 394 // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c} 395 stride_planes = strides[0]; 396 stride_rows = strides[1]; 397 stride_cols = strides[2]; 398 dilation_planes = dilations[0]; 399 dilation_rows = dilations[1]; 400 dilation_cols = dilations[2]; 401 } 402 403 // Output batch is same as input batch. 404 int out_batch = GetTensorDim(input_shape, data_format_, 'N'); 405 int out_depth; 406 407 // TODO(intel-tf) add support for 3-D Depthwise 408 409 // Output depth is same as last dimension for filters for regular 410 // convolutions and group convolutions. For depthwise it is in_depth * 411 // channel_multiplier. The channel_multiplier is the last dimension of 412 // TF filter for depthwise convolutions. 413 if (is_depthwise) { 414 out_depth = (filter_shape.dim_size(TF_2DFILTER_DIM_I) * 415 filter_shape.dim_size(TF_2DFILTER_DIM_O)); 416 } else if (is_grouped_convolution) { 417 out_depth = filter_shape.dim_size(TF_2DFILTER_DIM_O); 418 } else { 419 out_depth = filter_shape.dim_size( 420 is_conv2d ? static_cast<int>(TF_2DFILTER_DIM_O) 421 : static_cast<int>(TF_3DFILTER_DIM_O)); 422 } 423 424 int64 out_rows = 0, out_cols = 0, out_planes = 0; 425 int64 pad_top = 0, pad_bottom = 0, pad_left = 0, pad_right = 0; 426 int64 pad_front, pad_back; 427 428 if (is_conv2d) { 429 Padding padding_type; 430 if (pad_enabled) { 431 padding_type = Padding::EXPLICIT; 432 pad_top = static_cast<int64_t>((*pad_l)[0]); 433 pad_left = static_cast<int64_t>((*pad_l)[1]); 434 pad_bottom = static_cast<int64_t>((*pad_r)[0]); 435 pad_right = static_cast<int64_t>((*pad_r)[1]); 436 } else { 437 padding_type = padding_; 438 } 439 OP_REQUIRES_OK(context_, 440 GetWindowedOutputSizeVerboseV2( 441 input_rows, filter_rows, dilation_rows, stride_rows, 442 padding_type, &out_rows, &pad_top, &pad_bottom)); 443 OP_REQUIRES_OK(context_, 444 GetWindowedOutputSizeVerboseV2( 445 input_cols, filter_cols, dilation_cols, stride_cols, 446 padding_type, &out_cols, &pad_left, &pad_right)); 447 } else { 448 Padding padding_type; 449 if (pad_enabled) { 450 padding_type = Padding::EXPLICIT; 451 pad_front = static_cast<int64>((*pad_l)[0]); 452 pad_top = static_cast<int64>((*pad_l)[1]); 453 pad_left = static_cast<int64>((*pad_l)[2]); 454 pad_back = static_cast<int64>((*pad_r)[0]); 455 pad_bottom = static_cast<int64>((*pad_r)[1]); 456 pad_right = static_cast<int64>((*pad_r)[2]); 457 } else { 458 padding_type = padding_; 459 } 460 OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerboseV2( 461 input_planes, filter_planes, dilation_planes, 462 stride_planes, padding_type, &out_planes, 463 &pad_front, &pad_back)); 464 OP_REQUIRES_OK(context_, 465 GetWindowedOutputSizeVerboseV2( 466 input_rows, filter_rows, dilation_rows, stride_rows, 467 padding_type, &out_rows, &pad_top, &pad_bottom)); 468 OP_REQUIRES_OK(context_, 469 GetWindowedOutputSizeVerboseV2( 470 input_cols, filter_cols, dilation_cols, stride_cols, 471 padding_type, &out_cols, &pad_left, &pad_right)); 472 } 473 474 if (is_conv2d) { 475 // If pad_enabled, i.e., pad and conv op are fused, then 476 // all pads are already passed from pad op through 477 // *pad_l and *pad_r and they don't need to be set here. 478 if (!pad_enabled) { 479 *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; 480 *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)}; 481 } 482 } else { 483 // If pad_enabled, i.e., pad and conv op are fused, then 484 // all pads are already passed from pad op through 485 // *pad_l and *pad_r and they don't need to be set here. 486 if (!pad_enabled) { 487 *pad_l = {static_cast<int>(pad_front), static_cast<int>(pad_top), 488 static_cast<int>(pad_left)}; 489 *pad_r = {static_cast<int>(pad_back), static_cast<int>(pad_bottom), 490 static_cast<int>(pad_right)}; 491 } 492 } 493 // Tensorflow output is in data_format order. 494 // Conv2D: NHWC or NCHW 495 // Conv3D: NDHWC or NCDHW 496 // oneDNN uses asymmetric padding. 497 TensorShape out_shape = 498 is_conv2d 499 ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, 500 out_depth) 501 : ShapeFromFormat(data_format_, out_batch, 502 {{out_planes, out_rows, out_cols}}, out_depth); 503 *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); 504 if (is_grouped_convolution) { 505 int out_depth = GetTensorDim(out_shape, data_format_, 'C'); 506 int input_depth = GetTensorDim(input_shape, data_format_, 'C'); 507 int filter_in_depth = 508 static_cast<int>(filter_shape.dim_size(TF_2DFILTER_DIM_I)); 509 int num_groups = input_depth / filter_in_depth; 510 OP_REQUIRES( 511 context_, out_depth % num_groups == 0 && out_depth >= num_groups, 512 errors::InvalidArgument( 513 "output depth must be evenly divisible by number of groups: ", 514 out_depth, " vs ", num_groups)); 515 } 516 if (is_conv2d) { 517 // For Conv2D, oneDNN always needs output in NCHW format. 518 std::vector<memory::dim> output_sizes(4, -1); 519 output_sizes[MklDnnDims::Dim_N] = out_batch; 520 output_sizes[MklDnnDims::Dim_C] = out_depth; 521 output_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows); 522 output_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols); 523 *output_dims_mkl_order = output_sizes; 524 } else { 525 std::vector<memory::dim> output_sizes(5, -1); 526 output_sizes[MklDnnDims3D::Dim3d_N] = out_batch; 527 output_sizes[MklDnnDims3D::Dim3d_C] = out_depth; 528 output_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes); 529 output_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows); 530 output_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols); 531 *output_dims_mkl_order = output_sizes; 532 } 533 } 534 535 // Calculate output and pad size of forward Convolution operator. 536 // See comment on GetConvOutputAndPadSizeInMklOrder for parameters. 537 // 538 // Function does not return anything, but sets error in context status. GetOutputAndPadSizeInMklOrder(size_t src_index,size_t filter_index,const memory::dims & strides,const memory::dims & dilations,memory::dims * output_dims_tf_order,memory::dims * output_dims_mkl_order,memory::dims * pad_l,memory::dims * pad_r,bool is_grouped_convolution,bool is_depthwise)539 inline void GetOutputAndPadSizeInMklOrder( 540 size_t src_index, size_t filter_index, const memory::dims& strides, 541 const memory::dims& dilations, memory::dims* output_dims_tf_order, 542 memory::dims* output_dims_mkl_order, memory::dims* pad_l, 543 memory::dims* pad_r, bool is_grouped_convolution, bool is_depthwise) { 544 DCHECK(output_dims_tf_order); 545 DCHECK(output_dims_mkl_order); 546 DCHECK(pad_l); 547 DCHECK(pad_r); 548 549 auto input_tf_shape = GetTfShape(context_, src_index); 550 auto filter_tf_shape = GetTfShape(context_, filter_index); 551 552 if (strides_.size() == 4) { 553 // Conv2D 554 OP_REQUIRES(context_, input_tf_shape.dims() == 4, 555 errors::InvalidArgument("input must be 4-dimensional", 556 input_tf_shape.DebugString())); 557 } else { 558 // Conv3D 559 OP_REQUIRES(context_, input_tf_shape.dims() == 5, 560 errors::InvalidArgument("input must be 5-dimensional", 561 input_tf_shape.DebugString())); 562 } 563 564 GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides, 565 dilations, output_dims_tf_order, 566 output_dims_mkl_order, pad_l, pad_r, 567 is_grouped_convolution, is_depthwise); 568 } 569 570 // Wrapper function to calculate input, filter, and output sizes of 571 // Conv2D/Conv3D in MKL order: 572 // Conv2D: NCHW for input and output; OIHW for filter. 573 // Conv3D: NCDHW for input and output; OIDHW for filter. 574 // Function also calculates output shape in Tensorflow order. 575 // Additionally, it also calculates strides and paddings. 576 // 577 // Function does not return anything, but sets error in context status. 578 inline void GetConvFwdSizesInMklOrder( 579 const TensorShape& input_shape, const TensorShape& filter_shape, 580 memory::dims* input_dims, memory::dims* filter_dims, 581 memory::dims* strides, memory::dims* dilations, 582 memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, 583 memory::dims* pad_l, memory::dims* pad_r, bool* is_grouped_convolution, 584 bool pad_enabled = false, bool is_depthwise = false) { 585 DCHECK(input_dims); 586 DCHECK(filter_dims); 587 DCHECK(strides); 588 DCHECK(dilations); 589 DCHECK(output_dims_tf_order); 590 DCHECK(output_dims_mkl_order); 591 DCHECK(pad_l); 592 DCHECK(pad_r); 593 594 GetInputSizeInMklOrder(input_shape, input_dims); 595 if (!context_->status().ok()) return; 596 GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims, 597 is_grouped_convolution, is_depthwise); 598 if (!context_->status().ok()) return; 599 GetStridesInMklOrder(strides); 600 GetDilationsInMklOrder(dilations); 601 GetOutputAndPadSizeInMklOrder( 602 input_shape, filter_shape, *strides, *dilations, output_dims_tf_order, 603 output_dims_mkl_order, pad_l, pad_r, *is_grouped_convolution, 604 pad_enabled, is_depthwise); 605 if (!context_->status().ok()) return; 606 } 607 }; 608 609 ///////////////////////////////////////////////////////////////////// 610 /// Common class that implements ConvBackpropFilter and Input 611 ///////////////////////////////////////////////////////////////////// 612 613 template <typename Device, class T, bool is_depthwise> 614 class MklConvBackpropCommonOp : public OpKernel { 615 public: ~MklConvBackpropCommonOp()616 ~MklConvBackpropCommonOp() {} MklConvBackpropCommonOp(OpKernelConstruction * context)617 explicit MklConvBackpropCommonOp(OpKernelConstruction* context) 618 : OpKernel(context) { 619 string data_format_str; 620 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); 621 OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), 622 errors::InvalidArgument("Invalid data format")); 623 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 624 int stride_n = GetTensorDim(strides_, data_format_, 'N'); 625 int stride_c = GetTensorDim(strides_, data_format_, 'C'); 626 OP_REQUIRES( 627 context, (stride_n == 1 && stride_c == 1), 628 errors::InvalidArgument("Current implementation does not yet support " 629 "strides in the batch and depth dimensions.")); 630 631 // Depthwise Convolution doesn't have dilation parameter 632 if (!is_depthwise) { 633 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 634 if (strides_.size() == 4) { 635 // Check Conv2D dilations 636 OP_REQUIRES( 637 context, dilations_.size() == 4, 638 errors::InvalidArgument("Sliding window dilations field must " 639 "specify 4 dimensions")); 640 int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); 641 int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); 642 int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); 643 int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); 644 OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1), 645 errors::InvalidArgument( 646 "Current implementation does not yet support " 647 "dilations in the batch and depth dimensions.")); 648 OP_REQUIRES( 649 context, dilation_h > 0 && dilation_w > 0, 650 errors::InvalidArgument("Dilated rates should be larger than 0.")); 651 } 652 } else { 653 // Set dilations as 1 for depthwise conv 654 // for future support to align with Tensorflow 655 dilations_ = {1, 1, 1, 1}; 656 } 657 658 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 659 } 660 661 protected: 662 // data members accessible to derived classes. 663 std::vector<int32> dilations_; 664 std::vector<int32> strides_; 665 Padding padding_; 666 TensorFormat data_format_; // NCHW or NHWC 667 }; 668 669 ///////////////////////////////////////////////////////////////////// 670 /// Dummy Mkl op that is just used for operators that are intermediate 671 /// output of node fusion in the graph 672 ///////////////////////////////////////////////////////////////////// 673 674 template <typename Device, typename T> 675 class MklDummyOp : public OpKernel { 676 public: ~MklDummyOp()677 ~MklDummyOp() {} 678 MklDummyOp(OpKernelConstruction * context)679 explicit MklDummyOp(OpKernelConstruction* context) : OpKernel(context) {} 680 Compute(OpKernelContext * context)681 void Compute(OpKernelContext* context) override { 682 TF_CHECK_OK( 683 errors::Unimplemented("This is a dummy op." 684 "It should not have been invoked.")); 685 } 686 }; 687 688 } // namespace tensorflow 689 690 #endif // INTEL_MKL 691 #endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_ 692