xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/hash.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /**
2  * Hash utils in this file is adapted from PyTorch/XLA
3  * https://github.com/pytorch/xla/blob/e0e5f937a0ba8d904f9608137dc8c51ba439df2d/third_party/xla_client/util.h
4  */
5 #pragma once
6 
7 #include <ATen/Tensor.h>
8 #include <c10/core/Scalar.h>
9 #include <c10/util/int128.h>
10 #include <torch/csrc/Export.h>
11 #include <cstring>
12 #include <set>
13 #include <string>
14 #include <string_view>
15 #include <vector>
16 
17 namespace torch {
18 namespace lazy {
19 
20 using size_t = std::size_t;
21 
22 class TORCH_API hash_t : public c10::uint128 {
23  public:
24   // Swich from typedef hash_t = uint128 to provide explicit casters
hash_t(int8_t val)25   hash_t(int8_t val) : uint128(static_cast<uint32_t>(val)) {}
hash_t(int16_t val)26   hash_t(int16_t val) : uint128(static_cast<uint32_t>(val)) {}
hash_t(int32_t val)27   hash_t(int32_t val) : uint128(static_cast<uint32_t>(val)) {}
hash_t(int64_t val)28   hash_t(int64_t val) : uint128(static_cast<uint64_t>(val)) {}
hash_t(uint32_t val)29   hash_t(uint32_t val) : uint128(val) {}
hash_t(uint64_t val)30   hash_t(uint64_t val) : uint128(val) {}
hash_t(uint128 val)31   hash_t(uint128 val) : uint128(val) {}
hash_t(uint64_t top,uint64_t bottom)32   hash_t(uint64_t top, uint64_t bottom) : uint128(top, bottom) {}
hash_t()33   hash_t() : uint128() {}
34 };
35 
36 // Std* functions use 64-bit hash
37 size_t TORCH_API StdDataHash(const void* data, size_t size);
38 
39 size_t TORCH_API StdHashCombine(uintmax_t a, uintmax_t b);
40 
41 // Other functions are all 128-bit
42 hash_t TORCH_API HashBlock(const void* data, size_t n, const hash_t& seed);
43 
44 hash_t TORCH_API DataHash(const void* data, size_t size);
45 
46 hash_t TORCH_API HashCombine(const hash_t& a, const hash_t& b);
47 
48 size_t TORCH_API HashReduce(const hash_t& a);
49 
50 // Returns a string representation of a hash
51 std::string TORCH_API HashToString(const hash_t& a);
52 
53 struct HashReducer {
operatorHashReducer54   size_t operator()(const hash_t& value) const {
55     return HashReduce(value);
56   }
57 };
58 
StringHash(const char * data)59 static inline hash_t StringHash(const char* data) {
60   return DataHash(data, std::strlen(data));
61 }
62 
63 // Automatic templated implementation for 'arithmetic' types
64 template <
65     typename T,
66     typename std::enable_if<std::is_arithmetic<T>::value>::type* = nullptr>
Hash(const T & value)67 hash_t Hash(const T& value) {
68   return DataHash(&value, sizeof(value));
69 }
70 
71 // added because on macos builds the vector<bool> specialization
72 // breaks falling through to the templated arithmetic types above
73 hash_t TORCH_API Hash(const std::vector<bool>& value);
74 
75 // Specialiazed implementations for proprietary types
Hash(const c10::ScalarType & value)76 static inline hash_t Hash(const c10::ScalarType& value) {
77   return DataHash(&value, sizeof(value));
78 }
79 
Hash(const c10::MemoryFormat & value)80 static inline hash_t Hash(const c10::MemoryFormat& value) {
81   return DataHash(&value, sizeof(value));
82 }
83 
Hash(const c10::DeviceType & value)84 static inline hash_t Hash(const c10::DeviceType& value) {
85   return DataHash(&value, sizeof(value));
86 }
87 
Hash(const c10::Device & value)88 static inline hash_t Hash(const c10::Device& value) {
89   return HashCombine(Hash(value.type()), Hash(value.index()));
90 }
91 
Hash(const c10::Layout & value)92 static inline hash_t Hash(const c10::Layout& value) {
93   return DataHash(&value, sizeof(value));
94 }
95 
Hash(const c10::Scalar & value)96 static inline hash_t Hash(const c10::Scalar& value) {
97   switch (value.type()) {
98     case c10::ScalarType::ComplexDouble:
99       return Hash(value.toComplexDouble());
100     case c10::ScalarType::Double:
101       return Hash(value.toDouble());
102     case c10::ScalarType::Long:
103       return Hash(value.toLong());
104     case c10::ScalarType::Bool:
105       return Hash(value.toBool());
106     default:
107       TORCH_INTERNAL_ASSERT(false, "Unknown scalar type.", value.type());
108   }
109 }
110 
TensorHash(const at::Tensor & tensor)111 static inline hash_t TensorHash(const at::Tensor& tensor) {
112   at::Tensor ctensor = tensor.contiguous();
113   int64_t size = ctensor.numel() * ctensor.element_size();
114   switch (ctensor.scalar_type()) {
115     case at::ScalarType::Bool:
116       return DataHash(ctensor.const_data_ptr<bool>(), size);
117     case at::ScalarType::Byte:
118       return DataHash(ctensor.const_data_ptr<uint8_t>(), size);
119     case at::ScalarType::Char:
120       return DataHash(ctensor.const_data_ptr<int8_t>(), size);
121     case at::ScalarType::Short:
122       return DataHash(ctensor.const_data_ptr<int16_t>(), size);
123     case at::ScalarType::Int:
124       return DataHash(ctensor.const_data_ptr<int32_t>(), size);
125     case at::ScalarType::Long:
126       return DataHash(ctensor.const_data_ptr<int64_t>(), size);
127     case at::ScalarType::Float:
128       return DataHash(ctensor.const_data_ptr<float>(), size);
129     case at::ScalarType::Double:
130       return DataHash(ctensor.const_data_ptr<double>(), size);
131     case at::ScalarType::BFloat16:
132       return DataHash(ctensor.const_data_ptr<at::BFloat16>(), size);
133     case at::ScalarType::Half:
134       return DataHash(ctensor.const_data_ptr<at::Half>(), size);
135     case at::ScalarType::ComplexFloat:
136       return DataHash(ctensor.const_data_ptr<c10::complex<float>>(), size);
137     case at::ScalarType::ComplexDouble:
138       return DataHash(ctensor.const_data_ptr<c10::complex<double>>(), size);
139     case at::ScalarType::UInt16:
140       return DataHash(ctensor.const_data_ptr<uint16_t>(), size);
141     case at::ScalarType::UInt32:
142       return DataHash(ctensor.const_data_ptr<uint32_t>(), size);
143     case at::ScalarType::UInt64:
144       return DataHash(ctensor.const_data_ptr<uint64_t>(), size);
145     default:
146       TORCH_INTERNAL_ASSERT(
147           false, "Unsupported scalar type:", ctensor.scalar_type());
148   }
149 }
150 
Hash(const std::string & value)151 static inline hash_t Hash(const std::string& value) {
152   return DataHash(value.data(), value.size());
153 }
154 
Hash(const c10::string_view & value)155 static inline hash_t Hash(const c10::string_view& value) {
156   return DataHash(value.data(), value.size());
157 }
158 
Hash(const std::string_view & value)159 static inline hash_t Hash(const std::string_view& value) {
160   return DataHash(value.data(), value.size());
161 }
162 
Hash(const at::Generator & value)163 static inline hash_t Hash(const at::Generator& value) {
164   return TensorHash(value.get_state());
165 }
166 
167 // Taken from glibc's implementation of hashing optionals,
168 // we want to include a contribution to the hash to distinguish
169 // cases where one or another option was null, but we hope it doesn't
170 // collide with an actually scalar value.
171 //
172 // Use an arbitrary randomly-selected 64-bit integer rather than a
173 // small constant that we then hash at runtime so we don't have to
174 // repeatedly hash a constant at runtime.
175 static const int64_t kNullOpt = 0x8655d738f3678dda;
176 
177 // Hashing for std::optional types contributes to hash
178 // for optionals with null value, important to distinguish
179 // between <nullopt, non-nullopt> and <non-nullopt, nullopt> cases
180 template <typename T>
Hash(const std::optional<T> & value)181 hash_t Hash(const std::optional<T>& value) {
182   if (value.has_value()) {
183     return Hash(value.value());
184   } else {
185     return kNullOpt;
186   }
187 }
188 
189 // Hashing of containers
190 // Forward declare to allow hashes of vectors of vectors to work.
191 template <typename T>
192 hash_t ContainerHash(const T& values);
193 
194 template <typename T>
Hash(const std::vector<T> & values)195 hash_t Hash(const std::vector<T>& values) {
196   return ContainerHash(values);
197 }
198 
199 // Need a special case for std::optional<container>?
200 template <typename T>
Hash(const std::optional<std::vector<T>> & value)201 hash_t Hash(const std::optional<std::vector<T>>& value) {
202   if (value.has_value()) {
203     return ContainerHash(value.value());
204   } else {
205     return kNullOpt;
206   }
207 }
208 
209 template <typename T>
Hash(const std::set<T> & values)210 hash_t Hash(const std::set<T>& values) {
211   return ContainerHash(values);
212 }
213 
214 template <typename T, typename S>
Hash(const std::pair<T,S> & values)215 hash_t Hash(const std::pair<T, S>& values) {
216   return HashCombine(Hash(values.first), Hash(values.second));
217 }
218 
Hash(const hash_t & value)219 static inline hash_t Hash(const hash_t& value) {
220   return value;
221 }
222 
223 template <typename T>
Hash(c10::ArrayRef<T> values)224 hash_t Hash(c10::ArrayRef<T> values) {
225   return ContainerHash(values);
226 }
227 
228 template <typename T>
ContainerHash(const T & values)229 hash_t ContainerHash(const T& values) {
230   hash_t h(static_cast<uint64_t>(0x85ebca77c2b2ae63));
231   for (const auto& value : values) {
232     h = HashCombine(h, Hash(value));
233   }
234   return h;
235 }
236 
237 // Varargs hashing
238 template <typename T = void>
MHash()239 hash_t MHash() {
240   return hash_t(static_cast<uint64_t>(0x165667b19e3779f9));
241 }
242 
243 template <typename T, typename... Targs>
MHash(T value,Targs...Fargs)244 hash_t MHash(T value, Targs... Fargs) {
245   return HashCombine(Hash(value), MHash(Fargs...));
246 }
247 
248 } // namespace lazy
249 } // namespace torch
250