1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 #include <torch/csrc/jit/api/module.h> 5 #include <torch/csrc/jit/python/pybind_utils.h> 6 #include <memory> 7 #include <string> 8 #include <vector> 9 10 namespace torch::jit { 11 12 enum class IterableModuleKind { NONE, LIST, DICT, PARAMLIST, PARAMDICT }; 13 class ConcreteModuleType; 14 15 // You can think of an nn.Module as a template that corresponds to a family of 16 // JIT types. The template "arguments" are things like the constant values. 17 // e.g. 18 // class M(nn.Module): 19 // __constants__ = ["const"] 20 // ... 21 // 22 // Is similar to writing the following in C++: 23 // 24 // template<TConst> 25 // class M { 26 // ... 27 // } 28 // 29 // We need to consider each different member of the type family a different JIT 30 // type because, e.g. different constant values lead to different versions of 31 // the same method. 32 // 33 // ConcreteModuleType corresponds to a single member of the type family, with 34 // all template arguments fully specified. Two Modules that share a 35 // ConcreteModuleType can share a JIT type, and vice versa. 36 // 37 // Why not just use a JIT type to represent concrete types? Because constants, 38 // function attributes, etc. are currently not representable in the type system, 39 // so this acts a non-first-class way of tracking concrete types. 40 // 41 // ConcreteModuleType is also the source of truth for servicing all 42 // ModuleValue::attr calls. This is so we can guarantee that if two Module's 43 // share a JIT type (and thus a ConcreteModuleType), then they behave the same 44 // way when you access attributes on them. 45 46 // ConcreteModuleType has two phases. 47 // 1. Creation: First we build it up, during the ScriptModule conversion 48 // process. This is represented by ConcreteModuleTypeBuilder. 49 // ...then the converter calls ConcreteModuleTypeBuilder::build(), producing 50 // a 51 // ConcreteModuleType ready for querying. 52 // 2. Querying: We use ConcreteModuleType as a source of truth for 53 // ModuleValue::attr calls during method compilation. 54 55 // Represents a concrete type during in the process for construction. We use 56 // this to decide whether we can share types between modules. 57 class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { 58 public: ConcreteModuleTypeBuilder(py::object pyClass)59 explicit ConcreteModuleTypeBuilder(py::object pyClass) { 60 TORCH_INTERNAL_ASSERT(pyClass); 61 pyClass_ = std::move(pyClass); 62 } 63 64 void addConstant(std::string name, py::object value); 65 void addConstant(std::string name, IValue value); 66 void addAttribute( 67 std::string name, 68 const TypePtr& type, 69 bool isParameter, 70 bool isBuffer); 71 void addFunctionAttribute( 72 std::string name, 73 const TypePtr& type, 74 py::object pyFunction); 75 76 void addModule(std::string name, std::shared_ptr<ConcreteModuleType> meta); 77 78 void addForwardHook(py::object hook); 79 void addForwardPreHook(py::object pre_hook); 80 81 void addOverload( 82 std::string methodName, 83 std::vector<std::string> overloadedMethodNames); 84 void addBuiltinFunction(std::string name, const std::string& symbol_name); 85 void addFailedAttribute(std::string name, std::string failureReason); 86 void addIgnoredAttribute(std::string name); 87 void setIterableModuleKind(IterableModuleKind kind); 88 89 // If a ConcreteModuleType is poisoned, it will never compare equal to any 90 // other concrete type 91 void setPoisoned(); 92 build()93 std::shared_ptr<ConcreteModuleType> build() const { 94 return std::make_shared<ConcreteModuleType>(*this); 95 } 96 97 // This determines whether two modules can share a type. The container structs 98 // used by ConcreteModuleType have been defined such that operator== 99 // implements a meaningful comparison in that context. 100 bool equals(const ConcreteModuleTypeBuilder& other) const; 101 102 struct FunctionAttribute { 103 FunctionTypePtr function_; 104 py::object pyFunction_; 105 106 friend bool operator==( 107 const FunctionAttribute& lhs, 108 const FunctionAttribute& rhs) { 109 // Functions are not first class, so we can't do type comparison like a 110 // regular attribute. So we do a pointer equality check on the actual 111 // Python function object. 112 return lhs.pyFunction_.is(rhs.pyFunction_); 113 } 114 }; 115 116 struct Attribute { AttributeAttribute117 Attribute(TypePtr type, bool isParam, bool isBuffer) 118 : type_(std::move(type)), isParam_(isParam), isBuffer_(isBuffer) {} 119 120 friend bool operator==(const Attribute& lhs, const Attribute& rhs) { 121 return *(lhs.type_) == *(rhs.type_) && lhs.isParam_ == rhs.isParam_; 122 } 123 TypePtr type_; 124 bool isParam_; 125 bool isBuffer_; 126 }; 127 128 struct ModuleInfo { ModuleInfoModuleInfo129 ModuleInfo(std::string name, std::shared_ptr<ConcreteModuleType> meta) 130 : name_(std::move(name)), meta_(std::move(meta)) {} 131 132 friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs); 133 134 std::string name_; 135 std::shared_ptr<ConcreteModuleType> meta_; 136 }; 137 138 private: 139 ConcreteModuleTypeBuilder() = default; 140 ClassTypePtr createTypeFromThis() const; 141 142 // If true, this type will never compare equally to anything else. This is 143 // used if we want to ensure that this type is not shared (for example, if it 144 // came from a traced module) 145 bool isPoisoned_ = false; 146 147 // The value of any constants defined by the module. 148 std::unordered_map<std::string, IValue> constants_; 149 // The types of any attributes 150 OrderedDict<std::string, Attribute> attributes_; 151 // Overloads, in the same format as `__overloads__` in Python 152 std::unordered_map<std::string, std::vector<std::string>> overloads_; 153 // Any attributes we failed to convert to TorchScript, along with a hint as to 154 // why 155 std::unordered_map<std::string, std::string> failedAttributes_; 156 // Any attributes that were marked as ignored. They cannot be used in 157 // TorchScript but can still be used in ignored function in Python. 158 std::unordered_set<std::string> ignoredAttributes_; 159 // Any function attributes. These are special right now because functions are 160 // not first-class in the type system. 161 std::unordered_map<std::string, FunctionAttribute> functionAttributes_; 162 // Function attributes that are calls to builtin functions. These get 163 // de-sugared directly into the corresponding aten:: call. The map is 164 // attribute name -> aten symbol name 165 std::unordered_map<std::string, c10::Symbol> builtinFunctions_; 166 // The concrete types of any submodules 167 std::vector<ModuleInfo> modules_; 168 // Hooks to be called before/after forward when the module 169 // is called directly. Used to ensure modules have different types 170 // when they have different python hooks 171 // Actual hooks are added to ClassType directly during compilation 172 std::vector<py::object> forwardHooks_; 173 std::vector<py::object> forwardPreHooks_; 174 175 // If something is a ModuleDict/ModuleList, it means: 176 // 1. The order of the submodules matters for comparing the type 177 // 2. The compiler is allowed to treat it like a dict/tuple 178 IterableModuleKind iterableModuleKind_ = IterableModuleKind::NONE; 179 180 // The original `nn.Module` class that we derived this ScriptModule from. 181 py::object pyClass_; 182 183 // NOTE: If you ever add any more state to this struct, you need to make sure 184 // operator== still makes sense! 185 friend ConcreteModuleType; 186 }; 187 188 // Represents a finalized concrete type, used to service ModuleValue::attr calls 189 // during method compilation. 190 class VISIBILITY_HIDDEN ConcreteModuleType { 191 public: 192 explicit ConcreteModuleType(ConcreteModuleTypeBuilder data); 193 194 static std::shared_ptr<ConcreteModuleType> fromJitType(TypePtr type); 195 196 TypePtr getJitType() const; 197 std::optional<py::object> getPyClass() const; 198 IterableModuleKind getIterableModuleKind() const; 199 std::optional<std::vector<std::string>> findOverloads( 200 const std::string& name) const; 201 std::optional<Function*> findFunctionAttribute(const std::string& name) const; 202 std::optional<c10::Symbol> findBuiltinFunction(const std::string& name) const; 203 std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType( 204 const std::string& name) const; 205 std::optional<std::string> findFailedAttribute(const std::string& name) const; 206 bool isIgnoredAttribute(const std::string& name) const; 207 208 // These getters are only here to return things as types that can be 209 // automatically converted by pybind. 210 std::unordered_map<std::string, py::object> getConstantsPy() const; 211 std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy() 212 const; 213 std::vector<std::pair<std::string, std::shared_ptr<ConcreteModuleType>>> 214 getModulesPy() const; 215 equals(const ConcreteModuleType & other)216 bool equals(const ConcreteModuleType& other) const { 217 if (jitType_ == other.jitType_) { 218 // If the computed types are the same, these modules can (obviously) share 219 // a type. 220 return true; 221 } 222 223 return data_.equals(other.data_); 224 } equals(const ConcreteModuleTypeBuilder & other)225 bool equals(const ConcreteModuleTypeBuilder& other) const { 226 return data_.equals(other); 227 } 228 229 void dump() const; 230 231 private: 232 ConcreteModuleType() = default; 233 234 // The JIT type derived from this ConcreteModuleType. 235 ConcreteModuleTypeBuilder data_; 236 TypePtr jitType_; 237 }; 238 239 } // namespace torch::jit 240