1 #include <torch/csrc/jit/tensorexpr/types.h>
2
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/tensorexpr/exceptions.h>
5
6 #include <c10/util/Logging.h>
7
8 namespace torch::jit::tensorexpr {
9
scalar_dtype() const10 Dtype Dtype::scalar_dtype() const {
11 return ToDtype(scalar_type_);
12 }
13
14 #define DTYPE_DEFINE(_1, n) TORCH_API Dtype k##n(ScalarType::n, 1);
15
16 AT_FORALL_SCALAR_TYPES_AND7(
17 Bool,
18 Half,
19 BFloat16,
20 Float8_e5m2,
21 Float8_e5m2fnuz,
22 Float8_e4m3fn,
23 Float8_e4m3fnuz,
24 DTYPE_DEFINE)
25 DTYPE_DEFINE(c10::quint8, QUInt8);
26 DTYPE_DEFINE(c10::qint8, QInt8);
27
28 #undef DTYPE_DEFINE
29
30 TORCH_API Dtype kHandle(ScalarType::Undefined, 1);
31
ToDtype(ScalarType type)32 Dtype ToDtype(ScalarType type) {
33 switch (type) {
34 // NOLINTNEXTLINE
35 #define TYPE_CASE(_1, n) \
36 case ScalarType::n: \
37 return k##n;
38 AT_FORALL_SCALAR_TYPES_AND7(
39 Bool,
40 Half,
41 BFloat16,
42 Float8_e5m2,
43 Float8_e5m2fnuz,
44 Float8_e4m3fn,
45 Float8_e4m3fnuz,
46 TYPE_CASE)
47 TYPE_CASE(c10::quint8, QUInt8);
48 TYPE_CASE(c10::qint8, QInt8);
49 #undef TYPE_CASE
50
51 case ScalarType::Undefined:
52 return kHandle;
53 default:
54 throw unsupported_dtype();
55 }
56 }
57
operator <<(std::ostream & stream,const Dtype & dtype)58 TORCH_API std::ostream& operator<<(std::ostream& stream, const Dtype& dtype) {
59 stream << dtype.scalar_type_;
60 if (dtype.lanes() > 1) {
61 stream << "x" << dtype.lanes();
62 ;
63 }
64 return stream;
65 }
66
byte_size() const67 int Dtype::byte_size() const {
68 int scalar_size = -1;
69 switch (scalar_type_) {
70 #define TYPE_CASE(Type, Name) \
71 case ScalarType::Name: \
72 scalar_size = sizeof(Type); \
73 break;
74
75 AT_FORALL_SCALAR_TYPES_AND7(
76 Bool,
77 Half,
78 BFloat16,
79 Float8_e5m2,
80 Float8_e4m3fn,
81 Float8_e5m2fnuz,
82 Float8_e4m3fnuz,
83 TYPE_CASE);
84 TYPE_CASE(c10::quint8, QUInt8);
85 TYPE_CASE(c10::qint8, QInt8);
86 #undef TYPE_CASE
87 default:
88 throw std::runtime_error(
89 "invalid scalar type; " + std::to_string(scalar_type_));
90 }
91 return static_cast<int>(scalar_size * lanes());
92 }
93
ToCppString() const94 std::string Dtype::ToCppString() const {
95 switch (scalar_type_) {
96 // NOLINTNEXTLINE
97 #define TYPE_CASE(t, n) \
98 case ScalarType::n: \
99 return #t;
100 AT_FORALL_SCALAR_TYPES(TYPE_CASE);
101 #undef TYPE_CASE
102 case ScalarType::Bool:
103 return "bool";
104 case ScalarType::Half:
105 return "half";
106 case ScalarType::BFloat16:
107 return "bfloat16";
108 case ScalarType::Float8_e5m2:
109 return "float8_e5m2";
110 case ScalarType::Float8_e4m3fn:
111 return "float8_e4m3fn";
112 case ScalarType::Float8_e5m2fnuz:
113 return "float8_e5m2fnuz";
114 case ScalarType::Float8_e4m3fnuz:
115 return "float8_e4m3fnuz";
116 case ScalarType::QInt8:
117 return "qint8";
118 case ScalarType::QUInt8:
119 return "quint8";
120 default:
121 throw unsupported_dtype();
122 }
123 return "invalid";
124 }
125
126 } // namespace torch::jit::tensorexpr
127
128 namespace std {
129
to_string(const Dtype & dtype)130 std::string to_string(const Dtype& dtype) {
131 std::ostringstream oss;
132 oss << dtype;
133 return oss.str();
134 }
135
to_string(const ScalarType & type)136 std::string to_string(const ScalarType& type) {
137 std::ostringstream oss;
138 oss << type;
139 return oss.str();
140 }
141
142 } // namespace std
143