xref: /aosp_15_r20/external/pytorch/c10/core/Scalar.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <stdexcept>
5 #include <type_traits>
6 #include <utility>
7 
8 #include <c10/core/OptionalRef.h>
9 #include <c10/core/ScalarType.h>
10 #include <c10/core/SymBool.h>
11 #include <c10/core/SymFloat.h>
12 #include <c10/core/SymInt.h>
13 #include <c10/core/SymNodeImpl.h>
14 #include <c10/macros/Export.h>
15 #include <c10/macros/Macros.h>
16 #include <c10/util/Deprecated.h>
17 #include <c10/util/Exception.h>
18 #include <c10/util/Half.h>
19 #include <c10/util/TypeCast.h>
20 #include <c10/util/complex.h>
21 #include <c10/util/intrusive_ptr.h>
22 
23 namespace c10 {
24 
25 /**
26  * Scalar represents a 0-dimensional tensor which contains a single element.
27  * Unlike a tensor, numeric literals (in C++) are implicitly convertible to
28  * Scalar (which is why, for example, we provide both add(Tensor) and
29  * add(Scalar) overloads for many operations). It may also be used in
30  * circumstances where you statically know a tensor is 0-dim and single size,
31  * but don't know its type.
32  */
33 class C10_API Scalar {
34  public:
Scalar()35   Scalar() : Scalar(int64_t(0)) {}
36 
destroy()37   void destroy() {
38     if (Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag) {
39       raw::intrusive_ptr::decref(v.p);
40       v.p = nullptr;
41     }
42   }
43 
~Scalar()44   ~Scalar() {
45     destroy();
46   }
47 
48 #define DEFINE_IMPLICIT_CTOR(type, name) \
49   Scalar(type vv) : Scalar(vv, true) {}
50 
51   AT_FORALL_SCALAR_TYPES_AND7(
52       Half,
53       BFloat16,
54       Float8_e5m2,
55       Float8_e4m3fn,
56       Float8_e5m2fnuz,
57       Float8_e4m3fnuz,
58       ComplexHalf,
59       DEFINE_IMPLICIT_CTOR)
60   AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
61 
62   // Helper constructors to allow Scalar creation from long and long long types
63   // As std::is_same_v<long, long long> is false(except Android), one needs to
64   // provide a constructor from either long or long long in addition to one from
65   // int64_t
66 #if defined(__APPLE__) || defined(__MACOSX)
67   static_assert(
68       std::is_same_v<long long, int64_t>,
69       "int64_t is the same as long long on MacOS");
Scalar(long vv)70   Scalar(long vv) : Scalar(vv, true) {}
71 #endif
72 #if defined(_MSC_VER)
73   static_assert(
74       std::is_same_v<long long, int64_t>,
75       "int64_t is the same as long long on Windows");
Scalar(long vv)76   Scalar(long vv) : Scalar(vv, true) {}
77 #endif
78 #if defined(__linux__) && !defined(__ANDROID__)
79   static_assert(
80       std::is_same_v<long, int64_t>,
81       "int64_t is the same as long on Linux");
Scalar(long long vv)82   Scalar(long long vv) : Scalar(vv, true) {}
83 #endif
84 
Scalar(uint16_t vv)85   Scalar(uint16_t vv) : Scalar(vv, true) {}
Scalar(uint32_t vv)86   Scalar(uint32_t vv) : Scalar(vv, true) {}
Scalar(uint64_t vv)87   Scalar(uint64_t vv) {
88     if (vv > static_cast<uint64_t>(INT64_MAX)) {
89       tag = Tag::HAS_u;
90       v.u = vv;
91     } else {
92       tag = Tag::HAS_i;
93       // NB: no need to use convert, we've already tested convertibility
94       v.i = static_cast<int64_t>(vv);
95     }
96   }
97 
98 #undef DEFINE_IMPLICIT_CTOR
99 
100   // Value* is both implicitly convertible to SymbolicVariable and bool which
101   // causes ambiguity error. Specialized constructor for bool resolves this
102   // problem.
103   template <
104       typename T,
105       typename std::enable_if_t<std::is_same_v<T, bool>, bool>* = nullptr>
Scalar(T vv)106   Scalar(T vv) : tag(Tag::HAS_b) {
107     v.i = convert<int64_t, bool>(vv);
108   }
109 
110   template <
111       typename T,
112       typename std::enable_if_t<std::is_same_v<T, c10::SymBool>, bool>* =
113           nullptr>
Scalar(T vv)114   Scalar(T vv) : tag(Tag::HAS_sb) {
115     v.i = convert<int64_t, c10::SymBool>(vv);
116   }
117 
118 #define DEFINE_ACCESSOR(type, name)                                   \
119   type to##name() const {                                             \
120     if (Tag::HAS_d == tag) {                                          \
121       return checked_convert<type, double>(v.d, #type);               \
122     } else if (Tag::HAS_z == tag) {                                   \
123       return checked_convert<type, c10::complex<double>>(v.z, #type); \
124     }                                                                 \
125     if (Tag::HAS_b == tag) {                                          \
126       return checked_convert<type, bool>(v.i, #type);                 \
127     } else if (Tag::HAS_i == tag) {                                   \
128       return checked_convert<type, int64_t>(v.i, #type);              \
129     } else if (Tag::HAS_u == tag) {                                   \
130       return checked_convert<type, uint64_t>(v.u, #type);             \
131     } else if (Tag::HAS_si == tag) {                                  \
132       return checked_convert<type, int64_t>(                          \
133           toSymInt().guard_int(__FILE__, __LINE__), #type);           \
134     } else if (Tag::HAS_sd == tag) {                                  \
135       return checked_convert<type, int64_t>(                          \
136           toSymFloat().guard_float(__FILE__, __LINE__), #type);       \
137     } else if (Tag::HAS_sb == tag) {                                  \
138       return checked_convert<type, int64_t>(                          \
139           toSymBool().guard_bool(__FILE__, __LINE__), #type);         \
140     }                                                                 \
141     TORCH_CHECK(false)                                                \
142   }
143 
144   // TODO: Support ComplexHalf accessor
145   AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR)
DEFINE_ACCESSOR(uint16_t,UInt16)146   DEFINE_ACCESSOR(uint16_t, UInt16)
147   DEFINE_ACCESSOR(uint32_t, UInt32)
148   DEFINE_ACCESSOR(uint64_t, UInt64)
149 
150 #undef DEFINE_ACCESSOR
151 
152   SymInt toSymInt() const {
153     if (Tag::HAS_si == tag) {
154       return c10::SymInt(intrusive_ptr<SymNodeImpl>::reclaim_copy(
155           static_cast<SymNodeImpl*>(v.p)));
156     } else {
157       return toLong();
158     }
159   }
160 
toSymFloat()161   SymFloat toSymFloat() const {
162     if (Tag::HAS_sd == tag) {
163       return c10::SymFloat(intrusive_ptr<SymNodeImpl>::reclaim_copy(
164           static_cast<SymNodeImpl*>(v.p)));
165     } else {
166       return toDouble();
167     }
168   }
169 
toSymBool()170   SymBool toSymBool() const {
171     if (Tag::HAS_sb == tag) {
172       return c10::SymBool(intrusive_ptr<SymNodeImpl>::reclaim_copy(
173           static_cast<SymNodeImpl*>(v.p)));
174     } else {
175       return toBool();
176     }
177   }
178 
179   // also support scalar.to<int64_t>();
180   // Deleted for unsupported types, but specialized below for supported types
181   template <typename T>
182   T to() const = delete;
183 
184   // audit uses of data_ptr
data_ptr()185   const void* data_ptr() const {
186     TORCH_INTERNAL_ASSERT(!isSymbolic());
187     return static_cast<const void*>(&v);
188   }
189 
isFloatingPoint()190   bool isFloatingPoint() const {
191     return Tag::HAS_d == tag || Tag::HAS_sd == tag;
192   }
193 
194   C10_DEPRECATED_MESSAGE(
195       "isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")
isIntegral()196   bool isIntegral() const {
197     return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag;
198   }
isIntegral(bool includeBool)199   bool isIntegral(bool includeBool) const {
200     return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag ||
201         (includeBool && isBoolean());
202   }
203 
isComplex()204   bool isComplex() const {
205     return Tag::HAS_z == tag;
206   }
isBoolean()207   bool isBoolean() const {
208     return Tag::HAS_b == tag || Tag::HAS_sb == tag;
209   }
210 
211   // you probably don't actually want these; they're mostly for testing
isSymInt()212   bool isSymInt() const {
213     return Tag::HAS_si == tag;
214   }
isSymFloat()215   bool isSymFloat() const {
216     return Tag::HAS_sd == tag;
217   }
isSymBool()218   bool isSymBool() const {
219     return Tag::HAS_sb == tag;
220   }
221 
isSymbolic()222   bool isSymbolic() const {
223     return Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag;
224   }
225 
226   C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) noexcept {
227     if (&other == this) {
228       return *this;
229     }
230 
231     destroy();
232     moveFrom(std::move(other));
233     return *this;
234   }
235 
236   C10_ALWAYS_INLINE Scalar& operator=(const Scalar& other) {
237     if (&other == this) {
238       return *this;
239     }
240 
241     *this = Scalar(other);
242     return *this;
243   }
244 
245   Scalar operator-() const;
246   Scalar conj() const;
247   Scalar log() const;
248 
249   template <
250       typename T,
251       typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
equal(T num)252   bool equal(T num) const {
253     if (isComplex()) {
254       TORCH_INTERNAL_ASSERT(!isSymbolic());
255       auto val = v.z;
256       return (val.real() == num) && (val.imag() == T());
257     } else if (isFloatingPoint()) {
258       TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
259       return v.d == num;
260     } else if (tag == Tag::HAS_i) {
261       if (overflows<T>(v.i, /* strict_unsigned */ true)) {
262         return false;
263       } else {
264         return static_cast<T>(v.i) == num;
265       }
266     } else if (tag == Tag::HAS_u) {
267       if (overflows<T>(v.u, /* strict_unsigned */ true)) {
268         return false;
269       } else {
270         return static_cast<T>(v.u) == num;
271       }
272     } else if (tag == Tag::HAS_si) {
273       TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality");
274     } else if (isBoolean()) {
275       // boolean scalar does not equal to a non boolean value
276       TORCH_INTERNAL_ASSERT(!isSymbolic());
277       return false;
278     } else {
279       TORCH_INTERNAL_ASSERT(false);
280     }
281   }
282 
283   template <
284       typename T,
285       typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
equal(T num)286   bool equal(T num) const {
287     if (isComplex()) {
288       TORCH_INTERNAL_ASSERT(!isSymbolic());
289       return v.z == num;
290     } else if (isFloatingPoint()) {
291       TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
292       return (v.d == num.real()) && (num.imag() == T());
293     } else if (tag == Tag::HAS_i) {
294       if (overflows<T>(v.i, /* strict_unsigned */ true)) {
295         return false;
296       } else {
297         return static_cast<T>(v.i) == num.real() && num.imag() == T();
298       }
299     } else if (tag == Tag::HAS_u) {
300       if (overflows<T>(v.u, /* strict_unsigned */ true)) {
301         return false;
302       } else {
303         return static_cast<T>(v.u) == num.real() && num.imag() == T();
304       }
305     } else if (tag == Tag::HAS_si) {
306       TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality");
307     } else if (isBoolean()) {
308       // boolean scalar does not equal to a non boolean value
309       TORCH_INTERNAL_ASSERT(!isSymbolic());
310       return false;
311     } else {
312       TORCH_INTERNAL_ASSERT(false);
313     }
314   }
315 
equal(bool num)316   bool equal(bool num) const {
317     if (isBoolean()) {
318       TORCH_INTERNAL_ASSERT(!isSymbolic());
319       return static_cast<bool>(v.i) == num;
320     } else {
321       return false;
322     }
323   }
324 
type()325   ScalarType type() const {
326     if (isComplex()) {
327       return ScalarType::ComplexDouble;
328     } else if (isFloatingPoint()) {
329       return ScalarType::Double;
330     } else if (isIntegral(/*includeBool=*/false)) {
331       // Represent all integers as long, UNLESS it is unsigned and therefore
332       // unrepresentable as long
333       if (Tag::HAS_u == tag) {
334         return ScalarType::UInt64;
335       }
336       return ScalarType::Long;
337     } else if (isBoolean()) {
338       return ScalarType::Bool;
339     } else {
340       throw std::runtime_error("Unknown scalar type.");
341     }
342   }
343 
Scalar(Scalar && rhs)344   Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) {
345     moveFrom(std::move(rhs));
346   }
347 
Scalar(const Scalar & rhs)348   Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) {
349     if (isSymbolic()) {
350       c10::raw::intrusive_ptr::incref(v.p);
351     }
352   }
353 
Scalar(c10::SymInt si)354   Scalar(c10::SymInt si) {
355     if (auto m = si.maybe_as_int()) {
356       tag = Tag::HAS_i;
357       v.i = *m;
358     } else {
359       tag = Tag::HAS_si;
360       v.p = std::move(si).release();
361     }
362   }
363 
Scalar(c10::SymFloat sd)364   Scalar(c10::SymFloat sd) {
365     if (sd.is_symbolic()) {
366       tag = Tag::HAS_sd;
367       v.p = std::move(sd).release();
368     } else {
369       tag = Tag::HAS_d;
370       v.d = sd.as_float_unchecked();
371     }
372   }
373 
Scalar(c10::SymBool sb)374   Scalar(c10::SymBool sb) {
375     if (auto m = sb.maybe_as_bool()) {
376       tag = Tag::HAS_b;
377       v.i = *m;
378     } else {
379       tag = Tag::HAS_sb;
380       v.p = std::move(sb).release();
381     }
382   }
383 
384   // We can't set v in the initializer list using the
385   // syntax v{ .member = ... } because it doesn't work on MSVC
386  private:
387   enum class Tag { HAS_d, HAS_i, HAS_u, HAS_z, HAS_b, HAS_sd, HAS_si, HAS_sb };
388 
389   // Note [Meaning of HAS_u]
390   // ~~~~~~~~~~~~~~~~~~~~~~~
391   // HAS_u is a bit special.  On its face, it just means that we
392   // are holding an unsigned integer.  However, we generally don't
393   // distinguish between different bit sizes in Scalar (e.g., we represent
394   // float as double), instead, it represents a mathematical notion
395   // of some quantity (integral versus floating point).  So actually,
396   // HAS_u is used solely to represent unsigned integers that could
397   // not be represented as a signed integer.  That means only uint64_t
398   // potentially can get this tag; smaller types like uint8_t fits into a
399   // regular int and so for BC reasons we keep as an int.
400 
401   // NB: assumes that self has already been cleared
402   // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
moveFrom(Scalar && rhs)403   C10_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept {
404     v = rhs.v;
405     tag = rhs.tag;
406     if (rhs.tag == Tag::HAS_si || rhs.tag == Tag::HAS_sd ||
407         rhs.tag == Tag::HAS_sb) {
408       // Move out of scalar
409       rhs.tag = Tag::HAS_i;
410       rhs.v.i = 0;
411     }
412   }
413 
414   Tag tag;
415 
416   union v_t {
417     double d{};
418     int64_t i;
419     // See Note [Meaning of HAS_u]
420     uint64_t u;
421     c10::complex<double> z;
422     c10::intrusive_ptr_target* p;
423     // NOLINTNEXTLINE(modernize-use-equals-default)
v_t()424     v_t() {} // default constructor
425   } v;
426 
427   template <
428       typename T,
429       typename std::enable_if_t<
430           std::is_integral_v<T> && !std::is_same_v<T, bool>,
431           bool>* = nullptr>
Scalar(T vv,bool)432   Scalar(T vv, bool) : tag(Tag::HAS_i) {
433     v.i = convert<decltype(v.i), T>(vv);
434   }
435 
436   template <
437       typename T,
438       typename std::enable_if_t<
439           !std::is_integral_v<T> && !c10::is_complex<T>::value,
440           bool>* = nullptr>
Scalar(T vv,bool)441   Scalar(T vv, bool) : tag(Tag::HAS_d) {
442     v.d = convert<decltype(v.d), T>(vv);
443   }
444 
445   template <
446       typename T,
447       typename std::enable_if_t<c10::is_complex<T>::value, bool>* = nullptr>
Scalar(T vv,bool)448   Scalar(T vv, bool) : tag(Tag::HAS_z) {
449     v.z = convert<decltype(v.z), T>(vv);
450   }
451 };
452 
453 using OptionalScalarRef = c10::OptionalRef<Scalar>;
454 
455 // define the scalar.to<int64_t>() specializations
456 #define DEFINE_TO(T, name)         \
457   template <>                      \
458   inline T Scalar::to<T>() const { \
459     return to##name();             \
460   }
461 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO)
462 DEFINE_TO(uint16_t, UInt16)
463 DEFINE_TO(uint32_t, UInt32)
464 DEFINE_TO(uint64_t, UInt64)
465 #undef DEFINE_TO
466 
467 } // namespace c10
468