xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/list_model_parameters.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/error_report.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/dead_code_elimination.h>
4 #include <torch/csrc/jit/passes/onnx/helper.h>
5 #include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
6 
7 namespace torch::jit {
8 
9 namespace onnx {
10 using namespace ::c10::onnx;
11 }
12 
13 // findSubModuleAttr function chases getAttr chains backwards to locate the
14 // submodules. For example: module M {
15 //   attributes {
16 //     A = <SubModule at ...>
17 //   }
18 //   ...
19 //   %A = prim::GetAttr[name="A"](%self)
20 //   ...
21 //   %B = prim::GetAttr[name="B"](%A)
22 //   ...
23 //   %weight = prim::GetAttr[name="scale"](%B)
24 //   ...
findSubModuleAttr(Value * input,std::string & name,Module & attrModule,std::shared_ptr<Graph> & graph)25 std::deque<std::string> findSubModuleAttr(
26     Value* input,
27     std::string& name,
28     Module& attrModule,
29     std::shared_ptr<Graph>& graph) {
30   Node* node = input->node();
31   std::deque<std::string> moduleNames;
32 
33   // Loop starts from inner submodule and follows the chain until reaches the
34   // top module.
35 
36   while (node->outputs().at(0)->type() != graph->inputs().at(0)->type()) {
37     if (node->kind() == prim::GetAttr) {
38       moduleNames.push_front(node->s(attr::name));
39       node = node->inputs()[0]->node();
40     } else {
41       return moduleNames;
42     }
43   }
44   // Assign the inner module to attrModule.
45   for (auto& moduleName : moduleNames) {
46     attrModule = attrModule.attr(moduleName).toModule();
47   }
48   return moduleNames;
49 }
50 
addParamAsArgument(Function * function,std::string & name,IValue & attr)51 Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) {
52   auto schema = function->getSchema();
53   auto args = schema.arguments();
54   args.emplace_back(name, nullptr, std::nullopt, attr);
55   auto new_schema = FunctionSchema(
56       schema.name(),
57       schema.overload_name(),
58       args,
59       schema.returns(),
60       schema.is_vararg(),
61       schema.is_varret());
62   function->setSchema(new_schema);
63   return toGraphFunction(*function).graph()->addInput(name)->setType(
64       attr.type());
65 }
66 
getParamAttributes(Block * block,std::shared_ptr<Graph> & graph,const Module & module_,Function * function_,std::unordered_map<std::string,Value * > & attrValues)67 std::vector<IValue> getParamAttributes(
68     Block* block,
69     std::shared_ptr<Graph>& graph,
70     const Module& module_,
71     Function* function_,
72     std::unordered_map<std::string, Value*>& attrValues) {
73   auto isEval = !module_.hasattr("training") || !module_.is_training();
74 
75   Node* m = *block->nodes().begin();
76   WithInsertPoint guard(m);
77 
78   std::vector<IValue> parameterIValues = {};
79   std::unordered_set<Node*> nodesToDestroy;
80   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
81     Node* n = *it;
82     it++; // node n can be destroyed
83 
84     if (n->kind() == prim::GetAttr || n->kind() == prim::SetAttr) {
85       if (n->kind() == prim::GetAttr) {
86         for (auto use : n->output()->uses()) {
87           if (use.user->kind() == prim::PythonOp)
88             throw ErrorReport(n->sourceRange())
89                 << "Couldn't export Python method.";
90         }
91       }
92 
93       auto name = n->s(attr::name);
94       auto attrModule = module_;
95       auto input = n->inputs()[0];
96 
97       auto moduleNames = findSubModuleAttr(input, name, attrModule, graph);
98       if (!attrModule.hasattr(name))
99         continue;
100       auto attr = attrModule.attr(name);
101       Value* paramConst = nullptr;
102 
103       std::string fullName("");
104       for (auto& name : moduleNames) {
105         fullName += name + '.';
106       }
107       fullName += name;
108 
109       auto type = attrModule.type();
110       auto slot = *type->findAttributeSlot(name);
111 
112       // Add model_parameters and model_buffers as model inputs. Order is
113       // preserved based on the appearance in the graph.
114       if (type->is_parameter(slot) || type->is_buffer(slot) ||
115           (attr.isObject() && !attr.toObjectRef().type()->is_module()) ||
116           attr.isBool()) {
117         if (attrValues.find(fullName) == attrValues.end() &&
118             attr.isTensor()) { // TODO: Handle float/int
119           TORCH_INTERNAL_ASSERT(attr.isTensor());
120           auto tensor_ = attr.toTensor();
121           if (isEval && tensor_.requires_grad()) {
122             tensor_ = tensor_.detach();
123             tensor_.set_requires_grad(false);
124             attr = IValue(tensor_);
125           }
126           parameterIValues.emplace_back(attr.toTensor());
127           paramConst = addParamAsArgument(function_, fullName, attr);
128           attrValues.insert({fullName, paramConst});
129         } else if (attr.isObject() && !attr.toObjectRef().type()->is_module()) {
130           // Only below registered torch classes are supported.
131           try {
132             parameterIValues.emplace_back(
133                 script::Object(attr.toObject()).run_method("__getstate__"));
134             paramConst = addParamAsArgument(function_, fullName, attr);
135             attrValues.insert({fullName, paramConst});
136           } catch (const std::exception&) {
137             throw ErrorReport(n->sourceRange())
138                 << "Unknown type " << attr.type()->repr_str()
139                 << " encountered in handling model params."
140                 << " This class type does not extend __getstate__ method.";
141           }
142         } else if (attr.isNone() || (attr.isBool() && name == "training")) {
143           // This attr is constant for ONNX.
144           auto attrVal = tryInsertConstant(*graph, attr);
145           n->output()->replaceAllUsesWith(*attrVal);
146           nodesToDestroy.emplace(n);
147         }
148       }
149     }
150 
151     for (Block* sub_block : n->blocks()) {
152       auto nextParameterIValues =
153           getParamAttributes(sub_block, graph, module_, function_, attrValues);
154       parameterIValues.insert(
155           std::end(parameterIValues),
156           std::begin(nextParameterIValues),
157           std::end(nextParameterIValues));
158     }
159   }
160   for (auto n : nodesToDestroy) {
161     n->destroy();
162   }
163   return parameterIValues;
164 }
165 
insertMainModuleAsConstant(const std::shared_ptr<Graph> & graph)166 void insertMainModuleAsConstant(const std::shared_ptr<Graph>& graph) {
167   auto* constNode = graph->create(prim::CreateObject);
168   constNode->output()->setType(graph->inputs().at(0)->type());
169   auto it = graph->nodes().begin();
170   constNode->insertBefore(*it);
171   graph->inputs().at(0)->replaceAllUsesWith(constNode->output());
172   graph->eraseInput(0);
173 }
174 
list_module_parameters(const Module & module)175 std::pair<Module, std::vector<IValue>> list_module_parameters(
176     const Module& module) {
177   Module moduleClone = module.clone(true);
178   Method method = moduleClone.get_method("forward");
179   auto function = &method.function();
180   auto graph = toGraphFunction(*function).graph();
181   // A map of names and values of referenced attributes, to avoid duplicates.
182   std::unordered_map<std::string, Value*> attrValues = {};
183 
184   GRAPH_DEBUG("Fetch attributes for function: " + function->name());
185   std::vector<IValue> parameterIValues = getParamAttributes(
186       graph->block(), graph, moduleClone, function, attrValues);
187   insertMainModuleAsConstant(graph);
188   GRAPH_DEBUG("Listed parameters as inputs: ", *graph);
189 
190   return std::make_pair(moduleClone, parameterIValues);
191 }
192 
193 } // namespace torch::jit
194