1 #pragma once 2 3 #include <ATen/core/functional.h> 4 #include <ATen/core/ivalue.h> 5 #include <torch/csrc/jit/api/method.h> 6 #include <optional> 7 8 #include <utility> 9 10 namespace torch::jit { 11 12 struct Resolver; 13 using ResolverPtr = std::shared_ptr<Resolver>; 14 15 using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>; 16 17 // Throw this in C++ land if `attr` fails. This will be converted to a Python 18 // AttributeError by the Python binding code 19 class ObjectAttributeError : public std::runtime_error { 20 public: ObjectAttributeError(const std::string & what)21 ObjectAttributeError(const std::string& what) : std::runtime_error(what) {} 22 }; 23 24 struct TORCH_API Object { 25 Object() = default; 26 Object(const Object&) = default; 27 Object& operator=(const Object&) = default; 28 Object(Object&&) noexcept = default; 29 Object& operator=(Object&&) noexcept = default; ObjectObject30 Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {} 31 Object(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type); 32 Object( 33 c10::QualifiedName, 34 std::shared_ptr<CompilationUnit> cu, 35 bool shouldMangle = false); 36 _ivalueObject37 ObjectPtr _ivalue() const { 38 TORCH_INTERNAL_ASSERT(_ivalue_); 39 return _ivalue_; 40 } 41 typeObject42 c10::ClassTypePtr type() const { 43 return _ivalue()->type(); 44 } 45 46 struct Property { 47 std::string name; 48 Method getter_func; 49 std::optional<Method> setter_func; 50 }; 51 setattrObject52 void setattr(const std::string& name, c10::IValue v) { 53 if (_ivalue()->type()->hasConstant(name)) { 54 TORCH_CHECK( 55 false, 56 "Can't set constant '", 57 name, 58 "' which has value:", 59 _ivalue()->type()->getConstant(name)); 60 } else if (auto slot = _ivalue()->type()->findAttributeSlot(name)) { 61 const c10::TypePtr& expected = _ivalue()->type()->getAttribute(*slot); 62 TORCH_CHECK( 63 v.type()->isSubtypeOf(*expected), 64 "Expected a value of type '", 65 expected->repr_str(), 66 "' for field '", 67 name, 68 "', but found '", 69 v.type()->repr_str(), 70 "'"); 71 _ivalue()->setSlot(*slot, std::move(v)); 72 } else { 73 TORCH_CHECK(false, "Module has no attribute '", name, "'"); 74 } 75 } 76 attrObject77 c10::IValue attr(const std::string& name) const { 78 if (auto r = _ivalue()->type()->findAttributeSlot(name)) { 79 return _ivalue()->getSlot(*r); 80 } 81 if (auto r = _ivalue()->type()->findConstantSlot(name)) { 82 return _ivalue()->type()->getConstant(*r); 83 } 84 std::stringstream err; 85 err << _ivalue()->type()->repr_str() << " does not have a field with name '" 86 << name.c_str() << "'"; 87 throw ObjectAttributeError(err.str()); 88 } 89 attrObject90 c10::IValue attr(const std::string& name, c10::IValue or_else) const { 91 if (auto r = _ivalue()->type()->findAttributeSlot(name)) { 92 return _ivalue()->getSlot(*r); 93 } 94 if (auto r = _ivalue()->type()->findConstantSlot(name)) { 95 return _ivalue()->type()->getConstant(*r); 96 } 97 return or_else; 98 } 99 hasattrObject100 bool hasattr(const std::string& name) const { 101 return _ivalue()->type()->hasAttribute(name) || 102 _ivalue()->type()->hasConstant(name); 103 } 104 105 // each object owns its methods. The reference returned here 106 // is guaranteed to stay valid until this module has been destroyed get_methodObject107 Method get_method(const std::string& name) const { 108 if (auto method = find_method(name)) { 109 return *method; 110 } 111 AT_ERROR("Method '", name, "' is not defined."); 112 } 113 get_methodsObject114 const std::vector<Method> get_methods() const { 115 return c10::fmap(type()->methods(), [&](Function* func) { 116 return Method(_ivalue(), func); 117 }); 118 } 119 has_propertyObject120 bool has_property(const std::string& name) const { 121 for (const auto& prop : type()->properties()) { 122 if (prop.name == name) { 123 return true; 124 } 125 } 126 return false; 127 } 128 get_propertyObject129 const Property get_property(const std::string& name) const { 130 for (const auto& prop : type()->properties()) { 131 if (prop.name == name) { 132 std::optional<Method> setter = std::nullopt; 133 if (prop.setter) { 134 setter = Method(_ivalue(), prop.setter); 135 } 136 return Property{ 137 prop.name, Method(_ivalue(), prop.getter), std::move(setter)}; 138 } 139 } 140 AT_ERROR("Property '", name, "' is not defined."); 141 } 142 get_propertiesObject143 const std::vector<Property> get_properties() const { 144 return c10::fmap(type()->properties(), [&](ClassType::Property prop) { 145 std::optional<Method> setter = std::nullopt; 146 if (prop.setter) { 147 setter = Method(_ivalue(), prop.setter); 148 } 149 return Property{ 150 std::move(prop.name), 151 Method(_ivalue(), prop.getter), 152 std::move(setter)}; 153 }); 154 } 155 156 std::optional<Method> find_method(const std::string& basename) const; 157 158 /// Run a method from this module. 159 /// 160 /// For example: 161 /// @code 162 /// IValue output = module->run("relu_script", a, b); 163 /// @endcode 164 /// 165 /// To get a compile a module from a source string, see torch::jit::compile 166 /// 167 /// @param method_name The name of the method to run 168 /// @param args Arguments to be passed to the method 169 /// @return An IValue containing the return value (or values if it is a tuple) 170 /// from the method 171 template <typename... Types> run_methodObject172 IValue run_method(const std::string& method_name, Types&&... args) { 173 return get_method(method_name)({IValue(std::forward<Types>(args))...}); 174 } 175 176 // so that C++ users can easily add methods 177 void define(const std::string& src, const ResolverPtr& resolver = nullptr); 178 num_slotsObject179 size_t num_slots() const { 180 return _ivalue()->slots().size(); 181 } 182 183 // shallow copy the object 184 Object copy() const; 185 186 // Copies all the attributes of the object recursively without creating new 187 // `ClassType`, including deepcopy of Tensors 188 Object deepcopy() const; 189 190 private: 191 // mutable be we lazily initialize in module_object. 192 mutable ObjectPtr _ivalue_; 193 }; 194 195 namespace script { 196 // We once had a `script::` namespace that was deleted. This is for backcompat 197 // of the public API; new code should not use this type alias. 198 using Object = ::torch::jit::Object; 199 } // namespace script 200 } // namespace torch::jit 201