1 #pragma once 2 3 #include <memory> 4 5 #include <ATen/core/ivalue.h> 6 #include <ATen/core/jit_type_base.h> 7 #include <optional> 8 9 10 namespace torch::jit { 11 struct CompilationUnit; 12 struct Function; 13 } // namespace torch::jit 14 15 16 namespace c10 { 17 18 struct FunctionSchema; 19 20 // This enumerator represents the 'kind' of an attribute - a buffer, a parameter, or neither. 21 // This state is mutually exclusive. Buffers and Parameters can only appear on modules. 22 enum class AttributeKind { 23 BUFFER, 24 PARAMETER, 25 REGULAR_ATTRIBUTE 26 }; 27 28 // This structure represents all notional booking entities in a class attribute: name, kind (see: AttributeKind), and type (see: TypePtr). 29 // Note: This structure does not represent the value of the attribute. 30 struct TORCH_API ClassAttribute { 31 public: ClassAttributeClassAttribute32 ClassAttribute(AttributeKind kind, 33 TypePtr attributeType, 34 std::string attributeName) : 35 kind_(kind), 36 attributeType_(std::move(attributeType)), 37 attributeName_(std::move(attributeName)) {} 38 getKindClassAttribute39 AttributeKind getKind() const { 40 return kind_; 41 } 42 getTypeClassAttribute43 const TypePtr& getType() const { 44 return attributeType_; 45 } 46 getNameClassAttribute47 const std::string& getName() const { 48 return attributeName_; 49 } 50 51 private: 52 AttributeKind kind_; 53 TypePtr attributeType_; 54 std::string attributeName_; 55 }; 56 57 /** 58 * User Defined Types 59 */ 60 61 struct ClassType; 62 using ClassTypePtr = std::shared_ptr<ClassType>; 63 using ::torch::jit::CompilationUnit; 64 65 // This represents a class in TorchScript. 66 struct TORCH_API ClassType : public NamedType { 67 // This represents an attribute of a class; a name associated with an attribute, and a 68 // getter and (optional) setter for that attribute. 69 struct Property { 70 std::string name; 71 torch::jit::Function* getter; 72 torch::jit::Function* setter; 73 }; 74 75 // Create a class type with name `name` and its methods stored in `cu`. 76 static ClassTypePtr create( 77 std::optional<QualifiedName> qualifiedName, 78 std::weak_ptr<CompilationUnit> cu, 79 bool is_module = false, 80 std::string doc_string = "", 81 std::vector<std::string> unresolved_class_attributes = {}); 82 equalsClassType83 bool equals(const Type& rhs) const override { 84 if (this == &rhs) { 85 return true; 86 } 87 if (auto user_rhs = rhs.castRaw<ClassType>()) { 88 const auto& lhs_name = name().value(); 89 const auto& rhs_name = user_rhs->name().value(); 90 91 return lhs_name == rhs_name && 92 this->compilation_unit() == user_rhs->compilation_unit(); 93 } 94 return false; 95 } 96 strClassType97 std::string str() const override { 98 return annotation_str(); 99 } 100 repr_strClassType101 std::string repr_str() const override { 102 std::stringstream ss; 103 ss << str() 104 << " (of Python compilation unit at: " << compilation_unit().get() << ")"; 105 return ss.str(); 106 } 107 108 const std::vector<torch::jit::Function*>& methods() const; 109 findAttributeClassType110 TypePtr findAttribute(const std::string& name) const { 111 size_t pos = 0; 112 for (const auto& attr : attributes_) { 113 if (name == attr.getName()) { 114 break; 115 } 116 ++pos; 117 } 118 119 if (pos >= attributes_.size()) { 120 return nullptr; 121 } 122 return attributes_[pos].getType(); 123 } 124 getAttributeClassType125 const TypePtr& getAttribute(const std::string& name) const { 126 auto slot = findAttributeSlot(name); 127 TORCH_CHECK( 128 slot, 129 repr_str(), 130 " does not have an attribute with name '", 131 name, 132 "'"); 133 return attributes_[*slot].getType(); 134 } 135 numAttributesClassType136 size_t numAttributes() const { 137 return attributes_.size(); 138 } 139 getAttributeClassType140 const TypePtr& getAttribute(size_t slot) const { 141 AT_ASSERT(slot < attributes_.size()); 142 return attributes_.at(slot).getType(); 143 } 144 getAttributeNameClassType145 const std::string getAttributeName(size_t slot) const { 146 AT_ASSERT(slot < attributes_.size()); 147 return attributes_[slot].getName(); 148 } 149 150 void checkNotExist(const std::string& name, const std::string& what) const; 151 152 // Attributes are stored in a specific slot at runtime for effiency. 153 // When emitting instructions we specify the slot so that attribute access is 154 // a constant lookup findAttributeSlotClassType155 std::optional<size_t> findAttributeSlot(const std::string& name) const { 156 size_t slot = 0; 157 for (const auto& attr : attributes_) { 158 if (name == attr.getName()) { 159 return slot; 160 } 161 slot++; 162 } 163 return std::nullopt; 164 } getAttributeSlotClassType165 size_t getAttributeSlot(const std::string& name) const { 166 if (auto r = findAttributeSlot(name)) { 167 return *r; 168 } 169 TORCH_CHECK( 170 false, 171 repr_str(), 172 " does not have an attribute with name '", 173 name, 174 "'"); 175 } 176 hasAttributeClassType177 bool hasAttribute(const std::string& name) const { 178 return std::find_if( 179 attributes_.cbegin(), 180 attributes_.cend(), 181 [&](const ClassAttribute& attr) { return attr.getName() == name; }) != 182 attributes_.cend(); 183 } 184 185 bool isUnresolvedClassAttribute(const std::string& name) const; 186 containedTypesClassType187 at::ArrayRef<TypePtr> containedTypes() const override { 188 return attributeTypes_; 189 } 190 191 size_t addAttribute( 192 const std::string& name, 193 TypePtr type, 194 bool is_parameter = false, 195 bool is_buffer = false); 196 197 // [Internal Only] Remove attribute from the ClassType, 198 // caller is responsible to make sure the modification is safe: 199 // it is unsafe to having existing allocations 200 // of this object around anymore, and any code that works on 201 // the attribute is now invalid. Only newly created code is 202 // valid again. 203 void unsafeRemoveAttribute(const std::string& name); 204 205 // [Internal Only] Change the type of an attribute of the ClassType, 206 // The caller is responsible to make sure the modification is safe: 207 // it is unsafe to maintain uses of the old type of the attribute, 208 // and any code that works on the attribute is now invalid. 209 // Only newly created code is valid again. 210 void unsafeChangeAttributeType(const std::string& name, const TypePtr& new_ty); 211 212 // Add attribute \p NAME if it doesn't exist or verify that it has a 213 // compatible type otherwise. 214 size_t addOrCheckAttribute( 215 const std::string& name, 216 TypePtr ty, 217 bool is_parameter = false, 218 bool is_buffer = false) { 219 auto slot_idx = findAttributeSlot(name); 220 if (!slot_idx) { 221 return addAttribute(name, std::move(ty), is_parameter, is_buffer); 222 } 223 224 TORCH_CHECK( 225 is_parameter == this->is_parameter(*slot_idx), 226 "Parameter field mismatch for the field '", 227 name, 228 "'"); 229 const TypePtr& atype = getAttribute(*slot_idx); 230 TORCH_CHECK( 231 ty->isSubtypeOf(*atype), 232 ty->repr_str(), 233 " is not compatible with the type ", 234 atype->repr_str(), 235 " for the field '", 236 name, 237 "'"); 238 return *slot_idx; 239 } 240 241 // Get the property with the given \p name, if it exists on the class. 242 std::optional<ClassType::Property> getProperty(const std::string& name); 243 // Add a property named \p name with \p getter and \p setter as its getter and setter. 244 void addProperty(const std::string& name, torch::jit::Function* getter, torch::jit::Function* setter); 245 // Get a list of all properties. propertiesClassType246 const std::vector<Property>& properties() const { 247 return properties_; 248 } 249 hasConstantClassType250 bool hasConstant(const std::string& name) const { 251 return std::find_if( 252 constantNames_.cbegin(), 253 constantNames_.cend(), 254 [&](const std::string& constant) { return constant == name; }) != 255 constantNames_.cend(); 256 } 257 258 size_t addConstant(const std::string& name, const IValue& value); 259 260 std::optional<size_t> findConstantSlot(const std::string& name) const; 261 getConstantSlotClassType262 size_t getConstantSlot(const std::string& name) const { 263 if (auto r = findConstantSlot(name)) { 264 return *r; 265 } 266 TORCH_CHECK( 267 false, 268 repr_str(), 269 " does not have constant field with the name '", 270 name, 271 "'"); 272 } 273 274 const std::string& getConstantName(size_t slot) const; 275 doc_stringClassType276 const std::string& doc_string() const { 277 return doc_string_; 278 } 279 280 IValue getConstant(const std::string& name) const; 281 282 IValue getConstant(size_t slot) const; 283 284 std::optional<IValue> findConstant(const std::string& name) const; 285 286 size_t numConstants() const; 287 constantNamesClassType288 at::ArrayRef<std::string> constantNames() const { 289 return constantNames_; 290 } 291 292 at::ArrayRef<IValue> constantValues() const; 293 294 // [Internal Only] Remove constant from the ClassType 295 // caller is responsible to make sure the modification is safe: 296 // it is unsafe to having existing allocations 297 // of this object around anymore, and any code that works on 298 // the attribute is now invalid. Only newly created code is 299 // valid again. 300 void unsafeRemoveConstant(const std::string& name); 301 createWithContainedClassType302 TypePtr createWithContained(std::vector<TypePtr> contained_types) const override { 303 auto ptr = ClassType::create(name(), compilation_unit_, is_module()); 304 AT_ASSERT(numAttributes() == contained_types.size()); 305 for(size_t i = 0; i < attributes_.size(); ++i) { 306 AT_ASSERT(attributes_[i].getType()->isSubtypeOf(*contained_types[i])); 307 ptr->addAttribute(attributes_[i].getName(), std::move(contained_types[i])); 308 } 309 // Copy methods over 310 for (const auto& method : methods()) { 311 ptr->addMethod(method); 312 } 313 return ptr; 314 } 315 is_moduleClassType316 bool is_module() const override { 317 return isModule_; 318 } 319 getAttributesClassType320 const std::vector<ClassAttribute>& getAttributes() const { 321 return attributes_; 322 } 323 is_parameterClassType324 bool is_parameter(size_t slot) const { 325 TORCH_INTERNAL_ASSERT( 326 is_module(), "asking for parameterSlots of non-Module"); 327 return attributes_.at(slot).getKind() == AttributeKind::PARAMETER; 328 } 329 is_bufferClassType330 bool is_buffer(size_t slot) const { 331 TORCH_INTERNAL_ASSERT( 332 is_module(), "asking for bufferWrittenSlots of non-Module"); 333 return attributes_.at(slot).getKind() == AttributeKind::BUFFER; 334 } 335 336 void addForwardPreHook(torch::jit::Function* pre_hook_ptr); 337 void addForwardHook(torch::jit::Function* hook_ptr); 338 torch::jit::Function* findForwardPreHook(const std::string& name) const; 339 torch::jit::Function* findForwardHook(const std::string& name) const; 340 const std::vector<torch::jit::Function*>& getForwardHooks() const; 341 const std::vector<torch::jit::Function*>& getForwardPreHooks() const; 342 343 void checkForwardPreHookSchema( 344 size_t pre_hook_idx, 345 const FunctionSchema& pre_hook_schema) const; 346 void checkForwardHookSchema( 347 size_t hook_idx, 348 const FunctionSchema& hook_schema) const; 349 350 void addMethod(torch::jit::Function* method); 351 torch::jit::Function* findMethod(const std::string& name) const; 352 torch::jit::Function& getMethod(const std::string& name) const; 353 torch::jit::Function* findHook(const std::string& name) const; 354 torch::jit::Function& getHook(const std::string& name) const; 355 bool hasMethod(const std::string& name) const; 356 357 torch::jit::Function* findStaticMethod(const std::string& name) const; 358 void addStaticMethod(torch::jit::Function* method); 359 360 // [Internal Only] Remove method from the ClassType 361 // caller is responsible to make sure the modification is safe: 362 // it is unsafe to having existing allocations 363 // of this object around anymore, and any code that works on 364 // the attribute is now invalid. Only newly created code is 365 // valid again. 366 // Note this method is intended for freezing only. 367 void unsafeRemoveMethod(const std::string& name); 368 369 std::shared_ptr<CompilationUnit> compilation_unit(); 370 371 std::shared_ptr<const CompilationUnit> compilation_unit() const; 372 373 // generate a refined version of this class. 374 // It has the same name but the slot Types are subtypes of 375 // the original slots. It is only valid to refine a class type in a context 376 // where it is know that there are not assignments to the objects slots 377 // that would invalidate the refinement. 378 // These variants are not registered in the global class table. 379 ClassTypePtr refine(at::ArrayRef<TypePtr> refined_slots) const; 380 381 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; 382 383 static const TypeKind Kind = TypeKind::ClassType; 384 385 private: 386 ClassType( 387 std::optional<QualifiedName> name, 388 std::weak_ptr<CompilationUnit> cu, 389 bool is_module = false, 390 std::string doc_string = "", 391 std::vector<std::string> unresolved_class_attributes = {}); 392 393 std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override { 394 const auto& n = name().value(); 395 return n.qualifiedName(); 396 } 397 398 void addAttribute(ClassAttribute classAttribute); 399 std::string getForwardPreHookErrorMessage(size_t pre_hook_idx) const; 400 std::string getForwardHookErrorMessage(size_t hook_idx) const; 401 402 // Mapping of attribute names -> their type. 403 // NOTE: this does not contain methods, which are stored in the module 404 // TODO: once modules support arbitrary ivalue attributes, we don't need this 405 // anymore. 406 // TODO: This is better represented as an OrderedDict, but alas it is not yet 407 // available from c10 408 409 // Mapping of constant names -> their value. 410 std::vector<std::string> constantNames_; 411 std::vector<IValue> constantValues_; 412 // Holds method attributes 413 std::weak_ptr<CompilationUnit> compilation_unit_; 414 415 // Holds all atrributes, attribute details are found on ClassAttribute 416 std::vector<ClassAttribute> attributes_; 417 // Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef. 418 // Never fill this without using the appropriate provideNewClassAttribute method 419 std::vector<TypePtr> attributeTypes_; 420 421 // List of methods associated with this class. 422 std::vector<torch::jit::Function*> methods_; 423 std::vector<torch::jit::Function*> staticmethods_; 424 425 // List of hooks to be run before/after forward. 426 std::vector<torch::jit::Function*> forward_hooks_; 427 std::vector<torch::jit::Function*> forward_pre_hooks_; 428 429 // List of properties exposed by this class. 430 std::vector<Property> properties_; 431 432 bool isModule_ = false; 433 434 // Doc string of class. 435 std::string doc_string_ = ""; 436 437 // For error reporting accesses to class level attributes. 438 std::vector<std::string> unresolved_class_attributes_; 439 }; 440 441 } 442