xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_POOLING_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_POOLING_OPS_COMMON_H_
18 
19 #ifdef INTEL_MKL
20 
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "dnnl.hpp"
26 #include "tensorflow/core/util/mkl_util.h"
27 #include "tensorflow/core/util/padding.h"
28 #ifdef DNNL_AARCH64_USE_ACL
29 #include "tensorflow/core/platform/mutex.h"
30 #endif
31 
32 namespace tensorflow {
33 
34 using dnnl::pooling_backward;
35 using dnnl::pooling_forward;
36 using dnnl::prop_kind;
37 using dnnl::stream;
38 
39 using PoolingFwdPd = dnnl::pooling_forward::primitive_desc;
40 using PoolingBwdPd = dnnl::pooling_backward::primitive_desc;
41 
42 struct MklPoolingParams {
43   memory::dims src_dims;
44   memory::dims dst_dims;
45   memory::dims filter_dims;
46   memory::dims strides;
47   memory::dims padding_left;
48   memory::dims padding_right;
49   dnnl::algorithm alg_kind;
50   dnnl::prop_kind prop_kind;
51   memory::format_tag src_format;
52   memory::desc src_md;
53   bool native_format;
54 
MklPoolingParamsMklPoolingParams55   MklPoolingParams(memory::dims src_dims, memory::dims dst_dims,
56                    memory::dims filter_dims, memory::dims strides,
57                    memory::dims padding_left, memory::dims padding_right,
58                    dnnl::algorithm alg_kind, dnnl::prop_kind prop_kind,
59                    memory::format_tag src_format, memory::desc src_md,
60                    bool native_format)
61       : src_dims(src_dims),
62         dst_dims(dst_dims),
63         filter_dims(filter_dims),
64         strides(strides),
65         padding_left(padding_left),
66         padding_right(padding_right),
67         alg_kind(alg_kind),
68         prop_kind(prop_kind),
69         src_format(src_format),
70         src_md(src_md),
71         native_format(native_format) {}
72 };
73 
74 template <typename T>
75 class MklPoolingFwdPrimitive : public MklPrimitive {
76  public:
MklPoolingFwdPrimitive(const MklPoolingParams & fwdParams)77   explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams)
78       : MklPrimitive(engine(engine::kind::cpu, 0)) {
79     if (context_.fwd == nullptr) Setup(fwdParams);
80   }
81 
~MklPoolingFwdPrimitive()82   ~MklPoolingFwdPrimitive() {}
83 
84   // Pooling forward execute
85   //   src_data:  input data buffer of src
86   //   ws_data:   output data buffer of workspace
87   //   dst_data:  output data buffer of dst
88   void Execute(const T* src_data, T* dst_data, void* ws_data,
89                std::shared_ptr<stream> fwd_stream);
90 
GetPoolingFwdPd()91   std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
92     return context_.fwd_pd;
93   }
94 
GetSrcMemoryFormat()95   memory::format_tag GetSrcMemoryFormat() const { return context_.src_fmt; }
GetDstMemoryFormat()96   memory::format_tag GetDstMemoryFormat() const { return context_.dst_fmt; }
97 
98  private:
99   void Setup(const MklPoolingParams& fwdParams);
100 
101   struct PoolingFwdContext {
102     // Algorithm.
103     dnnl::algorithm alg_kind;
104 
105     // Kind of propagation, forward or backward.
106     dnnl::prop_kind prop_kind;
107 
108     // Expected memory format.
109     memory::format_tag src_fmt;
110     memory::format_tag dst_fmt;
111     memory::format_tag ws_fmt;
112 
113     // Workspace shape.
114     memory::dims ws_dims;
115     memory::data_type ws_dt;
116     size_t ws_size;
117 
118     // oneDNN memory, just dummy data.
119     std::shared_ptr<dnnl::memory> ws_mem;
120     std::shared_ptr<dnnl::memory> src_mem;
121     std::shared_ptr<dnnl::memory> dst_mem;
122 
123     // Pooling forward descriptor and primitive descriptor.
124     std::shared_ptr<dnnl::pooling_forward::desc> fwd_desc;
125     std::shared_ptr<PoolingFwdPd> fwd_pd;
126 
127     // Memory descriptor.
128     std::shared_ptr<dnnl::memory::desc> src_md;
129     std::shared_ptr<dnnl::memory::desc> dst_md;
130 
131     // Pooling primitive
132     std::shared_ptr<dnnl::pooling_forward> fwd;
133     std::shared_ptr<dnnl::stream> fwd_stream;
134     std::vector<dnnl::primitive> fwd_primitives;
135 
136     std::vector<std::unordered_map<int, memory>> net_args;
137 
PoolingFwdContextPoolingFwdContext138     PoolingFwdContext()
139         : src_fmt(memory::format_tag::any),
140           dst_fmt(memory::format_tag::any),
141           ws_fmt(memory::format_tag::any),
142           ws_mem(nullptr),
143           src_mem(nullptr),
144           dst_mem(nullptr),
145           fwd_desc(nullptr),
146           fwd_pd(nullptr),
147           src_md(nullptr),
148           dst_md(nullptr),
149           fwd(nullptr) {}
150   };
151 
152   struct PoolingFwdContext context_;
153 
154 #ifdef DNNL_AARCH64_USE_ACL
155   mutex primitive_execution_mu_;
156 #endif
157 };
158 
159 template <typename T>
160 class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
161  public:
Get(const MklPoolingParams & fwdParams)162   static MklPoolingFwdPrimitive<T>* Get(const MklPoolingParams& fwdParams) {
163     MklPoolingFwdPrimitive<T>* pooling_forward = nullptr;
164 
165     // Get pooling primitive from the pool
166     pooling_forward = static_cast<MklPoolingFwdPrimitive<T>*>(
167         MklPoolingFwdPrimitiveFactory<T>::GetInstance().GetPoolingFwd(
168             fwdParams));
169 
170     if (pooling_forward == nullptr) {
171       pooling_forward = new MklPoolingFwdPrimitive<T>(fwdParams);
172       MklPoolingFwdPrimitiveFactory<T>::GetInstance().SetPoolingFwd(
173           fwdParams, pooling_forward);
174     }
175     return pooling_forward;
176   }
177 
GetInstance()178   static MklPoolingFwdPrimitiveFactory& GetInstance() {
179     static MklPoolingFwdPrimitiveFactory instance_;
180     return instance_;
181   }
182 
183  private:
MklPoolingFwdPrimitiveFactory()184   MklPoolingFwdPrimitiveFactory() {}
~MklPoolingFwdPrimitiveFactory()185   ~MklPoolingFwdPrimitiveFactory() {}
186 
187   // The key to be created will be used to get/set pooling
188   // primitive op from reuse perspective.
189   // A pooling key is a string which concates key parameters
190   // as well as algorithm kind (max versus avg).
CreateKey(const MklPoolingParams & fwdParams)191   static string CreateKey(const MklPoolingParams& fwdParams) {
192     string prefix = "pooling_fwd";
193     FactoryKeyCreator key_creator;
194     key_creator.AddAsKey(prefix);
195     key_creator.AddAsKey(fwdParams.src_dims);
196     key_creator.AddAsKey(fwdParams.dst_dims);
197     key_creator.AddAsKey(fwdParams.filter_dims);
198     key_creator.AddAsKey(fwdParams.strides);
199     key_creator.AddAsKey(fwdParams.padding_left);
200     key_creator.AddAsKey(fwdParams.padding_right);
201     key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
202     key_creator.AddAsKey<int>(static_cast<int>(fwdParams.prop_kind));
203     return key_creator.GetKey();
204   }
205 
GetPoolingFwd(const MklPoolingParams & fwdParams)206   MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) {
207     string key = CreateKey(fwdParams);
208     return this->GetOp(key);
209   }
210 
SetPoolingFwd(const MklPoolingParams & fwdParams,MklPrimitive * op)211   void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive* op) {
212     string key = CreateKey(fwdParams);
213     this->SetOp(key, op);
214   }
215 };
216 
217 template <typename T>
218 class MklPoolingBwdPrimitive : public MklPrimitive {
219  public:
MklPoolingBwdPrimitive(const MklPoolingParams & bwdParams)220   explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams)
221       : MklPrimitive(engine(engine::kind::cpu, 0)) {
222     if (context_.bwd == nullptr) Setup(bwdParams);
223   }
224 
~MklPoolingBwdPrimitive()225   ~MklPoolingBwdPrimitive() {}
226 
227   // Pooling backward execute
228   //   diff_dst_data:  input data buffer of diff_dst
229   //   diff_src_data:  output data buffer of diff_src
230   //   ws_data:        input data buffer of workspace
231   void Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data,
232                std::shared_ptr<stream> bwd_stream);
233 
234  public:
GetPoolingFwdPd()235   std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
236     return context_.fwd_pd;
237   }
GetPoolingBwdPd()238   std::shared_ptr<PoolingBwdPd> GetPoolingBwdPd() const {
239     return context_.bwd_pd;
240   }
241 
GetWorkspaceDataType()242   dnnl::memory::data_type GetWorkspaceDataType() const {
243     return context_.ws_dt;
244   }
245 
246  private:
247   void Setup(const MklPoolingParams& bwdParams);
248 
249   // Primitive reuse context for pooling bwd ops
250   struct PoolingBwdContext {
251     // Algorithm.
252     dnnl::algorithm alg_kind;
253 
254     // Expected memory format.
255     memory::format_tag diff_src_fmt;
256     memory::format_tag diff_dst_fmt;
257     memory::format_tag ws_fmt;
258 
259     // Workspace attribute.
260     dnnl::memory::dims ws_dims;
261     dnnl::memory::data_type ws_dt;
262 
263     // oneDNN memory.
264     std::shared_ptr<dnnl::memory> ws_mem;
265     std::shared_ptr<dnnl::memory> diff_src_mem;
266     std::shared_ptr<dnnl::memory> diff_dst_mem;
267 
268     // Memory descriptors.
269     std::shared_ptr<dnnl::memory::desc> src_md;
270     std::shared_ptr<dnnl::memory::desc> dst_md;
271 
272     // Forward and backward pooling descriptors and primitive descriptors.
273     std::shared_ptr<dnnl::pooling_forward::desc> fwd_desc;
274     std::shared_ptr<dnnl::pooling_backward::desc> bwd_desc;
275     std::shared_ptr<PoolingFwdPd> fwd_pd;
276     std::shared_ptr<PoolingBwdPd> bwd_pd;
277 
278     // Backward pooling primitive.
279     std::shared_ptr<dnnl::pooling_backward> bwd;
280     std::shared_ptr<dnnl::stream> bwd_stream;
281 
282     std::vector<dnnl::primitive> bwd_primitives;
283     std::vector<std::unordered_map<int, memory>> net_args;
284 
PoolingBwdContextPoolingBwdContext285     PoolingBwdContext()
286         : diff_src_fmt(memory::format_tag::any),
287           diff_dst_fmt(memory::format_tag::any),
288           ws_fmt(memory::format_tag::any),
289           ws_mem(nullptr),
290           diff_src_mem(nullptr),
291           diff_dst_mem(nullptr),
292           src_md(nullptr),
293           dst_md(nullptr),
294           fwd_desc(nullptr),
295           bwd_desc(nullptr),
296           fwd_pd(nullptr),
297           bwd_pd(nullptr),
298           bwd(nullptr) {}
299   };
300 
301   struct PoolingBwdContext context_;
302 #ifdef DNNL_AARCH64_USE_ACL
303   mutex primitive_execution_mu_;
304 #endif
305 };
306 
307 template <typename T>
308 class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
309  public:
Get(const MklPoolingParams & bwdParams)310   static MklPoolingBwdPrimitive<T>* Get(const MklPoolingParams& bwdParams) {
311     MklPoolingBwdPrimitive<T>* pooling_backward = nullptr;
312 
313     // Find a pooling backward primitive from the pool.
314     // If it does not exist, create a new one.
315     pooling_backward = static_cast<MklPoolingBwdPrimitive<T>*>(
316         MklPoolingBwdPrimitiveFactory<T>::GetInstance().GetPoolingBwd(
317             bwdParams));
318     if (pooling_backward == nullptr) {
319       pooling_backward = new MklPoolingBwdPrimitive<T>(bwdParams);
320       MklPoolingBwdPrimitiveFactory<T>::GetInstance().SetPoolingBwd(
321           bwdParams, pooling_backward);
322     }
323     return pooling_backward;
324   }
325 
GetInstance()326   static MklPoolingBwdPrimitiveFactory& GetInstance() {
327     static MklPoolingBwdPrimitiveFactory instance_;
328     return instance_;
329   }
330 
331  private:
MklPoolingBwdPrimitiveFactory()332   MklPoolingBwdPrimitiveFactory() {}
~MklPoolingBwdPrimitiveFactory()333   ~MklPoolingBwdPrimitiveFactory() {}
334 
335   // The key to be created will be used to get/set pooling
336   // primitive op from reuse perspective.
337   // A pooling key is a string which concates key parameters
338   // as well as algorithm kind (max versus avg).
CreateKey(const MklPoolingParams & bwdParams)339   static string CreateKey(const MklPoolingParams& bwdParams) {
340     string prefix = "pooling_bwd";
341     FactoryKeyCreator key_creator;
342     key_creator.AddAsKey(prefix);
343     key_creator.AddAsKey(bwdParams.src_dims);
344     key_creator.AddAsKey(bwdParams.dst_dims);
345     key_creator.AddAsKey(bwdParams.filter_dims);
346     key_creator.AddAsKey(bwdParams.strides);
347     key_creator.AddAsKey(bwdParams.padding_left);
348     key_creator.AddAsKey(bwdParams.padding_right);
349     key_creator.AddAsKey<int>(static_cast<int>(bwdParams.alg_kind));
350     return key_creator.GetKey();
351   }
352 
GetPoolingBwd(const MklPoolingParams & bwdParams)353   MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) {
354     string key = CreateKey(bwdParams);
355     return this->GetOp(key);
356   }
357 
SetPoolingBwd(const MklPoolingParams & bwdParams,MklPrimitive * op)358   void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive* op) {
359     string key = CreateKey(bwdParams);
360     this->SetOp(key, op);
361   }
362 };
363 
364 typedef Eigen::ThreadPoolDevice CPUDevice;
365 
366 struct MklPoolParameters {
367   int depth;
368 
369   int tensor_in_planes;  // Pool3D
370   int tensor_in_cols;
371   int tensor_in_rows;
372   int tensor_in_batch;
373 
374   int window_planes;  // Pool3D
375   int window_rows;
376   int window_cols;
377   int depth_window;
378 
379   int planes_stride;  // Pool3D
380   int row_stride;
381   int col_stride;
382   int depth_stride;
383 
384   int64 out_planes;  // Pool3D
385   int64 out_height;
386   int64 out_width;
387   int out_depth;
388 
389   int64 pad_P1;  // Pool3D
390   int64 pad_P2;  // Pool3D
391   int64 pad_left;
392   int64 pad_right;
393   int64 pad_top;
394   int64 pad_bottom;
395   int pad_depth;
396 
397   TensorFormat data_format;
MklPoolParametersMklPoolParameters398   MklPoolParameters()
399       : depth(0),
400         tensor_in_planes(0),
401         tensor_in_cols(0),
402         tensor_in_rows(0),
403         tensor_in_batch(0),
404         window_planes(0),
405         window_rows(0),
406         window_cols(0),
407         depth_window(0),
408         planes_stride(0),
409         row_stride(0),
410         col_stride(0),
411         depth_stride(0),
412         out_planes(0),
413         out_height(0),
414         out_width(0),
415         out_depth(0),
416         pad_P1(0),
417         pad_P2(0),
418         pad_left(0),
419         pad_right(0),
420         pad_top(0),
421         pad_bottom(0),
422         pad_depth(0),
423         data_format(TensorFormat::FORMAT_NCHW) {}
424 
425   // Updates context->status if there is an invalid input.
426   void Init(OpKernelContext* context, const std::vector<int32>& ksize,
427             const std::vector<int32>& stride, Padding padding,
428             TensorFormat data_format, const TensorShape& tensor_in_shape);
429   void Init(OpKernelContext* context, const std::vector<int32>& ksize,
430             const std::vector<int32>& stride, Padding padding,
431             TensorFormat data_format, const MklDnnShape* mkl_in_shape);
432 
433  private:
434   // Common initialization for TensorFlow and MKL formats
435   void Init(OpKernelContext* context, const std::vector<int32>& ksize,
436             const std::vector<int32>& stride, Padding padding,
437             TensorFormat data_format);
438 };
439 
440 template <class T>
441 class MklPoolingOpBase : public OpKernel {
442  public:
MklPoolingOpBase(OpKernelConstruction * context)443   explicit MklPoolingOpBase(OpKernelConstruction* context)
444       : OpKernel(context), workspace_enabled_(false) {
445     string data_format;
446     if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
447       // Current quantized convolution doesn't have data_format attribute.
448       data_format = "NHWC";
449     } else {
450       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
451     }
452     OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_),
453                 errors::InvalidArgument("Invalid data format"));
454     OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_));
455     OP_REQUIRES(context, this->ksize_.size() == 4 || this->ksize_.size() == 5,
456                 errors::InvalidArgument("Sliding window ksize field must "
457                                         "specify 4 or 5 dimensions"));
458     for (int i = 0; i < this->ksize_.size(); ++i) {
459       OP_REQUIRES(context, this->ksize_[i] > 0,
460                   errors::InvalidArgument("Sliding window ksize for dimension ",
461                                           i, " was zero."));
462     }
463 
464     OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_));
465     OP_REQUIRES(context, this->stride_.size() == 4 || this->stride_.size() == 5,
466                 errors::InvalidArgument("Sliding window strides field must "
467                                         "specify 4 or 5 dimensions"));
468     OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_));
469     OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1,
470                 errors::Unimplemented("Pooling is not yet supported on the "
471                                       "batch dimension."));
472     bool is_pool2d = (this->ksize_.size() == 4);
473     this->tensor_format_mkldnn_ =
474         is_pool2d ? TFDataFormatToMklDnnDataFormat(this->data_format_tf_)
475                   : TFDataFormatToMklDnn3DDataFormat(this->data_format_tf_);
476 
477     this->data_format_mkldnn_ =
478         MklTensorFormatToMklDnnDataFormat(this->tensor_format_mkldnn_);
479 
480     // We may not get this attribute for this node if it does not go through
481     // graph rewrite pass. So we do not check for error while retrieving this
482     // attribute value.
483     auto status =
484         context->GetAttr("workspace_enabled", &this->workspace_enabled_);
485     (void)status;
486   }
487   void Compute(OpKernelContext* context) override = 0;
488 
489  protected:
490   // Calculate output shape of pooling op in oneDNN and TensorFlow order.
491   // oneDNN uses NCHW(Pool2D) or NCDHW(Pool3D) for output order.
492   // But TensorFlow output will be in NHWC/NCHW(Pool2D) or
493   // NDHWC/NCDHW(Pool3D) format depending on data format. Function expects
494   // output height and width to have already been int32 bounds-checked.
GetOutputDims(const MklPoolParameters & mkl_pool_params,memory::dims * output_dims_mkl_order)495   void GetOutputDims(const MklPoolParameters& mkl_pool_params,
496                      memory::dims* output_dims_mkl_order) {
497     if (this->ksize_.size() == 4) {
498       // Pooling2D: oneDNN always needs output in NCHW format.
499       *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
500                                 mkl_pool_params.out_depth,
501                                 static_cast<int>(mkl_pool_params.out_height),
502                                 static_cast<int>(mkl_pool_params.out_width)};
503     } else {
504       // Pooling3D: oneDNN always needs output in NCDHW format.
505       *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
506                                 mkl_pool_params.out_depth,
507                                 static_cast<int>(mkl_pool_params.out_planes),
508                                 static_cast<int>(mkl_pool_params.out_height),
509                                 static_cast<int>(mkl_pool_params.out_width)};
510     }
511   }
512 
InitMklPoolParameters(OpKernelContext * context,MklPoolParameters * pool_params,const MklDnnShape & original_input_mkl_shape,const TensorShape & input_tensor_shape)513   void InitMklPoolParameters(OpKernelContext* context,
514                              MklPoolParameters* pool_params,
515                              const MklDnnShape& original_input_mkl_shape,
516                              const TensorShape& input_tensor_shape) {
517     if (!original_input_mkl_shape.IsMklTensor()) {
518       pool_params->Init(context, this->ksize_, this->stride_, this->padding_,
519                         this->data_format_tf_, input_tensor_shape);
520     } else {
521       pool_params->Init(context, this->ksize_, this->stride_, this->padding_,
522                         this->data_format_tf_, &original_input_mkl_shape);
523     }
524   }
525 
PoolParamsToDims(const MklPoolParameters * pool_params,memory::dims * filter_dims,memory::dims * strides,memory::dims * padding_left,memory::dims * padding_right,bool is_pool2d)526   void PoolParamsToDims(const MklPoolParameters* pool_params,
527                         memory::dims* filter_dims, memory::dims* strides,
528                         memory::dims* padding_left, memory::dims* padding_right,
529                         bool is_pool2d) {
530     if (is_pool2d) {
531       // Pool2D
532       *filter_dims =
533           memory::dims({pool_params->window_rows, pool_params->window_cols});
534       *strides =
535           memory::dims({pool_params->row_stride, pool_params->col_stride});
536       *padding_left = memory::dims({static_cast<int>(pool_params->pad_top),
537                                     static_cast<int>(pool_params->pad_left)});
538       *padding_right = memory::dims({static_cast<int>(pool_params->pad_bottom),
539                                      static_cast<int>(pool_params->pad_right)});
540     } else {
541       // Pool3D
542       *filter_dims =
543           memory::dims({pool_params->window_planes, pool_params->window_rows,
544                         pool_params->window_cols});
545       *strides =
546           memory::dims({pool_params->planes_stride, pool_params->row_stride,
547                         pool_params->col_stride});
548 
549       *padding_left = memory::dims({static_cast<int>(pool_params->pad_P1),
550                                     static_cast<int>(pool_params->pad_top),
551                                     static_cast<int>(pool_params->pad_left)});
552       *padding_right = memory::dims({static_cast<int>(pool_params->pad_P2),
553                                      static_cast<int>(pool_params->pad_bottom),
554                                      static_cast<int>(pool_params->pad_right)});
555     }
556   }
557 
AllocateEmptyOutputTensor(OpKernelContext * context,const int kOutputIndex,MklPoolParameters * pool_params,const memory::dims output_dims_mkl_order,Tensor ** output_tensor)558   void AllocateEmptyOutputTensor(OpKernelContext* context,
559                                  const int kOutputIndex,
560                                  MklPoolParameters* pool_params,
561                                  const memory::dims output_dims_mkl_order,
562                                  Tensor** output_tensor) {
563     MklDnnShape output_mkl_shape;
564     output_mkl_shape.SetMklTensor(false);
565     TensorShape output_tf_shape;
566     if (pool_params->data_format == TensorFormat::FORMAT_NCHW) {
567       output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
568     } else {
569       memory::dims output_dims_order;
570       // determine Pooling2D (NHWC) or Pooling3D (NDHWC)
571       if (this->ksize_.size() == 4) {
572         output_dims_order = {pool_params->tensor_in_batch,
573                              static_cast<int>(pool_params->out_height),
574                              static_cast<int>(pool_params->out_width),
575                              pool_params->out_depth};
576       } else {
577         output_dims_order = {pool_params->tensor_in_batch,
578                              static_cast<int>(pool_params->out_planes),
579                              static_cast<int>(pool_params->out_height),
580                              static_cast<int>(pool_params->out_width),
581                              pool_params->out_depth};
582       }
583       output_tf_shape = MklDnnDimsToTFShape(output_dims_order);
584     }
585     AllocateOutputSetMklShape(context, kOutputIndex, output_tensor,
586                               output_tf_shape, output_mkl_shape,
587                               native_format_);
588     DCHECK(output_tensor);
589   }
590 
591   // Checks to make sure that the memory we need to allocate
592   // is a multiple of sizeof(T)
593   // returns the number of elements
GetNumTElements(const memory::desc & pd)594   size_t GetNumTElements(const memory::desc& pd) {
595     size_t num_bytes = pd.get_size();
596     size_t ret_val = num_bytes / sizeof(T);
597     if (num_bytes % sizeof(T) != 0) {
598       ret_val++;
599     }
600     return ret_val;
601   }
602 
603   std::vector<int32> ksize_;
604   std::vector<int32> stride_;
605   Padding padding_;
606   TensorFormat data_format_tf_;
607   MklTensorFormat tensor_format_mkldnn_;
608   memory::format_tag data_format_mkldnn_;
609   bool workspace_enabled_;
610   bool native_format_ = false;
611 };
612 
613 template <class T>
614 class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
615  public:
616   explicit MklPoolingForwardOpBase<T>(OpKernelConstruction* context)
617       : MklPoolingOpBase<T>(context) {}
618   void Compute(OpKernelContext* context) override = 0;
619 
620  protected:
ConfigureInput(OpKernelContext * context,const MklDnnShape & input_mkl_shape,const Tensor & input_tensor,MklPoolParameters * pool_params,MklDnnData<T> * dnn_data_input)621   void ConfigureInput(OpKernelContext* context,
622                       const MklDnnShape& input_mkl_shape,
623                       const Tensor& input_tensor,
624                       MklPoolParameters* pool_params,
625                       MklDnnData<T>* dnn_data_input) {
626     DCHECK(pool_params);
627     DCHECK(dnn_data_input);
628     TensorShape input_tensor_shape = input_tensor.shape();
629     if (input_tensor.NumElements() != 0) {
630       memory::desc input_md =
631           input_mkl_shape.IsMklTensor()
632               ? input_mkl_shape.GetMklLayout()
633               : memory::desc(
634                     (this->ksize_.size() == 4)
635                         ? TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
636                                                     this->data_format_tf_)
637                         : TFShapeToMklDnnDimsInNCDHW(input_tensor_shape,
638                                                      this->data_format_tf_),
639                     MklDnnType<T>(), this->data_format_mkldnn_);
640       dnn_data_input->SetUsrMem(input_md, &input_tensor);
641 
642       if (this->ksize_.size() == 5) {
643         // Pool3D
644         std::vector<dnnl::memory::dim> input_sizes(5, -1);
645         input_sizes[MklDnnDims3D::Dim3d_N] = input_md.data.dims[0];
646         input_sizes[MklDnnDims3D::Dim3d_C] = input_md.data.dims[1];
647         input_sizes[MklDnnDims3D::Dim3d_D] = input_md.data.dims[2];
648         input_sizes[MklDnnDims3D::Dim3d_H] = input_md.data.dims[3];
649         input_sizes[MklDnnDims3D::Dim3d_W] = input_md.data.dims[4];
650         dnn_data_input->SetOpMemDesc(input_sizes, this->data_format_mkldnn_);
651       }
652     }
653     this->InitMklPoolParameters(context, pool_params, input_mkl_shape,
654                                 input_tensor_shape);
655   }
656 
AllocateOutputTensor(OpKernelContext * context,const PoolingFwdPd & pool_fwd_prim_desc,const memory::dims output_dims_mkl_order,const MklTensorFormat & output_tf_format,Tensor ** output_tensor)657   void AllocateOutputTensor(OpKernelContext* context,
658                             const PoolingFwdPd& pool_fwd_prim_desc,
659                             const memory::dims output_dims_mkl_order,
660                             const MklTensorFormat& output_tf_format,
661                             Tensor** output_tensor) {
662     TensorShape output_tf_shape;
663     DCHECK(output_tensor);
664     memory::desc dst_pd = pool_fwd_prim_desc.dst_desc();
665 
666     MklDnnShape output_mkl_shape;
667     output_mkl_shape.SetMklTensor(true);
668     output_mkl_shape.SetMklLayout(&dst_pd);
669     output_mkl_shape.SetElemType(MklDnnType<T>());
670     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
671                                  output_dims_mkl_order, output_tf_format);
672     // Only allocate enough space for the elements we need.
673     output_tf_shape.AddDim(this->GetNumTElements(dst_pd));
674 
675     if (this->native_format_) {
676       output_tf_shape = output_mkl_shape.GetTfShape();
677     }
678     AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor,
679                               output_tf_shape, output_mkl_shape,
680                               this->native_format_);
681     DCHECK(*output_tensor);
682   }
683 
SanityCheckInput(OpKernelContext * context,const Tensor & input_tensor,const MklDnnShape & input_mkl_shape)684   void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor,
685                         const MklDnnShape& input_mkl_shape) {
686     if (!input_mkl_shape.IsMklTensor()) {
687       OP_REQUIRES(context, input_tensor.dims() == 4 || input_tensor.dims() == 5,
688                   errors::InvalidArgument("Input must be 4 or 5-dimensional"));
689     } else {
690       OP_REQUIRES(
691           context,
692           input_mkl_shape.GetDimension() == 4 ||
693               input_mkl_shape.GetDimension() == 5,
694           errors::InvalidArgument("Input shape must be 4 or 5-dimensional"));
695     }
696   }
697   const int kInputTensorIndexInput = 0;
698   const int kOutputTensorIndexOutput = 0;
699 };  // MklPoolingForwardBaseOp
700 
701 template <class T>
702 class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {
703  public:
704   explicit MklPoolingBackwardOpBase<T>(OpKernelConstruction* context)
705       : MklPoolingOpBase<T>(context) {}
706   void Compute(OpKernelContext* context) override = 0;
707 
708  protected:
709   const int kOutputTensorIndexOutput = 0;
710 
AllocateOutputTensor(OpKernelContext * context,const PoolingBwdPd & pool_bkwd_prim_desc,const memory::dims output_dims_mkl_order,const MklTensorFormat & output_tf_format,Tensor ** output_tensor)711   void AllocateOutputTensor(OpKernelContext* context,
712                             const PoolingBwdPd& pool_bkwd_prim_desc,
713                             const memory::dims output_dims_mkl_order,
714                             const MklTensorFormat& output_tf_format,
715                             Tensor** output_tensor) {
716     DCHECK(output_tensor);
717     memory::desc dst_pd = pool_bkwd_prim_desc.diff_src_desc();
718     MklDnnShape output_mkl_shape;
719     output_mkl_shape.SetMklTensor(true);
720     output_mkl_shape.SetMklLayout(&dst_pd);
721     output_mkl_shape.SetElemType(MklDnnType<T>());
722     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
723                                  output_dims_mkl_order, output_tf_format);
724 
725     TensorShape output_tf_shape;
726     output_tf_shape.AddDim(this->GetNumTElements(dst_pd));
727     if (this->native_format_) {
728       output_tf_shape = output_mkl_shape.GetTfShape();
729     }
730     AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor,
731                               output_tf_shape, output_mkl_shape,
732                               this->native_format_);
733     DCHECK(*output_tensor);
734   }
735 };
736 
737 }  // namespace tensorflow
738 
739 #endif  // INTEL_MKL
740 #endif  // TENSORFLOW_CORE_KERNELS_MKL_MKL_POOLING_OPS_COMMON_H_
741