xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/class_type.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <memory>
4 
5 #include <ATen/core/ivalue.h>
6 #include <ATen/core/jit_type_base.h>
7 #include <optional>
8 
9 
10 namespace torch::jit {
11 struct CompilationUnit;
12 struct Function;
13 } // namespace torch::jit
14 
15 
16 namespace c10 {
17 
18 struct FunctionSchema;
19 
20 // This enumerator represents the 'kind' of an attribute - a buffer, a parameter, or neither.
21 // This state is mutually exclusive. Buffers and Parameters can only appear on modules.
22 enum class AttributeKind {
23   BUFFER,
24   PARAMETER,
25   REGULAR_ATTRIBUTE
26 };
27 
28 // This structure represents all notional booking entities in a class attribute: name, kind (see: AttributeKind), and type (see: TypePtr).
29 // Note: This structure does not represent the value of the attribute.
30 struct TORCH_API ClassAttribute {
31   public:
ClassAttributeClassAttribute32   ClassAttribute(AttributeKind kind,
33   TypePtr attributeType,
34   std::string attributeName) :
35     kind_(kind),
36     attributeType_(std::move(attributeType)),
37     attributeName_(std::move(attributeName)) {}
38 
getKindClassAttribute39   AttributeKind getKind() const {
40     return kind_;
41   }
42 
getTypeClassAttribute43   const TypePtr& getType() const {
44     return attributeType_;
45   }
46 
getNameClassAttribute47   const std::string& getName() const {
48     return attributeName_;
49   }
50 
51   private:
52   AttributeKind kind_;
53   TypePtr attributeType_;
54   std::string attributeName_;
55 };
56 
57 /**
58  * User Defined Types
59  */
60 
61 struct ClassType;
62 using ClassTypePtr = std::shared_ptr<ClassType>;
63 using ::torch::jit::CompilationUnit;
64 
65 // This represents a class in TorchScript.
66 struct TORCH_API ClassType : public NamedType {
67   // This represents an attribute of a class; a name associated with an attribute, and a
68   // getter and (optional) setter for that attribute.
69   struct Property {
70     std::string name;
71     torch::jit::Function* getter;
72     torch::jit::Function* setter;
73   };
74 
75   // Create a class type with name `name` and its methods stored in `cu`.
76   static ClassTypePtr create(
77       std::optional<QualifiedName> qualifiedName,
78       std::weak_ptr<CompilationUnit> cu,
79       bool is_module = false,
80       std::string doc_string = "",
81       std::vector<std::string> unresolved_class_attributes = {});
82 
equalsClassType83   bool equals(const Type& rhs) const override {
84     if (this == &rhs) {
85       return true;
86     }
87     if (auto user_rhs = rhs.castRaw<ClassType>()) {
88       const auto& lhs_name = name().value();
89       const auto& rhs_name = user_rhs->name().value();
90 
91       return lhs_name == rhs_name &&
92           this->compilation_unit() == user_rhs->compilation_unit();
93     }
94     return false;
95   }
96 
strClassType97   std::string str() const override {
98      return annotation_str();
99   }
100 
repr_strClassType101   std::string repr_str() const override {
102     std::stringstream ss;
103     ss << str()
104        << " (of Python compilation unit at: " << compilation_unit().get() << ")";
105     return ss.str();
106   }
107 
108   const std::vector<torch::jit::Function*>& methods() const;
109 
findAttributeClassType110   TypePtr findAttribute(const std::string& name) const {
111     size_t pos = 0;
112     for (const auto& attr : attributes_) {
113       if (name == attr.getName()) {
114         break;
115       }
116       ++pos;
117     }
118 
119     if (pos >= attributes_.size()) {
120       return nullptr;
121     }
122     return attributes_[pos].getType();
123   }
124 
getAttributeClassType125   const TypePtr& getAttribute(const std::string& name) const {
126     auto slot = findAttributeSlot(name);
127     TORCH_CHECK(
128         slot,
129         repr_str(),
130         " does not have an attribute with name '",
131         name,
132         "'");
133     return attributes_[*slot].getType();
134   }
135 
numAttributesClassType136   size_t numAttributes() const {
137     return attributes_.size();
138   }
139 
getAttributeClassType140   const TypePtr& getAttribute(size_t slot) const {
141     AT_ASSERT(slot < attributes_.size());
142     return attributes_.at(slot).getType();
143   }
144 
getAttributeNameClassType145   const std::string getAttributeName(size_t slot) const {
146     AT_ASSERT(slot < attributes_.size());
147     return attributes_[slot].getName();
148   }
149 
150   void checkNotExist(const std::string& name, const std::string& what) const;
151 
152   // Attributes are stored in a specific slot at runtime for effiency.
153   // When emitting instructions we specify the slot so that attribute access is
154   // a constant lookup
findAttributeSlotClassType155   std::optional<size_t> findAttributeSlot(const std::string& name) const {
156     size_t slot = 0;
157     for (const auto& attr : attributes_) {
158       if (name == attr.getName()) {
159         return slot;
160       }
161       slot++;
162     }
163     return std::nullopt;
164   }
getAttributeSlotClassType165   size_t getAttributeSlot(const std::string& name) const {
166     if (auto r = findAttributeSlot(name)) {
167       return *r;
168     }
169     TORCH_CHECK(
170         false,
171         repr_str(),
172         " does not have an attribute with name '",
173         name,
174         "'");
175   }
176 
hasAttributeClassType177   bool hasAttribute(const std::string& name) const {
178     return std::find_if(
179                attributes_.cbegin(),
180                attributes_.cend(),
181                [&](const ClassAttribute& attr) { return attr.getName() == name; }) !=
182         attributes_.cend();
183   }
184 
185   bool isUnresolvedClassAttribute(const std::string& name) const;
186 
containedTypesClassType187   at::ArrayRef<TypePtr> containedTypes() const override {
188     return attributeTypes_;
189   }
190 
191   size_t addAttribute(
192       const std::string& name,
193       TypePtr type,
194       bool is_parameter = false,
195       bool is_buffer = false);
196 
197   // [Internal Only] Remove attribute from the ClassType,
198   // caller is responsible to make sure the modification is safe:
199   // it is unsafe to having existing allocations
200   // of this object around anymore, and any code that works on
201   // the attribute is now invalid. Only newly created code is
202   // valid again.
203   void unsafeRemoveAttribute(const std::string& name);
204 
205   // [Internal Only] Change the type of an attribute of the ClassType,
206   // The caller is responsible to make sure the modification is safe:
207   // it is unsafe to maintain uses of the old type of the attribute,
208   // and any code that works on the attribute is now invalid.
209   // Only newly created code is valid again.
210   void unsafeChangeAttributeType(const std::string& name, const TypePtr& new_ty);
211 
212   // Add attribute \p NAME if it doesn't exist or verify that it has a
213   // compatible type otherwise.
214   size_t addOrCheckAttribute(
215       const std::string& name,
216       TypePtr ty,
217       bool is_parameter = false,
218       bool is_buffer = false) {
219     auto slot_idx = findAttributeSlot(name);
220     if (!slot_idx) {
221       return addAttribute(name, std::move(ty), is_parameter, is_buffer);
222     }
223 
224     TORCH_CHECK(
225         is_parameter == this->is_parameter(*slot_idx),
226         "Parameter field mismatch for the field '",
227         name,
228         "'");
229     const TypePtr& atype = getAttribute(*slot_idx);
230     TORCH_CHECK(
231       ty->isSubtypeOf(*atype),
232       ty->repr_str(),
233       " is not compatible with the type ",
234       atype->repr_str(),
235       " for the field '",
236       name,
237       "'");
238     return *slot_idx;
239   }
240 
241   // Get the property with the given \p name, if it exists on the class.
242   std::optional<ClassType::Property> getProperty(const std::string& name);
243   // Add a property named \p name with \p getter and \p setter as its getter and setter.
244   void addProperty(const std::string& name, torch::jit::Function* getter, torch::jit::Function* setter);
245   // Get a list of all properties.
propertiesClassType246   const std::vector<Property>& properties() const {
247     return properties_;
248   }
249 
hasConstantClassType250   bool hasConstant(const std::string& name) const {
251     return std::find_if(
252                constantNames_.cbegin(),
253                constantNames_.cend(),
254                [&](const std::string& constant) { return constant == name; }) !=
255         constantNames_.cend();
256   }
257 
258   size_t addConstant(const std::string& name, const IValue& value);
259 
260   std::optional<size_t> findConstantSlot(const std::string& name) const;
261 
getConstantSlotClassType262   size_t getConstantSlot(const std::string& name) const {
263     if (auto r = findConstantSlot(name)) {
264       return *r;
265     }
266     TORCH_CHECK(
267         false,
268         repr_str(),
269         " does not have constant field with the name '",
270         name,
271         "'");
272   }
273 
274   const std::string& getConstantName(size_t slot) const;
275 
doc_stringClassType276   const std::string& doc_string() const {
277     return doc_string_;
278   }
279 
280   IValue getConstant(const std::string& name) const;
281 
282   IValue getConstant(size_t slot) const;
283 
284   std::optional<IValue> findConstant(const std::string& name) const;
285 
286   size_t numConstants() const;
287 
constantNamesClassType288   at::ArrayRef<std::string> constantNames() const {
289     return constantNames_;
290   }
291 
292   at::ArrayRef<IValue> constantValues() const;
293 
294   // [Internal Only] Remove constant from the ClassType
295   // caller is responsible to make sure the modification is safe:
296   // it is unsafe to having existing allocations
297   // of this object around anymore, and any code that works on
298   // the attribute is now invalid. Only newly created code is
299   // valid again.
300   void unsafeRemoveConstant(const std::string& name);
301 
createWithContainedClassType302   TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
303     auto ptr = ClassType::create(name(), compilation_unit_, is_module());
304     AT_ASSERT(numAttributes() == contained_types.size());
305     for(size_t i = 0; i < attributes_.size(); ++i) {
306       AT_ASSERT(attributes_[i].getType()->isSubtypeOf(*contained_types[i]));
307       ptr->addAttribute(attributes_[i].getName(), std::move(contained_types[i]));
308     }
309     // Copy methods over
310     for (const auto& method : methods()) {
311       ptr->addMethod(method);
312     }
313     return ptr;
314   }
315 
is_moduleClassType316   bool is_module() const override {
317     return isModule_;
318   }
319 
getAttributesClassType320   const std::vector<ClassAttribute>& getAttributes() const {
321     return attributes_;
322   }
323 
is_parameterClassType324   bool is_parameter(size_t slot) const {
325     TORCH_INTERNAL_ASSERT(
326         is_module(), "asking for parameterSlots of non-Module");
327     return attributes_.at(slot).getKind() == AttributeKind::PARAMETER;
328   }
329 
is_bufferClassType330   bool is_buffer(size_t slot) const {
331     TORCH_INTERNAL_ASSERT(
332         is_module(), "asking for bufferWrittenSlots of non-Module");
333     return attributes_.at(slot).getKind() == AttributeKind::BUFFER;
334   }
335 
336   void addForwardPreHook(torch::jit::Function* pre_hook_ptr);
337   void addForwardHook(torch::jit::Function* hook_ptr);
338   torch::jit::Function* findForwardPreHook(const std::string& name) const;
339   torch::jit::Function* findForwardHook(const std::string& name) const;
340   const std::vector<torch::jit::Function*>& getForwardHooks() const;
341   const std::vector<torch::jit::Function*>& getForwardPreHooks() const;
342 
343   void checkForwardPreHookSchema(
344       size_t pre_hook_idx,
345       const FunctionSchema& pre_hook_schema) const;
346   void checkForwardHookSchema(
347       size_t hook_idx,
348       const FunctionSchema& hook_schema) const;
349 
350   void addMethod(torch::jit::Function* method);
351   torch::jit::Function* findMethod(const std::string& name) const;
352   torch::jit::Function& getMethod(const std::string& name) const;
353   torch::jit::Function* findHook(const std::string& name) const;
354   torch::jit::Function& getHook(const std::string& name) const;
355   bool hasMethod(const std::string& name) const;
356 
357   torch::jit::Function* findStaticMethod(const std::string& name) const;
358   void addStaticMethod(torch::jit::Function* method);
359 
360   // [Internal Only] Remove method from the ClassType
361   // caller is responsible to make sure the modification is safe:
362   // it is unsafe to having existing allocations
363   // of this object around anymore, and any code that works on
364   // the attribute is now invalid. Only newly created code is
365   // valid again.
366   // Note this method is intended for freezing only.
367   void unsafeRemoveMethod(const std::string& name);
368 
369   std::shared_ptr<CompilationUnit> compilation_unit();
370 
371   std::shared_ptr<const CompilationUnit> compilation_unit() const;
372 
373   // generate a refined version of this class.
374   // It has the same name but the slot Types are subtypes of
375   // the original slots. It is only valid to refine a class type in a context
376   // where it is know that there are not assignments to the objects slots
377   // that would invalidate the refinement.
378   // These variants are not registered in the global class table.
379   ClassTypePtr refine(at::ArrayRef<TypePtr> refined_slots) const;
380 
381   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
382 
383   static const TypeKind Kind = TypeKind::ClassType;
384 
385  private:
386   ClassType(
387       std::optional<QualifiedName> name,
388       std::weak_ptr<CompilationUnit> cu,
389       bool is_module = false,
390       std::string doc_string = "",
391       std::vector<std::string> unresolved_class_attributes = {});
392 
393   std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override {
394     const auto& n = name().value();
395     return n.qualifiedName();
396   }
397 
398   void addAttribute(ClassAttribute classAttribute);
399   std::string getForwardPreHookErrorMessage(size_t pre_hook_idx) const;
400   std::string getForwardHookErrorMessage(size_t hook_idx) const;
401 
402   // Mapping of attribute names -> their type.
403   // NOTE: this does not contain methods, which are stored in the module
404   // TODO: once modules support arbitrary ivalue attributes, we don't need this
405   // anymore.
406   // TODO: This is better represented as an OrderedDict, but alas it is not yet
407   // available from c10
408 
409   // Mapping of constant names -> their value.
410   std::vector<std::string> constantNames_;
411   std::vector<IValue> constantValues_;
412   // Holds method attributes
413   std::weak_ptr<CompilationUnit> compilation_unit_;
414 
415   // Holds all atrributes, attribute details are found on ClassAttribute
416   std::vector<ClassAttribute> attributes_;
417   // Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef.
418   // Never fill this without using the appropriate provideNewClassAttribute method
419   std::vector<TypePtr> attributeTypes_;
420 
421   // List of methods associated with this class.
422   std::vector<torch::jit::Function*> methods_;
423   std::vector<torch::jit::Function*> staticmethods_;
424 
425   // List of hooks to be run before/after forward.
426   std::vector<torch::jit::Function*> forward_hooks_;
427   std::vector<torch::jit::Function*> forward_pre_hooks_;
428 
429   // List of properties exposed by this class.
430   std::vector<Property> properties_;
431 
432   bool isModule_ = false;
433 
434   // Doc string of class.
435   std::string doc_string_ = "";
436 
437   // For error reporting accesses to class level attributes.
438   std::vector<std::string> unresolved_class_attributes_;
439 };
440 
441 }
442