1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_POOLING_OPS_COMMON_H_ 17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_POOLING_OPS_COMMON_H_ 18 19 #ifdef INTEL_MKL 20 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "dnnl.hpp" 26 #include "tensorflow/core/util/mkl_util.h" 27 #include "tensorflow/core/util/padding.h" 28 #ifdef DNNL_AARCH64_USE_ACL 29 #include "tensorflow/core/platform/mutex.h" 30 #endif 31 32 namespace tensorflow { 33 34 using dnnl::pooling_backward; 35 using dnnl::pooling_forward; 36 using dnnl::prop_kind; 37 using dnnl::stream; 38 39 using PoolingFwdPd = dnnl::pooling_forward::primitive_desc; 40 using PoolingBwdPd = dnnl::pooling_backward::primitive_desc; 41 42 struct MklPoolingParams { 43 memory::dims src_dims; 44 memory::dims dst_dims; 45 memory::dims filter_dims; 46 memory::dims strides; 47 memory::dims padding_left; 48 memory::dims padding_right; 49 dnnl::algorithm alg_kind; 50 dnnl::prop_kind prop_kind; 51 memory::format_tag src_format; 52 memory::desc src_md; 53 bool native_format; 54 MklPoolingParamsMklPoolingParams55 MklPoolingParams(memory::dims src_dims, memory::dims dst_dims, 56 memory::dims filter_dims, memory::dims strides, 57 memory::dims padding_left, memory::dims padding_right, 58 dnnl::algorithm alg_kind, dnnl::prop_kind prop_kind, 59 memory::format_tag src_format, memory::desc src_md, 60 bool native_format) 61 : src_dims(src_dims), 62 dst_dims(dst_dims), 63 filter_dims(filter_dims), 64 strides(strides), 65 padding_left(padding_left), 66 padding_right(padding_right), 67 alg_kind(alg_kind), 68 prop_kind(prop_kind), 69 src_format(src_format), 70 src_md(src_md), 71 native_format(native_format) {} 72 }; 73 74 template <typename T> 75 class MklPoolingFwdPrimitive : public MklPrimitive { 76 public: MklPoolingFwdPrimitive(const MklPoolingParams & fwdParams)77 explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams) 78 : MklPrimitive(engine(engine::kind::cpu, 0)) { 79 if (context_.fwd == nullptr) Setup(fwdParams); 80 } 81 ~MklPoolingFwdPrimitive()82 ~MklPoolingFwdPrimitive() {} 83 84 // Pooling forward execute 85 // src_data: input data buffer of src 86 // ws_data: output data buffer of workspace 87 // dst_data: output data buffer of dst 88 void Execute(const T* src_data, T* dst_data, void* ws_data, 89 std::shared_ptr<stream> fwd_stream); 90 GetPoolingFwdPd()91 std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const { 92 return context_.fwd_pd; 93 } 94 GetSrcMemoryFormat()95 memory::format_tag GetSrcMemoryFormat() const { return context_.src_fmt; } GetDstMemoryFormat()96 memory::format_tag GetDstMemoryFormat() const { return context_.dst_fmt; } 97 98 private: 99 void Setup(const MklPoolingParams& fwdParams); 100 101 struct PoolingFwdContext { 102 // Algorithm. 103 dnnl::algorithm alg_kind; 104 105 // Kind of propagation, forward or backward. 106 dnnl::prop_kind prop_kind; 107 108 // Expected memory format. 109 memory::format_tag src_fmt; 110 memory::format_tag dst_fmt; 111 memory::format_tag ws_fmt; 112 113 // Workspace shape. 114 memory::dims ws_dims; 115 memory::data_type ws_dt; 116 size_t ws_size; 117 118 // oneDNN memory, just dummy data. 119 std::shared_ptr<dnnl::memory> ws_mem; 120 std::shared_ptr<dnnl::memory> src_mem; 121 std::shared_ptr<dnnl::memory> dst_mem; 122 123 // Pooling forward descriptor and primitive descriptor. 124 std::shared_ptr<dnnl::pooling_forward::desc> fwd_desc; 125 std::shared_ptr<PoolingFwdPd> fwd_pd; 126 127 // Memory descriptor. 128 std::shared_ptr<dnnl::memory::desc> src_md; 129 std::shared_ptr<dnnl::memory::desc> dst_md; 130 131 // Pooling primitive 132 std::shared_ptr<dnnl::pooling_forward> fwd; 133 std::shared_ptr<dnnl::stream> fwd_stream; 134 std::vector<dnnl::primitive> fwd_primitives; 135 136 std::vector<std::unordered_map<int, memory>> net_args; 137 PoolingFwdContextPoolingFwdContext138 PoolingFwdContext() 139 : src_fmt(memory::format_tag::any), 140 dst_fmt(memory::format_tag::any), 141 ws_fmt(memory::format_tag::any), 142 ws_mem(nullptr), 143 src_mem(nullptr), 144 dst_mem(nullptr), 145 fwd_desc(nullptr), 146 fwd_pd(nullptr), 147 src_md(nullptr), 148 dst_md(nullptr), 149 fwd(nullptr) {} 150 }; 151 152 struct PoolingFwdContext context_; 153 154 #ifdef DNNL_AARCH64_USE_ACL 155 mutex primitive_execution_mu_; 156 #endif 157 }; 158 159 template <typename T> 160 class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory<T> { 161 public: Get(const MklPoolingParams & fwdParams)162 static MklPoolingFwdPrimitive<T>* Get(const MklPoolingParams& fwdParams) { 163 MklPoolingFwdPrimitive<T>* pooling_forward = nullptr; 164 165 // Get pooling primitive from the pool 166 pooling_forward = static_cast<MklPoolingFwdPrimitive<T>*>( 167 MklPoolingFwdPrimitiveFactory<T>::GetInstance().GetPoolingFwd( 168 fwdParams)); 169 170 if (pooling_forward == nullptr) { 171 pooling_forward = new MklPoolingFwdPrimitive<T>(fwdParams); 172 MklPoolingFwdPrimitiveFactory<T>::GetInstance().SetPoolingFwd( 173 fwdParams, pooling_forward); 174 } 175 return pooling_forward; 176 } 177 GetInstance()178 static MklPoolingFwdPrimitiveFactory& GetInstance() { 179 static MklPoolingFwdPrimitiveFactory instance_; 180 return instance_; 181 } 182 183 private: MklPoolingFwdPrimitiveFactory()184 MklPoolingFwdPrimitiveFactory() {} ~MklPoolingFwdPrimitiveFactory()185 ~MklPoolingFwdPrimitiveFactory() {} 186 187 // The key to be created will be used to get/set pooling 188 // primitive op from reuse perspective. 189 // A pooling key is a string which concates key parameters 190 // as well as algorithm kind (max versus avg). CreateKey(const MklPoolingParams & fwdParams)191 static string CreateKey(const MklPoolingParams& fwdParams) { 192 string prefix = "pooling_fwd"; 193 FactoryKeyCreator key_creator; 194 key_creator.AddAsKey(prefix); 195 key_creator.AddAsKey(fwdParams.src_dims); 196 key_creator.AddAsKey(fwdParams.dst_dims); 197 key_creator.AddAsKey(fwdParams.filter_dims); 198 key_creator.AddAsKey(fwdParams.strides); 199 key_creator.AddAsKey(fwdParams.padding_left); 200 key_creator.AddAsKey(fwdParams.padding_right); 201 key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind)); 202 key_creator.AddAsKey<int>(static_cast<int>(fwdParams.prop_kind)); 203 return key_creator.GetKey(); 204 } 205 GetPoolingFwd(const MklPoolingParams & fwdParams)206 MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) { 207 string key = CreateKey(fwdParams); 208 return this->GetOp(key); 209 } 210 SetPoolingFwd(const MklPoolingParams & fwdParams,MklPrimitive * op)211 void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive* op) { 212 string key = CreateKey(fwdParams); 213 this->SetOp(key, op); 214 } 215 }; 216 217 template <typename T> 218 class MklPoolingBwdPrimitive : public MklPrimitive { 219 public: MklPoolingBwdPrimitive(const MklPoolingParams & bwdParams)220 explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) 221 : MklPrimitive(engine(engine::kind::cpu, 0)) { 222 if (context_.bwd == nullptr) Setup(bwdParams); 223 } 224 ~MklPoolingBwdPrimitive()225 ~MklPoolingBwdPrimitive() {} 226 227 // Pooling backward execute 228 // diff_dst_data: input data buffer of diff_dst 229 // diff_src_data: output data buffer of diff_src 230 // ws_data: input data buffer of workspace 231 void Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data, 232 std::shared_ptr<stream> bwd_stream); 233 234 public: GetPoolingFwdPd()235 std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const { 236 return context_.fwd_pd; 237 } GetPoolingBwdPd()238 std::shared_ptr<PoolingBwdPd> GetPoolingBwdPd() const { 239 return context_.bwd_pd; 240 } 241 GetWorkspaceDataType()242 dnnl::memory::data_type GetWorkspaceDataType() const { 243 return context_.ws_dt; 244 } 245 246 private: 247 void Setup(const MklPoolingParams& bwdParams); 248 249 // Primitive reuse context for pooling bwd ops 250 struct PoolingBwdContext { 251 // Algorithm. 252 dnnl::algorithm alg_kind; 253 254 // Expected memory format. 255 memory::format_tag diff_src_fmt; 256 memory::format_tag diff_dst_fmt; 257 memory::format_tag ws_fmt; 258 259 // Workspace attribute. 260 dnnl::memory::dims ws_dims; 261 dnnl::memory::data_type ws_dt; 262 263 // oneDNN memory. 264 std::shared_ptr<dnnl::memory> ws_mem; 265 std::shared_ptr<dnnl::memory> diff_src_mem; 266 std::shared_ptr<dnnl::memory> diff_dst_mem; 267 268 // Memory descriptors. 269 std::shared_ptr<dnnl::memory::desc> src_md; 270 std::shared_ptr<dnnl::memory::desc> dst_md; 271 272 // Forward and backward pooling descriptors and primitive descriptors. 273 std::shared_ptr<dnnl::pooling_forward::desc> fwd_desc; 274 std::shared_ptr<dnnl::pooling_backward::desc> bwd_desc; 275 std::shared_ptr<PoolingFwdPd> fwd_pd; 276 std::shared_ptr<PoolingBwdPd> bwd_pd; 277 278 // Backward pooling primitive. 279 std::shared_ptr<dnnl::pooling_backward> bwd; 280 std::shared_ptr<dnnl::stream> bwd_stream; 281 282 std::vector<dnnl::primitive> bwd_primitives; 283 std::vector<std::unordered_map<int, memory>> net_args; 284 PoolingBwdContextPoolingBwdContext285 PoolingBwdContext() 286 : diff_src_fmt(memory::format_tag::any), 287 diff_dst_fmt(memory::format_tag::any), 288 ws_fmt(memory::format_tag::any), 289 ws_mem(nullptr), 290 diff_src_mem(nullptr), 291 diff_dst_mem(nullptr), 292 src_md(nullptr), 293 dst_md(nullptr), 294 fwd_desc(nullptr), 295 bwd_desc(nullptr), 296 fwd_pd(nullptr), 297 bwd_pd(nullptr), 298 bwd(nullptr) {} 299 }; 300 301 struct PoolingBwdContext context_; 302 #ifdef DNNL_AARCH64_USE_ACL 303 mutex primitive_execution_mu_; 304 #endif 305 }; 306 307 template <typename T> 308 class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory<T> { 309 public: Get(const MklPoolingParams & bwdParams)310 static MklPoolingBwdPrimitive<T>* Get(const MklPoolingParams& bwdParams) { 311 MklPoolingBwdPrimitive<T>* pooling_backward = nullptr; 312 313 // Find a pooling backward primitive from the pool. 314 // If it does not exist, create a new one. 315 pooling_backward = static_cast<MklPoolingBwdPrimitive<T>*>( 316 MklPoolingBwdPrimitiveFactory<T>::GetInstance().GetPoolingBwd( 317 bwdParams)); 318 if (pooling_backward == nullptr) { 319 pooling_backward = new MklPoolingBwdPrimitive<T>(bwdParams); 320 MklPoolingBwdPrimitiveFactory<T>::GetInstance().SetPoolingBwd( 321 bwdParams, pooling_backward); 322 } 323 return pooling_backward; 324 } 325 GetInstance()326 static MklPoolingBwdPrimitiveFactory& GetInstance() { 327 static MklPoolingBwdPrimitiveFactory instance_; 328 return instance_; 329 } 330 331 private: MklPoolingBwdPrimitiveFactory()332 MklPoolingBwdPrimitiveFactory() {} ~MklPoolingBwdPrimitiveFactory()333 ~MklPoolingBwdPrimitiveFactory() {} 334 335 // The key to be created will be used to get/set pooling 336 // primitive op from reuse perspective. 337 // A pooling key is a string which concates key parameters 338 // as well as algorithm kind (max versus avg). CreateKey(const MklPoolingParams & bwdParams)339 static string CreateKey(const MklPoolingParams& bwdParams) { 340 string prefix = "pooling_bwd"; 341 FactoryKeyCreator key_creator; 342 key_creator.AddAsKey(prefix); 343 key_creator.AddAsKey(bwdParams.src_dims); 344 key_creator.AddAsKey(bwdParams.dst_dims); 345 key_creator.AddAsKey(bwdParams.filter_dims); 346 key_creator.AddAsKey(bwdParams.strides); 347 key_creator.AddAsKey(bwdParams.padding_left); 348 key_creator.AddAsKey(bwdParams.padding_right); 349 key_creator.AddAsKey<int>(static_cast<int>(bwdParams.alg_kind)); 350 return key_creator.GetKey(); 351 } 352 GetPoolingBwd(const MklPoolingParams & bwdParams)353 MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) { 354 string key = CreateKey(bwdParams); 355 return this->GetOp(key); 356 } 357 SetPoolingBwd(const MklPoolingParams & bwdParams,MklPrimitive * op)358 void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive* op) { 359 string key = CreateKey(bwdParams); 360 this->SetOp(key, op); 361 } 362 }; 363 364 typedef Eigen::ThreadPoolDevice CPUDevice; 365 366 struct MklPoolParameters { 367 int depth; 368 369 int tensor_in_planes; // Pool3D 370 int tensor_in_cols; 371 int tensor_in_rows; 372 int tensor_in_batch; 373 374 int window_planes; // Pool3D 375 int window_rows; 376 int window_cols; 377 int depth_window; 378 379 int planes_stride; // Pool3D 380 int row_stride; 381 int col_stride; 382 int depth_stride; 383 384 int64 out_planes; // Pool3D 385 int64 out_height; 386 int64 out_width; 387 int out_depth; 388 389 int64 pad_P1; // Pool3D 390 int64 pad_P2; // Pool3D 391 int64 pad_left; 392 int64 pad_right; 393 int64 pad_top; 394 int64 pad_bottom; 395 int pad_depth; 396 397 TensorFormat data_format; MklPoolParametersMklPoolParameters398 MklPoolParameters() 399 : depth(0), 400 tensor_in_planes(0), 401 tensor_in_cols(0), 402 tensor_in_rows(0), 403 tensor_in_batch(0), 404 window_planes(0), 405 window_rows(0), 406 window_cols(0), 407 depth_window(0), 408 planes_stride(0), 409 row_stride(0), 410 col_stride(0), 411 depth_stride(0), 412 out_planes(0), 413 out_height(0), 414 out_width(0), 415 out_depth(0), 416 pad_P1(0), 417 pad_P2(0), 418 pad_left(0), 419 pad_right(0), 420 pad_top(0), 421 pad_bottom(0), 422 pad_depth(0), 423 data_format(TensorFormat::FORMAT_NCHW) {} 424 425 // Updates context->status if there is an invalid input. 426 void Init(OpKernelContext* context, const std::vector<int32>& ksize, 427 const std::vector<int32>& stride, Padding padding, 428 TensorFormat data_format, const TensorShape& tensor_in_shape); 429 void Init(OpKernelContext* context, const std::vector<int32>& ksize, 430 const std::vector<int32>& stride, Padding padding, 431 TensorFormat data_format, const MklDnnShape* mkl_in_shape); 432 433 private: 434 // Common initialization for TensorFlow and MKL formats 435 void Init(OpKernelContext* context, const std::vector<int32>& ksize, 436 const std::vector<int32>& stride, Padding padding, 437 TensorFormat data_format); 438 }; 439 440 template <class T> 441 class MklPoolingOpBase : public OpKernel { 442 public: MklPoolingOpBase(OpKernelConstruction * context)443 explicit MklPoolingOpBase(OpKernelConstruction* context) 444 : OpKernel(context), workspace_enabled_(false) { 445 string data_format; 446 if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) { 447 // Current quantized convolution doesn't have data_format attribute. 448 data_format = "NHWC"; 449 } else { 450 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 451 } 452 OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_), 453 errors::InvalidArgument("Invalid data format")); 454 OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_)); 455 OP_REQUIRES(context, this->ksize_.size() == 4 || this->ksize_.size() == 5, 456 errors::InvalidArgument("Sliding window ksize field must " 457 "specify 4 or 5 dimensions")); 458 for (int i = 0; i < this->ksize_.size(); ++i) { 459 OP_REQUIRES(context, this->ksize_[i] > 0, 460 errors::InvalidArgument("Sliding window ksize for dimension ", 461 i, " was zero.")); 462 } 463 464 OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_)); 465 OP_REQUIRES(context, this->stride_.size() == 4 || this->stride_.size() == 5, 466 errors::InvalidArgument("Sliding window strides field must " 467 "specify 4 or 5 dimensions")); 468 OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_)); 469 OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1, 470 errors::Unimplemented("Pooling is not yet supported on the " 471 "batch dimension.")); 472 bool is_pool2d = (this->ksize_.size() == 4); 473 this->tensor_format_mkldnn_ = 474 is_pool2d ? TFDataFormatToMklDnnDataFormat(this->data_format_tf_) 475 : TFDataFormatToMklDnn3DDataFormat(this->data_format_tf_); 476 477 this->data_format_mkldnn_ = 478 MklTensorFormatToMklDnnDataFormat(this->tensor_format_mkldnn_); 479 480 // We may not get this attribute for this node if it does not go through 481 // graph rewrite pass. So we do not check for error while retrieving this 482 // attribute value. 483 auto status = 484 context->GetAttr("workspace_enabled", &this->workspace_enabled_); 485 (void)status; 486 } 487 void Compute(OpKernelContext* context) override = 0; 488 489 protected: 490 // Calculate output shape of pooling op in oneDNN and TensorFlow order. 491 // oneDNN uses NCHW(Pool2D) or NCDHW(Pool3D) for output order. 492 // But TensorFlow output will be in NHWC/NCHW(Pool2D) or 493 // NDHWC/NCDHW(Pool3D) format depending on data format. Function expects 494 // output height and width to have already been int32 bounds-checked. GetOutputDims(const MklPoolParameters & mkl_pool_params,memory::dims * output_dims_mkl_order)495 void GetOutputDims(const MklPoolParameters& mkl_pool_params, 496 memory::dims* output_dims_mkl_order) { 497 if (this->ksize_.size() == 4) { 498 // Pooling2D: oneDNN always needs output in NCHW format. 499 *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch, 500 mkl_pool_params.out_depth, 501 static_cast<int>(mkl_pool_params.out_height), 502 static_cast<int>(mkl_pool_params.out_width)}; 503 } else { 504 // Pooling3D: oneDNN always needs output in NCDHW format. 505 *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch, 506 mkl_pool_params.out_depth, 507 static_cast<int>(mkl_pool_params.out_planes), 508 static_cast<int>(mkl_pool_params.out_height), 509 static_cast<int>(mkl_pool_params.out_width)}; 510 } 511 } 512 InitMklPoolParameters(OpKernelContext * context,MklPoolParameters * pool_params,const MklDnnShape & original_input_mkl_shape,const TensorShape & input_tensor_shape)513 void InitMklPoolParameters(OpKernelContext* context, 514 MklPoolParameters* pool_params, 515 const MklDnnShape& original_input_mkl_shape, 516 const TensorShape& input_tensor_shape) { 517 if (!original_input_mkl_shape.IsMklTensor()) { 518 pool_params->Init(context, this->ksize_, this->stride_, this->padding_, 519 this->data_format_tf_, input_tensor_shape); 520 } else { 521 pool_params->Init(context, this->ksize_, this->stride_, this->padding_, 522 this->data_format_tf_, &original_input_mkl_shape); 523 } 524 } 525 PoolParamsToDims(const MklPoolParameters * pool_params,memory::dims * filter_dims,memory::dims * strides,memory::dims * padding_left,memory::dims * padding_right,bool is_pool2d)526 void PoolParamsToDims(const MklPoolParameters* pool_params, 527 memory::dims* filter_dims, memory::dims* strides, 528 memory::dims* padding_left, memory::dims* padding_right, 529 bool is_pool2d) { 530 if (is_pool2d) { 531 // Pool2D 532 *filter_dims = 533 memory::dims({pool_params->window_rows, pool_params->window_cols}); 534 *strides = 535 memory::dims({pool_params->row_stride, pool_params->col_stride}); 536 *padding_left = memory::dims({static_cast<int>(pool_params->pad_top), 537 static_cast<int>(pool_params->pad_left)}); 538 *padding_right = memory::dims({static_cast<int>(pool_params->pad_bottom), 539 static_cast<int>(pool_params->pad_right)}); 540 } else { 541 // Pool3D 542 *filter_dims = 543 memory::dims({pool_params->window_planes, pool_params->window_rows, 544 pool_params->window_cols}); 545 *strides = 546 memory::dims({pool_params->planes_stride, pool_params->row_stride, 547 pool_params->col_stride}); 548 549 *padding_left = memory::dims({static_cast<int>(pool_params->pad_P1), 550 static_cast<int>(pool_params->pad_top), 551 static_cast<int>(pool_params->pad_left)}); 552 *padding_right = memory::dims({static_cast<int>(pool_params->pad_P2), 553 static_cast<int>(pool_params->pad_bottom), 554 static_cast<int>(pool_params->pad_right)}); 555 } 556 } 557 AllocateEmptyOutputTensor(OpKernelContext * context,const int kOutputIndex,MklPoolParameters * pool_params,const memory::dims output_dims_mkl_order,Tensor ** output_tensor)558 void AllocateEmptyOutputTensor(OpKernelContext* context, 559 const int kOutputIndex, 560 MklPoolParameters* pool_params, 561 const memory::dims output_dims_mkl_order, 562 Tensor** output_tensor) { 563 MklDnnShape output_mkl_shape; 564 output_mkl_shape.SetMklTensor(false); 565 TensorShape output_tf_shape; 566 if (pool_params->data_format == TensorFormat::FORMAT_NCHW) { 567 output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); 568 } else { 569 memory::dims output_dims_order; 570 // determine Pooling2D (NHWC) or Pooling3D (NDHWC) 571 if (this->ksize_.size() == 4) { 572 output_dims_order = {pool_params->tensor_in_batch, 573 static_cast<int>(pool_params->out_height), 574 static_cast<int>(pool_params->out_width), 575 pool_params->out_depth}; 576 } else { 577 output_dims_order = {pool_params->tensor_in_batch, 578 static_cast<int>(pool_params->out_planes), 579 static_cast<int>(pool_params->out_height), 580 static_cast<int>(pool_params->out_width), 581 pool_params->out_depth}; 582 } 583 output_tf_shape = MklDnnDimsToTFShape(output_dims_order); 584 } 585 AllocateOutputSetMklShape(context, kOutputIndex, output_tensor, 586 output_tf_shape, output_mkl_shape, 587 native_format_); 588 DCHECK(output_tensor); 589 } 590 591 // Checks to make sure that the memory we need to allocate 592 // is a multiple of sizeof(T) 593 // returns the number of elements GetNumTElements(const memory::desc & pd)594 size_t GetNumTElements(const memory::desc& pd) { 595 size_t num_bytes = pd.get_size(); 596 size_t ret_val = num_bytes / sizeof(T); 597 if (num_bytes % sizeof(T) != 0) { 598 ret_val++; 599 } 600 return ret_val; 601 } 602 603 std::vector<int32> ksize_; 604 std::vector<int32> stride_; 605 Padding padding_; 606 TensorFormat data_format_tf_; 607 MklTensorFormat tensor_format_mkldnn_; 608 memory::format_tag data_format_mkldnn_; 609 bool workspace_enabled_; 610 bool native_format_ = false; 611 }; 612 613 template <class T> 614 class MklPoolingForwardOpBase : public MklPoolingOpBase<T> { 615 public: 616 explicit MklPoolingForwardOpBase<T>(OpKernelConstruction* context) 617 : MklPoolingOpBase<T>(context) {} 618 void Compute(OpKernelContext* context) override = 0; 619 620 protected: ConfigureInput(OpKernelContext * context,const MklDnnShape & input_mkl_shape,const Tensor & input_tensor,MklPoolParameters * pool_params,MklDnnData<T> * dnn_data_input)621 void ConfigureInput(OpKernelContext* context, 622 const MklDnnShape& input_mkl_shape, 623 const Tensor& input_tensor, 624 MklPoolParameters* pool_params, 625 MklDnnData<T>* dnn_data_input) { 626 DCHECK(pool_params); 627 DCHECK(dnn_data_input); 628 TensorShape input_tensor_shape = input_tensor.shape(); 629 if (input_tensor.NumElements() != 0) { 630 memory::desc input_md = 631 input_mkl_shape.IsMklTensor() 632 ? input_mkl_shape.GetMklLayout() 633 : memory::desc( 634 (this->ksize_.size() == 4) 635 ? TFShapeToMklDnnDimsInNCHW(input_tensor_shape, 636 this->data_format_tf_) 637 : TFShapeToMklDnnDimsInNCDHW(input_tensor_shape, 638 this->data_format_tf_), 639 MklDnnType<T>(), this->data_format_mkldnn_); 640 dnn_data_input->SetUsrMem(input_md, &input_tensor); 641 642 if (this->ksize_.size() == 5) { 643 // Pool3D 644 std::vector<dnnl::memory::dim> input_sizes(5, -1); 645 input_sizes[MklDnnDims3D::Dim3d_N] = input_md.data.dims[0]; 646 input_sizes[MklDnnDims3D::Dim3d_C] = input_md.data.dims[1]; 647 input_sizes[MklDnnDims3D::Dim3d_D] = input_md.data.dims[2]; 648 input_sizes[MklDnnDims3D::Dim3d_H] = input_md.data.dims[3]; 649 input_sizes[MklDnnDims3D::Dim3d_W] = input_md.data.dims[4]; 650 dnn_data_input->SetOpMemDesc(input_sizes, this->data_format_mkldnn_); 651 } 652 } 653 this->InitMklPoolParameters(context, pool_params, input_mkl_shape, 654 input_tensor_shape); 655 } 656 AllocateOutputTensor(OpKernelContext * context,const PoolingFwdPd & pool_fwd_prim_desc,const memory::dims output_dims_mkl_order,const MklTensorFormat & output_tf_format,Tensor ** output_tensor)657 void AllocateOutputTensor(OpKernelContext* context, 658 const PoolingFwdPd& pool_fwd_prim_desc, 659 const memory::dims output_dims_mkl_order, 660 const MklTensorFormat& output_tf_format, 661 Tensor** output_tensor) { 662 TensorShape output_tf_shape; 663 DCHECK(output_tensor); 664 memory::desc dst_pd = pool_fwd_prim_desc.dst_desc(); 665 666 MklDnnShape output_mkl_shape; 667 output_mkl_shape.SetMklTensor(true); 668 output_mkl_shape.SetMklLayout(&dst_pd); 669 output_mkl_shape.SetElemType(MklDnnType<T>()); 670 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), 671 output_dims_mkl_order, output_tf_format); 672 // Only allocate enough space for the elements we need. 673 output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); 674 675 if (this->native_format_) { 676 output_tf_shape = output_mkl_shape.GetTfShape(); 677 } 678 AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, 679 output_tf_shape, output_mkl_shape, 680 this->native_format_); 681 DCHECK(*output_tensor); 682 } 683 SanityCheckInput(OpKernelContext * context,const Tensor & input_tensor,const MklDnnShape & input_mkl_shape)684 void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor, 685 const MklDnnShape& input_mkl_shape) { 686 if (!input_mkl_shape.IsMklTensor()) { 687 OP_REQUIRES(context, input_tensor.dims() == 4 || input_tensor.dims() == 5, 688 errors::InvalidArgument("Input must be 4 or 5-dimensional")); 689 } else { 690 OP_REQUIRES( 691 context, 692 input_mkl_shape.GetDimension() == 4 || 693 input_mkl_shape.GetDimension() == 5, 694 errors::InvalidArgument("Input shape must be 4 or 5-dimensional")); 695 } 696 } 697 const int kInputTensorIndexInput = 0; 698 const int kOutputTensorIndexOutput = 0; 699 }; // MklPoolingForwardBaseOp 700 701 template <class T> 702 class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> { 703 public: 704 explicit MklPoolingBackwardOpBase<T>(OpKernelConstruction* context) 705 : MklPoolingOpBase<T>(context) {} 706 void Compute(OpKernelContext* context) override = 0; 707 708 protected: 709 const int kOutputTensorIndexOutput = 0; 710 AllocateOutputTensor(OpKernelContext * context,const PoolingBwdPd & pool_bkwd_prim_desc,const memory::dims output_dims_mkl_order,const MklTensorFormat & output_tf_format,Tensor ** output_tensor)711 void AllocateOutputTensor(OpKernelContext* context, 712 const PoolingBwdPd& pool_bkwd_prim_desc, 713 const memory::dims output_dims_mkl_order, 714 const MklTensorFormat& output_tf_format, 715 Tensor** output_tensor) { 716 DCHECK(output_tensor); 717 memory::desc dst_pd = pool_bkwd_prim_desc.diff_src_desc(); 718 MklDnnShape output_mkl_shape; 719 output_mkl_shape.SetMklTensor(true); 720 output_mkl_shape.SetMklLayout(&dst_pd); 721 output_mkl_shape.SetElemType(MklDnnType<T>()); 722 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), 723 output_dims_mkl_order, output_tf_format); 724 725 TensorShape output_tf_shape; 726 output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); 727 if (this->native_format_) { 728 output_tf_shape = output_mkl_shape.GetTfShape(); 729 } 730 AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, 731 output_tf_shape, output_mkl_shape, 732 this->native_format_); 733 DCHECK(*output_tensor); 734 } 735 }; 736 737 } // namespace tensorflow 738 739 #endif // INTEL_MKL 740 #endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_POOLING_OPS_COMMON_H_ 741