1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/optimized/blas/CPUBlas.h>
10 #include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12
13 #include <array>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using Tensor = exec_aten::Tensor;
20
opt_mm_out(RuntimeContext & ctx,const Tensor & in,const Tensor & mat2,Tensor & out)21 Tensor& opt_mm_out(
22 RuntimeContext& ctx,
23 const Tensor& in,
24 const Tensor& mat2,
25 Tensor& out) {
26 ET_KERNEL_CHECK(ctx, check_mm_args(in, mat2, out), InvalidArgument, out);
27
28 size_t output_ndim = 0;
29 std::array<exec_aten::SizesType, kTensorDimensionLimit> output_sizes;
30 get_mm_out_target_size(in, mat2, output_sizes.data(), &output_ndim);
31 ET_KERNEL_CHECK(
32 ctx,
33 resize_tensor(out, {output_sizes.data(), output_ndim}) == Error::Ok,
34 InvalidArgument,
35 out);
36
37 if (out.numel() == 0) {
38 return out;
39 }
40 ET_SWITCH_REAL_TYPES_AND2(
41 Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
42 size_t n = in.size(0);
43 size_t k = in.size(1);
44 size_t m = mat2.size(1);
45
46 // gemm expects column-major inputs and produces column-major
47 // output. So, we take advantage of the identity (A @ B).t()
48 // = B.t() @ A.t() here; row-major B is B.t() from gemm's
49 // column-major perspective, etc.
50 executorch::cpublas::gemm(
51 executorch::cpublas::TransposeType::NoTranspose,
52 executorch::cpublas::TransposeType::NoTranspose,
53 m,
54 n,
55 k,
56 static_cast<CTYPE>(1),
57 mat2.const_data_ptr<CTYPE>(),
58 m,
59 in.const_data_ptr<CTYPE>(),
60 k,
61 static_cast<CTYPE>(0),
62 out.mutable_data_ptr<CTYPE>(),
63 m);
64 });
65
66 return out;
67 }
68
69 } // namespace native
70 } // namespace executor
71 } // namespace torch
72