1# Portions (c) Meta Platforms, Inc. and affiliates. 2# This file is adapted from 3# https://github.com/Dao-AILab/fast-hadamard-transform/blob/master/csrc/code_gen.py . 4 5# BSD 3-Clause License 6 7# Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 8# All rights reserved. 9 10# Redistribution and use in source and binary forms, with or without 11# modification, are permitted provided that the following conditions are met: 12 13# * Redistributions of source code must retain the above copyright notice, this 14# list of conditions and the following disclaimer. 15 16# * Redistributions in binary form must reproduce the above copyright notice, 17# this list of conditions and the following disclaimer in the documentation 18# and/or other materials provided with the distribution. 19 20# * Neither the name of the copyright holder nor the names of its 21# contributors may be used to endorse or promote products derived from 22# this software without specific prior written permission. 23 24# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 35from pathlib import Path 36 37import numpy as np 38 39# From https://en.wikipedia.org/wiki/Paley_construction (construction II for q = 5) 40 41had_12_paley = """ 42+-++++++++++ 43--+-+-+-+-+- 44+++-++----++ 45+---+--+-++- 46+++++-++---- 47+-+---+--+-+ 48++--+++-++-- 49+--++---+--+ 50++----+++-++ 51+--+-++---+- 52++++----+++- 53+-+--+-++--- 54""" 55 56# From http://neilsloane.com/hadamard/ 57 58had_12 = """ 59+----------- 60++-+---+++-+ 61+++-+---+++- 62+-++-+---+++ 63++-++-+---++ 64+++-++-+---+ 65++++-++-+--- 66+-+++-++-+-- 67+--+++-++-+- 68+---+++-++-+ 69++---+++-++- 70+-+---+++-++ 71""" 72 73had_20_will = """ 74+----+----++--++-++- 75-+----+---+++---+-++ 76--+----+---+++-+-+-+ 77---+----+---+++++-+- 78----+----++--++-++-+ 79-+++++-----+--+++--+ 80+-+++-+---+-+--+++-- 81++-++--+---+-+--+++- 82+++-+---+---+-+--+++ 83++++-----++--+-+--++ 84--++-+-++-+-----++++ 85---++-+-++-+---+-+++ 86+---++-+-+--+--++-++ 87++---++-+----+-+++-+ 88-++---++-+----+++++- 89-+--+--++-+----+---- 90+-+-----++-+----+--- 91-+-+-+---+--+----+-- 92--+-+++------+----+- 93+--+--++------+----+ 94""" 95 96 97had_28_will = """ 98+------++----++-+--+-+--++-- 99-+-----+++-----+-+--+-+--++- 100--+-----+++---+-+-+----+--++ 101---+-----+++---+-+-+-+--+--+ 102----+-----+++---+-+-+++--+-- 103-----+-----++++--+-+--++--+- 104------++----++-+--+-+--++--+ 105--++++-+-------++--+++-+--+- 106---++++-+-----+-++--+-+-+--+ 107+---+++--+----++-++--+-+-+-- 108++---++---+----++-++--+-+-+- 109+++---+----+----++-++--+-+-+ 110++++--------+-+--++-++--+-+- 111-++++--------+++--++--+--+-+ 112-+-++-++--++--+--------++++- 113+-+-++--+--++--+--------++++ 114-+-+-++--+--++--+----+---+++ 115+-+-+-++--+--+---+---++---++ 116++-+-+-++--+------+--+++---+ 117-++-+-+-++--+------+-++++--- 118+-++-+---++--+------+-++++-- 119-++--++-+-++-+++----++------ 120+-++--++-+-++-+++-----+----- 121++-++---+-+-++-+++-----+---- 122-++-++-+-+-+-+--+++-----+--- 123--++-++++-+-+----+++-----+-- 124+--++-+-++-+-+----+++-----+- 125++--++-+-++-+-+----++------+ 126""" 127 128 129had_40_tpal = """ 130+-------------------+------------------- 131++-++----+-+-++++--+++-++----+-+-++++--+ 132+++-++----+-+-++++--+++-++----+-+-++++-- 133+-++-++----+-+-++++-+-++-++----+-+-++++- 134+--++-++----+-+-+++++--++-++----+-+-++++ 135++--++-++----+-+-+++++--++-++----+-+-+++ 136+++--++-++----+-+-+++++--++-++----+-+-++ 137++++--++-++----+-+-+++++--++-++----+-+-+ 138+++++--++-++----+-+-+++++--++-++----+-+- 139+-++++--++-++----+-++-++++--++-++----+-+ 140++-++++--++-++----+-++-++++--++-++----+- 141+-+-++++--++-++----++-+-++++--++-++----+ 142++-+-++++--++-++----++-+-++++--++-++---- 143+-+-+-++++--++-++---+-+-+-++++--++-++--- 144+--+-+-++++--++-++--+--+-+-++++--++-++-- 145+---+-+-++++--++-++-+---+-+-++++--++-++- 146+----+-+-++++--++-+++----+-+-++++--++-++ 147++----+-+-++++--++-+++----+-+-++++--++-+ 148+++----+-+-++++--++-+++----+-+-++++--++- 149+-++----+-+-++++--+++-++----+-+-++++--++ 150+--------------------+++++++++++++++++++ 151++-++----+-+-++++--+--+--++++-+-+----++- 152+++-++----+-+-++++-----+--++++-+-+----++ 153+-++-++----+-+-++++--+--+--++++-+-+----+ 154+--++-++----+-+-++++-++--+--++++-+-+---- 155++--++-++----+-+-+++--++--+--++++-+-+--- 156+++--++-++----+-+-++---++--+--++++-+-+-- 157++++--++-++----+-+-+----++--+--++++-+-+- 158+++++--++-++----+-+------++--+--++++-+-+ 159+-++++--++-++----+-+-+----++--+--++++-+- 160++-++++--++-++----+---+----++--+--++++-+ 161+-+-++++--++-++----+-+-+----++--+--++++- 162++-+-++++--++-++------+-+----++--+--++++ 163+-+-+-++++--++-++----+-+-+----++--+--+++ 164+--+-+-++++--++-++---++-+-+----++--+--++ 165+---+-+-++++--++-++--+++-+-+----++--+--+ 166+----+-+-++++--++-++-++++-+-+----++--+-- 167++----+-+-++++--++-+--++++-+-+----++--+- 168+++----+-+-++++--++----++++-+-+----++--+ 169+-++----+-+-++++--++-+--++++-+-+----++-- 170""" 171 172# NOTE: the original Dao-AILab/fast-hadamard-transform uses had_12_paley rather than 173# had_12 here. However, SpinQuant and QuaRot seem to use had_12, so we follow them here. 174had_strings = [had_12, had_20_will, had_28_will, had_40_tpal] 175 176header = """ 177 178#pragma once 179 180""" 181 182 183TEMPLATE = """ 184__device__ __forceinline__ void hadamard_mult_thread_{N}(float x[{N}]) {{ 185 float out[{N}]; 186 {code} 187 #pragma unroll 188 for (int i = 0; i < {N}; i++) {{ x[i] = out[i]; }} 189}} 190 191""" 192 193 194CPU_TEMPLATE = """ 195template <typename T> 196void hadamard_mult_{N}(T* x) {{ 197 float out[{N}]; 198 {code} 199 #pragma unroll 200 for (int i = 0; i < {N}; i++) {{ x[i] = out[i]; }} 201}} 202 203""" 204 205STRIDED_CPU_TEMPLATE = """ 206template <typename T> 207void hadamard_mult_{N}_strided(T* input, int stride) {{ 208 T x[{N}]; 209 T out[{N}]; 210 {strided_load_code} 211 {code} 212 #pragma unroll 213 for (int ii = 0; ii < {N}; ++ii) {{ input[stride * ii] = out[ii]; }} 214}} 215 216""" 217 218 219def string_to_array(string): 220 # Convert strings of + and - to bool arrays 221 string = string.strip().replace("+", "1").replace("-", "-1").split() 222 return np.stack( 223 [ 224 np.fromstring(" ".join(string[i]), dtype=np.int32, sep=" ") 225 for i in range(len(string)) 226 ] 227 ) 228 229 230def strided_load_code_gen(N): 231 return "\n ".join([f"x[{i}] = input[{i} * stride];" for i in range(N)]) 232 233 234def array_code_gen(arr, template): 235 N = arr.shape[0] 236 assert arr.shape[0] == arr.shape[1] 237 out = [] 238 for i in range(N): 239 out.append( 240 f"out[{i}] = " 241 + " ".join([f"{'+' if arr[i, j] == 1 else '-'} x[{j}]" for j in range(N)]) 242 + ";" 243 ) 244 return template.format( 245 N=str(N), code="\n ".join(out), strided_load_code=strided_load_code_gen(N) 246 ) 247 248 249OPTION_TO_TEMPLATE = { 250 "cuda": TEMPLATE, 251 "cpu": CPU_TEMPLATE, 252 "strided_cpu": STRIDED_CPU_TEMPLATE, 253} 254 255 256def main(option="cuda"): 257 try: 258 template = OPTION_TO_TEMPLATE[option] 259 except KeyError: 260 raise Exception( 261 f"bad target option {option}; options are {', '.join(OPTION_TO_TEMPLATE.keys())}" 262 ) 263 output_dir = Path(__file__).parent / "fast_hadamard_transform_special.h" 264 generated_line = f"// @{'generated'} by special_hadamard_code_gen.py {option}\n" 265 266 output_dir.write_text( 267 generated_line 268 + header 269 + "".join(array_code_gen(string_to_array(s), template) for s in had_strings) 270 ) 271 272 273if __name__ == "__main__": 274 import sys 275 276 option = "cuda" 277 if len(sys.argv) > 1: 278 option = sys.argv[1] 279 main(option) 280