xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // (c) Meta Platforms, Inc. and affiliates.
10 #pragma once
11 
12 #include <cassert>
13 #include <cmath>
14 #include <cstdint>
15 #include <memory>
16 
17 #include <executorch/extension/llm/custom_ops/spinquant/third-party/FFHT/fht.h>
18 
19 #include "fast_hadamard_transform_special.h"
20 
21 namespace executorch {
22 namespace internal {
23 
24 // Square root of 1 << log2_n.
25 template <typename T>
fast_sqrt_of_power_of_2(int log2_n)26 T fast_sqrt_of_power_of_2(int log2_n) {
27   // The square root of 2**N is, by definition, 2**(N/2), which is
28   // trivial to compute for even N using a left shift.
29   //
30   // For odd N, 2**(N/2) = 2**(floor(N/2) + 1/2)
31   //                     = 2**(floor(N/2)) * (2 ** (1/2))
32   //                     = 2**(floor(N/2)) * sqrt(2)
33   // which is again fast to compute.
34   return T(1 << (log2_n / 2)) * ((log2_n % 2) ? T(std::sqrt(2)) : T(1));
35 }
36 
37 template <typename T>
normalize_after_fht(T * out,int log2_vec_size)38 void normalize_after_fht(T* out, int log2_vec_size) {
39   const T inv_sqrt = T(1) / fast_sqrt_of_power_of_2<T>(log2_vec_size);
40   const int vec_size = 1 << log2_vec_size;
41   for (int ii = 0; ii < vec_size; ++ii) {
42     out[ii] *= inv_sqrt;
43   }
44 }
45 
46 template <typename T>
fast_hadamard_transform_unnormalized_simple_impl(T * vec,int log2_vec_size)47 void fast_hadamard_transform_unnormalized_simple_impl(
48     T* vec,
49     int log2_vec_size) {
50   // NOTE: If you're here because you're profiling a model and this is
51   // slow, consider updating FFHT to generate efficient assembly for
52   // your data type!
53   if (log2_vec_size == 0) {
54     return;
55   }
56 
57   int step = 1;
58   const auto vec_size = 1 << log2_vec_size;
59   while (step < vec_size) {
60     for (int ii = 0; ii < vec_size; ii += step * 2) {
61       for (int jj = ii; jj < ii + step; ++jj) {
62         auto x = vec[jj];
63         auto y = vec[jj + step];
64         vec[jj] = x + y;
65         vec[jj + step] = x - y;
66       }
67     }
68     step *= 2;
69   }
70 }
71 
72 template <typename T>
fast_hadamard_transform_simple_impl(T * vec,int log2_vec_size)73 void fast_hadamard_transform_simple_impl(T* vec, int log2_vec_size) {
74   fast_hadamard_transform_unnormalized_simple_impl(vec, log2_vec_size);
75   normalize_after_fht(vec, log2_vec_size);
76 }
77 
fast_hadamard_transform_ffht_impl(float * vec,int log2_vec_size)78 inline void fast_hadamard_transform_ffht_impl(float* vec, int log2_vec_size) {
79 #if defined(__aarch64__) || defined(__x86_64__)
80   if (log2_vec_size <= 0) {
81     return;
82   }
83 
84   fht_float(vec, log2_vec_size);
85   normalize_after_fht(vec, log2_vec_size);
86 #else
87   fast_hadamard_transform_simple_impl(vec, log2_vec_size);
88 #endif
89 }
90 
91 } // namespace internal
92 
93 // Compute the fast Walsh-Hadamard transform
94 // (https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform)
95 // of vec, which must be of length (1 << log2_vec_size).
96 template <typename T>
fast_hadamard_transform(T * vec,int log2_vec_size)97 void fast_hadamard_transform(T* vec, int log2_vec_size) {
98   if constexpr (std::is_same_v<T, float>) {
99     internal::fast_hadamard_transform_ffht_impl(vec, log2_vec_size);
100   } else {
101     internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size);
102   }
103 }
104 
105 // Compute a quantized fast Walsh-Hadamard transform of vec, which
106 // must be of length (1 << log2_vec_size) and symmetrically quantized.
107 //
108 // Note that we do not need to know the quantization scale, because
109 // the Fast Hadamard transform is a series of additions and
110 // subtractions with a final multiplication step, and we have the
111 // following trivial identities:
112 //
113 // scale * a + scale * b = scale * (a + b)  (addition doesn't need the scale)
114 // alpha * (scale * a) = scale * (alpha * a) (multiplication doesn't need the
115 // scale)
116 void fast_hadamard_transform_symmetric_quantized_s16(
117     int16_t* vec,
118     int log2_vec_size);
119 
120 // Like fast_hadamard_transform, but vec must be of length 28 * (1 <<
121 // log2_vec_size) and the transform is computed by interpreting vec as
122 // a (28, 1 << log2_vec_size) matrix and performing 28 FHTs, followed
123 // by (1 << log2_vec_size) multiplications by a particular Hadamard
124 // matrix of size 28x28 (see special_hadamard_code_gen.py for the
125 // exact matrix).
126 template <typename T>
fast_hadamard_transform_28N(T * vec,int log2_vec_size)127 void fast_hadamard_transform_28N(T* vec, int log2_vec_size) {
128   const int vec_size = (1 << log2_vec_size);
129   for (int ii = 0; ii < 28; ++ii) {
130     fast_hadamard_transform(&vec[ii * vec_size], log2_vec_size);
131   }
132   for (int ii = 0; ii < vec_size; ++ii) {
133     hadamard_mult_28_strided(&vec[ii], vec_size);
134   }
135 }
136 
137 // We don't need the quantization scale; see the function-level
138 // comment on fast_hadamard_transform_symmetric_quantized_s16 for
139 // details.
140 void fast_hadamard_transform_symmetric_quantized_s16_28N(
141     int16_t* vec,
142     int log2_vec_size);
143 
144 } // namespace executorch
145