xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/class_type.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/class_type.h>
2 
3 #include <ATen/core/Dict.h>
4 #include <ATen/core/Tensor.h>
5 #include <ATen/core/function_schema.h>
6 #include <ATen/core/ivalue.h>
7 #include <c10/macros/Macros.h>
8 #include <c10/util/irange.h>
9 #include <ATen/core/grad_mode.h>
10 #include <ATen/core/function.h>
11 
12 namespace c10 {
13 
addMethod(torch::jit::Function * method)14 void ClassType::addMethod(torch::jit::Function* method) {
15   TORCH_CHECK(
16       findMethod(method->name()) == nullptr,
17       "Can't redefine method: ",
18       method->name(),
19       " on class: ",
20       repr_str());
21   methods_.push_back(method);
22 }
23 
getForwardHooks() const24 const std::vector<torch::jit::Function*>& ClassType::getForwardHooks() const {
25     return forward_hooks_;
26 }
27 
getForwardPreHooks() const28 const std::vector<torch::jit::Function*>& ClassType::getForwardPreHooks() const {
29     return forward_pre_hooks_;
30 }
31 
addForwardPreHook(torch::jit::Function * pre_hook_ptr)32 void ClassType::addForwardPreHook(torch::jit::Function* pre_hook_ptr) {
33     forward_pre_hooks_.emplace_back(pre_hook_ptr);
34 }
35 
addForwardHook(torch::jit::Function * hook_ptr)36 void ClassType::addForwardHook(torch::jit::Function* hook_ptr) {
37     forward_hooks_.emplace_back(hook_ptr);
38 }
39 
findForwardPreHook(const std::string & name) const40 torch::jit::Function* ClassType::findForwardPreHook(const std::string& name) const {
41   for (const auto& pre_hook : forward_pre_hooks_) {
42     if (name == pre_hook->name()) {
43       return pre_hook;
44     }
45   }
46   return nullptr;
47 }
48 
findForwardHook(const std::string & name) const49 torch::jit::Function* ClassType::findForwardHook(const std::string& name) const {
50   for (const auto& hook : forward_hooks_) {
51     if (name == hook->name()) {
52       return hook;
53     }
54   }
55   return nullptr;
56 }
57 
getSchemaInputTypesString(const FunctionSchema & schema)58 static std::string getSchemaInputTypesString(const FunctionSchema& schema) {
59   std::stringstream input_types;
60   const std::vector<Argument>& forward_args = schema.arguments();
61   for (const auto i : c10::irange(1, forward_args.size())) {
62     input_types << forward_args[i].type()->annotation_str();
63     if (forward_args.size() - 1 != i) {
64       input_types << ", ";
65     }
66   }
67   if (forward_args.size() == 1) {
68     input_types << "()";
69   }
70   return input_types.str();
71 }
72 
getForwardPreHookErrorMessage(size_t pre_hook_idx) const73 std::string ClassType::getForwardPreHookErrorMessage(size_t pre_hook_idx) const {
74   const std::string& pre_hook_name = forward_pre_hooks_[pre_hook_idx]->name();
75   const FunctionSchema& forward_schema = getMethod("forward").getSchema();
76   std::string input_types = getSchemaInputTypesString(forward_schema);
77   const std::vector<Argument>& forward_args = forward_schema.arguments();
78 
79   std::string single_output = "";
80   if (forward_args.size() == 2 &&
81       forward_args[1].type()->cast<TupleType>() == nullptr) {
82     // if the output type is a single tuple, it needs to be wrapped in an outer tuple
83     // to match eager's behavior
84     single_output = ", '" + forward_args[1].type()->annotation_str() + "',";
85   }
86   std::string pre_hook_schema =
87       pre_hook_name + "(self, input: Tuple[" + input_types + "])";
88   std::string return_string =
89       "This error occurred while scripting the forward pre-hook '" +
90       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
91       pre_hook_name + "' on module '" + name()->name() +
92       "'. If you did not want to script this pre-hook remove it from the "
93       "original NN module before scripting. Pre-hooks for module '" +
94     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
95       name()->name() + "' are expected to have the following signature: "
96       + pre_hook_schema + " with a return type of either 'None'" +
97       single_output + " or 'Tuple[" + input_types + "]'.";
98   return return_string;
99 }
100 
getForwardHookErrorMessage(size_t hook_idx) const101 std::string ClassType::getForwardHookErrorMessage(size_t hook_idx) const {
102   const std::string& hook_name = forward_hooks_[hook_idx]->name();
103   const FunctionSchema& forward_schema = getMethod("forward").getSchema();
104   std::string input_types = getSchemaInputTypesString(forward_schema);
105 
106   // create expected output types string
107   const Argument& pre_output =
108       (hook_idx == 0)
109           ? forward_schema.returns()[0]
110           : forward_hooks_[hook_idx - 1]->getSchema().returns()[0];
111   std::string output_types = pre_output.type()->annotation_str();
112   // create error message
113   std::string hook_schema = hook_name + "(self, input: Tuple[" +
114                             input_types + "], output: " + output_types + ")";
115   std::string return_string =
116       "This error occurred while scripting the forward hook '"
117       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
118       + hook_name + "' on module " + name()->name() +
119       ". If you did not want to script this hook remove it from" +
120       " the original NN module before scripting. This hook was" +
121       " expected to have the following signature: " + hook_schema +
122       ". The type of the output arg is the returned type from" +
123       " either the forward method or the previous hook if it exists. " +
124       "Note that hooks can return anything, but if the hook is " +
125       "on a submodule the outer module is expecting" +
126       " the same return type as the submodule's forward.";
127   return return_string;
128 }
129 
isUnresolvedClassAttribute(const std::string & name) const130 bool ClassType::isUnresolvedClassAttribute(const std::string& name) const {
131   return std::find(
132       unresolved_class_attributes_.begin(),
133       unresolved_class_attributes_.end(),
134       name) != unresolved_class_attributes_.end();
135 }
136 
checkForwardHookInputArguments(const FunctionSchema & forward_schema,const FunctionSchema & hook_schema,const std::string & hook_id,const std::string & hook_err_msg)137 static void checkForwardHookInputArguments(
138     const FunctionSchema& forward_schema,
139     const FunctionSchema& hook_schema,
140     const std::string& hook_id,
141     const std::string& hook_err_msg) {
142   // check for proper tuple input types
143   const std::vector<Argument>& forward_args = forward_schema.arguments();
144   const Argument input_arg = hook_schema.arguments()[1];
145   TORCH_CHECK(
146       input_arg.type()->cast<TupleType>() != nullptr,
147       hook_id,
148       "expected the input argument to be typed as a Tuple but found type: '",
149       input_arg.type()->annotation_str(),
150       "' instead.\n",
151       hook_err_msg
152    );
153 
154   const at::ArrayRef<TypePtr> input_tuple_types = input_arg.type()->castRaw<TupleType>()->elements();
155   if (forward_args.size() == 1) {
156     // check for empty forward case
157     TORCH_CHECK(
158         input_tuple_types.empty(),
159         hook_id,
160         "was expecting Tuple[()] as the input type. Received type: '",
161         input_arg.type()->annotation_str(),
162         "'.\n",
163         hook_err_msg
164       );
165   } else {
166     // check input tuple for correct size and correct contained types
167     TORCH_CHECK(
168         input_tuple_types.size() == forward_args.size() - 1,
169         hook_id,
170         "has the wrong number of contained types for the",
171         " input argument's Tuple. Received type: '",
172         input_arg.type()->annotation_str(),
173         "'.\n",
174         hook_err_msg
175     );
176 
177     for (const auto i : c10::irange(1, forward_args.size())) {
178       if (*forward_args[i].type() != *input_tuple_types[i - 1]) {
179         TORCH_CHECK(
180             false,
181             hook_id,
182             "has the wrong inner types for the input tuple argument. Received type: '",
183             input_arg.type()->annotation_str(),
184             "'.\n",
185             hook_err_msg
186         );
187       }
188     }
189   }
190 }
191 
checkForwardPreHookSchema(size_t pre_hook_idx,const FunctionSchema & pre_hook_schema) const192 void ClassType::checkForwardPreHookSchema(
193     size_t pre_hook_idx,
194     const FunctionSchema& pre_hook_schema) const {
195   const torch::jit::Function* pre_hook = forward_pre_hooks_[pre_hook_idx];
196   std::string hook_id =
197       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
198       "Pre-hook '" + pre_hook->name() + "' on module '" + name()->name() + "' ";
199   std::string pre_hook_err_msg = getForwardPreHookErrorMessage(pre_hook_idx) + "\n";
200 
201   // Pre-hooks are expecting two inputs: self, and a Tuple containing the
202   // non-self arguments passed to Forward
203   TORCH_CHECK(
204       pre_hook_schema.arguments().size() == 2,
205       hook_id,
206       "was expected to only have exactly 2 inputs but it had ",
207       pre_hook_schema.arguments().size(),
208       " inputs. ",
209       pre_hook_err_msg
210    );
211 
212   const FunctionSchema& forward_schema = getMethod("forward").getSchema();
213   const std::vector<Argument>& forward_args = forward_schema.arguments();
214   checkForwardHookInputArguments(forward_schema, pre_hook_schema, hook_id, pre_hook_err_msg);
215 
216   // check return type, expected to be either None, the same type as the input,
217   // or the contained single type if the input was a tuple containing a single
218   // type.
219   TORCH_CHECK(
220             !pre_hook_schema.returns().empty(),
221             hook_id,
222             "is missing a return annotation. Return annotations are required, please add one.\n",
223             pre_hook_err_msg
224   );
225   const Argument return_arg = pre_hook_schema.returns()[0];
226   std::string wrong_type_returned_err_msg = hook_id +
227       "returned the wrong type of: '" +
228       return_arg.type()->annotation_str() + "'.";
229 
230   if (return_arg.type()->kind() == NoneType::get()->kind()) {
231     return;
232   }
233   if (forward_args.size() == 2 && *forward_args[1].type() == *return_arg.type()) {
234     // TORCH_CHECK below is for the edge case where forward's input is a tuple and the
235     // pre-hook returns a matching tuple. Eager doesn't support this- the working eager return
236     // for a tuple type is the forward's input tuple wrapped inside of another tuple.
237     TORCH_CHECK(
238         return_arg.type()->cast<TupleType>() == nullptr,
239         wrong_type_returned_err_msg,
240         " When forward has a single tuple input argument, the return needs",
241         " to be 'None' or a nested tuple containing forward's input tuple",
242         " argument as in: 'Tuple[",
243         forward_args[1].type()->annotation_str(),
244         "]'.\n",
245         pre_hook_err_msg
246     );
247     return;
248   }
249   // return can only be tuple of nested types now
250   // check to make sure return is of tuple type
251   TORCH_CHECK(
252       return_arg.type()->cast<TupleType>() != nullptr,
253       wrong_type_returned_err_msg,
254       pre_hook_err_msg
255   );
256   const at::ArrayRef<TypePtr> return_tuple_types =
257       return_arg.type()->castRaw<TupleType>()->elements();
258   // check for edge case of Tuple[()] for when forward has no arguments
259   if (forward_args.size() == 1) {
260     TORCH_CHECK(
261         return_tuple_types.empty(),
262         wrong_type_returned_err_msg,
263         " Was expecting either 'None' or 'Tuple[()]' since forward had ",
264         "no arguments.\n",
265         pre_hook_err_msg
266     );
267     return;
268   }
269 
270   // check that tuple has proper number of contained types
271   TORCH_CHECK(
272       return_tuple_types.size() == forward_args.size() - 1,
273       wrong_type_returned_err_msg,
274       " The returned tuple contains the wrong number of contained types.\n",
275       pre_hook_err_msg
276   );
277   // check that contained types match forward types
278   for (const auto i : c10::irange(1, forward_args.size())) {
279     if (*forward_args[i].type() != *return_tuple_types[i - 1]) {
280       TORCH_CHECK(
281           false,
282           wrong_type_returned_err_msg,
283           " The returned tuple contains the wrong inner types.\n",
284           pre_hook_err_msg);
285     }
286   }
287 }
288 
checkForwardHookSchema(size_t hook_idx,const FunctionSchema & hook_schema) const289 void ClassType::checkForwardHookSchema(
290       size_t hook_idx,
291       const FunctionSchema& hook_schema) const {
292   const torch::jit::Function* hook = forward_hooks_[hook_idx];
293   std::string hook_id =
294       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
295       "Hook '" + hook->name() + "' on module '" + name()->name() + "' ";
296   std::string hook_err_msg = getForwardHookErrorMessage(hook_idx) + "\n";
297   // Hooks are expecting three inputs: self, a Tuple containing the non-self
298   // arguments passed to Forward, and the output of either Forward or the
299   // previous hook
300   TORCH_CHECK(
301       hook_schema.arguments().size() == 3,
302       hook_id,
303       "was expected to only have exactly 3 inputs but it had ",
304       hook_schema.arguments().size(),
305       " inputs. ",
306       hook_err_msg
307   );
308 
309   const FunctionSchema& forward_schema = getMethod("forward").getSchema();
310   checkForwardHookInputArguments(forward_schema, hook_schema, hook_id, hook_err_msg);
311 
312   // check output tuple
313   const Argument& prev_output = (hook_idx == 0)
314             ? forward_schema.returns()[0]
315             : forward_hooks_[hook_idx - 1]->getSchema().returns()[0];
316   const Argument return_arg = hook_schema.arguments()[2];
317 
318   // output tuple needs to match prev_output's return exactly
319   TORCH_CHECK(
320       *prev_output.type() == *return_arg.type(),
321       hook_id,
322       "has the wrong type for the output argument. Received type: '",
323       return_arg.type()->annotation_str(),
324       "'. Expected type: '",
325       prev_output.type()->annotation_str(),
326       "'.\n",
327       hook_err_msg
328   );
329 }
330 
findMethod(const std::string & name) const331 torch::jit::Function* ClassType::findMethod(const std::string& name) const {
332   for (auto method : methods_) {
333     if (name == method->name()) {
334       return method;
335     }
336   }
337   return nullptr;
338 }
getMethod(const std::string & name) const339 torch::jit::Function& ClassType::getMethod(const std::string& name) const {
340   auto method = findMethod(name);
341   TORCH_CHECK(
342       method != nullptr,
343       "Couldn't find method: '",
344       name,
345       "' on class: '",
346       repr_str(),
347       "'");
348   return *method;
349 }
350 
findHook(const std::string & name) const351 torch::jit::Function* ClassType::findHook(const std::string& name) const {
352   auto hook = findForwardHook(name);
353   if (hook == nullptr) {
354     hook = findForwardPreHook(name);
355   }
356   return hook;
357 }
358 
getHook(const std::string & name) const359 torch::jit::Function& ClassType::getHook(const std::string& name) const {
360   torch::jit::Function* function = findHook(name);
361   TORCH_CHECK(
362       function != nullptr,
363       "Couldn't find: '",
364       name,
365       "' on class: '",
366       repr_str(),
367       "'as forward hook or forward pre_hook.");
368   return *function;
369 }
370 
hasMethod(const std::string & name) const371 bool ClassType::hasMethod(const std::string& name) const {
372   return findMethod(name) != nullptr;
373 }
374 
addStaticMethod(torch::jit::Function * method)375 void ClassType::addStaticMethod(torch::jit::Function* method) {
376   TORCH_CHECK(
377       findStaticMethod(method->name()) == nullptr &&
378           findMethod(method->name()) == nullptr, "Can't redefine method: ",
379       method->name(),
380       " on class: ",
381       repr_str());
382   staticmethods_.emplace_back(method);
383 }
384 
findStaticMethod(const std::string & name) const385 torch::jit::Function* ClassType::findStaticMethod(const std::string& name) const {
386   for (auto method : staticmethods_) {
387     if (name == method->name()) {
388       return method;
389     }
390   }
391   return nullptr;
392 }
393 
unsafeRemoveMethod(const std::string & name)394 void ClassType::unsafeRemoveMethod(const std::string& name) {
395   size_t slot = 0;
396   for (auto method : methods_) {
397     if (method->name() == name) {
398       methods_.erase(methods_.begin() + static_cast<std::ptrdiff_t>(slot));
399       return;
400     }
401     slot++;
402   }
403   TORCH_CHECK(
404       false,
405       "Can't delete undefined method ",
406       name,
407       " on class: ",
408       repr_str());
409 }
410 
refine(at::ArrayRef<TypePtr> refined_slots) const411 ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
412   auto ptr = ClassType::create(name(), compilation_unit_, is_module());
413   AT_ASSERT(numAttributes() == refined_slots.size());
414   for (size_t i = 0; i < attributes_.size(); ++i) {
415     AT_ASSERT(refined_slots[i]->isSubtypeOf(*attributes_[i].getType()));
416     ptr->addAttribute(attributes_[i].getName(), refined_slots[i], (attributes_[i].getKind() == AttributeKind::PARAMETER),
417     (attributes_[i].getKind() == AttributeKind::BUFFER));
418   }
419   // Copy methods over
420   for (const auto& method : methods()) {
421     ptr->addMethod(method);
422   }
423   return ptr;
424 }
425 
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const426 bool ClassType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
427   if (rhs.castRaw<AnyClassType>()) {
428     return true;
429   }
430   // to improve performance, this check can be cached
431   if (auto iface = rhs.cast<InterfaceType>()) {
432     // ClassType is not a subtype of InterfaceType if the InterfaceType is a
433     // Module Interface Type but the Class Type is not a Module Class Type
434     if (!is_module() && iface->is_module()) {
435       if (why_not) {
436         *why_not << "Class '" << repr_str() << "' is not a subtype of "
437                  << "the module interface '" << rhs.repr_str()
438                  << "' , only ScriptModule class can be subtype of module"
439                  << " interface.\n";
440       }
441       return false;
442     }
443     for (const FunctionSchema& schema : iface->methods()) {
444       auto self_method = findMethod(schema.name());
445       if (!self_method) {
446         if (why_not) {
447           *why_not << "Class '" << repr_str() << "' does not have method '"
448                    << schema.name() << "' but '" << rhs.repr_str()
449                    << "' does.\n";
450         }
451         return false;
452       }
453       if (!self_method->getSchema().isSubtypeOf(
454               schema, /*as_method=*/true, why_not)) {
455         if (why_not) {
456           *why_not << "Method on class '" << repr_str()
457                    << "' (1) is not compatible with interface '"
458                    << rhs.repr_str() << "' (2)\n"
459                    << "  (1) " << self_method->getSchema() << "\n"
460                    << "  (2) " << schema << "\n";
461         }
462         return false;
463       }
464     }
465     return true;
466   }
467   return Type::isSubtypeOfExt(rhs, why_not);
468 }
469 
create(std::optional<QualifiedName> qualifiedName,std::weak_ptr<CompilationUnit> cu,bool is_module,std::string doc_string,std::vector<std::string> unresolved_class_attributes)470 ClassTypePtr ClassType::create(
471     std::optional<QualifiedName> qualifiedName,
472     std::weak_ptr<CompilationUnit> cu,
473     bool is_module,
474     std::string doc_string,
475     std::vector<std::string> unresolved_class_attributes) {
476   return ClassTypePtr(new ClassType(
477       std::move(qualifiedName),
478       std::move(cu),
479       is_module,
480       std::move(doc_string),
481       std::move(unresolved_class_attributes)));
482 }
483 
ClassType(std::optional<QualifiedName> name,std::weak_ptr<CompilationUnit> cu,bool is_module,std::string doc_string,std::vector<std::string> unresolved_class_attributes)484 ClassType::ClassType(
485     std::optional<QualifiedName> name,
486     std::weak_ptr<CompilationUnit> cu,
487     bool is_module,
488     std::string doc_string,
489     std::vector<std::string> unresolved_class_attributes)
490     : NamedType(TypeKind::ClassType, std::move(name)),
491       compilation_unit_(std::move(cu)),
492       isModule_(is_module),
493       doc_string_(std::move(doc_string)),
494       unresolved_class_attributes_(std::move(unresolved_class_attributes)) {}
495 
methods() const496 const std::vector<torch::jit::Function*>& ClassType::methods() const {
497   return methods_;
498 }
499 
checkNotExist(const std::string & name,const std::string & what) const500 void ClassType::checkNotExist(const std::string& name, const std::string& what) const {
501   // Check no overlap with existing constants
502   for (size_t i = 0; i < constantNames_.size(); ++i) {
503     TORCH_CHECK(
504         name != constantNames_[i],
505         "attempting to add ",
506         what,
507         " '",
508         name,
509         "' to ",
510         repr_str(),
511         " but a constant field of the same name already exists with value ",
512         constantValues_[i]);
513   }
514 
515   // Check no overlap with existing attributes
516   for (const auto & attribute : attributes_) {
517     TORCH_CHECK(
518         name != attribute.getName(),
519         "attempting to add ",
520         what,
521         " '",
522         name,
523         "' to ",
524         repr_str(),
525         " but an attribute field of the same name already exists with type ",
526         attribute.getType()->repr_str());
527   }
528 }
529 
addAttribute(ClassAttribute classAttribute)530 void ClassType::addAttribute(ClassAttribute classAttribute) {
531     AT_ASSERT(attributes_.size() == attributeTypes_.size());
532     attributeTypes_.emplace_back(classAttribute.getType());
533     attributes_.emplace_back(std::move(classAttribute));
534 }
535 
addAttribute(const std::string & name,TypePtr type,bool is_parameter,bool is_buffer)536 size_t ClassType::addAttribute(
537     const std::string& name,
538     TypePtr type,
539     bool is_parameter,
540     bool is_buffer) {
541   if (is_parameter && is_buffer){
542     TORCH_INTERNAL_ASSERT(false, "Attribute cannot be both a parameter and a buffer!");
543   }
544 
545   std::string what = is_parameter ? "parameter" : "attribute";
546   what += (is_buffer? "buffer" : "not buffer");
547   checkNotExist(name, what);
548 
549   size_t slot = attributes_.size();
550 
551   AttributeKind kind = AttributeKind::REGULAR_ATTRIBUTE;
552   if (is_parameter) {
553     kind = AttributeKind::PARAMETER;
554   } else if (is_buffer) {
555     kind = AttributeKind::BUFFER;
556   }
557 
558 
559   if (is_parameter || is_buffer) {
560     TORCH_INTERNAL_ASSERT(is_module(), "adding a parameter or buffer to a non module");
561     TORCH_CHECK(
562         (type->kind() == TensorType::Kind) ||
563             (type->kind() == OptionalType::Kind &&
564             type->expectRef<OptionalType>().getElementType()->kind() ==
565                 TensorType::Kind) ||
566             (type->kind() == UnionType::Kind &&
567             TensorType::get()->isSubtypeOf(type->expectRef<UnionType>())) ||
568             (type->kind() == NoneType::Kind),
569         "Expecting parameter or buffer to have either None, Tensor or Optional[Tensor] type, but got: ",
570         toString(type));
571   }
572 
573   addAttribute(ClassAttribute(kind, std::move(type), name));
574 
575   return slot;
576 }
577 
unsafeRemoveAttribute(const std::string & name)578 void ClassType::unsafeRemoveAttribute(const std::string& name) {
579   auto slot = getAttributeSlot(name);
580   attributes_.erase(attributes_.begin() + static_cast<std::ptrdiff_t>(slot));
581   attributeTypes_.erase(attributeTypes_.begin() + static_cast<std::ptrdiff_t>(slot));
582   AT_ASSERT(attributes_.size() == attributeTypes_.size());
583 }
584 
unsafeChangeAttributeType(const std::string & name,const TypePtr & new_ty)585 void ClassType::unsafeChangeAttributeType(const std::string& name, const TypePtr& new_ty) {
586   auto slot = getAttributeSlot(name);
587   auto old_attr_info = attributes_[slot];
588   AT_ASSERT(old_attr_info.getKind() == AttributeKind::REGULAR_ATTRIBUTE);
589   attributes_[slot] = ClassAttribute(old_attr_info.getKind(), new_ty, old_attr_info.getName());
590   attributeTypes_[slot] = new_ty;
591 }
592 
addConstant(const std::string & name,const IValue & value)593 size_t ClassType::addConstant(const std::string& name, const IValue& value) {
594   checkNotExist(name, "constant");
595   size_t slot = constantNames_.size();
596   constantNames_.push_back(name);
597   constantValues_.push_back(value);
598   return slot;
599 }
600 
getConstant(const std::string & name) const601 IValue ClassType::getConstant(const std::string& name) const {
602   const auto& v = findConstant(name);
603   TORCH_CHECK(
604       v.has_value(),
605       repr_str(),
606       " does not have a constant field with name '",
607       name,
608       "'");
609   return *v;
610 }
611 
getConstant(size_t slot) const612 IValue ClassType::getConstant(size_t slot) const {
613   TORCH_INTERNAL_ASSERT(constantNames_.size() == constantValues_.size());
614   TORCH_CHECK(
615       slot < constantValues_.size(),
616       repr_str(),
617       " does not have a constant slot of index ",
618       slot);
619   return constantValues_[slot];
620 }
621 
findConstant(const std::string & name) const622 std::optional<IValue> ClassType::findConstant(const std::string& name) const {
623   TORCH_INTERNAL_ASSERT(constantNames_.size() == constantValues_.size());
624   size_t pos = 0;
625   for (const auto& c : constantNames_) {
626     if (name == c) {
627       break;
628     }
629     ++pos;
630   }
631 
632   if (pos >= constantNames_.size()) {
633     return std::nullopt;
634   }
635   return constantValues_[pos];
636 }
637 
unsafeRemoveConstant(const std::string & name)638 void ClassType::unsafeRemoveConstant(const std::string& name) {
639   auto slot = getConstantSlot(name);
640   constantNames_.erase(constantNames_.begin() + static_cast<std::ptrdiff_t>(slot));
641   constantValues_.erase(constantValues_.begin() + static_cast<std::ptrdiff_t>(slot));
642 }
643 
compilation_unit()644 std::shared_ptr<CompilationUnit> ClassType::compilation_unit() {
645   auto cu = compilation_unit_.lock();
646   return cu;
647 }
648 
compilation_unit() const649 std::shared_ptr<const CompilationUnit> ClassType::compilation_unit() const {
650   auto cu = compilation_unit_.lock();
651   return cu;
652 }
653 
getProperty(const std::string & name)654 std::optional<ClassType::Property> ClassType::getProperty(const std::string& name) {
655   for (auto& prop : properties_) {
656     if (name == prop.name) {
657       return prop;
658     }
659   }
660 
661   return std::nullopt;
662 }
663 
addProperty(const std::string & name,torch::jit::Function * getter,torch::jit::Function * setter)664 void ClassType::addProperty(const std::string& name, torch::jit::Function* getter, torch::jit::Function* setter) {
665   TORCH_INTERNAL_ASSERT(!getProperty(name), "Property named ", name, " already exists!");
666   properties_.push_back({name, getter, setter});
667 }
668 
findConstantSlot(const std::string & name) const669 std::optional<size_t> ClassType::findConstantSlot(const std::string& name) const {
670   TORCH_CHECK(constantNames_.size() == constantValues_.size());
671   size_t slot = 0;
672   for (const auto& constant : constantNames_) {
673     if (name == constant) {
674       return slot;
675     }
676     slot++;
677   }
678   return std::nullopt;
679 }
680 
getConstantName(size_t slot) const681 const std::string& ClassType::getConstantName(size_t slot) const {
682   TORCH_CHECK(constantNames_.size() == constantValues_.size());
683   TORCH_CHECK(slot < constantNames_.size());
684   return constantNames_[slot];
685 }
686 
numConstants() const687 size_t ClassType::numConstants() const {
688   TORCH_INTERNAL_ASSERT(constantNames_.size() == constantValues_.size());
689   return constantNames_.size();
690 }
691 
constantValues() const692 at::ArrayRef<IValue> ClassType::constantValues() const {
693   return constantValues_;
694 }
695 
696 } // namespace c10
697