xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_concat_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 
13 #ifdef INTEL_MKL
14 #define EIGEN_USE_THREADS
15 
16 #include <limits>
17 #include <unordered_map>
18 #include <vector>
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "dnnl.hpp"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_types.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/kernels/concat_lib.h"
29 #include "tensorflow/core/kernels/concat_lib_cpu.h"
30 #include "tensorflow/core/kernels/no_op.h"
31 #include "tensorflow/core/kernels/quantization_utils.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/util/mkl_util.h"
35 #ifdef DNNL_AARCH64_USE_ACL
36 #include "tensorflow/core/platform/mutex.h"
37 #endif
38 
39 using dnnl::concat;
40 using dnnl::stream;
41 
42 namespace tensorflow {
43 typedef Eigen::ThreadPoolDevice CPUDevice;
44 
45 // List of TensorShape objects. Used in Concat/Split layers.
46 typedef std::vector<TensorShape> TensorShapeList;
47 
48 enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
49 
50 // TODO(intel-tf) Check if we can reuse existing EigenConcatOp using Mutable
51 // reference inputs.
52 // --------------------------------------------------------------------------
53 //                      Eigen Concat Op
54 // --------------------------------------------------------------------------
55 namespace {
56 template <typename T>
57 struct RequantizeCopier {
RequantizeCopiertensorflow::__anond1ff67460111::RequantizeCopier58   RequantizeCopier(
59       const std::vector<std::pair<float, float>>* input_min_and_max,
60       float output_min, float output_max)
61       : output_min(output_min), output_max(output_max) {
62     DCHECK(input_min_and_max);
63     this->input_min_and_max = input_min_and_max;
64   }
65 
Copytensorflow::__anond1ff67460111::RequantizeCopier66   inline void Copy(T* dst, const T* src, int input_index, size_t n) {
67     const float input_min = (*input_min_and_max)[input_index].first;
68     const float input_max = (*input_min_and_max)[input_index].second;
69     if (input_min == output_min && input_max == output_max) {
70       DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum<T>::v()));
71       memcpy(dst, src, n * sizeof(T));
72     } else {
73       Eigen::array<Eigen::DenseIndex, 1> dims;
74       dims[0] = n;
75       typename TTypes<T, 1>::UnalignedConstTensor input_array(src, dims);
76       typename TTypes<T, 1>::UnalignedTensor output_array(dst, dims);
77 
78       QuantizedToFloatStruct<T> q2f(input_min, input_max);
79       auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f);
80       FloatToQuantizedStruct<T> f2q(output_min, output_max);
81       // RequantizeCopier::Copy is called from within a shard of computation, so
82       // don't use the threadpool device here, simply assign with default CPU
83       // device.
84       output_array = QUANTIZE_WITH_EIGEN(input_float, f2q, T);
85     }
86   }
87 
88   float output_min;
89   float output_max;
90   const std::vector<std::pair<float, float>>* input_min_and_max;
91 };
92 }  // namespace
93 
94 template <typename Device, typename T, AxisArgumentName AxisArgName>
95 class EigenConcatBaseOp : public OpKernel {
96  public:
97   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
98       ConstMatrixVector;
99 
EigenConcatBaseOp(OpKernelConstruction * c)100   explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
101 
CalculateInputAndOutputRange(const OpInputList & input_mins,const OpInputList & input_maxes,const size_t N,std::vector<std::pair<float,float>> * input_mins_and_maxes,float * output_min,float * output_max)102   void CalculateInputAndOutputRange(
103       const OpInputList& input_mins, const OpInputList& input_maxes,
104       const size_t N,
105       std::vector<std::pair<float, float>>* input_mins_and_maxes,
106       float* output_min, float* output_max) {
107     input_mins_and_maxes->reserve(N);
108     float overall_min = std::numeric_limits<float>::max();
109     float overall_max = std::numeric_limits<float>::lowest();
110     for (int i = 0; i < N; ++i) {
111       const float input_min = input_mins[i].flat<float>()(0);
112       const float input_max = input_maxes[i].flat<float>()(0);
113       input_mins_and_maxes->emplace_back(input_min, input_max);
114       overall_min = std::min(overall_min, input_min);
115       overall_max = std::max(overall_max, input_max);
116     }
117     if (std::is_signed<T>::value) {
118       // For signed, we want a symmetrical distribution including zero for the
119       // output, so pick a range that meets that need.
120       const float largest_value =
121           std::max(std::abs(overall_min), std::abs(overall_max));
122       *output_min = -largest_value;
123       *output_max = largest_value;
124     } else {
125       // For MKL quantization, we only support scaled mode, so the range is
126       // [0, m] for unsigned data where m is the range maximum
127       *output_min = 0.0f;
128       *output_max = overall_max;
129     }
130   }
131 
132   // Although, we modify Compute for this call to accept one extra param,
133   // we need to have empty Compute because Compute is pure virtual function.
Compute(OpKernelContext * c)134   void Compute(OpKernelContext* c) {}
135 
Compute(OpKernelContext * c,const std::vector<Tensor> & values,const TensorShapeList & input_shapes,const OpInputList & input_mins,const OpInputList & input_maxes,bool quantized_input)136   void Compute(OpKernelContext* c, const std::vector<Tensor>& values,
137                const TensorShapeList& input_shapes,
138                const OpInputList& input_mins, const OpInputList& input_maxes,
139                bool quantized_input) {
140     const Tensor* concat_dim_tensor;
141     const char* axis_attribute_name =
142         AxisArgName == NAME_IS_AXIS
143             ? "axis"
144             : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
145     OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
146     OP_REQUIRES(c, TensorShapeUtils::IsScalar(concat_dim_tensor->shape()),
147                 errors::InvalidArgument(
148                     axis_attribute_name,
149                     " tensor should be a scalar integer, but got shape ",
150                     concat_dim_tensor->shape().DebugString()));
151     const int32 concat_dim =
152         internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
153     // Instead of accessing values from context, we use input to Compute.
154     const int N = values.size();
155     const int input_dims = input_shapes[0].dims();
156     const TensorShape& input_shape = input_shapes[0];
157 
158     int32 axis = (concat_dim < 0) ? (concat_dim + input_dims) : concat_dim;
159     OP_REQUIRES(
160         c, (0 <= axis && axis < input_dims),
161         errors::InvalidArgument(
162             "ConcatOp : Expected concatenating dimensions in the range [",
163             -input_dims, ", ", input_dims, "), but got ", concat_dim));
164 
165     float output_min = std::numeric_limits<float>::max();
166     float output_max = std::numeric_limits<float>::lowest();
167     std::vector<std::pair<float, float>> input_mins_and_maxes;
168     if (quantized_input) {
169       CalculateInputAndOutputRange(input_mins, input_maxes, N,
170                                    &input_mins_and_maxes, &output_min,
171                                    &output_max);
172     }
173     // Note that we reduce the concat of n-dimensional tensors into a two
174     // dimensional concat. Assuming the dimensions of any input/output
175     // tensor are {x_0, x_1,...,x_n-1, y_0, y_1,...,y_m-1}, where the
176     // concat is along the dimension indicated with size y_0, we flatten it
177     // to {x, y}, where y = Prod_i(y_i) and x = ((n > 0) ? Prod_i(x_i) : 1).
178     ConstMatrixVector inputs_flat;
179     inputs_flat.reserve(N);
180     int64 inputs_flat_dim0 = 1;
181     for (int d = 0; d < axis; ++d) {
182       inputs_flat_dim0 *= input_shape.dim_size(d);
183     }
184     int64 output_concat_dim = 0;
185     const bool input_is_scalar = TensorShapeUtils::IsScalar(input_shape);
186     for (int i = 0; i < N; ++i) {
187       const auto in = values[i];
188       const bool in_is_scalar = TensorShapeUtils::IsScalar(input_shapes[i]);
189       OP_REQUIRES(
190           c,
191           (input_shapes[i].dims() == input_dims) ||
192               (input_is_scalar && in_is_scalar),
193           errors::InvalidArgument(
194               "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
195               input_shape.DebugString(), " vs. shape[", i,
196               "] = ", input_shapes[i].DebugString()));
197       if (in.NumElements() > 0) {
198         int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
199         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
200             in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
201       }
202       output_concat_dim +=
203           input_shapes[i].dims() > 0 ? input_shapes[i].dim_size(axis) : 1;
204     }
205 
206     TensorShape output_shape(input_shape);
207     if (output_shape.dims() == 0) {
208       output_shape.AddDim(output_concat_dim);
209     } else {
210       output_shape.set_dim(axis, output_concat_dim);
211     }
212     Tensor* output = nullptr;
213     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
214     if (output->NumElements() > 0) {
215       int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
216       auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
217       if (!quantized_input) {
218         ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
219       } else {
220         ConcatCPUImpl<T>(
221             c->device(), inputs_flat, sizeof(T) /* cost_per_unit */,
222             RequantizeCopier<T>(&input_mins_and_maxes, output_min, output_max),
223             &output_flat);
224       }
225     }
226 
227     if (quantized_input) {
228       Tensor* output_min_tensor = nullptr;
229       OP_REQUIRES_OK(c, c->allocate_output(1, {}, &output_min_tensor));
230       output_min_tensor->flat<float>()(0) = output_min;
231 
232       Tensor* output_max_tensor = nullptr;
233       OP_REQUIRES_OK(c, c->allocate_output(2, {}, &output_max_tensor));
234       output_max_tensor->flat<float>()(0) = output_max;
235     }
236   }
237 };
238 // --------------------------------------------------------------------------
239 //                      Mkl Concat Op
240 // --------------------------------------------------------------------------
241 // This structure aggregates multiple inputs to MklConcat* methods.
242 struct MklConcatFwdParams {
243   std::vector<memory::dims> src_dims;
244   memory::dims dst_dims;
245   int num_inputs;
246   int concat_dims;
247   memory::format_tag mkl_common_format;
248 
MklConcatFwdParamstensorflow::MklConcatFwdParams249   MklConcatFwdParams(std::vector<memory::dims>& src_dims_pt,
250                      memory::dims dst_dims, int num_inputs, int concat_dims,
251                      memory::format_tag mkl_common_format)
252       : dst_dims(dst_dims),
253         num_inputs(num_inputs),
254         concat_dims(concat_dims),
255         mkl_common_format(mkl_common_format) {
256     for (int k = 0; k < num_inputs; ++k) {
257       src_dims.push_back(src_dims_pt[k]);
258     }
259   }
260 };
261 
262 // TODO(intel-tf): The template type "T" is currently used to match the
263 // templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h).
264 // In the future, with the removal of "T" from MklPrimitiveFactory, this class
265 // needs to drop "T".
266 template <typename T>
267 class MklConcatFwdPrimitive : public MklPrimitive {
268  public:
MklConcatFwdPrimitive(const MklConcatFwdParams & concat_fwd_dims,const std::vector<memory::desc> & srcs_md)269   explicit MklConcatFwdPrimitive(const MklConcatFwdParams& concat_fwd_dims,
270                                  const std::vector<memory::desc>& srcs_md)
271       : MklPrimitive(engine(engine::kind::cpu, 0)) {
272     // Create concat primitive
273     Setup(concat_fwd_dims, srcs_md);
274   }
275 
~MklConcatFwdPrimitive()276   ~MklConcatFwdPrimitive() {}
277 
278   // Concat forward execute
279   //   src_data:    input data buffer of src
280   //   dst_data:    output data buffer of dst
Execute(const std::vector<dnnl::memory> & in_data,const dnnl::memory & dst_data,const MklConcatFwdParams & concat_fwd_dims,std::shared_ptr<stream> fwd_stream)281   void Execute(const std::vector<dnnl::memory>& in_data,
282                const dnnl::memory& dst_data,
283                const MklConcatFwdParams& concat_fwd_dims,
284                std::shared_ptr<stream> fwd_stream) {
285 #ifdef DNNL_AARCH64_USE_ACL
286     mutex_lock lock(primitive_execution_mu_);
287 #endif
288     DCHECK_EQ(in_data.size(), context_.data_mem.size());
289     for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
290 #ifndef ENABLE_ONEDNN_OPENMP
291       context_.data_mem_shdptr[i]->set_data_handle(
292           static_cast<void*>(in_data[i].get_data_handle()), *fwd_stream);
293     }
294     context_.dst_mem->set_data_handle(
295         static_cast<void*>(dst_data.get_data_handle()), *fwd_stream);
296 #else
297       context_.data_mem_shdptr[i]->set_data_handle(
298           static_cast<void*>(in_data[i].get_data_handle()));
299     }
300     context_.dst_mem->set_data_handle(
301         static_cast<void*>(dst_data.get_data_handle()));
302 #endif  // !ENABLE_ONEDNN_OPENMP
303 
304     for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
305       context_.data_mem[i] = *context_.data_mem_shdptr[i];
306     }
307 
308     execute_primitives(context_.fwd_primitives, fwd_stream,
309                        context_.fwd_primitives_args);
310 
311     // After exec, set data handle back
312     context_.dst_mem->set_data_handle(DummyData);
313     for (int k = 0; k < concat_fwd_dims.num_inputs; k++) {
314       context_.data_mem_shdptr[k]->set_data_handle(DummyData);
315     }
316 
317     for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
318       context_.data_mem[i] = *context_.data_mem_shdptr[i];
319     }
320   }
321 
322  private:
323   // Primitive reuse context for concat Fwd op
324   struct ConcatFwdContext {
325     // oneDNN memory
326     std::vector<dnnl::memory> data_mem;
327     std::vector<std::shared_ptr<dnnl::memory>> data_mem_shdptr;
328     std::shared_ptr<dnnl::memory> dst_mem;
329 
330     // Memory descriptor
331     std::vector<dnnl::memory::desc> src_md;
332     std::shared_ptr<dnnl::memory::desc> dst_md;
333 
334     // Concat primitive descriptor
335     std::shared_ptr<dnnl::concat::primitive_desc> fwd_pd;
336     std::shared_ptr<dnnl::primitive> concat_fwd;
337 
338     std::vector<dnnl::primitive> fwd_primitives;
339 
340     std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
341 
ConcatFwdContexttensorflow::MklConcatFwdPrimitive::ConcatFwdContext342     ConcatFwdContext()
343         : dst_mem(nullptr), fwd_pd(nullptr), concat_fwd(nullptr) {}
344   };
345 
346   // Creates the src and dst memory descriptor for mkl concat
347   // and also creates the concat primitive and primitive descriptor
Setup(const MklConcatFwdParams & concat_fwd_dims,const std::vector<memory::desc> & srcs_md)348   void Setup(const MklConcatFwdParams& concat_fwd_dims,
349              const std::vector<memory::desc>& srcs_md) {
350     // Create memory descriptors for concat with specified srcs format
351     for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
352       dnnl::memory::desc source_md(memory::desc(srcs_md[i].data));
353       context_.src_md.push_back(source_md);
354       std::shared_ptr<dnnl::memory> src_mem(
355           new dnnl::memory(source_md, cpu_engine_, DummyData));
356       context_.data_mem_shdptr.push_back(src_mem);
357       context_.data_mem.push_back(*context_.data_mem_shdptr[i]);
358     }
359     // Store the expected memory format
360     context_.dst_md.reset(new memory::desc({concat_fwd_dims.dst_dims},
361                                            MklDnnType<T>(),
362                                            concat_fwd_dims.mkl_common_format));
363     // Create a concat primitive descriptor
364     context_.fwd_pd.reset(new concat::primitive_desc(
365         *context_.dst_md, concat_fwd_dims.concat_dims, context_.src_md,
366         cpu_engine_));
367 
368     // Create memory primitive based on dummy data
369     context_.dst_mem.reset(
370         new memory(*context_.dst_md, cpu_engine_, DummyData));
371 
372     context_.concat_fwd.reset(new concat(*context_.fwd_pd));
373     std::unordered_map<int, memory> net_args = {
374         {DNNL_ARG_DST, *context_.dst_mem}};
375     for (int i = 0; i < concat_fwd_dims.num_inputs; ++i) {
376       net_args.insert({DNNL_ARG_MULTIPLE_SRC + i, context_.data_mem[i]});
377     }
378 
379     context_.fwd_primitives_args.push_back(net_args);
380     context_.fwd_primitives.push_back(*context_.concat_fwd);
381   }
382 
383   struct ConcatFwdContext context_;
384 
385 #ifdef DNNL_AARCH64_USE_ACL
386   mutex primitive_execution_mu_;
387 #endif
388 };
389 
390 // Class to create/cache the mkl concat primitives based on the
391 // input and output parameters
392 template <typename T>
393 class MklConcatFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
394  public:
Get(const MklConcatFwdParams & concat_fwd_dims,const std::vector<memory::desc> & srcs_md,bool do_not_cache)395   static MklConcatFwdPrimitive<T>* Get(
396       const MklConcatFwdParams& concat_fwd_dims,
397       const std::vector<memory::desc>& srcs_md, bool do_not_cache) {
398     MklConcatFwdPrimitive<T>* concat_fwd = nullptr;
399 
400     if (do_not_cache) {
401       // Always create new primitive
402       concat_fwd = new MklConcatFwdPrimitive<T>(concat_fwd_dims, srcs_md);
403     } else {
404       // Try to find a suitable one in pool
405       concat_fwd = dynamic_cast<MklConcatFwdPrimitive<T>*>(
406           MklConcatFwdPrimitiveFactory<T>::GetInstance().GetConcatFwd(
407               concat_fwd_dims));
408       if (concat_fwd == nullptr) {
409         concat_fwd = new MklConcatFwdPrimitive<T>(concat_fwd_dims, srcs_md);
410         MklConcatFwdPrimitiveFactory<T>::GetInstance().SetConcatFwd(
411             concat_fwd_dims, concat_fwd);
412       }
413     }
414 
415     return concat_fwd;
416   }
417 
418  private:
MklConcatFwdPrimitiveFactory()419   MklConcatFwdPrimitiveFactory() {}
~MklConcatFwdPrimitiveFactory()420   ~MklConcatFwdPrimitiveFactory() {}
421 
GetInstance()422   static MklConcatFwdPrimitiveFactory& GetInstance() {
423     static MklConcatFwdPrimitiveFactory instance_;
424     return instance_;
425   }
426 
CreateKey(const MklConcatFwdParams & concat_fwd_dims)427   static string CreateKey(const MklConcatFwdParams& concat_fwd_dims) {
428     string prefix = "concat_fwd_";
429     FactoryKeyCreator key_creator;
430     key_creator.AddAsKey(prefix);
431     for (int k = 0; k < concat_fwd_dims.num_inputs; k++) {
432       key_creator.AddAsKey(concat_fwd_dims.src_dims[k]);
433     }
434     key_creator.AddAsKey(concat_fwd_dims.concat_dims);
435     return key_creator.GetKey();
436   }
437 
GetConcatFwd(const MklConcatFwdParams & concat_fwd_dims)438   MklPrimitive* GetConcatFwd(const MklConcatFwdParams& concat_fwd_dims) {
439     string key = CreateKey(concat_fwd_dims);
440     return this->GetOp(key);
441   }
442 
SetConcatFwd(const MklConcatFwdParams & concat_fwd_dims,MklPrimitive * op)443   void SetConcatFwd(const MklConcatFwdParams& concat_fwd_dims,
444                     MklPrimitive* op) {
445     string key = CreateKey(concat_fwd_dims);
446     this->SetOp(key, op);
447   }
448 };
449 
450 template <typename Device, typename T, AxisArgumentName AxisArgName,
451           bool native_format = false>
452 class MklConcatOp : public OpKernel {
453  private:
454   TensorFormat data_format_;
455   EigenConcatBaseOp<Device, T, AxisArgName> eigen_concat_op_;
456 
457  public:
458   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
459       ConstMatrixVector;
460 
MklConcatOp(OpKernelConstruction * c)461   explicit MklConcatOp(OpKernelConstruction* c)
462       : OpKernel(c),
463         data_format_(TensorFormat::FORMAT_NCHW),
464         eigen_concat_op_(c) {}
465 
Compute(OpKernelContext * context)466   void Compute(OpKernelContext* context) override {
467     try {
468       auto cpu_engine = engine(engine::kind::cpu, 0);
469       OpInputList input_tensors;
470       GetMklInputList(context, "values", &input_tensors);
471       const int N = input_tensors.size();
472       // Get Tensor shapes.
473       std::vector<MklDnnShape> mkl_input_shapes(N);
474       GetMklShapeList(context, "values", &mkl_input_shapes, native_format);
475 
476       const Tensor& concat_dim_tensor = (AxisArgName == NAME_IS_CONCAT_DIM)
477                                             ? MklGetInput(context, 0)
478                                             : MklGetInput(context, N);
479       // Sanity checks
480       OP_REQUIRES(
481           context, TensorShapeUtils::IsScalar(concat_dim_tensor.shape()),
482           errors::InvalidArgument(
483               "Concat dim tensor should be a scalar integer, but got shape ",
484               concat_dim_tensor.shape().DebugString()));
485       int32 concat_dim =
486           internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
487 
488       // check that ranks of all tensors match
489       // and that their shapes match except for concat_dim.
490       int i = 0;
491       int num_of_empty_inputs = 0;
492       bool invoke_eigen = false;
493       bool are_all_mkl_inputs = true, are_all_tf_inputs = true;
494       const TensorShape expected_shape = mkl_input_shapes[0].IsMklTensor()
495                                              ? mkl_input_shapes[0].GetTfShape()
496                                              : input_tensors[0].shape();
497       size_t expected_dims = expected_shape.dims();
498 
499       if (concat_dim < 0) concat_dim = expected_dims + concat_dim;
500 
501       for (auto& s : mkl_input_shapes) {
502         TensorShape s_shape =
503             s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape();
504         size_t s_dims = s_shape.dims();
505 
506         OP_REQUIRES(
507             context, s_dims == expected_dims,
508             errors::InvalidArgument(
509                 "_MklConcatOp : Ranks of all input tensors should match:"
510                 " input dimensions = ",
511                 s_dims, " vs. expected rank = ", expected_dims));
512 
513         for (int d = 0; d < expected_dims; ++d) {
514           if (d == concat_dim) continue;
515 
516           size_t expected_size = expected_shape.dim_size(d);
517           size_t s_size = s_shape.dim_size(d);
518           OP_REQUIRES(
519               context, expected_size == s_size,
520               errors::InvalidArgument("_MklConcatOp : Dimensions of inputs "
521                                       "should match: shape[0][",
522                                       d, "]= ", expected_size, " vs. shape[", i,
523                                       "][", d, "] = ", s_size));
524         }
525 
526         if (s.IsMklTensor())
527           are_all_tf_inputs = false;
528         else
529           are_all_mkl_inputs = false;
530 
531         if (s_dims != 4 && s_dims != 2) invoke_eigen = true;
532 
533         if (input_tensors[i].NumElements() == 0) num_of_empty_inputs++;
534 
535         ++i;
536       }
537 
538       if (num_of_empty_inputs == i) invoke_eigen = true;
539 
540       // All inputs are not in one format (TF or MKL). This is mixed input case.
541       // We can potentially optimize this case by converting all TF inputs
542       // to Mkl format. But currently, we fall to Eigen for this case.
543       // It may be possible to convert inputs that in TF format to Mkl
544       // format and avoid calling eigen version.
545       if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true;
546 
547       // Temporally call Eigen if number of input dimensions is 2.
548       // That is due to an incorrect output results in DNNL 1.2 path.
549       if (expected_dims == 2) invoke_eigen = true;
550 
551       OpInputList input_mins, input_maxes;
552       bool quantized_input =
553           std::is_same<T, qint8>::value || std::is_same<T, quint8>::value;
554       if (quantized_input) {
555         // oneDNN concat does not support input tensors that have different
556         // ranges. Check if the ranges of the all input tensors are the same.
557         // If not, forward it to Eigen implementation.
558 
559         OP_REQUIRES_OK(context, context->input_list("input_mins", &input_mins));
560         OP_REQUIRES(context, (input_mins.size() == N),
561                     errors::InvalidArgument(
562                         "QuantizedConcatOp : Expected mins input list length ",
563                         input_mins.size(), " to equal values length ", N));
564 
565         OP_REQUIRES_OK(context,
566                        context->input_list("input_maxes", &input_maxes));
567         OP_REQUIRES(context, (input_maxes.size() == N),
568                     errors::InvalidArgument(
569                         "QuantizedConcatOp : Expected maxes input list length ",
570                         input_maxes.size(), " to equal values length ", N));
571         float input_min = input_mins[0].flat<float>()(0);
572         float input_max = input_maxes[0].flat<float>()(0);
573         const float eps = 1.0e-6;
574         for (int i = 1; i < N; ++i) {
575           float min = input_mins[i].flat<float>()(0);
576           float max = input_maxes[i].flat<float>()(0);
577 
578           if (fabs(input_min - min) > eps || fabs(input_max - max) > eps) {
579             invoke_eigen = true;
580             break;
581           }
582         }
583       }
584 
585       // Call Eigen library
586       if (invoke_eigen) {
587         CallEigenVersion(context, input_tensors, input_mins, input_maxes,
588                          mkl_input_shapes, quantized_input);
589         return;
590       }
591 
592       memory::dims dst_dims;
593 
594       if (are_all_mkl_inputs)
595         dst_dims = TFShapeToMklDnnDims(mkl_input_shapes[0].GetTfShape());
596       else
597         // When all the inputs are in Tensorflow format, we don't know
598         // what is the input data format. In that case, we just use
599         // output format that is same as input formats.
600         dst_dims = TFShapeToMklDnnDims(input_tensors[0].shape());
601 
602       std::vector<memory::desc> srcs_pd;
603       std::vector<MklDnnData<T>> srcs(N, MklDnnData<T>(&cpu_engine));
604       int64 dst_concat_dim_size = 0;
605 
606       bool isMklReorderNeeded = false;
607       memory::format_tag mkl_common_format = memory::format_tag::any;
608       std::vector<memory> inputs;
609       std::vector<memory::dims> src_dims_pt;
610       std::vector<dnnl::memory> srcs_mem;
611       std::vector<memory::desc> srcs_md;
612 
613       if (are_all_mkl_inputs) {
614         mkl_common_format =
615             FindMklCommonFormat(mkl_input_shapes, concat_dim,
616                                 &isMklReorderNeeded, &dst_concat_dim_size);
617 
618         if (!isMklReorderNeeded) {
619           // All MKL tensors have a same format. Reorder is not needed.
620           for (int k = 0; k < N; k++) {
621             if (input_tensors[k].NumElements() == 0) continue;
622             auto src_md = mkl_input_shapes[k].GetMklLayout();
623             srcs[k].SetUsrMem(src_md, &input_tensors[k]);
624             auto src_mpd = srcs[k].GetUsrMemDesc();
625             srcs_pd.push_back(src_mpd);
626             inputs.push_back(srcs[k].GetOpMem());
627           }
628         } else {
629           // MKL tensors have different formats.
630           // Reorder them to most common format.
631           for (int k = 0; k < N; k++) {
632             if (input_tensors[k].NumElements() == 0) continue;
633             auto src_md = mkl_input_shapes[k].GetMklLayout();
634             srcs[k].SetUsrMem(src_md, &input_tensors[k]);
635             auto src_tf_fmt = MklTensorFormatToMklDnnDataFormat(
636                 mkl_input_shapes[k].GetTfDataFormat());
637             if (src_tf_fmt != mkl_common_format) {
638               memory::dims src_dims(src_md.data.dims,
639                                     &src_md.data.dims[src_md.data.ndims]);
640               src_md =
641                   memory::desc(src_dims, MklDnnType<T>(), mkl_common_format);
642             }
643             srcs_pd.push_back(memory::desc(src_md));
644           }
645         }
646       } else {  // All TF inputs
647         for (int k = 0; k < N; k++) {
648           if (input_tensors[k].NumElements() == 0) continue;
649           TensorShape s_shape = input_tensors[k].shape();
650           memory::dims src_dims = TFShapeToMklDnnDims(s_shape);
651           dst_concat_dim_size += src_dims[concat_dim];
652           size_t s_dims = s_shape.dims();
653 
654           // It does not matter what data format to be used (NHWC versus NCHW).
655           // We just need to ensure that output uses same data format as inputs.
656           if (s_dims == 4)
657             mkl_common_format = memory::format_tag::nchw;
658           else if (s_dims == 2)
659             mkl_common_format = memory::format_tag::nc;
660 
661           auto src_md =
662               memory::desc(src_dims, MklDnnType<T>(), mkl_common_format);
663 
664           srcs[k].SetUsrMem(src_md, &input_tensors[k]);
665           auto src_mpd = srcs[k].GetUsrMemDesc();
666           srcs_pd.push_back(src_mpd);
667           inputs.push_back(srcs[k].GetOpMem());
668           src_dims_pt.push_back(src_dims);
669           srcs_md.push_back(src_md);
670           srcs_mem.push_back(srcs[k].GetOpMem());
671         }
672       }
673       dst_dims[concat_dim] = dst_concat_dim_size;
674 
675       MklDnnData<T> dst(&cpu_engine);
676       memory::desc dst_md({}, memory::data_type::undef,
677                           memory::format_tag::undef);
678       memory::dims dst_dims_in_nchw;
679       if (are_all_mkl_inputs) {
680         // Since we are passing a specific format for destination,
681         // we need to have dst_dims in MklDnn order (NCHW).
682         auto orig_tf_format = mkl_input_shapes[0].GetTfDataFormat();
683         if (dst_dims.size() == 4) {
684           dst_dims_in_nchw = MklDnnDimsInNCHW(
685               dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format));
686           // Set the output format same as the most common format of inputs
687           // to avoid layout conversions.
688           // DNN 1.0: internal format is always blocked;
689           //          format_tag does not have "blocked" field.
690           VLOG(1) << "mkl_common_format == memory::format_tag::blocked";
691           dst_md = MklDnnData<T>::CreateBlockedMemDesc(
692               dst_dims_in_nchw, CalculateTFStrides(dst_dims_in_nchw));
693         } else if (dst_dims.size() == 2 &&
694                    mkl_common_format == memory::format_tag::nc) {
695           // When memory::format_tag::nc, dst_dims are already in oneDNN order
696           dst_md = memory::desc(dst_dims, MklDnnType<T>(), mkl_common_format);
697         } else {
698           TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION,
699                              "Unsupported tensor dimension or"
700                              "oneDNN memory format"));
701         }
702       } else {
703         // All inputs are TF tensors.
704         // Set the output format same as input format (nchw/nc).
705         dst_md = memory::desc(dst_dims, MklDnnType<T>(), mkl_common_format);
706       }
707 
708       if (isMklReorderNeeded) {
709         for (int k = 0; k < input_tensors.size(); k++) {
710           if (input_tensors[k].NumElements() > 0) {
711             srcs[k].CheckReorderToOpMem(srcs_pd[k], cpu_engine, context);
712             inputs.push_back(srcs[k].GetOpMem());
713           }
714         }
715       }
716 
717       // If all inputs are in MKL format, then meaning of concat_dim needs to
718       // change. Value of concat_dim is tied to input Tensorflow data format
719       // (NHWC or NCHW). MklDnn dimensions are in NCHW order. So if Tensorflow
720       // tensors are in NCHW order, then concat_dim semantics is preserved.
721       // But ifinput tensors are in NHWC order, then semantics need to change.
722       // E.g., if we are concatinating over Channel (dimension 3 for NHWC),
723       // then since MklDnn order is NCHW, concat_dim needs to be 1.
724       if (are_all_mkl_inputs)
725         concat_dim = mkl_input_shapes[0].TfDimIdx(concat_dim);
726 
727       if (!inputs.empty()) {
728         if (are_all_mkl_inputs) {
729           auto concat_pd =
730               concat::primitive_desc(concat_dim, srcs_pd, cpu_engine);
731           auto dst_pd = concat_pd.dst_desc();
732 
733           MklDnnShape dnn_shape_dst;
734           TensorShape tf_shape_dst;
735           Tensor* dst_tensor = nullptr;
736           dnn_shape_dst.SetMklTensor(true);
737           dnn_shape_dst.SetMklLayout(&dst_pd);
738           dnn_shape_dst.SetElemType(MklDnnType<T>());
739           dnn_shape_dst.SetTfLayout(dst_dims.size(), dst_dims_in_nchw,
740                                     mkl_input_shapes[0].GetTfDataFormat());
741           tf_shape_dst.AddDim((dst_pd.get_size() / sizeof(T)));
742           AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst,
743                                     dnn_shape_dst);
744           DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
745 
746           std::shared_ptr<stream> fwd_cpu_stream;
747           MklDnnThreadPool eigen_tp(context);
748           fwd_cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine));
749 
750           if (dnn_shape_dst.IsMklTensor())
751             dst_md = dnn_shape_dst.GetMklLayout();
752           dst.SetUsrMem(dst_md, dst_tensor);
753           dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
754 
755           auto concat_op = concat(concat_pd);
756           std::unordered_map<int, memory> net_args = {
757               {DNNL_ARG_DST, dst.GetOpMem()}};
758           for (int i = 0; i < inputs.size(); ++i) {
759             net_args.insert({DNNL_ARG_MULTIPLE_SRC + i, inputs[i]});
760           }
761           concat_op.execute(*fwd_cpu_stream, net_args);
762         } else {
763           MklConcatFwdPrimitive<T>* concat_fwd = nullptr;
764 
765           MklConcatFwdParams concat_fwd_dims(src_dims_pt, dst_dims,
766                                              (N - num_of_empty_inputs),
767                                              concat_dim, mkl_common_format);
768           // Get a concat fwd from primitive pool
769           concat_fwd =
770               MklConcatFwdPrimitiveFactory<T>::Get(concat_fwd_dims, srcs_md, 0);
771 
772           // Allocate output tensor.
773           MklDnnShape dnn_shape_dst;
774           TensorShape tf_shape_dst;
775           Tensor* dst_tensor = nullptr;
776           dnn_shape_dst.SetMklTensor(false);
777           tf_shape_dst = MklDnnDimsToTFShape(dst_dims);
778           AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst,
779                                     dnn_shape_dst, native_format);
780           DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
781 
782           dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout()
783                                                : dst_md;
784           std::shared_ptr<stream> fwd_cpu_stream;
785           MklDnnThreadPool eigen_tp(context);
786           fwd_cpu_stream.reset(
787               CreateStream(&eigen_tp, concat_fwd->GetEngine()));
788           dst.SetUsrMem(dst_md, dst_tensor);
789           dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
790           // Execute concat
791           concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims,
792                               fwd_cpu_stream);
793         }
794 
795         // For quantized concat, min and max outputs are also computed.
796         if (quantized_input) {
797           Tensor* output_min = nullptr;
798           Tensor* output_max = nullptr;
799           MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
800           output_min_mkl_shape.SetMklTensor(false);
801           output_max_mkl_shape.SetMklTensor(false);
802           AllocateOutputSetMklShape(context, 1, &output_min, {},
803                                     output_min_mkl_shape, native_format);
804           AllocateOutputSetMklShape(context, 2, &output_max, {},
805                                     output_max_mkl_shape, native_format);
806           // All input tensors should have the same range, just use the
807           // first one
808           output_min->flat<float>()(0) = input_mins[0].flat<float>()(0);
809           output_max->flat<float>()(0) = input_maxes[0].flat<float>()(0);
810         }
811       } else {
812         MklDnnShape dnn_shape_dst;
813         TensorShape tf_shape_dst;
814         Tensor* dst_tensor = nullptr;
815         dnn_shape_dst.SetMklTensor(false);
816         tf_shape_dst = MklDnnDimsToTFShape(dst_dims);
817 
818         AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst,
819                                   dnn_shape_dst, native_format);
820         DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
821       }
822     } catch (dnnl::error& e) {
823       string error_msg = "Status: " + std::to_string(e.status) +
824                          ", message: " + string(e.message) + ", in file " +
825                          string(__FILE__) + ":" + std::to_string(__LINE__);
826       OP_REQUIRES_OK(
827           context,
828           errors::Aborted("Operation received an exception:", error_msg));
829     }
830   }
831 
CallEigenVersion(OpKernelContext * context,const OpInputList & values,const OpInputList & input_mins,const OpInputList & input_maxes,const MklDnnShapeList & mkl_input_shapes,bool quantized_input)832   void CallEigenVersion(OpKernelContext* context, const OpInputList& values,
833                         const OpInputList& input_mins,
834                         const OpInputList& input_maxes,
835                         const MklDnnShapeList& mkl_input_shapes,
836                         bool quantized_input) {
837     size_t num_mkl_input_shapes = mkl_input_shapes.size();
838     DCHECK_EQ(values.size(), num_mkl_input_shapes);
839     std::vector<Tensor> converted_values(num_mkl_input_shapes);
840     TensorShapeList tf_input_shapes;
841     for (size_t i = 0; i < num_mkl_input_shapes; ++i) {
842       if (mkl_input_shapes[i].IsMklTensor()) {
843         // Do conversion from MKL to TF
844         OP_REQUIRES_OK(
845             context, ConvertMklToTF<T>(context, values[i], mkl_input_shapes[i],
846                                        &converted_values[i]));
847         tf_input_shapes.push_back(mkl_input_shapes[i].GetTfShape());
848       } else {
849         // No conversion since it is TF tensor already
850         converted_values[i] = values[i];
851         tf_input_shapes.push_back(values[i].shape());
852       }
853     }
854 
855     // Call Eigen concat.
856     eigen_concat_op_.Compute(context, converted_values, tf_input_shapes,
857                              input_mins, input_maxes, quantized_input);
858 
859     if (!native_format) {
860       // Get the number of dims from first input since all input tensors
861       // should have same rank.
862       size_t dims = values[0].shape().dims();
863       MklDnnShape output_data_mkl_shape;
864       output_data_mkl_shape.SetMklTensor(false);
865       output_data_mkl_shape.SetDimensions(dims);
866       AllocateOutputSetMklShape(context, 0, output_data_mkl_shape);
867       if (quantized_input) {
868         MklDnnShape output_min_max_mkl_shape;
869         output_min_max_mkl_shape.SetMklTensor(false);
870         AllocateOutputSetMklShape(context, 1, output_min_max_mkl_shape);
871         AllocateOutputSetMklShape(context, 2, output_min_max_mkl_shape);
872       }
873     }
874   }
875 
876   // This method finds the most common format across all MKL inputs
877   // Inputs:
878   //   1. input_shapes: shapes of input (MKL) tensors.
879   //   2. concat_dim: concat dimension.
880   // Outputs:
881   //   1. is_reorder_needed is set to true if inputs have difference formats
882   //      It is set to false otherwise.
883   //   2. concat_dim_size is the size of concat_dim.
884   // Return:
885   //   return the common MKL format.
FindMklCommonFormat(const MklDnnShapeList & input_shapes,int concat_dim,bool * is_reorder_needed,int64 * concat_dim_size)886   memory::format_tag FindMklCommonFormat(const MklDnnShapeList& input_shapes,
887                                          int concat_dim,
888                                          bool* is_reorder_needed,
889                                          int64* concat_dim_size) {
890     *is_reorder_needed = false;
891     *concat_dim_size = 0;
892     std::unordered_map<int, int> occurrence_map;
893     if (input_shapes.size() == 0) return memory::format_tag::any;
894 
895     // Compute ocurrences of each format of all inputs.
896     for (int k = 0; k < input_shapes.size(); k++) {
897       auto src_dims = TFShapeToMklDnnDims(input_shapes[k].GetTfShape());
898       *concat_dim_size += src_dims[concat_dim];
899       int fmt = static_cast<int>(
900           MklTensorFormatToMklDnnDataFormat(input_shapes[k].GetTfDataFormat()));
901       occurrence_map[fmt] += 1;
902     }
903 
904     if (occurrence_map.size() == 1) {
905       // this means that all inputs have a same format
906       // return it with is_reorder_needed set false.
907       return static_cast<memory::format_tag>(
908           MklTensorFormatToMklDnnDataFormat(input_shapes[0].GetTfDataFormat()));
909     }
910 
911     // Input tensors have different formats. Thus, reorder is needed.
912     // We pick up the most common format to minimize the total
913     // number of input reorder.
914     memory::format_tag commonest_format = memory::format_tag::any;
915     int max_occurrence = 0;
916     *is_reorder_needed = true;
917     for (auto item : occurrence_map) {
918       if (item.second > max_occurrence) {
919         commonest_format = static_cast<memory::format_tag>(item.first);
920         max_occurrence = item.second;
921       }
922     }
923     return commonest_format;
924   }
925 };
926 
927 /* Use optimized concat for float type only */
928 #define REGISTER_MKL_CPU(type)                                 \
929   REGISTER_KERNEL_BUILDER(                                     \
930       Name("_MklConcat")                                       \
931           .Device(DEVICE_CPU)                                  \
932           .TypeConstraint<type>("T")                           \
933           .HostMemory("concat_dim")                            \
934           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
935       MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>);       \
936   REGISTER_KERNEL_BUILDER(                                     \
937       Name("_MklConcatV2")                                     \
938           .Device(DEVICE_CPU)                                  \
939           .TypeConstraint<type>("T")                           \
940           .TypeConstraint<int32>("Tidx")                       \
941           .HostMemory("axis")                                  \
942           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
943       MklConcatOp<CPUDevice, type, NAME_IS_AXIS>);
944 
945 TF_CALL_float(REGISTER_MKL_CPU);
946 TF_CALL_bfloat16(REGISTER_MKL_CPU);
947 
948 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
949                             .Device(DEVICE_CPU)
950                             .TypeConstraint<quint8>("T")
951                             .HostMemory("axis")
952                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
953                         MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS, true>);
954 
955 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
956                             .Device(DEVICE_CPU)
957                             .TypeConstraint<qint8>("T")
958                             .HostMemory("axis")
959                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
960                         MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS, true>);
961 
962 #define REGISTER_QUANTIZED_CONCATV2(type)                \
963   REGISTER_KERNEL_BUILDER(Name("QuantizedConcatV2")      \
964                               .Device(DEVICE_CPU)        \
965                               .TypeConstraint<type>("T") \
966                               .HostMemory("axis"),       \
967                           NoOp)
968 
969 REGISTER_QUANTIZED_CONCATV2(quint8);
970 REGISTER_QUANTIZED_CONCATV2(qint8);
971 
972 #undef REGISTER_CONCAT_MKL
973 }  // namespace tensorflow
974 
975 #endif  // INTEL_MKL
976