1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 Licensed under the Apache License, Version 2.0 (the "License"); 3 you may not use this file except in compliance with the License. 4 You may obtain a copy of the License at 5 http://www.apache.org/licenses/LICENSE-2.0 6 Unless required by applicable law or agreed to in writing, software 7 distributed under the License is distributed on an "AS IS" BASIS, 8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 See the License for the specific language governing permissions and 10 limitations under the License. 11 ==============================================================================*/ 12 13 #ifdef INTEL_MKL 14 #define EIGEN_USE_THREADS 15 16 #include <limits> 17 #include <unordered_map> 18 #include <vector> 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "dnnl.hpp" 22 #include "tensorflow/core/framework/bounds_check.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_types.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/kernels/concat_lib.h" 29 #include "tensorflow/core/kernels/concat_lib_cpu.h" 30 #include "tensorflow/core/kernels/no_op.h" 31 #include "tensorflow/core/kernels/quantization_utils.h" 32 #include "tensorflow/core/lib/core/status.h" 33 #include "tensorflow/core/platform/types.h" 34 #include "tensorflow/core/util/mkl_util.h" 35 #ifdef DNNL_AARCH64_USE_ACL 36 #include "tensorflow/core/platform/mutex.h" 37 #endif 38 39 using dnnl::concat; 40 using dnnl::stream; 41 42 namespace tensorflow { 43 typedef Eigen::ThreadPoolDevice CPUDevice; 44 45 // List of TensorShape objects. Used in Concat/Split layers. 46 typedef std::vector<TensorShape> TensorShapeList; 47 48 enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM }; 49 50 // TODO(intel-tf) Check if we can reuse existing EigenConcatOp using Mutable 51 // reference inputs. 52 // -------------------------------------------------------------------------- 53 // Eigen Concat Op 54 // -------------------------------------------------------------------------- 55 namespace { 56 template <typename T> 57 struct RequantizeCopier { RequantizeCopiertensorflow::__anond1ff67460111::RequantizeCopier58 RequantizeCopier( 59 const std::vector<std::pair<float, float>>* input_min_and_max, 60 float output_min, float output_max) 61 : output_min(output_min), output_max(output_max) { 62 DCHECK(input_min_and_max); 63 this->input_min_and_max = input_min_and_max; 64 } 65 Copytensorflow::__anond1ff67460111::RequantizeCopier66 inline void Copy(T* dst, const T* src, int input_index, size_t n) { 67 const float input_min = (*input_min_and_max)[input_index].first; 68 const float input_max = (*input_min_and_max)[input_index].second; 69 if (input_min == output_min && input_max == output_max) { 70 DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())); 71 memcpy(dst, src, n * sizeof(T)); 72 } else { 73 Eigen::array<Eigen::DenseIndex, 1> dims; 74 dims[0] = n; 75 typename TTypes<T, 1>::UnalignedConstTensor input_array(src, dims); 76 typename TTypes<T, 1>::UnalignedTensor output_array(dst, dims); 77 78 QuantizedToFloatStruct<T> q2f(input_min, input_max); 79 auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f); 80 FloatToQuantizedStruct<T> f2q(output_min, output_max); 81 // RequantizeCopier::Copy is called from within a shard of computation, so 82 // don't use the threadpool device here, simply assign with default CPU 83 // device. 84 output_array = QUANTIZE_WITH_EIGEN(input_float, f2q, T); 85 } 86 } 87 88 float output_min; 89 float output_max; 90 const std::vector<std::pair<float, float>>* input_min_and_max; 91 }; 92 } // namespace 93 94 template <typename Device, typename T, AxisArgumentName AxisArgName> 95 class EigenConcatBaseOp : public OpKernel { 96 public: 97 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> 98 ConstMatrixVector; 99 EigenConcatBaseOp(OpKernelConstruction * c)100 explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {} 101 CalculateInputAndOutputRange(const OpInputList & input_mins,const OpInputList & input_maxes,const size_t N,std::vector<std::pair<float,float>> * input_mins_and_maxes,float * output_min,float * output_max)102 void CalculateInputAndOutputRange( 103 const OpInputList& input_mins, const OpInputList& input_maxes, 104 const size_t N, 105 std::vector<std::pair<float, float>>* input_mins_and_maxes, 106 float* output_min, float* output_max) { 107 input_mins_and_maxes->reserve(N); 108 float overall_min = std::numeric_limits<float>::max(); 109 float overall_max = std::numeric_limits<float>::lowest(); 110 for (int i = 0; i < N; ++i) { 111 const float input_min = input_mins[i].flat<float>()(0); 112 const float input_max = input_maxes[i].flat<float>()(0); 113 input_mins_and_maxes->emplace_back(input_min, input_max); 114 overall_min = std::min(overall_min, input_min); 115 overall_max = std::max(overall_max, input_max); 116 } 117 if (std::is_signed<T>::value) { 118 // For signed, we want a symmetrical distribution including zero for the 119 // output, so pick a range that meets that need. 120 const float largest_value = 121 std::max(std::abs(overall_min), std::abs(overall_max)); 122 *output_min = -largest_value; 123 *output_max = largest_value; 124 } else { 125 // For MKL quantization, we only support scaled mode, so the range is 126 // [0, m] for unsigned data where m is the range maximum 127 *output_min = 0.0f; 128 *output_max = overall_max; 129 } 130 } 131 132 // Although, we modify Compute for this call to accept one extra param, 133 // we need to have empty Compute because Compute is pure virtual function. Compute(OpKernelContext * c)134 void Compute(OpKernelContext* c) {} 135 Compute(OpKernelContext * c,const std::vector<Tensor> & values,const TensorShapeList & input_shapes,const OpInputList & input_mins,const OpInputList & input_maxes,bool quantized_input)136 void Compute(OpKernelContext* c, const std::vector<Tensor>& values, 137 const TensorShapeList& input_shapes, 138 const OpInputList& input_mins, const OpInputList& input_maxes, 139 bool quantized_input) { 140 const Tensor* concat_dim_tensor; 141 const char* axis_attribute_name = 142 AxisArgName == NAME_IS_AXIS 143 ? "axis" 144 : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>"; 145 OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); 146 OP_REQUIRES(c, TensorShapeUtils::IsScalar(concat_dim_tensor->shape()), 147 errors::InvalidArgument( 148 axis_attribute_name, 149 " tensor should be a scalar integer, but got shape ", 150 concat_dim_tensor->shape().DebugString())); 151 const int32 concat_dim = 152 internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()()); 153 // Instead of accessing values from context, we use input to Compute. 154 const int N = values.size(); 155 const int input_dims = input_shapes[0].dims(); 156 const TensorShape& input_shape = input_shapes[0]; 157 158 int32 axis = (concat_dim < 0) ? (concat_dim + input_dims) : concat_dim; 159 OP_REQUIRES( 160 c, (0 <= axis && axis < input_dims), 161 errors::InvalidArgument( 162 "ConcatOp : Expected concatenating dimensions in the range [", 163 -input_dims, ", ", input_dims, "), but got ", concat_dim)); 164 165 float output_min = std::numeric_limits<float>::max(); 166 float output_max = std::numeric_limits<float>::lowest(); 167 std::vector<std::pair<float, float>> input_mins_and_maxes; 168 if (quantized_input) { 169 CalculateInputAndOutputRange(input_mins, input_maxes, N, 170 &input_mins_and_maxes, &output_min, 171 &output_max); 172 } 173 // Note that we reduce the concat of n-dimensional tensors into a two 174 // dimensional concat. Assuming the dimensions of any input/output 175 // tensor are {x_0, x_1,...,x_n-1, y_0, y_1,...,y_m-1}, where the 176 // concat is along the dimension indicated with size y_0, we flatten it 177 // to {x, y}, where y = Prod_i(y_i) and x = ((n > 0) ? Prod_i(x_i) : 1). 178 ConstMatrixVector inputs_flat; 179 inputs_flat.reserve(N); 180 int64 inputs_flat_dim0 = 1; 181 for (int d = 0; d < axis; ++d) { 182 inputs_flat_dim0 *= input_shape.dim_size(d); 183 } 184 int64 output_concat_dim = 0; 185 const bool input_is_scalar = TensorShapeUtils::IsScalar(input_shape); 186 for (int i = 0; i < N; ++i) { 187 const auto in = values[i]; 188 const bool in_is_scalar = TensorShapeUtils::IsScalar(input_shapes[i]); 189 OP_REQUIRES( 190 c, 191 (input_shapes[i].dims() == input_dims) || 192 (input_is_scalar && in_is_scalar), 193 errors::InvalidArgument( 194 "ConcatOp : Ranks of all input tensors should match: shape[0] = ", 195 input_shape.DebugString(), " vs. shape[", i, 196 "] = ", input_shapes[i].DebugString())); 197 if (in.NumElements() > 0) { 198 int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; 199 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( 200 in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1}))); 201 } 202 output_concat_dim += 203 input_shapes[i].dims() > 0 ? input_shapes[i].dim_size(axis) : 1; 204 } 205 206 TensorShape output_shape(input_shape); 207 if (output_shape.dims() == 0) { 208 output_shape.AddDim(output_concat_dim); 209 } else { 210 output_shape.set_dim(axis, output_concat_dim); 211 } 212 Tensor* output = nullptr; 213 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); 214 if (output->NumElements() > 0) { 215 int64 output_dim1 = output->NumElements() / inputs_flat_dim0; 216 auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1}); 217 if (!quantized_input) { 218 ConcatCPU<T>(c->device(), inputs_flat, &output_flat); 219 } else { 220 ConcatCPUImpl<T>( 221 c->device(), inputs_flat, sizeof(T) /* cost_per_unit */, 222 RequantizeCopier<T>(&input_mins_and_maxes, output_min, output_max), 223 &output_flat); 224 } 225 } 226 227 if (quantized_input) { 228 Tensor* output_min_tensor = nullptr; 229 OP_REQUIRES_OK(c, c->allocate_output(1, {}, &output_min_tensor)); 230 output_min_tensor->flat<float>()(0) = output_min; 231 232 Tensor* output_max_tensor = nullptr; 233 OP_REQUIRES_OK(c, c->allocate_output(2, {}, &output_max_tensor)); 234 output_max_tensor->flat<float>()(0) = output_max; 235 } 236 } 237 }; 238 // -------------------------------------------------------------------------- 239 // Mkl Concat Op 240 // -------------------------------------------------------------------------- 241 // This structure aggregates multiple inputs to MklConcat* methods. 242 struct MklConcatFwdParams { 243 std::vector<memory::dims> src_dims; 244 memory::dims dst_dims; 245 int num_inputs; 246 int concat_dims; 247 memory::format_tag mkl_common_format; 248 MklConcatFwdParamstensorflow::MklConcatFwdParams249 MklConcatFwdParams(std::vector<memory::dims>& src_dims_pt, 250 memory::dims dst_dims, int num_inputs, int concat_dims, 251 memory::format_tag mkl_common_format) 252 : dst_dims(dst_dims), 253 num_inputs(num_inputs), 254 concat_dims(concat_dims), 255 mkl_common_format(mkl_common_format) { 256 for (int k = 0; k < num_inputs; ++k) { 257 src_dims.push_back(src_dims_pt[k]); 258 } 259 } 260 }; 261 262 // TODO(intel-tf): The template type "T" is currently used to match the 263 // templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h). 264 // In the future, with the removal of "T" from MklPrimitiveFactory, this class 265 // needs to drop "T". 266 template <typename T> 267 class MklConcatFwdPrimitive : public MklPrimitive { 268 public: MklConcatFwdPrimitive(const MklConcatFwdParams & concat_fwd_dims,const std::vector<memory::desc> & srcs_md)269 explicit MklConcatFwdPrimitive(const MklConcatFwdParams& concat_fwd_dims, 270 const std::vector<memory::desc>& srcs_md) 271 : MklPrimitive(engine(engine::kind::cpu, 0)) { 272 // Create concat primitive 273 Setup(concat_fwd_dims, srcs_md); 274 } 275 ~MklConcatFwdPrimitive()276 ~MklConcatFwdPrimitive() {} 277 278 // Concat forward execute 279 // src_data: input data buffer of src 280 // dst_data: output data buffer of dst Execute(const std::vector<dnnl::memory> & in_data,const dnnl::memory & dst_data,const MklConcatFwdParams & concat_fwd_dims,std::shared_ptr<stream> fwd_stream)281 void Execute(const std::vector<dnnl::memory>& in_data, 282 const dnnl::memory& dst_data, 283 const MklConcatFwdParams& concat_fwd_dims, 284 std::shared_ptr<stream> fwd_stream) { 285 #ifdef DNNL_AARCH64_USE_ACL 286 mutex_lock lock(primitive_execution_mu_); 287 #endif 288 DCHECK_EQ(in_data.size(), context_.data_mem.size()); 289 for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) { 290 #ifndef ENABLE_ONEDNN_OPENMP 291 context_.data_mem_shdptr[i]->set_data_handle( 292 static_cast<void*>(in_data[i].get_data_handle()), *fwd_stream); 293 } 294 context_.dst_mem->set_data_handle( 295 static_cast<void*>(dst_data.get_data_handle()), *fwd_stream); 296 #else 297 context_.data_mem_shdptr[i]->set_data_handle( 298 static_cast<void*>(in_data[i].get_data_handle())); 299 } 300 context_.dst_mem->set_data_handle( 301 static_cast<void*>(dst_data.get_data_handle())); 302 #endif // !ENABLE_ONEDNN_OPENMP 303 304 for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) { 305 context_.data_mem[i] = *context_.data_mem_shdptr[i]; 306 } 307 308 execute_primitives(context_.fwd_primitives, fwd_stream, 309 context_.fwd_primitives_args); 310 311 // After exec, set data handle back 312 context_.dst_mem->set_data_handle(DummyData); 313 for (int k = 0; k < concat_fwd_dims.num_inputs; k++) { 314 context_.data_mem_shdptr[k]->set_data_handle(DummyData); 315 } 316 317 for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) { 318 context_.data_mem[i] = *context_.data_mem_shdptr[i]; 319 } 320 } 321 322 private: 323 // Primitive reuse context for concat Fwd op 324 struct ConcatFwdContext { 325 // oneDNN memory 326 std::vector<dnnl::memory> data_mem; 327 std::vector<std::shared_ptr<dnnl::memory>> data_mem_shdptr; 328 std::shared_ptr<dnnl::memory> dst_mem; 329 330 // Memory descriptor 331 std::vector<dnnl::memory::desc> src_md; 332 std::shared_ptr<dnnl::memory::desc> dst_md; 333 334 // Concat primitive descriptor 335 std::shared_ptr<dnnl::concat::primitive_desc> fwd_pd; 336 std::shared_ptr<dnnl::primitive> concat_fwd; 337 338 std::vector<dnnl::primitive> fwd_primitives; 339 340 std::vector<std::unordered_map<int, memory>> fwd_primitives_args; 341 ConcatFwdContexttensorflow::MklConcatFwdPrimitive::ConcatFwdContext342 ConcatFwdContext() 343 : dst_mem(nullptr), fwd_pd(nullptr), concat_fwd(nullptr) {} 344 }; 345 346 // Creates the src and dst memory descriptor for mkl concat 347 // and also creates the concat primitive and primitive descriptor Setup(const MklConcatFwdParams & concat_fwd_dims,const std::vector<memory::desc> & srcs_md)348 void Setup(const MklConcatFwdParams& concat_fwd_dims, 349 const std::vector<memory::desc>& srcs_md) { 350 // Create memory descriptors for concat with specified srcs format 351 for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) { 352 dnnl::memory::desc source_md(memory::desc(srcs_md[i].data)); 353 context_.src_md.push_back(source_md); 354 std::shared_ptr<dnnl::memory> src_mem( 355 new dnnl::memory(source_md, cpu_engine_, DummyData)); 356 context_.data_mem_shdptr.push_back(src_mem); 357 context_.data_mem.push_back(*context_.data_mem_shdptr[i]); 358 } 359 // Store the expected memory format 360 context_.dst_md.reset(new memory::desc({concat_fwd_dims.dst_dims}, 361 MklDnnType<T>(), 362 concat_fwd_dims.mkl_common_format)); 363 // Create a concat primitive descriptor 364 context_.fwd_pd.reset(new concat::primitive_desc( 365 *context_.dst_md, concat_fwd_dims.concat_dims, context_.src_md, 366 cpu_engine_)); 367 368 // Create memory primitive based on dummy data 369 context_.dst_mem.reset( 370 new memory(*context_.dst_md, cpu_engine_, DummyData)); 371 372 context_.concat_fwd.reset(new concat(*context_.fwd_pd)); 373 std::unordered_map<int, memory> net_args = { 374 {DNNL_ARG_DST, *context_.dst_mem}}; 375 for (int i = 0; i < concat_fwd_dims.num_inputs; ++i) { 376 net_args.insert({DNNL_ARG_MULTIPLE_SRC + i, context_.data_mem[i]}); 377 } 378 379 context_.fwd_primitives_args.push_back(net_args); 380 context_.fwd_primitives.push_back(*context_.concat_fwd); 381 } 382 383 struct ConcatFwdContext context_; 384 385 #ifdef DNNL_AARCH64_USE_ACL 386 mutex primitive_execution_mu_; 387 #endif 388 }; 389 390 // Class to create/cache the mkl concat primitives based on the 391 // input and output parameters 392 template <typename T> 393 class MklConcatFwdPrimitiveFactory : public MklPrimitiveFactory<T> { 394 public: Get(const MklConcatFwdParams & concat_fwd_dims,const std::vector<memory::desc> & srcs_md,bool do_not_cache)395 static MklConcatFwdPrimitive<T>* Get( 396 const MklConcatFwdParams& concat_fwd_dims, 397 const std::vector<memory::desc>& srcs_md, bool do_not_cache) { 398 MklConcatFwdPrimitive<T>* concat_fwd = nullptr; 399 400 if (do_not_cache) { 401 // Always create new primitive 402 concat_fwd = new MklConcatFwdPrimitive<T>(concat_fwd_dims, srcs_md); 403 } else { 404 // Try to find a suitable one in pool 405 concat_fwd = dynamic_cast<MklConcatFwdPrimitive<T>*>( 406 MklConcatFwdPrimitiveFactory<T>::GetInstance().GetConcatFwd( 407 concat_fwd_dims)); 408 if (concat_fwd == nullptr) { 409 concat_fwd = new MklConcatFwdPrimitive<T>(concat_fwd_dims, srcs_md); 410 MklConcatFwdPrimitiveFactory<T>::GetInstance().SetConcatFwd( 411 concat_fwd_dims, concat_fwd); 412 } 413 } 414 415 return concat_fwd; 416 } 417 418 private: MklConcatFwdPrimitiveFactory()419 MklConcatFwdPrimitiveFactory() {} ~MklConcatFwdPrimitiveFactory()420 ~MklConcatFwdPrimitiveFactory() {} 421 GetInstance()422 static MklConcatFwdPrimitiveFactory& GetInstance() { 423 static MklConcatFwdPrimitiveFactory instance_; 424 return instance_; 425 } 426 CreateKey(const MklConcatFwdParams & concat_fwd_dims)427 static string CreateKey(const MklConcatFwdParams& concat_fwd_dims) { 428 string prefix = "concat_fwd_"; 429 FactoryKeyCreator key_creator; 430 key_creator.AddAsKey(prefix); 431 for (int k = 0; k < concat_fwd_dims.num_inputs; k++) { 432 key_creator.AddAsKey(concat_fwd_dims.src_dims[k]); 433 } 434 key_creator.AddAsKey(concat_fwd_dims.concat_dims); 435 return key_creator.GetKey(); 436 } 437 GetConcatFwd(const MklConcatFwdParams & concat_fwd_dims)438 MklPrimitive* GetConcatFwd(const MklConcatFwdParams& concat_fwd_dims) { 439 string key = CreateKey(concat_fwd_dims); 440 return this->GetOp(key); 441 } 442 SetConcatFwd(const MklConcatFwdParams & concat_fwd_dims,MklPrimitive * op)443 void SetConcatFwd(const MklConcatFwdParams& concat_fwd_dims, 444 MklPrimitive* op) { 445 string key = CreateKey(concat_fwd_dims); 446 this->SetOp(key, op); 447 } 448 }; 449 450 template <typename Device, typename T, AxisArgumentName AxisArgName, 451 bool native_format = false> 452 class MklConcatOp : public OpKernel { 453 private: 454 TensorFormat data_format_; 455 EigenConcatBaseOp<Device, T, AxisArgName> eigen_concat_op_; 456 457 public: 458 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> 459 ConstMatrixVector; 460 MklConcatOp(OpKernelConstruction * c)461 explicit MklConcatOp(OpKernelConstruction* c) 462 : OpKernel(c), 463 data_format_(TensorFormat::FORMAT_NCHW), 464 eigen_concat_op_(c) {} 465 Compute(OpKernelContext * context)466 void Compute(OpKernelContext* context) override { 467 try { 468 auto cpu_engine = engine(engine::kind::cpu, 0); 469 OpInputList input_tensors; 470 GetMklInputList(context, "values", &input_tensors); 471 const int N = input_tensors.size(); 472 // Get Tensor shapes. 473 std::vector<MklDnnShape> mkl_input_shapes(N); 474 GetMklShapeList(context, "values", &mkl_input_shapes, native_format); 475 476 const Tensor& concat_dim_tensor = (AxisArgName == NAME_IS_CONCAT_DIM) 477 ? MklGetInput(context, 0) 478 : MklGetInput(context, N); 479 // Sanity checks 480 OP_REQUIRES( 481 context, TensorShapeUtils::IsScalar(concat_dim_tensor.shape()), 482 errors::InvalidArgument( 483 "Concat dim tensor should be a scalar integer, but got shape ", 484 concat_dim_tensor.shape().DebugString())); 485 int32 concat_dim = 486 internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()()); 487 488 // check that ranks of all tensors match 489 // and that their shapes match except for concat_dim. 490 int i = 0; 491 int num_of_empty_inputs = 0; 492 bool invoke_eigen = false; 493 bool are_all_mkl_inputs = true, are_all_tf_inputs = true; 494 const TensorShape expected_shape = mkl_input_shapes[0].IsMklTensor() 495 ? mkl_input_shapes[0].GetTfShape() 496 : input_tensors[0].shape(); 497 size_t expected_dims = expected_shape.dims(); 498 499 if (concat_dim < 0) concat_dim = expected_dims + concat_dim; 500 501 for (auto& s : mkl_input_shapes) { 502 TensorShape s_shape = 503 s.IsMklTensor() ? s.GetTfShape() : input_tensors[i].shape(); 504 size_t s_dims = s_shape.dims(); 505 506 OP_REQUIRES( 507 context, s_dims == expected_dims, 508 errors::InvalidArgument( 509 "_MklConcatOp : Ranks of all input tensors should match:" 510 " input dimensions = ", 511 s_dims, " vs. expected rank = ", expected_dims)); 512 513 for (int d = 0; d < expected_dims; ++d) { 514 if (d == concat_dim) continue; 515 516 size_t expected_size = expected_shape.dim_size(d); 517 size_t s_size = s_shape.dim_size(d); 518 OP_REQUIRES( 519 context, expected_size == s_size, 520 errors::InvalidArgument("_MklConcatOp : Dimensions of inputs " 521 "should match: shape[0][", 522 d, "]= ", expected_size, " vs. shape[", i, 523 "][", d, "] = ", s_size)); 524 } 525 526 if (s.IsMklTensor()) 527 are_all_tf_inputs = false; 528 else 529 are_all_mkl_inputs = false; 530 531 if (s_dims != 4 && s_dims != 2) invoke_eigen = true; 532 533 if (input_tensors[i].NumElements() == 0) num_of_empty_inputs++; 534 535 ++i; 536 } 537 538 if (num_of_empty_inputs == i) invoke_eigen = true; 539 540 // All inputs are not in one format (TF or MKL). This is mixed input case. 541 // We can potentially optimize this case by converting all TF inputs 542 // to Mkl format. But currently, we fall to Eigen for this case. 543 // It may be possible to convert inputs that in TF format to Mkl 544 // format and avoid calling eigen version. 545 if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true; 546 547 // Temporally call Eigen if number of input dimensions is 2. 548 // That is due to an incorrect output results in DNNL 1.2 path. 549 if (expected_dims == 2) invoke_eigen = true; 550 551 OpInputList input_mins, input_maxes; 552 bool quantized_input = 553 std::is_same<T, qint8>::value || std::is_same<T, quint8>::value; 554 if (quantized_input) { 555 // oneDNN concat does not support input tensors that have different 556 // ranges. Check if the ranges of the all input tensors are the same. 557 // If not, forward it to Eigen implementation. 558 559 OP_REQUIRES_OK(context, context->input_list("input_mins", &input_mins)); 560 OP_REQUIRES(context, (input_mins.size() == N), 561 errors::InvalidArgument( 562 "QuantizedConcatOp : Expected mins input list length ", 563 input_mins.size(), " to equal values length ", N)); 564 565 OP_REQUIRES_OK(context, 566 context->input_list("input_maxes", &input_maxes)); 567 OP_REQUIRES(context, (input_maxes.size() == N), 568 errors::InvalidArgument( 569 "QuantizedConcatOp : Expected maxes input list length ", 570 input_maxes.size(), " to equal values length ", N)); 571 float input_min = input_mins[0].flat<float>()(0); 572 float input_max = input_maxes[0].flat<float>()(0); 573 const float eps = 1.0e-6; 574 for (int i = 1; i < N; ++i) { 575 float min = input_mins[i].flat<float>()(0); 576 float max = input_maxes[i].flat<float>()(0); 577 578 if (fabs(input_min - min) > eps || fabs(input_max - max) > eps) { 579 invoke_eigen = true; 580 break; 581 } 582 } 583 } 584 585 // Call Eigen library 586 if (invoke_eigen) { 587 CallEigenVersion(context, input_tensors, input_mins, input_maxes, 588 mkl_input_shapes, quantized_input); 589 return; 590 } 591 592 memory::dims dst_dims; 593 594 if (are_all_mkl_inputs) 595 dst_dims = TFShapeToMklDnnDims(mkl_input_shapes[0].GetTfShape()); 596 else 597 // When all the inputs are in Tensorflow format, we don't know 598 // what is the input data format. In that case, we just use 599 // output format that is same as input formats. 600 dst_dims = TFShapeToMklDnnDims(input_tensors[0].shape()); 601 602 std::vector<memory::desc> srcs_pd; 603 std::vector<MklDnnData<T>> srcs(N, MklDnnData<T>(&cpu_engine)); 604 int64 dst_concat_dim_size = 0; 605 606 bool isMklReorderNeeded = false; 607 memory::format_tag mkl_common_format = memory::format_tag::any; 608 std::vector<memory> inputs; 609 std::vector<memory::dims> src_dims_pt; 610 std::vector<dnnl::memory> srcs_mem; 611 std::vector<memory::desc> srcs_md; 612 613 if (are_all_mkl_inputs) { 614 mkl_common_format = 615 FindMklCommonFormat(mkl_input_shapes, concat_dim, 616 &isMklReorderNeeded, &dst_concat_dim_size); 617 618 if (!isMklReorderNeeded) { 619 // All MKL tensors have a same format. Reorder is not needed. 620 for (int k = 0; k < N; k++) { 621 if (input_tensors[k].NumElements() == 0) continue; 622 auto src_md = mkl_input_shapes[k].GetMklLayout(); 623 srcs[k].SetUsrMem(src_md, &input_tensors[k]); 624 auto src_mpd = srcs[k].GetUsrMemDesc(); 625 srcs_pd.push_back(src_mpd); 626 inputs.push_back(srcs[k].GetOpMem()); 627 } 628 } else { 629 // MKL tensors have different formats. 630 // Reorder them to most common format. 631 for (int k = 0; k < N; k++) { 632 if (input_tensors[k].NumElements() == 0) continue; 633 auto src_md = mkl_input_shapes[k].GetMklLayout(); 634 srcs[k].SetUsrMem(src_md, &input_tensors[k]); 635 auto src_tf_fmt = MklTensorFormatToMklDnnDataFormat( 636 mkl_input_shapes[k].GetTfDataFormat()); 637 if (src_tf_fmt != mkl_common_format) { 638 memory::dims src_dims(src_md.data.dims, 639 &src_md.data.dims[src_md.data.ndims]); 640 src_md = 641 memory::desc(src_dims, MklDnnType<T>(), mkl_common_format); 642 } 643 srcs_pd.push_back(memory::desc(src_md)); 644 } 645 } 646 } else { // All TF inputs 647 for (int k = 0; k < N; k++) { 648 if (input_tensors[k].NumElements() == 0) continue; 649 TensorShape s_shape = input_tensors[k].shape(); 650 memory::dims src_dims = TFShapeToMklDnnDims(s_shape); 651 dst_concat_dim_size += src_dims[concat_dim]; 652 size_t s_dims = s_shape.dims(); 653 654 // It does not matter what data format to be used (NHWC versus NCHW). 655 // We just need to ensure that output uses same data format as inputs. 656 if (s_dims == 4) 657 mkl_common_format = memory::format_tag::nchw; 658 else if (s_dims == 2) 659 mkl_common_format = memory::format_tag::nc; 660 661 auto src_md = 662 memory::desc(src_dims, MklDnnType<T>(), mkl_common_format); 663 664 srcs[k].SetUsrMem(src_md, &input_tensors[k]); 665 auto src_mpd = srcs[k].GetUsrMemDesc(); 666 srcs_pd.push_back(src_mpd); 667 inputs.push_back(srcs[k].GetOpMem()); 668 src_dims_pt.push_back(src_dims); 669 srcs_md.push_back(src_md); 670 srcs_mem.push_back(srcs[k].GetOpMem()); 671 } 672 } 673 dst_dims[concat_dim] = dst_concat_dim_size; 674 675 MklDnnData<T> dst(&cpu_engine); 676 memory::desc dst_md({}, memory::data_type::undef, 677 memory::format_tag::undef); 678 memory::dims dst_dims_in_nchw; 679 if (are_all_mkl_inputs) { 680 // Since we are passing a specific format for destination, 681 // we need to have dst_dims in MklDnn order (NCHW). 682 auto orig_tf_format = mkl_input_shapes[0].GetTfDataFormat(); 683 if (dst_dims.size() == 4) { 684 dst_dims_in_nchw = MklDnnDimsInNCHW( 685 dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format)); 686 // Set the output format same as the most common format of inputs 687 // to avoid layout conversions. 688 // DNN 1.0: internal format is always blocked; 689 // format_tag does not have "blocked" field. 690 VLOG(1) << "mkl_common_format == memory::format_tag::blocked"; 691 dst_md = MklDnnData<T>::CreateBlockedMemDesc( 692 dst_dims_in_nchw, CalculateTFStrides(dst_dims_in_nchw)); 693 } else if (dst_dims.size() == 2 && 694 mkl_common_format == memory::format_tag::nc) { 695 // When memory::format_tag::nc, dst_dims are already in oneDNN order 696 dst_md = memory::desc(dst_dims, MklDnnType<T>(), mkl_common_format); 697 } else { 698 TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION, 699 "Unsupported tensor dimension or" 700 "oneDNN memory format")); 701 } 702 } else { 703 // All inputs are TF tensors. 704 // Set the output format same as input format (nchw/nc). 705 dst_md = memory::desc(dst_dims, MklDnnType<T>(), mkl_common_format); 706 } 707 708 if (isMklReorderNeeded) { 709 for (int k = 0; k < input_tensors.size(); k++) { 710 if (input_tensors[k].NumElements() > 0) { 711 srcs[k].CheckReorderToOpMem(srcs_pd[k], cpu_engine, context); 712 inputs.push_back(srcs[k].GetOpMem()); 713 } 714 } 715 } 716 717 // If all inputs are in MKL format, then meaning of concat_dim needs to 718 // change. Value of concat_dim is tied to input Tensorflow data format 719 // (NHWC or NCHW). MklDnn dimensions are in NCHW order. So if Tensorflow 720 // tensors are in NCHW order, then concat_dim semantics is preserved. 721 // But ifinput tensors are in NHWC order, then semantics need to change. 722 // E.g., if we are concatinating over Channel (dimension 3 for NHWC), 723 // then since MklDnn order is NCHW, concat_dim needs to be 1. 724 if (are_all_mkl_inputs) 725 concat_dim = mkl_input_shapes[0].TfDimIdx(concat_dim); 726 727 if (!inputs.empty()) { 728 if (are_all_mkl_inputs) { 729 auto concat_pd = 730 concat::primitive_desc(concat_dim, srcs_pd, cpu_engine); 731 auto dst_pd = concat_pd.dst_desc(); 732 733 MklDnnShape dnn_shape_dst; 734 TensorShape tf_shape_dst; 735 Tensor* dst_tensor = nullptr; 736 dnn_shape_dst.SetMklTensor(true); 737 dnn_shape_dst.SetMklLayout(&dst_pd); 738 dnn_shape_dst.SetElemType(MklDnnType<T>()); 739 dnn_shape_dst.SetTfLayout(dst_dims.size(), dst_dims_in_nchw, 740 mkl_input_shapes[0].GetTfDataFormat()); 741 tf_shape_dst.AddDim((dst_pd.get_size() / sizeof(T))); 742 AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst, 743 dnn_shape_dst); 744 DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL"; 745 746 std::shared_ptr<stream> fwd_cpu_stream; 747 MklDnnThreadPool eigen_tp(context); 748 fwd_cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine)); 749 750 if (dnn_shape_dst.IsMklTensor()) 751 dst_md = dnn_shape_dst.GetMklLayout(); 752 dst.SetUsrMem(dst_md, dst_tensor); 753 dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream); 754 755 auto concat_op = concat(concat_pd); 756 std::unordered_map<int, memory> net_args = { 757 {DNNL_ARG_DST, dst.GetOpMem()}}; 758 for (int i = 0; i < inputs.size(); ++i) { 759 net_args.insert({DNNL_ARG_MULTIPLE_SRC + i, inputs[i]}); 760 } 761 concat_op.execute(*fwd_cpu_stream, net_args); 762 } else { 763 MklConcatFwdPrimitive<T>* concat_fwd = nullptr; 764 765 MklConcatFwdParams concat_fwd_dims(src_dims_pt, dst_dims, 766 (N - num_of_empty_inputs), 767 concat_dim, mkl_common_format); 768 // Get a concat fwd from primitive pool 769 concat_fwd = 770 MklConcatFwdPrimitiveFactory<T>::Get(concat_fwd_dims, srcs_md, 0); 771 772 // Allocate output tensor. 773 MklDnnShape dnn_shape_dst; 774 TensorShape tf_shape_dst; 775 Tensor* dst_tensor = nullptr; 776 dnn_shape_dst.SetMklTensor(false); 777 tf_shape_dst = MklDnnDimsToTFShape(dst_dims); 778 AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst, 779 dnn_shape_dst, native_format); 780 DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL"; 781 782 dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout() 783 : dst_md; 784 std::shared_ptr<stream> fwd_cpu_stream; 785 MklDnnThreadPool eigen_tp(context); 786 fwd_cpu_stream.reset( 787 CreateStream(&eigen_tp, concat_fwd->GetEngine())); 788 dst.SetUsrMem(dst_md, dst_tensor); 789 dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream); 790 // Execute concat 791 concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims, 792 fwd_cpu_stream); 793 } 794 795 // For quantized concat, min and max outputs are also computed. 796 if (quantized_input) { 797 Tensor* output_min = nullptr; 798 Tensor* output_max = nullptr; 799 MklDnnShape output_min_mkl_shape, output_max_mkl_shape; 800 output_min_mkl_shape.SetMklTensor(false); 801 output_max_mkl_shape.SetMklTensor(false); 802 AllocateOutputSetMklShape(context, 1, &output_min, {}, 803 output_min_mkl_shape, native_format); 804 AllocateOutputSetMklShape(context, 2, &output_max, {}, 805 output_max_mkl_shape, native_format); 806 // All input tensors should have the same range, just use the 807 // first one 808 output_min->flat<float>()(0) = input_mins[0].flat<float>()(0); 809 output_max->flat<float>()(0) = input_maxes[0].flat<float>()(0); 810 } 811 } else { 812 MklDnnShape dnn_shape_dst; 813 TensorShape tf_shape_dst; 814 Tensor* dst_tensor = nullptr; 815 dnn_shape_dst.SetMklTensor(false); 816 tf_shape_dst = MklDnnDimsToTFShape(dst_dims); 817 818 AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst, 819 dnn_shape_dst, native_format); 820 DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL"; 821 } 822 } catch (dnnl::error& e) { 823 string error_msg = "Status: " + std::to_string(e.status) + 824 ", message: " + string(e.message) + ", in file " + 825 string(__FILE__) + ":" + std::to_string(__LINE__); 826 OP_REQUIRES_OK( 827 context, 828 errors::Aborted("Operation received an exception:", error_msg)); 829 } 830 } 831 CallEigenVersion(OpKernelContext * context,const OpInputList & values,const OpInputList & input_mins,const OpInputList & input_maxes,const MklDnnShapeList & mkl_input_shapes,bool quantized_input)832 void CallEigenVersion(OpKernelContext* context, const OpInputList& values, 833 const OpInputList& input_mins, 834 const OpInputList& input_maxes, 835 const MklDnnShapeList& mkl_input_shapes, 836 bool quantized_input) { 837 size_t num_mkl_input_shapes = mkl_input_shapes.size(); 838 DCHECK_EQ(values.size(), num_mkl_input_shapes); 839 std::vector<Tensor> converted_values(num_mkl_input_shapes); 840 TensorShapeList tf_input_shapes; 841 for (size_t i = 0; i < num_mkl_input_shapes; ++i) { 842 if (mkl_input_shapes[i].IsMklTensor()) { 843 // Do conversion from MKL to TF 844 OP_REQUIRES_OK( 845 context, ConvertMklToTF<T>(context, values[i], mkl_input_shapes[i], 846 &converted_values[i])); 847 tf_input_shapes.push_back(mkl_input_shapes[i].GetTfShape()); 848 } else { 849 // No conversion since it is TF tensor already 850 converted_values[i] = values[i]; 851 tf_input_shapes.push_back(values[i].shape()); 852 } 853 } 854 855 // Call Eigen concat. 856 eigen_concat_op_.Compute(context, converted_values, tf_input_shapes, 857 input_mins, input_maxes, quantized_input); 858 859 if (!native_format) { 860 // Get the number of dims from first input since all input tensors 861 // should have same rank. 862 size_t dims = values[0].shape().dims(); 863 MklDnnShape output_data_mkl_shape; 864 output_data_mkl_shape.SetMklTensor(false); 865 output_data_mkl_shape.SetDimensions(dims); 866 AllocateOutputSetMklShape(context, 0, output_data_mkl_shape); 867 if (quantized_input) { 868 MklDnnShape output_min_max_mkl_shape; 869 output_min_max_mkl_shape.SetMklTensor(false); 870 AllocateOutputSetMklShape(context, 1, output_min_max_mkl_shape); 871 AllocateOutputSetMklShape(context, 2, output_min_max_mkl_shape); 872 } 873 } 874 } 875 876 // This method finds the most common format across all MKL inputs 877 // Inputs: 878 // 1. input_shapes: shapes of input (MKL) tensors. 879 // 2. concat_dim: concat dimension. 880 // Outputs: 881 // 1. is_reorder_needed is set to true if inputs have difference formats 882 // It is set to false otherwise. 883 // 2. concat_dim_size is the size of concat_dim. 884 // Return: 885 // return the common MKL format. FindMklCommonFormat(const MklDnnShapeList & input_shapes,int concat_dim,bool * is_reorder_needed,int64 * concat_dim_size)886 memory::format_tag FindMklCommonFormat(const MklDnnShapeList& input_shapes, 887 int concat_dim, 888 bool* is_reorder_needed, 889 int64* concat_dim_size) { 890 *is_reorder_needed = false; 891 *concat_dim_size = 0; 892 std::unordered_map<int, int> occurrence_map; 893 if (input_shapes.size() == 0) return memory::format_tag::any; 894 895 // Compute ocurrences of each format of all inputs. 896 for (int k = 0; k < input_shapes.size(); k++) { 897 auto src_dims = TFShapeToMklDnnDims(input_shapes[k].GetTfShape()); 898 *concat_dim_size += src_dims[concat_dim]; 899 int fmt = static_cast<int>( 900 MklTensorFormatToMklDnnDataFormat(input_shapes[k].GetTfDataFormat())); 901 occurrence_map[fmt] += 1; 902 } 903 904 if (occurrence_map.size() == 1) { 905 // this means that all inputs have a same format 906 // return it with is_reorder_needed set false. 907 return static_cast<memory::format_tag>( 908 MklTensorFormatToMklDnnDataFormat(input_shapes[0].GetTfDataFormat())); 909 } 910 911 // Input tensors have different formats. Thus, reorder is needed. 912 // We pick up the most common format to minimize the total 913 // number of input reorder. 914 memory::format_tag commonest_format = memory::format_tag::any; 915 int max_occurrence = 0; 916 *is_reorder_needed = true; 917 for (auto item : occurrence_map) { 918 if (item.second > max_occurrence) { 919 commonest_format = static_cast<memory::format_tag>(item.first); 920 max_occurrence = item.second; 921 } 922 } 923 return commonest_format; 924 } 925 }; 926 927 /* Use optimized concat for float type only */ 928 #define REGISTER_MKL_CPU(type) \ 929 REGISTER_KERNEL_BUILDER( \ 930 Name("_MklConcat") \ 931 .Device(DEVICE_CPU) \ 932 .TypeConstraint<type>("T") \ 933 .HostMemory("concat_dim") \ 934 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 935 MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>); \ 936 REGISTER_KERNEL_BUILDER( \ 937 Name("_MklConcatV2") \ 938 .Device(DEVICE_CPU) \ 939 .TypeConstraint<type>("T") \ 940 .TypeConstraint<int32>("Tidx") \ 941 .HostMemory("axis") \ 942 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 943 MklConcatOp<CPUDevice, type, NAME_IS_AXIS>); 944 945 TF_CALL_float(REGISTER_MKL_CPU); 946 TF_CALL_bfloat16(REGISTER_MKL_CPU); 947 948 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2") 949 .Device(DEVICE_CPU) 950 .TypeConstraint<quint8>("T") 951 .HostMemory("axis") 952 .Label(mkl_op_registry::kMklQuantizedOpLabel), 953 MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS, true>); 954 955 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2") 956 .Device(DEVICE_CPU) 957 .TypeConstraint<qint8>("T") 958 .HostMemory("axis") 959 .Label(mkl_op_registry::kMklQuantizedOpLabel), 960 MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS, true>); 961 962 #define REGISTER_QUANTIZED_CONCATV2(type) \ 963 REGISTER_KERNEL_BUILDER(Name("QuantizedConcatV2") \ 964 .Device(DEVICE_CPU) \ 965 .TypeConstraint<type>("T") \ 966 .HostMemory("axis"), \ 967 NoOp) 968 969 REGISTER_QUANTIZED_CONCATV2(quint8); 970 REGISTER_QUANTIZED_CONCATV2(qint8); 971 972 #undef REGISTER_CONCAT_MKL 973 } // namespace tensorflow 974 975 #endif // INTEL_MKL 976