xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/freeze_module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/freeze_module.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/api/function_impl.h>
7 #include <torch/csrc/jit/ir/alias_analysis.h>
8 #include <torch/csrc/jit/passes/autocast.h>
9 #include <torch/csrc/jit/passes/clear_profiling.h>
10 #include <torch/csrc/jit/passes/eliminate_no_ops.h>
11 #include <torch/csrc/jit/passes/inliner.h>
12 #include <torch/csrc/jit/passes/lower_tuples.h>
13 #include <torch/csrc/jit/runtime/graph_executor_impl.h>
14 
15 #include <stack>
16 #include <utility>
17 
18 namespace torch::jit {
19 
20 namespace {
21 
splitName(const std::string & name)22 std::vector<std::string> splitName(const std::string& name) {
23   std::vector<std::string> result;
24   std::string sub_name;
25   std::istringstream name_stream(name);
26   while (std::getline(name_stream, sub_name, '.')) {
27     result.push_back(std::move(sub_name));
28   }
29   return result;
30 }
31 
32 template <typename Iter>
concatName(const Iter & begin,const Iter & end)33 std::string concatName(const Iter& begin, const Iter& end) {
34   std::string combined_name = "";
35   for (Iter it = begin; it != end; ++it) {
36     const std::string& sub_name = *it;
37     if (!combined_name.empty()) {
38       combined_name += ".";
39     }
40     combined_name += sub_name;
41   }
42   return combined_name;
43 }
44 
45 class AttributePropagator {
46  public:
AttributePropagator(Module & module,std::vector<std::string> & preservedAttrs,bool freezeInterfaces,bool preserveParameters)47   AttributePropagator(
48       Module& module,
49       std::vector<std::string>& preservedAttrs,
50       bool freezeInterfaces,
51       bool preserveParameters)
52       : module_(module),
53         freezeInterfaces_(freezeInterfaces),
54         preserveParameters_(preserveParameters) {
55     auto checkName = [this](std::string& name) {
56       const auto resolved_name = resolveName(name);
57 
58       if (resolved_name) {
59         const auto& parent_module = resolved_name->first;
60         const auto& attr_name = resolved_name->second;
61         if (parent_module.hasattr(attr_name)) {
62           auto value = parent_module.attr(attr_name);
63           // Freezing client wants to preserve this submodule. When cleaning
64           // the frozen module, make sure it will be preserved entirely.
65           if (value.isModule()) {
66             preservedSubModule_.insert(value.toModule()._ivalue());
67           }
68           insertMutableAttr(attr_name, value, parent_module._ivalue());
69         } else {
70           auto fn = parent_module.get_method(attr_name);
71           preservedMethods_.insert(&fn.function());
72         }
73         return true;
74       }
75 
76       return false;
77     };
78 
79     // forward is preserved by default, but
80     // not all modules have a forward function defined
81     if (module_.find_method("forward")) {
82       auto method = module_.get_method("forward");
83       preservedMethods_.insert(&method.function());
84     }
85 
86     for (auto name : preservedAttrs) {
87       TORCH_CHECK(checkName(name), "Unknown name: " + name);
88     }
89   }
90 
optimizeSubGraphs(std::shared_ptr<Graph> & graph,const std::function<void (std::shared_ptr<Graph> &)> & func)91   void optimizeSubGraphs(
92       std::shared_ptr<Graph>& graph,
93       const std::function<void(std::shared_ptr<Graph>&)>& func) {
94     func(graph);
95     std::stack<Block*> blocks({graph->block()});
96     while (!blocks.empty()) {
97       Block* block = blocks.top();
98       blocks.pop();
99       for (auto n : block->nodes()) {
100         for (Block* sub_block : n->blocks()) {
101           blocks.push(sub_block);
102         }
103         if (n->kind() == prim::fork) {
104           auto subgraph = n->g(attr::Subgraph);
105           optimizeSubGraphs(subgraph, func);
106         }
107       }
108     }
109   }
110 
run()111   void run() {
112     auto applyInline = [](std::shared_ptr<Graph>& subgraph) {
113       Inline(*subgraph);
114       ClearProfilingInformation(subgraph);
115     };
116     auto applyOptimizations = [](std::shared_ptr<Graph>& subgraph) {
117 #ifndef C10_MOBILE
118       Autocast(subgraph);
119 #endif
120       runOptimization(
121           subgraph,
122           /* unroll_non_constant_loops? */ false,
123           /* const_prop_user_classes? */ false);
124       EliminateNoOps(subgraph);
125       LowerSimpleTuples(subgraph);
126     };
127 
128     std::unordered_map<std::string, std::unordered_set<std::string>>
129         interfacesToReassignType;
130 
131     for (auto function : preservedMethods_) {
132       GRAPH_DEBUG("Analyzing function: " + function->name());
133       auto graph = toGraphFunction(*function).graph();
134       optimizeSubGraphs(graph, applyInline);
135       if (freezeInterfaces_) {
136         inlineInterfaceCalls(graph, interfacesToReassignType);
137       }
138     }
139 
140     reassignInterfaceTypes(interfacesToReassignType);
141 
142     for (auto function : preservedMethods_) {
143       GRAPH_DEBUG("Recording mutable attrs for function: " + function->name());
144       auto graph = toGraphFunction(*function).graph();
145       // Record Attributes that are explicitly set in the module.
146       // They cannot be folded.
147       recordMutableAttrs(graph);
148     }
149 
150     for (auto function : preservedMethods_) {
151       GRAPH_DEBUG("Propagating function: " + function->name());
152       auto graph = toGraphFunction(*function).graph();
153       propagateAttributes(graph);
154       optimizeSubGraphs(graph, applyOptimizations);
155     }
156     GRAPH_DEBUG("Cleaning up module");
157     cleanupFrozenModule();
158   }
159 
160  private:
161   using ResolvedName = std::pair<Module, std::string>;
162 
163   // Try to resolve qualified names (submodule1.submodule2.foo). If
164   // the qualified name exists in the root module, return the unqualified
165   // attribute/function name and the parent module. Else, return nullopt.
166   // Examples:
167   // submodule1.submodule2.foo -> {submodule2, "foo"}
168   // submodule1.non_existent_module.foo -> nullopt
resolveName(const std::string & name)169   std::optional<ResolvedName> resolveName(const std::string& name) {
170     auto sub_names = splitName(name);
171     if (sub_names.empty()) {
172       return std::nullopt;
173     }
174     auto& attr_name = sub_names.back();
175     auto cur_module = module_;
176     std::vector<ResolvedName> attr_infos;
177     attr_infos.reserve(sub_names.size() - 1);
178 
179     for (size_t i = 0; i < sub_names.size() - 1; ++i) {
180       bool found = false;
181       const auto& sub_name = sub_names[i];
182       for (const auto& child_module : cur_module.named_children()) {
183         if (child_module.name == sub_name) {
184           attr_infos.emplace_back(cur_module._ivalue(), child_module.name);
185           cur_module = child_module.value;
186           found = true;
187           break;
188         }
189       }
190       if (!found) {
191         return std::nullopt;
192       }
193     }
194 
195     if (cur_module.hasattr(attr_name) || cur_module.find_method(attr_name)) {
196       // We don't want to mark these modules as mutable yet; that could
197       // interfere with the inlining procedure. Instead, we'll record
198       // the fact that the user wants to preserve them. They will be
199       // processed during clean-up preparation (recordReferenceAttrs)
200       for (auto& attr_info : attr_infos) {
201         const auto& parent_module = attr_info.first;
202         auto& sub_name = attr_info.second;
203         userPreservedAttrs_[parent_module._ivalue()].insert(
204             std::move(sub_name));
205       }
206       return std::make_pair(std::move(cur_module), std::move(attr_name));
207     }
208 
209     return std::nullopt;
210   }
211 
_loadModulePath(Value * input,std::shared_ptr<Graph> & graph)212   bool _loadModulePath(Value* input, std::shared_ptr<Graph>& graph) {
213     Node* node = input->node();
214     names_.clear();
215     while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) {
216       if (node->kind() == prim::GetAttr) {
217         names_.push_front(node->s(attr::name));
218         node = node->inputs()[0]->node();
219       } else {
220         return false;
221       }
222     }
223 
224     return true;
225   }
226 
getModulePath(Value * input,std::shared_ptr<Graph> & graph)227   std::optional<std::deque<std::string>> getModulePath(
228       Value* input,
229       std::shared_ptr<Graph>& graph) {
230     bool success = _loadModulePath(input, graph);
231     if (!success) {
232       return std::nullopt;
233     }
234     return names_;
235   }
236 
237   template <typename Iter>
getModuleFromPath(Module & attrModule,const Iter & begin,const Iter & end)238   bool getModuleFromPath(
239       Module& attrModule,
240       const Iter& begin,
241       const Iter& end) {
242     for (Iter it = begin; it != end; ++it) {
243       const std::string& moduleName = *it;
244       if (preservedAttrs_.count(attrModule.attr(moduleName))) {
245         return false;
246       }
247       attrModule = attrModule.attr(moduleName).toModule();
248     }
249     return true;
250   }
251 
252   // findConstantAttr function locates the sub Module where attributes are
253   // defined. The algorithm chases getAttr chains to locate the submodules.
254   // For example:
255   // module M {
256   //   attributes {
257   //     A = <SubModule at ...>
258   //   }
259   //   ...
260   //   %A = prim::GetAttr[name="A"](%self)
261   //   ...
262   //   %B = prim::GetAttr[name="B"](%A)
263   //   ...
264   //   %weight = prim::GetAttr[name="scale"](%B)
265   //   ...
266   //   submodules {
267   //     module SubModule {
268   //       attributes {
269   //          B = <SubModule2 at ...>
270   //       }
271   //       submodules {
272   //         module SubModule2 {
273   //            attributes {
274   //               scale = 2
275   //            }
276   //         }
277   //       }
278   //     }
279   //   }
280   //
281   // findConstantAttr(%B, "scale", M)  returns true because there are no
282   // explicit SetAttr that modifies %B. attrModule points to the module where
283   // attribute lives (in this example it is <SubModule2 at ...>).
284   //
285   // Note inplace mutations to attributes are checked later using alias
286   // analysis.
287   //
288   // We can use a more efficient algorithm to hash each constant GetAttr to its
289   // corresponding value. Based on initial test on resnet50 and other torch
290   // vision tests. GetAttrs are not too frequent so it is ok to chase GetAttr
291   // chain to retrieve their values.
findConstantAttr(Value * input,std::string & name,Module & attrModule,std::shared_ptr<Graph> & graph)292   bool findConstantAttr(
293       Value* input,
294       std::string& name,
295       Module& attrModule,
296       std::shared_ptr<Graph>& graph) {
297     if (!input->type()->cast<InterfaceType>() &&
298         !input->type()->expectRef<ClassType>().is_module()) {
299       return false;
300     }
301 
302     // loads the path into this->names_
303     if (!_loadModulePath(input, graph)) {
304       return false;
305     }
306 
307     // reassigns attrModule to the module in names_
308     if (!getModuleFromPath(attrModule, names_.begin(), names_.end())) {
309       return false;
310     }
311 
312     auto attr = attrModule.attr(name);
313     if (!AliasDb::isMutableType(attr.type())) {
314       auto it = preservedScalarAttrs_.find(attrModule._ivalue());
315       return it == preservedScalarAttrs_.end() || !it->second.count(name);
316     }
317 
318     if (preservedAttrs_.count(attr)) {
319       return false;
320     }
321     if (!attr.type()->cast<ClassType>()) {
322       for (auto& ivalue : preservedAttrs_) {
323         if (!ivalue.isObject() && ivalue.overlaps(attr)) {
324           return false;
325         }
326       }
327     }
328     return true;
329   }
330 
insertMutableAttr(const std::string & name,const IValue & attr,const ModulePtr & attrModule)331   void insertMutableAttr(
332       const std::string& name,
333       const IValue& attr,
334       const ModulePtr& attrModule) {
335     if (AliasDb::isMutableType(attr.type())) {
336       preservedAttrs_.insert(attr);
337     } else {
338       preservedScalarAttrs_[attrModule].insert(name);
339     }
340   }
341 
recordMutableAttrs(std::shared_ptr<Graph> & graph)342   void recordMutableAttrs(std::shared_ptr<Graph>& graph) {
343     std::stack<Block*> blocks({graph->block()});
344     std::unique_ptr<AliasDb> aliasDb =
345         std::make_unique<AliasDb>(graph, /* isFrozen */ true);
346     while (!blocks.empty()) {
347       Block* block = blocks.top();
348       blocks.pop();
349       for (auto n : block->nodes()) {
350         for (Block* sub_block : n->blocks()) {
351           blocks.push(sub_block);
352         }
353 
354         // Modules with prim::ModuleContainerIndex cannot be frozen because they
355         // return InterfaceTypes.
356         TORCH_CHECK(
357             n->kind() != prim::ModuleContainerIndex,
358             "Freezing modules containing prim::ModuleContainerIndex is not supported");
359 
360         if (n->kind() == prim::SetAttr || n->kind() == prim::GetAttr) {
361           // By default if interface attributes are present then fail freezing.
362           // If freezingInterfaces is on then Interfaces are folded similarly
363           // to other attributes.
364           TORCH_CHECK(
365               freezeInterfaces_ ||
366                   !(n->kind() == prim::GetAttr &&
367                     n->output()->type()->cast<InterfaceType>()),
368               "attempted to freeze a module that uses interface attributes");
369           auto name = n->s(attr::name);
370           auto attrModule = module_;
371           if (!findConstantAttr(n->inputs()[0], name, attrModule, graph)) {
372             continue;
373           }
374 
375           auto attr = attrModule.attr(name);
376           if (n->kind() == prim::GetAttr) {
377             auto type = n->output()->type();
378             // Do not record submodules. Their attributes are tracked
379             // individually.
380             if (attr.isObject() || !AliasDb::isMutableType(attr.type())) {
381               continue;
382             }
383             usedAttrs_.insert(attr);
384           }
385 
386           if (n->kind() == prim::SetAttr || aliasDb->hasOutputWriters(n)) {
387             GRAPH_DEBUG(
388                 n->kind() == prim::GetAttr ? "attribute: " + name + " in %" +
389                         n->output()->debugName() + " has inplace writer"
390                                            : "attribute: " + name + " is set");
391             auto mptr = attrModule._ivalue();
392             insertMutableAttr(name, attr, mptr);
393           }
394         } else if (n->kind() == prim::fork) {
395           applyToForkSubgraph(
396               n,
397               graph,
398               // NOLINTNEXTLINE(modernize-avoid-bind)
399               std::bind(
400                   &AttributePropagator::recordMutableAttrs,
401                   *this,
402                   std::placeholders::_1));
403         }
404       }
405     }
406     // FIXME: Current Alias analysis fails to track subvalues.
407     // This is not a common scenario, for freezing, detect and error out.
408     IValue::HashAliasedIValues seen;
409     for (auto& val : usedAttrs_) {
410       IValue::HashAliasedIValues subValues;
411       val.getSubValues(subValues);
412       TORCH_CHECK(
413           std::all_of(
414               subValues.begin(),
415               subValues.end(),
416               [&seen](const IValue& v) { return seen.count(v) == 0; }),
417           "module contains attributes values that overlaps ",
418           val);
419       seen.insert(subValues.begin(), subValues.end());
420     }
421   }
422 
overrideGradient(IValue attr)423   IValue overrideGradient(IValue attr) {
424     if (attr.isTensor()) {
425       auto& t = attr.toTensor();
426       if (t.requires_grad()) {
427         auto detached = t.detach();
428         detached.set_requires_grad(false);
429         attr = IValue(std::move(detached));
430       }
431     } else if (attr.isTuple()) {
432       auto tuple = std::move(attr).toTuple();
433       const auto& elems = tuple->elements();
434       for (const auto idx : c10::irange(elems.size())) {
435         tuple->unsafeSetElement(idx, overrideGradient(elems[idx]));
436       }
437       attr = std::move(tuple);
438     } else if (attr.isList()) {
439       c10::List<IValue> elems = std::move(attr).toList();
440       for (const auto i : c10::irange(elems.size())) {
441         elems.set(i, overrideGradient(elems.extract(i)));
442       }
443       attr = elems;
444     } else if (attr.isGenericDict()) {
445       auto dict = std::move(attr).toGenericDict();
446       for (const auto& pair : dict) {
447         auto val = pair.value();
448         val = overrideGradient(std::move(val));
449       }
450       attr = dict;
451     } else if (attr.isObject() && !attr.toObjectRef().type()->is_module()) {
452       auto obj_type = attr.type()->expect<ClassType>();
453       auto obj_value = std::move(attr).toObject();
454       auto sub_attributes = obj_type->getAttributes();
455       for (const auto& sub_attr : sub_attributes) {
456         auto sub_attr_val = obj_value->getAttr(sub_attr.getName());
457         sub_attr_val = overrideGradient(std::move(sub_attr_val));
458       }
459       return obj_value;
460     }
461 
462     return attr;
463   }
464 
465   // This method is invoked only when 'freezeInterfaces' parameter is on.
466   // The module associated with Interface is retrieved and the invoked method
467   // is inlined.
inlineInterfaceCall(Node * n,const IValue & attr)468   bool inlineInterfaceCall(Node* n, const IValue& attr) {
469     auto class_type = attr.type()->expect<ClassType>();
470     bool inlined = false;
471     for (auto use : n->output()->uses()) {
472       auto user_node = use.user;
473       if (user_node->kind() == prim::CallMethod) {
474         const std::string& methodName = user_node->s(attr::name);
475         Function& function = class_type->getMethod(methodName);
476         if (auto graphFunction = tryToGraphFunction(function)) {
477           GRAPH_UPDATE(
478               "Inlining interface method '",
479               function.name(),
480               "' to ",
481               *user_node);
482 
483           GRAPH_UPDATE("Function body: ", graphFunction->optimized_graph());
484           inlineCallTo(user_node, graphFunction);
485           inlined = true;
486         }
487       }
488     }
489     return inlined;
490   }
491 
492   //   [Note: Inlining interfaces strategy]
493   // There's two structures that are relevant to freezing:
494   // - the graph describing the computation in a method
495   // - the module describing the data structure of the module instance.
496   //
497   // First, in inlineInterfaceCalls, we inline interfaces. This is done in a
498   // separate step from normal inlining because CallMethod on an interface type
499   // requires extra steps compared to inlining a normal CallMethod.
500   //
501   // Next we need to simplify the structure of the module data structure, which
502   // is done for the most part by the usual steps in cleanupFrozenModule.
503   //
504   // However, there's a complication that comes from the fact that within a
505   // method, you can change the value of an interface to another module that
506   // implements that interface.
507   //
508   // For example:
509   //
510   // impl: MyInterface
511   // ...
512   // def forward(self, x):
513   //     if x > 0:
514   //         self.impl = my_interface_impl
515   //
516   // This is disallowed in freezing, because in this case we can't flatten out
517   // the module structure, since the type of self.impl will change.
518   //
519   // To handle this, we do the following:
520   //   1. inlineInterfaceCalls:
521   //     a. inline the graph, and in the process record all interfaces
522   //     b. simultaneously, check (throw error) for disallowed SetAttr calls.
523   //   2. call reassignInterfaceTypes, which reassigns interface types to their
524   //      concrete types. This is done in a separate step to avoid interfering
525   //      with inlineInterfaceCalls (note: this may not need to be done as a
526   //      separate step)
527   //   3. eventually cleanupFrozenModule will reorder the module data structure
528   //      and it will expect that all interface types have been removed.
inlineInterfaceCalls(std::shared_ptr<Graph> & graph,std::unordered_map<std::string,std::unordered_set<std::string>> & interfacesToRetype)529   void inlineInterfaceCalls(
530       std::shared_ptr<Graph>& graph,
531       std::unordered_map<std::string, std::unordered_set<std::string>>&
532           interfacesToRetype) {
533     auto block = graph->block();
534     std::stack<Block*> blocks({block});
535 
536     while (!blocks.empty()) {
537       Block* block = blocks.top();
538       blocks.pop();
539       for (auto n : block->nodes()) {
540         for (Block* sub_block : n->blocks()) {
541           blocks.push(sub_block);
542         }
543         if (n->kind() == prim::GetAttr) {
544           if (!n->output()->type()->cast<InterfaceType>()) {
545             continue;
546           }
547           auto name = n->s(attr::name);
548           auto attrModule = module_;
549           auto input = n->inputs()[0];
550           TORCH_CHECK(
551               findConstantAttr(input, name, attrModule, graph),
552               "failed to freeze interface attribute '" + name + "'");
553           TORCH_INTERNAL_ASSERT(attrModule.hasattr(name));
554           auto attr = attrModule.attr(name);
555           inlineInterfaceCall(n, attr);
556           // Reset the GetAttr to concrete module type.
557           n->output()->setType(attr.type());
558 
559           // Record this so that we can reassign the type later
560           // in reassignInterfaceTypes()
561           // See [Note: Inlining interfaces strategy]
562           auto path = getModulePath(input, graph);
563           TORCH_INTERNAL_ASSERT(path.has_value());
564           auto path_str = concatName(path->begin(), path->end());
565           interfacesToRetype[path_str].insert(name);
566         } else if (n->kind() == prim::SetAttr) {
567           // Check to make sure we're not assigning the value of any parameters
568           // that are interface types.
569           // See [Note: Inlining interfaces strategy]
570           auto name = n->s(attr::name);
571           auto attrModule = module_;
572           auto input = n->inputs()[0];
573 
574           if (!input->type()->cast<InterfaceType>() &&
575               !input->type()->expectRef<ClassType>().is_module()) {
576             // we only care if we're setattr["thing"](%mod) if %mod
577             continue;
578           }
579 
580           // note: this will modify attrModule until it is the parent of the
581           // "name" attr. In other words, attrModule is now the module that
582           // matches "input".
583           // We can't use findConstantAttr in case the base item is an object,
584           // instead of a module/interface.
585           auto path = getModulePath(input, graph);
586           TORCH_INTERNAL_ASSERT(path.has_value());
587           getModuleFromPath(attrModule, path->begin(), path->end());
588 
589           const auto& attrType = attrModule.type()->getAttribute(name);
590           TORCH_INTERNAL_ASSERT(
591               !attrType->cast<InterfaceType>(),
592               "Freezing does not support SetAttr on an interface type. ",
593               "SetAttr is attempted on '",
594               name,
595               "'");
596         } else if (n->kind() == prim::fork) {
597           applyToForkSubgraph(
598               n,
599               graph,
600               // NOLINTNEXTLINE(modernize-avoid-bind)
601               std::bind(
602                   &AttributePropagator::inlineInterfaceCalls,
603                   *this,
604                   std::placeholders::_1,
605                   interfacesToRetype));
606         }
607       }
608     }
609   }
610 
611   // See [Note: Inlining interfaces strategy]
612   // This modifies the internal structure of module types to reassign the
613   // type from an interface type to its concrete type.
reassignInterfaceTypes(const std::unordered_map<std::string,std::unordered_set<std::string>> & interfacesToRetype)614   void reassignInterfaceTypes(
615       const std::unordered_map<std::string, std::unordered_set<std::string>>&
616           interfacesToRetype) {
617     for (const auto& it : interfacesToRetype) {
618       const std::string& modulePath = it.first;
619       const std::vector<std::string>& splitPath = splitName(modulePath);
620       Module attrModule = module_;
621       getModuleFromPath(attrModule, splitPath.begin(), splitPath.end());
622 
623       for (const std::string& name : it.second) {
624         auto subvalue = attrModule.attr(name);
625         auto subvalueType = subvalue.type();
626         attrModule.type()->unsafeChangeAttributeType(name, subvalueType);
627       }
628     }
629   }
630 
propagateAttributes(std::shared_ptr<Graph> & graph)631   void propagateAttributes(std::shared_ptr<Graph>& graph) {
632     std::unordered_map<ModulePtr, std::unordered_map<std::string, Value*>>
633         attrValues;
634     auto isEval = !module_.hasattr("training") || !module_.is_training();
635     GRAPH_DEBUG("Freezing Module: ", module_.type()->name()->name());
636     auto block = graph->block();
637     std::stack<Block*> blocks({block});
638 
639     Node* m = *block->nodes().begin();
640     WithInsertPoint guard(m);
641     while (!blocks.empty()) {
642       Block* block = blocks.top();
643       blocks.pop();
644       for (auto it = block->nodes().begin(); it != block->nodes().end();) {
645         Node* n = *it;
646         it++; // advance iterator bc the current node may be destroyed
647 
648         for (Block* sub_block : n->blocks()) {
649           blocks.push(sub_block);
650         }
651         if (n->kind() == prim::GetAttr) {
652           auto name = n->s(attr::name);
653           auto attrModule = module_;
654           auto input = n->inputs()[0];
655           if (!findConstantAttr(input, name, attrModule, graph)) {
656             GRAPH_DEBUG(
657                 input->type()->cast<InterfaceType>() ||
658                         input->type()->expectRef<ClassType>().is_module()
659                     ? "attribute: " + name + " is mutable."
660                     : "");
661             continue;
662           }
663           TORCH_INTERNAL_ASSERT(attrModule.hasattr(name));
664           Value* paramConst = nullptr;
665           auto iter = attrValues.find(attrModule._ivalue());
666           if (iter != attrValues.end()) {
667             auto iter2 = iter->second.find(name);
668             if (iter2 != iter->second.end())
669               paramConst = iter2->second;
670           }
671           if (!paramConst) {
672             auto attr = attrModule.attr(name);
673             if (!isEval || preserveParameters_) {
674               auto type = attrModule.type();
675               auto slot = *type->findAttributeSlot(name);
676               if (type->is_parameter(slot) || type->is_buffer(slot) ||
677                   (attr.isObject() &&
678                    !attr.toObjectRef().type()->is_module())) {
679                 continue;
680               } else {
681                 attr = overrideGradient(attr);
682               }
683               if (!isEval && name == "training") {
684                 continue;
685               }
686             } else {
687               attr = overrideGradient(attr);
688             }
689             if (attr.isObject()) {
690               if (object_memo_.count(attr.toObject())) {
691                 attr = object_memo_[attr.toObject()];
692               } else {
693                 auto weak_class_obj =
694                     attr.toObject()->copy_to_weak_compilation_ref();
695                 object_memo_[attr.toObject()] = weak_class_obj;
696                 attr = weak_class_obj;
697               }
698             }
699             if (auto attrVal = tryInsertConstant(*graph, attr)) {
700               paramConst = *attrVal;
701             } else {
702               GRAPH_DEBUG(
703                   attr.type()->cast<ClassType>() ? "" : "attribute: ",
704                   name,
705                   " is not materializable.");
706               continue;
707             }
708             std::string fullName("self.");
709             for (auto& name : names_) {
710               fullName += name + '.';
711             }
712             fullName += name;
713             paramConst->setDebugName(fullName);
714             attrValues[attrModule._ivalue()][name] = paramConst;
715           }
716           GRAPH_UPDATE(
717               "Folding GetAttr %",
718               n->outputs()[0]->debugName(),
719               " with ",
720               paramConst->debugName());
721           n->outputs().at(0)->replaceAllUsesWith(paramConst);
722           n->removeAllInputs();
723         } else if (n->kind() == prim::fork) {
724           applyToForkSubgraph(
725               n,
726               graph,
727               // NOLINTNEXTLINE(modernize-avoid-bind)
728               std::bind(
729                   &AttributePropagator::propagateAttributes,
730                   *this,
731                   std::placeholders::_1));
732         }
733       }
734     }
735   }
736 
applyToForkSubgraph(Node * n,std::shared_ptr<Graph> & graph,const std::function<void (std::shared_ptr<Graph> &)> & func)737   void applyToForkSubgraph(
738       Node* n,
739       std::shared_ptr<Graph>& graph,
740       const std::function<void(std::shared_ptr<Graph>&)>& func) {
741     TORCH_CHECK(n->kind() == prim::fork);
742     auto attrModule = module_;
743     auto node = n->inputs()[0]->node();
744     // Check if first parameter of fork is a module. This module is used
745     // as the base module (similar to 'self' in forward) to resolve GetAttrs.
746     //  Otherwise freezing is applied using module_
747     if (node->kind() == prim::GetAttr &&
748         node->output()->type()->cast<ClassType>()) {
749       auto name = node->s(attr::name);
750       auto input = node->inputs()[0];
751       if (!findConstantAttr(input, name, attrModule, graph)) {
752         // Module needs to be preserved.
753         return;
754       }
755       attrModule = attrModule.attr(name).toModule();
756       std::swap(module_, attrModule);
757     }
758 
759     auto subgraph = n->g(attr::Subgraph);
760     func(subgraph);
761     module_ = attrModule;
762   }
763 
moduleEscapes(Module & subModule,std::shared_ptr<Graph> & graph)764   bool moduleEscapes(Module& subModule, std::shared_ptr<Graph>& graph) {
765     for (auto& output : graph->outputs()) {
766       if (subModule.type()->isSubtypeOf(*output->type())) {
767         return true;
768       }
769     }
770     return preservedSubModule_.count(subModule._ivalue());
771   }
772 
removeExtraWaitCalls(Block * b)773   void removeExtraWaitCalls(Block* b) {
774     auto nodes = b->nodes();
775     for (auto it = nodes.begin(); it != nodes.end(); it++) {
776       auto node = *it;
777       if (node->kind() != aten::wait) {
778         continue;
779       }
780       TORCH_INTERNAL_ASSERT(node->inputs().size() == 1);
781       TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
782       // If input type is not a from aten::fork call then the
783       // aten::wait operator can be deleted.
784       if (node->input()->type()->kind() != TypeKind::FutureType) {
785         node->output()->replaceAllUsesWith(node->input());
786         it.destroyCurrent();
787       }
788     }
789     // For the remaining nodes, recurse.
790     for (auto it = nodes.begin(); it != nodes.end(); it++) {
791       auto node = *it;
792       for (auto sub_b : node->blocks()) {
793         removeExtraWaitCalls(sub_b);
794       }
795     }
796   }
797 
798   // cleanupFrozenModule function cleans up the Frozen module. It performs the
799   // following:
800   // 1) Remove unused attributes.
801   // 2) Remove unreferenced submodules
802   // 3) Remove non public unreferenced methods.
cleanupFrozenModule()803   void cleanupFrozenModule() {
804     for (auto function : preservedMethods_) {
805       auto graph = toGraphFunction(*function).graph();
806       recordReferencedAttrs(graph);
807       handleSharedClassType(module_, graph);
808       removeExtraWaitCalls(graph->block());
809       toGraphFunction(*function).clear_optimized_graphs();
810     }
811     removeUnusedAttrs();
812   }
813 
814   // Preparing for clean up phase. At this point, record all subModules that
815   // contains mutable attributes.
recordReferencedAttrs(std::shared_ptr<Graph> & graph)816   void recordReferencedAttrs(std::shared_ptr<Graph>& graph) {
817     std::stack<Block*> blocks({graph->block()});
818     std::set<ModulePtr> modules({module_._ivalue()});
819     while (!blocks.empty()) {
820       Block* block = blocks.top();
821       blocks.pop();
822       for (auto n : block->nodes()) {
823         for (Block* subBlock : n->blocks()) {
824           blocks.push(subBlock);
825         }
826         if (n->kind() == prim::GetAttr) {
827           auto& name = n->s(attr::name);
828           // For now, use all module ivalues which are the same type
829           // and could be the module that this GetAttr resolves to
830           // TODO: we could attempt to follow the GetAttr chain and
831           // find the exact ivalue, we would have to be careful
832           // that the chain does not contain any attributes which
833           // get written to (setAttr calls)
834           for (auto& mptr : modules) {
835             auto module = Module(mptr);
836             if (module.type() == n->inputs()[0]->type()) {
837               TORCH_INTERNAL_ASSERT(module.hasattr(name));
838               auto module = Module(mptr);
839               auto attr = module.attr(name);
840               // TODO: this could be insertReferencedAttr to be more clear,
841               // these are attributes we could not inline, which include
842               // other reasons besides mutation (unsupported constant,
843               // getAttr resolving to non-getAttr node, etc)
844               insertMutableAttr(name, attr, mptr);
845               if (attr.isModule()) {
846                 modules.insert(attr.toModule()._ivalue());
847               }
848             }
849           }
850         } else if (n->kind() == prim::fork) {
851           applyToForkSubgraph(
852               n,
853               graph,
854               // NOLINTNEXTLINE(modernize-avoid-bind)
855               std::bind(
856                   &AttributePropagator::recordReferencedAttrs,
857                   *this,
858                   std::placeholders::_1));
859         }
860       }
861     }
862     // We have to process the attributes that the user wants to preserve
863     // separately since it's possible that the user-preserved module is
864     // never referenced in the graph.
865     for (const auto& attr_info : userPreservedAttrs_) {
866       const auto& parent_module = attr_info.first;
867       for (const auto& attr_name : attr_info.second) {
868         const auto value = parent_module->getAttr(attr_name);
869         insertMutableAttr(attr_name, value, parent_module);
870       }
871     }
872   }
873 
874   // This function recursively iterates over submodules to identify
875   // for each class type the attribute slots that need to be preserved.
876   //
877   // Note 'attrsToKeep[type].insert(type->numAttributes())' means all
878   // attribute slots of 'type' and its methods are preserved. A submodule is
879   // preserved when it escapes (meaning it is returned).
handleSharedClassType(Module & module,std::shared_ptr<Graph> & graph)880   void handleSharedClassType(Module& module, std::shared_ptr<Graph>& graph) {
881     auto type = module.type();
882     size_t N = type->numAttributes();
883     if (moduleEscapes(module, graph)) {
884       // Preserve all its attributes and methods.
885       attrsToKeep_[type].insert(N);
886       return;
887     }
888     auto it2 = preservedScalarAttrs_.find(module._ivalue());
889     SharedTypeSubModules_[type].insert(module._ivalue());
890     attrsToKeep_[type].insert({});
891     for (const auto i : c10::irange(N)) {
892       auto name = type->getAttributeName(i);
893       auto attr = module.attr(name);
894       auto attrTy = attr.type();
895 
896       bool isMutable = false;
897       if (AliasDb::isMutableType(attrTy)) {
898         isMutable = preservedAttrs_.count(attr);
899       } else {
900         isMutable =
901             it2 != preservedScalarAttrs_.end() && it2->second.count(name);
902       }
903       if (isMutable) {
904         attrsToKeep_[type].insert(i);
905         if (attr.isModule()) {
906           // See [Note: Inlining interfaces strategy]
907           TORCH_CHECK(
908               !type->getAttribute(i)->cast<InterfaceType>(),
909               "Unexpected interface attribute '" + name + "' during freezing");
910 
911           auto attrModule = attr.toModule();
912           handleSharedClassType(attrModule, graph);
913         }
914       }
915     }
916   }
917 
918   // Remove unused attributes and methods for each sub module of the frozen
919   // module. This function iterates over the Classtypes of its submodule
920   // attributes including its own type.
removeUnusedAttrs()921   void removeUnusedAttrs() {
922     std::vector<std::string> attrsToRemove;
923     std::vector<Function*> funcsToRemove;
924     for (auto& it : attrsToKeep_) {
925       auto& type = it.first;
926       size_t N = type->numAttributes();
927       if (it.second.count(N)) {
928         continue;
929       }
930       for (const auto i : c10::irange(N)) {
931         if (it.second.count(i) == 0) {
932           attrsToRemove.push_back(type->getAttributeName(i));
933         }
934       }
935       for (auto& fn : type->methods()) {
936         if (preservedMethods_.count(fn)) {
937           continue;
938         }
939         funcsToRemove.push_back(fn);
940       }
941 
942       for (auto& name : attrsToRemove) {
943         for (auto& val : SharedTypeSubModules_[type]) {
944           auto mod = val.toModule();
945           mod._ivalue()->unsafeRemoveAttr(name);
946         }
947         type->unsafeRemoveAttribute(name);
948       }
949       for (auto fn : funcsToRemove) {
950         type->unsafeRemoveMethod(fn->name());
951         auto mod = SharedTypeSubModules_[type].begin()->toModule();
952         mod._ivalue()->compilation_unit()->unsafeRemoveMethod(fn->qualname());
953       }
954 
955       attrsToRemove.clear();
956       funcsToRemove.clear();
957     }
958   }
959 
960   // Contains attributes that can't be folded or user directs to keep them.
961   IValue::HashAliasedIValues preservedAttrs_;
962   // Tracked immutable types (Scalars) by their attribute names not
963   // IValues.
964   std::unordered_map<ModulePtr, std::unordered_set<std::string>>
965       preservedScalarAttrs_;
966 
967   // Contains user specified methods to be preserved in frozen module.
968   std::unordered_set<Function*> preservedMethods_;
969 
970   // Contains user specified sub module to be preserve in frozen module.
971   std::unordered_set<ModulePtr> preservedSubModule_;
972 
973   // Track all used attributes ivalues that can be aliased.
974   IValue::HashAliasedIValues usedAttrs_;
975 
976   // Contains the attribute slots that need to be preserved for each ClassType.
977   std::unordered_map<ClassTypePtr, std::unordered_set<size_t>> attrsToKeep_;
978 
979   // Contains the sub modules that share the same ClassType.
980   std::unordered_map<ClassTypePtr, IValue::HashAliasedIValues>
981       SharedTypeSubModules_;
982 
983   Module& module_;
984 
985   // Allow to freeze modules containing interfaces.
986   bool freezeInterfaces_;
987 
988   // Preserve module parameters
989   bool preserveParameters_;
990 
991   // Contains the attributes names (e.g. {"self", "subModule", "a"}
992   std::deque<std::string> names_;
993 
994   // see [Constant Object Weak CompilationUnit Reference]
995   std::unordered_map<
996       c10::intrusive_ptr<at::ivalue::Object>,
997       c10::intrusive_ptr<at::ivalue::Object>>
998       object_memo_;
999 
1000   // Contains names of attributes that the user wants to preserve with
1001   // their owning modules.
1002   std::unordered_map<ModulePtr, std::unordered_set<std::string>>
1003       userPreservedAttrs_;
1004 
1005 }; // class AttributePropagator
1006 
checkModuleDoesNotReturnSelf(const Module & module)1007 void checkModuleDoesNotReturnSelf(const Module& module) {
1008   if (module.find_method("forward")) {
1009     Method method = module.get_method("forward");
1010     // Check that module does not return itself.
1011     for (auto& output : method.graph()->outputs()) {
1012       TORCH_CHECK(
1013           output->type() != module.type(),
1014           "attempted to freeze a module that return itself");
1015     }
1016   }
1017 }
1018 } // namespace
1019 
freeze_module(const Module & module,std::vector<std::string> preservedAttrs,bool freezeInterfaces,bool preserveParameters)1020 Module freeze_module(
1021     const Module& module,
1022     std::vector<std::string> preservedAttrs,
1023     bool freezeInterfaces,
1024     bool preserveParameters) {
1025   checkModuleDoesNotReturnSelf(module);
1026 
1027   auto moduleClone = module.clone(true);
1028   AttributePropagator attrPropagator(
1029       moduleClone, preservedAttrs, freezeInterfaces, preserveParameters);
1030   attrPropagator.run();
1031   return moduleClone;
1032 }
1033 
freeze_module_inplace(Module * module,std::vector<std::string> preservedAttrs,bool freezeInterfaces,bool preserveParameters)1034 void freeze_module_inplace(
1035     Module* module,
1036     std::vector<std::string> preservedAttrs,
1037     bool freezeInterfaces,
1038     bool preserveParameters) {
1039   TORCH_CHECK(module != nullptr, "module cannot be nullptr");
1040   checkModuleDoesNotReturnSelf(*module);
1041   AttributePropagator attrPropagator(
1042       *module, preservedAttrs, freezeInterfaces, preserveParameters);
1043   attrPropagator.run();
1044 }
1045 
1046 } // namespace torch::jit
1047