xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/types.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/types.h>
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/tensorexpr/exceptions.h>
5 
6 #include <c10/util/Logging.h>
7 
8 namespace torch::jit::tensorexpr {
9 
scalar_dtype() const10 Dtype Dtype::scalar_dtype() const {
11   return ToDtype(scalar_type_);
12 }
13 
14 #define DTYPE_DEFINE(_1, n) TORCH_API Dtype k##n(ScalarType::n, 1);
15 
16 AT_FORALL_SCALAR_TYPES_AND7(
17     Bool,
18     Half,
19     BFloat16,
20     Float8_e5m2,
21     Float8_e5m2fnuz,
22     Float8_e4m3fn,
23     Float8_e4m3fnuz,
24     DTYPE_DEFINE)
25 DTYPE_DEFINE(c10::quint8, QUInt8);
26 DTYPE_DEFINE(c10::qint8, QInt8);
27 
28 #undef DTYPE_DEFINE
29 
30 TORCH_API Dtype kHandle(ScalarType::Undefined, 1);
31 
ToDtype(ScalarType type)32 Dtype ToDtype(ScalarType type) {
33   switch (type) {
34 // NOLINTNEXTLINE
35 #define TYPE_CASE(_1, n) \
36   case ScalarType::n:    \
37     return k##n;
38     AT_FORALL_SCALAR_TYPES_AND7(
39         Bool,
40         Half,
41         BFloat16,
42         Float8_e5m2,
43         Float8_e5m2fnuz,
44         Float8_e4m3fn,
45         Float8_e4m3fnuz,
46         TYPE_CASE)
47     TYPE_CASE(c10::quint8, QUInt8);
48     TYPE_CASE(c10::qint8, QInt8);
49 #undef TYPE_CASE
50 
51     case ScalarType::Undefined:
52       return kHandle;
53     default:
54       throw unsupported_dtype();
55   }
56 }
57 
operator <<(std::ostream & stream,const Dtype & dtype)58 TORCH_API std::ostream& operator<<(std::ostream& stream, const Dtype& dtype) {
59   stream << dtype.scalar_type_;
60   if (dtype.lanes() > 1) {
61     stream << "x" << dtype.lanes();
62     ;
63   }
64   return stream;
65 }
66 
byte_size() const67 int Dtype::byte_size() const {
68   int scalar_size = -1;
69   switch (scalar_type_) {
70 #define TYPE_CASE(Type, Name)   \
71   case ScalarType::Name:        \
72     scalar_size = sizeof(Type); \
73     break;
74 
75     AT_FORALL_SCALAR_TYPES_AND7(
76         Bool,
77         Half,
78         BFloat16,
79         Float8_e5m2,
80         Float8_e4m3fn,
81         Float8_e5m2fnuz,
82         Float8_e4m3fnuz,
83         TYPE_CASE);
84     TYPE_CASE(c10::quint8, QUInt8);
85     TYPE_CASE(c10::qint8, QInt8);
86 #undef TYPE_CASE
87     default:
88       throw std::runtime_error(
89           "invalid scalar type; " + std::to_string(scalar_type_));
90   }
91   return static_cast<int>(scalar_size * lanes());
92 }
93 
ToCppString() const94 std::string Dtype::ToCppString() const {
95   switch (scalar_type_) {
96 // NOLINTNEXTLINE
97 #define TYPE_CASE(t, n) \
98   case ScalarType::n:   \
99     return #t;
100     AT_FORALL_SCALAR_TYPES(TYPE_CASE);
101 #undef TYPE_CASE
102     case ScalarType::Bool:
103       return "bool";
104     case ScalarType::Half:
105       return "half";
106     case ScalarType::BFloat16:
107       return "bfloat16";
108     case ScalarType::Float8_e5m2:
109       return "float8_e5m2";
110     case ScalarType::Float8_e4m3fn:
111       return "float8_e4m3fn";
112     case ScalarType::Float8_e5m2fnuz:
113       return "float8_e5m2fnuz";
114     case ScalarType::Float8_e4m3fnuz:
115       return "float8_e4m3fnuz";
116     case ScalarType::QInt8:
117       return "qint8";
118     case ScalarType::QUInt8:
119       return "quint8";
120     default:
121       throw unsupported_dtype();
122   }
123   return "invalid";
124 }
125 
126 } // namespace torch::jit::tensorexpr
127 
128 namespace std {
129 
to_string(const Dtype & dtype)130 std::string to_string(const Dtype& dtype) {
131   std::ostringstream oss;
132   oss << dtype;
133   return oss.str();
134 }
135 
to_string(const ScalarType & type)136 std::string to_string(const ScalarType& type) {
137   std::ostringstream oss;
138   oss << type;
139   return oss.str();
140 }
141 
142 } // namespace std
143