xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 // This file uses oneDNN library for acceleration of Batch Matrix-Matrix
19 // Multiplication (MatMul) operations. We currently register this kernel only
20 // for oneDNN supported data types (float, bfloat16). The maximum number of
21 // dimensions (rank) for output tensor is DNNL_MAX_NDIMS = 12 in oneDNN.
22 // If output tensor rank exceeds 12, we exit with reporting an error message.
23 
24 #define EIGEN_USE_THREADS
25 
26 #if defined(INTEL_MKL)
27 
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/type_traits.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/kernels/fill_functor.h"
35 #include "tensorflow/core/kernels/matmul_op_impl.h"
36 #include "tensorflow/core/kernels/mkl/mkl_batch_matmul_helper.h"
37 #include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/core/util/matmul_bcast.h"
40 
41 namespace tensorflow {
42 
43 typedef Eigen::ThreadPoolDevice CPUDevice;
44 
45 //  The third parameter v2_bcast is set to true if we are using V2 otherwise
46 //  we set it to false.
47 template <typename Device, typename Tlhs, typename Trhs, typename Toutput,
48           bool v2_bcast>
49 class BatchMatMulMkl : public OpKernel {
50  public:
BatchMatMulMkl(OpKernelConstruction * context)51   explicit BatchMatMulMkl(OpKernelConstruction* context) : OpKernel(context) {
52     OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
53     OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
54   }
55 
~BatchMatMulMkl()56   virtual ~BatchMatMulMkl() {}
57 
Compute(OpKernelContext * ctx)58   void Compute(OpKernelContext* ctx) override {
59     const Tensor& lhs = ctx->input(0);
60     const Tensor& rhs = ctx->input(1);
61 
62     if (!v2_bcast) {
63       // Using V1, so check to make sure lhs and rhs dimensions are correct and
64       // no broadcasting is needed.
65       OP_REQUIRES(ctx, lhs.dims() == rhs.dims(),
66                   errors::InvalidArgument("lhs and rhs has different ndims: ",
67                                           lhs.shape().DebugString(), " vs. ",
68                                           rhs.shape().DebugString()));
69       const int ndims = lhs.dims();
70       OP_REQUIRES(
71           ctx, ndims >= 2,
72           errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims));
73       for (int i = 0; i < ndims - 2; ++i) {
74         OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i),
75                     errors::InvalidArgument(
76                         "lhs.dim(", i, ") and rhs.dim(", i,
77                         ") must be the same: ", lhs.shape().DebugString(),
78                         " vs ", rhs.shape().DebugString()));
79       }
80     } else {
81       OP_REQUIRES(
82           ctx, lhs.dims() >= 2,
83           errors::InvalidArgument("In[0] ndims must be >= 2: ", lhs.dims()));
84       OP_REQUIRES(
85           ctx, rhs.dims() >= 2,
86           errors::InvalidArgument("In[1] ndims must be >= 2: ", rhs.dims()));
87     }
88 
89     // lhs and rhs can have different dimensions
90     const auto ndims_lhs = lhs.dims();
91     const auto ndims_rhs = rhs.dims();
92 
93     // Get broadcast info
94     MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes());
95     OP_REQUIRES(
96         ctx, bcast.IsValid(),
97         errors::InvalidArgument(
98             "In[0] and In[1] must have compatible batch dimensions: ",
99             lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString()));
100 
101     TensorShape out_shape = bcast.output_batch_shape();
102 
103     auto lhs_rows = lhs.dim_size(ndims_lhs - 2);
104     auto lhs_cols = lhs.dim_size(ndims_lhs - 1);
105     auto rhs_rows = rhs.dim_size(ndims_rhs - 2);
106     auto rhs_cols = rhs.dim_size(ndims_rhs - 1);
107 
108     if (adj_x_) std::swap(lhs_rows, lhs_cols);
109     if (adj_y_) std::swap(rhs_rows, rhs_cols);
110     OP_REQUIRES(ctx, lhs_cols == rhs_rows,
111                 errors::InvalidArgument(
112                     "lhs mismatch rhs shape: ", lhs_cols, " vs. ", rhs_rows,
113                     ": ", lhs.shape().DebugString(), " ",
114                     rhs.shape().DebugString(), " ", adj_x_, " ", adj_y_));
115 
116     out_shape.AddDim(lhs_rows);
117     out_shape.AddDim(rhs_cols);
118     // The maximum number of DNNL tensor dimensions is DNNL_MAX_NDIMS = 12.
119     OP_REQUIRES(
120         ctx, out_shape.dims() <= DNNL_MAX_NDIMS,
121         errors::InvalidArgument(
122             "Rank of output tensor must be <= 12, but is ", out_shape.dims(),
123             ". Current implementation supports upto rank 12 tensors."));
124 
125     Tensor* out = nullptr;
126     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
127     if (out->NumElements() == 0) {
128       return;
129     }
130     if (lhs.NumElements() == 0 || rhs.NumElements() == 0) {
131       functor::SetZeroFunctor<Device, Toutput> f;
132       f(ctx->eigen_device<Device>(), out->flat<Toutput>());
133       return;
134     }
135 
136     // Compute parameters for DNNL matmul primitive.
137     MklBatchMatMulHelper bmm;
138     string prefix = "batchmatmul";
139     auto params = bmm.CreateMatMulParams(prefix, lhs.shape(), rhs.shape(),
140                                          out_shape, adj_x_, adj_y_);
141 
142 #ifdef DNNL_AARCH64_USE_ACL
143     // ACL does not support reuse of primitives with different data.
144     // For matmul, the previous approach (PR #47775) of using Tensor addresses
145     // does not work, as the addresses are re-used in matmul with different data
146     // The counter  ensure we still benefit from caching via SetMklMatmul().
147     params->aarch64_counter =
148         MklMatMulPrimitiveFactory<float, Tlhs, Trhs,
149                                   Toutput>::IncrementCounter();
150 #endif
151     this->ExtendMklMatMulParams(ctx, *params);
152 
153     // Create or retrieve matmul primitive from cache.
154     MklMatMulPrimitive<Tlhs, Trhs, Toutput>* matmul_prim =
155         MklMatMulPrimitiveFactory<float, Tlhs, Trhs, Toutput>::Get(
156             *params, false /* value for do_not_cache */);
157 
158     UserScratchPad<unsigned char> scratch_pad;
159     scratch_pad.AllocateSPTensor(matmul_prim, ctx);
160     // Execute matmul primitive.
161     std::shared_ptr<stream> cpu_stream;
162     MklDnnThreadPool eigen_tp(ctx);
163     cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine()));
164     if (fused_ops_.size() > 0) {
165       void* mul_data = nullptr;
166       void* add_data = nullptr;
167       if (fused_ops_.at(0) == "Mul") {
168         const Tensor& mul_tensor = ctx->input(2);
169         mul_data = static_cast<void*>(
170             const_cast<Toutput*>(mul_tensor.flat<Toutput>().data()));
171       }
172       if (fused_ops_.size() > 1 && fused_ops_.at(1) == "Add") {
173         const Tensor& add_tensor = ctx->input(3);
174         add_data = static_cast<void*>(
175             const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
176       }
177       matmul_prim->Execute(cpu_stream, lhs.flat<Tlhs>().data(),
178                            rhs.flat<Trhs>().data(), out->flat<Toutput>().data(),
179                            scratch_pad.Get(), mul_data, add_data);
180     } else {
181       matmul_prim->Execute(cpu_stream, lhs.flat<Tlhs>().data(),
182                            rhs.flat<Trhs>().data(), out->flat<Toutput>().data(),
183                            scratch_pad.Get());
184     }
185   }
186 
187  protected:
ExtendMklMatMulParams(OpKernelContext * ctx,MklMatMulParams & params)188   virtual void ExtendMklMatMulParams(OpKernelContext* ctx,
189                                      MklMatMulParams& params) {}
190   std::vector<string> fused_ops_;
191 
192  private:
193   bool adj_x_;
194   bool adj_y_;
195 };
196 
197 template <typename Device, typename Tlhs, typename Trhs, typename Toutput,
198           bool v2_bcast>
199 class FusedBatchMatMulMkl
200     : public BatchMatMulMkl<Device, Tlhs, Trhs, Toutput, v2_bcast> {
201  public:
FusedBatchMatMulMkl(OpKernelConstruction * context)202   explicit FusedBatchMatMulMkl(OpKernelConstruction* context)
203       : BatchMatMulMkl<Device, Tlhs, Trhs, Toutput, v2_bcast>(context) {
204     OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &this->fused_ops_));
205     OP_REQUIRES(context, !this->fused_ops_.empty(),
206                 errors::InvalidArgument(
207                     "Fused BatchMatMul must have at least one fused op."));
208 
209     int num_args;
210     OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
211 
212     if (this->fused_ops_ == std::vector<string>{"Mul"} ||
213         this->fused_ops_ == std::vector<string>{"Mul", "Add"}) {
214       OP_REQUIRES(context, num_args == this->fused_ops_.size(),
215                   errors::InvalidArgument(
216                       "Fused BatchMatmul should have same number of additional "
217                       "inputs as the number of fusions"));
218     } else {
219       OP_REQUIRES(
220           context, false,
221           errors::Unimplemented("Fusion is not implemented: [",
222                                 absl::StrJoin(this->fused_ops_, ","), "]"));
223     }
224   }
225 
~FusedBatchMatMulMkl()226   virtual ~FusedBatchMatMulMkl() {}
227 
228  protected:
ExtendMklMatMulParams(OpKernelContext * ctx,MklMatMulParams & params)229   virtual void ExtendMklMatMulParams(OpKernelContext* ctx,
230                                      MklMatMulParams& params) {
231     if (this->fused_ops_.size() > 0) {
232       const Tensor& scale_tensor = ctx->input(2);
233       OP_REQUIRES(ctx, scale_tensor.NumElements() == 1,
234                   errors::InvalidArgument("Scale tensor must be a scalar"));
235 
236       memory::data_type data_type = MklDnnType<Toutput>();
237       memory::format_tag format_tag;
238       switch (params.c_dims.size()) {
239         case 3:
240           format_tag = memory::format_tag::abc;
241           break;
242         case 4:
243           format_tag = memory::format_tag::abcd;
244           break;
245         default:
246           OP_REQUIRES(ctx, false, errors::Unimplemented("Unimplemented"));
247       }
248       if (this->fused_ops_.at(0) == "Mul") {
249         memory::dims mul_dims(params.c_dims.size(), 1);
250         params.post_op_params.push_back(
251             {"mul", {}, mul_dims, data_type, format_tag});
252       } else {
253         OP_REQUIRES(ctx, false,
254                     errors::InvalidArgument(
255                         "Currently first fusion is supported only for Mul",
256                         ", but it is ", this->fused_ops_.at(0), " op."));
257       }
258       if (this->fused_ops_.size() > 1 && this->fused_ops_.at(1) == "Add") {
259         auto add_shape = ctx->input(3).shape();
260         memory::dims add_dims = {add_shape.dim_size(0), add_shape.dim_size(1),
261                                  add_shape.dim_size(2), add_shape.dim_size(3)};
262         params.post_op_params.push_back(
263             {"add", {}, add_dims, data_type, format_tag});
264       } else {
265         OP_REQUIRES(ctx, false,
266                     errors::InvalidArgument(
267                         "Currently second fusion is supported only for Add",
268                         ", but it is ", this->fused_ops_.at(1), " op."));
269       }
270     }
271   }
272 };
273 
274 #define REGISTER_BATCH_MATMUL_MKL(TYPE)                                       \
275   REGISTER_KERNEL_BUILDER(Name("_MklBatchMatMul")                             \
276                               .Device(DEVICE_CPU)                             \
277                               .TypeConstraint<TYPE>("T")                      \
278                               .Label(mkl_op_registry::kMklNameChangeOpLabel), \
279                           BatchMatMulMkl<CPUDevice, TYPE, TYPE, TYPE, false>)
280 
281 #define REGISTER_BATCH_MATMUL_MKL_V2(TYPE)                                    \
282   REGISTER_KERNEL_BUILDER(Name("_MklBatchMatMulV2")                           \
283                               .Device(DEVICE_CPU)                             \
284                               .TypeConstraint<TYPE>("T")                      \
285                               .Label(mkl_op_registry::kMklNameChangeOpLabel), \
286                           BatchMatMulMkl<CPUDevice, TYPE, TYPE, TYPE, true>)
287 
288 #define REGISTER_FUSED_BATCH_MATMUL_MKL(TYPE) \
289   REGISTER_KERNEL_BUILDER(                    \
290       Name("_MklFusedBatchMatMulV2")          \
291           .Device(DEVICE_CPU)                 \
292           .TypeConstraint<TYPE>("T"),         \
293       FusedBatchMatMulMkl<CPUDevice, TYPE, TYPE, TYPE, true>)
294 
295 #ifdef INTEL_MKL
296 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
297 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2);
298 TF_CALL_float(REGISTER_FUSED_BATCH_MATMUL_MKL);
299 TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL);
300 TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2);
301 TF_CALL_bfloat16(REGISTER_FUSED_BATCH_MATMUL_MKL);
302 #endif  // INTEL_MKL
303 
304 }  // end namespace tensorflow
305 #endif
306