xref: /aosp_15_r20/external/pytorch/c10/util/Load.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/macros/Macros.h>
3 #include <cstring>
4 
5 namespace c10 {
6 namespace detail {
7 
8 template <typename T>
9 struct LoadImpl {
applyLoadImpl10   C10_HOST_DEVICE static T apply(const void* src) {
11     return *reinterpret_cast<const T*>(src);
12   }
13 };
14 
15 template <>
16 struct LoadImpl<bool> {
17   C10_HOST_DEVICE static bool apply(const void* src) {
18     static_assert(sizeof(bool) == sizeof(char));
19     // NOTE: [Loading boolean values]
20     // Protect against invalid boolean values by loading as a byte
21     // first, then converting to bool (see gh-54789).
22     return *reinterpret_cast<const unsigned char*>(src);
23   }
24 };
25 
26 } // namespace detail
27 
28 template <typename T>
29 C10_HOST_DEVICE T load(const void* src) {
30   return c10::detail::LoadImpl<T>::apply(src);
31 }
32 
33 template <typename scalar_t>
34 C10_HOST_DEVICE scalar_t load(const scalar_t* src) {
35   return c10::detail::LoadImpl<scalar_t>::apply(src);
36 }
37 
38 } // namespace c10
39