1 #pragma once
2
3 #include <c10/util/Exception.h>
4 #include <cstddef>
5 #include <functional>
6 #include <iomanip>
7 #include <ios>
8 #include <sstream>
9 #include <string>
10 #include <tuple>
11 #include <type_traits>
12 #include <utility>
13 #include <vector>
14
15 #include <c10/util/ArrayRef.h>
16 #include <c10/util/complex.h>
17
18 namespace c10 {
19
20 // NOTE: hash_combine and SHA1 hashing is based on implementation from Boost
21 //
22 // Boost Software License - Version 1.0 - August 17th, 2003
23 //
24 // Permission is hereby granted, free of charge, to any person or organization
25 // obtaining a copy of the software and accompanying documentation covered by
26 // this license (the "Software") to use, reproduce, display, distribute,
27 // execute, and transmit the Software, and to prepare derivative works of the
28 // Software, and to permit third-parties to whom the Software is furnished to
29 // do so, all subject to the following:
30 //
31 // The copyright notices in the Software and this entire statement, including
32 // the above license grant, this restriction and the following disclaimer,
33 // must be included in all copies of the Software, in whole or in part, and
34 // all derivative works of the Software, unless such copies or derivative
35 // works are solely in the form of machine-executable object code generated by
36 // a source language processor.
37 //
38 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
39 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
40 // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
41 // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
42 // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
43 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
44 // DEALINGS IN THE SOFTWARE.
45
hash_combine(size_t seed,size_t value)46 inline size_t hash_combine(size_t seed, size_t value) {
47 return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
48 }
49
50 // Creates the SHA1 hash of a string. A 160-bit hash.
51 // Based on the implementation in Boost (see notice above).
52 // Note that SHA1 hashes are no longer considered cryptographically
53 // secure, but are the standard hash for generating unique ids.
54 // Usage:
55 // // Let 'code' be a std::string
56 // c10::sha1 sha1_hash{code};
57 // const auto hash_code = sha1_hash.str();
58 // TODO: Compare vs OpenSSL and/or CryptoPP implementations
59 struct sha1 {
60 typedef unsigned int(digest_type)[5];
61
62 sha1(const std::string& s = "") {
63 if (!s.empty()) {
64 reset();
65 process_bytes(s.c_str(), s.size());
66 }
67 }
68
resetsha169 void reset() {
70 h_[0] = 0x67452301;
71 h_[1] = 0xEFCDAB89;
72 h_[2] = 0x98BADCFE;
73 h_[3] = 0x10325476;
74 h_[4] = 0xC3D2E1F0;
75
76 block_byte_index_ = 0;
77 bit_count_low = 0;
78 bit_count_high = 0;
79 }
80
strsha181 std::string str() {
82 unsigned int digest[5];
83 get_digest(digest);
84
85 std::ostringstream buf;
86 for (unsigned int i : digest) {
87 buf << std::hex << std::setfill('0') << std::setw(8) << i;
88 }
89
90 return buf.str();
91 }
92
93 private:
left_rotatesha194 unsigned int left_rotate(unsigned int x, std::size_t n) {
95 return (x << n) ^ (x >> (32 - n));
96 }
97
process_block_implsha198 void process_block_impl() {
99 unsigned int w[80];
100
101 for (std::size_t i = 0; i < 16; ++i) {
102 w[i] = (block_[i * 4 + 0] << 24);
103 w[i] |= (block_[i * 4 + 1] << 16);
104 w[i] |= (block_[i * 4 + 2] << 8);
105 w[i] |= (block_[i * 4 + 3]);
106 }
107
108 for (std::size_t i = 16; i < 80; ++i) {
109 w[i] = left_rotate((w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]), 1);
110 }
111
112 unsigned int a = h_[0];
113 unsigned int b = h_[1];
114 unsigned int c = h_[2];
115 unsigned int d = h_[3];
116 unsigned int e = h_[4];
117
118 for (std::size_t i = 0; i < 80; ++i) {
119 unsigned int f = 0;
120 unsigned int k = 0;
121
122 if (i < 20) {
123 f = (b & c) | (~b & d);
124 k = 0x5A827999;
125 } else if (i < 40) {
126 f = b ^ c ^ d;
127 k = 0x6ED9EBA1;
128 } else if (i < 60) {
129 f = (b & c) | (b & d) | (c & d);
130 k = 0x8F1BBCDC;
131 } else {
132 f = b ^ c ^ d;
133 k = 0xCA62C1D6;
134 }
135
136 unsigned temp = left_rotate(a, 5) + f + e + k + w[i];
137 e = d;
138 d = c;
139 c = left_rotate(b, 30);
140 b = a;
141 a = temp;
142 }
143
144 h_[0] += a;
145 h_[1] += b;
146 h_[2] += c;
147 h_[3] += d;
148 h_[4] += e;
149 }
150
process_byte_implsha1151 void process_byte_impl(unsigned char byte) {
152 block_[block_byte_index_++] = byte;
153
154 if (block_byte_index_ == 64) {
155 block_byte_index_ = 0;
156 process_block_impl();
157 }
158 }
159
process_bytesha1160 void process_byte(unsigned char byte) {
161 process_byte_impl(byte);
162
163 // size_t max value = 0xFFFFFFFF
164 // if (bit_count_low + 8 >= 0x100000000) { // would overflow
165 // if (bit_count_low >= 0x100000000-8) {
166 if (bit_count_low < 0xFFFFFFF8) {
167 bit_count_low += 8;
168 } else {
169 bit_count_low = 0;
170
171 if (bit_count_high <= 0xFFFFFFFE) {
172 ++bit_count_high;
173 } else {
174 TORCH_CHECK(false, "sha1 too many bytes");
175 }
176 }
177 }
178
process_blocksha1179 void process_block(void const* bytes_begin, void const* bytes_end) {
180 unsigned char const* begin = static_cast<unsigned char const*>(bytes_begin);
181 unsigned char const* end = static_cast<unsigned char const*>(bytes_end);
182 for (; begin != end; ++begin) {
183 process_byte(*begin);
184 }
185 }
186
process_bytessha1187 void process_bytes(void const* buffer, std::size_t byte_count) {
188 unsigned char const* b = static_cast<unsigned char const*>(buffer);
189 process_block(b, b + byte_count);
190 }
191
get_digestsha1192 void get_digest(digest_type& digest) {
193 // append the bit '1' to the message
194 process_byte_impl(0x80);
195
196 // append k bits '0', where k is the minimum number >= 0
197 // such that the resulting message length is congruent to 56 (mod 64)
198 // check if there is enough space for padding and bit_count
199 if (block_byte_index_ > 56) {
200 // finish this block
201 while (block_byte_index_ != 0) {
202 process_byte_impl(0);
203 }
204
205 // one more block
206 while (block_byte_index_ < 56) {
207 process_byte_impl(0);
208 }
209 } else {
210 while (block_byte_index_ < 56) {
211 process_byte_impl(0);
212 }
213 }
214
215 // append length of message (before pre-processing)
216 // as a 64-bit big-endian integer
217 process_byte_impl(
218 static_cast<unsigned char>((bit_count_high >> 24) & 0xFF));
219 process_byte_impl(
220 static_cast<unsigned char>((bit_count_high >> 16) & 0xFF));
221 process_byte_impl(static_cast<unsigned char>((bit_count_high >> 8) & 0xFF));
222 process_byte_impl(static_cast<unsigned char>((bit_count_high) & 0xFF));
223 process_byte_impl(static_cast<unsigned char>((bit_count_low >> 24) & 0xFF));
224 process_byte_impl(static_cast<unsigned char>((bit_count_low >> 16) & 0xFF));
225 process_byte_impl(static_cast<unsigned char>((bit_count_low >> 8) & 0xFF));
226 process_byte_impl(static_cast<unsigned char>((bit_count_low) & 0xFF));
227
228 // get final digest
229 digest[0] = h_[0];
230 digest[1] = h_[1];
231 digest[2] = h_[2];
232 digest[3] = h_[3];
233 digest[4] = h_[4];
234 }
235
236 unsigned int h_[5]{};
237 unsigned char block_[64]{};
238 std::size_t block_byte_index_{};
239 std::size_t bit_count_low{};
240 std::size_t bit_count_high{};
241 };
242
twang_mix64(uint64_t key)243 constexpr uint64_t twang_mix64(uint64_t key) noexcept {
244 key = (~key) + (key << 21); // key *= (1 << 21) - 1; key -= 1;
245 key = key ^ (key >> 24);
246 key = key + (key << 3) + (key << 8); // key *= 1 + (1 << 3) + (1 << 8)
247 key = key ^ (key >> 14);
248 key = key + (key << 2) + (key << 4); // key *= 1 + (1 << 2) + (1 << 4)
249 key = key ^ (key >> 28);
250 key = key + (key << 31); // key *= 1 + (1 << 31)
251 return key;
252 }
253
254 ////////////////////////////////////////////////////////////////////////////////
255 // c10::hash implementation
256 ////////////////////////////////////////////////////////////////////////////////
257
258 namespace _hash_detail {
259
260 // Use template argument deduction to shorten calls to c10::hash
261 template <typename T>
262 size_t simple_get_hash(const T& o);
263
264 template <typename T, typename V>
265 using type_if_not_enum = std::enable_if_t<!std::is_enum_v<T>, V>;
266
267 // Use SFINAE to dispatch to std::hash if possible, cast enum types to int
268 // automatically, and fall back to T::hash otherwise. NOTE: C++14 added support
269 // for hashing enum types to the standard, and some compilers implement it even
270 // when C++14 flags aren't specified. This is why we have to disable this
271 // overload if T is an enum type (and use the one below in this case).
272 template <typename T>
273 auto dispatch_hash(const T& o)
274 -> decltype(std::hash<T>()(o), type_if_not_enum<T, size_t>()) {
275 return std::hash<T>()(o);
276 }
277
278 template <typename T>
dispatch_hash(const T & o)279 std::enable_if_t<std::is_enum_v<T>, size_t> dispatch_hash(const T& o) {
280 using R = std::underlying_type_t<T>;
281 return std::hash<R>()(static_cast<R>(o));
282 }
283
284 template <typename T>
285 auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) {
286 return T::hash(o);
287 }
288
289 } // namespace _hash_detail
290
291 // Hasher struct
292 template <typename T>
293 struct hash {
operatorhash294 size_t operator()(const T& o) const {
295 return _hash_detail::dispatch_hash(o);
296 };
297 };
298
299 // Specialization for std::tuple
300 template <typename... Types>
301 struct hash<std::tuple<Types...>> {
302 template <size_t idx, typename... Ts>
303 struct tuple_hash {
304 size_t operator()(const std::tuple<Ts...>& t) const {
305 return hash_combine(
306 _hash_detail::simple_get_hash(std::get<idx>(t)),
307 tuple_hash<idx - 1, Ts...>()(t));
308 }
309 };
310
311 template <typename... Ts>
312 struct tuple_hash<0, Ts...> {
313 size_t operator()(const std::tuple<Ts...>& t) const {
314 return _hash_detail::simple_get_hash(std::get<0>(t));
315 }
316 };
317
318 size_t operator()(const std::tuple<Types...>& t) const {
319 return tuple_hash<sizeof...(Types) - 1, Types...>()(t);
320 }
321 };
322
323 template <typename T1, typename T2>
324 struct hash<std::pair<T1, T2>> {
325 size_t operator()(const std::pair<T1, T2>& pair) const {
326 std::tuple<T1, T2> tuple = std::make_tuple(pair.first, pair.second);
327 return _hash_detail::simple_get_hash(tuple);
328 }
329 };
330
331 template <typename T>
332 struct hash<c10::ArrayRef<T>> {
333 size_t operator()(c10::ArrayRef<T> v) const {
334 size_t seed = 0;
335 for (const auto& elem : v) {
336 seed = hash_combine(seed, _hash_detail::simple_get_hash(elem));
337 }
338 return seed;
339 }
340 };
341
342 // Specialization for std::vector
343 template <typename T>
344 struct hash<std::vector<T>> {
345 size_t operator()(const std::vector<T>& v) const {
346 return hash<c10::ArrayRef<T>>()(v);
347 }
348 };
349
350 namespace _hash_detail {
351
352 template <typename T>
353 size_t simple_get_hash(const T& o) {
354 return c10::hash<T>()(o);
355 }
356
357 } // namespace _hash_detail
358
359 // Use this function to actually hash multiple things in one line.
360 // Dispatches to c10::hash, so it can hash containers.
361 // Example:
362 //
363 // static size_t hash(const MyStruct& s) {
364 // return get_hash(s.member1, s.member2, s.member3);
365 // }
366 template <typename... Types>
367 size_t get_hash(const Types&... args) {
368 return c10::hash<decltype(std::tie(args...))>()(std::tie(args...));
369 }
370
371 // Specialization for c10::complex
372 template <typename T>
373 struct hash<c10::complex<T>> {
374 size_t operator()(const c10::complex<T>& c) const {
375 return get_hash(c.real(), c.imag());
376 }
377 };
378
379 } // namespace c10
380