1 #include <torch/csrc/jit/passes/freeze_module.h>
2
3 #include <torch/csrc/jit/jit_log.h>
4
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/api/function_impl.h>
7 #include <torch/csrc/jit/ir/alias_analysis.h>
8 #include <torch/csrc/jit/passes/autocast.h>
9 #include <torch/csrc/jit/passes/clear_profiling.h>
10 #include <torch/csrc/jit/passes/eliminate_no_ops.h>
11 #include <torch/csrc/jit/passes/inliner.h>
12 #include <torch/csrc/jit/passes/lower_tuples.h>
13 #include <torch/csrc/jit/runtime/graph_executor_impl.h>
14
15 #include <stack>
16 #include <utility>
17
18 namespace torch::jit {
19
20 namespace {
21
splitName(const std::string & name)22 std::vector<std::string> splitName(const std::string& name) {
23 std::vector<std::string> result;
24 std::string sub_name;
25 std::istringstream name_stream(name);
26 while (std::getline(name_stream, sub_name, '.')) {
27 result.push_back(std::move(sub_name));
28 }
29 return result;
30 }
31
32 template <typename Iter>
concatName(const Iter & begin,const Iter & end)33 std::string concatName(const Iter& begin, const Iter& end) {
34 std::string combined_name = "";
35 for (Iter it = begin; it != end; ++it) {
36 const std::string& sub_name = *it;
37 if (!combined_name.empty()) {
38 combined_name += ".";
39 }
40 combined_name += sub_name;
41 }
42 return combined_name;
43 }
44
45 class AttributePropagator {
46 public:
AttributePropagator(Module & module,std::vector<std::string> & preservedAttrs,bool freezeInterfaces,bool preserveParameters)47 AttributePropagator(
48 Module& module,
49 std::vector<std::string>& preservedAttrs,
50 bool freezeInterfaces,
51 bool preserveParameters)
52 : module_(module),
53 freezeInterfaces_(freezeInterfaces),
54 preserveParameters_(preserveParameters) {
55 auto checkName = [this](std::string& name) {
56 const auto resolved_name = resolveName(name);
57
58 if (resolved_name) {
59 const auto& parent_module = resolved_name->first;
60 const auto& attr_name = resolved_name->second;
61 if (parent_module.hasattr(attr_name)) {
62 auto value = parent_module.attr(attr_name);
63 // Freezing client wants to preserve this submodule. When cleaning
64 // the frozen module, make sure it will be preserved entirely.
65 if (value.isModule()) {
66 preservedSubModule_.insert(value.toModule()._ivalue());
67 }
68 insertMutableAttr(attr_name, value, parent_module._ivalue());
69 } else {
70 auto fn = parent_module.get_method(attr_name);
71 preservedMethods_.insert(&fn.function());
72 }
73 return true;
74 }
75
76 return false;
77 };
78
79 // forward is preserved by default, but
80 // not all modules have a forward function defined
81 if (module_.find_method("forward")) {
82 auto method = module_.get_method("forward");
83 preservedMethods_.insert(&method.function());
84 }
85
86 for (auto name : preservedAttrs) {
87 TORCH_CHECK(checkName(name), "Unknown name: " + name);
88 }
89 }
90
optimizeSubGraphs(std::shared_ptr<Graph> & graph,const std::function<void (std::shared_ptr<Graph> &)> & func)91 void optimizeSubGraphs(
92 std::shared_ptr<Graph>& graph,
93 const std::function<void(std::shared_ptr<Graph>&)>& func) {
94 func(graph);
95 std::stack<Block*> blocks({graph->block()});
96 while (!blocks.empty()) {
97 Block* block = blocks.top();
98 blocks.pop();
99 for (auto n : block->nodes()) {
100 for (Block* sub_block : n->blocks()) {
101 blocks.push(sub_block);
102 }
103 if (n->kind() == prim::fork) {
104 auto subgraph = n->g(attr::Subgraph);
105 optimizeSubGraphs(subgraph, func);
106 }
107 }
108 }
109 }
110
run()111 void run() {
112 auto applyInline = [](std::shared_ptr<Graph>& subgraph) {
113 Inline(*subgraph);
114 ClearProfilingInformation(subgraph);
115 };
116 auto applyOptimizations = [](std::shared_ptr<Graph>& subgraph) {
117 #ifndef C10_MOBILE
118 Autocast(subgraph);
119 #endif
120 runOptimization(
121 subgraph,
122 /* unroll_non_constant_loops? */ false,
123 /* const_prop_user_classes? */ false);
124 EliminateNoOps(subgraph);
125 LowerSimpleTuples(subgraph);
126 };
127
128 std::unordered_map<std::string, std::unordered_set<std::string>>
129 interfacesToReassignType;
130
131 for (auto function : preservedMethods_) {
132 GRAPH_DEBUG("Analyzing function: " + function->name());
133 auto graph = toGraphFunction(*function).graph();
134 optimizeSubGraphs(graph, applyInline);
135 if (freezeInterfaces_) {
136 inlineInterfaceCalls(graph, interfacesToReassignType);
137 }
138 }
139
140 reassignInterfaceTypes(interfacesToReassignType);
141
142 for (auto function : preservedMethods_) {
143 GRAPH_DEBUG("Recording mutable attrs for function: " + function->name());
144 auto graph = toGraphFunction(*function).graph();
145 // Record Attributes that are explicitly set in the module.
146 // They cannot be folded.
147 recordMutableAttrs(graph);
148 }
149
150 for (auto function : preservedMethods_) {
151 GRAPH_DEBUG("Propagating function: " + function->name());
152 auto graph = toGraphFunction(*function).graph();
153 propagateAttributes(graph);
154 optimizeSubGraphs(graph, applyOptimizations);
155 }
156 GRAPH_DEBUG("Cleaning up module");
157 cleanupFrozenModule();
158 }
159
160 private:
161 using ResolvedName = std::pair<Module, std::string>;
162
163 // Try to resolve qualified names (submodule1.submodule2.foo). If
164 // the qualified name exists in the root module, return the unqualified
165 // attribute/function name and the parent module. Else, return nullopt.
166 // Examples:
167 // submodule1.submodule2.foo -> {submodule2, "foo"}
168 // submodule1.non_existent_module.foo -> nullopt
resolveName(const std::string & name)169 std::optional<ResolvedName> resolveName(const std::string& name) {
170 auto sub_names = splitName(name);
171 if (sub_names.empty()) {
172 return std::nullopt;
173 }
174 auto& attr_name = sub_names.back();
175 auto cur_module = module_;
176 std::vector<ResolvedName> attr_infos;
177 attr_infos.reserve(sub_names.size() - 1);
178
179 for (size_t i = 0; i < sub_names.size() - 1; ++i) {
180 bool found = false;
181 const auto& sub_name = sub_names[i];
182 for (const auto& child_module : cur_module.named_children()) {
183 if (child_module.name == sub_name) {
184 attr_infos.emplace_back(cur_module._ivalue(), child_module.name);
185 cur_module = child_module.value;
186 found = true;
187 break;
188 }
189 }
190 if (!found) {
191 return std::nullopt;
192 }
193 }
194
195 if (cur_module.hasattr(attr_name) || cur_module.find_method(attr_name)) {
196 // We don't want to mark these modules as mutable yet; that could
197 // interfere with the inlining procedure. Instead, we'll record
198 // the fact that the user wants to preserve them. They will be
199 // processed during clean-up preparation (recordReferenceAttrs)
200 for (auto& attr_info : attr_infos) {
201 const auto& parent_module = attr_info.first;
202 auto& sub_name = attr_info.second;
203 userPreservedAttrs_[parent_module._ivalue()].insert(
204 std::move(sub_name));
205 }
206 return std::make_pair(std::move(cur_module), std::move(attr_name));
207 }
208
209 return std::nullopt;
210 }
211
_loadModulePath(Value * input,std::shared_ptr<Graph> & graph)212 bool _loadModulePath(Value* input, std::shared_ptr<Graph>& graph) {
213 Node* node = input->node();
214 names_.clear();
215 while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) {
216 if (node->kind() == prim::GetAttr) {
217 names_.push_front(node->s(attr::name));
218 node = node->inputs()[0]->node();
219 } else {
220 return false;
221 }
222 }
223
224 return true;
225 }
226
getModulePath(Value * input,std::shared_ptr<Graph> & graph)227 std::optional<std::deque<std::string>> getModulePath(
228 Value* input,
229 std::shared_ptr<Graph>& graph) {
230 bool success = _loadModulePath(input, graph);
231 if (!success) {
232 return std::nullopt;
233 }
234 return names_;
235 }
236
237 template <typename Iter>
getModuleFromPath(Module & attrModule,const Iter & begin,const Iter & end)238 bool getModuleFromPath(
239 Module& attrModule,
240 const Iter& begin,
241 const Iter& end) {
242 for (Iter it = begin; it != end; ++it) {
243 const std::string& moduleName = *it;
244 if (preservedAttrs_.count(attrModule.attr(moduleName))) {
245 return false;
246 }
247 attrModule = attrModule.attr(moduleName).toModule();
248 }
249 return true;
250 }
251
252 // findConstantAttr function locates the sub Module where attributes are
253 // defined. The algorithm chases getAttr chains to locate the submodules.
254 // For example:
255 // module M {
256 // attributes {
257 // A = <SubModule at ...>
258 // }
259 // ...
260 // %A = prim::GetAttr[name="A"](%self)
261 // ...
262 // %B = prim::GetAttr[name="B"](%A)
263 // ...
264 // %weight = prim::GetAttr[name="scale"](%B)
265 // ...
266 // submodules {
267 // module SubModule {
268 // attributes {
269 // B = <SubModule2 at ...>
270 // }
271 // submodules {
272 // module SubModule2 {
273 // attributes {
274 // scale = 2
275 // }
276 // }
277 // }
278 // }
279 // }
280 //
281 // findConstantAttr(%B, "scale", M) returns true because there are no
282 // explicit SetAttr that modifies %B. attrModule points to the module where
283 // attribute lives (in this example it is <SubModule2 at ...>).
284 //
285 // Note inplace mutations to attributes are checked later using alias
286 // analysis.
287 //
288 // We can use a more efficient algorithm to hash each constant GetAttr to its
289 // corresponding value. Based on initial test on resnet50 and other torch
290 // vision tests. GetAttrs are not too frequent so it is ok to chase GetAttr
291 // chain to retrieve their values.
findConstantAttr(Value * input,std::string & name,Module & attrModule,std::shared_ptr<Graph> & graph)292 bool findConstantAttr(
293 Value* input,
294 std::string& name,
295 Module& attrModule,
296 std::shared_ptr<Graph>& graph) {
297 if (!input->type()->cast<InterfaceType>() &&
298 !input->type()->expectRef<ClassType>().is_module()) {
299 return false;
300 }
301
302 // loads the path into this->names_
303 if (!_loadModulePath(input, graph)) {
304 return false;
305 }
306
307 // reassigns attrModule to the module in names_
308 if (!getModuleFromPath(attrModule, names_.begin(), names_.end())) {
309 return false;
310 }
311
312 auto attr = attrModule.attr(name);
313 if (!AliasDb::isMutableType(attr.type())) {
314 auto it = preservedScalarAttrs_.find(attrModule._ivalue());
315 return it == preservedScalarAttrs_.end() || !it->second.count(name);
316 }
317
318 if (preservedAttrs_.count(attr)) {
319 return false;
320 }
321 if (!attr.type()->cast<ClassType>()) {
322 for (auto& ivalue : preservedAttrs_) {
323 if (!ivalue.isObject() && ivalue.overlaps(attr)) {
324 return false;
325 }
326 }
327 }
328 return true;
329 }
330
insertMutableAttr(const std::string & name,const IValue & attr,const ModulePtr & attrModule)331 void insertMutableAttr(
332 const std::string& name,
333 const IValue& attr,
334 const ModulePtr& attrModule) {
335 if (AliasDb::isMutableType(attr.type())) {
336 preservedAttrs_.insert(attr);
337 } else {
338 preservedScalarAttrs_[attrModule].insert(name);
339 }
340 }
341
recordMutableAttrs(std::shared_ptr<Graph> & graph)342 void recordMutableAttrs(std::shared_ptr<Graph>& graph) {
343 std::stack<Block*> blocks({graph->block()});
344 std::unique_ptr<AliasDb> aliasDb =
345 std::make_unique<AliasDb>(graph, /* isFrozen */ true);
346 while (!blocks.empty()) {
347 Block* block = blocks.top();
348 blocks.pop();
349 for (auto n : block->nodes()) {
350 for (Block* sub_block : n->blocks()) {
351 blocks.push(sub_block);
352 }
353
354 // Modules with prim::ModuleContainerIndex cannot be frozen because they
355 // return InterfaceTypes.
356 TORCH_CHECK(
357 n->kind() != prim::ModuleContainerIndex,
358 "Freezing modules containing prim::ModuleContainerIndex is not supported");
359
360 if (n->kind() == prim::SetAttr || n->kind() == prim::GetAttr) {
361 // By default if interface attributes are present then fail freezing.
362 // If freezingInterfaces is on then Interfaces are folded similarly
363 // to other attributes.
364 TORCH_CHECK(
365 freezeInterfaces_ ||
366 !(n->kind() == prim::GetAttr &&
367 n->output()->type()->cast<InterfaceType>()),
368 "attempted to freeze a module that uses interface attributes");
369 auto name = n->s(attr::name);
370 auto attrModule = module_;
371 if (!findConstantAttr(n->inputs()[0], name, attrModule, graph)) {
372 continue;
373 }
374
375 auto attr = attrModule.attr(name);
376 if (n->kind() == prim::GetAttr) {
377 auto type = n->output()->type();
378 // Do not record submodules. Their attributes are tracked
379 // individually.
380 if (attr.isObject() || !AliasDb::isMutableType(attr.type())) {
381 continue;
382 }
383 usedAttrs_.insert(attr);
384 }
385
386 if (n->kind() == prim::SetAttr || aliasDb->hasOutputWriters(n)) {
387 GRAPH_DEBUG(
388 n->kind() == prim::GetAttr ? "attribute: " + name + " in %" +
389 n->output()->debugName() + " has inplace writer"
390 : "attribute: " + name + " is set");
391 auto mptr = attrModule._ivalue();
392 insertMutableAttr(name, attr, mptr);
393 }
394 } else if (n->kind() == prim::fork) {
395 applyToForkSubgraph(
396 n,
397 graph,
398 // NOLINTNEXTLINE(modernize-avoid-bind)
399 std::bind(
400 &AttributePropagator::recordMutableAttrs,
401 *this,
402 std::placeholders::_1));
403 }
404 }
405 }
406 // FIXME: Current Alias analysis fails to track subvalues.
407 // This is not a common scenario, for freezing, detect and error out.
408 IValue::HashAliasedIValues seen;
409 for (auto& val : usedAttrs_) {
410 IValue::HashAliasedIValues subValues;
411 val.getSubValues(subValues);
412 TORCH_CHECK(
413 std::all_of(
414 subValues.begin(),
415 subValues.end(),
416 [&seen](const IValue& v) { return seen.count(v) == 0; }),
417 "module contains attributes values that overlaps ",
418 val);
419 seen.insert(subValues.begin(), subValues.end());
420 }
421 }
422
overrideGradient(IValue attr)423 IValue overrideGradient(IValue attr) {
424 if (attr.isTensor()) {
425 auto& t = attr.toTensor();
426 if (t.requires_grad()) {
427 auto detached = t.detach();
428 detached.set_requires_grad(false);
429 attr = IValue(std::move(detached));
430 }
431 } else if (attr.isTuple()) {
432 auto tuple = std::move(attr).toTuple();
433 const auto& elems = tuple->elements();
434 for (const auto idx : c10::irange(elems.size())) {
435 tuple->unsafeSetElement(idx, overrideGradient(elems[idx]));
436 }
437 attr = std::move(tuple);
438 } else if (attr.isList()) {
439 c10::List<IValue> elems = std::move(attr).toList();
440 for (const auto i : c10::irange(elems.size())) {
441 elems.set(i, overrideGradient(elems.extract(i)));
442 }
443 attr = elems;
444 } else if (attr.isGenericDict()) {
445 auto dict = std::move(attr).toGenericDict();
446 for (const auto& pair : dict) {
447 auto val = pair.value();
448 val = overrideGradient(std::move(val));
449 }
450 attr = dict;
451 } else if (attr.isObject() && !attr.toObjectRef().type()->is_module()) {
452 auto obj_type = attr.type()->expect<ClassType>();
453 auto obj_value = std::move(attr).toObject();
454 auto sub_attributes = obj_type->getAttributes();
455 for (const auto& sub_attr : sub_attributes) {
456 auto sub_attr_val = obj_value->getAttr(sub_attr.getName());
457 sub_attr_val = overrideGradient(std::move(sub_attr_val));
458 }
459 return obj_value;
460 }
461
462 return attr;
463 }
464
465 // This method is invoked only when 'freezeInterfaces' parameter is on.
466 // The module associated with Interface is retrieved and the invoked method
467 // is inlined.
inlineInterfaceCall(Node * n,const IValue & attr)468 bool inlineInterfaceCall(Node* n, const IValue& attr) {
469 auto class_type = attr.type()->expect<ClassType>();
470 bool inlined = false;
471 for (auto use : n->output()->uses()) {
472 auto user_node = use.user;
473 if (user_node->kind() == prim::CallMethod) {
474 const std::string& methodName = user_node->s(attr::name);
475 Function& function = class_type->getMethod(methodName);
476 if (auto graphFunction = tryToGraphFunction(function)) {
477 GRAPH_UPDATE(
478 "Inlining interface method '",
479 function.name(),
480 "' to ",
481 *user_node);
482
483 GRAPH_UPDATE("Function body: ", graphFunction->optimized_graph());
484 inlineCallTo(user_node, graphFunction);
485 inlined = true;
486 }
487 }
488 }
489 return inlined;
490 }
491
492 // [Note: Inlining interfaces strategy]
493 // There's two structures that are relevant to freezing:
494 // - the graph describing the computation in a method
495 // - the module describing the data structure of the module instance.
496 //
497 // First, in inlineInterfaceCalls, we inline interfaces. This is done in a
498 // separate step from normal inlining because CallMethod on an interface type
499 // requires extra steps compared to inlining a normal CallMethod.
500 //
501 // Next we need to simplify the structure of the module data structure, which
502 // is done for the most part by the usual steps in cleanupFrozenModule.
503 //
504 // However, there's a complication that comes from the fact that within a
505 // method, you can change the value of an interface to another module that
506 // implements that interface.
507 //
508 // For example:
509 //
510 // impl: MyInterface
511 // ...
512 // def forward(self, x):
513 // if x > 0:
514 // self.impl = my_interface_impl
515 //
516 // This is disallowed in freezing, because in this case we can't flatten out
517 // the module structure, since the type of self.impl will change.
518 //
519 // To handle this, we do the following:
520 // 1. inlineInterfaceCalls:
521 // a. inline the graph, and in the process record all interfaces
522 // b. simultaneously, check (throw error) for disallowed SetAttr calls.
523 // 2. call reassignInterfaceTypes, which reassigns interface types to their
524 // concrete types. This is done in a separate step to avoid interfering
525 // with inlineInterfaceCalls (note: this may not need to be done as a
526 // separate step)
527 // 3. eventually cleanupFrozenModule will reorder the module data structure
528 // and it will expect that all interface types have been removed.
inlineInterfaceCalls(std::shared_ptr<Graph> & graph,std::unordered_map<std::string,std::unordered_set<std::string>> & interfacesToRetype)529 void inlineInterfaceCalls(
530 std::shared_ptr<Graph>& graph,
531 std::unordered_map<std::string, std::unordered_set<std::string>>&
532 interfacesToRetype) {
533 auto block = graph->block();
534 std::stack<Block*> blocks({block});
535
536 while (!blocks.empty()) {
537 Block* block = blocks.top();
538 blocks.pop();
539 for (auto n : block->nodes()) {
540 for (Block* sub_block : n->blocks()) {
541 blocks.push(sub_block);
542 }
543 if (n->kind() == prim::GetAttr) {
544 if (!n->output()->type()->cast<InterfaceType>()) {
545 continue;
546 }
547 auto name = n->s(attr::name);
548 auto attrModule = module_;
549 auto input = n->inputs()[0];
550 TORCH_CHECK(
551 findConstantAttr(input, name, attrModule, graph),
552 "failed to freeze interface attribute '" + name + "'");
553 TORCH_INTERNAL_ASSERT(attrModule.hasattr(name));
554 auto attr = attrModule.attr(name);
555 inlineInterfaceCall(n, attr);
556 // Reset the GetAttr to concrete module type.
557 n->output()->setType(attr.type());
558
559 // Record this so that we can reassign the type later
560 // in reassignInterfaceTypes()
561 // See [Note: Inlining interfaces strategy]
562 auto path = getModulePath(input, graph);
563 TORCH_INTERNAL_ASSERT(path.has_value());
564 auto path_str = concatName(path->begin(), path->end());
565 interfacesToRetype[path_str].insert(name);
566 } else if (n->kind() == prim::SetAttr) {
567 // Check to make sure we're not assigning the value of any parameters
568 // that are interface types.
569 // See [Note: Inlining interfaces strategy]
570 auto name = n->s(attr::name);
571 auto attrModule = module_;
572 auto input = n->inputs()[0];
573
574 if (!input->type()->cast<InterfaceType>() &&
575 !input->type()->expectRef<ClassType>().is_module()) {
576 // we only care if we're setattr["thing"](%mod) if %mod
577 continue;
578 }
579
580 // note: this will modify attrModule until it is the parent of the
581 // "name" attr. In other words, attrModule is now the module that
582 // matches "input".
583 // We can't use findConstantAttr in case the base item is an object,
584 // instead of a module/interface.
585 auto path = getModulePath(input, graph);
586 TORCH_INTERNAL_ASSERT(path.has_value());
587 getModuleFromPath(attrModule, path->begin(), path->end());
588
589 const auto& attrType = attrModule.type()->getAttribute(name);
590 TORCH_INTERNAL_ASSERT(
591 !attrType->cast<InterfaceType>(),
592 "Freezing does not support SetAttr on an interface type. ",
593 "SetAttr is attempted on '",
594 name,
595 "'");
596 } else if (n->kind() == prim::fork) {
597 applyToForkSubgraph(
598 n,
599 graph,
600 // NOLINTNEXTLINE(modernize-avoid-bind)
601 std::bind(
602 &AttributePropagator::inlineInterfaceCalls,
603 *this,
604 std::placeholders::_1,
605 interfacesToRetype));
606 }
607 }
608 }
609 }
610
611 // See [Note: Inlining interfaces strategy]
612 // This modifies the internal structure of module types to reassign the
613 // type from an interface type to its concrete type.
reassignInterfaceTypes(const std::unordered_map<std::string,std::unordered_set<std::string>> & interfacesToRetype)614 void reassignInterfaceTypes(
615 const std::unordered_map<std::string, std::unordered_set<std::string>>&
616 interfacesToRetype) {
617 for (const auto& it : interfacesToRetype) {
618 const std::string& modulePath = it.first;
619 const std::vector<std::string>& splitPath = splitName(modulePath);
620 Module attrModule = module_;
621 getModuleFromPath(attrModule, splitPath.begin(), splitPath.end());
622
623 for (const std::string& name : it.second) {
624 auto subvalue = attrModule.attr(name);
625 auto subvalueType = subvalue.type();
626 attrModule.type()->unsafeChangeAttributeType(name, subvalueType);
627 }
628 }
629 }
630
propagateAttributes(std::shared_ptr<Graph> & graph)631 void propagateAttributes(std::shared_ptr<Graph>& graph) {
632 std::unordered_map<ModulePtr, std::unordered_map<std::string, Value*>>
633 attrValues;
634 auto isEval = !module_.hasattr("training") || !module_.is_training();
635 GRAPH_DEBUG("Freezing Module: ", module_.type()->name()->name());
636 auto block = graph->block();
637 std::stack<Block*> blocks({block});
638
639 Node* m = *block->nodes().begin();
640 WithInsertPoint guard(m);
641 while (!blocks.empty()) {
642 Block* block = blocks.top();
643 blocks.pop();
644 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
645 Node* n = *it;
646 it++; // advance iterator bc the current node may be destroyed
647
648 for (Block* sub_block : n->blocks()) {
649 blocks.push(sub_block);
650 }
651 if (n->kind() == prim::GetAttr) {
652 auto name = n->s(attr::name);
653 auto attrModule = module_;
654 auto input = n->inputs()[0];
655 if (!findConstantAttr(input, name, attrModule, graph)) {
656 GRAPH_DEBUG(
657 input->type()->cast<InterfaceType>() ||
658 input->type()->expectRef<ClassType>().is_module()
659 ? "attribute: " + name + " is mutable."
660 : "");
661 continue;
662 }
663 TORCH_INTERNAL_ASSERT(attrModule.hasattr(name));
664 Value* paramConst = nullptr;
665 auto iter = attrValues.find(attrModule._ivalue());
666 if (iter != attrValues.end()) {
667 auto iter2 = iter->second.find(name);
668 if (iter2 != iter->second.end())
669 paramConst = iter2->second;
670 }
671 if (!paramConst) {
672 auto attr = attrModule.attr(name);
673 if (!isEval || preserveParameters_) {
674 auto type = attrModule.type();
675 auto slot = *type->findAttributeSlot(name);
676 if (type->is_parameter(slot) || type->is_buffer(slot) ||
677 (attr.isObject() &&
678 !attr.toObjectRef().type()->is_module())) {
679 continue;
680 } else {
681 attr = overrideGradient(attr);
682 }
683 if (!isEval && name == "training") {
684 continue;
685 }
686 } else {
687 attr = overrideGradient(attr);
688 }
689 if (attr.isObject()) {
690 if (object_memo_.count(attr.toObject())) {
691 attr = object_memo_[attr.toObject()];
692 } else {
693 auto weak_class_obj =
694 attr.toObject()->copy_to_weak_compilation_ref();
695 object_memo_[attr.toObject()] = weak_class_obj;
696 attr = weak_class_obj;
697 }
698 }
699 if (auto attrVal = tryInsertConstant(*graph, attr)) {
700 paramConst = *attrVal;
701 } else {
702 GRAPH_DEBUG(
703 attr.type()->cast<ClassType>() ? "" : "attribute: ",
704 name,
705 " is not materializable.");
706 continue;
707 }
708 std::string fullName("self.");
709 for (auto& name : names_) {
710 fullName += name + '.';
711 }
712 fullName += name;
713 paramConst->setDebugName(fullName);
714 attrValues[attrModule._ivalue()][name] = paramConst;
715 }
716 GRAPH_UPDATE(
717 "Folding GetAttr %",
718 n->outputs()[0]->debugName(),
719 " with ",
720 paramConst->debugName());
721 n->outputs().at(0)->replaceAllUsesWith(paramConst);
722 n->removeAllInputs();
723 } else if (n->kind() == prim::fork) {
724 applyToForkSubgraph(
725 n,
726 graph,
727 // NOLINTNEXTLINE(modernize-avoid-bind)
728 std::bind(
729 &AttributePropagator::propagateAttributes,
730 *this,
731 std::placeholders::_1));
732 }
733 }
734 }
735 }
736
applyToForkSubgraph(Node * n,std::shared_ptr<Graph> & graph,const std::function<void (std::shared_ptr<Graph> &)> & func)737 void applyToForkSubgraph(
738 Node* n,
739 std::shared_ptr<Graph>& graph,
740 const std::function<void(std::shared_ptr<Graph>&)>& func) {
741 TORCH_CHECK(n->kind() == prim::fork);
742 auto attrModule = module_;
743 auto node = n->inputs()[0]->node();
744 // Check if first parameter of fork is a module. This module is used
745 // as the base module (similar to 'self' in forward) to resolve GetAttrs.
746 // Otherwise freezing is applied using module_
747 if (node->kind() == prim::GetAttr &&
748 node->output()->type()->cast<ClassType>()) {
749 auto name = node->s(attr::name);
750 auto input = node->inputs()[0];
751 if (!findConstantAttr(input, name, attrModule, graph)) {
752 // Module needs to be preserved.
753 return;
754 }
755 attrModule = attrModule.attr(name).toModule();
756 std::swap(module_, attrModule);
757 }
758
759 auto subgraph = n->g(attr::Subgraph);
760 func(subgraph);
761 module_ = attrModule;
762 }
763
moduleEscapes(Module & subModule,std::shared_ptr<Graph> & graph)764 bool moduleEscapes(Module& subModule, std::shared_ptr<Graph>& graph) {
765 for (auto& output : graph->outputs()) {
766 if (subModule.type()->isSubtypeOf(*output->type())) {
767 return true;
768 }
769 }
770 return preservedSubModule_.count(subModule._ivalue());
771 }
772
removeExtraWaitCalls(Block * b)773 void removeExtraWaitCalls(Block* b) {
774 auto nodes = b->nodes();
775 for (auto it = nodes.begin(); it != nodes.end(); it++) {
776 auto node = *it;
777 if (node->kind() != aten::wait) {
778 continue;
779 }
780 TORCH_INTERNAL_ASSERT(node->inputs().size() == 1);
781 TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
782 // If input type is not a from aten::fork call then the
783 // aten::wait operator can be deleted.
784 if (node->input()->type()->kind() != TypeKind::FutureType) {
785 node->output()->replaceAllUsesWith(node->input());
786 it.destroyCurrent();
787 }
788 }
789 // For the remaining nodes, recurse.
790 for (auto it = nodes.begin(); it != nodes.end(); it++) {
791 auto node = *it;
792 for (auto sub_b : node->blocks()) {
793 removeExtraWaitCalls(sub_b);
794 }
795 }
796 }
797
798 // cleanupFrozenModule function cleans up the Frozen module. It performs the
799 // following:
800 // 1) Remove unused attributes.
801 // 2) Remove unreferenced submodules
802 // 3) Remove non public unreferenced methods.
cleanupFrozenModule()803 void cleanupFrozenModule() {
804 for (auto function : preservedMethods_) {
805 auto graph = toGraphFunction(*function).graph();
806 recordReferencedAttrs(graph);
807 handleSharedClassType(module_, graph);
808 removeExtraWaitCalls(graph->block());
809 toGraphFunction(*function).clear_optimized_graphs();
810 }
811 removeUnusedAttrs();
812 }
813
814 // Preparing for clean up phase. At this point, record all subModules that
815 // contains mutable attributes.
recordReferencedAttrs(std::shared_ptr<Graph> & graph)816 void recordReferencedAttrs(std::shared_ptr<Graph>& graph) {
817 std::stack<Block*> blocks({graph->block()});
818 std::set<ModulePtr> modules({module_._ivalue()});
819 while (!blocks.empty()) {
820 Block* block = blocks.top();
821 blocks.pop();
822 for (auto n : block->nodes()) {
823 for (Block* subBlock : n->blocks()) {
824 blocks.push(subBlock);
825 }
826 if (n->kind() == prim::GetAttr) {
827 auto& name = n->s(attr::name);
828 // For now, use all module ivalues which are the same type
829 // and could be the module that this GetAttr resolves to
830 // TODO: we could attempt to follow the GetAttr chain and
831 // find the exact ivalue, we would have to be careful
832 // that the chain does not contain any attributes which
833 // get written to (setAttr calls)
834 for (auto& mptr : modules) {
835 auto module = Module(mptr);
836 if (module.type() == n->inputs()[0]->type()) {
837 TORCH_INTERNAL_ASSERT(module.hasattr(name));
838 auto module = Module(mptr);
839 auto attr = module.attr(name);
840 // TODO: this could be insertReferencedAttr to be more clear,
841 // these are attributes we could not inline, which include
842 // other reasons besides mutation (unsupported constant,
843 // getAttr resolving to non-getAttr node, etc)
844 insertMutableAttr(name, attr, mptr);
845 if (attr.isModule()) {
846 modules.insert(attr.toModule()._ivalue());
847 }
848 }
849 }
850 } else if (n->kind() == prim::fork) {
851 applyToForkSubgraph(
852 n,
853 graph,
854 // NOLINTNEXTLINE(modernize-avoid-bind)
855 std::bind(
856 &AttributePropagator::recordReferencedAttrs,
857 *this,
858 std::placeholders::_1));
859 }
860 }
861 }
862 // We have to process the attributes that the user wants to preserve
863 // separately since it's possible that the user-preserved module is
864 // never referenced in the graph.
865 for (const auto& attr_info : userPreservedAttrs_) {
866 const auto& parent_module = attr_info.first;
867 for (const auto& attr_name : attr_info.second) {
868 const auto value = parent_module->getAttr(attr_name);
869 insertMutableAttr(attr_name, value, parent_module);
870 }
871 }
872 }
873
874 // This function recursively iterates over submodules to identify
875 // for each class type the attribute slots that need to be preserved.
876 //
877 // Note 'attrsToKeep[type].insert(type->numAttributes())' means all
878 // attribute slots of 'type' and its methods are preserved. A submodule is
879 // preserved when it escapes (meaning it is returned).
handleSharedClassType(Module & module,std::shared_ptr<Graph> & graph)880 void handleSharedClassType(Module& module, std::shared_ptr<Graph>& graph) {
881 auto type = module.type();
882 size_t N = type->numAttributes();
883 if (moduleEscapes(module, graph)) {
884 // Preserve all its attributes and methods.
885 attrsToKeep_[type].insert(N);
886 return;
887 }
888 auto it2 = preservedScalarAttrs_.find(module._ivalue());
889 SharedTypeSubModules_[type].insert(module._ivalue());
890 attrsToKeep_[type].insert({});
891 for (const auto i : c10::irange(N)) {
892 auto name = type->getAttributeName(i);
893 auto attr = module.attr(name);
894 auto attrTy = attr.type();
895
896 bool isMutable = false;
897 if (AliasDb::isMutableType(attrTy)) {
898 isMutable = preservedAttrs_.count(attr);
899 } else {
900 isMutable =
901 it2 != preservedScalarAttrs_.end() && it2->second.count(name);
902 }
903 if (isMutable) {
904 attrsToKeep_[type].insert(i);
905 if (attr.isModule()) {
906 // See [Note: Inlining interfaces strategy]
907 TORCH_CHECK(
908 !type->getAttribute(i)->cast<InterfaceType>(),
909 "Unexpected interface attribute '" + name + "' during freezing");
910
911 auto attrModule = attr.toModule();
912 handleSharedClassType(attrModule, graph);
913 }
914 }
915 }
916 }
917
918 // Remove unused attributes and methods for each sub module of the frozen
919 // module. This function iterates over the Classtypes of its submodule
920 // attributes including its own type.
removeUnusedAttrs()921 void removeUnusedAttrs() {
922 std::vector<std::string> attrsToRemove;
923 std::vector<Function*> funcsToRemove;
924 for (auto& it : attrsToKeep_) {
925 auto& type = it.first;
926 size_t N = type->numAttributes();
927 if (it.second.count(N)) {
928 continue;
929 }
930 for (const auto i : c10::irange(N)) {
931 if (it.second.count(i) == 0) {
932 attrsToRemove.push_back(type->getAttributeName(i));
933 }
934 }
935 for (auto& fn : type->methods()) {
936 if (preservedMethods_.count(fn)) {
937 continue;
938 }
939 funcsToRemove.push_back(fn);
940 }
941
942 for (auto& name : attrsToRemove) {
943 for (auto& val : SharedTypeSubModules_[type]) {
944 auto mod = val.toModule();
945 mod._ivalue()->unsafeRemoveAttr(name);
946 }
947 type->unsafeRemoveAttribute(name);
948 }
949 for (auto fn : funcsToRemove) {
950 type->unsafeRemoveMethod(fn->name());
951 auto mod = SharedTypeSubModules_[type].begin()->toModule();
952 mod._ivalue()->compilation_unit()->unsafeRemoveMethod(fn->qualname());
953 }
954
955 attrsToRemove.clear();
956 funcsToRemove.clear();
957 }
958 }
959
960 // Contains attributes that can't be folded or user directs to keep them.
961 IValue::HashAliasedIValues preservedAttrs_;
962 // Tracked immutable types (Scalars) by their attribute names not
963 // IValues.
964 std::unordered_map<ModulePtr, std::unordered_set<std::string>>
965 preservedScalarAttrs_;
966
967 // Contains user specified methods to be preserved in frozen module.
968 std::unordered_set<Function*> preservedMethods_;
969
970 // Contains user specified sub module to be preserve in frozen module.
971 std::unordered_set<ModulePtr> preservedSubModule_;
972
973 // Track all used attributes ivalues that can be aliased.
974 IValue::HashAliasedIValues usedAttrs_;
975
976 // Contains the attribute slots that need to be preserved for each ClassType.
977 std::unordered_map<ClassTypePtr, std::unordered_set<size_t>> attrsToKeep_;
978
979 // Contains the sub modules that share the same ClassType.
980 std::unordered_map<ClassTypePtr, IValue::HashAliasedIValues>
981 SharedTypeSubModules_;
982
983 Module& module_;
984
985 // Allow to freeze modules containing interfaces.
986 bool freezeInterfaces_;
987
988 // Preserve module parameters
989 bool preserveParameters_;
990
991 // Contains the attributes names (e.g. {"self", "subModule", "a"}
992 std::deque<std::string> names_;
993
994 // see [Constant Object Weak CompilationUnit Reference]
995 std::unordered_map<
996 c10::intrusive_ptr<at::ivalue::Object>,
997 c10::intrusive_ptr<at::ivalue::Object>>
998 object_memo_;
999
1000 // Contains names of attributes that the user wants to preserve with
1001 // their owning modules.
1002 std::unordered_map<ModulePtr, std::unordered_set<std::string>>
1003 userPreservedAttrs_;
1004
1005 }; // class AttributePropagator
1006
checkModuleDoesNotReturnSelf(const Module & module)1007 void checkModuleDoesNotReturnSelf(const Module& module) {
1008 if (module.find_method("forward")) {
1009 Method method = module.get_method("forward");
1010 // Check that module does not return itself.
1011 for (auto& output : method.graph()->outputs()) {
1012 TORCH_CHECK(
1013 output->type() != module.type(),
1014 "attempted to freeze a module that return itself");
1015 }
1016 }
1017 }
1018 } // namespace
1019
freeze_module(const Module & module,std::vector<std::string> preservedAttrs,bool freezeInterfaces,bool preserveParameters)1020 Module freeze_module(
1021 const Module& module,
1022 std::vector<std::string> preservedAttrs,
1023 bool freezeInterfaces,
1024 bool preserveParameters) {
1025 checkModuleDoesNotReturnSelf(module);
1026
1027 auto moduleClone = module.clone(true);
1028 AttributePropagator attrPropagator(
1029 moduleClone, preservedAttrs, freezeInterfaces, preserveParameters);
1030 attrPropagator.run();
1031 return moduleClone;
1032 }
1033
freeze_module_inplace(Module * module,std::vector<std::string> preservedAttrs,bool freezeInterfaces,bool preserveParameters)1034 void freeze_module_inplace(
1035 Module* module,
1036 std::vector<std::string> preservedAttrs,
1037 bool freezeInterfaces,
1038 bool preserveParameters) {
1039 TORCH_CHECK(module != nullptr, "module cannot be nullptr");
1040 checkModuleDoesNotReturnSelf(*module);
1041 AttributePropagator attrPropagator(
1042 *module, preservedAttrs, freezeInterfaces, preserveParameters);
1043 attrPropagator.run();
1044 }
1045
1046 } // namespace torch::jit
1047