1 /**
2 * Hash utils in this file is adapted from PyTorch/XLA
3 * https://github.com/pytorch/xla/blob/e0e5f937a0ba8d904f9608137dc8c51ba439df2d/third_party/xla_client/util.h
4 */
5 #pragma once
6
7 #include <ATen/Tensor.h>
8 #include <c10/core/Scalar.h>
9 #include <c10/util/int128.h>
10 #include <torch/csrc/Export.h>
11 #include <cstring>
12 #include <set>
13 #include <string>
14 #include <string_view>
15 #include <vector>
16
17 namespace torch {
18 namespace lazy {
19
20 using size_t = std::size_t;
21
22 class TORCH_API hash_t : public c10::uint128 {
23 public:
24 // Swich from typedef hash_t = uint128 to provide explicit casters
hash_t(int8_t val)25 hash_t(int8_t val) : uint128(static_cast<uint32_t>(val)) {}
hash_t(int16_t val)26 hash_t(int16_t val) : uint128(static_cast<uint32_t>(val)) {}
hash_t(int32_t val)27 hash_t(int32_t val) : uint128(static_cast<uint32_t>(val)) {}
hash_t(int64_t val)28 hash_t(int64_t val) : uint128(static_cast<uint64_t>(val)) {}
hash_t(uint32_t val)29 hash_t(uint32_t val) : uint128(val) {}
hash_t(uint64_t val)30 hash_t(uint64_t val) : uint128(val) {}
hash_t(uint128 val)31 hash_t(uint128 val) : uint128(val) {}
hash_t(uint64_t top,uint64_t bottom)32 hash_t(uint64_t top, uint64_t bottom) : uint128(top, bottom) {}
hash_t()33 hash_t() : uint128() {}
34 };
35
36 // Std* functions use 64-bit hash
37 size_t TORCH_API StdDataHash(const void* data, size_t size);
38
39 size_t TORCH_API StdHashCombine(uintmax_t a, uintmax_t b);
40
41 // Other functions are all 128-bit
42 hash_t TORCH_API HashBlock(const void* data, size_t n, const hash_t& seed);
43
44 hash_t TORCH_API DataHash(const void* data, size_t size);
45
46 hash_t TORCH_API HashCombine(const hash_t& a, const hash_t& b);
47
48 size_t TORCH_API HashReduce(const hash_t& a);
49
50 // Returns a string representation of a hash
51 std::string TORCH_API HashToString(const hash_t& a);
52
53 struct HashReducer {
operatorHashReducer54 size_t operator()(const hash_t& value) const {
55 return HashReduce(value);
56 }
57 };
58
StringHash(const char * data)59 static inline hash_t StringHash(const char* data) {
60 return DataHash(data, std::strlen(data));
61 }
62
63 // Automatic templated implementation for 'arithmetic' types
64 template <
65 typename T,
66 typename std::enable_if<std::is_arithmetic<T>::value>::type* = nullptr>
Hash(const T & value)67 hash_t Hash(const T& value) {
68 return DataHash(&value, sizeof(value));
69 }
70
71 // added because on macos builds the vector<bool> specialization
72 // breaks falling through to the templated arithmetic types above
73 hash_t TORCH_API Hash(const std::vector<bool>& value);
74
75 // Specialiazed implementations for proprietary types
Hash(const c10::ScalarType & value)76 static inline hash_t Hash(const c10::ScalarType& value) {
77 return DataHash(&value, sizeof(value));
78 }
79
Hash(const c10::MemoryFormat & value)80 static inline hash_t Hash(const c10::MemoryFormat& value) {
81 return DataHash(&value, sizeof(value));
82 }
83
Hash(const c10::DeviceType & value)84 static inline hash_t Hash(const c10::DeviceType& value) {
85 return DataHash(&value, sizeof(value));
86 }
87
Hash(const c10::Device & value)88 static inline hash_t Hash(const c10::Device& value) {
89 return HashCombine(Hash(value.type()), Hash(value.index()));
90 }
91
Hash(const c10::Layout & value)92 static inline hash_t Hash(const c10::Layout& value) {
93 return DataHash(&value, sizeof(value));
94 }
95
Hash(const c10::Scalar & value)96 static inline hash_t Hash(const c10::Scalar& value) {
97 switch (value.type()) {
98 case c10::ScalarType::ComplexDouble:
99 return Hash(value.toComplexDouble());
100 case c10::ScalarType::Double:
101 return Hash(value.toDouble());
102 case c10::ScalarType::Long:
103 return Hash(value.toLong());
104 case c10::ScalarType::Bool:
105 return Hash(value.toBool());
106 default:
107 TORCH_INTERNAL_ASSERT(false, "Unknown scalar type.", value.type());
108 }
109 }
110
TensorHash(const at::Tensor & tensor)111 static inline hash_t TensorHash(const at::Tensor& tensor) {
112 at::Tensor ctensor = tensor.contiguous();
113 int64_t size = ctensor.numel() * ctensor.element_size();
114 switch (ctensor.scalar_type()) {
115 case at::ScalarType::Bool:
116 return DataHash(ctensor.const_data_ptr<bool>(), size);
117 case at::ScalarType::Byte:
118 return DataHash(ctensor.const_data_ptr<uint8_t>(), size);
119 case at::ScalarType::Char:
120 return DataHash(ctensor.const_data_ptr<int8_t>(), size);
121 case at::ScalarType::Short:
122 return DataHash(ctensor.const_data_ptr<int16_t>(), size);
123 case at::ScalarType::Int:
124 return DataHash(ctensor.const_data_ptr<int32_t>(), size);
125 case at::ScalarType::Long:
126 return DataHash(ctensor.const_data_ptr<int64_t>(), size);
127 case at::ScalarType::Float:
128 return DataHash(ctensor.const_data_ptr<float>(), size);
129 case at::ScalarType::Double:
130 return DataHash(ctensor.const_data_ptr<double>(), size);
131 case at::ScalarType::BFloat16:
132 return DataHash(ctensor.const_data_ptr<at::BFloat16>(), size);
133 case at::ScalarType::Half:
134 return DataHash(ctensor.const_data_ptr<at::Half>(), size);
135 case at::ScalarType::ComplexFloat:
136 return DataHash(ctensor.const_data_ptr<c10::complex<float>>(), size);
137 case at::ScalarType::ComplexDouble:
138 return DataHash(ctensor.const_data_ptr<c10::complex<double>>(), size);
139 case at::ScalarType::UInt16:
140 return DataHash(ctensor.const_data_ptr<uint16_t>(), size);
141 case at::ScalarType::UInt32:
142 return DataHash(ctensor.const_data_ptr<uint32_t>(), size);
143 case at::ScalarType::UInt64:
144 return DataHash(ctensor.const_data_ptr<uint64_t>(), size);
145 default:
146 TORCH_INTERNAL_ASSERT(
147 false, "Unsupported scalar type:", ctensor.scalar_type());
148 }
149 }
150
Hash(const std::string & value)151 static inline hash_t Hash(const std::string& value) {
152 return DataHash(value.data(), value.size());
153 }
154
Hash(const c10::string_view & value)155 static inline hash_t Hash(const c10::string_view& value) {
156 return DataHash(value.data(), value.size());
157 }
158
Hash(const std::string_view & value)159 static inline hash_t Hash(const std::string_view& value) {
160 return DataHash(value.data(), value.size());
161 }
162
Hash(const at::Generator & value)163 static inline hash_t Hash(const at::Generator& value) {
164 return TensorHash(value.get_state());
165 }
166
167 // Taken from glibc's implementation of hashing optionals,
168 // we want to include a contribution to the hash to distinguish
169 // cases where one or another option was null, but we hope it doesn't
170 // collide with an actually scalar value.
171 //
172 // Use an arbitrary randomly-selected 64-bit integer rather than a
173 // small constant that we then hash at runtime so we don't have to
174 // repeatedly hash a constant at runtime.
175 static const int64_t kNullOpt = 0x8655d738f3678dda;
176
177 // Hashing for std::optional types contributes to hash
178 // for optionals with null value, important to distinguish
179 // between <nullopt, non-nullopt> and <non-nullopt, nullopt> cases
180 template <typename T>
Hash(const std::optional<T> & value)181 hash_t Hash(const std::optional<T>& value) {
182 if (value.has_value()) {
183 return Hash(value.value());
184 } else {
185 return kNullOpt;
186 }
187 }
188
189 // Hashing of containers
190 // Forward declare to allow hashes of vectors of vectors to work.
191 template <typename T>
192 hash_t ContainerHash(const T& values);
193
194 template <typename T>
Hash(const std::vector<T> & values)195 hash_t Hash(const std::vector<T>& values) {
196 return ContainerHash(values);
197 }
198
199 // Need a special case for std::optional<container>?
200 template <typename T>
Hash(const std::optional<std::vector<T>> & value)201 hash_t Hash(const std::optional<std::vector<T>>& value) {
202 if (value.has_value()) {
203 return ContainerHash(value.value());
204 } else {
205 return kNullOpt;
206 }
207 }
208
209 template <typename T>
Hash(const std::set<T> & values)210 hash_t Hash(const std::set<T>& values) {
211 return ContainerHash(values);
212 }
213
214 template <typename T, typename S>
Hash(const std::pair<T,S> & values)215 hash_t Hash(const std::pair<T, S>& values) {
216 return HashCombine(Hash(values.first), Hash(values.second));
217 }
218
Hash(const hash_t & value)219 static inline hash_t Hash(const hash_t& value) {
220 return value;
221 }
222
223 template <typename T>
Hash(c10::ArrayRef<T> values)224 hash_t Hash(c10::ArrayRef<T> values) {
225 return ContainerHash(values);
226 }
227
228 template <typename T>
ContainerHash(const T & values)229 hash_t ContainerHash(const T& values) {
230 hash_t h(static_cast<uint64_t>(0x85ebca77c2b2ae63));
231 for (const auto& value : values) {
232 h = HashCombine(h, Hash(value));
233 }
234 return h;
235 }
236
237 // Varargs hashing
238 template <typename T = void>
MHash()239 hash_t MHash() {
240 return hash_t(static_cast<uint64_t>(0x165667b19e3779f9));
241 }
242
243 template <typename T, typename... Targs>
MHash(T value,Targs...Fargs)244 hash_t MHash(T value, Targs... Fargs) {
245 return HashCombine(Hash(value), MHash(Fargs...));
246 }
247
248 } // namespace lazy
249 } // namespace torch
250