xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/concrete_module_type.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <torch/csrc/jit/api/module.h>
5 #include <torch/csrc/jit/python/pybind_utils.h>
6 #include <memory>
7 #include <string>
8 #include <vector>
9 
10 namespace torch::jit {
11 
12 enum class IterableModuleKind { NONE, LIST, DICT, PARAMLIST, PARAMDICT };
13 class ConcreteModuleType;
14 
15 // You can think of an nn.Module as a template that corresponds to a family of
16 // JIT types. The template "arguments" are things like the constant values.
17 // e.g.
18 //   class M(nn.Module):
19 //        __constants__ = ["const"]
20 //        ...
21 //
22 // Is similar to writing the following in C++:
23 //
24 //    template<TConst>
25 //    class M {
26 //       ...
27 //    }
28 //
29 // We need to consider each different member of the type family a different JIT
30 // type because, e.g. different constant values lead to different versions of
31 // the same method.
32 //
33 // ConcreteModuleType corresponds to a single member of the type family, with
34 // all template arguments fully specified. Two Modules that share a
35 // ConcreteModuleType can share a JIT type, and vice versa.
36 //
37 // Why not just use a JIT type to represent concrete types? Because constants,
38 // function attributes, etc. are currently not representable in the type system,
39 // so this acts a non-first-class way of tracking concrete types.
40 //
41 // ConcreteModuleType is also the source of truth for servicing all
42 // ModuleValue::attr calls. This is so we can guarantee that if two Module's
43 // share a JIT type (and thus a ConcreteModuleType), then they behave the same
44 // way when you access attributes on them.
45 
46 // ConcreteModuleType has two phases.
47 // 1. Creation: First we build it up, during the ScriptModule conversion
48 // process. This is represented by ConcreteModuleTypeBuilder.
49 //    ...then the converter calls ConcreteModuleTypeBuilder::build(), producing
50 //    a
51 //       ConcreteModuleType ready for querying.
52 // 2. Querying: We use ConcreteModuleType as a source of truth for
53 // ModuleValue::attr calls during method compilation.
54 
55 // Represents a concrete type during in the process for construction. We use
56 // this to decide whether we can share types between modules.
57 class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
58  public:
ConcreteModuleTypeBuilder(py::object pyClass)59   explicit ConcreteModuleTypeBuilder(py::object pyClass) {
60     TORCH_INTERNAL_ASSERT(pyClass);
61     pyClass_ = std::move(pyClass);
62   }
63 
64   void addConstant(std::string name, py::object value);
65   void addConstant(std::string name, IValue value);
66   void addAttribute(
67       std::string name,
68       const TypePtr& type,
69       bool isParameter,
70       bool isBuffer);
71   void addFunctionAttribute(
72       std::string name,
73       const TypePtr& type,
74       py::object pyFunction);
75 
76   void addModule(std::string name, std::shared_ptr<ConcreteModuleType> meta);
77 
78   void addForwardHook(py::object hook);
79   void addForwardPreHook(py::object pre_hook);
80 
81   void addOverload(
82       std::string methodName,
83       std::vector<std::string> overloadedMethodNames);
84   void addBuiltinFunction(std::string name, const std::string& symbol_name);
85   void addFailedAttribute(std::string name, std::string failureReason);
86   void addIgnoredAttribute(std::string name);
87   void setIterableModuleKind(IterableModuleKind kind);
88 
89   // If a ConcreteModuleType is poisoned, it will never compare equal to any
90   // other concrete type
91   void setPoisoned();
92 
build()93   std::shared_ptr<ConcreteModuleType> build() const {
94     return std::make_shared<ConcreteModuleType>(*this);
95   }
96 
97   // This determines whether two modules can share a type. The container structs
98   // used by ConcreteModuleType have been defined such that operator==
99   // implements a meaningful comparison in that context.
100   bool equals(const ConcreteModuleTypeBuilder& other) const;
101 
102   struct FunctionAttribute {
103     FunctionTypePtr function_;
104     py::object pyFunction_;
105 
106     friend bool operator==(
107         const FunctionAttribute& lhs,
108         const FunctionAttribute& rhs) {
109       // Functions are not first class, so we can't do type comparison like a
110       // regular attribute. So we do a pointer equality check on the actual
111       // Python function object.
112       return lhs.pyFunction_.is(rhs.pyFunction_);
113     }
114   };
115 
116   struct Attribute {
AttributeAttribute117     Attribute(TypePtr type, bool isParam, bool isBuffer)
118         : type_(std::move(type)), isParam_(isParam), isBuffer_(isBuffer) {}
119 
120     friend bool operator==(const Attribute& lhs, const Attribute& rhs) {
121       return *(lhs.type_) == *(rhs.type_) && lhs.isParam_ == rhs.isParam_;
122     }
123     TypePtr type_;
124     bool isParam_;
125     bool isBuffer_;
126   };
127 
128   struct ModuleInfo {
ModuleInfoModuleInfo129     ModuleInfo(std::string name, std::shared_ptr<ConcreteModuleType> meta)
130         : name_(std::move(name)), meta_(std::move(meta)) {}
131 
132     friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs);
133 
134     std::string name_;
135     std::shared_ptr<ConcreteModuleType> meta_;
136   };
137 
138  private:
139   ConcreteModuleTypeBuilder() = default;
140   ClassTypePtr createTypeFromThis() const;
141 
142   // If true, this type will never compare equally to anything else. This is
143   // used if we want to ensure that this type is not shared (for example, if it
144   // came from a traced module)
145   bool isPoisoned_ = false;
146 
147   // The value of any constants defined by the module.
148   std::unordered_map<std::string, IValue> constants_;
149   // The types of any attributes
150   OrderedDict<std::string, Attribute> attributes_;
151   // Overloads, in the same format as `__overloads__` in Python
152   std::unordered_map<std::string, std::vector<std::string>> overloads_;
153   // Any attributes we failed to convert to TorchScript, along with a hint as to
154   // why
155   std::unordered_map<std::string, std::string> failedAttributes_;
156   // Any attributes that were marked as ignored. They cannot be used in
157   // TorchScript but can still be used in ignored function in Python.
158   std::unordered_set<std::string> ignoredAttributes_;
159   // Any function attributes. These are special right now because functions are
160   // not first-class in the type system.
161   std::unordered_map<std::string, FunctionAttribute> functionAttributes_;
162   // Function attributes that are calls to builtin functions. These get
163   // de-sugared directly into the corresponding aten:: call. The map is
164   // attribute name -> aten symbol name
165   std::unordered_map<std::string, c10::Symbol> builtinFunctions_;
166   // The concrete types of any submodules
167   std::vector<ModuleInfo> modules_;
168   // Hooks to be called before/after forward when the module
169   // is called directly. Used to ensure modules have different types
170   // when they have different python hooks
171   // Actual hooks are added to ClassType directly during compilation
172   std::vector<py::object> forwardHooks_;
173   std::vector<py::object> forwardPreHooks_;
174 
175   // If something is a ModuleDict/ModuleList, it means:
176   //   1. The order of the submodules matters for comparing the type
177   //   2. The compiler is allowed to treat it like a dict/tuple
178   IterableModuleKind iterableModuleKind_ = IterableModuleKind::NONE;
179 
180   // The original `nn.Module` class that we derived this ScriptModule from.
181   py::object pyClass_;
182 
183   // NOTE: If you ever add any more state to this struct, you need to make sure
184   // operator== still makes sense!
185   friend ConcreteModuleType;
186 };
187 
188 // Represents a finalized concrete type, used to service ModuleValue::attr calls
189 // during method compilation.
190 class VISIBILITY_HIDDEN ConcreteModuleType {
191  public:
192   explicit ConcreteModuleType(ConcreteModuleTypeBuilder data);
193 
194   static std::shared_ptr<ConcreteModuleType> fromJitType(TypePtr type);
195 
196   TypePtr getJitType() const;
197   std::optional<py::object> getPyClass() const;
198   IterableModuleKind getIterableModuleKind() const;
199   std::optional<std::vector<std::string>> findOverloads(
200       const std::string& name) const;
201   std::optional<Function*> findFunctionAttribute(const std::string& name) const;
202   std::optional<c10::Symbol> findBuiltinFunction(const std::string& name) const;
203   std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
204       const std::string& name) const;
205   std::optional<std::string> findFailedAttribute(const std::string& name) const;
206   bool isIgnoredAttribute(const std::string& name) const;
207 
208   // These getters are only here to return things as types that can be
209   // automatically converted by pybind.
210   std::unordered_map<std::string, py::object> getConstantsPy() const;
211   std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy()
212       const;
213   std::vector<std::pair<std::string, std::shared_ptr<ConcreteModuleType>>>
214   getModulesPy() const;
215 
equals(const ConcreteModuleType & other)216   bool equals(const ConcreteModuleType& other) const {
217     if (jitType_ == other.jitType_) {
218       // If the computed types are the same, these modules can (obviously) share
219       // a type.
220       return true;
221     }
222 
223     return data_.equals(other.data_);
224   }
equals(const ConcreteModuleTypeBuilder & other)225   bool equals(const ConcreteModuleTypeBuilder& other) const {
226     return data_.equals(other);
227   }
228 
229   void dump() const;
230 
231  private:
232   ConcreteModuleType() = default;
233 
234   // The JIT type derived from this ConcreteModuleType.
235   ConcreteModuleTypeBuilder data_;
236   TypePtr jitType_;
237 };
238 
239 } // namespace torch::jit
240