xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/mkl_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
17 #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
18 #ifdef INTEL_MKL
19 
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 
27 #include "dnnl.hpp"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/graph/mkl_graph_util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/stringpiece.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/platform/cpu_info.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/util/env_var.h"
39 #include "tensorflow/core/util/mkl_threadpool.h"
40 #include "tensorflow/core/util/padding.h"
41 #include "tensorflow/core/util/tensor_format.h"
42 #ifdef DNNL_AARCH64_USE_ACL
43 #include "tensorflow/core/platform/mutex.h"
44 #endif
45 
46 using dnnl::engine;
47 using dnnl::memory;
48 using dnnl::primitive;
49 using dnnl::reorder;
50 using dnnl::stream;
51 using CPUDevice = Eigen::ThreadPoolDevice;
52 using MemoryArgsMap = std::unordered_map<int, memory>;
53 using ReorderPd = dnnl::reorder::primitive_desc;
54 
55 #ifdef _WIN32
56 typedef unsigned int uint;
57 #endif
58 
59 namespace tensorflow {
60 
61 // The file contains a number of utility classes and functions used by MKL
62 // enabled kernels
63 
64 // This class encapsulates all the meta data that is associated with an MKL
65 // tensor. A tensor is an MKL tensor if it was created as the result of an
66 // MKL operation, and did not go through a conversion to a standard
67 // Tensorflow tensor.
68 
69 // The dimensions order that oneDNN internally uses for 2D activations
70 // [Batch, Channel, Height, Width] and
71 // for 2D filters [Out_Channel, In_Channel, Height, Width].
72 typedef enum {
73   Dim_N = 0,
74   Dim_C = 1,
75   Dim_H = 2,
76   Dim_W = 3,
77   Dim_O = 0,
78   Dim_I = 1
79 } MklDnnDims;
80 
81 // The dimensions order that oneDNN internally uses for 3D activations
82 // [Batch, Channel, Depth, Height, Width] and
83 // for 3D filters [Out_Channel, In_Channel, Depth, Height, Width].
84 typedef enum {
85   Dim3d_N = 0,
86   Dim3d_C = 1,
87   Dim3d_D = 2,
88   Dim3d_H = 3,
89   Dim3d_W = 4,
90   Dim3d_O = 0,
91   Dim3d_I = 1
92 } MklDnnDims3D;
93 
94 // Enum for the order of dimensions of a TF 2D filter with shape [filter_height,
95 // filter_width, in_channels, out_channels]
96 typedef enum {
97   TF_2DFILTER_DIM_H = 0,
98   TF_2DFILTER_DIM_W = 1,
99   TF_2DFILTER_DIM_I = 2,
100   TF_2DFILTER_DIM_O = 3
101 } TFFilterDims2d;
102 
103 // Enum for the order of dimensions of a TF 3D filter with shape [filter_depth,
104 // filter_height, filter_width, in_channels, out_channels]
105 typedef enum {
106   TF_3DFILTER_DIM_P = 0,
107   TF_3DFILTER_DIM_H = 1,
108   TF_3DFILTER_DIM_W = 2,
109   TF_3DFILTER_DIM_I = 3,
110   TF_3DFILTER_DIM_O = 4
111 } TFFilterDims3d;
112 
113 // The dimensions order that oneDNN requires for the filter in a grouped
114 // convolution (2D only)
115 typedef enum {
116   MKL_GROUP_FILTER_DIM_G = 0,
117   MKL_GROUP_FILTER_DIM_O = 1,
118   MKL_GROUP_FILTER_DIM_I = 2,
119   MKL_GROUP_FILTER_DIM_H = 3,
120   MKL_GROUP_FILTER_DIM_W = 4
121 } MklDnnFilterGroupDims;
122 
123 // Enum used to templatize MklOp kernel implementation
124 // that support both fp32 and int8 versions.
125 enum class MklQuantization {
126   QUANTIZED_VERSION,
127   FP_VERSION,
128 };
129 
130 static const int kSmallBatchSize = 32;
131 
execute_primitives(std::vector<dnnl::primitive> & primitives,std::shared_ptr<stream> stream,std::vector<std::unordered_map<int,memory>> & net_args)132 inline void execute_primitives(
133     std::vector<dnnl::primitive>& primitives, std::shared_ptr<stream> stream,
134     std::vector<std::unordered_map<int, memory>>& net_args) {
135   DCHECK_EQ(primitives.size(), net_args.size());
136   for (size_t i = 0; i < primitives.size(); ++i) {
137     primitives.at(i).execute(*stream, net_args.at(i));
138   }
139 }
140 
141 // In oneDNN v1.x, the format (ex. NCHW) used to initialize a memory descriptor
142 // (md) structure will no longer be recorded in its `format` field. Instead, it
143 // will be set to a canonical `blocked` format for every fully described md.
144 //
145 // Currently, we query this `format` field while mapping oneDNN's data format
146 // to TF's data format. Due to the above restriction, we will now get this data
147 // format information from TF's `data_format` attribute (i.e. via
148 // `TensorFormat`) for oneDNN v1.x.
149 //
150 // Some oneDNN operators such as ReLU do not have a `data_format` attribute
151 // since they are usually in `blocked` format. Therefore, in order to
152 // distinguish between blocked and non-blocked formats, we have defined a new
153 // enum called `MklTensorFormat` that is semantically similar to `TensorFormat`
154 // but with the following additional fields namely:
155 //  1) FORMAT_BLOCKED: as described above, this is needed for element-wise
156 //     operators such as ReLU.
157 //  2) FORMAT_INVALID: for error-checking (ex. unsupported format)
158 //  3) FORMAT_X, FORMAT_NC, FORMAT_TNC: to distinguish between MKL tensors based
159 //     on their dimensions in operators such as Softmax, i.e.:
160 //        FORMAT_X   - 1D tensor
161 //        FORMAT_NC  - 2D tensor
162 //        FORMAT_TNC - 3D tensor
163 enum class MklTensorFormat {
164   FORMAT_NHWC = 0,
165   FORMAT_NCHW = 1,
166   FORMAT_NDHWC = 2,
167   FORMAT_NCDHW = 3,
168   FORMAT_X = 4,
169   FORMAT_NC = 5,
170   FORMAT_TNC = 6,
171   FORMAT_BLOCKED = 7,
172   FORMAT_INVALID = 8,
173 };
174 
175 // Forward declarations
176 memory::format_tag MklTensorFormatToMklDnnDataFormat(MklTensorFormat format);
177 
178 TensorFormat MklDnn3DDataFormatToTFDataFormat(MklTensorFormat format);
179 TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format);
180 
181 memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
182 Status CreateBlockedMemDescHelper(const memory::dims& dim,
183                                   const memory::dims& strides,
184                                   memory::data_type dtype,
185                                   dnnl_memory_desc_t* blocked_md);
186 
187 inline std::ostream& operator<<(std::ostream& os,
188                                 const memory::format_tag& tag) {
189   if (tag == memory::format_tag::undef) {
190     os << "undef";
191   } else if (tag == memory::format_tag::any) {
192     os << "any";
193   } else {
194     os << "invalid";
195   }
196   return os;
197 }
198 
199 inline void operator<<(std::ostream& os, const MklTensorFormat& format) {
200   if (format == MklTensorFormat::FORMAT_NHWC) {
201     os << "FORMAT_NHWC";
202   } else if (format == MklTensorFormat::FORMAT_NCHW) {
203     os << "FORMAT_NCHW";
204   } else if (format == MklTensorFormat::FORMAT_NDHWC) {
205     os << "FORMAT_NDHWC";
206   } else if (format == MklTensorFormat::FORMAT_NCDHW) {
207     os << "FORMAT_NCDHW";
208   } else if (format == MklTensorFormat::FORMAT_X) {
209     os << "FORMAT_X";
210   } else if (format == MklTensorFormat::FORMAT_NC) {
211     os << "FORMAT_NC";
212   } else if (format == MklTensorFormat::FORMAT_TNC) {
213     os << "FORMAT_TNC";
214   } else if (format == MklTensorFormat::FORMAT_BLOCKED) {
215     os << "FORMAT_BLOCKED";
216   } else {
217     os << "INVALID FORMAT";
218   }
219 }
220 
221 template <typename T>
array_cmp(const T * a1,const T * a2,size_t size)222 inline bool array_cmp(const T* a1, const T* a2, size_t size) {
223   for (size_t i = 0; i < size; ++i)
224     if (a1[i] != a2[i]) return false;
225   return true;
226 }
227 
CreateStream(MklDnnThreadPool * eigen_tp,const engine & engine)228 inline dnnl::stream* CreateStream(MklDnnThreadPool* eigen_tp,
229                                   const engine& engine) {
230 #ifndef ENABLE_ONEDNN_OPENMP
231   if (eigen_tp != nullptr) {
232     stream* tp_stream =
233         new stream(dnnl::threadpool_interop::make_stream(engine, eigen_tp));
234     return tp_stream;
235   } else {
236     stream* tp_stream = new stream(engine);
237     return tp_stream;
238   }
239 #else
240   stream* tp_stream = new stream(engine);
241   return tp_stream;
242 #endif  // !ENABLE_ONEDNN_OPENMP
243 }
244 
245 class MklDnnShape {
246  private:
247   struct MklShapeData {
248     // Flag to indicate if the tensor is an MKL tensor or not
249     bool is_mkl_tensor_ = false;
250     // Number of dimensions in Tensorflow format
251     size_t dimension_ = 0;
252     dnnl_dims_t sizes_;  // Required by MKL for conversions
253     MklTensorFormat tf_data_format_ = MklTensorFormat::FORMAT_BLOCKED;
254     memory::data_type T_ = memory::data_type::undef;
255     // MKL layout
256     dnnl_memory_desc_t mkl_md_;
257     /// TF dimension corresponding to this MKL dimension
258     dnnl_dims_t map_;
259   };
260   MklShapeData data_;
261 
262   typedef std::remove_extent<dnnl_dims_t>::type dnnl_dim_t;
263 
264 #define INVALID_DIM_SIZE -1
265 
266  public:
MklDnnShape()267   MklDnnShape() : data_{} {
268     for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
269          ++i) {
270       data_.sizes_[i] = -1;
271     }
272     for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) {
273       data_.map_[i] = -1;
274     }
275   }
276 
~MklDnnShape()277   ~MklDnnShape() {}
278   TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape);  // Cannot copy
279 
280   /// Equality function for MklDnnShape objects
281   /// @return true if both are equal; false otherwise.
282   inline bool operator==(const MklDnnShape& input_shape) const {
283     if (this->IsMklTensor() != input_shape.IsMklTensor()) {
284       return false;
285     }
286 
287     // If input tensors are in MKL layout, then we check for dimensions and
288     // sizes.
289     if (this->IsMklTensor()) {
290       const dnnl_memory_desc_t& cur_md = (this->GetMklLayout()).data;
291       const dnnl_memory_desc_t& input_shape_md =
292           input_shape.GetMklLayout().data;
293       return this->GetTfShape() == input_shape.GetTfShape() &&
294              dnnl_memory_desc_equal(&cur_md, &input_shape_md);
295     }
296 
297     // Both inputs are not MKL tensors.
298     return true;
299   }
300 
301   /// Equality operator for MklDnnShape and TFShape.
302   /// Returns: true if TF shapes for both are the same, false otherwise
303   inline bool operator==(const TensorShape& input_shape) const {
304     if (!this->IsMklTensor()) {
305       return false;
306     }
307 
308     return this->GetTfShape() == input_shape;
309   }
310 
IsMklTensor()311   inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; }
SetMklTensor(bool is_mkl_tensor)312   inline void SetMklTensor(bool is_mkl_tensor) {
313     data_.is_mkl_tensor_ = is_mkl_tensor;
314   }
315 
SetDimensions(const size_t dimension)316   inline void SetDimensions(const size_t dimension) {
317     data_.dimension_ = dimension;
318   }
GetDimension(char dimension)319   inline size_t GetDimension(char dimension) const {
320     int index = GetMklDnnTensorDimIndex(dimension);
321     CHECK(index >= 0 && index < this->GetDimension())
322         << "Invalid index from the dimension: " << index << ", " << dimension;
323     return this->DimSize(index);
324   }
325 
GetDimension3D(char dimension)326   inline size_t GetDimension3D(char dimension) const {
327     int index = GetMklDnnTensor3DDimIndex(dimension);
328     CHECK(index >= 0 && index < this->GetDimension())
329         << "Invalid index from the dimension: " << index << ", " << dimension;
330     return this->DimSize(index);
331   }
332 
GetMklDnnTensorDimIndex(char dimension)333   inline int32 GetMklDnnTensorDimIndex(char dimension) const {
334     switch (dimension) {
335       case 'N':
336         return MklDnnDims::Dim_N;
337       case 'C':
338         return MklDnnDims::Dim_C;
339       case 'H':
340         return MklDnnDims::Dim_H;
341       case 'W':
342         return MklDnnDims::Dim_W;
343       default:
344         LOG(FATAL) << "Invalid dimension: " << dimension;
345         return -1;  // Avoid compiler warning about missing return value
346     }
347   }
348 
GetMklDnnTensor3DDimIndex(char dimension)349   inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
350     switch (dimension) {
351       case 'N':
352         return MklDnnDims3D::Dim3d_N;
353       case 'C':
354         return MklDnnDims3D::Dim3d_C;
355       case 'D':
356         return MklDnnDims3D::Dim3d_D;
357       case 'H':
358         return MklDnnDims3D::Dim3d_H;
359       case 'W':
360         return MklDnnDims3D::Dim3d_W;
361       default:
362         LOG(FATAL) << "Invalid dimension: " << dimension;
363         return -1;  // Avoid compiler warning about missing return value
364     }
365   }
366 
GetDimension()367   inline size_t GetDimension() const { return data_.dimension_; }
GetSizes()368   inline const int* GetSizes() const {
369     return reinterpret_cast<const int*>(&data_.sizes_[0]);
370   }
371 
372   // Returns an dnnl::memory::dims object that contains the sizes of this
373   // MklDnnShape object.
GetSizesAsMklDnnDims()374   inline memory::dims GetSizesAsMklDnnDims() const {
375     memory::dims retVal;
376     if (data_.is_mkl_tensor_) {
377       size_t dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
378       for (size_t i = 0; i < dimensions; i++) {
379         if (data_.sizes_[i] != INVALID_DIM_SIZE)
380           retVal.push_back(data_.sizes_[i]);
381       }
382     } else {
383       CHECK_EQ(data_.is_mkl_tensor_, true);
384     }
385     return retVal;
386   }
387 
DimSize(int index)388   inline int64 DimSize(int index) const {
389     CHECK_LT(index, sizeof(data_.sizes_) / sizeof(data_.sizes_[0]));
390     return data_.sizes_[index];
391   }
392 
393   /// Return TensorShape that describes the Tensorflow shape of the tensor
394   /// represented by this MklShape.
GetTfShape()395   inline TensorShape GetTfShape() const {
396     CHECK_EQ(data_.is_mkl_tensor_, true);
397 
398     std::vector<int32> shape(data_.dimension_, -1);
399     // As mentioned in the comment above, we now rely on TF's `data_format`
400     // attribute to determine if TF shape is in blocked format or not.
401     if (data_.tf_data_format_ != MklTensorFormat::FORMAT_BLOCKED) {
402       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
403         shape[idx] = data_.sizes_[TfDimIdx(idx)];
404       }
405     } else {
406       // If Tensorflow shape is in Blocked format, then we don't have dimension
407       // map for it. So we just create Tensorflow shape from sizes in the
408       // specified order.
409       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
410         shape[idx] = data_.sizes_[idx];
411       }
412     }
413 
414     TensorShape ts;
415     bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok();
416     CHECK_EQ(ret, true);
417     return ts;
418   }
419 
SetElemType(memory::data_type dt)420   inline void SetElemType(memory::data_type dt) { data_.T_ = dt; }
GetElemType()421   inline const memory::data_type GetElemType() { return data_.T_; }
422 
SetMklLayout(memory::desc * md)423   inline void SetMklLayout(memory::desc* md) {
424     CHECK_NOTNULL(md);
425     data_.mkl_md_ = md->data;
426   }
427 
GetMklLayout()428   inline const memory::desc GetMklLayout() const {
429     return memory::desc(data_.mkl_md_);
430   }
431 
GetTfDataFormat()432   inline MklTensorFormat GetTfDataFormat() const {
433     return data_.tf_data_format_;
434   }
435 
436   /// We don't create primitive_descriptor for TensorFlow layout now.
437   /// We use lazy evaluation and create it only when needed. Input format can
438   /// also be Blocked format.
SetTfLayout(size_t dims,const memory::dims & sizes,MklTensorFormat format)439   inline void SetTfLayout(size_t dims, const memory::dims& sizes,
440                           MklTensorFormat format) {
441     DCHECK_EQ(dims, sizes.size())
442         << "SetTfLayout: Number of dimensions does not"
443            "match with dimension array";
444     data_.dimension_ = dims;
445     for (size_t ii = 0; ii < dims; ++ii) {
446       data_.sizes_[ii] = sizes[ii];
447     }
448     data_.tf_data_format_ = format;
449     if (format != MklTensorFormat::FORMAT_BLOCKED) {
450       if (dims == 2) {
451         data_.map_[0] = MklDnnDims::Dim_N;
452         data_.map_[1] = MklDnnDims::Dim_C;
453       } else {
454         SetTfDimOrder(dims, format);
455       }
456     }
457   }
458 
GetTfLayout()459   inline const memory::desc GetTfLayout() const {
460     memory::dims dims;
461     for (size_t ii = 0; ii < data_.dimension_; ++ii) {
462       dims.push_back(data_.sizes_[ii]);
463     }
464 
465     // Create Blocked memory desc if input TF format was set like that.
466     if (data_.tf_data_format_ == MklTensorFormat::FORMAT_BLOCKED) {
467       auto strides = CalculateTFStrides(dims);
468       dnnl_memory_desc_t blocked_md;
469       TF_CHECK_OK(
470           CreateBlockedMemDescHelper(dims, strides, data_.T_, &blocked_md));
471       return memory::desc(blocked_md);
472     } else {
473       auto format_tag =
474           MklTensorFormatToMklDnnDataFormat(data_.tf_data_format_);
475       return memory::desc(dims, data_.T_, format_tag);
476     }
477   }
478 
GetCurLayout()479   inline const memory::desc GetCurLayout() const {
480     return IsMklTensor() ? GetMklLayout() : GetTfLayout();
481   }
482 
483   // We don't need a case of default dimension order because
484   // when an operator that does not get data_format attribute gets all inputs
485   // in Tensorflow format, it will produce output in Tensorflow format.
SetTfDimOrder(const size_t dimension,const dnnl_dims_t map)486   inline void SetTfDimOrder(const size_t dimension, const dnnl_dims_t map) {
487     CHECK(dimension == data_.dimension_);
488     for (size_t ii = 0; ii < dimension; ii++) {
489       data_.map_[ii] = map[ii];
490     }
491   }
492 
SetTfDimOrder(const size_t dimension,TensorFormat data_format)493   inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
494     if (dimension == 5) {
495       CHECK(dimension == data_.dimension_);
496       data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
497           MklDnnDims3D::Dim3d_D;
498       data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
499           MklDnnDims3D::Dim3d_H;
500       data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
501           MklDnnDims3D::Dim3d_W;
502       data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
503           MklDnnDims3D::Dim3d_C;
504       data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
505           MklDnnDims3D::Dim3d_N;
506     } else {
507       CHECK_EQ(dimension, 4);
508       CHECK(dimension == data_.dimension_);
509       data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
510       data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
511       data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
512       data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
513     }
514   }
515 
SetTfDimOrder(const size_t dimension,MklTensorFormat format)516   inline void SetTfDimOrder(const size_t dimension, MklTensorFormat format) {
517     TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
518     SetTfDimOrder(dimension, data_format);
519   }
520 
GetTfToMklDimMap()521   inline const dnnl_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; }
TfDimIdx(int index)522   inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
TfDimSize(int index)523   inline int64 TfDimSize(int index) const {
524     return data_.sizes_[TfDimIdx(index)];
525   }
526 
527   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
528   /// corresponds to MKL's Channel dimension.
IsMklChannelDim(int d)529   inline bool IsMklChannelDim(int d) const {
530     return TfDimIdx(d) == MklDnnDims::Dim_C;
531   }
532 
533   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
534   /// corresponds to MKL's Batch dimension.
IsMklBatchDim(int d)535   inline bool IsMklBatchDim(int d) const {
536     return TfDimIdx(d) == MklDnnDims::Dim_N;
537   }
538 
539   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
540   /// corresponds to MKL's Width dimension.
IsMklWidthDim(int d)541   inline bool IsMklWidthDim(int d) const {
542     return TfDimIdx(d) == MklDnnDims::Dim_W;
543   }
544   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
545   /// corresponds to MKL's Height dimension.
IsMklHeightDim(int d)546   inline bool IsMklHeightDim(int d) const {
547     return TfDimIdx(d) == MklDnnDims::Dim_H;
548   }
549 
550   /// Check if the TF-MKL dimension ordering map specifies if the input
551   /// tensor is in NCHW format.
IsTensorInNCHWFormat()552   inline bool IsTensorInNCHWFormat() const {
553     TensorFormat data_format = FORMAT_NCHW;
554     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
555             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
556             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
557             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
558   }
559 
560   /// Check if the TF-MKL dimension ordering map specifies if the input
561   /// tensor is in NHWC format.
IsTensorInNHWCFormat()562   inline bool IsTensorInNHWCFormat() const {
563     TensorFormat data_format = FORMAT_NHWC;
564     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
565             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
566             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
567             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
568   }
569 
570   /// The following methods are used for serializing and de-serializing the
571   /// contents of the mklshape object.
572   /// The data is serialized in this order
573   /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_;
574 
575   /// Size of buffer to hold the serialized object, the size is computed by
576   /// following above mentioned order
GetSerializeBufferSize()577   inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); }
578 
SerializeMklDnnShape(unsigned char * buf,size_t buf_size)579   void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
580     CHECK(buf_size >= GetSerializeBufferSize())
581         << "Buffer size is too small to SerializeMklDnnShape";
582     *reinterpret_cast<MklShapeData*>(buf) = data_;
583   }
584 
DeSerializeMklDnnShape(const unsigned char * buf,size_t buf_size)585   void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) {
586     // Make sure buffer holds at least is_mkl_tensor_.
587     CHECK(buf_size >= sizeof(data_.is_mkl_tensor_))
588         << "Buffer size is too small in DeSerializeMklDnnShape";
589 
590     const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf);
591     if (is_mkl_tensor) {  // If it is an MKL Tensor then read the rest
592       CHECK(buf_size >= GetSerializeBufferSize())
593           << "Buffer size is too small in DeSerializeMklDnnShape";
594       data_ = *reinterpret_cast<const MklShapeData*>(buf);
595     }
596   }
597 };
598 
599 // List of MklShape objects. Used in Concat/Split layers.
600 typedef std::vector<MklDnnShape> MklDnnShapeList;
601 
602 template <typename T>
603 class MklDnnData;
604 
605 // TODO(intel-tf): Merge with the execute_primitives.
606 inline void ExecutePrimitive(const std::vector<primitive>& net,
607                              const std::vector<MemoryArgsMap>* net_args,
608                              const engine& cpu_engine,
609                              OpKernelContext* context = nullptr) {
610   DCHECK(net_args);
611   DCHECK_EQ(net.size(), net_args->size());
612   std::unique_ptr<stream> cpu_stream;
613   MklDnnThreadPool eigen_tp;
614   if (context != nullptr) {
615     eigen_tp = MklDnnThreadPool(context);
616     cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine));
617   } else {
618     cpu_stream.reset(CreateStream(nullptr, cpu_engine));
619   }
620   for (size_t i = 0; i < net.size(); ++i) {
621     net.at(i).execute(*cpu_stream, net_args->at(i));
622   }
623   cpu_stream->wait();
624 }
625 template <typename T>
ConvertMklToTF(OpKernelContext * context,const Tensor & input_mkl_tensor,const MklDnnShape & input_mkl_shape,Tensor * output_tf_tensor)626 inline Status ConvertMklToTF(OpKernelContext* context,
627                              const Tensor& input_mkl_tensor,
628                              const MklDnnShape& input_mkl_shape,
629                              Tensor* output_tf_tensor) {
630   try {
631     if (!input_mkl_shape.IsMklTensor()) {
632       // Return input as is since it is already a TF tensor
633       *output_tf_tensor = input_mkl_tensor;
634       return Status::OK();
635     }
636 
637     // Allocate output tensor.
638     TensorShape output_tf_shape = input_mkl_shape.GetTfShape();
639     TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<T>::v(), output_tf_shape,
640                                        output_tf_tensor));
641 
642     engine cpu_engine(engine::kind::cpu, 0);
643     MklDnnData<T> input(&cpu_engine);
644 
645     // Get MKL layout of input tensor.
646     auto input_mkl_md = input_mkl_shape.GetMklLayout();
647     auto output_tf_md = input_mkl_shape.GetTfLayout();
648     input.SetUsrMem(input_mkl_md, &input_mkl_tensor);
649 
650     if (input.IsReorderNeeded(output_tf_md)) {
651       std::vector<primitive> net;
652       std::vector<MemoryArgsMap> net_args;
653       bool status = input.CheckReorderToOpMem(output_tf_md, output_tf_tensor,
654                                               net, net_args, cpu_engine);
655       if (!status) {
656         return Status(error::Code::INTERNAL,
657                       "ConvertMklToTF(): Failed to create reorder for input");
658       }
659       ExecutePrimitive(net, &net_args, cpu_engine, context);
660     } else {
661       // If not, just forward input tensor to output tensor.
662       bool status =
663           output_tf_tensor->CopyFrom(input_mkl_tensor, output_tf_shape);
664       if (!status) {
665         return Status(
666             error::Code::INTERNAL,
667             "ConvertMklToTF(): Failed to forward input tensor to output");
668       }
669     }
670     return Status::OK();
671   } catch (dnnl::error& e) {
672     string error_msg = "Status: " + std::to_string(e.status) +
673                        ", message: " + string(e.message) + ", in file " +
674                        string(__FILE__) + ":" + std::to_string(__LINE__);
675     LOG(FATAL) << "Operation received an exception: " << error_msg;
676   }
677 }
678 
679 // Get the MKL shape from the second string tensor
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape,bool eager_mode)680 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape,
681                         bool eager_mode) {
682   if (!eager_mode) {
683     mklshape->DeSerializeMklDnnShape(
684         ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
685             .flat<uint8>()
686             .data(),
687         ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
688                 .flat<uint8>()
689                 .size() *
690             sizeof(uint8));
691   } else {
692     mklshape->SetMklTensor(false);
693   }
694 }
695 
GetMklShape(OpKernelContext * ctext,int n,MklDnnShape * mklshape)696 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
697   GetMklShape(ctext, n, mklshape, false);
698 }
699 
700 // Gets the actual input
MklGetInput(OpKernelContext * ctext,int n)701 inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
702   return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
703 }
704 
GetMklInputList(OpKernelContext * ctext,StringPiece name,OpInputList * input_tensors)705 inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
706                             OpInputList* input_tensors) {
707   CHECK_NOTNULL(input_tensors);
708   TF_CHECK_OK(ctext->input_list(name, input_tensors));
709 }
710 
711 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
712                             MklDnnShapeList* mkl_shapes,
713                             bool native_format = false) {
714   if (!native_format) {
715     OpInputList input_mkl_tensors;
716     GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
717 
718     for (int i = 0; i < input_mkl_tensors.size(); i++) {
719       (*mkl_shapes)[i].DeSerializeMklDnnShape(
720           input_mkl_tensors[i].flat<uint8>().data(),
721           input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
722     }
723   } else {
724     for (int i = 0; i < mkl_shapes->size(); ++i) {
725       (*mkl_shapes)[i].SetMklTensor(false);
726     }
727   }
728 }
729 
730 /// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
731 /// If the input tensor is in MKL layout, then obtains TensorShape from
732 /// MklShape.
733 inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx,
734                               bool eager_mode = false) {
735   // Sanity check.
736   CHECK_NOTNULL(context);
737   CHECK_LT(input_idx, context->num_inputs());
738 
739   MklDnnShape input_mkl_shape;
740   GetMklShape(context, input_idx, &input_mkl_shape, eager_mode);
741   if (input_mkl_shape.IsMklTensor() && !eager_mode) {
742     return input_mkl_shape.GetTfShape();
743   } else {
744     const Tensor& t = MklGetInput(context, input_idx);
745     return t.shape();
746   }
747 }
748 
749 // Allocate the second output tensor that will contain
750 // the MKL shape serialized
AllocateOutputSetMklShape(OpKernelContext * ctext,int n,const MklDnnShape & mkl_shape)751 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
752                                       const MklDnnShape& mkl_shape) {
753   Tensor* second_tensor = nullptr;
754   TensorShape second_shape;
755   second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
756   OP_REQUIRES_OK(ctext, ctext->allocate_output(
757                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
758                             second_shape, &second_tensor));
759   mkl_shape.SerializeMklDnnShape(
760       second_tensor->flat<uint8>().data(),
761       second_tensor->flat<uint8>().size() * sizeof(uint8));
762 }
763 
764 // Allocate the output tensor, create a second output tensor that will contain
765 // the MKL shape serialized
766 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
767                                       Tensor** output,
768                                       const TensorShape& tf_shape,
769                                       const MklDnnShape& mkl_shape,
770                                       bool eager_mode = false) {
771   OP_REQUIRES_OK(
772       ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
773                                     tf_shape, output));
774   if (!eager_mode) {
775     Tensor* second_tensor = nullptr;
776     TensorShape second_shape;
777     second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
778     OP_REQUIRES_OK(ctext, ctext->allocate_output(
779                               GetTensorMetaDataIndex(n, ctext->num_outputs()),
780                               second_shape, &second_tensor));
781     mkl_shape.SerializeMklDnnShape(
782         second_tensor->flat<uint8>().data(),
783         second_tensor->flat<uint8>().size() * sizeof(uint8));
784   }
785 }
786 
787 // Allocates a temp tensor and returns the data buffer for temporary storage.
788 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,const memory::desc & pd,void ** buf_out)789 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
790                            const memory::desc& pd, void** buf_out) {
791   TensorShape tf_shape;
792 
793   tf_shape.AddDim(pd.get_size() / sizeof(T) + 1);
794   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
795                                                  tf_shape, tensor_out));
796   *buf_out = static_cast<void*>(tensor_out->flat<T>().data());
797 }
798 
799 template <typename T>
AllocTmpBuffer(OpKernelContext * context,Tensor * tensor_out,TensorShape tf_shape)800 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
801                            TensorShape tf_shape) {
802   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
803                                                  tf_shape, tensor_out));
804 }
805 
806 template <typename T>
807 struct UserScratchPad {
808   template <typename MklPrim>
809   // NOTE: if scratchpad is not required for a particular primitive the
810   //      spad_md.get_size() will return 0. It is fine to return
811   //      nullptr in this case
AllocateSPTensorUserScratchPad812   inline void AllocateSPTensor(MklPrim* mkl_prim, OpKernelContext* context) {
813     allocated_ = false;
814     auto spad_md = mkl_prim->GetScratchPadDesc();
815     size_t spad_size = spad_md.get_size();
816     if (spad_size == 0) return;
817 
818     size_t allocate_size = (spad_size + sizeof(T) - 1) / sizeof(T);
819     TensorShape tf_shape;
820     tf_shape.AddDim(allocate_size);
821     AllocTmpBuffer<T>(context, &scratch_pad_, tf_shape);
822     allocated_ = true;
823   }
GetUserScratchPad824   inline void* Get() {
825     if (allocated_) {
826       return static_cast<void*>(scratch_pad_.flat<T>().data());
827     } else {
828       return nullptr;
829     }
830   }
831 
832  private:
833   Tensor scratch_pad_;
834   bool allocated_ = false;
835 };
836 
GetStridesFromSizes(MklTensorFormat data_format,size_t * strides,const size_t * sizes)837 inline void GetStridesFromSizes(MklTensorFormat data_format, size_t* strides,
838                                 const size_t* sizes) {
839   DCHECK_NE(data_format, MklTensorFormat::FORMAT_INVALID);
840   // MKL requires strides in NCHW
841   if (data_format == MklTensorFormat::FORMAT_NHWC) {
842     strides[0] = sizes[2];
843     strides[1] = sizes[0] * sizes[2];
844     strides[2] = 1;
845     strides[3] = sizes[0] * sizes[1] * sizes[2];
846   } else {
847     strides[0] = 1;
848     strides[1] = sizes[0];
849     strides[2] = sizes[0] * sizes[1];
850     strides[3] = sizes[0] * sizes[1] * sizes[2];
851   }
852 }
853 
CopyMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)854 inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
855                                  int idx_out) {
856   int num_inputs = context->num_inputs();
857   int num_outputs = context->num_outputs();
858   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
859   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
860   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
861   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
862 
863   const Tensor& data = context->input(idx_data_in);
864   const Tensor& meta = context->input(idx_meta_in);
865   Tensor output(data.dtype());
866   Tensor meta_output(meta.dtype());
867 
868   // TODO(intel-tf): alternatively, call forward_input_to_output_with_shape(...)
869   CHECK(output.CopyFrom(data, data.shape()));
870   CHECK(meta_output.CopyFrom(meta, meta.shape()));
871   context->set_output(idx_data_out, output);
872   context->set_output(idx_meta_out, meta_output);
873 }
874 
CopyTfTensorInToOutWithShape(OpKernelContext * context,int idx_in,int idx_out,const TensorShape & shape)875 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
876                                          int idx_out,
877                                          const TensorShape& shape) {
878   int num_inputs = context->num_inputs();
879   int num_outputs = context->num_outputs();
880   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
881   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
882 
883   const Tensor& data = context->input(idx_data_in);
884   MklDnnShape mkl_shape_output;
885   mkl_shape_output.SetMklTensor(false);
886   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
887   Tensor output(data.dtype());
888   // TODO(intel-tf): alternatively, call forward_input_to_output_with_shape(...)
889   CHECK(output.CopyFrom(data, shape));
890   context->set_output(idx_data_out, output);
891 }
892 
ForwardTfTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)893 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
894                                    int idx_out) {
895   int num_inputs = context->num_inputs();
896   int num_outputs = context->num_outputs();
897   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
898   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
899 
900   MklDnnShape dnn_shape_output;
901   dnn_shape_output.SetMklTensor(false);
902   AllocateOutputSetMklShape(context, idx_out, dnn_shape_output);
903   if (IsRefType(context->input_dtype(idx_data_in))) {
904     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
905   } else {
906     context->set_output(idx_data_out, context->input(idx_data_in));
907   }
908 }
909 
ForwardMklTensorInToOut(OpKernelContext * context,int idx_in,int idx_out)910 inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
911                                     int idx_out) {
912   int num_inputs = context->num_inputs();
913   int num_outputs = context->num_outputs();
914   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
915   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
916   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
917   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
918 
919   if (IsRefType(context->input_dtype(idx_data_in))) {
920     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
921     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
922   } else {
923     context->set_output(idx_data_out, context->input(idx_data_in));
924     context->set_output(idx_meta_out, context->input(idx_meta_in));
925   }
926 }
927 
928 // Set a dummy oneDNN shape (called when the output is in TF format)
SetDummyMklDnnShapeOutput(OpKernelContext * context,uint32 idx_data_out)929 inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
930                                       uint32 idx_data_out) {
931   MklDnnShape mkl_shape_output;
932   mkl_shape_output.SetMklTensor(false);
933   AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
934 }
935 
936 // If the input tensor has ref count as 1, it is forwarded to the desired
937 // output port and the function returns true. In that case, it also allocates
938 // the serialized MklDnnShape object. Otherwise, the function returns false.
939 inline bool ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
940                                                 int idx_in, int idx_out,
941                                                 Tensor** output,
942                                                 const MklDnnShape& mkl_shape,
943                                                 bool always_forward = true) {
944   int num_inputs = context->num_inputs();
945   int num_outputs = context->num_outputs();
946   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
947   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
948   bool is_forwarded = false;
949   const Tensor& input_tensor = context->input(idx_data_in);
950   const auto output_shape = input_tensor.shape();
951   if (always_forward) {
952     if (IsRefType(context->input_dtype(idx_data_in))) {
953       context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
954     } else {
955       context->set_output(idx_data_out, input_tensor);
956     }
957   } else {
958     is_forwarded = context->forward_input_to_output_with_shape(
959         idx_data_in, idx_data_out, output_shape, output);
960   }
961   if (is_forwarded || always_forward) {
962     AllocateOutputSetMklShape(context, idx_out, mkl_shape);
963     return true;
964   }
965   return false;
966 }
967 
968 // Forward the MKL shape ONLY (used in elementwise and other ops where
969 // we call the eigen implementation and MKL shape is not used)
ForwardMklMetaDataInToOut(OpKernelContext * context,uint32 idx_data_in,uint32_t idx_data_out)970 inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
971                                       uint32 idx_data_in,
972                                       uint32_t idx_data_out) {
973   uint32 idx_meta_in =
974       GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
975   uint32 idx_meta_out =
976       GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
977 
978   if (IsRefType(context->input_dtype(idx_data_in))) {
979     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
980   } else {
981     context->set_output(idx_meta_out, context->input(idx_meta_in));
982   }
983 }
984 
985 // -------------------------------------------------------------------
986 //          Common utility functions used by MKL unit tests
987 
GetMklMetaTensor()988 inline Tensor GetMklMetaTensor() {
989   MklDnnShape non_mkl_shape;
990   non_mkl_shape.SetMklTensor(false);
991 
992   auto size = static_cast<int64_t>(non_mkl_shape.GetSerializeBufferSize());
993   Tensor tensor(DT_UINT8, {size});
994 
995   non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
996                                      size * sizeof(uint8));
997   return tensor;
998 }
999 
1000 // -------------------------------------------------------------------
1001 
1002 /// Return oneDNN data type (memory::data_type) for input type T
1003 ///
1004 /// @input None
1005 /// @return memory::data_type corresponding to type T
1006 template <typename T>
1007 static memory::data_type MklDnnType();
1008 
1009 /// Instantiation for float type. Add similar instantiations for other
1010 /// type if needed.
1011 template <>
1012 memory::data_type MklDnnType<float>() {
1013   return memory::data_type::f32;
1014 }
1015 
1016 template <>
1017 memory::data_type MklDnnType<quint8>() {
1018   return memory::data_type::u8;
1019 }
1020 
1021 template <>
1022 memory::data_type MklDnnType<uint8>() {
1023   return memory::data_type::u8;
1024 }
1025 
1026 template <>
1027 memory::data_type MklDnnType<qint8>() {
1028   return memory::data_type::s8;
1029 }
1030 
1031 template <>
1032 memory::data_type MklDnnType<qint32>() {
1033   return memory::data_type::s32;
1034 }
1035 template <>
1036 memory::data_type MklDnnType<bfloat16>() {
1037   return memory::data_type::bf16;
1038 }
1039 
1040 // Map MklTensorFormat to oneDNN format tag
1041 //
1042 // @input: MklTensorFormat i.e. TensorFlow data format
1043 // @return: oneDNN's memory format tag corresponding to MklTensorFormat.
1044 //          Fails with an error if invalid data format.
MklTensorFormatToMklDnnDataFormat(MklTensorFormat format)1045 inline memory::format_tag MklTensorFormatToMklDnnDataFormat(
1046     MklTensorFormat format) {
1047   if (format == MklTensorFormat::FORMAT_NHWC) return memory::format_tag::nhwc;
1048   if (format == MklTensorFormat::FORMAT_NCHW) return memory::format_tag::nchw;
1049   if (format == MklTensorFormat::FORMAT_NDHWC) return memory::format_tag::ndhwc;
1050   if (format == MklTensorFormat::FORMAT_NCDHW) return memory::format_tag::ncdhw;
1051   if (format == MklTensorFormat::FORMAT_X) return memory::format_tag::x;
1052   if (format == MklTensorFormat::FORMAT_NC) return memory::format_tag::nc;
1053   if (format == MklTensorFormat::FORMAT_TNC) return memory::format_tag::tnc;
1054   return memory::format_tag::undef;
1055 }
1056 
1057 /// Map TensorFlow data format into oneDNN 3D data format
1058 /// @input: TensorFlow data format
1059 /// @return: oneDNN 3D data format corresponding to TensorFlow data format;
1060 ///          Fails with an error if invalid data format.
TFDataFormatToMklDnn3DDataFormat(TensorFormat format)1061 inline MklTensorFormat TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
1062   if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NDHWC;
1063   if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCDHW;
1064   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1065   return MklTensorFormat::FORMAT_INVALID;
1066 }
1067 
1068 /// Map TensorFlow data format into oneDNN data format
1069 ///
1070 /// @input: TensorFlow data format
1071 /// @return: oneDNN data format corresponding to TensorFlow data format;
1072 ///          Fails with an error if invalid data format.
TFDataFormatToMklDnnDataFormat(TensorFormat format)1073 inline MklTensorFormat TFDataFormatToMklDnnDataFormat(TensorFormat format) {
1074   if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NHWC;
1075   if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCHW;
1076   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1077   return MklTensorFormat::FORMAT_INVALID;
1078 }
1079 
1080 /// Map oneDNN data format into TensorFlow data format
1081 ///
1082 /// @input: oneDNN data format
1083 /// @return: Tensorflow data format corresponding to oneDNN data format;
1084 ///          Fails with an error if invalid data format.
MklDnnDataFormatToTFDataFormat(MklTensorFormat format)1085 inline TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format) {
1086   if (format == MklTensorFormat::FORMAT_NHWC ||
1087       format == MklTensorFormat::FORMAT_NDHWC)
1088     return FORMAT_NHWC;
1089   if (format == MklTensorFormat::FORMAT_NCHW ||
1090       format == MklTensorFormat::FORMAT_NCDHW)
1091     return FORMAT_NCHW;
1092   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1093 
1094   // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure
1095   // that we don't come here.
1096   return FORMAT_NHWC;
1097 }
1098 
1099 /// Map TensorShape object into memory::dims required by oneDNN
1100 ///
1101 /// This function will simply map input TensorShape into oneDNN dims
1102 /// naively. So it will preserve the order of dimensions. E.g., if
1103 /// input tensor is in NHWC format, then dims will be in NHWC format also.
1104 ///
1105 /// @input TensorShape object in shape
1106 /// @return memory::dims corresponding to TensorShape
TFShapeToMklDnnDims(const TensorShape & shape)1107 inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
1108   memory::dims dims(shape.dims());
1109   for (int d = 0; d < shape.dims(); ++d) {
1110     dims[d] = shape.dim_size(d);
1111   }
1112   return dims;
1113 }
1114 
1115 /// Map TensorShape object into memory::dims in NCHW format required by oneDNN
1116 ///
1117 /// This function is a specific one than above function. It will map input
1118 /// TensorShape into oneDNN dims in NCHW format. So it may not preserve the
1119 /// order of dimensions. E.g., if input tensor is in NHWC format, then dims
1120 /// will be in NCHW format, and not in NHWC format.
1121 ///
1122 /// @input TensorShape object in shape
1123 /// @return memory::dims in oneDNN required NCHW format
TFShapeToMklDnnDimsInNCHW(const TensorShape & shape,TensorFormat format)1124 inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
1125                                               TensorFormat format) {
1126   // Check validity of format.
1127   DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1128             MklTensorFormat::FORMAT_INVALID);
1129 
1130   int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
1131   int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
1132   int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
1133   int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
1134 
1135   // oneDNN requires dimensions in NCHW format.
1136   return memory::dims({n, c, h, w});
1137 }
1138 
TFShapeToMklDnnDimsInNCDHW(const TensorShape & shape,TensorFormat format)1139 inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
1140                                                TensorFormat format) {
1141   // Validate format.
1142   DCHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
1143             MklTensorFormat::FORMAT_INVALID);
1144 
1145   int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
1146   int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
1147   int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
1148   int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
1149   int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
1150 
1151   // oneDNN requires dimensions in NCDHW format.
1152   return memory::dims({n, c, d, h, w});
1153 }
1154 
1155 /// Overloaded version of function TFShapeToMklDnnDimsInNCHW above.
1156 /// Input parameters are self-explanatory.
MklDnnDimsInNCHW(const memory::dims & in_dims,TensorFormat format)1157 inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
1158                                      TensorFormat format) {
1159   // Validate format.
1160   DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1161             MklTensorFormat::FORMAT_INVALID);
1162 
1163   int n = in_dims[GetTensorDimIndex(format, 'N')];
1164   int c = in_dims[GetTensorDimIndex(format, 'C')];
1165   int h = in_dims[GetTensorDimIndex(format, 'H')];
1166   int w = in_dims[GetTensorDimIndex(format, 'W')];
1167 
1168   // oneDNN requires dimensions in NCHW format.
1169   return memory::dims({n, c, h, w});
1170 }
1171 
1172 /// Overloaded version of function TFShapeToMklDnnDimsInNCDHW above.
1173 /// Input parameters are self-explanatory.
MklDnnDimsInNCDHW(const memory::dims & in_dims,TensorFormat format)1174 inline memory::dims MklDnnDimsInNCDHW(const memory::dims& in_dims,
1175                                       TensorFormat format) {
1176   // Validate format.
1177   DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1178             MklTensorFormat::FORMAT_INVALID);
1179 
1180   int n = in_dims[GetTensorDimIndex<3>(format, 'N')];
1181   int c = in_dims[GetTensorDimIndex<3>(format, 'C')];
1182   int d = in_dims[GetTensorDimIndex<3>(format, '0')];
1183   int h = in_dims[GetTensorDimIndex<3>(format, '1')];
1184   int w = in_dims[GetTensorDimIndex<3>(format, '2')];
1185 
1186   // MKL DNN requires dimensions in NCDHW format.
1187   return memory::dims({n, c, d, h, w});
1188 }
1189 
1190 /// Map MklDnn memory::dims object into TensorShape object.
1191 ///
1192 /// This function will simply map input shape in oneDNN memory::dims format
1193 /// in Tensorflow's TensorShape object by preserving dimension order.
1194 ///
1195 /// @input oneDNN memory::dims object
1196 /// @output TensorShape corresponding to memory::dims
MklDnnDimsToTFShape(const memory::dims & dims)1197 inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) {
1198   std::vector<int32> shape(dims.size(), -1);
1199   for (int d = 0; d < dims.size(); d++) {
1200     shape[d] = dims[d];
1201   }
1202 
1203   TensorShape ret;
1204   CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true);
1205   return ret;
1206 }
1207 
1208 /// Function to calculate strides given tensor shape in Tensorflow order
1209 /// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention,
1210 /// dimension with size 1 is outermost dimension; while dimension with size 4 is
1211 /// innermost dimension. So strides for this tensor would be {4 * 3 * 2,
1212 /// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}.
1213 ///
1214 /// @input Tensorflow shape in memory::dims type
1215 /// @return memory::dims containing strides for the tensor.
CalculateTFStrides(const memory::dims & dims_tf_order)1216 inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
1217   CHECK_GT(dims_tf_order.size(), 0);
1218   memory::dims strides(dims_tf_order.size());
1219   int last_dim_idx = dims_tf_order.size() - 1;
1220   strides[last_dim_idx] = 1;
1221   for (int d = last_dim_idx - 1; d >= 0; d--) {
1222     strides[d] = strides[d + 1] * dims_tf_order[d + 1];
1223   }
1224   return strides;
1225 }
1226 
1227 /// Helper function to create memory descriptor in Blocked format
1228 ///
1229 /// @input: Tensor dimensions
1230 /// @input: strides corresponding to dimensions. One can use utility
1231 ///         function such as CalculateTFStrides to compute strides
1232 ///         for given dimensions.
1233 /// @output: dnnl_memory_desc_t object corresponding to blocked memory
1234 ///          format for given dimensions and strides.
1235 /// @return: Status indicating whether the blocked memory descriptor
1236 ///          was successfully created.
CreateBlockedMemDescHelper(const memory::dims & dim,const memory::dims & strides,memory::data_type dtype,dnnl_memory_desc_t * blocked_md)1237 inline Status CreateBlockedMemDescHelper(const memory::dims& dim,
1238                                          const memory::dims& strides,
1239                                          memory::data_type dtype,
1240                                          dnnl_memory_desc_t* blocked_md) {
1241   DCHECK_EQ(dim.size(), strides.size());
1242   const int kNumDims = dim.size();
1243   dnnl_dim_t* input_dims = new dnnl_dim_t[kNumDims];
1244   dnnl_dim_t* input_strides = new dnnl_dim_t[kNumDims];
1245   for (int i = 0; i < kNumDims; ++i) {
1246     input_dims[i] = dim[i];
1247     input_strides[i] = strides[i];
1248   }
1249   try {
1250     dnnl_memory_desc_init_by_strides(blocked_md, kNumDims, input_dims,
1251                                      memory::convert_to_c(dtype),
1252                                      input_strides);
1253     delete[] input_dims;
1254     delete[] input_strides;
1255   } catch (dnnl::error& e) {
1256     delete[] input_dims;
1257     delete[] input_strides;
1258     return Status(error::Code::INTERNAL,
1259                   tensorflow::strings::StrCat(
1260                       "Failed to create blocked memory descriptor.",
1261                       "Status: ", e.status, ", message: ", e.message));
1262   }
1263   return Status::OK();
1264 }
1265 
1266 inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
1267                                     const memory& src_mem,
1268                                     const memory& dst_mem, const engine& engine,
1269                                     OpKernelContext* ctx = nullptr) {
1270   std::vector<primitive> net;
1271   net.push_back(dnnl::reorder(reorder_desc));
1272   std::vector<MemoryArgsMap> net_args;
1273   net_args.push_back({{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, dst_mem}});
1274   ExecutePrimitive(net, &net_args, engine, ctx);
1275 }
1276 
1277 class MklReorderPrimitive;
1278 
1279 template <typename T>
1280 inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
1281                                                 const memory* to);
1282 
1283 // Class to represent all the resources corresponding to a tensor in TensorFlow
1284 // that are required to execute an operation (such as Convolution).
1285 template <typename T>
1286 class MklDnnData {
1287  private:
1288   /// oneDNN memory primitive for input user memory
1289   memory* user_memory_;
1290 
1291   /// oneDNN memory primitive in case input or output reorder is needed.
1292   memory* reorder_memory_;
1293 
1294   /// Operations memory descriptor
1295   memory::desc* op_md_;
1296   // flat to indicate if data is 3D or not.
1297   bool bIs3D;
1298   /// Operations temp buffer
1299   void* allocated_buffer_;
1300   /// CPU engine on which operation will be executed
1301   const engine* cpu_engine_;
1302 
1303  public:
MklDnnData(const engine * e)1304   explicit MklDnnData(const engine* e)
1305       : user_memory_(nullptr),
1306         reorder_memory_(nullptr),
1307         op_md_(nullptr),
1308         bIs3D(false),
1309         allocated_buffer_(nullptr),
1310         cpu_engine_(e) {}
1311 
1312   // MklDnnData does not use any smart pointers,
1313   // hence default operator= will result in memory leak if user_memory was
1314   // already initialized. See
1315   // https://github.com/tensorflow/tensorflow/pull/45593 as an example of such
1316   // leak.
1317   MklDnnData(const MklDnnData&) = default;
1318   MklDnnData& operator=(const MklDnnData&) = delete;
1319 
~MklDnnData()1320   ~MklDnnData() {
1321     if (allocated_buffer_ != nullptr) {
1322       cpu_allocator()->DeallocateRaw(allocated_buffer_);
1323     }
1324     cpu_engine_ = nullptr;  // We don't own this.
1325     delete (user_memory_);
1326     delete (reorder_memory_);
1327     delete (op_md_);
1328   }
1329 
GetTensorBuffer(const Tensor * tensor)1330   inline void* GetTensorBuffer(const Tensor* tensor) const {
1331     CHECK_NOTNULL(tensor);
1332     return const_cast<void*>(
1333         static_cast<const void*>(tensor->flat<T>().data()));
1334   }
1335 
SetIs3DData(bool bIs3D_)1336   void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
GetIs3D()1337   bool GetIs3D() { return bIs3D; }
1338 
1339   /// Set user memory primitive using specified dimensions, memory format tag
1340   /// and data_buffer. Function automatically uses element data type by using
1341   /// input type T used for creating call object.
1342   ///
1343   /// In a nutshell, function allows user to describe the input tensor to
1344   /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
1345   /// memory format tag HWIO, and the buffer that contains actual values is
1346   /// pointed by data_buffer.
1347   inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1348                         void* data_buffer = nullptr) {
1349     auto md = memory::desc(dim, MklDnnType<T>(), fm);
1350     SetUsrMem(md, data_buffer);
1351   }
1352 
SetUsrMem(const memory::dims & dim,memory::format_tag fm,const Tensor * tensor)1353   inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1354                         const Tensor* tensor) {
1355     DCHECK(tensor);
1356     SetUsrMem(dim, fm, GetTensorBuffer(tensor));
1357   }
1358 
1359   /// Helper function to create memory descriptor in Blocked format
1360   ///
1361   /// @input: Tensor dimensions
1362   /// @input: strides corresponding to dimensions. One can use utility
1363   ///         function such as CalculateTFStrides to compute strides
1364   ///         for given dimensions.
1365   /// @return: memory::desc object corresponding to blocked memory format
1366   ///          for given dimensions and strides.
CreateBlockedMemDesc(const memory::dims & dim,const memory::dims & strides)1367   static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim,
1368                                                   const memory::dims& strides) {
1369     dnnl_memory_desc_t blocked_md;
1370     TF_CHECK_OK(
1371         CreateBlockedMemDescHelper(dim, strides, MklDnnType<T>(), &blocked_md));
1372     return memory::desc(blocked_md);
1373   }
1374 
1375   /// A version of SetUsrMem call that allows user to create memory in blocked
1376   /// format. So in addition to accepting dimensions, it also accepts strides.
1377   /// This allows user to create memory for tensor in a format that is not
1378   /// supported by oneDNN. E.g., oneDNN does not support tensor format for 6
1379   /// dimensional tensor as a native format. But by using blocked format, a user
1380   /// can create memory for 6D tensor.
1381   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1382                         void* data_buffer = nullptr) {
1383     CHECK_EQ(dim.size(), strides.size());
1384     auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides);
1385     SetUsrMem(blocked_md, data_buffer);
1386   }
1387 
SetUsrMem(const memory::dims & dim,const memory::dims & strides,const Tensor * tensor)1388   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1389                         const Tensor* tensor) {
1390     CHECK_NOTNULL(tensor);
1391     SetUsrMem(dim, strides, GetTensorBuffer(tensor));
1392   }
1393 
1394   /// A version of SetUsrMem with memory descriptor and tensor
SetUsrMem(const memory::desc & md,const Tensor * tensor)1395   inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) {
1396     CHECK_NOTNULL(tensor);
1397     SetUsrMem(md, GetTensorBuffer(tensor));
1398   }
1399 
1400   /// A version of function to set user memory type that accepts memory
1401   /// descriptor directly, instead of accepting dimensions and format. This
1402   /// function is more generic than the one above, but the function above is
1403   /// sufficient in most cases.
1404   inline void SetUsrMem(const memory::desc& pd, void* data_buffer = nullptr) {
1405     DCHECK(cpu_engine_);
1406     if (user_memory_) delete user_memory_;
1407     // TODO(intel-tf): can we remove dynamic memory allocation?
1408     if (data_buffer) {
1409       user_memory_ = new memory(pd, *cpu_engine_, data_buffer);
1410     } else {
1411       user_memory_ = new memory(pd, *cpu_engine_);
1412     }
1413   }
1414 
1415   /// Get function for user memory primitive.
GetUsrMem()1416   inline const memory* GetUsrMem() const { return user_memory_; }
1417 
1418   /// Get function for descriptor of user memory.
GetUsrMemDesc()1419   inline memory::desc GetUsrMemDesc() const {
1420     DCHECK(user_memory_);
1421     return user_memory_->get_desc();
1422   }
1423 
1424   /// Get function for data buffer of user memory primitive.
GetUsrMemDataHandle()1425   inline void* GetUsrMemDataHandle() const {
1426     CHECK_NOTNULL(user_memory_);
1427     return user_memory_->get_data_handle();
1428   }
1429 
1430   /// Set function for data buffer of user memory primitive.
1431   inline void SetUsrMemDataHandle(void* data_buffer,
1432                                   std::shared_ptr<stream> t_stream = nullptr) {
1433     CHECK_NOTNULL(user_memory_);
1434     CHECK_NOTNULL(data_buffer);
1435 #ifndef ENABLE_ONEDNN_OPENMP
1436     user_memory_->set_data_handle(data_buffer, *t_stream);
1437 #else
1438     user_memory_->set_data_handle(data_buffer);
1439 #endif  // !ENABLE_ONEDNN_OPENMP
1440   }
1441 
1442   /// Set function for data buffer of user memory primitive.
1443   inline void SetUsrMemDataHandle(const Tensor* tensor,
1444                                   std::shared_ptr<stream> t_stream = nullptr) {
1445     SetUsrMemDataHandle(GetTensorBuffer(tensor), t_stream);
1446   }
1447 
1448   /// allocate function for data buffer
AllocateBuffer(size_t size)1449   inline void AllocateBuffer(size_t size) {
1450     const int64 kMemoryAlignment = 64;  // For AVX512 memory alignment.
1451     allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlignment, size);
1452   }
1453 
GetAllocatedBuffer()1454   inline void* GetAllocatedBuffer() { return allocated_buffer_; }
1455 
1456   /// Get the memory primitive for input and output of an op. If inputs
1457   /// to an op require reorders, then this function returns memory primitive
1458   /// for reorder. Otherwise, it will return memory primitive for user memory.
1459   ///
1460   /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
1461   /// execute Conv2D, we need memory primitive for I and F. But if reorder is
1462   /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
1463   /// primitive for F), then we need I_r and F_r to perform Conv2D.
GetOpMem()1464   inline const memory& GetOpMem() const {
1465     return reorder_memory_ ? *reorder_memory_ : *user_memory_;
1466   }
1467 
1468   /// Set memory descriptor of an operation in terms of dimensions and memory
1469   /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
1470   /// but memory::format_tag would be dnnl::any because we want oneDNN to
1471   /// choose the best layout/format for given input dimensions.
SetOpMemDesc(const memory::dims & dim,memory::format_tag fm)1472   inline void SetOpMemDesc(const memory::dims& dim, memory::format_tag fm) {
1473     // TODO(intel-tf): can we remove dynamic memory allocation?
1474     op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
1475   }
1476 
1477   /// Get function for memory descriptor for an operation
GetOpMemDesc()1478   inline const memory::desc& GetOpMemDesc() const { return *op_md_; }
1479 
1480   /// Predicate that checks if we need to reorder user's memory into memory
1481   /// pointed by op_md.
1482   ///
1483   /// @input: op_md - memory descriptor of the given input of an operation.
1484   /// @return: true in case reorder of input is needed; false, otherwise.
IsReorderNeeded(const memory::desc & op_pd)1485   inline bool IsReorderNeeded(const memory::desc& op_pd) const {
1486     DCHECK(user_memory_);
1487     return op_pd != user_memory_->get_desc();
1488   }
1489 
1490   /// Function to create a reorder from memory pointed by from to memory pointed
1491   /// by to. Returns created primitive.
CreateReorder(const memory * from,const memory * to)1492   inline primitive CreateReorder(const memory* from, const memory* to) const {
1493     CHECK_NOTNULL(from);
1494     CHECK_NOTNULL(to);
1495     return reorder(*from, *to);
1496   }
1497 
1498   /// Function to handle input reordering
1499   ///
1500   /// Check if we need to reorder this input of an operation.
1501   /// Return true and allocate reorder memory primitive if reorder is needed.
1502   /// Otherwise, return false and do not allocate reorder memory primitive.
1503   ///
1504   /// To check if reorder is needed, this function compares memory primitive
1505   /// descriptor (memory descriptor for v1.x) of an operation (op_pd) for
1506   /// the given input with the user-specified memory descriptor.
1507   ///
1508   /// @input: op_pd - memory primitive descriptor of the given input of an
1509   ///                 operation
1510   /// @input: net - net to which to add reorder primitive in case it is needed.
1511   /// @input: net_args - net to which user and reorder memories are added if
1512   ///                    needed. Each entry is a key-value pair of the form
1513   ///                    <argument-type, dnnl::memory>.
1514   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1515   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1516                                   std::vector<primitive>& net,
1517                                   std::vector<MemoryArgsMap>& net_args,
1518                                   const engine& engine) {
1519     DCHECK(user_memory_);
1520     DCHECK_EQ(net.size(), net_args.size());
1521     if (IsReorderNeeded(op_md)) {
1522       // TODO(intel-tf): can we remove dynamic memory allocation?
1523       reorder_memory_ = new memory(op_md, engine);
1524       net.push_back(CreateReorder(user_memory_, reorder_memory_));
1525       net_args.push_back(MemoryArgsMap{{DNNL_ARG_FROM, *user_memory_},
1526                                        {DNNL_ARG_TO, *reorder_memory_}});
1527       return true;
1528     }
1529     return false;
1530   }
1531 
1532   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1533                                   const engine& engine,
1534                                   OpKernelContext* context = nullptr) {
1535     DCHECK(user_memory_);
1536     if (IsReorderNeeded(op_md)) {
1537       // TODO(intel-tf): can we remove dynamic memory allocation?
1538       // primitive reuse don't allow two same reorder prim in
1539       // one stream, so submit it immediately
1540       reorder_memory_ = new memory(op_md, engine);
1541       auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1542       std::shared_ptr<stream> cpu_stream;
1543       MklDnnThreadPool eigen_tp;
1544       if (context != nullptr) {
1545         eigen_tp = MklDnnThreadPool(context);
1546         cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1547       } else {
1548         cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1549       }
1550       std::vector<primitive> net;
1551       net.push_back(*(prim->GetPrimitive()));
1552       std::vector<MemoryArgsMap> net_args;
1553       net_args.push_back(
1554           {{DNNL_ARG_FROM, *user_memory_}, {DNNL_ARG_TO, *reorder_memory_}});
1555       execute_primitives(net, cpu_stream, net_args);
1556       return true;
1557     }
1558     return false;
1559   }
1560 
1561   /// Overloaded version of above function that accepts memory buffer
1562   /// where output of reorder needs to be stored.
1563   ///
1564   /// @input: op_pd - memory primitive descriptor (memory descriptor for v1.x)
1565   ///                 of the given input of an operation
1566   /// @reorder_data_handle - memory buffer where output of reorder needs to be
1567   ///                        stored. Primitive does not check if buffer has
1568   ///                        enough size to write.
1569   /// @input: net - net to which to add reorder primitive in case it is needed.
1570   /// @input: net_args - net to which user and reorder memories are added if
1571   ///                    needed. Each entry is a key-value pair of the form
1572   ///                    <argument-type, dnnl::memory>.
1573   /// @input: engine - oneDNN's abstraction of a computational device
1574   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,void * reorder_data_handle,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1575   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1576                                   void* reorder_data_handle,
1577                                   std::vector<primitive>& net,
1578                                   std::vector<MemoryArgsMap>& net_args,
1579                                   const engine& engine) {
1580     DCHECK(reorder_data_handle);
1581     DCHECK(user_memory_);
1582     if (IsReorderNeeded(op_md)) {
1583       // TODO(intel-tf): can we remove dynamic memory allocation?
1584       reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1585       net.push_back(CreateReorder(user_memory_, reorder_memory_));
1586       net_args.push_back(MemoryArgsMap{{DNNL_ARG_FROM, *user_memory_},
1587                                        {DNNL_ARG_TO, *reorder_memory_}});
1588       return true;
1589     }
1590     return false;
1591   }
1592 
1593   /// This is a faster path with reorder primitive cache compared with
1594   /// CheckReorderToOpMem(..., std::vector<primitive>* net).
1595   /// The slower path will be removed in the future
1596   /// TODO(intel-tf): Need to use reorder cache here for better performance.
1597   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1598                                   void* reorder_data_handle,
1599                                   const engine& engine,
1600                                   OpKernelContext* context = nullptr) {
1601     DCHECK(reorder_data_handle);
1602     DCHECK(user_memory_);
1603     if (IsReorderNeeded(op_md)) {
1604       // TODO(intel-tf): can we remove dynamic memory allocation?
1605       // primitive reuse don't allow two same reorder prim in
1606       // one stream, so submit it immediately
1607       reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1608       auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1609       std::shared_ptr<stream> cpu_stream;
1610       MklDnnThreadPool eigen_tp;
1611       if (context != nullptr) {
1612         eigen_tp = MklDnnThreadPool(context);
1613         cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1614       } else {
1615         cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1616       }
1617       std::vector<primitive> net;
1618       net.push_back(*(prim->GetPrimitive()));
1619       std::vector<MemoryArgsMap> net_args;
1620       net_args.push_back(
1621           {{DNNL_ARG_FROM, *user_memory_}, {DNNL_ARG_TO, *reorder_memory_}});
1622       execute_primitives(net, cpu_stream, net_args);
1623       return true;
1624     }
1625     return false;
1626   }
1627 
1628   /// Another overloaded version of CheckReorderToOpMem that accepts Tensor
1629   /// where output of reorder needs to be stored.
1630   ///
1631   /// @input: op_md - memory primitive descriptor (memory descriptor for v1.x)
1632   ///                 of the given input of an operation
1633   /// @reorder_tensor - Tensor whose buffer is to be used to store output of
1634   ///                   reorder. Primitive does not check if buffer is
1635   ///                   enough size to write.
1636   /// @input: net - net to which to add reorder primitive in case it is needed.
1637   /// @input: net_args - net to which user and reorder memories are added if
1638   ///                    needed. Each entry is a key-value pair of the form
1639   ///                    <argument-type, dnnl::memory>.
1640   /// @input: engine - MKL-DNN's abstraction of a computational device
1641   /// @return: true in case reorder of input is needed; false, otherwise.
CheckReorderToOpMem(const memory::desc & op_md,Tensor * reorder_tensor,std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args,const engine & engine)1642   inline bool CheckReorderToOpMem(const memory::desc& op_md,
1643                                   Tensor* reorder_tensor,
1644                                   std::vector<primitive>& net,
1645                                   std::vector<MemoryArgsMap>& net_args,
1646                                   const engine& engine) {
1647     DCHECK(reorder_tensor);
1648     return CheckReorderToOpMem(op_md, GetTensorBuffer(reorder_tensor), net,
1649                                net_args, engine);
1650   }
1651 
1652   /// TODO(intel-tf): this is a faster path with reorder primitive cache
1653   /// compared with CheckReorderToOpMem(op_md, reorder_tensor, net, net_args,
1654   /// engine), will remove slow path in the future.
1655   inline bool CheckReorderToOpMem(const memory::desc& op_pd,
1656                                   Tensor* reorder_tensor,
1657                                   OpKernelContext* ctx = nullptr) {
1658     DCHECK(reorder_tensor);
1659     return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
1660                                *cpu_engine_, ctx);
1661   }
1662 
1663   /// Function to handle output reorder
1664   ///
1665   /// This function performs very similar functionality as input reordering
1666   /// function above. The only difference is that this function does not add
1667   /// reorder primitive to the net. The reason for this is: the reorder
1668   /// primitive for output needs to be added to the list only after operation
1669   /// has executed. But we need to prepare a temporary buffer in case output
1670   /// reorder is needed. And this temporary buffer will hold the output of
1671   /// an operation before it is fed to reorder primitive.
1672   ///
1673   /// @input - memory primitive descriptor (memory descriptor for v1.x) for the
1674   ///          given output of an operation
1675   /// @return: true in case reorder of output is needed; false, otherwise.
PrepareReorderToUserMemIfReq(const memory::desc & op_pd)1676   inline bool PrepareReorderToUserMemIfReq(const memory::desc& op_pd) {
1677     DCHECK(user_memory_);
1678     if (IsReorderNeeded(op_pd)) {
1679       // TODO(intel-tf): can we remove dynamic memory allocation?
1680       reorder_memory_ = new memory(op_pd, *cpu_engine_);
1681       return true;
1682     }
1683     return false;
1684   }
1685 
1686   /// Function to actually insert reorder primitive in the net
1687   ///
1688   /// This function completes remaining part of output reordering. It inserts
1689   /// a reordering primitive from the temporary buffer that holds the output
1690   /// to the user-specified output buffer.
1691   ///
1692   /// @input: net - net to which to add reorder primitive
1693   /// @input: net_args - net to which user and reorder memories are added if
1694   ///                    needed. Each entry is a key-value pair of the form
1695   ///                    <argument-type, dnnl::memory>.
InsertReorderToUserMem(std::vector<primitive> & net,std::vector<MemoryArgsMap> & net_args)1696   inline void InsertReorderToUserMem(std::vector<primitive>& net,
1697                                      std::vector<MemoryArgsMap>& net_args) {
1698     DCHECK(user_memory_);
1699     DCHECK(reorder_memory_);
1700     net.push_back(CreateReorder(reorder_memory_, user_memory_));
1701     net_args.push_back(MemoryArgsMap{{DNNL_ARG_FROM, *reorder_memory_},
1702                                      {DNNL_ARG_TO, *user_memory_}});
1703   }
1704 
1705   /// TODO(intel-tf): this is a faster path with reorder primitive cache
1706   ///     compared with InsertReorderToUserMem(net, net_args), will remove
1707   ///     slow path in the future
1708   inline void InsertReorderToUserMem(OpKernelContext* ctx = nullptr) {
1709     DCHECK(user_memory_);
1710     DCHECK(reorder_memory_);
1711     DCHECK(cpu_engine_);
1712     // primitive reuse don't allow two same reorder prim in
1713     // one stream, so submit it immediately
1714     std::vector<primitive> net;
1715     auto* prim = FindOrCreateReorder<T>(reorder_memory_, user_memory_);
1716     net.push_back(*(prim->GetPrimitive()));
1717     std::vector<MemoryArgsMap> net_args;
1718     net_args.push_back(
1719         {{DNNL_ARG_FROM, *reorder_memory_}, {DNNL_ARG_TO, *user_memory_}});
1720     std::shared_ptr<stream> cpu_stream;
1721     MklDnnThreadPool eigen_tp;
1722     if (ctx != nullptr) {
1723       eigen_tp = MklDnnThreadPool(ctx);
1724       cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1725     } else {
1726       cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1727     }
1728     execute_primitives(net, cpu_stream, net_args);
1729   }
1730 };
1731 
1732 /// Base class for operations with reuse of primitives
1733 class MklPrimitive {
1734  public:
~MklPrimitive()1735   virtual ~MklPrimitive() {}
MklPrimitive()1736   MklPrimitive() {}
MklPrimitive(const engine & cpu_engine)1737   MklPrimitive(const engine& cpu_engine) { cpu_engine_ = cpu_engine; }
1738   // Dummy data which MKL DNN never operates on
1739   unsigned char* DummyData = nullptr;
1740   engine cpu_engine_ = engine(engine::kind::cpu, 0);
GetEngine()1741   const engine& GetEngine() { return cpu_engine_; }
1742 };
1743 
1744 const dnnl::memory::dims NONE_DIMS = {};
1745 
1746 //
1747 // LRUCache is a class which implements LRU (Least Recently Used) cache.
1748 // The implementation is similar to that of
1749 //    tensorflow/core/platform/cloud/expiring_lru_cache.h
1750 // without its thread-safe part because the cache is supposed to be
1751 // used as thread local (for instance, MklPrimitive caching).
1752 //
1753 // The LRU list maintains objects in chronological order based on
1754 // creation time, with the least recently accessed object at the
1755 // tail of LRU list, while the most recently accessed object
1756 // at the head of LRU list.
1757 //
1758 // This class is used to maintain an upper bound on the total number of
1759 // cached items. When the cache reaches its capacity, the LRU item will
1760 // be removed and replaced by a new one from SetOp call.
1761 //
1762 template <typename T>
1763 class LRUCache {
1764  public:
LRUCache(size_t capacity)1765   explicit LRUCache(size_t capacity) {
1766     capacity_ = capacity;
1767     Clear();
1768   }
1769 
GetOp(const string & key)1770   T* GetOp(const string& key) {
1771 #ifdef DNNL_AARCH64_USE_ACL
1772     mutex_lock lock(lru_mu_);
1773 #endif
1774     auto it = cache_.find(key);
1775     if (it == cache_.end()) {
1776       return nullptr;
1777     }
1778 
1779     // Move to the front of LRU list as the most recently accessed.
1780     lru_list_.erase(it->second.lru_iterator);
1781     lru_list_.push_front(it->first);
1782     it->second.lru_iterator = lru_list_.begin();
1783     return it->second.op;
1784   }
1785 
SetOp(const string & key,T * op)1786   void SetOp(const string& key, T* op) {
1787 #ifdef DNNL_AARCH64_USE_ACL
1788     mutex_lock lock(lru_mu_);
1789 #endif
1790     if (lru_list_.size() >= capacity_) {
1791       Delete();
1792     }
1793 
1794     // Insert an entry to the front of the LRU list
1795     lru_list_.push_front(key);
1796     Entry entry(op, lru_list_.begin());
1797     cache_.emplace(std::make_pair(key, std::move(entry)));
1798 #ifdef DNNL_AARCH64_USE_ACL
1799     FinishedAllocation(key);
1800 #endif
1801   }
1802 
Clear()1803   void Clear() {
1804     if (lru_list_.empty()) return;
1805 
1806     // Clean up the cache
1807     cache_.clear();
1808     lru_list_.clear();
1809   }
1810 
1811 #ifdef DNNL_AARCH64_USE_ACL
IsAllocating(const string & key)1812   bool IsAllocating(const string& key) {
1813     mutex_lock lock(in_flight_mu_);
1814     return in_flight_.find(key) != in_flight_.end();
1815   }
1816 
Allocate(const string & key)1817   void Allocate(const string& key) {
1818     mutex_lock lock(in_flight_mu_);
1819     in_flight_.insert(key);
1820   }
1821 
FinishedAllocation(const string & key)1822   void FinishedAllocation(const string& key) {
1823     mutex_lock lock(in_flight_mu_);
1824     in_flight_.erase(key);
1825   }
1826 #endif
1827 
1828  private:
1829   struct Entry {
1830     // The entry's value.
1831     T* op;
1832 
1833     // A list iterator pointing to the entry's position in the LRU list.
1834     std::list<string>::iterator lru_iterator;
1835 
1836     // Constructor
EntryEntry1837     Entry(T* op, std::list<string>::iterator it) {
1838       this->op = op;
1839       this->lru_iterator = it;
1840     }
1841 
1842     // Move constructor
EntryEntry1843     Entry(Entry&& source) noexcept
1844         : lru_iterator(std::move(source.lru_iterator)) {
1845       op = std::move(source.op);
1846       source.op = std::forward<T*>(nullptr);
1847     }
1848 
1849     // Destructor
~EntryEntry1850     ~Entry() {
1851       if (op != nullptr) delete op;
1852     }
1853   };
1854 
1855   // Remove the least recently accessed entry from LRU list, which
1856   // is the tail of lru_list_. Update cache_ correspondingly.
Delete()1857   bool Delete() {
1858     if (lru_list_.empty()) return false;
1859     string key = lru_list_.back();
1860     lru_list_.pop_back();
1861     cache_.erase(key);
1862     return true;
1863   }
1864 
1865   // Cache capacity
1866   size_t capacity_;
1867 
1868   // The cache, a map from string key to a LRU entry.
1869   std::unordered_map<string, Entry> cache_;
1870 
1871   // The LRU list of entries.
1872   // The front of the list contains the key of the most recently accessed
1873   // entry, while the back of the list is the least recently accessed entry.
1874   std::list<string> lru_list_;
1875 
1876 #ifdef DNNL_AARCH64_USE_ACL
1877   // Guards access to the cache and LRU list
1878   mutex lru_mu_;
1879 
1880   // The keys that are currently under creation
1881   std::set<string> in_flight_;
1882   TF_GUARDED_BY(in_flight_mu_)
1883   mutex in_flight_mu_;
1884 #endif
1885 };
1886 
1887 template <typename T>
1888 class MklPrimitiveFactory {
1889  public:
1890   MklPrimitiveFactory() {}
1891 
1892   ~MklPrimitiveFactory() {}
1893 
1894   MklPrimitive* GetOp(const string& key) {
1895 #ifndef DNNL_AARCH64_USE_ACL
1896     auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1897     return lru_cache.GetOp(key);
1898 #else
1899     while (true) {
1900       // TODO(milpuz01): Consider if it is possible to narrow scope to be
1901       // only around checks for allocations and conditional wait.
1902       mutex_lock lock(primitive_creation_mu_);
1903       auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1904 
1905       // Check to see whether primitive already exists.
1906       MklPrimitive* primitive = lru_cache.GetOp(key);
1907       if (primitive != nullptr) {
1908         return primitive;
1909       }
1910 
1911       // Now check whether some other thread is creating this primitive.
1912       if (!lru_cache.IsAllocating(key)) {
1913         // This thread is going to pick it up and create the primitive.
1914         lru_cache.Allocate(key);
1915         return nullptr;
1916         // Now we release lock as primitive creation might take long time.
1917       }
1918 
1919       // At this point we cannot create primitive as other thread is creating
1920       // it. We should wait for primitive to get created.
1921       primitive_creation_cv_.wait(lock);
1922 
1923       // The primitive is created and is in the cache so we are going to try
1924       // retrieve it again after getting a lock on it as multiple threads might
1925       // be waiting for the primitive.
1926     }
1927 #endif
1928   }
1929 
1930   void SetOp(const string& key, MklPrimitive* op) {
1931 #ifndef DNNL_AARCH64_USE_ACL
1932     auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1933     lru_cache.SetOp(key, op);
1934 #else
1935     {
1936       mutex_lock lock(primitive_creation_mu_);
1937       auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1938       lru_cache.SetOp(key, op);
1939     }
1940 
1941     // Now we can inform all waiting threads that primitive is created.
1942     primitive_creation_cv_.notify_all();
1943 #endif
1944   }
1945 
1946   /// Function to decide whether HW has AVX512 or AVX2
1947   /// For those legacy device(w/o AVX512 and AVX2),
1948   /// MKL-DNN GEMM will be used.
1949   static inline bool IsLegacyPlatform() {
1950     static const bool is_legacy_platform =
1951         (!port::TestCPUFeature(port::CPUFeature::AVX512F) &&
1952          !port::TestCPUFeature(port::CPUFeature::AVX2));
1953     return is_legacy_platform;
1954   }
1955 
1956   /// Function to check whether primitive memory optimization is enabled
1957   static inline bool IsPrimitiveMemOptEnabled() {
1958     static const bool is_primitive_mem_opt_enabled = [] {
1959       bool value = true;
1960       TF_CHECK_OK(
1961           ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true, &value));
1962       return value;
1963     }();
1964     return is_primitive_mem_opt_enabled;
1965   }
1966 
1967 #ifdef DNNL_AARCH64_USE_ACL
1968   static int IncrementCounter() {
1969     static std::atomic_int counter{1};
1970     return counter.fetch_add(1);
1971   }
1972 #endif
1973 
1974  private:
1975   static inline LRUCache<MklPrimitive>& GetLRUCache() {
1976     static const int kCapacity = 1024;  // cache capacity
1977 #ifndef DNNL_AARCH64_USE_ACL
1978     static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
1979 #else
1980     static LRUCache<MklPrimitive> lru_cache_(kCapacity);
1981     TF_GUARDED_BY(lru_mu_)
1982 #endif
1983     return lru_cache_;
1984   }
1985 
1986 #ifdef DNNL_AARCH64_USE_ACL
1987   mutex primitive_creation_mu_;
1988   condition_variable primitive_creation_cv_;
1989 #endif
1990 };
1991 
1992 // utility class for creating keys of MKL primitive pool.
1993 class FactoryKeyCreator {
1994  public:
1995   FactoryKeyCreator() { key_.reserve(kMaxKeyLength); }
1996 
1997   ~FactoryKeyCreator() {}
1998 
1999   void AddAsKey(const string& str) { Append(str); }
2000 
2001   void AddAsKey(const dnnl::memory::dims& dims) {
2002     for (unsigned int i = 0; i < dims.size(); i++) {
2003       AddAsKey<int>(dims[i]);
2004     }
2005   }
2006 
2007   template <typename T>
2008   void AddAsKey(const T data) {
2009     auto buffer = reinterpret_cast<const char*>(&data);
2010     Append(StringPiece(buffer, sizeof(T)));
2011   }
2012 
2013   // generalisation to handle pointers
2014   void AddAsKey(const void* data) {
2015     auto buffer = reinterpret_cast<const char*>(&data);
2016     Append(StringPiece(buffer, sizeof(data)));
2017   }
2018 
2019   string GetKey() { return key_; }
2020 
2021  private:
2022   string key_;
2023   const char delimiter = 'x';
2024   const int kMaxKeyLength = 256;
2025   void Append(StringPiece s) {
2026     key_.append(string(s));
2027     key_.append(1, delimiter);
2028   }
2029 };
2030 
2031 class MklReorderPrimitive : public MklPrimitive {
2032  public:
2033   explicit MklReorderPrimitive(const memory* from, const memory* to)
2034       : MklPrimitive(engine(engine::kind::cpu, 0)) {
2035     Setup(from, to);
2036   }
2037   ~MklReorderPrimitive() {}
2038 
2039   std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
2040 
2041   void SetMemory(const memory* from, const memory* to) {
2042     context_.src_mem->set_data_handle(from->get_data_handle());
2043     context_.dst_mem->set_data_handle(to->get_data_handle());
2044   }
2045 
2046   std::shared_ptr<dnnl::stream> GetStream() { return stream_; }
2047 
2048  private:
2049   struct ReorderContext {
2050     std::shared_ptr<dnnl::memory> src_mem;
2051     std::shared_ptr<dnnl::memory> dst_mem;
2052     std::shared_ptr<primitive> reorder_prim;
2053     ReorderContext()
2054         : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
2055   } context_;
2056 
2057   std::shared_ptr<dnnl::stream> stream_;
2058 
2059   void Setup(const memory* from, const memory* to) {
2060     context_.src_mem.reset(
2061         new memory(from->get_desc(), cpu_engine_, DummyData));
2062     context_.dst_mem.reset(new memory(to->get_desc(), cpu_engine_, DummyData));
2063     context_.reorder_prim = std::make_shared<dnnl::reorder>(
2064         reorder(*context_.src_mem, *context_.dst_mem));
2065     stream_.reset(new stream(cpu_engine_));
2066   }
2067 };
2068 
2069 template <typename T>
2070 class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
2071  public:
2072   static MklReorderPrimitive* Get(const memory* from, const memory* to) {
2073     auto reorderPrim = static_cast<MklReorderPrimitive*>(
2074         MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
2075     if (reorderPrim == nullptr) {
2076       reorderPrim = new MklReorderPrimitive(from, to);
2077       MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
2078                                                               reorderPrim);
2079     }
2080     reorderPrim->SetMemory(from, to);
2081     return reorderPrim;
2082   }
2083 
2084   static MklReorderPrimitiveFactory& GetInstance() {
2085     static MklReorderPrimitiveFactory instance_;
2086     return instance_;
2087   }
2088 
2089   static string CreateKey(const memory* from, const memory* to) {
2090     string prefix = "reorder";
2091     FactoryKeyCreator key_creator;
2092     auto const& from_desc = from->get_desc().data;
2093     auto const& to_desc = to->get_desc().data;
2094     memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
2095     memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
2096     auto from_strides = from_desc.format_desc.blocking.strides;
2097 
2098     // As DNNL memory desc has C style array and only init the used
2099     // part, so need use the valid part as key.
2100     auto from_inner_nblks = from_desc.format_desc.blocking.inner_nblks;
2101     auto from_inner_blks = from_desc.format_desc.blocking.inner_blks;
2102     auto from_inner_idxs = from_desc.format_desc.blocking.inner_idxs;
2103     memory::dims from_inner_blks_1(from_inner_blks,
2104                                    &from_inner_blks[from_inner_nblks]);
2105     memory::dims from_inner_idxs_1(from_inner_idxs,
2106                                    &from_inner_idxs[from_inner_nblks]);
2107     auto to_inner_nblks = to_desc.format_desc.blocking.inner_nblks;
2108     auto to_inner_blks = to_desc.format_desc.blocking.inner_blks;
2109     auto to_inner_idxs = to_desc.format_desc.blocking.inner_idxs;
2110     memory::dims to_inner_blks_1(to_inner_blks, &to_inner_blks[to_inner_nblks]);
2111     memory::dims to_inner_idxs_1(to_inner_idxs, &to_inner_idxs[to_inner_nblks]);
2112 
2113     auto to_strides = to_desc.format_desc.blocking.strides;
2114     memory::dims from_strides_outer_blocks(from_strides,
2115                                            &from_strides[from_desc.ndims]);
2116     memory::dims to_strides_outer_blocks(to_strides,
2117                                          &to_strides[to_desc.ndims]);
2118 
2119     key_creator.AddAsKey(prefix);
2120 #ifdef DNNL_AARCH64_USE_ACL
2121     // The reorder primitives have local memory (calls to SetMemory) so we
2122     // need to make sure that memory for those primitives is cached per thread.
2123     key_creator.AddAsKey(std::this_thread::get_id());
2124 #endif
2125     key_creator.AddAsKey(static_cast<int>(from_desc.extra.flags));
2126     key_creator.AddAsKey(static_cast<int>(from_inner_nblks));
2127     key_creator.AddAsKey(from_inner_blks_1);
2128     key_creator.AddAsKey(from_inner_idxs_1);
2129     key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
2130     key_creator.AddAsKey(from_dims);
2131     key_creator.AddAsKey(from_strides_outer_blocks);
2132     key_creator.AddAsKey(static_cast<int>(to_desc.extra.flags));
2133     key_creator.AddAsKey(static_cast<int>(to_inner_nblks));
2134     key_creator.AddAsKey(to_inner_blks_1);
2135     key_creator.AddAsKey(to_inner_idxs_1);
2136     key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
2137     key_creator.AddAsKey(to_dims);
2138     key_creator.AddAsKey(to_strides_outer_blocks);
2139     return key_creator.GetKey();
2140   }
2141 
2142  private:
2143   MklReorderPrimitiveFactory() {}
2144   ~MklReorderPrimitiveFactory() {}
2145 
2146   MklPrimitive* GetReorder(const memory* from, const memory* to) {
2147     string key = CreateKey(from, to);
2148     return this->GetOp(key);
2149   }
2150 
2151   void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
2152     string key = CreateKey(from, to);
2153     this->SetOp(key, op);
2154   }
2155 };
2156 
2157 /// Function to find(or create) a reorder from memory pointed by
2158 /// from to memory pointed by to, it will created primitive or
2159 /// get primitive from pool if it is cached.
2160 /// Returns the primitive.
2161 template <typename T>
2162 inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
2163                                                 const memory* to) {
2164   CHECK_NOTNULL(from);
2165   CHECK_NOTNULL(to);
2166   MklReorderPrimitive* reorder_prim =
2167       MklReorderPrimitiveFactory<T>::Get(from, to);
2168   return reorder_prim;
2169 }
2170 
2171 // utility function to determine if it is conv 1x1 and stride != 1
2172 // for purpose of temporarily disabling primitive reuse
2173 inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
2174                                 memory::dims strides) {
2175   if (filter_dims.size() != 4 || strides.size() != 2) return false;
2176 
2177   return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
2178           ((strides[0] != 1) || (strides[1] != 1)));
2179 }
2180 
2181 }  // namespace tensorflow
2182 
2183 /////////////////////////////////////////////////////////////////////
2184 // Macros for handling registration for various types
2185 /////////////////////////////////////////////////////////////////////
2186 
2187 #define REGISTER_TEST_FLOAT32(TEST) REGISTER_TEST(TEST, DT_FLOAT, Float32Input);
2188 
2189 #define REGISTER_TEST_BFLOAT16(TEST) \
2190   REGISTER_TEST(TEST, DT_BFLOAT16, BFloat16Input);
2191 
2192 #define REGISTER_TEST_ALL_TYPES(TEST) \
2193   REGISTER_TEST_FLOAT32(TEST);        \
2194   REGISTER_TEST_BFLOAT16(TEST);
2195 #else
2196 #define REGISTER_TEST_ALL_TYPES(TEST) REGISTER_TEST_FLOAT32(TEST);
2197 
2198 #endif  // INTEL_MKL
2199 #endif  // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
2200