xref: /aosp_15_r20/external/pytorch/c10/util/Load.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
3*da0073e9SAndroid Build Coastguard Worker #include <cstring>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker namespace c10 {
6*da0073e9SAndroid Build Coastguard Worker namespace detail {
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker template <typename T>
9*da0073e9SAndroid Build Coastguard Worker struct LoadImpl {
applyLoadImpl10*da0073e9SAndroid Build Coastguard Worker   C10_HOST_DEVICE static T apply(const void* src) {
11*da0073e9SAndroid Build Coastguard Worker     return *reinterpret_cast<const T*>(src);
12*da0073e9SAndroid Build Coastguard Worker   }
13*da0073e9SAndroid Build Coastguard Worker };
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker template <>
16*da0073e9SAndroid Build Coastguard Worker struct LoadImpl<bool> {
17*da0073e9SAndroid Build Coastguard Worker   C10_HOST_DEVICE static bool apply(const void* src) {
18*da0073e9SAndroid Build Coastguard Worker     static_assert(sizeof(bool) == sizeof(char));
19*da0073e9SAndroid Build Coastguard Worker     // NOTE: [Loading boolean values]
20*da0073e9SAndroid Build Coastguard Worker     // Protect against invalid boolean values by loading as a byte
21*da0073e9SAndroid Build Coastguard Worker     // first, then converting to bool (see gh-54789).
22*da0073e9SAndroid Build Coastguard Worker     return *reinterpret_cast<const unsigned char*>(src);
23*da0073e9SAndroid Build Coastguard Worker   }
24*da0073e9SAndroid Build Coastguard Worker };
25*da0073e9SAndroid Build Coastguard Worker 
26*da0073e9SAndroid Build Coastguard Worker } // namespace detail
27*da0073e9SAndroid Build Coastguard Worker 
28*da0073e9SAndroid Build Coastguard Worker template <typename T>
29*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE T load(const void* src) {
30*da0073e9SAndroid Build Coastguard Worker   return c10::detail::LoadImpl<T>::apply(src);
31*da0073e9SAndroid Build Coastguard Worker }
32*da0073e9SAndroid Build Coastguard Worker 
33*da0073e9SAndroid Build Coastguard Worker template <typename scalar_t>
34*da0073e9SAndroid Build Coastguard Worker C10_HOST_DEVICE scalar_t load(const scalar_t* src) {
35*da0073e9SAndroid Build Coastguard Worker   return c10::detail::LoadImpl<scalar_t>::apply(src);
36*da0073e9SAndroid Build Coastguard Worker }
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker } // namespace c10
39