1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
10 #include <executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h>
11 #include <executorch/kernels/optimized/utils/llvmMathExtras.h>
12 #include <executorch/kernels/portable/cpu/util/reduce_util.h> // For apply_over_dim.
13 #include <executorch/runtime/kernel/kernel_includes.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
fast_hadamard_transform_out(RuntimeContext & ctx,const Tensor & mat,Tensor & out)19 Tensor& fast_hadamard_transform_out(
20 RuntimeContext& ctx,
21 const Tensor& mat,
22 Tensor& out) {
23 ET_KERNEL_CHECK_MSG(
24 ctx,
25 resize_tensor(out, mat.sizes()) == Error::Ok,
26 InvalidArgument,
27 out,
28 "Failed to resize output tensor.");
29
30 ET_KERNEL_CHECK(
31 ctx, mat.scalar_type() == out.scalar_type(), InvalidArgument, out);
32
33 if (mat.dim() == 0 || mat.numel() == 0) {
34 return out;
35 }
36
37 ET_KERNEL_CHECK(
38 ctx,
39 is_contiguous_dim_order(mat.dim_order().data(), mat.dim()),
40 InvalidArgument,
41 out);
42
43 ET_KERNEL_CHECK(
44 ctx,
45 is_contiguous_dim_order(out.dim_order().data(), out.dim()),
46 InvalidArgument,
47 out);
48
49 ET_KERNEL_CHECK_MSG(
50 ctx,
51 mat.strides().back() == 1,
52 InvalidArgument,
53 out,
54 "input matrix that isn't contiguous in the last dimension is not supported!");
55
56 const auto last_dim_size = mat.sizes().back();
57 const auto divisible_by_28 = last_dim_size % 28 == 0;
58 auto power_of_two_size = divisible_by_28 ? last_dim_size / 28 : last_dim_size;
59 ET_KERNEL_CHECK_MSG(
60 ctx,
61 (power_of_two_size & (power_of_two_size - 1)) == 0,
62 InvalidArgument,
63 out,
64 "This implementation requires power-of-2 (or power-of-2 * 28) input size in the last dimension!");
65
66 const auto log2_power_of_two_size = executorch::llvm::countTrailingZeros(
67 static_cast<unsigned int>(power_of_two_size),
68 executorch::llvm::ZeroBehavior::ZB_Undefined);
69
70 ET_SWITCH_FLOATH_TYPES(mat.scalar_type(), ctx, __func__, CTYPE, [&] {
71 const CTYPE* const mat_data = mat.const_data_ptr<CTYPE>();
72 CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
73
74 std::memcpy(out_data, mat_data, mat.numel() * sizeof(CTYPE));
75
76 if (divisible_by_28) {
77 apply_over_dim(
78 [log2_power_of_two_size, out_data](
79 const size_t size, const size_t stride, const size_t base) {
80 executorch::fast_hadamard_transform_28N(
81 out_data + base, log2_power_of_two_size);
82 },
83 out,
84 out.dim() - 1);
85 } else {
86 apply_over_dim(
87 [log2_power_of_two_size, out_data](
88 const size_t size, const size_t stride, const size_t base) {
89 executorch::fast_hadamard_transform(
90 out_data + base, log2_power_of_two_size);
91 },
92 out,
93 out.dim() - 1);
94 }
95 });
96 return out;
97 }
98 } // namespace native
99 } // namespace executor
100 } // namespace torch
101
102 EXECUTORCH_LIBRARY(
103 llama,
104 "fast_hadamard_transform.out",
105 torch::executor::native::fast_hadamard_transform_out);
106