xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/container/any_value.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/detail/static.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/types.h>
7 
8 #include <torch/csrc/autograd/variable.h>
9 #include <torch/csrc/utils/variadic.h>
10 
11 #include <memory>
12 #include <type_traits>
13 #include <typeinfo>
14 #include <utility>
15 
16 namespace torch {
17 namespace nn {
18 
19 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyValue ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
20 
21 /// An implementation of `std::any` which stores
22 /// a type erased object, whose concrete value can be retrieved at runtime by
23 /// checking if the `typeid()` of a requested type matches the `typeid()` of
24 /// the object stored.
25 class AnyValue {
26  public:
27   /// Move construction and assignment is allowed, and follows the default
28   /// behavior of move for `std::unique_ptr`.
29   AnyValue(AnyValue&&) = default;
30   AnyValue& operator=(AnyValue&&) = default;
31 
32   /// Copy construction and assignment is allowed.
AnyValue(const AnyValue & other)33   AnyValue(const AnyValue& other) : content_(other.content_->clone()) {}
34   AnyValue& operator=(const AnyValue& other) {
35     content_ = other.content_->clone();
36     return *this;
37   }
38 
39   /// Constructs the `AnyValue` from value type.
40   template <typename T>
41   // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
AnyValue(T && value)42   explicit AnyValue(T&& value)
43       : content_(
44             std::make_unique<Holder<std::decay_t<T>>>(std::forward<T>(value))) {
45   }
46 
47   /// Returns a pointer to the value contained in the `AnyValue` if the type
48   /// passed as template parameter matches the type of the value stored, and
49   /// returns a null pointer otherwise.
50   template <typename T>
try_get()51   T* try_get() {
52     static_assert(
53         !std::is_reference<T>::value,
54         "AnyValue stores decayed types, you cannot cast it to a reference type");
55     static_assert(
56         !std::is_array<T>::value,
57         "AnyValue stores decayed types, you must cast it to T* instead of T[]");
58     if (typeid(T).hash_code() == type_info().hash_code()) {
59       return &static_cast<Holder<T>&>(*content_).value;
60     }
61     return nullptr;
62   }
63 
64   /// Returns the value contained in the `AnyValue` if the type passed as
65   /// template parameter matches the type of the value stored, and throws an
66   /// exception otherwise.
67   template <typename T>
get()68   T get() {
69     if (auto* maybe_value = try_get<T>()) {
70       return *maybe_value;
71     }
72     AT_ERROR(
73         "Attempted to cast AnyValue to ",
74         c10::demangle(typeid(T).name()),
75         ", but its actual type is ",
76         c10::demangle(type_info().name()));
77   }
78 
79   /// Returns the `type_info` object of the contained value.
type_info()80   const std::type_info& type_info() const noexcept {
81     return content_->type_info;
82   }
83 
84  private:
85   friend struct AnyModulePlaceholder;
86   friend struct TestAnyValue;
87 
88   /// \internal
89   /// The static type of the object we store in the `AnyValue`, which erases the
90   /// actual object's type, allowing us only to check the `type_info` of the
91   /// type stored in the dynamic type.
92   struct Placeholder {
PlaceholderPlaceholder93     explicit Placeholder(const std::type_info& type_info_) noexcept
94         : type_info(type_info_) {}
95     Placeholder(const Placeholder&) = default;
96     Placeholder(Placeholder&&) = default;
97     virtual ~Placeholder() = default;
clonePlaceholder98     virtual std::unique_ptr<Placeholder> clone() const {
99       TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`");
100     }
101     const std::type_info& type_info;
102   };
103 
104   /// \internal
105   /// The dynamic type of the object we store in the `AnyValue`, which hides the
106   /// actual object we have erased in this `AnyValue`.
107   template <typename T>
108   struct Holder : public Placeholder {
109     /// A template because T&& would not be universal reference here.
110     template <typename U>
111     // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
HolderHolder112     explicit Holder(U&& value_) noexcept
113         : Placeholder(typeid(T)), value(std::forward<U>(value_)) {}
cloneHolder114     std::unique_ptr<Placeholder> clone() const override {
115       return std::make_unique<Holder<T>>(value);
116     }
117     T value;
118   };
119 
120   /// The type erased object.
121   std::unique_ptr<Placeholder> content_;
122 };
123 
124 } // namespace nn
125 } // namespace torch
126