1 #pragma once 2 3 #include <torch/detail/static.h> 4 #include <torch/nn/module.h> 5 #include <torch/nn/pimpl.h> 6 #include <torch/types.h> 7 8 #include <torch/csrc/autograd/variable.h> 9 #include <torch/csrc/utils/variadic.h> 10 11 #include <memory> 12 #include <type_traits> 13 #include <typeinfo> 14 #include <utility> 15 16 namespace torch { 17 namespace nn { 18 19 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyValue ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 20 21 /// An implementation of `std::any` which stores 22 /// a type erased object, whose concrete value can be retrieved at runtime by 23 /// checking if the `typeid()` of a requested type matches the `typeid()` of 24 /// the object stored. 25 class AnyValue { 26 public: 27 /// Move construction and assignment is allowed, and follows the default 28 /// behavior of move for `std::unique_ptr`. 29 AnyValue(AnyValue&&) = default; 30 AnyValue& operator=(AnyValue&&) = default; 31 32 /// Copy construction and assignment is allowed. AnyValue(const AnyValue & other)33 AnyValue(const AnyValue& other) : content_(other.content_->clone()) {} 34 AnyValue& operator=(const AnyValue& other) { 35 content_ = other.content_->clone(); 36 return *this; 37 } 38 39 /// Constructs the `AnyValue` from value type. 40 template <typename T> 41 // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) AnyValue(T && value)42 explicit AnyValue(T&& value) 43 : content_( 44 std::make_unique<Holder<std::decay_t<T>>>(std::forward<T>(value))) { 45 } 46 47 /// Returns a pointer to the value contained in the `AnyValue` if the type 48 /// passed as template parameter matches the type of the value stored, and 49 /// returns a null pointer otherwise. 50 template <typename T> try_get()51 T* try_get() { 52 static_assert( 53 !std::is_reference<T>::value, 54 "AnyValue stores decayed types, you cannot cast it to a reference type"); 55 static_assert( 56 !std::is_array<T>::value, 57 "AnyValue stores decayed types, you must cast it to T* instead of T[]"); 58 if (typeid(T).hash_code() == type_info().hash_code()) { 59 return &static_cast<Holder<T>&>(*content_).value; 60 } 61 return nullptr; 62 } 63 64 /// Returns the value contained in the `AnyValue` if the type passed as 65 /// template parameter matches the type of the value stored, and throws an 66 /// exception otherwise. 67 template <typename T> get()68 T get() { 69 if (auto* maybe_value = try_get<T>()) { 70 return *maybe_value; 71 } 72 AT_ERROR( 73 "Attempted to cast AnyValue to ", 74 c10::demangle(typeid(T).name()), 75 ", but its actual type is ", 76 c10::demangle(type_info().name())); 77 } 78 79 /// Returns the `type_info` object of the contained value. type_info()80 const std::type_info& type_info() const noexcept { 81 return content_->type_info; 82 } 83 84 private: 85 friend struct AnyModulePlaceholder; 86 friend struct TestAnyValue; 87 88 /// \internal 89 /// The static type of the object we store in the `AnyValue`, which erases the 90 /// actual object's type, allowing us only to check the `type_info` of the 91 /// type stored in the dynamic type. 92 struct Placeholder { PlaceholderPlaceholder93 explicit Placeholder(const std::type_info& type_info_) noexcept 94 : type_info(type_info_) {} 95 Placeholder(const Placeholder&) = default; 96 Placeholder(Placeholder&&) = default; 97 virtual ~Placeholder() = default; clonePlaceholder98 virtual std::unique_ptr<Placeholder> clone() const { 99 TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`"); 100 } 101 const std::type_info& type_info; 102 }; 103 104 /// \internal 105 /// The dynamic type of the object we store in the `AnyValue`, which hides the 106 /// actual object we have erased in this `AnyValue`. 107 template <typename T> 108 struct Holder : public Placeholder { 109 /// A template because T&& would not be universal reference here. 110 template <typename U> 111 // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) HolderHolder112 explicit Holder(U&& value_) noexcept 113 : Placeholder(typeid(T)), value(std::forward<U>(value_)) {} cloneHolder114 std::unique_ptr<Placeholder> clone() const override { 115 return std::make_unique<Holder<T>>(value); 116 } 117 T value; 118 }; 119 120 /// The type erased object. 121 std::unique_ptr<Placeholder> content_; 122 }; 123 124 } // namespace nn 125 } // namespace torch 126