1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // (c) Meta Platforms, Inc. and affiliates.
10 #pragma once
11
12 #include <cassert>
13 #include <cmath>
14 #include <cstdint>
15 #include <memory>
16
17 #include <executorch/extension/llm/custom_ops/spinquant/third-party/FFHT/fht.h>
18
19 #include "fast_hadamard_transform_special.h"
20
21 namespace executorch {
22 namespace internal {
23
24 // Square root of 1 << log2_n.
25 template <typename T>
fast_sqrt_of_power_of_2(int log2_n)26 T fast_sqrt_of_power_of_2(int log2_n) {
27 // The square root of 2**N is, by definition, 2**(N/2), which is
28 // trivial to compute for even N using a left shift.
29 //
30 // For odd N, 2**(N/2) = 2**(floor(N/2) + 1/2)
31 // = 2**(floor(N/2)) * (2 ** (1/2))
32 // = 2**(floor(N/2)) * sqrt(2)
33 // which is again fast to compute.
34 return T(1 << (log2_n / 2)) * ((log2_n % 2) ? T(std::sqrt(2)) : T(1));
35 }
36
37 template <typename T>
normalize_after_fht(T * out,int log2_vec_size)38 void normalize_after_fht(T* out, int log2_vec_size) {
39 const T inv_sqrt = T(1) / fast_sqrt_of_power_of_2<T>(log2_vec_size);
40 const int vec_size = 1 << log2_vec_size;
41 for (int ii = 0; ii < vec_size; ++ii) {
42 out[ii] *= inv_sqrt;
43 }
44 }
45
46 template <typename T>
fast_hadamard_transform_unnormalized_simple_impl(T * vec,int log2_vec_size)47 void fast_hadamard_transform_unnormalized_simple_impl(
48 T* vec,
49 int log2_vec_size) {
50 // NOTE: If you're here because you're profiling a model and this is
51 // slow, consider updating FFHT to generate efficient assembly for
52 // your data type!
53 if (log2_vec_size == 0) {
54 return;
55 }
56
57 int step = 1;
58 const auto vec_size = 1 << log2_vec_size;
59 while (step < vec_size) {
60 for (int ii = 0; ii < vec_size; ii += step * 2) {
61 for (int jj = ii; jj < ii + step; ++jj) {
62 auto x = vec[jj];
63 auto y = vec[jj + step];
64 vec[jj] = x + y;
65 vec[jj + step] = x - y;
66 }
67 }
68 step *= 2;
69 }
70 }
71
72 template <typename T>
fast_hadamard_transform_simple_impl(T * vec,int log2_vec_size)73 void fast_hadamard_transform_simple_impl(T* vec, int log2_vec_size) {
74 fast_hadamard_transform_unnormalized_simple_impl(vec, log2_vec_size);
75 normalize_after_fht(vec, log2_vec_size);
76 }
77
fast_hadamard_transform_ffht_impl(float * vec,int log2_vec_size)78 inline void fast_hadamard_transform_ffht_impl(float* vec, int log2_vec_size) {
79 #if defined(__aarch64__) || defined(__x86_64__)
80 if (log2_vec_size <= 0) {
81 return;
82 }
83
84 fht_float(vec, log2_vec_size);
85 normalize_after_fht(vec, log2_vec_size);
86 #else
87 fast_hadamard_transform_simple_impl(vec, log2_vec_size);
88 #endif
89 }
90
91 } // namespace internal
92
93 // Compute the fast Walsh-Hadamard transform
94 // (https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform)
95 // of vec, which must be of length (1 << log2_vec_size).
96 template <typename T>
fast_hadamard_transform(T * vec,int log2_vec_size)97 void fast_hadamard_transform(T* vec, int log2_vec_size) {
98 if constexpr (std::is_same_v<T, float>) {
99 internal::fast_hadamard_transform_ffht_impl(vec, log2_vec_size);
100 } else {
101 internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size);
102 }
103 }
104
105 // Compute a quantized fast Walsh-Hadamard transform of vec, which
106 // must be of length (1 << log2_vec_size) and symmetrically quantized.
107 //
108 // Note that we do not need to know the quantization scale, because
109 // the Fast Hadamard transform is a series of additions and
110 // subtractions with a final multiplication step, and we have the
111 // following trivial identities:
112 //
113 // scale * a + scale * b = scale * (a + b) (addition doesn't need the scale)
114 // alpha * (scale * a) = scale * (alpha * a) (multiplication doesn't need the
115 // scale)
116 void fast_hadamard_transform_symmetric_quantized_s16(
117 int16_t* vec,
118 int log2_vec_size);
119
120 // Like fast_hadamard_transform, but vec must be of length 28 * (1 <<
121 // log2_vec_size) and the transform is computed by interpreting vec as
122 // a (28, 1 << log2_vec_size) matrix and performing 28 FHTs, followed
123 // by (1 << log2_vec_size) multiplications by a particular Hadamard
124 // matrix of size 28x28 (see special_hadamard_code_gen.py for the
125 // exact matrix).
126 template <typename T>
fast_hadamard_transform_28N(T * vec,int log2_vec_size)127 void fast_hadamard_transform_28N(T* vec, int log2_vec_size) {
128 const int vec_size = (1 << log2_vec_size);
129 for (int ii = 0; ii < 28; ++ii) {
130 fast_hadamard_transform(&vec[ii * vec_size], log2_vec_size);
131 }
132 for (int ii = 0; ii < vec_size; ++ii) {
133 hadamard_mult_28_strided(&vec[ii], vec_size);
134 }
135 }
136
137 // We don't need the quantization scale; see the function-level
138 // comment on fast_hadamard_transform_symmetric_quantized_s16 for
139 // details.
140 void fast_hadamard_transform_symmetric_quantized_s16_28N(
141 int16_t* vec,
142 int log2_vec_size);
143
144 } // namespace executorch
145