xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/byte_order.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/BFloat16.h>
4 #include <c10/util/Float8_e4m3fn.h>
5 #include <c10/util/Float8_e4m3fnuz.h>
6 #include <c10/util/Float8_e5m2.h>
7 #include <c10/util/Float8_e5m2fnuz.h>
8 #include <c10/util/Half.h>
9 #include <torch/csrc/Export.h>
10 #include <cstddef>
11 #include <cstdint>
12 
13 #ifdef __FreeBSD__
14 #include <sys/endian.h>
15 #include <sys/types.h>
16 #define thp_bswap16(x) bswap16(x)
17 #define thp_bswap32(x) bswap32(x)
18 #define thp_bswap64(x) bswap64(x)
19 #elif defined(__APPLE__)
20 #include <libkern/OSByteOrder.h>
21 #define thp_bswap16(x) OSSwapInt16(x)
22 #define thp_bswap32(x) OSSwapInt32(x)
23 #define thp_bswap64(x) OSSwapInt64(x)
24 #elif defined(__GNUC__) && !defined(__MINGW32__)
25 #include <byteswap.h>
26 #define thp_bswap16(x) bswap_16(x)
27 #define thp_bswap32(x) bswap_32(x)
28 #define thp_bswap64(x) bswap_64(x)
29 #elif defined _WIN32 || defined _WIN64
30 #define thp_bswap16(x) _byteswap_ushort(x)
31 #define thp_bswap32(x) _byteswap_ulong(x)
32 #define thp_bswap64(x) _byteswap_uint64(x)
33 #endif
34 
35 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
36 #define to_be16(x) thp_bswap16(x)
37 #define from_be16(x) thp_bswap16(x)
38 #define to_be32(x) thp_bswap32(x)
39 #define from_be32(x) thp_bswap32(x)
40 #define to_be64(x) thp_bswap64(x)
41 #define from_be64(x) thp_bswap64(x)
42 #define to_le16(x) (x)
43 #define from_le16(x) (x)
44 #define to_le32(x) (x)
45 #define from_le32(x) (x)
46 #define to_le64(x) (x)
47 #define from_le64(x) (x)
48 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
49 #define to_be16(x) (x)
50 #define from_be16(x) (x)
51 #define to_be32(x) (x)
52 #define from_be32(x) (x)
53 #define to_be64(x) (x)
54 #define from_be64(x) (x)
55 #define to_le16(x) thp_bswap16(x)
56 #define from_le16(x) thp_bswap16(x)
57 #define to_le32(x) thp_bswap32(x)
58 #define from_le32(x) thp_bswap32(x)
59 #define to_le64(x) thp_bswap64(x)
60 #define from_le64(x) thp_bswap64(x)
61 #else
62 #error Unexpected or undefined __BYTE_ORDER__
63 #endif
64 
65 namespace torch::utils {
66 
67 enum THPByteOrder { THP_LITTLE_ENDIAN = 0, THP_BIG_ENDIAN = 1 };
68 
69 TORCH_API THPByteOrder THP_nativeByteOrder();
70 
71 TORCH_API void THP_decodeInt16Buffer(
72     int16_t* dst,
73     const uint8_t* src,
74     bool do_byte_swap,
75     size_t len);
76 TORCH_API void THP_decodeInt32Buffer(
77     int32_t* dst,
78     const uint8_t* src,
79     bool do_byte_swap,
80     size_t len);
81 TORCH_API void THP_decodeInt64Buffer(
82     int64_t* dst,
83     const uint8_t* src,
84     bool do_byte_swap,
85     size_t len);
86 TORCH_API void THP_decodeHalfBuffer(
87     c10::Half* dst,
88     const uint8_t* src,
89     bool do_byte_swap,
90     size_t len);
91 TORCH_API void THP_decodeFloatBuffer(
92     float* dst,
93     const uint8_t* src,
94     bool do_byte_swap,
95     size_t len);
96 TORCH_API void THP_decodeDoubleBuffer(
97     double* dst,
98     const uint8_t* src,
99     bool do_byte_swap,
100     size_t len);
101 TORCH_API void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, size_t len);
102 TORCH_API void THP_decodeBFloat16Buffer(
103     at::BFloat16* dst,
104     const uint8_t* src,
105     bool do_byte_swap,
106     size_t len);
107 TORCH_API void THP_decodeComplexFloatBuffer(
108     c10::complex<float>* dst,
109     const uint8_t* src,
110     bool do_byte_swap,
111     size_t len);
112 TORCH_API void THP_decodeComplexDoubleBuffer(
113     c10::complex<double>* dst,
114     const uint8_t* src,
115     bool do_byte_swap,
116     size_t len);
117 
118 TORCH_API void THP_decodeInt16Buffer(
119     int16_t* dst,
120     const uint8_t* src,
121     THPByteOrder order,
122     size_t len);
123 TORCH_API void THP_decodeInt32Buffer(
124     int32_t* dst,
125     const uint8_t* src,
126     THPByteOrder order,
127     size_t len);
128 TORCH_API void THP_decodeInt64Buffer(
129     int64_t* dst,
130     const uint8_t* src,
131     THPByteOrder order,
132     size_t len);
133 TORCH_API void THP_decodeHalfBuffer(
134     c10::Half* dst,
135     const uint8_t* src,
136     THPByteOrder order,
137     size_t len);
138 TORCH_API void THP_decodeFloatBuffer(
139     float* dst,
140     const uint8_t* src,
141     THPByteOrder order,
142     size_t len);
143 TORCH_API void THP_decodeDoubleBuffer(
144     double* dst,
145     const uint8_t* src,
146     THPByteOrder order,
147     size_t len);
148 TORCH_API void THP_decodeBFloat16Buffer(
149     at::BFloat16* dst,
150     const uint8_t* src,
151     THPByteOrder order,
152     size_t len);
153 TORCH_API void THP_decodeFloat8_e5m2Buffer(
154     at::Float8_e5m2* dst,
155     const uint8_t* src,
156     size_t len);
157 TORCH_API void THP_decodeFloat8_e4m3fnBuffer(
158     at::Float8_e4m3fn* dst,
159     const uint8_t* src,
160     size_t len);
161 TORCH_API void THP_decodeFloat8_e5m2fnuzBuffer(
162     at::Float8_e5m2fnuz* dst,
163     const uint8_t* src,
164     size_t len);
165 TORCH_API void THP_decodeFloat8_e4m3fnuzBuffer(
166     at::Float8_e4m3fnuz* dst,
167     const uint8_t* src,
168     size_t len);
169 TORCH_API void THP_decodeComplexFloatBuffer(
170     c10::complex<float>* dst,
171     const uint8_t* src,
172     THPByteOrder order,
173     size_t len);
174 TORCH_API void THP_decodeComplexDoubleBuffer(
175     c10::complex<double>* dst,
176     const uint8_t* src,
177     THPByteOrder order,
178     size_t len);
179 
180 TORCH_API void THP_encodeInt16Buffer(
181     uint8_t* dst,
182     const int16_t* src,
183     THPByteOrder order,
184     size_t len);
185 TORCH_API void THP_encodeInt32Buffer(
186     uint8_t* dst,
187     const int32_t* src,
188     THPByteOrder order,
189     size_t len);
190 TORCH_API void THP_encodeInt64Buffer(
191     uint8_t* dst,
192     const int64_t* src,
193     THPByteOrder order,
194     size_t len);
195 TORCH_API void THP_encodeFloatBuffer(
196     uint8_t* dst,
197     const float* src,
198     THPByteOrder order,
199     size_t len);
200 TORCH_API void THP_encodeDoubleBuffer(
201     uint8_t* dst,
202     const double* src,
203     THPByteOrder order,
204     size_t len);
205 TORCH_API void THP_encodeComplexFloatBuffer(
206     uint8_t* dst,
207     const c10::complex<float>* src,
208     THPByteOrder order,
209     size_t len);
210 TORCH_API void THP_encodeComplexDoubleBuffer(
211     uint8_t* dst,
212     const c10::complex<double>* src,
213     THPByteOrder order,
214     size_t len);
215 
216 } // namespace torch::utils
217