xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/byte_order.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/BFloat16.h>
2 #include <c10/util/irange.h>
3 #include <torch/csrc/utils/byte_order.h>
4 
5 #include <cstring>
6 #include <vector>
7 
8 #if defined(_MSC_VER)
9 #include <stdlib.h>
10 #endif
11 
12 namespace {
13 
swapBytes16(void * ptr)14 static inline void swapBytes16(void* ptr) {
15   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
16   uint16_t output;
17   memcpy(&output, ptr, sizeof(uint16_t));
18 #if defined(_MSC_VER) && !defined(_DEBUG)
19   output = _byteswap_ushort(output);
20 #elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
21   output = __builtin_bswap16(output);
22 #else
23   uint16_t Hi = output >> 8;
24   uint16_t Lo = output << 8;
25   output = Hi | Lo;
26 #endif
27   memcpy(ptr, &output, sizeof(uint16_t));
28 }
29 
swapBytes32(void * ptr)30 static inline void swapBytes32(void* ptr) {
31   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
32   uint32_t output;
33   memcpy(&output, ptr, sizeof(uint32_t));
34 #if defined(_MSC_VER) && !defined(_DEBUG)
35   output = _byteswap_ulong(output);
36 #elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
37   output = __builtin_bswap32(output);
38 #else
39   uint32_t Byte0 = output & 0x000000FF;
40   uint32_t Byte1 = output & 0x0000FF00;
41   uint32_t Byte2 = output & 0x00FF0000;
42   uint32_t Byte3 = output & 0xFF000000;
43   output = (Byte0 << 24) | (Byte1 << 8) | (Byte2 >> 8) | (Byte3 >> 24);
44 #endif
45   memcpy(ptr, &output, sizeof(uint32_t));
46 }
47 
swapBytes64(void * ptr)48 static inline void swapBytes64(void* ptr) {
49   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
50   uint64_t output;
51   memcpy(&output, ptr, sizeof(uint64_t));
52 #if defined(_MSC_VER)
53   output = _byteswap_uint64(output);
54 #elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
55   output = __builtin_bswap64(output);
56 #else
57   uint64_t Byte0 = output & 0x00000000000000FF;
58   uint64_t Byte1 = output & 0x000000000000FF00;
59   uint64_t Byte2 = output & 0x0000000000FF0000;
60   uint64_t Byte3 = output & 0x00000000FF000000;
61   uint64_t Byte4 = output & 0x000000FF00000000;
62   uint64_t Byte5 = output & 0x0000FF0000000000;
63   uint64_t Byte6 = output & 0x00FF000000000000;
64   uint64_t Byte7 = output & 0xFF00000000000000;
65   output = (Byte0 << (7 * 8)) | (Byte1 << (5 * 8)) | (Byte2 << (3 * 8)) |
66       (Byte3 << (1 * 8)) | (Byte7 >> (7 * 8)) | (Byte6 >> (5 * 8)) |
67       (Byte5 >> (3 * 8)) | (Byte4 >> (1 * 8));
68 #endif
69   memcpy(ptr, &output, sizeof(uint64_t));
70 }
71 
decodeUInt16(const uint8_t * data)72 static inline uint16_t decodeUInt16(const uint8_t* data) {
73   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
74   uint16_t output;
75   memcpy(&output, data, sizeof(uint16_t));
76   return output;
77 }
78 
decodeUInt16ByteSwapped(const uint8_t * data)79 static inline uint16_t decodeUInt16ByteSwapped(const uint8_t* data) {
80   uint16_t output = decodeUInt16(data);
81   swapBytes16(&output);
82   return output;
83 }
84 
decodeUInt32(const uint8_t * data)85 static inline uint32_t decodeUInt32(const uint8_t* data) {
86   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
87   uint32_t output;
88   memcpy(&output, data, sizeof(uint32_t));
89   return output;
90 }
91 
decodeUInt32ByteSwapped(const uint8_t * data)92 static inline uint32_t decodeUInt32ByteSwapped(const uint8_t* data) {
93   uint32_t output = decodeUInt32(data);
94   swapBytes32(&output);
95   return output;
96 }
97 
decodeUInt64(const uint8_t * data)98 static inline uint64_t decodeUInt64(const uint8_t* data) {
99   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
100   uint64_t output;
101   memcpy(&output, data, sizeof(uint64_t));
102   return output;
103 }
104 
decodeUInt64ByteSwapped(const uint8_t * data)105 static inline uint64_t decodeUInt64ByteSwapped(const uint8_t* data) {
106   uint64_t output = decodeUInt64(data);
107   swapBytes64(&output);
108   return output;
109 }
110 
111 } // anonymous namespace
112 
113 namespace torch::utils {
114 
THP_nativeByteOrder()115 THPByteOrder THP_nativeByteOrder() {
116   uint32_t x = 1;
117   return *(uint8_t*)&x ? THP_LITTLE_ENDIAN : THP_BIG_ENDIAN;
118 }
119 
THP_decodeInt16Buffer(int16_t * dst,const uint8_t * src,bool do_byte_swap,size_t len)120 void THP_decodeInt16Buffer(
121     int16_t* dst,
122     const uint8_t* src,
123     bool do_byte_swap,
124     size_t len) {
125   for (const auto i : c10::irange(len)) {
126     dst[i] = (int16_t)(do_byte_swap ? decodeUInt16ByteSwapped(src)
127                                     : decodeUInt16(src));
128     src += sizeof(int16_t);
129   }
130 }
131 
THP_decodeInt32Buffer(int32_t * dst,const uint8_t * src,bool do_byte_swap,size_t len)132 void THP_decodeInt32Buffer(
133     int32_t* dst,
134     const uint8_t* src,
135     bool do_byte_swap,
136     size_t len) {
137   for (const auto i : c10::irange(len)) {
138     dst[i] = (int32_t)(do_byte_swap ? decodeUInt32ByteSwapped(src)
139                                     : decodeUInt32(src));
140     src += sizeof(int32_t);
141   }
142 }
143 
THP_decodeInt64Buffer(int64_t * dst,const uint8_t * src,bool do_byte_swap,size_t len)144 void THP_decodeInt64Buffer(
145     int64_t* dst,
146     const uint8_t* src,
147     bool do_byte_swap,
148     size_t len) {
149   for (const auto i : c10::irange(len)) {
150     dst[i] = (int64_t)(do_byte_swap ? decodeUInt64ByteSwapped(src)
151                                     : decodeUInt64(src));
152     src += sizeof(int64_t);
153   }
154 }
155 
THP_decodeHalfBuffer(c10::Half * dst,const uint8_t * src,bool do_byte_swap,size_t len)156 void THP_decodeHalfBuffer(
157     c10::Half* dst,
158     const uint8_t* src,
159     bool do_byte_swap,
160     size_t len) {
161   for (const auto i : c10::irange(len)) {
162     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
163     union {
164       uint16_t x;
165       c10::Half f;
166     };
167     x = (do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src));
168     dst[i] = f;
169     src += sizeof(uint16_t);
170   }
171 }
172 
THP_decodeBFloat16Buffer(at::BFloat16 * dst,const uint8_t * src,bool do_byte_swap,size_t len)173 void THP_decodeBFloat16Buffer(
174     at::BFloat16* dst,
175     const uint8_t* src,
176     bool do_byte_swap,
177     size_t len) {
178   for (const auto i : c10::irange(len)) {
179     uint16_t x =
180         (do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src));
181     std::memcpy(&dst[i], &x, sizeof(dst[i]));
182     src += sizeof(uint16_t);
183   }
184 }
185 
THP_decodeBoolBuffer(bool * dst,const uint8_t * src,size_t len)186 void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, size_t len) {
187   for (const auto i : c10::irange(len)) {
188     dst[i] = (int)src[i] != 0 ? true : false;
189   }
190 }
191 
THP_decodeFloatBuffer(float * dst,const uint8_t * src,bool do_byte_swap,size_t len)192 void THP_decodeFloatBuffer(
193     float* dst,
194     const uint8_t* src,
195     bool do_byte_swap,
196     size_t len) {
197   for (const auto i : c10::irange(len)) {
198     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
199     union {
200       uint32_t x;
201       float f;
202     };
203     x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
204     dst[i] = f;
205     src += sizeof(float);
206   }
207 }
208 
THP_decodeDoubleBuffer(double * dst,const uint8_t * src,bool do_byte_swap,size_t len)209 void THP_decodeDoubleBuffer(
210     double* dst,
211     const uint8_t* src,
212     bool do_byte_swap,
213     size_t len) {
214   for (const auto i : c10::irange(len)) {
215     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
216     union {
217       uint64_t x;
218       double d;
219     };
220     x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
221     dst[i] = d;
222     src += sizeof(double);
223   }
224 }
225 
THP_decodeComplexFloatBuffer(c10::complex<float> * dst,const uint8_t * src,bool do_byte_swap,size_t len)226 void THP_decodeComplexFloatBuffer(
227     c10::complex<float>* dst,
228     const uint8_t* src,
229     bool do_byte_swap,
230     size_t len) {
231   for (const auto i : c10::irange(len)) {
232     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
233     union {
234       uint32_t x;
235       float re;
236     };
237     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
238     union {
239       uint32_t y;
240       float im;
241     };
242 
243     x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
244     src += sizeof(float);
245     y = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
246     src += sizeof(float);
247 
248     dst[i] = c10::complex<float>(re, im);
249   }
250 }
251 
THP_decodeComplexDoubleBuffer(c10::complex<double> * dst,const uint8_t * src,bool do_byte_swap,size_t len)252 void THP_decodeComplexDoubleBuffer(
253     c10::complex<double>* dst,
254     const uint8_t* src,
255     bool do_byte_swap,
256     size_t len) {
257   for (const auto i : c10::irange(len)) {
258     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
259     union {
260       uint64_t x;
261       double re;
262     };
263     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
264     union {
265       uint64_t y;
266       double im;
267     };
268     static_assert(sizeof(uint64_t) == sizeof(double));
269 
270     x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
271     src += sizeof(double);
272     y = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
273     src += sizeof(double);
274 
275     dst[i] = c10::complex<double>(re, im);
276   }
277 }
278 
THP_decodeInt16Buffer(int16_t * dst,const uint8_t * src,THPByteOrder order,size_t len)279 void THP_decodeInt16Buffer(
280     int16_t* dst,
281     const uint8_t* src,
282     THPByteOrder order,
283     size_t len) {
284   THP_decodeInt16Buffer(dst, src, (order != THP_nativeByteOrder()), len);
285 }
286 
THP_decodeInt32Buffer(int32_t * dst,const uint8_t * src,THPByteOrder order,size_t len)287 void THP_decodeInt32Buffer(
288     int32_t* dst,
289     const uint8_t* src,
290     THPByteOrder order,
291     size_t len) {
292   THP_decodeInt32Buffer(dst, src, (order != THP_nativeByteOrder()), len);
293 }
294 
THP_decodeInt64Buffer(int64_t * dst,const uint8_t * src,THPByteOrder order,size_t len)295 void THP_decodeInt64Buffer(
296     int64_t* dst,
297     const uint8_t* src,
298     THPByteOrder order,
299     size_t len) {
300   THP_decodeInt64Buffer(dst, src, (order != THP_nativeByteOrder()), len);
301 }
302 
THP_decodeHalfBuffer(c10::Half * dst,const uint8_t * src,THPByteOrder order,size_t len)303 void THP_decodeHalfBuffer(
304     c10::Half* dst,
305     const uint8_t* src,
306     THPByteOrder order,
307     size_t len) {
308   THP_decodeHalfBuffer(dst, src, (order != THP_nativeByteOrder()), len);
309 }
310 
THP_decodeBFloat16Buffer(at::BFloat16 * dst,const uint8_t * src,THPByteOrder order,size_t len)311 void THP_decodeBFloat16Buffer(
312     at::BFloat16* dst,
313     const uint8_t* src,
314     THPByteOrder order,
315     size_t len) {
316   THP_decodeBFloat16Buffer(dst, src, (order != THP_nativeByteOrder()), len);
317 }
318 
THP_decodeFloatBuffer(float * dst,const uint8_t * src,THPByteOrder order,size_t len)319 void THP_decodeFloatBuffer(
320     float* dst,
321     const uint8_t* src,
322     THPByteOrder order,
323     size_t len) {
324   THP_decodeFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len);
325 }
326 
THP_decodeDoubleBuffer(double * dst,const uint8_t * src,THPByteOrder order,size_t len)327 void THP_decodeDoubleBuffer(
328     double* dst,
329     const uint8_t* src,
330     THPByteOrder order,
331     size_t len) {
332   THP_decodeDoubleBuffer(dst, src, (order != THP_nativeByteOrder()), len);
333 }
334 
THP_decodeComplexFloatBuffer(c10::complex<float> * dst,const uint8_t * src,THPByteOrder order,size_t len)335 void THP_decodeComplexFloatBuffer(
336     c10::complex<float>* dst,
337     const uint8_t* src,
338     THPByteOrder order,
339     size_t len) {
340   THP_decodeComplexFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len);
341 }
342 
THP_decodeComplexDoubleBuffer(c10::complex<double> * dst,const uint8_t * src,THPByteOrder order,size_t len)343 void THP_decodeComplexDoubleBuffer(
344     c10::complex<double>* dst,
345     const uint8_t* src,
346     THPByteOrder order,
347     size_t len) {
348   THP_decodeComplexDoubleBuffer(
349       dst, src, (order != THP_nativeByteOrder()), len);
350 }
351 
THP_encodeInt16Buffer(uint8_t * dst,const int16_t * src,THPByteOrder order,size_t len)352 void THP_encodeInt16Buffer(
353     uint8_t* dst,
354     const int16_t* src,
355     THPByteOrder order,
356     size_t len) {
357   memcpy(dst, src, sizeof(int16_t) * len);
358   if (order != THP_nativeByteOrder()) {
359     for (const auto i : c10::irange(len)) {
360       (void)i;
361       swapBytes16(dst);
362       dst += sizeof(int16_t);
363     }
364   }
365 }
366 
THP_encodeInt32Buffer(uint8_t * dst,const int32_t * src,THPByteOrder order,size_t len)367 void THP_encodeInt32Buffer(
368     uint8_t* dst,
369     const int32_t* src,
370     THPByteOrder order,
371     size_t len) {
372   memcpy(dst, src, sizeof(int32_t) * len);
373   if (order != THP_nativeByteOrder()) {
374     for (const auto i : c10::irange(len)) {
375       (void)i;
376       swapBytes32(dst);
377       dst += sizeof(int32_t);
378     }
379   }
380 }
381 
THP_encodeInt64Buffer(uint8_t * dst,const int64_t * src,THPByteOrder order,size_t len)382 void THP_encodeInt64Buffer(
383     uint8_t* dst,
384     const int64_t* src,
385     THPByteOrder order,
386     size_t len) {
387   memcpy(dst, src, sizeof(int64_t) * len);
388   if (order != THP_nativeByteOrder()) {
389     for (const auto i : c10::irange(len)) {
390       (void)i;
391       swapBytes64(dst);
392       dst += sizeof(int64_t);
393     }
394   }
395 }
396 
THP_encodeFloatBuffer(uint8_t * dst,const float * src,THPByteOrder order,size_t len)397 void THP_encodeFloatBuffer(
398     uint8_t* dst,
399     const float* src,
400     THPByteOrder order,
401     size_t len) {
402   memcpy(dst, src, sizeof(float) * len);
403   if (order != THP_nativeByteOrder()) {
404     for (const auto i : c10::irange(len)) {
405       (void)i;
406       swapBytes32(dst);
407       dst += sizeof(float);
408     }
409   }
410 }
411 
THP_encodeDoubleBuffer(uint8_t * dst,const double * src,THPByteOrder order,size_t len)412 void THP_encodeDoubleBuffer(
413     uint8_t* dst,
414     const double* src,
415     THPByteOrder order,
416     size_t len) {
417   memcpy(dst, src, sizeof(double) * len);
418   if (order != THP_nativeByteOrder()) {
419     for (const auto i : c10::irange(len)) {
420       (void)i;
421       swapBytes64(dst);
422       dst += sizeof(double);
423     }
424   }
425 }
426 
427 template <typename T>
complex_to_float(const c10::complex<T> * src,size_t len)428 std::vector<T> complex_to_float(const c10::complex<T>* src, size_t len) {
429   std::vector<T> new_src;
430   new_src.reserve(2 * len);
431   for (const auto i : c10::irange(len)) {
432     auto elem = src[i];
433     new_src.emplace_back(elem.real());
434     new_src.emplace_back(elem.imag());
435   }
436   return new_src;
437 }
438 
THP_encodeComplexFloatBuffer(uint8_t * dst,const c10::complex<float> * src,THPByteOrder order,size_t len)439 void THP_encodeComplexFloatBuffer(
440     uint8_t* dst,
441     const c10::complex<float>* src,
442     THPByteOrder order,
443     size_t len) {
444   auto new_src = complex_to_float(src, len);
445   memcpy(dst, static_cast<void*>(&new_src), 2 * sizeof(float) * len);
446   if (order != THP_nativeByteOrder()) {
447     for (const auto i : c10::irange(2 * len)) {
448       (void)i; // Suppress unused variable warning
449       swapBytes32(dst);
450       dst += sizeof(float);
451     }
452   }
453 }
454 
THP_encodeComplexDoubleBuffer(uint8_t * dst,const c10::complex<double> * src,THPByteOrder order,size_t len)455 void THP_encodeComplexDoubleBuffer(
456     uint8_t* dst,
457     const c10::complex<double>* src,
458     THPByteOrder order,
459     size_t len) {
460   auto new_src = complex_to_float(src, len);
461   memcpy(dst, static_cast<void*>(&new_src), 2 * sizeof(double) * len);
462   if (order != THP_nativeByteOrder()) {
463     for (const auto i : c10::irange(2 * len)) {
464       (void)i; // Suppress unused variable warning
465       swapBytes64(dst);
466       dst += sizeof(double);
467     }
468   }
469 }
470 
471 } // namespace torch::utils
472