1 #include <torch/csrc/jit/frontend/error_report.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/dead_code_elimination.h>
4 #include <torch/csrc/jit/passes/onnx/helper.h>
5 #include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
6
7 namespace torch::jit {
8
9 namespace onnx {
10 using namespace ::c10::onnx;
11 }
12
13 // findSubModuleAttr function chases getAttr chains backwards to locate the
14 // submodules. For example: module M {
15 // attributes {
16 // A = <SubModule at ...>
17 // }
18 // ...
19 // %A = prim::GetAttr[name="A"](%self)
20 // ...
21 // %B = prim::GetAttr[name="B"](%A)
22 // ...
23 // %weight = prim::GetAttr[name="scale"](%B)
24 // ...
findSubModuleAttr(Value * input,std::string & name,Module & attrModule,std::shared_ptr<Graph> & graph)25 std::deque<std::string> findSubModuleAttr(
26 Value* input,
27 std::string& name,
28 Module& attrModule,
29 std::shared_ptr<Graph>& graph) {
30 Node* node = input->node();
31 std::deque<std::string> moduleNames;
32
33 // Loop starts from inner submodule and follows the chain until reaches the
34 // top module.
35
36 while (node->outputs().at(0)->type() != graph->inputs().at(0)->type()) {
37 if (node->kind() == prim::GetAttr) {
38 moduleNames.push_front(node->s(attr::name));
39 node = node->inputs()[0]->node();
40 } else {
41 return moduleNames;
42 }
43 }
44 // Assign the inner module to attrModule.
45 for (auto& moduleName : moduleNames) {
46 attrModule = attrModule.attr(moduleName).toModule();
47 }
48 return moduleNames;
49 }
50
addParamAsArgument(Function * function,std::string & name,IValue & attr)51 Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) {
52 auto schema = function->getSchema();
53 auto args = schema.arguments();
54 args.emplace_back(name, nullptr, std::nullopt, attr);
55 auto new_schema = FunctionSchema(
56 schema.name(),
57 schema.overload_name(),
58 args,
59 schema.returns(),
60 schema.is_vararg(),
61 schema.is_varret());
62 function->setSchema(new_schema);
63 return toGraphFunction(*function).graph()->addInput(name)->setType(
64 attr.type());
65 }
66
getParamAttributes(Block * block,std::shared_ptr<Graph> & graph,const Module & module_,Function * function_,std::unordered_map<std::string,Value * > & attrValues)67 std::vector<IValue> getParamAttributes(
68 Block* block,
69 std::shared_ptr<Graph>& graph,
70 const Module& module_,
71 Function* function_,
72 std::unordered_map<std::string, Value*>& attrValues) {
73 auto isEval = !module_.hasattr("training") || !module_.is_training();
74
75 Node* m = *block->nodes().begin();
76 WithInsertPoint guard(m);
77
78 std::vector<IValue> parameterIValues = {};
79 std::unordered_set<Node*> nodesToDestroy;
80 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
81 Node* n = *it;
82 it++; // node n can be destroyed
83
84 if (n->kind() == prim::GetAttr || n->kind() == prim::SetAttr) {
85 if (n->kind() == prim::GetAttr) {
86 for (auto use : n->output()->uses()) {
87 if (use.user->kind() == prim::PythonOp)
88 throw ErrorReport(n->sourceRange())
89 << "Couldn't export Python method.";
90 }
91 }
92
93 auto name = n->s(attr::name);
94 auto attrModule = module_;
95 auto input = n->inputs()[0];
96
97 auto moduleNames = findSubModuleAttr(input, name, attrModule, graph);
98 if (!attrModule.hasattr(name))
99 continue;
100 auto attr = attrModule.attr(name);
101 Value* paramConst = nullptr;
102
103 std::string fullName("");
104 for (auto& name : moduleNames) {
105 fullName += name + '.';
106 }
107 fullName += name;
108
109 auto type = attrModule.type();
110 auto slot = *type->findAttributeSlot(name);
111
112 // Add model_parameters and model_buffers as model inputs. Order is
113 // preserved based on the appearance in the graph.
114 if (type->is_parameter(slot) || type->is_buffer(slot) ||
115 (attr.isObject() && !attr.toObjectRef().type()->is_module()) ||
116 attr.isBool()) {
117 if (attrValues.find(fullName) == attrValues.end() &&
118 attr.isTensor()) { // TODO: Handle float/int
119 TORCH_INTERNAL_ASSERT(attr.isTensor());
120 auto tensor_ = attr.toTensor();
121 if (isEval && tensor_.requires_grad()) {
122 tensor_ = tensor_.detach();
123 tensor_.set_requires_grad(false);
124 attr = IValue(tensor_);
125 }
126 parameterIValues.emplace_back(attr.toTensor());
127 paramConst = addParamAsArgument(function_, fullName, attr);
128 attrValues.insert({fullName, paramConst});
129 } else if (attr.isObject() && !attr.toObjectRef().type()->is_module()) {
130 // Only below registered torch classes are supported.
131 try {
132 parameterIValues.emplace_back(
133 script::Object(attr.toObject()).run_method("__getstate__"));
134 paramConst = addParamAsArgument(function_, fullName, attr);
135 attrValues.insert({fullName, paramConst});
136 } catch (const std::exception&) {
137 throw ErrorReport(n->sourceRange())
138 << "Unknown type " << attr.type()->repr_str()
139 << " encountered in handling model params."
140 << " This class type does not extend __getstate__ method.";
141 }
142 } else if (attr.isNone() || (attr.isBool() && name == "training")) {
143 // This attr is constant for ONNX.
144 auto attrVal = tryInsertConstant(*graph, attr);
145 n->output()->replaceAllUsesWith(*attrVal);
146 nodesToDestroy.emplace(n);
147 }
148 }
149 }
150
151 for (Block* sub_block : n->blocks()) {
152 auto nextParameterIValues =
153 getParamAttributes(sub_block, graph, module_, function_, attrValues);
154 parameterIValues.insert(
155 std::end(parameterIValues),
156 std::begin(nextParameterIValues),
157 std::end(nextParameterIValues));
158 }
159 }
160 for (auto n : nodesToDestroy) {
161 n->destroy();
162 }
163 return parameterIValues;
164 }
165
insertMainModuleAsConstant(const std::shared_ptr<Graph> & graph)166 void insertMainModuleAsConstant(const std::shared_ptr<Graph>& graph) {
167 auto* constNode = graph->create(prim::CreateObject);
168 constNode->output()->setType(graph->inputs().at(0)->type());
169 auto it = graph->nodes().begin();
170 constNode->insertBefore(*it);
171 graph->inputs().at(0)->replaceAllUsesWith(constNode->output());
172 graph->eraseInput(0);
173 }
174
list_module_parameters(const Module & module)175 std::pair<Module, std::vector<IValue>> list_module_parameters(
176 const Module& module) {
177 Module moduleClone = module.clone(true);
178 Method method = moduleClone.get_method("forward");
179 auto function = &method.function();
180 auto graph = toGraphFunction(*function).graph();
181 // A map of names and values of referenced attributes, to avoid duplicates.
182 std::unordered_map<std::string, Value*> attrValues = {};
183
184 GRAPH_DEBUG("Fetch attributes for function: " + function->name());
185 std::vector<IValue> parameterIValues = getParamAttributes(
186 graph->block(), graph, moduleClone, function, attrValues);
187 insertMainModuleAsConstant(graph);
188 GRAPH_DEBUG("Listed parameters as inputs: ", *graph);
189
190 return std::make_pair(moduleClone, parameterIValues);
191 }
192
193 } // namespace torch::jit
194