1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #ifdef INTEL_MKL 19 20 #include <algorithm> 21 #include <vector> 22 23 #include "tensorflow/core/kernels/mkl/mkl_conv_ops.h" 24 #include "tensorflow/core/util/use_cudnn.h" 25 #include "tensorflow/core/util/work_sharder.h" 26 #ifdef DNNL_AARCH64_USE_ACL 27 #include "tensorflow/core/platform/mutex.h" 28 #endif 29 30 using dnnl::convolution_backward_weights; 31 using dnnl::memory; 32 using dnnl::prop_kind; 33 using dnnl::stream; 34 35 namespace tensorflow { 36 37 typedef Eigen::ThreadPoolDevice CPUDevice; 38 39 using ConvBwdFilterDesc = dnnl::convolution_backward_weights::desc; 40 using ConvBwdFilterPd = dnnl::convolution_backward_weights::primitive_desc; 41 42 struct MklConvBwdFilterParams { 43 memory::dims src_dims; 44 memory::dims diff_filter_dims; 45 memory::dims diff_bias_dims; 46 memory::dims diff_dst_dims; 47 memory::dims strides; 48 MklTensorFormat tf_fmt; 49 bool native_format; 50 memory::dims dilations; 51 memory::dims padding_left; 52 memory::dims padding_right; 53 MklConvBwdFilterParamstensorflow::MklConvBwdFilterParams54 MklConvBwdFilterParams(memory::dims src_dims, memory::dims diff_filter_dims, 55 memory::dims diff_bias_dims, 56 memory::dims diff_dst_dims, memory::dims strides, 57 MklTensorFormat tf_fmt, bool native_format, 58 memory::dims dilations, memory::dims padding_left, 59 memory::dims padding_right) 60 : src_dims(src_dims), 61 diff_filter_dims(diff_filter_dims), 62 diff_bias_dims(diff_bias_dims), 63 diff_dst_dims(diff_dst_dims), 64 strides(strides), 65 tf_fmt(tf_fmt), 66 native_format(native_format), 67 dilations(dilations), 68 padding_left(padding_left), 69 padding_right(padding_right) {} 70 }; 71 72 template <typename T> 73 class MklConvBwdFilterPrimitive : public MklPrimitive { 74 public: MklConvBwdFilterPrimitive(const MklConvBwdFilterParams & convBwdFilterDims)75 explicit MklConvBwdFilterPrimitive( 76 const MklConvBwdFilterParams& convBwdFilterDims) 77 : MklPrimitive(engine(engine::kind::cpu, 0)) { 78 // Create convolution backward filter primitive. 79 if (context_.conv_bwd_filter == nullptr) { 80 Setup(convBwdFilterDims); 81 } 82 } 83 ~MklConvBwdFilterPrimitive()84 ~MklConvBwdFilterPrimitive() {} 85 86 // Convolution backward weights execution with bias 87 // src_data: input data buffer for src 88 // diff_filter_data: output data buffer for diff_filter 89 // diff_bias_data: output data buffer for diff_bias 90 // diff_dst_data: input data buffer for diff_dst Execute(const T * src_data,const T * diff_filter_data,const T * diff_bias_data,const T * diff_dst_data,std::shared_ptr<stream> bwd_filter_stream)91 void Execute(const T* src_data, const T* diff_filter_data, 92 const T* diff_bias_data, const T* diff_dst_data, 93 std::shared_ptr<stream> bwd_filter_stream) { 94 #ifdef DNNL_AARCH64_USE_ACL 95 mutex_lock lock(primitive_execution_mu_); 96 #endif 97 #ifndef ENABLE_ONEDNN_OPENMP 98 // TODO(intel-tf): Create a common function and avoid the duplicate code 99 context_.src_mem->set_data_handle( 100 static_cast<void*>(const_cast<T*>(src_data)), *bwd_filter_stream); 101 context_.diff_filter_mem->set_data_handle( 102 static_cast<void*>(const_cast<T*>(diff_filter_data)), 103 *bwd_filter_stream); 104 if (diff_bias_data != nullptr) { 105 context_.diff_bias_mem->set_data_handle( 106 static_cast<void*>(const_cast<T*>(diff_bias_data)), 107 *bwd_filter_stream); 108 } 109 context_.diff_dst_mem->set_data_handle( 110 static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_filter_stream); 111 #else 112 context_.src_mem->set_data_handle( 113 static_cast<void*>(const_cast<T*>(src_data))); 114 context_.diff_filter_mem->set_data_handle( 115 static_cast<void*>(const_cast<T*>(diff_filter_data))); 116 if (diff_bias_data != nullptr) { 117 context_.diff_bias_mem->set_data_handle( 118 static_cast<void*>(const_cast<T*>(diff_bias_data))); 119 } 120 context_.diff_dst_mem->set_data_handle( 121 static_cast<void*>(const_cast<T*>(diff_dst_data))); 122 #endif // !ENABLE_ONEDNN_OPENMP 123 execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream, 124 context_.bwd_filter_primitives_args); 125 126 context_.src_mem->set_data_handle(DummyData); 127 context_.diff_filter_mem->set_data_handle(DummyData); 128 if (diff_bias_data != nullptr) { 129 context_.diff_bias_mem->set_data_handle(DummyData); 130 } 131 context_.diff_dst_mem->set_data_handle(DummyData); 132 } 133 134 // Convolution backward weights without bias. 135 // src_data: input data buffer of src 136 // diff_filter_data: output data buffer of diff_filter 137 // diff_dst_data: input data buffer of diff_dst Execute(const T * src_data,const T * diff_filter_data,const T * diff_dst_data,std::shared_ptr<stream> bwd_filter_stream)138 void Execute(const T* src_data, const T* diff_filter_data, 139 const T* diff_dst_data, 140 std::shared_ptr<stream> bwd_filter_stream) { 141 Execute(src_data, diff_filter_data, nullptr, diff_dst_data, 142 bwd_filter_stream); 143 } 144 GetPrimitiveDesc() const145 std::shared_ptr<ConvBwdFilterPd> GetPrimitiveDesc() const { 146 return context_.bwd_filter_pd; 147 } 148 149 private: 150 // Primitive reuse context for Conv2D backward filter op. 151 struct ConvBwdFilterContext { 152 // oneDNN memory for inputs and outputs. 153 std::shared_ptr<dnnl::memory> src_mem; 154 std::shared_ptr<dnnl::memory> diff_filter_mem; 155 std::shared_ptr<dnnl::memory> diff_bias_mem; 156 std::shared_ptr<dnnl::memory> diff_dst_mem; 157 158 // Primitive descriptor and descriptor for convolution backward filter. 159 std::shared_ptr<ConvBwdFilterPd> bwd_filter_pd; 160 std::shared_ptr<ConvBwdFilterDesc> bwd_filter_desc; 161 162 // Primitive descriptor and descriptor for convolution forward. 163 std::shared_ptr<ConvFwdPd> fwd_pd; 164 std::shared_ptr<ConvFwdDesc> fwd_desc; 165 166 // Convolution backward filter primitive. 167 std::shared_ptr<dnnl::primitive> conv_bwd_filter; 168 169 // Memory descriptors: forward & backward share the same memory descriptors 170 std::shared_ptr<dnnl::memory::desc> src_md; 171 std::shared_ptr<dnnl::memory::desc> diff_filter_md; 172 std::shared_ptr<dnnl::memory::desc> diff_bias_md; 173 std::shared_ptr<dnnl::memory::desc> diff_dst_md; 174 175 // oneDNN pipeline for executing primitives. 176 std::shared_ptr<dnnl::stream> bwd_filter_stream; 177 std::vector<dnnl::primitive> bwd_filter_primitives; 178 std::vector<MemoryArgsMap> bwd_filter_primitives_args; 179 ConvBwdFilterContexttensorflow::MklConvBwdFilterPrimitive::ConvBwdFilterContext180 ConvBwdFilterContext() 181 : src_mem(nullptr), 182 diff_filter_mem(nullptr), 183 diff_bias_mem(nullptr), 184 diff_dst_mem(nullptr), 185 bwd_filter_desc(nullptr), 186 fwd_pd(nullptr), 187 fwd_desc(nullptr), 188 src_md(nullptr), 189 diff_filter_md(nullptr), 190 diff_bias_md(nullptr), 191 diff_dst_md(nullptr) {} 192 }; 193 Setup(const MklConvBwdFilterParams & convBwdFilterDims)194 void Setup(const MklConvBwdFilterParams& convBwdFilterDims) { 195 memory::format_tag user_data_fmt; 196 if (convBwdFilterDims.native_format) { 197 user_data_fmt = 198 MklTensorFormatToMklDnnDataFormat(convBwdFilterDims.tf_fmt); 199 } else { 200 // Create memory descriptors for convolution backward filter without any 201 // specific format so that oneDNN can pick an appropriate one depending 202 // on the input parameters. 203 user_data_fmt = memory::format_tag::any; 204 } 205 context_.src_md.reset(new memory::desc({convBwdFilterDims.src_dims}, 206 MklDnnType<T>(), user_data_fmt)); 207 208 context_.diff_dst_md.reset(new memory::desc( 209 {convBwdFilterDims.diff_dst_dims}, MklDnnType<T>(), user_data_fmt)); 210 211 context_.diff_filter_md.reset( 212 new memory::desc({convBwdFilterDims.diff_filter_dims}, MklDnnType<T>(), 213 memory::format_tag::any)); 214 215 if (!convBwdFilterDims.diff_bias_dims.empty()) 216 context_.diff_bias_md.reset( 217 new memory::desc({convBwdFilterDims.diff_bias_dims}, MklDnnType<T>(), 218 memory::format_tag::x)); 219 220 // Create descriptor and primitive descriptor for convolution forward. 221 context_.fwd_desc.reset(new ConvFwdDesc( 222 prop_kind::forward, dnnl::algorithm::convolution_direct, 223 *context_.src_md, *context_.diff_filter_md, *context_.diff_dst_md, 224 convBwdFilterDims.strides, convBwdFilterDims.dilations, 225 convBwdFilterDims.padding_left, convBwdFilterDims.padding_right)); 226 context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); 227 228 // Create descriptor and primitive descriptor for convolution bwd filter. 229 if (!convBwdFilterDims.diff_bias_dims.empty()) { 230 context_.bwd_filter_desc.reset(new ConvBwdFilterDesc( 231 dnnl::algorithm::convolution_direct, *context_.src_md, 232 *context_.diff_filter_md, *context_.diff_bias_md, 233 *context_.diff_dst_md, convBwdFilterDims.strides, 234 convBwdFilterDims.dilations, convBwdFilterDims.padding_left, 235 convBwdFilterDims.padding_right)); 236 } else { 237 context_.bwd_filter_desc.reset(new ConvBwdFilterDesc( 238 dnnl::algorithm::convolution_direct, *context_.src_md, 239 *context_.diff_filter_md, *context_.diff_dst_md, 240 convBwdFilterDims.strides, convBwdFilterDims.dilations, 241 convBwdFilterDims.padding_left, convBwdFilterDims.padding_right)); 242 } 243 context_.bwd_filter_pd.reset(new ConvBwdFilterPd( 244 *context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd)); 245 246 auto bwd_filter_pd = context_.bwd_filter_pd.get(); 247 248 // Create memory using dummy data. 249 context_.src_mem.reset( 250 new memory(bwd_filter_pd->src_desc(), cpu_engine_, DummyData)); 251 context_.diff_filter_mem.reset( 252 new memory(bwd_filter_pd->diff_weights_desc(), cpu_engine_, DummyData)); 253 context_.diff_dst_mem.reset( 254 new memory(bwd_filter_pd->diff_dst_desc(), cpu_engine_, DummyData)); 255 256 // Create convolution backward filter primitive and add it to the net. 257 if (!convBwdFilterDims.diff_bias_dims.empty()) { 258 context_.diff_bias_mem.reset( 259 new memory({{convBwdFilterDims.diff_bias_dims}, 260 MklDnnType<T>(), 261 memory::format_tag::x}, 262 cpu_engine_, DummyData)); 263 context_.conv_bwd_filter.reset( 264 new convolution_backward_weights(*context_.bwd_filter_pd)); 265 context_.bwd_filter_primitives_args.push_back( 266 {{DNNL_ARG_SRC, *context_.src_mem}, 267 {DNNL_ARG_DIFF_WEIGHTS, *context_.diff_filter_mem}, 268 {DNNL_ARG_DIFF_BIAS, *context_.diff_bias_mem}, 269 {DNNL_ARG_DIFF_DST, *context_.diff_dst_mem}}); 270 } else { 271 context_.conv_bwd_filter.reset( 272 new convolution_backward_weights(*context_.bwd_filter_pd)); 273 context_.bwd_filter_primitives_args.push_back( 274 {{DNNL_ARG_SRC, *context_.src_mem}, 275 {DNNL_ARG_DIFF_WEIGHTS, *context_.diff_filter_mem}, 276 {DNNL_ARG_DIFF_DST, *context_.diff_dst_mem}}); 277 } 278 context_.bwd_filter_primitives.push_back(*context_.conv_bwd_filter); 279 } 280 281 struct ConvBwdFilterContext context_; 282 283 #ifdef DNNL_AARCH64_USE_ACL 284 mutex primitive_execution_mu_; 285 #endif 286 }; 287 288 template <typename T> 289 class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> { 290 public: Get(const MklConvBwdFilterParams & convBwdFilterDims,bool do_not_cache)291 static MklConvBwdFilterPrimitive<T>* Get( 292 const MklConvBwdFilterParams& convBwdFilterDims, bool do_not_cache) { 293 MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr; 294 295 if (do_not_cache) { /* Create new primitive always */ 296 conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims); 297 } else { 298 // Look into the pool for reusable primitive. 299 conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*>( 300 MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter( 301 convBwdFilterDims)); 302 303 if (conv_bwd_filter == nullptr) { 304 conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims); 305 MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter( 306 convBwdFilterDims, conv_bwd_filter); 307 } 308 } 309 310 return conv_bwd_filter; 311 } 312 313 private: MklConvBwdFilterPrimitiveFactory()314 MklConvBwdFilterPrimitiveFactory() {} ~MklConvBwdFilterPrimitiveFactory()315 ~MklConvBwdFilterPrimitiveFactory() {} 316 GetInstance()317 static MklConvBwdFilterPrimitiveFactory& GetInstance() { 318 static MklConvBwdFilterPrimitiveFactory instance_; 319 return instance_; 320 } 321 CreateKey(const MklConvBwdFilterParams & convBwdFilterDims)322 static string CreateKey(const MklConvBwdFilterParams& convBwdFilterDims) { 323 string prefix = "conv_bwd_filter"; 324 FactoryKeyCreator key_creator; 325 key_creator.AddAsKey(prefix); 326 key_creator.AddAsKey(convBwdFilterDims.src_dims); 327 key_creator.AddAsKey(convBwdFilterDims.diff_filter_dims); 328 key_creator.AddAsKey(convBwdFilterDims.diff_bias_dims); 329 key_creator.AddAsKey(convBwdFilterDims.diff_dst_dims); 330 key_creator.AddAsKey(convBwdFilterDims.strides); 331 key_creator.AddAsKey(convBwdFilterDims.dilations); 332 key_creator.AddAsKey(convBwdFilterDims.padding_left); 333 key_creator.AddAsKey(convBwdFilterDims.padding_right); 334 if (convBwdFilterDims.native_format) { 335 key_creator.AddAsKey(convBwdFilterDims.tf_fmt); 336 } 337 return key_creator.GetKey(); 338 } 339 GetConvBwdFilter(const MklConvBwdFilterParams & convBwdFilterDims)340 MklPrimitive* GetConvBwdFilter( 341 const MklConvBwdFilterParams& convBwdFilterDims) { 342 string key = CreateKey(convBwdFilterDims); 343 return this->GetOp(key); 344 } 345 SetConvBwdFilter(const MklConvBwdFilterParams & convBwdFilterDims,MklPrimitive * op)346 void SetConvBwdFilter(const MklConvBwdFilterParams& convBwdFilterDims, 347 MklPrimitive* op) { 348 string key = CreateKey(convBwdFilterDims); 349 this->SetOp(key, op); 350 } 351 }; 352 353 template <typename Device, class T, bool bias_enabled, bool is_depthwise, 354 bool native_format> 355 class MklConvCustomBackpropFilterOp 356 : public MklConvBackpropCommonOp<Device, T, is_depthwise> { 357 public: MklConvCustomBackpropFilterOp(OpKernelConstruction * context)358 explicit MklConvCustomBackpropFilterOp(OpKernelConstruction* context) 359 : MklConvBackpropCommonOp<Device, T, is_depthwise>(context) {} 360 ~MklConvCustomBackpropFilterOp()361 ~MklConvCustomBackpropFilterOp() {} 362 Compute(OpKernelContext * context)363 void Compute(OpKernelContext* context) { 364 try { 365 // Input tensors. 366 const Tensor& src_tensor = MklGetInput(context, kInputIdx); 367 const Tensor& filter_tensor = MklGetInput(context, kFilterIdx); 368 const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIdx); 369 370 MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape; 371 GetMklShape(context, kInputIdx, &src_mkl_shape, native_format); 372 GetMklShape(context, kFilterIdx, &filter_mkl_shape, native_format); 373 GetMklShape(context, kDiffDstIdx, &diff_dst_mkl_shape, native_format); 374 // Allow operator-specific sanity checking of shapes. 375 ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape); 376 377 // Allow operator-specific generation of shapes. 378 // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a 379 // tensor containing shape of filter. So filter.shape() is not 380 // a correct way to get filter shape. These operator-specific calls 381 // allow this class to handle this case. 382 TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor); 383 const string& op_type = this->type_string(); 384 if ((op_type.find("3D") != std::string::npos) && 385 (op_type.find("V2") != std::string::npos)) { 386 OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_tensor.shape()), 387 errors::InvalidArgument( 388 "filter_sizes shape must be rank 1 but is rank ", 389 filter_tensor.shape().dims())); 390 } 391 TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor); 392 TensorShape diff_dst_tf_shape = 393 GetTfShape(context, kDiffDstIdx, native_format); 394 395 // Corner cases: output with 0 elements and 0 batch size. 396 Tensor* diff_filter_tensor = nullptr; 397 if (src_tf_shape.num_elements() == 0 || 398 filter_tf_shape.num_elements() == 0 || 399 diff_dst_tf_shape.num_elements() == 0) { 400 MklDnnShape diff_filter_mkl_shape; 401 diff_filter_mkl_shape.SetMklTensor(false); 402 TensorShape diff_filter_tf_shape = 403 GetOutputTfShape(src_tf_shape, filter_tf_shape, diff_dst_tf_shape); 404 const int kOutputIdx = 0; 405 AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor, 406 diff_filter_tf_shape, diff_filter_mkl_shape, 407 native_format); 408 DCHECK(diff_filter_tensor != nullptr); 409 410 // If output tensor has more than 0 elements, we need to 0 them out. 411 auto diff_filter_data = diff_filter_tensor->flat<T>().data(); 412 for (size_t i = 0; i < diff_filter_tf_shape.num_elements(); ++i) { 413 diff_filter_data[i] = static_cast<T>(0); 414 } 415 return; 416 } 417 418 // By default, all dims are in MKL order except those that are suffixed 419 // with `tf_order` 420 memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims; 421 memory::dims padding_left, padding_right, dilations, strides; 422 memory::dims fwd_dst_dims, fwd_dst_dims_tf_order; 423 424 // Get forward convolution parameters. 425 bool is_grouped_convolution = false; 426 MklDnnConvUtil conv_util(context, this->strides_, this->padding_, 427 this->data_format_, this->dilations_); 428 conv_util.GetConvFwdSizesInMklOrder( 429 src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims, 430 &strides, &dilations, &fwd_dst_dims_tf_order, &fwd_dst_dims, 431 &padding_left, &padding_right, &is_grouped_convolution, false, 432 is_depthwise); 433 if (!context->status().ok()) return; 434 435 bool is_conv2d = (this->strides_.size() == 4); 436 437 auto tf_fmt = is_conv2d 438 ? TFDataFormatToMklDnnDataFormat(this->data_format_) 439 : TFDataFormatToMklDnn3DDataFormat(this->data_format_); 440 auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); 441 OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef, 442 errors::InvalidArgument("Invalid data format")); 443 444 auto fwd_src_md = 445 src_mkl_shape.IsMklTensor() 446 ? src_mkl_shape.GetMklLayout() 447 : memory::desc(fwd_src_dims, MklDnnType<T>(), mkl_fmt_tag); 448 449 conv_util.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); 450 if (!context->status().ok()) return; 451 452 auto diff_dst_md = 453 diff_dst_mkl_shape.IsMklTensor() 454 ? diff_dst_mkl_shape.GetMklLayout() 455 : memory::desc(diff_dst_dims, MklDnnType<T>(), mkl_fmt_tag); 456 457 memory::dims diff_bias_dims = {}; 458 int64 depth = 0; 459 if (bias_enabled) { 460 TensorShape obp_tf_shape = GetTfShape(context, 2, native_format); 461 depth = (this->data_format_ == FORMAT_NCHW) 462 ? obp_tf_shape.dim_size(1) 463 : obp_tf_shape.dim_size(is_conv2d ? 3 : 4); 464 diff_bias_dims = {static_cast<int>(depth)}; 465 } 466 467 // The default dilation factor for each dimension is 1 in TF and 468 // 0 in oneDNN. 469 for (int i = 0; i < dilations.size(); ++i) --dilations[i]; 470 MklConvBwdFilterParams convBwdFilterDims( 471 fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides, 472 tf_fmt, native_format, dilations, padding_left, padding_right); 473 474 // oneDNN allocates large buffers when a conv gradient filter primitive 475 // is created. So we don't cache conv backward primitives when the env 476 // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true. 477 bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled(); 478 479 MklConvBwdFilterPrimitive<T>* conv_bwd_filter = 480 MklConvBwdFilterPrimitiveFactory<T>::Get(convBwdFilterDims, 481 do_not_cache); 482 483 // Allocate output tensors: diff_filter and diff_bias (w bias). 484 auto diff_filter_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims); 485 486 MklDnnShape diff_filter_mkl_shape; 487 diff_filter_mkl_shape.SetMklTensor(false); 488 489 if (is_conv2d) { 490 if (!is_depthwise && !is_grouped_convolution) { 491 // Conv2D: output_dims_mkl_order is in OIHW format. 492 TensorShape diff_filter_tf_shape( 493 {diff_filter_dims[MklDnnDims::Dim_H], 494 diff_filter_dims[MklDnnDims::Dim_W], 495 diff_filter_dims[MklDnnDims::Dim_I], 496 diff_filter_dims[MklDnnDims::Dim_O]}); 497 AllocateOutputSetMklShape(context, 0, &diff_filter_tensor, 498 diff_filter_tf_shape, diff_filter_mkl_shape, 499 native_format); 500 } else if (is_depthwise) { 501 // Depthwise Conv2d: diff_filter_dims is GOIHW format. 502 // | TensorFlow | oneDNN 503 // ---------------------------------------------------------------- 504 // filter_out_depth | depth_multiplier | depth_multiplier * 505 // | | group_count 506 // ---------------------------------------------------------------- 507 // filter_in_depth | in_depth | in_depth / group_count 508 // For depthwise convolution, we have group_count == in_depth. 509 // So here G = original I, and I = 1. 510 // And the GOIHW is oneDNN format, here we try to extract the TF 511 // format, TF format is HWIO, as G = original I, so here is HWGO. 512 TensorShape diff_filter_tf_shape( 513 {diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_H], 514 diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_W], 515 diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_G], 516 diff_filter_dims 517 [MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_O]}); 518 AllocateOutputSetMklShape(context, 0, &diff_filter_tensor, 519 diff_filter_tf_shape, diff_filter_mkl_shape, 520 native_format); 521 } else { 522 // For group convolution, we have group_count == in_depth / 523 // filter_in_depth. So here G = in_depth / filter_in_depth, and 524 // O = original O / group_count. 525 // And the GOIHW is oneDNN format, here we try to extract the TF 526 // format, TF format is HWIO, here O is O * G. 527 TensorShape diff_filter_tf_shape( 528 {diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_H], 529 diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_W], 530 diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_I], 531 diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_O] * 532 diff_filter_dims 533 [MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_G]}); 534 AllocateOutputSetMklShape(context, 0, &diff_filter_tensor, 535 diff_filter_tf_shape, diff_filter_mkl_shape, 536 native_format); 537 } 538 } else { 539 // Conv3D: output_dims_mkl_order is in OIDHW format. 540 TensorShape diff_filter_tf_shape( 541 {diff_filter_dims[MklDnnDims3D::Dim3d_D], 542 diff_filter_dims[MklDnnDims3D::Dim3d_H], 543 diff_filter_dims[MklDnnDims3D::Dim3d_W], 544 diff_filter_dims[MklDnnDims3D::Dim3d_I], 545 diff_filter_dims[MklDnnDims3D::Dim3d_O]}); 546 AllocateOutputSetMklShape(context, 0, &diff_filter_tensor, 547 diff_filter_tf_shape, diff_filter_mkl_shape, 548 native_format); 549 } 550 551 Tensor* diff_bias_tensor = nullptr; 552 if (bias_enabled) { 553 TensorShape diff_bias_shape({depth}); 554 AllocateBiasGradTensor(context, diff_bias_shape, &diff_bias_tensor); 555 } 556 557 // Check if src and diff_dst need to be reordered. 558 T* src_data = nullptr; 559 MklDnnData<T> src(&cpu_engine_); 560 auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc(); 561 if (fwd_src_md != bwd_filter_pd->src_desc()) { 562 src.SetUsrMem(fwd_src_md, &src_tensor); 563 src.CheckReorderToOpMem(bwd_filter_pd->src_desc(), cpu_engine_, 564 context); 565 src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); 566 } else { 567 src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data())); 568 } 569 570 T* diff_dst_data = nullptr; 571 MklDnnData<T> diff_dst(&cpu_engine_); 572 if (diff_dst_md != bwd_filter_pd->diff_dst_desc()) { 573 diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); 574 diff_dst.CheckReorderToOpMem(bwd_filter_pd->diff_dst_desc(), 575 cpu_engine_, context); 576 diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle()); 577 } else { 578 diff_dst_data = 579 static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); 580 } 581 582 DCHECK(!diff_filter_mkl_shape.IsMklTensor()); 583 // Output layout is Tensorflow's filter layout 584 // Conv2D: HWIO; Conv3D: DHWIO; Depthwise Conv: HWIGO; Group Conv: 585 // HWIGO 586 auto diff_filter_format = 587 (is_depthwise || is_grouped_convolution) 588 ? memory::format_tag::hwigo 589 : ((this->strides_.size() == 4) ? memory::format_tag::hwio 590 : memory::format_tag::dhwio); 591 auto diff_filter_md = 592 memory::desc(diff_filter_dims, MklDnnType<T>(), diff_filter_format); 593 594 // Convert diff_filter (output) back to TF layout if needed 595 // (i.e. reorder op memory back to user memory) 596 MklDnnData<T> diff_filter(&cpu_engine_); 597 bool diff_filter_reorder_required = false; 598 T* diff_filter_data = nullptr; 599 if (diff_filter_md != bwd_filter_pd->diff_weights_desc()) { 600 // Allocate diff_filter tensor as Tensorflow layout. 601 diff_filter.SetUsrMem(diff_filter_dims, diff_filter_format, 602 diff_filter_tensor); 603 diff_filter_reorder_required = true; 604 diff_filter.PrepareReorderToUserMemIfReq( 605 bwd_filter_pd->diff_weights_desc()); 606 diff_filter_data = 607 static_cast<T*>(diff_filter.GetOpMem().get_data_handle()); 608 } else { 609 diff_filter_data = static_cast<T*>( 610 const_cast<T*>(diff_filter_tensor->flat<T>().data())); 611 } 612 613 // Execute convolution backward filter. 614 std::shared_ptr<stream> bwd_cpu_stream; 615 MklDnnThreadPool eigen_tp(context); 616 bwd_cpu_stream.reset( 617 CreateStream(&eigen_tp, conv_bwd_filter->GetEngine())); 618 if (bias_enabled) { 619 T* diff_bias_data = 620 static_cast<T*>(const_cast<T*>(diff_bias_tensor->flat<T>().data())); 621 conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data, 622 diff_dst_data, bwd_cpu_stream); 623 } else { 624 conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data, 625 bwd_cpu_stream); 626 } 627 628 // Reorder diff_filter back to Tensorflow layout if necessary. 629 if (diff_filter_reorder_required) { 630 diff_filter.InsertReorderToUserMem(context); 631 } 632 633 // Delete primitive since it is not cached. 634 if (do_not_cache) delete conv_bwd_filter; 635 } catch (dnnl::error& e) { 636 string error_msg = "Status: " + std::to_string(e.status) + 637 ", message: " + string(e.message) + ", in file " + 638 string(__FILE__) + ":" + std::to_string(__LINE__); 639 OP_REQUIRES_OK( 640 context, 641 errors::Aborted("Operation received an exception:", error_msg)); 642 } 643 } 644 645 private: 646 const int kInputIdx = 0, kFilterIdx = 1, kDiffDstIdx = 2; 647 const int kDilationH = 0, kDilationW = 1; 648 649 engine cpu_engine_ = engine(engine::kind::cpu, 0); 650 651 // Assert that input shapes are valid. ValidateMklShapes(const MklDnnShape & input_mkl_shape,const MklDnnShape & filter_mkl_shape,const MklDnnShape & obp_mkl_shape)652 void ValidateMklShapes(const MklDnnShape& input_mkl_shape, 653 const MklDnnShape& filter_mkl_shape, 654 const MklDnnShape& obp_mkl_shape) { 655 CHECK(!filter_mkl_shape.IsMklTensor()) 656 << "ConvBackpropFilter: filter should not be in MKL Layout"; 657 } 658 659 // Get TensorFlow shape of input tensor. MakeInputTfShape(OpKernelContext * context,const Tensor & input_tensor)660 TensorShape MakeInputTfShape(OpKernelContext* context, 661 const Tensor& input_tensor) { 662 size_t input_idx = 0; 663 return GetTfShape(context, input_idx, native_format); 664 } 665 666 // Get TensorFlow shape of filter tensor. MakeFilterTfShape(OpKernelContext * context,const Tensor & filter_tensor)667 TensorShape MakeFilterTfShape(OpKernelContext* context, 668 const Tensor& filter_tensor) { 669 TensorShape filter_tf_shape; 670 CHECK_EQ(TensorShapeUtils::IsVector(filter_tensor.shape()), true); 671 CHECK_EQ(TensorShapeUtils::MakeShape(filter_tensor.vec<int32>(), 672 &filter_tf_shape) 673 .ok(), 674 true); 675 return filter_tf_shape; 676 } 677 678 // Get Tensorflow shape of output tensor (diff_filter), 679 // which is same as shape of filter. GetOutputTfShape(const TensorShape & input_shape,const TensorShape & filter_shape,const TensorShape & outbprop_shape)680 TensorShape GetOutputTfShape(const TensorShape& input_shape, 681 const TensorShape& filter_shape, 682 const TensorShape& outbprop_shape) { 683 return filter_shape; 684 } 685 686 // Get the shape of output (diff_filter) in oneDNN order. 687 // Computes shape of output from input shape (fwd_input_dims) 688 // and filter shape (fwd_filter_dims). GetOutputDims(const memory::dims & fwd_input_dims,const memory::dims & fwd_filter_dims)689 const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims, 690 const memory::dims& fwd_filter_dims) { 691 return fwd_filter_dims; 692 } 693 AllocateOutputTensor(OpKernelContext * context,const memory::dims & output_dims_mkl_order,Tensor ** output_tensor)694 void AllocateOutputTensor(OpKernelContext* context, 695 const memory::dims& output_dims_mkl_order, 696 Tensor** output_tensor) { 697 DCHECK(output_tensor != nullptr); 698 699 // For BackpropFilter, we convert the output tensor back in Tensorflow 700 // layout. Because typically, BackpropFilter is the last operator in the 701 // graph that emit filter gradient that is provided to ApplyGradient 702 // method to update the filter. But it may be possible to eliminate this 703 // by forwarding filter in MKL layout if we support ApplyGradient method 704 // for MKL layout propagation. 705 MklDnnShape output_mkl_shape; 706 output_mkl_shape.SetMklTensor(false); 707 // output_dims_mkl_order is in OIHW format. 708 // Allocate shape of TF tensor in HWIO format. 709 TensorShape output_tf_shape({output_dims_mkl_order[MklDnnDims::Dim_H], 710 output_dims_mkl_order[MklDnnDims::Dim_W], 711 output_dims_mkl_order[MklDnnDims::Dim_I], 712 output_dims_mkl_order[MklDnnDims::Dim_O]}); 713 AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape, 714 output_mkl_shape); 715 } 716 AllocateBiasGradTensor(OpKernelContext * context,const TensorShape & bias_grad_shape,Tensor ** bias_grad_tensor)717 void AllocateBiasGradTensor(OpKernelContext* context, 718 const TensorShape& bias_grad_shape, 719 Tensor** bias_grad_tensor) { 720 DCHECK(bias_grad_tensor); 721 722 MklDnnShape bias_grad_mkl_shape; 723 bias_grad_mkl_shape.SetMklTensor(false); 724 AllocateOutputSetMklShape(context, 1, bias_grad_tensor, bias_grad_shape, 725 bias_grad_mkl_shape, native_format); 726 } 727 }; 728 729 #define REGISTER_MKL_FILTER_KERNELS(T) \ 730 REGISTER_KERNEL_BUILDER( \ 731 Name("_MklConv2DBackpropFilter") \ 732 .Device(DEVICE_CPU) \ 733 .TypeConstraint<T>("T") \ 734 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 735 MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, false>); \ 736 REGISTER_KERNEL_BUILDER( \ 737 Name("_MklConv2DBackpropFilterWithBias") \ 738 .Device(DEVICE_CPU) \ 739 .TypeConstraint<T>("T") \ 740 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 741 MklConvCustomBackpropFilterOp<CPUDevice, T, true, false, false>); \ 742 REGISTER_KERNEL_BUILDER( \ 743 Name("_MklDepthwiseConv2dNativeBackpropFilter") \ 744 .Device(DEVICE_CPU) \ 745 .TypeConstraint<T>("T") \ 746 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 747 MklConvCustomBackpropFilterOp<CPUDevice, T, false, true, false>); \ 748 REGISTER_KERNEL_BUILDER( \ 749 Name("__MklDummyConv2DBackpropFilterWithBias") \ 750 .Device(DEVICE_CPU) \ 751 .TypeConstraint<T>("T") \ 752 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 753 MklDummyOp<CPUDevice, T>); \ 754 REGISTER_KERNEL_BUILDER( \ 755 Name("_MklConv3DBackpropFilterV2") \ 756 .Device(DEVICE_CPU) \ 757 .TypeConstraint<T>("T") \ 758 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 759 MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, false>); \ 760 REGISTER_KERNEL_BUILDER( \ 761 Name("_MklNativeConv2DBackpropFilter") \ 762 .Device(DEVICE_CPU) \ 763 .TypeConstraint<T>("T") \ 764 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 765 MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, true>); \ 766 REGISTER_KERNEL_BUILDER( \ 767 Name("_MklNativeDepthwiseConv2dNativeBackpropFilter") \ 768 .Device(DEVICE_CPU) \ 769 .TypeConstraint<T>("T") \ 770 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 771 MklConvCustomBackpropFilterOp<CPUDevice, T, false, true, true>); \ 772 REGISTER_KERNEL_BUILDER( \ 773 Name("_MklNativeConv3DBackpropFilterV2") \ 774 .Device(DEVICE_CPU) \ 775 .TypeConstraint<T>("T") \ 776 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 777 MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, true>); \ 778 REGISTER_KERNEL_BUILDER( \ 779 Name("_MklNativeConv2DBackpropFilterWithBias") \ 780 .Device(DEVICE_CPU) \ 781 .TypeConstraint<T>("T") \ 782 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 783 MklConvCustomBackpropFilterOp<CPUDevice, T, true, false, true>); 784 785 TF_CALL_float(REGISTER_MKL_FILTER_KERNELS); 786 TF_CALL_bfloat16(REGISTER_MKL_FILTER_KERNELS); 787 788 #undef REGISTER_MKL_FILTER_KERNELS 789 790 } // namespace tensorflow 791 792 #endif // INTEL_MKL 793