xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/api/object.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/functional.h>
4 #include <ATen/core/ivalue.h>
5 #include <torch/csrc/jit/api/method.h>
6 #include <optional>
7 
8 #include <utility>
9 
10 namespace torch::jit {
11 
12 struct Resolver;
13 using ResolverPtr = std::shared_ptr<Resolver>;
14 
15 using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
16 
17 // Throw this in C++ land if `attr` fails. This will be converted to a Python
18 // AttributeError by the Python binding code
19 class ObjectAttributeError : public std::runtime_error {
20  public:
ObjectAttributeError(const std::string & what)21   ObjectAttributeError(const std::string& what) : std::runtime_error(what) {}
22 };
23 
24 struct TORCH_API Object {
25   Object() = default;
26   Object(const Object&) = default;
27   Object& operator=(const Object&) = default;
28   Object(Object&&) noexcept = default;
29   Object& operator=(Object&&) noexcept = default;
ObjectObject30   Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {}
31   Object(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
32   Object(
33       c10::QualifiedName,
34       std::shared_ptr<CompilationUnit> cu,
35       bool shouldMangle = false);
36 
_ivalueObject37   ObjectPtr _ivalue() const {
38     TORCH_INTERNAL_ASSERT(_ivalue_);
39     return _ivalue_;
40   }
41 
typeObject42   c10::ClassTypePtr type() const {
43     return _ivalue()->type();
44   }
45 
46   struct Property {
47     std::string name;
48     Method getter_func;
49     std::optional<Method> setter_func;
50   };
51 
setattrObject52   void setattr(const std::string& name, c10::IValue v) {
53     if (_ivalue()->type()->hasConstant(name)) {
54       TORCH_CHECK(
55           false,
56           "Can't set constant '",
57           name,
58           "' which has value:",
59           _ivalue()->type()->getConstant(name));
60     } else if (auto slot = _ivalue()->type()->findAttributeSlot(name)) {
61       const c10::TypePtr& expected = _ivalue()->type()->getAttribute(*slot);
62       TORCH_CHECK(
63           v.type()->isSubtypeOf(*expected),
64           "Expected a value of type '",
65           expected->repr_str(),
66           "' for field '",
67           name,
68           "', but found '",
69           v.type()->repr_str(),
70           "'");
71       _ivalue()->setSlot(*slot, std::move(v));
72     } else {
73       TORCH_CHECK(false, "Module has no attribute '", name, "'");
74     }
75   }
76 
attrObject77   c10::IValue attr(const std::string& name) const {
78     if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
79       return _ivalue()->getSlot(*r);
80     }
81     if (auto r = _ivalue()->type()->findConstantSlot(name)) {
82       return _ivalue()->type()->getConstant(*r);
83     }
84     std::stringstream err;
85     err << _ivalue()->type()->repr_str() << " does not have a field with name '"
86         << name.c_str() << "'";
87     throw ObjectAttributeError(err.str());
88   }
89 
attrObject90   c10::IValue attr(const std::string& name, c10::IValue or_else) const {
91     if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
92       return _ivalue()->getSlot(*r);
93     }
94     if (auto r = _ivalue()->type()->findConstantSlot(name)) {
95       return _ivalue()->type()->getConstant(*r);
96     }
97     return or_else;
98   }
99 
hasattrObject100   bool hasattr(const std::string& name) const {
101     return _ivalue()->type()->hasAttribute(name) ||
102         _ivalue()->type()->hasConstant(name);
103   }
104 
105   // each object owns its methods. The reference returned here
106   // is guaranteed to stay valid until this module has been destroyed
get_methodObject107   Method get_method(const std::string& name) const {
108     if (auto method = find_method(name)) {
109       return *method;
110     }
111     AT_ERROR("Method '", name, "' is not defined.");
112   }
113 
get_methodsObject114   const std::vector<Method> get_methods() const {
115     return c10::fmap(type()->methods(), [&](Function* func) {
116       return Method(_ivalue(), func);
117     });
118   }
119 
has_propertyObject120   bool has_property(const std::string& name) const {
121     for (const auto& prop : type()->properties()) {
122       if (prop.name == name) {
123         return true;
124       }
125     }
126     return false;
127   }
128 
get_propertyObject129   const Property get_property(const std::string& name) const {
130     for (const auto& prop : type()->properties()) {
131       if (prop.name == name) {
132         std::optional<Method> setter = std::nullopt;
133         if (prop.setter) {
134           setter = Method(_ivalue(), prop.setter);
135         }
136         return Property{
137             prop.name, Method(_ivalue(), prop.getter), std::move(setter)};
138       }
139     }
140     AT_ERROR("Property '", name, "' is not defined.");
141   }
142 
get_propertiesObject143   const std::vector<Property> get_properties() const {
144     return c10::fmap(type()->properties(), [&](ClassType::Property prop) {
145       std::optional<Method> setter = std::nullopt;
146       if (prop.setter) {
147         setter = Method(_ivalue(), prop.setter);
148       }
149       return Property{
150           std::move(prop.name),
151           Method(_ivalue(), prop.getter),
152           std::move(setter)};
153     });
154   }
155 
156   std::optional<Method> find_method(const std::string& basename) const;
157 
158   /// Run a method from this module.
159   ///
160   /// For example:
161   /// @code
162   ///   IValue output = module->run("relu_script", a, b);
163   /// @endcode
164   ///
165   /// To get a compile a module from a source string, see torch::jit::compile
166   ///
167   /// @param method_name The name of the method to run
168   /// @param args Arguments to be passed to the method
169   /// @return An IValue containing the return value (or values if it is a tuple)
170   /// from the method
171   template <typename... Types>
run_methodObject172   IValue run_method(const std::string& method_name, Types&&... args) {
173     return get_method(method_name)({IValue(std::forward<Types>(args))...});
174   }
175 
176   // so that C++ users can easily add methods
177   void define(const std::string& src, const ResolverPtr& resolver = nullptr);
178 
num_slotsObject179   size_t num_slots() const {
180     return _ivalue()->slots().size();
181   }
182 
183   // shallow copy the object
184   Object copy() const;
185 
186   // Copies all the attributes of the object recursively without creating new
187   // `ClassType`, including deepcopy of Tensors
188   Object deepcopy() const;
189 
190  private:
191   // mutable be we lazily initialize in module_object.
192   mutable ObjectPtr _ivalue_;
193 };
194 
195 namespace script {
196 // We once had a `script::` namespace that was deleted. This is for backcompat
197 // of the public API; new code should not use this type alias.
198 using Object = ::torch::jit::Object;
199 } // namespace script
200 } // namespace torch::jit
201