1 #pragma once 2 3 #include <cstdint> 4 #include <stdexcept> 5 #include <type_traits> 6 #include <utility> 7 8 #include <c10/core/OptionalRef.h> 9 #include <c10/core/ScalarType.h> 10 #include <c10/core/SymBool.h> 11 #include <c10/core/SymFloat.h> 12 #include <c10/core/SymInt.h> 13 #include <c10/core/SymNodeImpl.h> 14 #include <c10/macros/Export.h> 15 #include <c10/macros/Macros.h> 16 #include <c10/util/Deprecated.h> 17 #include <c10/util/Exception.h> 18 #include <c10/util/Half.h> 19 #include <c10/util/TypeCast.h> 20 #include <c10/util/complex.h> 21 #include <c10/util/intrusive_ptr.h> 22 23 namespace c10 { 24 25 /** 26 * Scalar represents a 0-dimensional tensor which contains a single element. 27 * Unlike a tensor, numeric literals (in C++) are implicitly convertible to 28 * Scalar (which is why, for example, we provide both add(Tensor) and 29 * add(Scalar) overloads for many operations). It may also be used in 30 * circumstances where you statically know a tensor is 0-dim and single size, 31 * but don't know its type. 32 */ 33 class C10_API Scalar { 34 public: Scalar()35 Scalar() : Scalar(int64_t(0)) {} 36 destroy()37 void destroy() { 38 if (Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag) { 39 raw::intrusive_ptr::decref(v.p); 40 v.p = nullptr; 41 } 42 } 43 ~Scalar()44 ~Scalar() { 45 destroy(); 46 } 47 48 #define DEFINE_IMPLICIT_CTOR(type, name) \ 49 Scalar(type vv) : Scalar(vv, true) {} 50 51 AT_FORALL_SCALAR_TYPES_AND7( 52 Half, 53 BFloat16, 54 Float8_e5m2, 55 Float8_e4m3fn, 56 Float8_e5m2fnuz, 57 Float8_e4m3fnuz, 58 ComplexHalf, 59 DEFINE_IMPLICIT_CTOR) 60 AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR) 61 62 // Helper constructors to allow Scalar creation from long and long long types 63 // As std::is_same_v<long, long long> is false(except Android), one needs to 64 // provide a constructor from either long or long long in addition to one from 65 // int64_t 66 #if defined(__APPLE__) || defined(__MACOSX) 67 static_assert( 68 std::is_same_v<long long, int64_t>, 69 "int64_t is the same as long long on MacOS"); Scalar(long vv)70 Scalar(long vv) : Scalar(vv, true) {} 71 #endif 72 #if defined(_MSC_VER) 73 static_assert( 74 std::is_same_v<long long, int64_t>, 75 "int64_t is the same as long long on Windows"); Scalar(long vv)76 Scalar(long vv) : Scalar(vv, true) {} 77 #endif 78 #if defined(__linux__) && !defined(__ANDROID__) 79 static_assert( 80 std::is_same_v<long, int64_t>, 81 "int64_t is the same as long on Linux"); Scalar(long long vv)82 Scalar(long long vv) : Scalar(vv, true) {} 83 #endif 84 Scalar(uint16_t vv)85 Scalar(uint16_t vv) : Scalar(vv, true) {} Scalar(uint32_t vv)86 Scalar(uint32_t vv) : Scalar(vv, true) {} Scalar(uint64_t vv)87 Scalar(uint64_t vv) { 88 if (vv > static_cast<uint64_t>(INT64_MAX)) { 89 tag = Tag::HAS_u; 90 v.u = vv; 91 } else { 92 tag = Tag::HAS_i; 93 // NB: no need to use convert, we've already tested convertibility 94 v.i = static_cast<int64_t>(vv); 95 } 96 } 97 98 #undef DEFINE_IMPLICIT_CTOR 99 100 // Value* is both implicitly convertible to SymbolicVariable and bool which 101 // causes ambiguity error. Specialized constructor for bool resolves this 102 // problem. 103 template < 104 typename T, 105 typename std::enable_if_t<std::is_same_v<T, bool>, bool>* = nullptr> Scalar(T vv)106 Scalar(T vv) : tag(Tag::HAS_b) { 107 v.i = convert<int64_t, bool>(vv); 108 } 109 110 template < 111 typename T, 112 typename std::enable_if_t<std::is_same_v<T, c10::SymBool>, bool>* = 113 nullptr> Scalar(T vv)114 Scalar(T vv) : tag(Tag::HAS_sb) { 115 v.i = convert<int64_t, c10::SymBool>(vv); 116 } 117 118 #define DEFINE_ACCESSOR(type, name) \ 119 type to##name() const { \ 120 if (Tag::HAS_d == tag) { \ 121 return checked_convert<type, double>(v.d, #type); \ 122 } else if (Tag::HAS_z == tag) { \ 123 return checked_convert<type, c10::complex<double>>(v.z, #type); \ 124 } \ 125 if (Tag::HAS_b == tag) { \ 126 return checked_convert<type, bool>(v.i, #type); \ 127 } else if (Tag::HAS_i == tag) { \ 128 return checked_convert<type, int64_t>(v.i, #type); \ 129 } else if (Tag::HAS_u == tag) { \ 130 return checked_convert<type, uint64_t>(v.u, #type); \ 131 } else if (Tag::HAS_si == tag) { \ 132 return checked_convert<type, int64_t>( \ 133 toSymInt().guard_int(__FILE__, __LINE__), #type); \ 134 } else if (Tag::HAS_sd == tag) { \ 135 return checked_convert<type, int64_t>( \ 136 toSymFloat().guard_float(__FILE__, __LINE__), #type); \ 137 } else if (Tag::HAS_sb == tag) { \ 138 return checked_convert<type, int64_t>( \ 139 toSymBool().guard_bool(__FILE__, __LINE__), #type); \ 140 } \ 141 TORCH_CHECK(false) \ 142 } 143 144 // TODO: Support ComplexHalf accessor 145 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR) DEFINE_ACCESSOR(uint16_t,UInt16)146 DEFINE_ACCESSOR(uint16_t, UInt16) 147 DEFINE_ACCESSOR(uint32_t, UInt32) 148 DEFINE_ACCESSOR(uint64_t, UInt64) 149 150 #undef DEFINE_ACCESSOR 151 152 SymInt toSymInt() const { 153 if (Tag::HAS_si == tag) { 154 return c10::SymInt(intrusive_ptr<SymNodeImpl>::reclaim_copy( 155 static_cast<SymNodeImpl*>(v.p))); 156 } else { 157 return toLong(); 158 } 159 } 160 toSymFloat()161 SymFloat toSymFloat() const { 162 if (Tag::HAS_sd == tag) { 163 return c10::SymFloat(intrusive_ptr<SymNodeImpl>::reclaim_copy( 164 static_cast<SymNodeImpl*>(v.p))); 165 } else { 166 return toDouble(); 167 } 168 } 169 toSymBool()170 SymBool toSymBool() const { 171 if (Tag::HAS_sb == tag) { 172 return c10::SymBool(intrusive_ptr<SymNodeImpl>::reclaim_copy( 173 static_cast<SymNodeImpl*>(v.p))); 174 } else { 175 return toBool(); 176 } 177 } 178 179 // also support scalar.to<int64_t>(); 180 // Deleted for unsupported types, but specialized below for supported types 181 template <typename T> 182 T to() const = delete; 183 184 // audit uses of data_ptr data_ptr()185 const void* data_ptr() const { 186 TORCH_INTERNAL_ASSERT(!isSymbolic()); 187 return static_cast<const void*>(&v); 188 } 189 isFloatingPoint()190 bool isFloatingPoint() const { 191 return Tag::HAS_d == tag || Tag::HAS_sd == tag; 192 } 193 194 C10_DEPRECATED_MESSAGE( 195 "isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.") isIntegral()196 bool isIntegral() const { 197 return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag; 198 } isIntegral(bool includeBool)199 bool isIntegral(bool includeBool) const { 200 return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag || 201 (includeBool && isBoolean()); 202 } 203 isComplex()204 bool isComplex() const { 205 return Tag::HAS_z == tag; 206 } isBoolean()207 bool isBoolean() const { 208 return Tag::HAS_b == tag || Tag::HAS_sb == tag; 209 } 210 211 // you probably don't actually want these; they're mostly for testing isSymInt()212 bool isSymInt() const { 213 return Tag::HAS_si == tag; 214 } isSymFloat()215 bool isSymFloat() const { 216 return Tag::HAS_sd == tag; 217 } isSymBool()218 bool isSymBool() const { 219 return Tag::HAS_sb == tag; 220 } 221 isSymbolic()222 bool isSymbolic() const { 223 return Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag; 224 } 225 226 C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) noexcept { 227 if (&other == this) { 228 return *this; 229 } 230 231 destroy(); 232 moveFrom(std::move(other)); 233 return *this; 234 } 235 236 C10_ALWAYS_INLINE Scalar& operator=(const Scalar& other) { 237 if (&other == this) { 238 return *this; 239 } 240 241 *this = Scalar(other); 242 return *this; 243 } 244 245 Scalar operator-() const; 246 Scalar conj() const; 247 Scalar log() const; 248 249 template < 250 typename T, 251 typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0> equal(T num)252 bool equal(T num) const { 253 if (isComplex()) { 254 TORCH_INTERNAL_ASSERT(!isSymbolic()); 255 auto val = v.z; 256 return (val.real() == num) && (val.imag() == T()); 257 } else if (isFloatingPoint()) { 258 TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality"); 259 return v.d == num; 260 } else if (tag == Tag::HAS_i) { 261 if (overflows<T>(v.i, /* strict_unsigned */ true)) { 262 return false; 263 } else { 264 return static_cast<T>(v.i) == num; 265 } 266 } else if (tag == Tag::HAS_u) { 267 if (overflows<T>(v.u, /* strict_unsigned */ true)) { 268 return false; 269 } else { 270 return static_cast<T>(v.u) == num; 271 } 272 } else if (tag == Tag::HAS_si) { 273 TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality"); 274 } else if (isBoolean()) { 275 // boolean scalar does not equal to a non boolean value 276 TORCH_INTERNAL_ASSERT(!isSymbolic()); 277 return false; 278 } else { 279 TORCH_INTERNAL_ASSERT(false); 280 } 281 } 282 283 template < 284 typename T, 285 typename std::enable_if_t<c10::is_complex<T>::value, int> = 0> equal(T num)286 bool equal(T num) const { 287 if (isComplex()) { 288 TORCH_INTERNAL_ASSERT(!isSymbolic()); 289 return v.z == num; 290 } else if (isFloatingPoint()) { 291 TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality"); 292 return (v.d == num.real()) && (num.imag() == T()); 293 } else if (tag == Tag::HAS_i) { 294 if (overflows<T>(v.i, /* strict_unsigned */ true)) { 295 return false; 296 } else { 297 return static_cast<T>(v.i) == num.real() && num.imag() == T(); 298 } 299 } else if (tag == Tag::HAS_u) { 300 if (overflows<T>(v.u, /* strict_unsigned */ true)) { 301 return false; 302 } else { 303 return static_cast<T>(v.u) == num.real() && num.imag() == T(); 304 } 305 } else if (tag == Tag::HAS_si) { 306 TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality"); 307 } else if (isBoolean()) { 308 // boolean scalar does not equal to a non boolean value 309 TORCH_INTERNAL_ASSERT(!isSymbolic()); 310 return false; 311 } else { 312 TORCH_INTERNAL_ASSERT(false); 313 } 314 } 315 equal(bool num)316 bool equal(bool num) const { 317 if (isBoolean()) { 318 TORCH_INTERNAL_ASSERT(!isSymbolic()); 319 return static_cast<bool>(v.i) == num; 320 } else { 321 return false; 322 } 323 } 324 type()325 ScalarType type() const { 326 if (isComplex()) { 327 return ScalarType::ComplexDouble; 328 } else if (isFloatingPoint()) { 329 return ScalarType::Double; 330 } else if (isIntegral(/*includeBool=*/false)) { 331 // Represent all integers as long, UNLESS it is unsigned and therefore 332 // unrepresentable as long 333 if (Tag::HAS_u == tag) { 334 return ScalarType::UInt64; 335 } 336 return ScalarType::Long; 337 } else if (isBoolean()) { 338 return ScalarType::Bool; 339 } else { 340 throw std::runtime_error("Unknown scalar type."); 341 } 342 } 343 Scalar(Scalar && rhs)344 Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) { 345 moveFrom(std::move(rhs)); 346 } 347 Scalar(const Scalar & rhs)348 Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) { 349 if (isSymbolic()) { 350 c10::raw::intrusive_ptr::incref(v.p); 351 } 352 } 353 Scalar(c10::SymInt si)354 Scalar(c10::SymInt si) { 355 if (auto m = si.maybe_as_int()) { 356 tag = Tag::HAS_i; 357 v.i = *m; 358 } else { 359 tag = Tag::HAS_si; 360 v.p = std::move(si).release(); 361 } 362 } 363 Scalar(c10::SymFloat sd)364 Scalar(c10::SymFloat sd) { 365 if (sd.is_symbolic()) { 366 tag = Tag::HAS_sd; 367 v.p = std::move(sd).release(); 368 } else { 369 tag = Tag::HAS_d; 370 v.d = sd.as_float_unchecked(); 371 } 372 } 373 Scalar(c10::SymBool sb)374 Scalar(c10::SymBool sb) { 375 if (auto m = sb.maybe_as_bool()) { 376 tag = Tag::HAS_b; 377 v.i = *m; 378 } else { 379 tag = Tag::HAS_sb; 380 v.p = std::move(sb).release(); 381 } 382 } 383 384 // We can't set v in the initializer list using the 385 // syntax v{ .member = ... } because it doesn't work on MSVC 386 private: 387 enum class Tag { HAS_d, HAS_i, HAS_u, HAS_z, HAS_b, HAS_sd, HAS_si, HAS_sb }; 388 389 // Note [Meaning of HAS_u] 390 // ~~~~~~~~~~~~~~~~~~~~~~~ 391 // HAS_u is a bit special. On its face, it just means that we 392 // are holding an unsigned integer. However, we generally don't 393 // distinguish between different bit sizes in Scalar (e.g., we represent 394 // float as double), instead, it represents a mathematical notion 395 // of some quantity (integral versus floating point). So actually, 396 // HAS_u is used solely to represent unsigned integers that could 397 // not be represented as a signed integer. That means only uint64_t 398 // potentially can get this tag; smaller types like uint8_t fits into a 399 // regular int and so for BC reasons we keep as an int. 400 401 // NB: assumes that self has already been cleared 402 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) moveFrom(Scalar && rhs)403 C10_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept { 404 v = rhs.v; 405 tag = rhs.tag; 406 if (rhs.tag == Tag::HAS_si || rhs.tag == Tag::HAS_sd || 407 rhs.tag == Tag::HAS_sb) { 408 // Move out of scalar 409 rhs.tag = Tag::HAS_i; 410 rhs.v.i = 0; 411 } 412 } 413 414 Tag tag; 415 416 union v_t { 417 double d{}; 418 int64_t i; 419 // See Note [Meaning of HAS_u] 420 uint64_t u; 421 c10::complex<double> z; 422 c10::intrusive_ptr_target* p; 423 // NOLINTNEXTLINE(modernize-use-equals-default) v_t()424 v_t() {} // default constructor 425 } v; 426 427 template < 428 typename T, 429 typename std::enable_if_t< 430 std::is_integral_v<T> && !std::is_same_v<T, bool>, 431 bool>* = nullptr> Scalar(T vv,bool)432 Scalar(T vv, bool) : tag(Tag::HAS_i) { 433 v.i = convert<decltype(v.i), T>(vv); 434 } 435 436 template < 437 typename T, 438 typename std::enable_if_t< 439 !std::is_integral_v<T> && !c10::is_complex<T>::value, 440 bool>* = nullptr> Scalar(T vv,bool)441 Scalar(T vv, bool) : tag(Tag::HAS_d) { 442 v.d = convert<decltype(v.d), T>(vv); 443 } 444 445 template < 446 typename T, 447 typename std::enable_if_t<c10::is_complex<T>::value, bool>* = nullptr> Scalar(T vv,bool)448 Scalar(T vv, bool) : tag(Tag::HAS_z) { 449 v.z = convert<decltype(v.z), T>(vv); 450 } 451 }; 452 453 using OptionalScalarRef = c10::OptionalRef<Scalar>; 454 455 // define the scalar.to<int64_t>() specializations 456 #define DEFINE_TO(T, name) \ 457 template <> \ 458 inline T Scalar::to<T>() const { \ 459 return to##name(); \ 460 } 461 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO) 462 DEFINE_TO(uint16_t, UInt16) 463 DEFINE_TO(uint32_t, UInt32) 464 DEFINE_TO(uint64_t, UInt64) 465 #undef DEFINE_TO 466 467 } // namespace c10 468