xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/op_fast_hadamard_transform.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
10 #include <executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h>
11 #include <executorch/kernels/optimized/utils/llvmMathExtras.h>
12 #include <executorch/kernels/portable/cpu/util/reduce_util.h> // For apply_over_dim.
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
fast_hadamard_transform_out(RuntimeContext & ctx,const Tensor & mat,Tensor & out)19 Tensor& fast_hadamard_transform_out(
20     RuntimeContext& ctx,
21     const Tensor& mat,
22     Tensor& out) {
23   ET_KERNEL_CHECK_MSG(
24       ctx,
25       resize_tensor(out, mat.sizes()) == Error::Ok,
26       InvalidArgument,
27       out,
28       "Failed to resize output tensor.");
29 
30   ET_KERNEL_CHECK(
31       ctx, mat.scalar_type() == out.scalar_type(), InvalidArgument, out);
32 
33   if (mat.dim() == 0 || mat.numel() == 0) {
34     return out;
35   }
36 
37   ET_KERNEL_CHECK(
38       ctx,
39       is_contiguous_dim_order(mat.dim_order().data(), mat.dim()),
40       InvalidArgument,
41       out);
42 
43   ET_KERNEL_CHECK(
44       ctx,
45       is_contiguous_dim_order(out.dim_order().data(), out.dim()),
46       InvalidArgument,
47       out);
48 
49   ET_KERNEL_CHECK_MSG(
50       ctx,
51       mat.strides().back() == 1,
52       InvalidArgument,
53       out,
54       "input matrix that isn't contiguous in the last dimension is not supported!");
55 
56   const auto last_dim_size = mat.sizes().back();
57   const auto divisible_by_28 = last_dim_size % 28 == 0;
58   auto power_of_two_size = divisible_by_28 ? last_dim_size / 28 : last_dim_size;
59   ET_KERNEL_CHECK_MSG(
60       ctx,
61       (power_of_two_size & (power_of_two_size - 1)) == 0,
62       InvalidArgument,
63       out,
64       "This implementation requires power-of-2 (or power-of-2 * 28) input size in the last dimension!");
65 
66   const auto log2_power_of_two_size = executorch::llvm::countTrailingZeros(
67       static_cast<unsigned int>(power_of_two_size),
68       executorch::llvm::ZeroBehavior::ZB_Undefined);
69 
70   ET_SWITCH_FLOATH_TYPES(mat.scalar_type(), ctx, __func__, CTYPE, [&] {
71     const CTYPE* const mat_data = mat.const_data_ptr<CTYPE>();
72     CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
73 
74     std::memcpy(out_data, mat_data, mat.numel() * sizeof(CTYPE));
75 
76     if (divisible_by_28) {
77       apply_over_dim(
78           [log2_power_of_two_size, out_data](
79               const size_t size, const size_t stride, const size_t base) {
80             executorch::fast_hadamard_transform_28N(
81                 out_data + base, log2_power_of_two_size);
82           },
83           out,
84           out.dim() - 1);
85     } else {
86       apply_over_dim(
87           [log2_power_of_two_size, out_data](
88               const size_t size, const size_t stride, const size_t base) {
89             executorch::fast_hadamard_transform(
90                 out_data + base, log2_power_of_two_size);
91           },
92           out,
93           out.dim() - 1);
94     }
95   });
96   return out;
97 }
98 } // namespace native
99 } // namespace executor
100 } // namespace torch
101 
102 EXECUTORCH_LIBRARY(
103     llama,
104     "fast_hadamard_transform.out",
105     torch::executor::native::fast_hadamard_transform_out);
106