1 #include <torch/csrc/jit/frontend/concrete_module_type.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5
6 #include <iostream>
7
8 namespace torch::jit {
9
createTypeFromThis() const10 ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
11 auto cu = get_python_cu();
12 py::object pyQualName = py::module::import("torch._jit_internal")
13 .attr("_qualified_name")(pyClass_);
14
15 auto className = c10::QualifiedName(py::cast<std::string>(pyQualName));
16 if (className.prefix().empty()) {
17 className = c10::QualifiedName("__torch__", className.name());
18 }
19 if (cu->get_class(className) != nullptr) {
20 className = cu->mangle(className);
21 }
22 auto cls = ClassType::create(std::move(className), cu, /*is_module=*/true);
23 cu->register_type(cls);
24
25 // populate type with info from the concrete type information
26 for (const auto& pr : attributes_) {
27 const auto& name = pr.key();
28 const auto& type = pr.value().type_;
29 const auto& isParameter = pr.value().isParam_;
30 const auto& isBuffer = pr.value().isBuffer_;
31 cls->addAttribute(name, type, isParameter, isBuffer);
32 }
33
34 for (const auto& pr : constants_) {
35 cls->addConstant(pr.first, pr.second);
36 }
37
38 for (const auto& moduleInfo : modules_) {
39 cls->addAttribute(
40 moduleInfo.name_,
41 moduleInfo.meta_->getJitType(),
42 /*is_parameter=*/false);
43 }
44
45 return cls;
46 }
47
fromJitType(TypePtr type)48 std::shared_ptr<ConcreteModuleType> ConcreteModuleType::fromJitType(
49 TypePtr type) {
50 ConcreteModuleTypeBuilder builder;
51 builder.setPoisoned();
52
53 // `type` should either be a module interface or a class type
54 if (auto interface = type->cast<InterfaceType>()) {
55 TORCH_INTERNAL_ASSERT(interface->is_module());
56 } else {
57 const auto classType = type->expect<ClassType>();
58
59 // Populate the builder metadata from the JIT type. This is to ensure
60 // ConcreteModuleTypes produced from Python and ones produced from a JIT
61 // type directly behave the same to the rest of the system.
62 for (const auto i : c10::irange(classType->numAttributes())) {
63 const auto& attrName = classType->getAttributeName(i);
64 const auto& attrType = classType->getAttribute(i);
65 if (attrType->is_module()) {
66 builder.addModule(attrName, ConcreteModuleType::fromJitType(attrType));
67 } else {
68 builder.addAttribute(
69 attrName,
70 attrType,
71 classType->is_parameter(i),
72 classType->is_buffer(i));
73 }
74 }
75
76 for (const auto i : c10::irange(classType->numConstants())) {
77 builder.addConstant(
78 classType->getConstantName(i), classType->getConstant(i));
79 }
80 }
81
82 // Not make_shared because the constructor is private.
83 auto ret = std::shared_ptr<ConcreteModuleType>(new ConcreteModuleType());
84 ret->jitType_ = std::move(type);
85 ret->data_ = builder;
86
87 return ret;
88 }
89
ConcreteModuleType(ConcreteModuleTypeBuilder data)90 ConcreteModuleType::ConcreteModuleType(ConcreteModuleTypeBuilder data)
91 : data_(std::move(data)) {
92 jitType_ = data_.createTypeFromThis();
93 }
94
operator ==(const ConcreteModuleTypeBuilder::ModuleInfo & lhs,const ConcreteModuleTypeBuilder::ModuleInfo & rhs)95 bool operator==(
96 const ConcreteModuleTypeBuilder::ModuleInfo& lhs,
97 const ConcreteModuleTypeBuilder::ModuleInfo& rhs) {
98 return lhs.name_ == rhs.name_ && lhs.meta_->equals(*rhs.meta_);
99 }
100
equals(const ConcreteModuleTypeBuilder & other) const101 bool ConcreteModuleTypeBuilder::equals(
102 const ConcreteModuleTypeBuilder& other) const {
103 if (isPoisoned_ || other.isPoisoned_) {
104 return false;
105 }
106
107 // clang-format off
108 // These are vaguely ordered so that cheap, discriminating checks happen first.
109 bool equal =
110 pyClass_.is(other.pyClass_) &&
111 iterableModuleKind_ == other.iterableModuleKind_ &&
112 ignoredAttributes_ == other.ignoredAttributes_ &&
113 constants_ == other.constants_ &&
114 attributes_ == other.attributes_ &&
115 overloads_ == other.overloads_ &&
116 functionAttributes_ == other.functionAttributes_ &&
117 builtinFunctions_ == other.builtinFunctions_ &&
118 forwardHooks_ == other.forwardHooks_ &&
119 forwardPreHooks_ == other.forwardPreHooks_;
120 // clang-format on
121 if (!equal) {
122 return false;
123 }
124
125 // We store modules in order of insertion (to make compilation
126 // deterministic). However, for the purposes of equality, insertion order
127 // should not matter, so sort them by name.
128 // We put this check last because it involves the most work.
129 auto thisSorted = modules_;
130 std::sort(
131 thisSorted.begin(),
132 thisSorted.end(),
133 [](const ModuleInfo& a, const ModuleInfo& b) {
134 return a.name_ < b.name_;
135 });
136
137 auto otherSorted = other.modules_;
138 std::sort(
139 otherSorted.begin(),
140 otherSorted.end(),
141 [](const ModuleInfo& a, const ModuleInfo& b) {
142 return a.name_ < b.name_;
143 });
144
145 return thisSorted == otherSorted;
146 }
147
getJitType() const148 TypePtr ConcreteModuleType::getJitType() const {
149 return jitType_;
150 }
151
getPyClass() const152 std::optional<py::object> ConcreteModuleType::getPyClass() const {
153 if (!data_.pyClass_) {
154 return std::nullopt;
155 }
156 return data_.pyClass_;
157 }
158
findOverloads(const std::string & name) const159 std::optional<std::vector<std::string>> ConcreteModuleType::findOverloads(
160 const std::string& name) const {
161 const auto it = data_.overloads_.find(name);
162 if (it != data_.overloads_.end()) {
163 return it->second;
164 }
165 return std::nullopt;
166 }
167
findFunctionAttribute(const std::string & name) const168 std::optional<Function*> ConcreteModuleType::findFunctionAttribute(
169 const std::string& name) const {
170 const auto it = data_.functionAttributes_.find(name);
171 if (it != data_.functionAttributes_.end()) {
172 return it->second.function_->function();
173 }
174 return std::nullopt;
175 }
176
findBuiltinFunction(const std::string & name) const177 std::optional<c10::Symbol> ConcreteModuleType::findBuiltinFunction(
178 const std::string& name) const {
179 const auto it = data_.builtinFunctions_.find(name);
180 if (it != data_.builtinFunctions_.end()) {
181 return it->second;
182 }
183 return std::nullopt;
184 }
185
findFailedAttribute(const std::string & name) const186 std::optional<std::string> ConcreteModuleType::findFailedAttribute(
187 const std::string& name) const {
188 const auto it = data_.failedAttributes_.find(name);
189 if (it != data_.failedAttributes_.end()) {
190 return it->second;
191 }
192 return std::nullopt;
193 }
194
isIgnoredAttribute(const std::string & name) const195 bool ConcreteModuleType::isIgnoredAttribute(const std::string& name) const {
196 return data_.ignoredAttributes_.count(name) > 0;
197 }
198
199 std::shared_ptr<ConcreteModuleType> ConcreteModuleType::
findSubmoduleConcreteType(const std::string & name) const200 findSubmoduleConcreteType(const std::string& name) const {
201 const auto it = std::find_if(
202 data_.modules_.cbegin(),
203 data_.modules_.cend(),
204 [&](const ConcreteModuleTypeBuilder::ModuleInfo& info) {
205 return info.name_ == name;
206 });
207 TORCH_INTERNAL_ASSERT(it != data_.modules_.end());
208 return it->meta_;
209 }
210
setIterableModuleKind(IterableModuleKind kind)211 void ConcreteModuleTypeBuilder::setIterableModuleKind(IterableModuleKind kind) {
212 iterableModuleKind_ = kind;
213 }
214
getIterableModuleKind() const215 IterableModuleKind ConcreteModuleType::getIterableModuleKind() const {
216 return data_.iterableModuleKind_;
217 }
218
setPoisoned()219 void ConcreteModuleTypeBuilder::setPoisoned() {
220 isPoisoned_ = true;
221 }
222
addConstant(std::string name,py::object value)223 void ConcreteModuleTypeBuilder::addConstant(
224 std::string name,
225 py::object value) {
226 auto match = tryToInferType(value);
227 if (!match.success()) {
228 TORCH_INTERNAL_ASSERT(
229 false,
230 "We need to infer the type of constant to convert the python value to IValue,"
231 " but failed to infer type of ",
232 py::str(value),
233 "\n:",
234 match.reason());
235 }
236 constants_.emplace(std::move(name), toIValue(std::move(value), match.type()));
237 }
238
addConstant(std::string name,IValue value)239 void ConcreteModuleTypeBuilder::addConstant(std::string name, IValue value) {
240 constants_.emplace(std::move(name), std::move(value));
241 }
242
addAttribute(std::string name,const TypePtr & type,bool isParameter,bool isBuffer)243 void ConcreteModuleTypeBuilder::addAttribute(
244 std::string name,
245 const TypePtr& type,
246 bool isParameter,
247 bool isBuffer) {
248 TORCH_INTERNAL_ASSERT(type);
249 // Function attributes should be handled separately
250 TORCH_INTERNAL_ASSERT(type->cast<FunctionType>() == nullptr);
251 attributes_.insert(
252 std::move(name),
253 ConcreteModuleTypeBuilder::Attribute(
254 unshapedType(type), isParameter, isBuffer));
255 }
256
addFunctionAttribute(std::string name,const TypePtr & type,py::object pyFunction)257 void ConcreteModuleTypeBuilder::addFunctionAttribute(
258 std::string name,
259 const TypePtr& type,
260 py::object pyFunction) {
261 TORCH_INTERNAL_ASSERT(type);
262 functionAttributes_.emplace(
263 std::move(name),
264 ConcreteModuleTypeBuilder::FunctionAttribute{
265 type->expect<FunctionType>(), std::move(pyFunction)});
266 }
267
addBuiltinFunction(std::string name,const std::string & symbol_name)268 void ConcreteModuleTypeBuilder::addBuiltinFunction(
269 std::string name,
270 const std::string& symbol_name) {
271 builtinFunctions_.emplace(
272 std::move(name), c10::Symbol::fromQualString(symbol_name));
273 }
274
addModule(std::string name,std::shared_ptr<ConcreteModuleType> meta)275 void ConcreteModuleTypeBuilder::addModule(
276 std::string name,
277 std::shared_ptr<ConcreteModuleType> meta) {
278 modules_.emplace_back(std::move(name), std::move(meta));
279 }
280
addForwardHook(py::object hook)281 void ConcreteModuleTypeBuilder::addForwardHook(py::object hook) {
282 forwardHooks_.emplace_back(std::move(hook));
283 }
284
addForwardPreHook(py::object pre_hook)285 void ConcreteModuleTypeBuilder::addForwardPreHook(py::object pre_hook) {
286 forwardPreHooks_.emplace_back(std::move(pre_hook));
287 }
288
addOverload(std::string methodName,std::vector<std::string> overloadedMethodNames)289 void ConcreteModuleTypeBuilder::addOverload(
290 std::string methodName,
291 std::vector<std::string> overloadedMethodNames) {
292 overloads_.emplace(std::move(methodName), std::move(overloadedMethodNames));
293 }
294
addFailedAttribute(std::string name,std::string failureReason)295 void ConcreteModuleTypeBuilder::addFailedAttribute(
296 std::string name,
297 std::string failureReason) {
298 failedAttributes_.emplace(std::move(name), std::move(failureReason));
299 }
300
addIgnoredAttribute(std::string name)301 void ConcreteModuleTypeBuilder::addIgnoredAttribute(std::string name) {
302 ignoredAttributes_.emplace(std::move(name));
303 }
304
dump() const305 void ConcreteModuleType::dump() const {
306 std::cout << "ConcreteModuleType for: "
307 << py::getattr(data_.pyClass_, "__name__") << "\n";
308 std::cout << "Constants: \n";
309 for (const auto& pr : data_.constants_) {
310 std::cout << "\t" << pr.first << ": " << pr.second << "\n";
311 }
312 std::cout << "\nAttributes: \n";
313 for (const auto& pr : data_.attributes_) {
314 std::cout << "\t" << pr.key() << ": " << pr.value().type_->annotation_str()
315 << "\n";
316 }
317 std::cout << "\nSubmodules: \n";
318 for (const auto& info : data_.modules_) {
319 std::cout << "\t" << info.name_ << ": "
320 << info.meta_->getJitType()->annotation_str() << "\n";
321 }
322 std::cout << "\nForward Pre-Hooks: \n";
323 for (const auto& pre_hook_id : data_.forwardPreHooks_) {
324 std::cout << "\t"
325 << "pre_hook id: " << pre_hook_id << "\n";
326 }
327 std::cout << "\nForward Hooks: \n";
328 for (const auto& hook_id : data_.forwardHooks_) {
329 std::cout << "\t"
330 << "hook id: " << hook_id << "\n";
331 }
332 std::cout << "\nOverloads: \n";
333 for (const auto& pr : data_.overloads_) {
334 std::cout << "\t" << pr.first << ": " << pr.second << "\n";
335 }
336 std::string isPoisoned = data_.isPoisoned_ ? "true" : "false";
337 std::cout << "isPoisoned: " << isPoisoned << "\n";
338 if (jitType_) {
339 std::cout << "jit type: " << jitType_->annotation_str() << "\n";
340 }
341 }
342
getConstantsPy() const343 std::unordered_map<std::string, py::object> ConcreteModuleType::getConstantsPy()
344 const {
345 // Convert to a more pybind-friendly representation, so we don't
346 // need to bind ConcreteModuleType::Constant as well.
347 std::unordered_map<std::string, py::object> ret;
348 for (const auto& pr : data_.constants_) {
349 ret.emplace(pr.first, toPyObject(pr.second));
350 }
351 return ret;
352 }
353
354 std::unordered_map<std::string, std::pair<TypePtr, bool>> ConcreteModuleType::
getAttributesPy() const355 getAttributesPy() const {
356 // Convert to a more pybind-friendly representation, so we don't
357 // need to bind ConcreteModuleType::Attribute as well.
358 std::unordered_map<std::string, std::pair<TypePtr, bool>> ret;
359 for (auto& pr : data_.attributes_) {
360 ret.emplace(
361 pr.key(),
362 std::pair<TypePtr, bool>(pr.value().type_, pr.value().isParam_));
363 }
364 return ret;
365 }
366
367 std::vector<std::pair<std::string, std::shared_ptr<ConcreteModuleType>>>
getModulesPy() const368 ConcreteModuleType::getModulesPy() const {
369 std::vector<std::pair<std::string, std::shared_ptr<ConcreteModuleType>>> ret;
370
371 ret.reserve(data_.modules_.size());
372 for (const auto& info : data_.modules_) {
373 ret.emplace_back(info.name_, info.meta_);
374 }
375 return ret;
376 }
377
378 } // namespace torch::jit
379