xref: /aosp_15_r20/external/pytorch/c10/util/hash.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/Exception.h>
4 #include <cstddef>
5 #include <functional>
6 #include <iomanip>
7 #include <ios>
8 #include <sstream>
9 #include <string>
10 #include <tuple>
11 #include <type_traits>
12 #include <utility>
13 #include <vector>
14 
15 #include <c10/util/ArrayRef.h>
16 #include <c10/util/complex.h>
17 
18 namespace c10 {
19 
20 // NOTE: hash_combine and SHA1 hashing is based on implementation from Boost
21 //
22 // Boost Software License - Version 1.0 - August 17th, 2003
23 //
24 // Permission is hereby granted, free of charge, to any person or organization
25 // obtaining a copy of the software and accompanying documentation covered by
26 // this license (the "Software") to use, reproduce, display, distribute,
27 // execute, and transmit the Software, and to prepare derivative works of the
28 // Software, and to permit third-parties to whom the Software is furnished to
29 // do so, all subject to the following:
30 //
31 // The copyright notices in the Software and this entire statement, including
32 // the above license grant, this restriction and the following disclaimer,
33 // must be included in all copies of the Software, in whole or in part, and
34 // all derivative works of the Software, unless such copies or derivative
35 // works are solely in the form of machine-executable object code generated by
36 // a source language processor.
37 //
38 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
39 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
40 // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
41 // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
42 // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
43 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
44 // DEALINGS IN THE SOFTWARE.
45 
hash_combine(size_t seed,size_t value)46 inline size_t hash_combine(size_t seed, size_t value) {
47   return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
48 }
49 
50 // Creates the SHA1 hash of a string. A 160-bit hash.
51 // Based on the implementation in Boost (see notice above).
52 // Note that SHA1 hashes are no longer considered cryptographically
53 //   secure, but are the standard hash for generating unique ids.
54 // Usage:
55 //   // Let 'code' be a std::string
56 //   c10::sha1 sha1_hash{code};
57 //   const auto hash_code = sha1_hash.str();
58 // TODO: Compare vs OpenSSL and/or CryptoPP implementations
59 struct sha1 {
60   typedef unsigned int(digest_type)[5];
61 
62   sha1(const std::string& s = "") {
63     if (!s.empty()) {
64       reset();
65       process_bytes(s.c_str(), s.size());
66     }
67   }
68 
resetsha169   void reset() {
70     h_[0] = 0x67452301;
71     h_[1] = 0xEFCDAB89;
72     h_[2] = 0x98BADCFE;
73     h_[3] = 0x10325476;
74     h_[4] = 0xC3D2E1F0;
75 
76     block_byte_index_ = 0;
77     bit_count_low = 0;
78     bit_count_high = 0;
79   }
80 
strsha181   std::string str() {
82     unsigned int digest[5];
83     get_digest(digest);
84 
85     std::ostringstream buf;
86     for (unsigned int i : digest) {
87       buf << std::hex << std::setfill('0') << std::setw(8) << i;
88     }
89 
90     return buf.str();
91   }
92 
93  private:
left_rotatesha194   unsigned int left_rotate(unsigned int x, std::size_t n) {
95     return (x << n) ^ (x >> (32 - n));
96   }
97 
process_block_implsha198   void process_block_impl() {
99     unsigned int w[80];
100 
101     for (std::size_t i = 0; i < 16; ++i) {
102       w[i] = (block_[i * 4 + 0] << 24);
103       w[i] |= (block_[i * 4 + 1] << 16);
104       w[i] |= (block_[i * 4 + 2] << 8);
105       w[i] |= (block_[i * 4 + 3]);
106     }
107 
108     for (std::size_t i = 16; i < 80; ++i) {
109       w[i] = left_rotate((w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]), 1);
110     }
111 
112     unsigned int a = h_[0];
113     unsigned int b = h_[1];
114     unsigned int c = h_[2];
115     unsigned int d = h_[3];
116     unsigned int e = h_[4];
117 
118     for (std::size_t i = 0; i < 80; ++i) {
119       unsigned int f = 0;
120       unsigned int k = 0;
121 
122       if (i < 20) {
123         f = (b & c) | (~b & d);
124         k = 0x5A827999;
125       } else if (i < 40) {
126         f = b ^ c ^ d;
127         k = 0x6ED9EBA1;
128       } else if (i < 60) {
129         f = (b & c) | (b & d) | (c & d);
130         k = 0x8F1BBCDC;
131       } else {
132         f = b ^ c ^ d;
133         k = 0xCA62C1D6;
134       }
135 
136       unsigned temp = left_rotate(a, 5) + f + e + k + w[i];
137       e = d;
138       d = c;
139       c = left_rotate(b, 30);
140       b = a;
141       a = temp;
142     }
143 
144     h_[0] += a;
145     h_[1] += b;
146     h_[2] += c;
147     h_[3] += d;
148     h_[4] += e;
149   }
150 
process_byte_implsha1151   void process_byte_impl(unsigned char byte) {
152     block_[block_byte_index_++] = byte;
153 
154     if (block_byte_index_ == 64) {
155       block_byte_index_ = 0;
156       process_block_impl();
157     }
158   }
159 
process_bytesha1160   void process_byte(unsigned char byte) {
161     process_byte_impl(byte);
162 
163     // size_t max value = 0xFFFFFFFF
164     // if (bit_count_low + 8 >= 0x100000000) { // would overflow
165     // if (bit_count_low >= 0x100000000-8) {
166     if (bit_count_low < 0xFFFFFFF8) {
167       bit_count_low += 8;
168     } else {
169       bit_count_low = 0;
170 
171       if (bit_count_high <= 0xFFFFFFFE) {
172         ++bit_count_high;
173       } else {
174         TORCH_CHECK(false, "sha1 too many bytes");
175       }
176     }
177   }
178 
process_blocksha1179   void process_block(void const* bytes_begin, void const* bytes_end) {
180     unsigned char const* begin = static_cast<unsigned char const*>(bytes_begin);
181     unsigned char const* end = static_cast<unsigned char const*>(bytes_end);
182     for (; begin != end; ++begin) {
183       process_byte(*begin);
184     }
185   }
186 
process_bytessha1187   void process_bytes(void const* buffer, std::size_t byte_count) {
188     unsigned char const* b = static_cast<unsigned char const*>(buffer);
189     process_block(b, b + byte_count);
190   }
191 
get_digestsha1192   void get_digest(digest_type& digest) {
193     // append the bit '1' to the message
194     process_byte_impl(0x80);
195 
196     // append k bits '0', where k is the minimum number >= 0
197     // such that the resulting message length is congruent to 56 (mod 64)
198     // check if there is enough space for padding and bit_count
199     if (block_byte_index_ > 56) {
200       // finish this block
201       while (block_byte_index_ != 0) {
202         process_byte_impl(0);
203       }
204 
205       // one more block
206       while (block_byte_index_ < 56) {
207         process_byte_impl(0);
208       }
209     } else {
210       while (block_byte_index_ < 56) {
211         process_byte_impl(0);
212       }
213     }
214 
215     // append length of message (before pre-processing)
216     // as a 64-bit big-endian integer
217     process_byte_impl(
218         static_cast<unsigned char>((bit_count_high >> 24) & 0xFF));
219     process_byte_impl(
220         static_cast<unsigned char>((bit_count_high >> 16) & 0xFF));
221     process_byte_impl(static_cast<unsigned char>((bit_count_high >> 8) & 0xFF));
222     process_byte_impl(static_cast<unsigned char>((bit_count_high) & 0xFF));
223     process_byte_impl(static_cast<unsigned char>((bit_count_low >> 24) & 0xFF));
224     process_byte_impl(static_cast<unsigned char>((bit_count_low >> 16) & 0xFF));
225     process_byte_impl(static_cast<unsigned char>((bit_count_low >> 8) & 0xFF));
226     process_byte_impl(static_cast<unsigned char>((bit_count_low) & 0xFF));
227 
228     // get final digest
229     digest[0] = h_[0];
230     digest[1] = h_[1];
231     digest[2] = h_[2];
232     digest[3] = h_[3];
233     digest[4] = h_[4];
234   }
235 
236   unsigned int h_[5]{};
237   unsigned char block_[64]{};
238   std::size_t block_byte_index_{};
239   std::size_t bit_count_low{};
240   std::size_t bit_count_high{};
241 };
242 
twang_mix64(uint64_t key)243 constexpr uint64_t twang_mix64(uint64_t key) noexcept {
244   key = (~key) + (key << 21); // key *= (1 << 21) - 1; key -= 1;
245   key = key ^ (key >> 24);
246   key = key + (key << 3) + (key << 8); // key *= 1 + (1 << 3) + (1 << 8)
247   key = key ^ (key >> 14);
248   key = key + (key << 2) + (key << 4); // key *= 1 + (1 << 2) + (1 << 4)
249   key = key ^ (key >> 28);
250   key = key + (key << 31); // key *= 1 + (1 << 31)
251   return key;
252 }
253 
254 ////////////////////////////////////////////////////////////////////////////////
255 // c10::hash implementation
256 ////////////////////////////////////////////////////////////////////////////////
257 
258 namespace _hash_detail {
259 
260 // Use template argument deduction to shorten calls to c10::hash
261 template <typename T>
262 size_t simple_get_hash(const T& o);
263 
264 template <typename T, typename V>
265 using type_if_not_enum = std::enable_if_t<!std::is_enum_v<T>, V>;
266 
267 // Use SFINAE to dispatch to std::hash if possible, cast enum types to int
268 // automatically, and fall back to T::hash otherwise. NOTE: C++14 added support
269 // for hashing enum types to the standard, and some compilers implement it even
270 // when C++14 flags aren't specified. This is why we have to disable this
271 // overload if T is an enum type (and use the one below in this case).
272 template <typename T>
273 auto dispatch_hash(const T& o)
274     -> decltype(std::hash<T>()(o), type_if_not_enum<T, size_t>()) {
275   return std::hash<T>()(o);
276 }
277 
278 template <typename T>
dispatch_hash(const T & o)279 std::enable_if_t<std::is_enum_v<T>, size_t> dispatch_hash(const T& o) {
280   using R = std::underlying_type_t<T>;
281   return std::hash<R>()(static_cast<R>(o));
282 }
283 
284 template <typename T>
285 auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) {
286   return T::hash(o);
287 }
288 
289 } // namespace _hash_detail
290 
291 // Hasher struct
292 template <typename T>
293 struct hash {
operatorhash294   size_t operator()(const T& o) const {
295     return _hash_detail::dispatch_hash(o);
296   };
297 };
298 
299 // Specialization for std::tuple
300 template <typename... Types>
301 struct hash<std::tuple<Types...>> {
302   template <size_t idx, typename... Ts>
303   struct tuple_hash {
304     size_t operator()(const std::tuple<Ts...>& t) const {
305       return hash_combine(
306           _hash_detail::simple_get_hash(std::get<idx>(t)),
307           tuple_hash<idx - 1, Ts...>()(t));
308     }
309   };
310 
311   template <typename... Ts>
312   struct tuple_hash<0, Ts...> {
313     size_t operator()(const std::tuple<Ts...>& t) const {
314       return _hash_detail::simple_get_hash(std::get<0>(t));
315     }
316   };
317 
318   size_t operator()(const std::tuple<Types...>& t) const {
319     return tuple_hash<sizeof...(Types) - 1, Types...>()(t);
320   }
321 };
322 
323 template <typename T1, typename T2>
324 struct hash<std::pair<T1, T2>> {
325   size_t operator()(const std::pair<T1, T2>& pair) const {
326     std::tuple<T1, T2> tuple = std::make_tuple(pair.first, pair.second);
327     return _hash_detail::simple_get_hash(tuple);
328   }
329 };
330 
331 template <typename T>
332 struct hash<c10::ArrayRef<T>> {
333   size_t operator()(c10::ArrayRef<T> v) const {
334     size_t seed = 0;
335     for (const auto& elem : v) {
336       seed = hash_combine(seed, _hash_detail::simple_get_hash(elem));
337     }
338     return seed;
339   }
340 };
341 
342 // Specialization for std::vector
343 template <typename T>
344 struct hash<std::vector<T>> {
345   size_t operator()(const std::vector<T>& v) const {
346     return hash<c10::ArrayRef<T>>()(v);
347   }
348 };
349 
350 namespace _hash_detail {
351 
352 template <typename T>
353 size_t simple_get_hash(const T& o) {
354   return c10::hash<T>()(o);
355 }
356 
357 } // namespace _hash_detail
358 
359 // Use this function to actually hash multiple things in one line.
360 // Dispatches to c10::hash, so it can hash containers.
361 // Example:
362 //
363 // static size_t hash(const MyStruct& s) {
364 //   return get_hash(s.member1, s.member2, s.member3);
365 // }
366 template <typename... Types>
367 size_t get_hash(const Types&... args) {
368   return c10::hash<decltype(std::tie(args...))>()(std::tie(args...));
369 }
370 
371 // Specialization for c10::complex
372 template <typename T>
373 struct hash<c10::complex<T>> {
374   size_t operator()(const c10::complex<T>& c) const {
375     return get_hash(c.real(), c.imag());
376   }
377 };
378 
379 } // namespace c10
380