xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/python_scalars.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <c10/util/TypeCast.h>
5 #include <torch/csrc/python_headers.h>
6 
7 #include <torch/csrc/Exceptions.h>
8 #include <torch/csrc/utils/python_numbers.h>
9 
10 namespace torch::utils {
11 
12 template <typename T>
unpackIntegral(PyObject * obj,const char * type)13 inline T unpackIntegral(PyObject* obj, const char* type) {
14 #if PY_VERSION_HEX >= 0x030a00f0
15   // In Python-3.10 floats can no longer be silently converted to integers
16   // Keep backward compatible behavior for now
17   if (PyFloat_Check(obj)) {
18     return c10::checked_convert<T>(THPUtils_unpackDouble(obj), type);
19   }
20   return c10::checked_convert<T>(THPUtils_unpackLong(obj), type);
21 #else
22   return static_cast<T>(THPUtils_unpackLong(obj));
23 #endif
24 }
25 
store_scalar(void * data,at::ScalarType scalarType,PyObject * obj)26 inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
27   switch (scalarType) {
28     case at::kByte:
29       *(uint8_t*)data = unpackIntegral<uint8_t>(obj, "uint8");
30       break;
31     case at::kUInt16:
32       *(uint16_t*)data = unpackIntegral<uint16_t>(obj, "uint16");
33       break;
34     case at::kUInt32:
35       *(uint32_t*)data = unpackIntegral<uint32_t>(obj, "uint32");
36       break;
37     case at::kUInt64:
38       // NB: This doesn't allow implicit conversion of float to int
39       *(uint64_t*)data = THPUtils_unpackUInt64(obj);
40       break;
41     case at::kChar:
42       *(int8_t*)data = unpackIntegral<int8_t>(obj, "int8");
43       break;
44     case at::kShort:
45       *(int16_t*)data = unpackIntegral<int16_t>(obj, "int16");
46       break;
47     case at::kInt:
48       *(int32_t*)data = unpackIntegral<int32_t>(obj, "int32");
49       break;
50     case at::kLong:
51       *(int64_t*)data = unpackIntegral<int64_t>(obj, "int64");
52       break;
53     case at::kHalf:
54       *(at::Half*)data =
55           at::convert<at::Half, double>(THPUtils_unpackDouble(obj));
56       break;
57     case at::kFloat:
58       *(float*)data = (float)THPUtils_unpackDouble(obj);
59       break;
60     case at::kDouble:
61       *(double*)data = THPUtils_unpackDouble(obj);
62       break;
63     case at::kComplexHalf:
64       *(c10::complex<at::Half>*)data =
65           (c10::complex<at::Half>)static_cast<c10::complex<float>>(
66               THPUtils_unpackComplexDouble(obj));
67       break;
68     case at::kComplexFloat:
69       *(c10::complex<float>*)data =
70           (c10::complex<float>)THPUtils_unpackComplexDouble(obj);
71       break;
72     case at::kComplexDouble:
73       *(c10::complex<double>*)data = THPUtils_unpackComplexDouble(obj);
74       break;
75     case at::kBool:
76       *(bool*)data = THPUtils_unpackNumberAsBool(obj);
77       break;
78     case at::kBFloat16:
79       *(at::BFloat16*)data =
80           at::convert<at::BFloat16, double>(THPUtils_unpackDouble(obj));
81       break;
82     case at::kFloat8_e5m2:
83       *(at::Float8_e5m2*)data =
84           at::convert<at::Float8_e5m2, double>(THPUtils_unpackDouble(obj));
85       break;
86     case at::kFloat8_e5m2fnuz:
87       *(at::Float8_e5m2fnuz*)data =
88           at::convert<at::Float8_e5m2fnuz, double>(THPUtils_unpackDouble(obj));
89       break;
90     case at::kFloat8_e4m3fn:
91       *(at::Float8_e4m3fn*)data =
92           at::convert<at::Float8_e4m3fn, double>(THPUtils_unpackDouble(obj));
93       break;
94     case at::kFloat8_e4m3fnuz:
95       *(at::Float8_e4m3fnuz*)data =
96           at::convert<at::Float8_e4m3fnuz, double>(THPUtils_unpackDouble(obj));
97       break;
98     default:
99       throw std::runtime_error("invalid type");
100   }
101 }
102 
load_scalar(const void * data,at::ScalarType scalarType)103 inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) {
104   switch (scalarType) {
105     case at::kByte:
106       return THPUtils_packInt64(*(uint8_t*)data);
107     case at::kUInt16:
108       return THPUtils_packInt64(*(uint16_t*)data);
109     case at::kUInt32:
110       return THPUtils_packUInt32(*(uint32_t*)data);
111     case at::kUInt64:
112       return THPUtils_packUInt64(*(uint64_t*)data);
113     case at::kChar:
114       return THPUtils_packInt64(*(int8_t*)data);
115     case at::kShort:
116       return THPUtils_packInt64(*(int16_t*)data);
117     case at::kInt:
118       return THPUtils_packInt64(*(int32_t*)data);
119     case at::kLong:
120       return THPUtils_packInt64(*(int64_t*)data);
121     case at::kHalf:
122       return PyFloat_FromDouble(
123           at::convert<double, at::Half>(*(at::Half*)data));
124     case at::kFloat:
125       return PyFloat_FromDouble(*(float*)data);
126     case at::kDouble:
127       return PyFloat_FromDouble(*(double*)data);
128     case at::kComplexHalf: {
129       auto data_ = reinterpret_cast<const c10::complex<at::Half>*>(data);
130       return PyComplex_FromDoubles(data_->real(), data_->imag());
131     }
132     case at::kComplexFloat: {
133       auto data_ = reinterpret_cast<const c10::complex<float>*>(data);
134       return PyComplex_FromDoubles(data_->real(), data_->imag());
135     }
136     case at::kComplexDouble:
137       return PyComplex_FromCComplex(
138           *reinterpret_cast<Py_complex*>((c10::complex<double>*)data));
139     case at::kBool:
140       return PyBool_FromLong(*(bool*)data);
141     case at::kBFloat16:
142       return PyFloat_FromDouble(
143           at::convert<double, at::BFloat16>(*(at::BFloat16*)data));
144     case at::kFloat8_e5m2:
145       return PyFloat_FromDouble(
146           at::convert<double, at::Float8_e5m2>(*(at::Float8_e5m2*)data));
147     case at::kFloat8_e4m3fn:
148       return PyFloat_FromDouble(
149           at::convert<double, at::Float8_e4m3fn>(*(at::Float8_e4m3fn*)data));
150     case at::kFloat8_e5m2fnuz:
151       return PyFloat_FromDouble(at::convert<double, at::Float8_e5m2fnuz>(
152           *(at::Float8_e5m2fnuz*)data));
153     case at::kFloat8_e4m3fnuz:
154       return PyFloat_FromDouble(at::convert<double, at::Float8_e4m3fnuz>(
155           *(at::Float8_e4m3fnuz*)data));
156     default:
157       throw std::runtime_error("invalid type");
158   }
159 }
160 
161 } // namespace torch::utils
162