1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 // This file uses oneDNN library for acceleration of Batch Matrix-Matrix 19 // Multiplication (MatMul) operations. We currently register this kernel only 20 // for oneDNN supported data types (float, bfloat16). The maximum number of 21 // dimensions (rank) for output tensor is DNNL_MAX_NDIMS = 12 in oneDNN. 22 // If output tensor rank exceeds 12, we exit with reporting an error message. 23 24 #define EIGEN_USE_THREADS 25 26 #if defined(INTEL_MKL) 27 28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/type_traits.h" 33 #include "tensorflow/core/framework/types.h" 34 #include "tensorflow/core/kernels/fill_functor.h" 35 #include "tensorflow/core/kernels/matmul_op_impl.h" 36 #include "tensorflow/core/kernels/mkl/mkl_batch_matmul_helper.h" 37 #include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h" 38 #include "tensorflow/core/platform/types.h" 39 #include "tensorflow/core/util/matmul_bcast.h" 40 41 namespace tensorflow { 42 43 typedef Eigen::ThreadPoolDevice CPUDevice; 44 45 // The third parameter v2_bcast is set to true if we are using V2 otherwise 46 // we set it to false. 47 template <typename Device, typename Tlhs, typename Trhs, typename Toutput, 48 bool v2_bcast> 49 class BatchMatMulMkl : public OpKernel { 50 public: BatchMatMulMkl(OpKernelConstruction * context)51 explicit BatchMatMulMkl(OpKernelConstruction* context) : OpKernel(context) { 52 OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); 53 OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); 54 } 55 ~BatchMatMulMkl()56 virtual ~BatchMatMulMkl() {} 57 Compute(OpKernelContext * ctx)58 void Compute(OpKernelContext* ctx) override { 59 const Tensor& lhs = ctx->input(0); 60 const Tensor& rhs = ctx->input(1); 61 62 if (!v2_bcast) { 63 // Using V1, so check to make sure lhs and rhs dimensions are correct and 64 // no broadcasting is needed. 65 OP_REQUIRES(ctx, lhs.dims() == rhs.dims(), 66 errors::InvalidArgument("lhs and rhs has different ndims: ", 67 lhs.shape().DebugString(), " vs. ", 68 rhs.shape().DebugString())); 69 const int ndims = lhs.dims(); 70 OP_REQUIRES( 71 ctx, ndims >= 2, 72 errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims)); 73 for (int i = 0; i < ndims - 2; ++i) { 74 OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i), 75 errors::InvalidArgument( 76 "lhs.dim(", i, ") and rhs.dim(", i, 77 ") must be the same: ", lhs.shape().DebugString(), 78 " vs ", rhs.shape().DebugString())); 79 } 80 } else { 81 OP_REQUIRES( 82 ctx, lhs.dims() >= 2, 83 errors::InvalidArgument("In[0] ndims must be >= 2: ", lhs.dims())); 84 OP_REQUIRES( 85 ctx, rhs.dims() >= 2, 86 errors::InvalidArgument("In[1] ndims must be >= 2: ", rhs.dims())); 87 } 88 89 // lhs and rhs can have different dimensions 90 const auto ndims_lhs = lhs.dims(); 91 const auto ndims_rhs = rhs.dims(); 92 93 // Get broadcast info 94 MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes()); 95 OP_REQUIRES( 96 ctx, bcast.IsValid(), 97 errors::InvalidArgument( 98 "In[0] and In[1] must have compatible batch dimensions: ", 99 lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString())); 100 101 TensorShape out_shape = bcast.output_batch_shape(); 102 103 auto lhs_rows = lhs.dim_size(ndims_lhs - 2); 104 auto lhs_cols = lhs.dim_size(ndims_lhs - 1); 105 auto rhs_rows = rhs.dim_size(ndims_rhs - 2); 106 auto rhs_cols = rhs.dim_size(ndims_rhs - 1); 107 108 if (adj_x_) std::swap(lhs_rows, lhs_cols); 109 if (adj_y_) std::swap(rhs_rows, rhs_cols); 110 OP_REQUIRES(ctx, lhs_cols == rhs_rows, 111 errors::InvalidArgument( 112 "lhs mismatch rhs shape: ", lhs_cols, " vs. ", rhs_rows, 113 ": ", lhs.shape().DebugString(), " ", 114 rhs.shape().DebugString(), " ", adj_x_, " ", adj_y_)); 115 116 out_shape.AddDim(lhs_rows); 117 out_shape.AddDim(rhs_cols); 118 // The maximum number of DNNL tensor dimensions is DNNL_MAX_NDIMS = 12. 119 OP_REQUIRES( 120 ctx, out_shape.dims() <= DNNL_MAX_NDIMS, 121 errors::InvalidArgument( 122 "Rank of output tensor must be <= 12, but is ", out_shape.dims(), 123 ". Current implementation supports upto rank 12 tensors.")); 124 125 Tensor* out = nullptr; 126 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); 127 if (out->NumElements() == 0) { 128 return; 129 } 130 if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { 131 functor::SetZeroFunctor<Device, Toutput> f; 132 f(ctx->eigen_device<Device>(), out->flat<Toutput>()); 133 return; 134 } 135 136 // Compute parameters for DNNL matmul primitive. 137 MklBatchMatMulHelper bmm; 138 string prefix = "batchmatmul"; 139 auto params = bmm.CreateMatMulParams(prefix, lhs.shape(), rhs.shape(), 140 out_shape, adj_x_, adj_y_); 141 142 #ifdef DNNL_AARCH64_USE_ACL 143 // ACL does not support reuse of primitives with different data. 144 // For matmul, the previous approach (PR #47775) of using Tensor addresses 145 // does not work, as the addresses are re-used in matmul with different data 146 // The counter ensure we still benefit from caching via SetMklMatmul(). 147 params->aarch64_counter = 148 MklMatMulPrimitiveFactory<float, Tlhs, Trhs, 149 Toutput>::IncrementCounter(); 150 #endif 151 this->ExtendMklMatMulParams(ctx, *params); 152 153 // Create or retrieve matmul primitive from cache. 154 MklMatMulPrimitive<Tlhs, Trhs, Toutput>* matmul_prim = 155 MklMatMulPrimitiveFactory<float, Tlhs, Trhs, Toutput>::Get( 156 *params, false /* value for do_not_cache */); 157 158 UserScratchPad<unsigned char> scratch_pad; 159 scratch_pad.AllocateSPTensor(matmul_prim, ctx); 160 // Execute matmul primitive. 161 std::shared_ptr<stream> cpu_stream; 162 MklDnnThreadPool eigen_tp(ctx); 163 cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine())); 164 if (fused_ops_.size() > 0) { 165 void* mul_data = nullptr; 166 void* add_data = nullptr; 167 if (fused_ops_.at(0) == "Mul") { 168 const Tensor& mul_tensor = ctx->input(2); 169 mul_data = static_cast<void*>( 170 const_cast<Toutput*>(mul_tensor.flat<Toutput>().data())); 171 } 172 if (fused_ops_.size() > 1 && fused_ops_.at(1) == "Add") { 173 const Tensor& add_tensor = ctx->input(3); 174 add_data = static_cast<void*>( 175 const_cast<Toutput*>(add_tensor.flat<Toutput>().data())); 176 } 177 matmul_prim->Execute(cpu_stream, lhs.flat<Tlhs>().data(), 178 rhs.flat<Trhs>().data(), out->flat<Toutput>().data(), 179 scratch_pad.Get(), mul_data, add_data); 180 } else { 181 matmul_prim->Execute(cpu_stream, lhs.flat<Tlhs>().data(), 182 rhs.flat<Trhs>().data(), out->flat<Toutput>().data(), 183 scratch_pad.Get()); 184 } 185 } 186 187 protected: ExtendMklMatMulParams(OpKernelContext * ctx,MklMatMulParams & params)188 virtual void ExtendMklMatMulParams(OpKernelContext* ctx, 189 MklMatMulParams& params) {} 190 std::vector<string> fused_ops_; 191 192 private: 193 bool adj_x_; 194 bool adj_y_; 195 }; 196 197 template <typename Device, typename Tlhs, typename Trhs, typename Toutput, 198 bool v2_bcast> 199 class FusedBatchMatMulMkl 200 : public BatchMatMulMkl<Device, Tlhs, Trhs, Toutput, v2_bcast> { 201 public: FusedBatchMatMulMkl(OpKernelConstruction * context)202 explicit FusedBatchMatMulMkl(OpKernelConstruction* context) 203 : BatchMatMulMkl<Device, Tlhs, Trhs, Toutput, v2_bcast>(context) { 204 OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &this->fused_ops_)); 205 OP_REQUIRES(context, !this->fused_ops_.empty(), 206 errors::InvalidArgument( 207 "Fused BatchMatMul must have at least one fused op.")); 208 209 int num_args; 210 OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); 211 212 if (this->fused_ops_ == std::vector<string>{"Mul"} || 213 this->fused_ops_ == std::vector<string>{"Mul", "Add"}) { 214 OP_REQUIRES(context, num_args == this->fused_ops_.size(), 215 errors::InvalidArgument( 216 "Fused BatchMatmul should have same number of additional " 217 "inputs as the number of fusions")); 218 } else { 219 OP_REQUIRES( 220 context, false, 221 errors::Unimplemented("Fusion is not implemented: [", 222 absl::StrJoin(this->fused_ops_, ","), "]")); 223 } 224 } 225 ~FusedBatchMatMulMkl()226 virtual ~FusedBatchMatMulMkl() {} 227 228 protected: ExtendMklMatMulParams(OpKernelContext * ctx,MklMatMulParams & params)229 virtual void ExtendMklMatMulParams(OpKernelContext* ctx, 230 MklMatMulParams& params) { 231 if (this->fused_ops_.size() > 0) { 232 const Tensor& scale_tensor = ctx->input(2); 233 OP_REQUIRES(ctx, scale_tensor.NumElements() == 1, 234 errors::InvalidArgument("Scale tensor must be a scalar")); 235 236 memory::data_type data_type = MklDnnType<Toutput>(); 237 memory::format_tag format_tag; 238 switch (params.c_dims.size()) { 239 case 3: 240 format_tag = memory::format_tag::abc; 241 break; 242 case 4: 243 format_tag = memory::format_tag::abcd; 244 break; 245 default: 246 OP_REQUIRES(ctx, false, errors::Unimplemented("Unimplemented")); 247 } 248 if (this->fused_ops_.at(0) == "Mul") { 249 memory::dims mul_dims(params.c_dims.size(), 1); 250 params.post_op_params.push_back( 251 {"mul", {}, mul_dims, data_type, format_tag}); 252 } else { 253 OP_REQUIRES(ctx, false, 254 errors::InvalidArgument( 255 "Currently first fusion is supported only for Mul", 256 ", but it is ", this->fused_ops_.at(0), " op.")); 257 } 258 if (this->fused_ops_.size() > 1 && this->fused_ops_.at(1) == "Add") { 259 auto add_shape = ctx->input(3).shape(); 260 memory::dims add_dims = {add_shape.dim_size(0), add_shape.dim_size(1), 261 add_shape.dim_size(2), add_shape.dim_size(3)}; 262 params.post_op_params.push_back( 263 {"add", {}, add_dims, data_type, format_tag}); 264 } else { 265 OP_REQUIRES(ctx, false, 266 errors::InvalidArgument( 267 "Currently second fusion is supported only for Add", 268 ", but it is ", this->fused_ops_.at(1), " op.")); 269 } 270 } 271 } 272 }; 273 274 #define REGISTER_BATCH_MATMUL_MKL(TYPE) \ 275 REGISTER_KERNEL_BUILDER(Name("_MklBatchMatMul") \ 276 .Device(DEVICE_CPU) \ 277 .TypeConstraint<TYPE>("T") \ 278 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 279 BatchMatMulMkl<CPUDevice, TYPE, TYPE, TYPE, false>) 280 281 #define REGISTER_BATCH_MATMUL_MKL_V2(TYPE) \ 282 REGISTER_KERNEL_BUILDER(Name("_MklBatchMatMulV2") \ 283 .Device(DEVICE_CPU) \ 284 .TypeConstraint<TYPE>("T") \ 285 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 286 BatchMatMulMkl<CPUDevice, TYPE, TYPE, TYPE, true>) 287 288 #define REGISTER_FUSED_BATCH_MATMUL_MKL(TYPE) \ 289 REGISTER_KERNEL_BUILDER( \ 290 Name("_MklFusedBatchMatMulV2") \ 291 .Device(DEVICE_CPU) \ 292 .TypeConstraint<TYPE>("T"), \ 293 FusedBatchMatMulMkl<CPUDevice, TYPE, TYPE, TYPE, true>) 294 295 #ifdef INTEL_MKL 296 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL); 297 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2); 298 TF_CALL_float(REGISTER_FUSED_BATCH_MATMUL_MKL); 299 TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL); 300 TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2); 301 TF_CALL_bfloat16(REGISTER_FUSED_BATCH_MATMUL_MKL); 302 #endif // INTEL_MKL 303 304 } // end namespace tensorflow 305 #endif 306