1 #include <ATen/core/class_type.h>
2
3 #include <ATen/core/Dict.h>
4 #include <ATen/core/Tensor.h>
5 #include <ATen/core/function_schema.h>
6 #include <ATen/core/ivalue.h>
7 #include <c10/macros/Macros.h>
8 #include <c10/util/irange.h>
9 #include <ATen/core/grad_mode.h>
10 #include <ATen/core/function.h>
11
12 namespace c10 {
13
addMethod(torch::jit::Function * method)14 void ClassType::addMethod(torch::jit::Function* method) {
15 TORCH_CHECK(
16 findMethod(method->name()) == nullptr,
17 "Can't redefine method: ",
18 method->name(),
19 " on class: ",
20 repr_str());
21 methods_.push_back(method);
22 }
23
getForwardHooks() const24 const std::vector<torch::jit::Function*>& ClassType::getForwardHooks() const {
25 return forward_hooks_;
26 }
27
getForwardPreHooks() const28 const std::vector<torch::jit::Function*>& ClassType::getForwardPreHooks() const {
29 return forward_pre_hooks_;
30 }
31
addForwardPreHook(torch::jit::Function * pre_hook_ptr)32 void ClassType::addForwardPreHook(torch::jit::Function* pre_hook_ptr) {
33 forward_pre_hooks_.emplace_back(pre_hook_ptr);
34 }
35
addForwardHook(torch::jit::Function * hook_ptr)36 void ClassType::addForwardHook(torch::jit::Function* hook_ptr) {
37 forward_hooks_.emplace_back(hook_ptr);
38 }
39
findForwardPreHook(const std::string & name) const40 torch::jit::Function* ClassType::findForwardPreHook(const std::string& name) const {
41 for (const auto& pre_hook : forward_pre_hooks_) {
42 if (name == pre_hook->name()) {
43 return pre_hook;
44 }
45 }
46 return nullptr;
47 }
48
findForwardHook(const std::string & name) const49 torch::jit::Function* ClassType::findForwardHook(const std::string& name) const {
50 for (const auto& hook : forward_hooks_) {
51 if (name == hook->name()) {
52 return hook;
53 }
54 }
55 return nullptr;
56 }
57
getSchemaInputTypesString(const FunctionSchema & schema)58 static std::string getSchemaInputTypesString(const FunctionSchema& schema) {
59 std::stringstream input_types;
60 const std::vector<Argument>& forward_args = schema.arguments();
61 for (const auto i : c10::irange(1, forward_args.size())) {
62 input_types << forward_args[i].type()->annotation_str();
63 if (forward_args.size() - 1 != i) {
64 input_types << ", ";
65 }
66 }
67 if (forward_args.size() == 1) {
68 input_types << "()";
69 }
70 return input_types.str();
71 }
72
getForwardPreHookErrorMessage(size_t pre_hook_idx) const73 std::string ClassType::getForwardPreHookErrorMessage(size_t pre_hook_idx) const {
74 const std::string& pre_hook_name = forward_pre_hooks_[pre_hook_idx]->name();
75 const FunctionSchema& forward_schema = getMethod("forward").getSchema();
76 std::string input_types = getSchemaInputTypesString(forward_schema);
77 const std::vector<Argument>& forward_args = forward_schema.arguments();
78
79 std::string single_output = "";
80 if (forward_args.size() == 2 &&
81 forward_args[1].type()->cast<TupleType>() == nullptr) {
82 // if the output type is a single tuple, it needs to be wrapped in an outer tuple
83 // to match eager's behavior
84 single_output = ", '" + forward_args[1].type()->annotation_str() + "',";
85 }
86 std::string pre_hook_schema =
87 pre_hook_name + "(self, input: Tuple[" + input_types + "])";
88 std::string return_string =
89 "This error occurred while scripting the forward pre-hook '" +
90 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
91 pre_hook_name + "' on module '" + name()->name() +
92 "'. If you did not want to script this pre-hook remove it from the "
93 "original NN module before scripting. Pre-hooks for module '" +
94 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
95 name()->name() + "' are expected to have the following signature: "
96 + pre_hook_schema + " with a return type of either 'None'" +
97 single_output + " or 'Tuple[" + input_types + "]'.";
98 return return_string;
99 }
100
getForwardHookErrorMessage(size_t hook_idx) const101 std::string ClassType::getForwardHookErrorMessage(size_t hook_idx) const {
102 const std::string& hook_name = forward_hooks_[hook_idx]->name();
103 const FunctionSchema& forward_schema = getMethod("forward").getSchema();
104 std::string input_types = getSchemaInputTypesString(forward_schema);
105
106 // create expected output types string
107 const Argument& pre_output =
108 (hook_idx == 0)
109 ? forward_schema.returns()[0]
110 : forward_hooks_[hook_idx - 1]->getSchema().returns()[0];
111 std::string output_types = pre_output.type()->annotation_str();
112 // create error message
113 std::string hook_schema = hook_name + "(self, input: Tuple[" +
114 input_types + "], output: " + output_types + ")";
115 std::string return_string =
116 "This error occurred while scripting the forward hook '"
117 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
118 + hook_name + "' on module " + name()->name() +
119 ". If you did not want to script this hook remove it from" +
120 " the original NN module before scripting. This hook was" +
121 " expected to have the following signature: " + hook_schema +
122 ". The type of the output arg is the returned type from" +
123 " either the forward method or the previous hook if it exists. " +
124 "Note that hooks can return anything, but if the hook is " +
125 "on a submodule the outer module is expecting" +
126 " the same return type as the submodule's forward.";
127 return return_string;
128 }
129
isUnresolvedClassAttribute(const std::string & name) const130 bool ClassType::isUnresolvedClassAttribute(const std::string& name) const {
131 return std::find(
132 unresolved_class_attributes_.begin(),
133 unresolved_class_attributes_.end(),
134 name) != unresolved_class_attributes_.end();
135 }
136
checkForwardHookInputArguments(const FunctionSchema & forward_schema,const FunctionSchema & hook_schema,const std::string & hook_id,const std::string & hook_err_msg)137 static void checkForwardHookInputArguments(
138 const FunctionSchema& forward_schema,
139 const FunctionSchema& hook_schema,
140 const std::string& hook_id,
141 const std::string& hook_err_msg) {
142 // check for proper tuple input types
143 const std::vector<Argument>& forward_args = forward_schema.arguments();
144 const Argument input_arg = hook_schema.arguments()[1];
145 TORCH_CHECK(
146 input_arg.type()->cast<TupleType>() != nullptr,
147 hook_id,
148 "expected the input argument to be typed as a Tuple but found type: '",
149 input_arg.type()->annotation_str(),
150 "' instead.\n",
151 hook_err_msg
152 );
153
154 const at::ArrayRef<TypePtr> input_tuple_types = input_arg.type()->castRaw<TupleType>()->elements();
155 if (forward_args.size() == 1) {
156 // check for empty forward case
157 TORCH_CHECK(
158 input_tuple_types.empty(),
159 hook_id,
160 "was expecting Tuple[()] as the input type. Received type: '",
161 input_arg.type()->annotation_str(),
162 "'.\n",
163 hook_err_msg
164 );
165 } else {
166 // check input tuple for correct size and correct contained types
167 TORCH_CHECK(
168 input_tuple_types.size() == forward_args.size() - 1,
169 hook_id,
170 "has the wrong number of contained types for the",
171 " input argument's Tuple. Received type: '",
172 input_arg.type()->annotation_str(),
173 "'.\n",
174 hook_err_msg
175 );
176
177 for (const auto i : c10::irange(1, forward_args.size())) {
178 if (*forward_args[i].type() != *input_tuple_types[i - 1]) {
179 TORCH_CHECK(
180 false,
181 hook_id,
182 "has the wrong inner types for the input tuple argument. Received type: '",
183 input_arg.type()->annotation_str(),
184 "'.\n",
185 hook_err_msg
186 );
187 }
188 }
189 }
190 }
191
checkForwardPreHookSchema(size_t pre_hook_idx,const FunctionSchema & pre_hook_schema) const192 void ClassType::checkForwardPreHookSchema(
193 size_t pre_hook_idx,
194 const FunctionSchema& pre_hook_schema) const {
195 const torch::jit::Function* pre_hook = forward_pre_hooks_[pre_hook_idx];
196 std::string hook_id =
197 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
198 "Pre-hook '" + pre_hook->name() + "' on module '" + name()->name() + "' ";
199 std::string pre_hook_err_msg = getForwardPreHookErrorMessage(pre_hook_idx) + "\n";
200
201 // Pre-hooks are expecting two inputs: self, and a Tuple containing the
202 // non-self arguments passed to Forward
203 TORCH_CHECK(
204 pre_hook_schema.arguments().size() == 2,
205 hook_id,
206 "was expected to only have exactly 2 inputs but it had ",
207 pre_hook_schema.arguments().size(),
208 " inputs. ",
209 pre_hook_err_msg
210 );
211
212 const FunctionSchema& forward_schema = getMethod("forward").getSchema();
213 const std::vector<Argument>& forward_args = forward_schema.arguments();
214 checkForwardHookInputArguments(forward_schema, pre_hook_schema, hook_id, pre_hook_err_msg);
215
216 // check return type, expected to be either None, the same type as the input,
217 // or the contained single type if the input was a tuple containing a single
218 // type.
219 TORCH_CHECK(
220 !pre_hook_schema.returns().empty(),
221 hook_id,
222 "is missing a return annotation. Return annotations are required, please add one.\n",
223 pre_hook_err_msg
224 );
225 const Argument return_arg = pre_hook_schema.returns()[0];
226 std::string wrong_type_returned_err_msg = hook_id +
227 "returned the wrong type of: '" +
228 return_arg.type()->annotation_str() + "'.";
229
230 if (return_arg.type()->kind() == NoneType::get()->kind()) {
231 return;
232 }
233 if (forward_args.size() == 2 && *forward_args[1].type() == *return_arg.type()) {
234 // TORCH_CHECK below is for the edge case where forward's input is a tuple and the
235 // pre-hook returns a matching tuple. Eager doesn't support this- the working eager return
236 // for a tuple type is the forward's input tuple wrapped inside of another tuple.
237 TORCH_CHECK(
238 return_arg.type()->cast<TupleType>() == nullptr,
239 wrong_type_returned_err_msg,
240 " When forward has a single tuple input argument, the return needs",
241 " to be 'None' or a nested tuple containing forward's input tuple",
242 " argument as in: 'Tuple[",
243 forward_args[1].type()->annotation_str(),
244 "]'.\n",
245 pre_hook_err_msg
246 );
247 return;
248 }
249 // return can only be tuple of nested types now
250 // check to make sure return is of tuple type
251 TORCH_CHECK(
252 return_arg.type()->cast<TupleType>() != nullptr,
253 wrong_type_returned_err_msg,
254 pre_hook_err_msg
255 );
256 const at::ArrayRef<TypePtr> return_tuple_types =
257 return_arg.type()->castRaw<TupleType>()->elements();
258 // check for edge case of Tuple[()] for when forward has no arguments
259 if (forward_args.size() == 1) {
260 TORCH_CHECK(
261 return_tuple_types.empty(),
262 wrong_type_returned_err_msg,
263 " Was expecting either 'None' or 'Tuple[()]' since forward had ",
264 "no arguments.\n",
265 pre_hook_err_msg
266 );
267 return;
268 }
269
270 // check that tuple has proper number of contained types
271 TORCH_CHECK(
272 return_tuple_types.size() == forward_args.size() - 1,
273 wrong_type_returned_err_msg,
274 " The returned tuple contains the wrong number of contained types.\n",
275 pre_hook_err_msg
276 );
277 // check that contained types match forward types
278 for (const auto i : c10::irange(1, forward_args.size())) {
279 if (*forward_args[i].type() != *return_tuple_types[i - 1]) {
280 TORCH_CHECK(
281 false,
282 wrong_type_returned_err_msg,
283 " The returned tuple contains the wrong inner types.\n",
284 pre_hook_err_msg);
285 }
286 }
287 }
288
checkForwardHookSchema(size_t hook_idx,const FunctionSchema & hook_schema) const289 void ClassType::checkForwardHookSchema(
290 size_t hook_idx,
291 const FunctionSchema& hook_schema) const {
292 const torch::jit::Function* hook = forward_hooks_[hook_idx];
293 std::string hook_id =
294 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
295 "Hook '" + hook->name() + "' on module '" + name()->name() + "' ";
296 std::string hook_err_msg = getForwardHookErrorMessage(hook_idx) + "\n";
297 // Hooks are expecting three inputs: self, a Tuple containing the non-self
298 // arguments passed to Forward, and the output of either Forward or the
299 // previous hook
300 TORCH_CHECK(
301 hook_schema.arguments().size() == 3,
302 hook_id,
303 "was expected to only have exactly 3 inputs but it had ",
304 hook_schema.arguments().size(),
305 " inputs. ",
306 hook_err_msg
307 );
308
309 const FunctionSchema& forward_schema = getMethod("forward").getSchema();
310 checkForwardHookInputArguments(forward_schema, hook_schema, hook_id, hook_err_msg);
311
312 // check output tuple
313 const Argument& prev_output = (hook_idx == 0)
314 ? forward_schema.returns()[0]
315 : forward_hooks_[hook_idx - 1]->getSchema().returns()[0];
316 const Argument return_arg = hook_schema.arguments()[2];
317
318 // output tuple needs to match prev_output's return exactly
319 TORCH_CHECK(
320 *prev_output.type() == *return_arg.type(),
321 hook_id,
322 "has the wrong type for the output argument. Received type: '",
323 return_arg.type()->annotation_str(),
324 "'. Expected type: '",
325 prev_output.type()->annotation_str(),
326 "'.\n",
327 hook_err_msg
328 );
329 }
330
findMethod(const std::string & name) const331 torch::jit::Function* ClassType::findMethod(const std::string& name) const {
332 for (auto method : methods_) {
333 if (name == method->name()) {
334 return method;
335 }
336 }
337 return nullptr;
338 }
getMethod(const std::string & name) const339 torch::jit::Function& ClassType::getMethod(const std::string& name) const {
340 auto method = findMethod(name);
341 TORCH_CHECK(
342 method != nullptr,
343 "Couldn't find method: '",
344 name,
345 "' on class: '",
346 repr_str(),
347 "'");
348 return *method;
349 }
350
findHook(const std::string & name) const351 torch::jit::Function* ClassType::findHook(const std::string& name) const {
352 auto hook = findForwardHook(name);
353 if (hook == nullptr) {
354 hook = findForwardPreHook(name);
355 }
356 return hook;
357 }
358
getHook(const std::string & name) const359 torch::jit::Function& ClassType::getHook(const std::string& name) const {
360 torch::jit::Function* function = findHook(name);
361 TORCH_CHECK(
362 function != nullptr,
363 "Couldn't find: '",
364 name,
365 "' on class: '",
366 repr_str(),
367 "'as forward hook or forward pre_hook.");
368 return *function;
369 }
370
hasMethod(const std::string & name) const371 bool ClassType::hasMethod(const std::string& name) const {
372 return findMethod(name) != nullptr;
373 }
374
addStaticMethod(torch::jit::Function * method)375 void ClassType::addStaticMethod(torch::jit::Function* method) {
376 TORCH_CHECK(
377 findStaticMethod(method->name()) == nullptr &&
378 findMethod(method->name()) == nullptr, "Can't redefine method: ",
379 method->name(),
380 " on class: ",
381 repr_str());
382 staticmethods_.emplace_back(method);
383 }
384
findStaticMethod(const std::string & name) const385 torch::jit::Function* ClassType::findStaticMethod(const std::string& name) const {
386 for (auto method : staticmethods_) {
387 if (name == method->name()) {
388 return method;
389 }
390 }
391 return nullptr;
392 }
393
unsafeRemoveMethod(const std::string & name)394 void ClassType::unsafeRemoveMethod(const std::string& name) {
395 size_t slot = 0;
396 for (auto method : methods_) {
397 if (method->name() == name) {
398 methods_.erase(methods_.begin() + static_cast<std::ptrdiff_t>(slot));
399 return;
400 }
401 slot++;
402 }
403 TORCH_CHECK(
404 false,
405 "Can't delete undefined method ",
406 name,
407 " on class: ",
408 repr_str());
409 }
410
refine(at::ArrayRef<TypePtr> refined_slots) const411 ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
412 auto ptr = ClassType::create(name(), compilation_unit_, is_module());
413 AT_ASSERT(numAttributes() == refined_slots.size());
414 for (size_t i = 0; i < attributes_.size(); ++i) {
415 AT_ASSERT(refined_slots[i]->isSubtypeOf(*attributes_[i].getType()));
416 ptr->addAttribute(attributes_[i].getName(), refined_slots[i], (attributes_[i].getKind() == AttributeKind::PARAMETER),
417 (attributes_[i].getKind() == AttributeKind::BUFFER));
418 }
419 // Copy methods over
420 for (const auto& method : methods()) {
421 ptr->addMethod(method);
422 }
423 return ptr;
424 }
425
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const426 bool ClassType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
427 if (rhs.castRaw<AnyClassType>()) {
428 return true;
429 }
430 // to improve performance, this check can be cached
431 if (auto iface = rhs.cast<InterfaceType>()) {
432 // ClassType is not a subtype of InterfaceType if the InterfaceType is a
433 // Module Interface Type but the Class Type is not a Module Class Type
434 if (!is_module() && iface->is_module()) {
435 if (why_not) {
436 *why_not << "Class '" << repr_str() << "' is not a subtype of "
437 << "the module interface '" << rhs.repr_str()
438 << "' , only ScriptModule class can be subtype of module"
439 << " interface.\n";
440 }
441 return false;
442 }
443 for (const FunctionSchema& schema : iface->methods()) {
444 auto self_method = findMethod(schema.name());
445 if (!self_method) {
446 if (why_not) {
447 *why_not << "Class '" << repr_str() << "' does not have method '"
448 << schema.name() << "' but '" << rhs.repr_str()
449 << "' does.\n";
450 }
451 return false;
452 }
453 if (!self_method->getSchema().isSubtypeOf(
454 schema, /*as_method=*/true, why_not)) {
455 if (why_not) {
456 *why_not << "Method on class '" << repr_str()
457 << "' (1) is not compatible with interface '"
458 << rhs.repr_str() << "' (2)\n"
459 << " (1) " << self_method->getSchema() << "\n"
460 << " (2) " << schema << "\n";
461 }
462 return false;
463 }
464 }
465 return true;
466 }
467 return Type::isSubtypeOfExt(rhs, why_not);
468 }
469
create(std::optional<QualifiedName> qualifiedName,std::weak_ptr<CompilationUnit> cu,bool is_module,std::string doc_string,std::vector<std::string> unresolved_class_attributes)470 ClassTypePtr ClassType::create(
471 std::optional<QualifiedName> qualifiedName,
472 std::weak_ptr<CompilationUnit> cu,
473 bool is_module,
474 std::string doc_string,
475 std::vector<std::string> unresolved_class_attributes) {
476 return ClassTypePtr(new ClassType(
477 std::move(qualifiedName),
478 std::move(cu),
479 is_module,
480 std::move(doc_string),
481 std::move(unresolved_class_attributes)));
482 }
483
ClassType(std::optional<QualifiedName> name,std::weak_ptr<CompilationUnit> cu,bool is_module,std::string doc_string,std::vector<std::string> unresolved_class_attributes)484 ClassType::ClassType(
485 std::optional<QualifiedName> name,
486 std::weak_ptr<CompilationUnit> cu,
487 bool is_module,
488 std::string doc_string,
489 std::vector<std::string> unresolved_class_attributes)
490 : NamedType(TypeKind::ClassType, std::move(name)),
491 compilation_unit_(std::move(cu)),
492 isModule_(is_module),
493 doc_string_(std::move(doc_string)),
494 unresolved_class_attributes_(std::move(unresolved_class_attributes)) {}
495
methods() const496 const std::vector<torch::jit::Function*>& ClassType::methods() const {
497 return methods_;
498 }
499
checkNotExist(const std::string & name,const std::string & what) const500 void ClassType::checkNotExist(const std::string& name, const std::string& what) const {
501 // Check no overlap with existing constants
502 for (size_t i = 0; i < constantNames_.size(); ++i) {
503 TORCH_CHECK(
504 name != constantNames_[i],
505 "attempting to add ",
506 what,
507 " '",
508 name,
509 "' to ",
510 repr_str(),
511 " but a constant field of the same name already exists with value ",
512 constantValues_[i]);
513 }
514
515 // Check no overlap with existing attributes
516 for (const auto & attribute : attributes_) {
517 TORCH_CHECK(
518 name != attribute.getName(),
519 "attempting to add ",
520 what,
521 " '",
522 name,
523 "' to ",
524 repr_str(),
525 " but an attribute field of the same name already exists with type ",
526 attribute.getType()->repr_str());
527 }
528 }
529
addAttribute(ClassAttribute classAttribute)530 void ClassType::addAttribute(ClassAttribute classAttribute) {
531 AT_ASSERT(attributes_.size() == attributeTypes_.size());
532 attributeTypes_.emplace_back(classAttribute.getType());
533 attributes_.emplace_back(std::move(classAttribute));
534 }
535
addAttribute(const std::string & name,TypePtr type,bool is_parameter,bool is_buffer)536 size_t ClassType::addAttribute(
537 const std::string& name,
538 TypePtr type,
539 bool is_parameter,
540 bool is_buffer) {
541 if (is_parameter && is_buffer){
542 TORCH_INTERNAL_ASSERT(false, "Attribute cannot be both a parameter and a buffer!");
543 }
544
545 std::string what = is_parameter ? "parameter" : "attribute";
546 what += (is_buffer? "buffer" : "not buffer");
547 checkNotExist(name, what);
548
549 size_t slot = attributes_.size();
550
551 AttributeKind kind = AttributeKind::REGULAR_ATTRIBUTE;
552 if (is_parameter) {
553 kind = AttributeKind::PARAMETER;
554 } else if (is_buffer) {
555 kind = AttributeKind::BUFFER;
556 }
557
558
559 if (is_parameter || is_buffer) {
560 TORCH_INTERNAL_ASSERT(is_module(), "adding a parameter or buffer to a non module");
561 TORCH_CHECK(
562 (type->kind() == TensorType::Kind) ||
563 (type->kind() == OptionalType::Kind &&
564 type->expectRef<OptionalType>().getElementType()->kind() ==
565 TensorType::Kind) ||
566 (type->kind() == UnionType::Kind &&
567 TensorType::get()->isSubtypeOf(type->expectRef<UnionType>())) ||
568 (type->kind() == NoneType::Kind),
569 "Expecting parameter or buffer to have either None, Tensor or Optional[Tensor] type, but got: ",
570 toString(type));
571 }
572
573 addAttribute(ClassAttribute(kind, std::move(type), name));
574
575 return slot;
576 }
577
unsafeRemoveAttribute(const std::string & name)578 void ClassType::unsafeRemoveAttribute(const std::string& name) {
579 auto slot = getAttributeSlot(name);
580 attributes_.erase(attributes_.begin() + static_cast<std::ptrdiff_t>(slot));
581 attributeTypes_.erase(attributeTypes_.begin() + static_cast<std::ptrdiff_t>(slot));
582 AT_ASSERT(attributes_.size() == attributeTypes_.size());
583 }
584
unsafeChangeAttributeType(const std::string & name,const TypePtr & new_ty)585 void ClassType::unsafeChangeAttributeType(const std::string& name, const TypePtr& new_ty) {
586 auto slot = getAttributeSlot(name);
587 auto old_attr_info = attributes_[slot];
588 AT_ASSERT(old_attr_info.getKind() == AttributeKind::REGULAR_ATTRIBUTE);
589 attributes_[slot] = ClassAttribute(old_attr_info.getKind(), new_ty, old_attr_info.getName());
590 attributeTypes_[slot] = new_ty;
591 }
592
addConstant(const std::string & name,const IValue & value)593 size_t ClassType::addConstant(const std::string& name, const IValue& value) {
594 checkNotExist(name, "constant");
595 size_t slot = constantNames_.size();
596 constantNames_.push_back(name);
597 constantValues_.push_back(value);
598 return slot;
599 }
600
getConstant(const std::string & name) const601 IValue ClassType::getConstant(const std::string& name) const {
602 const auto& v = findConstant(name);
603 TORCH_CHECK(
604 v.has_value(),
605 repr_str(),
606 " does not have a constant field with name '",
607 name,
608 "'");
609 return *v;
610 }
611
getConstant(size_t slot) const612 IValue ClassType::getConstant(size_t slot) const {
613 TORCH_INTERNAL_ASSERT(constantNames_.size() == constantValues_.size());
614 TORCH_CHECK(
615 slot < constantValues_.size(),
616 repr_str(),
617 " does not have a constant slot of index ",
618 slot);
619 return constantValues_[slot];
620 }
621
findConstant(const std::string & name) const622 std::optional<IValue> ClassType::findConstant(const std::string& name) const {
623 TORCH_INTERNAL_ASSERT(constantNames_.size() == constantValues_.size());
624 size_t pos = 0;
625 for (const auto& c : constantNames_) {
626 if (name == c) {
627 break;
628 }
629 ++pos;
630 }
631
632 if (pos >= constantNames_.size()) {
633 return std::nullopt;
634 }
635 return constantValues_[pos];
636 }
637
unsafeRemoveConstant(const std::string & name)638 void ClassType::unsafeRemoveConstant(const std::string& name) {
639 auto slot = getConstantSlot(name);
640 constantNames_.erase(constantNames_.begin() + static_cast<std::ptrdiff_t>(slot));
641 constantValues_.erase(constantValues_.begin() + static_cast<std::ptrdiff_t>(slot));
642 }
643
compilation_unit()644 std::shared_ptr<CompilationUnit> ClassType::compilation_unit() {
645 auto cu = compilation_unit_.lock();
646 return cu;
647 }
648
compilation_unit() const649 std::shared_ptr<const CompilationUnit> ClassType::compilation_unit() const {
650 auto cu = compilation_unit_.lock();
651 return cu;
652 }
653
getProperty(const std::string & name)654 std::optional<ClassType::Property> ClassType::getProperty(const std::string& name) {
655 for (auto& prop : properties_) {
656 if (name == prop.name) {
657 return prop;
658 }
659 }
660
661 return std::nullopt;
662 }
663
addProperty(const std::string & name,torch::jit::Function * getter,torch::jit::Function * setter)664 void ClassType::addProperty(const std::string& name, torch::jit::Function* getter, torch::jit::Function* setter) {
665 TORCH_INTERNAL_ASSERT(!getProperty(name), "Property named ", name, " already exists!");
666 properties_.push_back({name, getter, setter});
667 }
668
findConstantSlot(const std::string & name) const669 std::optional<size_t> ClassType::findConstantSlot(const std::string& name) const {
670 TORCH_CHECK(constantNames_.size() == constantValues_.size());
671 size_t slot = 0;
672 for (const auto& constant : constantNames_) {
673 if (name == constant) {
674 return slot;
675 }
676 slot++;
677 }
678 return std::nullopt;
679 }
680
getConstantName(size_t slot) const681 const std::string& ClassType::getConstantName(size_t slot) const {
682 TORCH_CHECK(constantNames_.size() == constantValues_.size());
683 TORCH_CHECK(slot < constantNames_.size());
684 return constantNames_[slot];
685 }
686
numConstants() const687 size_t ClassType::numConstants() const {
688 TORCH_INTERNAL_ASSERT(constantNames_.size() == constantValues_.size());
689 return constantNames_.size();
690 }
691
constantValues() const692 at::ArrayRef<IValue> ClassType::constantValues() const {
693 return constantValues_;
694 }
695
696 } // namespace c10
697