1 #pragma once
2
3 #include <ATen/ATen.h>
4 #include <c10/util/TypeCast.h>
5 #include <torch/csrc/python_headers.h>
6
7 #include <torch/csrc/Exceptions.h>
8 #include <torch/csrc/utils/python_numbers.h>
9
10 namespace torch::utils {
11
12 template <typename T>
unpackIntegral(PyObject * obj,const char * type)13 inline T unpackIntegral(PyObject* obj, const char* type) {
14 #if PY_VERSION_HEX >= 0x030a00f0
15 // In Python-3.10 floats can no longer be silently converted to integers
16 // Keep backward compatible behavior for now
17 if (PyFloat_Check(obj)) {
18 return c10::checked_convert<T>(THPUtils_unpackDouble(obj), type);
19 }
20 return c10::checked_convert<T>(THPUtils_unpackLong(obj), type);
21 #else
22 return static_cast<T>(THPUtils_unpackLong(obj));
23 #endif
24 }
25
store_scalar(void * data,at::ScalarType scalarType,PyObject * obj)26 inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
27 switch (scalarType) {
28 case at::kByte:
29 *(uint8_t*)data = unpackIntegral<uint8_t>(obj, "uint8");
30 break;
31 case at::kUInt16:
32 *(uint16_t*)data = unpackIntegral<uint16_t>(obj, "uint16");
33 break;
34 case at::kUInt32:
35 *(uint32_t*)data = unpackIntegral<uint32_t>(obj, "uint32");
36 break;
37 case at::kUInt64:
38 // NB: This doesn't allow implicit conversion of float to int
39 *(uint64_t*)data = THPUtils_unpackUInt64(obj);
40 break;
41 case at::kChar:
42 *(int8_t*)data = unpackIntegral<int8_t>(obj, "int8");
43 break;
44 case at::kShort:
45 *(int16_t*)data = unpackIntegral<int16_t>(obj, "int16");
46 break;
47 case at::kInt:
48 *(int32_t*)data = unpackIntegral<int32_t>(obj, "int32");
49 break;
50 case at::kLong:
51 *(int64_t*)data = unpackIntegral<int64_t>(obj, "int64");
52 break;
53 case at::kHalf:
54 *(at::Half*)data =
55 at::convert<at::Half, double>(THPUtils_unpackDouble(obj));
56 break;
57 case at::kFloat:
58 *(float*)data = (float)THPUtils_unpackDouble(obj);
59 break;
60 case at::kDouble:
61 *(double*)data = THPUtils_unpackDouble(obj);
62 break;
63 case at::kComplexHalf:
64 *(c10::complex<at::Half>*)data =
65 (c10::complex<at::Half>)static_cast<c10::complex<float>>(
66 THPUtils_unpackComplexDouble(obj));
67 break;
68 case at::kComplexFloat:
69 *(c10::complex<float>*)data =
70 (c10::complex<float>)THPUtils_unpackComplexDouble(obj);
71 break;
72 case at::kComplexDouble:
73 *(c10::complex<double>*)data = THPUtils_unpackComplexDouble(obj);
74 break;
75 case at::kBool:
76 *(bool*)data = THPUtils_unpackNumberAsBool(obj);
77 break;
78 case at::kBFloat16:
79 *(at::BFloat16*)data =
80 at::convert<at::BFloat16, double>(THPUtils_unpackDouble(obj));
81 break;
82 case at::kFloat8_e5m2:
83 *(at::Float8_e5m2*)data =
84 at::convert<at::Float8_e5m2, double>(THPUtils_unpackDouble(obj));
85 break;
86 case at::kFloat8_e5m2fnuz:
87 *(at::Float8_e5m2fnuz*)data =
88 at::convert<at::Float8_e5m2fnuz, double>(THPUtils_unpackDouble(obj));
89 break;
90 case at::kFloat8_e4m3fn:
91 *(at::Float8_e4m3fn*)data =
92 at::convert<at::Float8_e4m3fn, double>(THPUtils_unpackDouble(obj));
93 break;
94 case at::kFloat8_e4m3fnuz:
95 *(at::Float8_e4m3fnuz*)data =
96 at::convert<at::Float8_e4m3fnuz, double>(THPUtils_unpackDouble(obj));
97 break;
98 default:
99 throw std::runtime_error("invalid type");
100 }
101 }
102
load_scalar(const void * data,at::ScalarType scalarType)103 inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) {
104 switch (scalarType) {
105 case at::kByte:
106 return THPUtils_packInt64(*(uint8_t*)data);
107 case at::kUInt16:
108 return THPUtils_packInt64(*(uint16_t*)data);
109 case at::kUInt32:
110 return THPUtils_packUInt32(*(uint32_t*)data);
111 case at::kUInt64:
112 return THPUtils_packUInt64(*(uint64_t*)data);
113 case at::kChar:
114 return THPUtils_packInt64(*(int8_t*)data);
115 case at::kShort:
116 return THPUtils_packInt64(*(int16_t*)data);
117 case at::kInt:
118 return THPUtils_packInt64(*(int32_t*)data);
119 case at::kLong:
120 return THPUtils_packInt64(*(int64_t*)data);
121 case at::kHalf:
122 return PyFloat_FromDouble(
123 at::convert<double, at::Half>(*(at::Half*)data));
124 case at::kFloat:
125 return PyFloat_FromDouble(*(float*)data);
126 case at::kDouble:
127 return PyFloat_FromDouble(*(double*)data);
128 case at::kComplexHalf: {
129 auto data_ = reinterpret_cast<const c10::complex<at::Half>*>(data);
130 return PyComplex_FromDoubles(data_->real(), data_->imag());
131 }
132 case at::kComplexFloat: {
133 auto data_ = reinterpret_cast<const c10::complex<float>*>(data);
134 return PyComplex_FromDoubles(data_->real(), data_->imag());
135 }
136 case at::kComplexDouble:
137 return PyComplex_FromCComplex(
138 *reinterpret_cast<Py_complex*>((c10::complex<double>*)data));
139 case at::kBool:
140 return PyBool_FromLong(*(bool*)data);
141 case at::kBFloat16:
142 return PyFloat_FromDouble(
143 at::convert<double, at::BFloat16>(*(at::BFloat16*)data));
144 case at::kFloat8_e5m2:
145 return PyFloat_FromDouble(
146 at::convert<double, at::Float8_e5m2>(*(at::Float8_e5m2*)data));
147 case at::kFloat8_e4m3fn:
148 return PyFloat_FromDouble(
149 at::convert<double, at::Float8_e4m3fn>(*(at::Float8_e4m3fn*)data));
150 case at::kFloat8_e5m2fnuz:
151 return PyFloat_FromDouble(at::convert<double, at::Float8_e5m2fnuz>(
152 *(at::Float8_e5m2fnuz*)data));
153 case at::kFloat8_e4m3fnuz:
154 return PyFloat_FromDouble(at::convert<double, at::Float8_e4m3fnuz>(
155 *(at::Float8_e4m3fnuz*)data));
156 default:
157 throw std::runtime_error("invalid type");
158 }
159 }
160
161 } // namespace torch::utils
162