1 #include <c10/util/BFloat16.h>
2 #include <c10/util/irange.h>
3 #include <torch/csrc/utils/byte_order.h>
4
5 #include <cstring>
6 #include <vector>
7
8 #if defined(_MSC_VER)
9 #include <stdlib.h>
10 #endif
11
12 namespace {
13
swapBytes16(void * ptr)14 static inline void swapBytes16(void* ptr) {
15 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
16 uint16_t output;
17 memcpy(&output, ptr, sizeof(uint16_t));
18 #if defined(_MSC_VER) && !defined(_DEBUG)
19 output = _byteswap_ushort(output);
20 #elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
21 output = __builtin_bswap16(output);
22 #else
23 uint16_t Hi = output >> 8;
24 uint16_t Lo = output << 8;
25 output = Hi | Lo;
26 #endif
27 memcpy(ptr, &output, sizeof(uint16_t));
28 }
29
swapBytes32(void * ptr)30 static inline void swapBytes32(void* ptr) {
31 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
32 uint32_t output;
33 memcpy(&output, ptr, sizeof(uint32_t));
34 #if defined(_MSC_VER) && !defined(_DEBUG)
35 output = _byteswap_ulong(output);
36 #elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
37 output = __builtin_bswap32(output);
38 #else
39 uint32_t Byte0 = output & 0x000000FF;
40 uint32_t Byte1 = output & 0x0000FF00;
41 uint32_t Byte2 = output & 0x00FF0000;
42 uint32_t Byte3 = output & 0xFF000000;
43 output = (Byte0 << 24) | (Byte1 << 8) | (Byte2 >> 8) | (Byte3 >> 24);
44 #endif
45 memcpy(ptr, &output, sizeof(uint32_t));
46 }
47
swapBytes64(void * ptr)48 static inline void swapBytes64(void* ptr) {
49 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
50 uint64_t output;
51 memcpy(&output, ptr, sizeof(uint64_t));
52 #if defined(_MSC_VER)
53 output = _byteswap_uint64(output);
54 #elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
55 output = __builtin_bswap64(output);
56 #else
57 uint64_t Byte0 = output & 0x00000000000000FF;
58 uint64_t Byte1 = output & 0x000000000000FF00;
59 uint64_t Byte2 = output & 0x0000000000FF0000;
60 uint64_t Byte3 = output & 0x00000000FF000000;
61 uint64_t Byte4 = output & 0x000000FF00000000;
62 uint64_t Byte5 = output & 0x0000FF0000000000;
63 uint64_t Byte6 = output & 0x00FF000000000000;
64 uint64_t Byte7 = output & 0xFF00000000000000;
65 output = (Byte0 << (7 * 8)) | (Byte1 << (5 * 8)) | (Byte2 << (3 * 8)) |
66 (Byte3 << (1 * 8)) | (Byte7 >> (7 * 8)) | (Byte6 >> (5 * 8)) |
67 (Byte5 >> (3 * 8)) | (Byte4 >> (1 * 8));
68 #endif
69 memcpy(ptr, &output, sizeof(uint64_t));
70 }
71
decodeUInt16(const uint8_t * data)72 static inline uint16_t decodeUInt16(const uint8_t* data) {
73 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
74 uint16_t output;
75 memcpy(&output, data, sizeof(uint16_t));
76 return output;
77 }
78
decodeUInt16ByteSwapped(const uint8_t * data)79 static inline uint16_t decodeUInt16ByteSwapped(const uint8_t* data) {
80 uint16_t output = decodeUInt16(data);
81 swapBytes16(&output);
82 return output;
83 }
84
decodeUInt32(const uint8_t * data)85 static inline uint32_t decodeUInt32(const uint8_t* data) {
86 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
87 uint32_t output;
88 memcpy(&output, data, sizeof(uint32_t));
89 return output;
90 }
91
decodeUInt32ByteSwapped(const uint8_t * data)92 static inline uint32_t decodeUInt32ByteSwapped(const uint8_t* data) {
93 uint32_t output = decodeUInt32(data);
94 swapBytes32(&output);
95 return output;
96 }
97
decodeUInt64(const uint8_t * data)98 static inline uint64_t decodeUInt64(const uint8_t* data) {
99 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
100 uint64_t output;
101 memcpy(&output, data, sizeof(uint64_t));
102 return output;
103 }
104
decodeUInt64ByteSwapped(const uint8_t * data)105 static inline uint64_t decodeUInt64ByteSwapped(const uint8_t* data) {
106 uint64_t output = decodeUInt64(data);
107 swapBytes64(&output);
108 return output;
109 }
110
111 } // anonymous namespace
112
113 namespace torch::utils {
114
THP_nativeByteOrder()115 THPByteOrder THP_nativeByteOrder() {
116 uint32_t x = 1;
117 return *(uint8_t*)&x ? THP_LITTLE_ENDIAN : THP_BIG_ENDIAN;
118 }
119
THP_decodeInt16Buffer(int16_t * dst,const uint8_t * src,bool do_byte_swap,size_t len)120 void THP_decodeInt16Buffer(
121 int16_t* dst,
122 const uint8_t* src,
123 bool do_byte_swap,
124 size_t len) {
125 for (const auto i : c10::irange(len)) {
126 dst[i] = (int16_t)(do_byte_swap ? decodeUInt16ByteSwapped(src)
127 : decodeUInt16(src));
128 src += sizeof(int16_t);
129 }
130 }
131
THP_decodeInt32Buffer(int32_t * dst,const uint8_t * src,bool do_byte_swap,size_t len)132 void THP_decodeInt32Buffer(
133 int32_t* dst,
134 const uint8_t* src,
135 bool do_byte_swap,
136 size_t len) {
137 for (const auto i : c10::irange(len)) {
138 dst[i] = (int32_t)(do_byte_swap ? decodeUInt32ByteSwapped(src)
139 : decodeUInt32(src));
140 src += sizeof(int32_t);
141 }
142 }
143
THP_decodeInt64Buffer(int64_t * dst,const uint8_t * src,bool do_byte_swap,size_t len)144 void THP_decodeInt64Buffer(
145 int64_t* dst,
146 const uint8_t* src,
147 bool do_byte_swap,
148 size_t len) {
149 for (const auto i : c10::irange(len)) {
150 dst[i] = (int64_t)(do_byte_swap ? decodeUInt64ByteSwapped(src)
151 : decodeUInt64(src));
152 src += sizeof(int64_t);
153 }
154 }
155
THP_decodeHalfBuffer(c10::Half * dst,const uint8_t * src,bool do_byte_swap,size_t len)156 void THP_decodeHalfBuffer(
157 c10::Half* dst,
158 const uint8_t* src,
159 bool do_byte_swap,
160 size_t len) {
161 for (const auto i : c10::irange(len)) {
162 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
163 union {
164 uint16_t x;
165 c10::Half f;
166 };
167 x = (do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src));
168 dst[i] = f;
169 src += sizeof(uint16_t);
170 }
171 }
172
THP_decodeBFloat16Buffer(at::BFloat16 * dst,const uint8_t * src,bool do_byte_swap,size_t len)173 void THP_decodeBFloat16Buffer(
174 at::BFloat16* dst,
175 const uint8_t* src,
176 bool do_byte_swap,
177 size_t len) {
178 for (const auto i : c10::irange(len)) {
179 uint16_t x =
180 (do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src));
181 std::memcpy(&dst[i], &x, sizeof(dst[i]));
182 src += sizeof(uint16_t);
183 }
184 }
185
THP_decodeBoolBuffer(bool * dst,const uint8_t * src,size_t len)186 void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, size_t len) {
187 for (const auto i : c10::irange(len)) {
188 dst[i] = (int)src[i] != 0 ? true : false;
189 }
190 }
191
THP_decodeFloatBuffer(float * dst,const uint8_t * src,bool do_byte_swap,size_t len)192 void THP_decodeFloatBuffer(
193 float* dst,
194 const uint8_t* src,
195 bool do_byte_swap,
196 size_t len) {
197 for (const auto i : c10::irange(len)) {
198 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
199 union {
200 uint32_t x;
201 float f;
202 };
203 x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
204 dst[i] = f;
205 src += sizeof(float);
206 }
207 }
208
THP_decodeDoubleBuffer(double * dst,const uint8_t * src,bool do_byte_swap,size_t len)209 void THP_decodeDoubleBuffer(
210 double* dst,
211 const uint8_t* src,
212 bool do_byte_swap,
213 size_t len) {
214 for (const auto i : c10::irange(len)) {
215 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
216 union {
217 uint64_t x;
218 double d;
219 };
220 x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
221 dst[i] = d;
222 src += sizeof(double);
223 }
224 }
225
THP_decodeComplexFloatBuffer(c10::complex<float> * dst,const uint8_t * src,bool do_byte_swap,size_t len)226 void THP_decodeComplexFloatBuffer(
227 c10::complex<float>* dst,
228 const uint8_t* src,
229 bool do_byte_swap,
230 size_t len) {
231 for (const auto i : c10::irange(len)) {
232 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
233 union {
234 uint32_t x;
235 float re;
236 };
237 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
238 union {
239 uint32_t y;
240 float im;
241 };
242
243 x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
244 src += sizeof(float);
245 y = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
246 src += sizeof(float);
247
248 dst[i] = c10::complex<float>(re, im);
249 }
250 }
251
THP_decodeComplexDoubleBuffer(c10::complex<double> * dst,const uint8_t * src,bool do_byte_swap,size_t len)252 void THP_decodeComplexDoubleBuffer(
253 c10::complex<double>* dst,
254 const uint8_t* src,
255 bool do_byte_swap,
256 size_t len) {
257 for (const auto i : c10::irange(len)) {
258 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
259 union {
260 uint64_t x;
261 double re;
262 };
263 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
264 union {
265 uint64_t y;
266 double im;
267 };
268 static_assert(sizeof(uint64_t) == sizeof(double));
269
270 x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
271 src += sizeof(double);
272 y = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
273 src += sizeof(double);
274
275 dst[i] = c10::complex<double>(re, im);
276 }
277 }
278
THP_decodeInt16Buffer(int16_t * dst,const uint8_t * src,THPByteOrder order,size_t len)279 void THP_decodeInt16Buffer(
280 int16_t* dst,
281 const uint8_t* src,
282 THPByteOrder order,
283 size_t len) {
284 THP_decodeInt16Buffer(dst, src, (order != THP_nativeByteOrder()), len);
285 }
286
THP_decodeInt32Buffer(int32_t * dst,const uint8_t * src,THPByteOrder order,size_t len)287 void THP_decodeInt32Buffer(
288 int32_t* dst,
289 const uint8_t* src,
290 THPByteOrder order,
291 size_t len) {
292 THP_decodeInt32Buffer(dst, src, (order != THP_nativeByteOrder()), len);
293 }
294
THP_decodeInt64Buffer(int64_t * dst,const uint8_t * src,THPByteOrder order,size_t len)295 void THP_decodeInt64Buffer(
296 int64_t* dst,
297 const uint8_t* src,
298 THPByteOrder order,
299 size_t len) {
300 THP_decodeInt64Buffer(dst, src, (order != THP_nativeByteOrder()), len);
301 }
302
THP_decodeHalfBuffer(c10::Half * dst,const uint8_t * src,THPByteOrder order,size_t len)303 void THP_decodeHalfBuffer(
304 c10::Half* dst,
305 const uint8_t* src,
306 THPByteOrder order,
307 size_t len) {
308 THP_decodeHalfBuffer(dst, src, (order != THP_nativeByteOrder()), len);
309 }
310
THP_decodeBFloat16Buffer(at::BFloat16 * dst,const uint8_t * src,THPByteOrder order,size_t len)311 void THP_decodeBFloat16Buffer(
312 at::BFloat16* dst,
313 const uint8_t* src,
314 THPByteOrder order,
315 size_t len) {
316 THP_decodeBFloat16Buffer(dst, src, (order != THP_nativeByteOrder()), len);
317 }
318
THP_decodeFloatBuffer(float * dst,const uint8_t * src,THPByteOrder order,size_t len)319 void THP_decodeFloatBuffer(
320 float* dst,
321 const uint8_t* src,
322 THPByteOrder order,
323 size_t len) {
324 THP_decodeFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len);
325 }
326
THP_decodeDoubleBuffer(double * dst,const uint8_t * src,THPByteOrder order,size_t len)327 void THP_decodeDoubleBuffer(
328 double* dst,
329 const uint8_t* src,
330 THPByteOrder order,
331 size_t len) {
332 THP_decodeDoubleBuffer(dst, src, (order != THP_nativeByteOrder()), len);
333 }
334
THP_decodeComplexFloatBuffer(c10::complex<float> * dst,const uint8_t * src,THPByteOrder order,size_t len)335 void THP_decodeComplexFloatBuffer(
336 c10::complex<float>* dst,
337 const uint8_t* src,
338 THPByteOrder order,
339 size_t len) {
340 THP_decodeComplexFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len);
341 }
342
THP_decodeComplexDoubleBuffer(c10::complex<double> * dst,const uint8_t * src,THPByteOrder order,size_t len)343 void THP_decodeComplexDoubleBuffer(
344 c10::complex<double>* dst,
345 const uint8_t* src,
346 THPByteOrder order,
347 size_t len) {
348 THP_decodeComplexDoubleBuffer(
349 dst, src, (order != THP_nativeByteOrder()), len);
350 }
351
THP_encodeInt16Buffer(uint8_t * dst,const int16_t * src,THPByteOrder order,size_t len)352 void THP_encodeInt16Buffer(
353 uint8_t* dst,
354 const int16_t* src,
355 THPByteOrder order,
356 size_t len) {
357 memcpy(dst, src, sizeof(int16_t) * len);
358 if (order != THP_nativeByteOrder()) {
359 for (const auto i : c10::irange(len)) {
360 (void)i;
361 swapBytes16(dst);
362 dst += sizeof(int16_t);
363 }
364 }
365 }
366
THP_encodeInt32Buffer(uint8_t * dst,const int32_t * src,THPByteOrder order,size_t len)367 void THP_encodeInt32Buffer(
368 uint8_t* dst,
369 const int32_t* src,
370 THPByteOrder order,
371 size_t len) {
372 memcpy(dst, src, sizeof(int32_t) * len);
373 if (order != THP_nativeByteOrder()) {
374 for (const auto i : c10::irange(len)) {
375 (void)i;
376 swapBytes32(dst);
377 dst += sizeof(int32_t);
378 }
379 }
380 }
381
THP_encodeInt64Buffer(uint8_t * dst,const int64_t * src,THPByteOrder order,size_t len)382 void THP_encodeInt64Buffer(
383 uint8_t* dst,
384 const int64_t* src,
385 THPByteOrder order,
386 size_t len) {
387 memcpy(dst, src, sizeof(int64_t) * len);
388 if (order != THP_nativeByteOrder()) {
389 for (const auto i : c10::irange(len)) {
390 (void)i;
391 swapBytes64(dst);
392 dst += sizeof(int64_t);
393 }
394 }
395 }
396
THP_encodeFloatBuffer(uint8_t * dst,const float * src,THPByteOrder order,size_t len)397 void THP_encodeFloatBuffer(
398 uint8_t* dst,
399 const float* src,
400 THPByteOrder order,
401 size_t len) {
402 memcpy(dst, src, sizeof(float) * len);
403 if (order != THP_nativeByteOrder()) {
404 for (const auto i : c10::irange(len)) {
405 (void)i;
406 swapBytes32(dst);
407 dst += sizeof(float);
408 }
409 }
410 }
411
THP_encodeDoubleBuffer(uint8_t * dst,const double * src,THPByteOrder order,size_t len)412 void THP_encodeDoubleBuffer(
413 uint8_t* dst,
414 const double* src,
415 THPByteOrder order,
416 size_t len) {
417 memcpy(dst, src, sizeof(double) * len);
418 if (order != THP_nativeByteOrder()) {
419 for (const auto i : c10::irange(len)) {
420 (void)i;
421 swapBytes64(dst);
422 dst += sizeof(double);
423 }
424 }
425 }
426
427 template <typename T>
complex_to_float(const c10::complex<T> * src,size_t len)428 std::vector<T> complex_to_float(const c10::complex<T>* src, size_t len) {
429 std::vector<T> new_src;
430 new_src.reserve(2 * len);
431 for (const auto i : c10::irange(len)) {
432 auto elem = src[i];
433 new_src.emplace_back(elem.real());
434 new_src.emplace_back(elem.imag());
435 }
436 return new_src;
437 }
438
THP_encodeComplexFloatBuffer(uint8_t * dst,const c10::complex<float> * src,THPByteOrder order,size_t len)439 void THP_encodeComplexFloatBuffer(
440 uint8_t* dst,
441 const c10::complex<float>* src,
442 THPByteOrder order,
443 size_t len) {
444 auto new_src = complex_to_float(src, len);
445 memcpy(dst, static_cast<void*>(&new_src), 2 * sizeof(float) * len);
446 if (order != THP_nativeByteOrder()) {
447 for (const auto i : c10::irange(2 * len)) {
448 (void)i; // Suppress unused variable warning
449 swapBytes32(dst);
450 dst += sizeof(float);
451 }
452 }
453 }
454
THP_encodeComplexDoubleBuffer(uint8_t * dst,const c10::complex<double> * src,THPByteOrder order,size_t len)455 void THP_encodeComplexDoubleBuffer(
456 uint8_t* dst,
457 const c10::complex<double>* src,
458 THPByteOrder order,
459 size_t len) {
460 auto new_src = complex_to_float(src, len);
461 memcpy(dst, static_cast<void*>(&new_src), 2 * sizeof(double) * len);
462 if (order != THP_nativeByteOrder()) {
463 for (const auto i : c10::irange(2 * len)) {
464 (void)i; // Suppress unused variable warning
465 swapBytes64(dst);
466 dst += sizeof(double);
467 }
468 }
469 }
470
471 } // namespace torch::utils
472