xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/concrete_module_type.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/concrete_module_type.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5 
6 #include <iostream>
7 
8 namespace torch::jit {
9 
createTypeFromThis() const10 ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
11   auto cu = get_python_cu();
12   py::object pyQualName = py::module::import("torch._jit_internal")
13                               .attr("_qualified_name")(pyClass_);
14 
15   auto className = c10::QualifiedName(py::cast<std::string>(pyQualName));
16   if (className.prefix().empty()) {
17     className = c10::QualifiedName("__torch__", className.name());
18   }
19   if (cu->get_class(className) != nullptr) {
20     className = cu->mangle(className);
21   }
22   auto cls = ClassType::create(std::move(className), cu, /*is_module=*/true);
23   cu->register_type(cls);
24 
25   // populate type with info from the concrete type information
26   for (const auto& pr : attributes_) {
27     const auto& name = pr.key();
28     const auto& type = pr.value().type_;
29     const auto& isParameter = pr.value().isParam_;
30     const auto& isBuffer = pr.value().isBuffer_;
31     cls->addAttribute(name, type, isParameter, isBuffer);
32   }
33 
34   for (const auto& pr : constants_) {
35     cls->addConstant(pr.first, pr.second);
36   }
37 
38   for (const auto& moduleInfo : modules_) {
39     cls->addAttribute(
40         moduleInfo.name_,
41         moduleInfo.meta_->getJitType(),
42         /*is_parameter=*/false);
43   }
44 
45   return cls;
46 }
47 
fromJitType(TypePtr type)48 std::shared_ptr<ConcreteModuleType> ConcreteModuleType::fromJitType(
49     TypePtr type) {
50   ConcreteModuleTypeBuilder builder;
51   builder.setPoisoned();
52 
53   // `type` should either be a module interface or a class type
54   if (auto interface = type->cast<InterfaceType>()) {
55     TORCH_INTERNAL_ASSERT(interface->is_module());
56   } else {
57     const auto classType = type->expect<ClassType>();
58 
59     // Populate the builder metadata from the JIT type. This is to ensure
60     // ConcreteModuleTypes produced from Python and ones produced from a JIT
61     // type directly behave the same to the rest of the system.
62     for (const auto i : c10::irange(classType->numAttributes())) {
63       const auto& attrName = classType->getAttributeName(i);
64       const auto& attrType = classType->getAttribute(i);
65       if (attrType->is_module()) {
66         builder.addModule(attrName, ConcreteModuleType::fromJitType(attrType));
67       } else {
68         builder.addAttribute(
69             attrName,
70             attrType,
71             classType->is_parameter(i),
72             classType->is_buffer(i));
73       }
74     }
75 
76     for (const auto i : c10::irange(classType->numConstants())) {
77       builder.addConstant(
78           classType->getConstantName(i), classType->getConstant(i));
79     }
80   }
81 
82   // Not make_shared because the constructor is private.
83   auto ret = std::shared_ptr<ConcreteModuleType>(new ConcreteModuleType());
84   ret->jitType_ = std::move(type);
85   ret->data_ = builder;
86 
87   return ret;
88 }
89 
ConcreteModuleType(ConcreteModuleTypeBuilder data)90 ConcreteModuleType::ConcreteModuleType(ConcreteModuleTypeBuilder data)
91     : data_(std::move(data)) {
92   jitType_ = data_.createTypeFromThis();
93 }
94 
operator ==(const ConcreteModuleTypeBuilder::ModuleInfo & lhs,const ConcreteModuleTypeBuilder::ModuleInfo & rhs)95 bool operator==(
96     const ConcreteModuleTypeBuilder::ModuleInfo& lhs,
97     const ConcreteModuleTypeBuilder::ModuleInfo& rhs) {
98   return lhs.name_ == rhs.name_ && lhs.meta_->equals(*rhs.meta_);
99 }
100 
equals(const ConcreteModuleTypeBuilder & other) const101 bool ConcreteModuleTypeBuilder::equals(
102     const ConcreteModuleTypeBuilder& other) const {
103   if (isPoisoned_ || other.isPoisoned_) {
104     return false;
105   }
106 
107   // clang-format off
108     // These are vaguely ordered so that cheap, discriminating checks happen first.
109     bool equal =
110       pyClass_.is(other.pyClass_) &&
111       iterableModuleKind_ == other.iterableModuleKind_ &&
112       ignoredAttributes_ == other.ignoredAttributes_ &&
113       constants_ == other.constants_ &&
114       attributes_ == other.attributes_ &&
115       overloads_ == other.overloads_ &&
116       functionAttributes_ == other.functionAttributes_ &&
117       builtinFunctions_ == other.builtinFunctions_ &&
118       forwardHooks_ == other.forwardHooks_ &&
119       forwardPreHooks_ == other.forwardPreHooks_;
120   // clang-format on
121   if (!equal) {
122     return false;
123   }
124 
125   // We store modules in order of insertion (to make compilation
126   // deterministic). However, for the purposes of equality, insertion order
127   // should not matter, so sort them by name.
128   // We put this check last because it involves the most work.
129   auto thisSorted = modules_;
130   std::sort(
131       thisSorted.begin(),
132       thisSorted.end(),
133       [](const ModuleInfo& a, const ModuleInfo& b) {
134         return a.name_ < b.name_;
135       });
136 
137   auto otherSorted = other.modules_;
138   std::sort(
139       otherSorted.begin(),
140       otherSorted.end(),
141       [](const ModuleInfo& a, const ModuleInfo& b) {
142         return a.name_ < b.name_;
143       });
144 
145   return thisSorted == otherSorted;
146 }
147 
getJitType() const148 TypePtr ConcreteModuleType::getJitType() const {
149   return jitType_;
150 }
151 
getPyClass() const152 std::optional<py::object> ConcreteModuleType::getPyClass() const {
153   if (!data_.pyClass_) {
154     return std::nullopt;
155   }
156   return data_.pyClass_;
157 }
158 
findOverloads(const std::string & name) const159 std::optional<std::vector<std::string>> ConcreteModuleType::findOverloads(
160     const std::string& name) const {
161   const auto it = data_.overloads_.find(name);
162   if (it != data_.overloads_.end()) {
163     return it->second;
164   }
165   return std::nullopt;
166 }
167 
findFunctionAttribute(const std::string & name) const168 std::optional<Function*> ConcreteModuleType::findFunctionAttribute(
169     const std::string& name) const {
170   const auto it = data_.functionAttributes_.find(name);
171   if (it != data_.functionAttributes_.end()) {
172     return it->second.function_->function();
173   }
174   return std::nullopt;
175 }
176 
findBuiltinFunction(const std::string & name) const177 std::optional<c10::Symbol> ConcreteModuleType::findBuiltinFunction(
178     const std::string& name) const {
179   const auto it = data_.builtinFunctions_.find(name);
180   if (it != data_.builtinFunctions_.end()) {
181     return it->second;
182   }
183   return std::nullopt;
184 }
185 
findFailedAttribute(const std::string & name) const186 std::optional<std::string> ConcreteModuleType::findFailedAttribute(
187     const std::string& name) const {
188   const auto it = data_.failedAttributes_.find(name);
189   if (it != data_.failedAttributes_.end()) {
190     return it->second;
191   }
192   return std::nullopt;
193 }
194 
isIgnoredAttribute(const std::string & name) const195 bool ConcreteModuleType::isIgnoredAttribute(const std::string& name) const {
196   return data_.ignoredAttributes_.count(name) > 0;
197 }
198 
199 std::shared_ptr<ConcreteModuleType> ConcreteModuleType::
findSubmoduleConcreteType(const std::string & name) const200     findSubmoduleConcreteType(const std::string& name) const {
201   const auto it = std::find_if(
202       data_.modules_.cbegin(),
203       data_.modules_.cend(),
204       [&](const ConcreteModuleTypeBuilder::ModuleInfo& info) {
205         return info.name_ == name;
206       });
207   TORCH_INTERNAL_ASSERT(it != data_.modules_.end());
208   return it->meta_;
209 }
210 
setIterableModuleKind(IterableModuleKind kind)211 void ConcreteModuleTypeBuilder::setIterableModuleKind(IterableModuleKind kind) {
212   iterableModuleKind_ = kind;
213 }
214 
getIterableModuleKind() const215 IterableModuleKind ConcreteModuleType::getIterableModuleKind() const {
216   return data_.iterableModuleKind_;
217 }
218 
setPoisoned()219 void ConcreteModuleTypeBuilder::setPoisoned() {
220   isPoisoned_ = true;
221 }
222 
addConstant(std::string name,py::object value)223 void ConcreteModuleTypeBuilder::addConstant(
224     std::string name,
225     py::object value) {
226   auto match = tryToInferType(value);
227   if (!match.success()) {
228     TORCH_INTERNAL_ASSERT(
229         false,
230         "We need to infer the type of constant to convert the python value to IValue,"
231         " but failed to infer type of ",
232         py::str(value),
233         "\n:",
234         match.reason());
235   }
236   constants_.emplace(std::move(name), toIValue(std::move(value), match.type()));
237 }
238 
addConstant(std::string name,IValue value)239 void ConcreteModuleTypeBuilder::addConstant(std::string name, IValue value) {
240   constants_.emplace(std::move(name), std::move(value));
241 }
242 
addAttribute(std::string name,const TypePtr & type,bool isParameter,bool isBuffer)243 void ConcreteModuleTypeBuilder::addAttribute(
244     std::string name,
245     const TypePtr& type,
246     bool isParameter,
247     bool isBuffer) {
248   TORCH_INTERNAL_ASSERT(type);
249   // Function attributes should be handled separately
250   TORCH_INTERNAL_ASSERT(type->cast<FunctionType>() == nullptr);
251   attributes_.insert(
252       std::move(name),
253       ConcreteModuleTypeBuilder::Attribute(
254           unshapedType(type), isParameter, isBuffer));
255 }
256 
addFunctionAttribute(std::string name,const TypePtr & type,py::object pyFunction)257 void ConcreteModuleTypeBuilder::addFunctionAttribute(
258     std::string name,
259     const TypePtr& type,
260     py::object pyFunction) {
261   TORCH_INTERNAL_ASSERT(type);
262   functionAttributes_.emplace(
263       std::move(name),
264       ConcreteModuleTypeBuilder::FunctionAttribute{
265           type->expect<FunctionType>(), std::move(pyFunction)});
266 }
267 
addBuiltinFunction(std::string name,const std::string & symbol_name)268 void ConcreteModuleTypeBuilder::addBuiltinFunction(
269     std::string name,
270     const std::string& symbol_name) {
271   builtinFunctions_.emplace(
272       std::move(name), c10::Symbol::fromQualString(symbol_name));
273 }
274 
addModule(std::string name,std::shared_ptr<ConcreteModuleType> meta)275 void ConcreteModuleTypeBuilder::addModule(
276     std::string name,
277     std::shared_ptr<ConcreteModuleType> meta) {
278   modules_.emplace_back(std::move(name), std::move(meta));
279 }
280 
addForwardHook(py::object hook)281 void ConcreteModuleTypeBuilder::addForwardHook(py::object hook) {
282   forwardHooks_.emplace_back(std::move(hook));
283 }
284 
addForwardPreHook(py::object pre_hook)285 void ConcreteModuleTypeBuilder::addForwardPreHook(py::object pre_hook) {
286   forwardPreHooks_.emplace_back(std::move(pre_hook));
287 }
288 
addOverload(std::string methodName,std::vector<std::string> overloadedMethodNames)289 void ConcreteModuleTypeBuilder::addOverload(
290     std::string methodName,
291     std::vector<std::string> overloadedMethodNames) {
292   overloads_.emplace(std::move(methodName), std::move(overloadedMethodNames));
293 }
294 
addFailedAttribute(std::string name,std::string failureReason)295 void ConcreteModuleTypeBuilder::addFailedAttribute(
296     std::string name,
297     std::string failureReason) {
298   failedAttributes_.emplace(std::move(name), std::move(failureReason));
299 }
300 
addIgnoredAttribute(std::string name)301 void ConcreteModuleTypeBuilder::addIgnoredAttribute(std::string name) {
302   ignoredAttributes_.emplace(std::move(name));
303 }
304 
dump() const305 void ConcreteModuleType::dump() const {
306   std::cout << "ConcreteModuleType for: "
307             << py::getattr(data_.pyClass_, "__name__") << "\n";
308   std::cout << "Constants: \n";
309   for (const auto& pr : data_.constants_) {
310     std::cout << "\t" << pr.first << ": " << pr.second << "\n";
311   }
312   std::cout << "\nAttributes: \n";
313   for (const auto& pr : data_.attributes_) {
314     std::cout << "\t" << pr.key() << ": " << pr.value().type_->annotation_str()
315               << "\n";
316   }
317   std::cout << "\nSubmodules: \n";
318   for (const auto& info : data_.modules_) {
319     std::cout << "\t" << info.name_ << ": "
320               << info.meta_->getJitType()->annotation_str() << "\n";
321   }
322   std::cout << "\nForward Pre-Hooks: \n";
323   for (const auto& pre_hook_id : data_.forwardPreHooks_) {
324     std::cout << "\t"
325               << "pre_hook id: " << pre_hook_id << "\n";
326   }
327   std::cout << "\nForward Hooks: \n";
328   for (const auto& hook_id : data_.forwardHooks_) {
329     std::cout << "\t"
330               << "hook id: " << hook_id << "\n";
331   }
332   std::cout << "\nOverloads: \n";
333   for (const auto& pr : data_.overloads_) {
334     std::cout << "\t" << pr.first << ": " << pr.second << "\n";
335   }
336   std::string isPoisoned = data_.isPoisoned_ ? "true" : "false";
337   std::cout << "isPoisoned: " << isPoisoned << "\n";
338   if (jitType_) {
339     std::cout << "jit type: " << jitType_->annotation_str() << "\n";
340   }
341 }
342 
getConstantsPy() const343 std::unordered_map<std::string, py::object> ConcreteModuleType::getConstantsPy()
344     const {
345   // Convert to a more pybind-friendly representation, so we don't
346   // need to bind ConcreteModuleType::Constant as well.
347   std::unordered_map<std::string, py::object> ret;
348   for (const auto& pr : data_.constants_) {
349     ret.emplace(pr.first, toPyObject(pr.second));
350   }
351   return ret;
352 }
353 
354 std::unordered_map<std::string, std::pair<TypePtr, bool>> ConcreteModuleType::
getAttributesPy() const355     getAttributesPy() const {
356   // Convert to a more pybind-friendly representation, so we don't
357   // need to bind ConcreteModuleType::Attribute as well.
358   std::unordered_map<std::string, std::pair<TypePtr, bool>> ret;
359   for (auto& pr : data_.attributes_) {
360     ret.emplace(
361         pr.key(),
362         std::pair<TypePtr, bool>(pr.value().type_, pr.value().isParam_));
363   }
364   return ret;
365 }
366 
367 std::vector<std::pair<std::string, std::shared_ptr<ConcreteModuleType>>>
getModulesPy() const368 ConcreteModuleType::getModulesPy() const {
369   std::vector<std::pair<std::string, std::shared_ptr<ConcreteModuleType>>> ret;
370 
371   ret.reserve(data_.modules_.size());
372   for (const auto& info : data_.modules_) {
373     ret.emplace_back(info.name_, info.meta_);
374   }
375   return ret;
376 }
377 
378 } // namespace torch::jit
379