xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/specialize_autogradzero.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/specialize_autogradzero.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/runtime/graph_executor.h>
7 #include <torch/csrc/jit/runtime/profiling_record.h>
8 
9 #include <ATen/core/symbol.h>
10 #include <c10/util/irange.h>
11 
12 namespace torch::jit {
13 
14 static const auto countsAttribute = Symbol::attr("none_counts");
15 
hasGradSumToSizeUses(Value * v)16 static bool hasGradSumToSizeUses(Value* v) {
17   return std::any_of(v->uses().begin(), v->uses().end(), [](const Use& use) {
18     return use.user->kind() == aten::_grad_sum_to_size;
19   });
20 }
21 
insertProfileNodesForSpecializeAutogradZero(Block * block,ProfilingRecord * pr)22 static void insertProfileNodesForSpecializeAutogradZero(
23     Block* block,
24     ProfilingRecord* pr) {
25   for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
26     auto n = *it;
27     for (const auto offset : c10::irange(n->inputs().size())) {
28       auto i = n->input(offset);
29       if (i->type()->cast<OptionalType>() && hasGradSumToSizeUses(i)) {
30         // here we are profile the definition instead of the use,
31         // because we are only optimizing in the case of a None value which is
32         // immutable
33         auto opt_pn = pr->createProfileIValueNode(i);
34 
35         c10::Dict<std::string, int64_t> noneCountsDict;
36         noneCountsDict.insert("num_none", 0);
37         noneCountsDict.insert("num_present", 0);
38         IValue init_val(noneCountsDict);
39 
40         opt_pn->ival_(countsAttribute, init_val);
41 
42         std::function<void(Stack&)> optional_profiler = [pr,
43                                                          opt_pn](Stack& stack) {
44           std::lock_guard<std::mutex> lock(pr->mutex_);
45 
46           TORCH_INTERNAL_ASSERT(opt_pn->hasAttribute(countsAttribute));
47           // frame_id is unused
48           int64_t frame_id = 0;
49           pop(stack, frame_id);
50 
51           const auto& counts_attr = opt_pn->ival(countsAttribute);
52           auto noneCounts = c10::impl::toTypedDict<std::string, int64_t>(
53               counts_attr.toGenericDict());
54           IValue value;
55           pop(stack, value);
56           if (value.isNone()) {
57             noneCounts.insert_or_assign(
58                 "num_none", noneCounts.at("num_none") + 1);
59           } else {
60             noneCounts.insert_or_assign(
61                 "num_present", noneCounts.at("num_present") + 1);
62           }
63           push(stack, value);
64         };
65         opt_pn->setCallback(optional_profiler);
66         opt_pn->insertAfter(i->node());
67         i->replaceAllUsesAfterNodeWith(opt_pn, opt_pn->output());
68       }
69     }
70 
71     for (auto ib : n->blocks()) {
72       insertProfileNodesForSpecializeAutogradZero(ib, pr);
73     }
74   }
75 }
76 
InsertProfileNodesForSpecializeAutogradZero(ProfilingRecord * pr)77 void InsertProfileNodesForSpecializeAutogradZero(ProfilingRecord* pr) {
78   insertProfileNodesForSpecializeAutogradZero(pr->profiled_graph_->block(), pr);
79 }
80 
81 struct AutogradZeroSpecializer {
82   enum class State { Nonzero, Zero, Unknown };
83 
AutogradZeroSpecializertorch::jit::AutogradZeroSpecializer84   AutogradZeroSpecializer(std::shared_ptr<Graph> graph)
85       : graph_(std::move(graph)) {}
86 
runtorch::jit::AutogradZeroSpecializer87   void run() {
88     if (!isBackwardGraph()) {
89       return;
90     }
91     if (getExecutorMode()) {
92       if (auto versioning_if = guardSpecializations()) {
93         specializeAutogradOps(versioning_if->blocks()[0]);
94         GRAPH_DUMP("After versioning graph", graph_);
95       }
96     } else {
97       setStatesOnGraphInputs();
98       specializeAutogradOps(graph_->block());
99     }
100     GRAPH_DUMP("After specializeAutogradOps graph", graph_);
101   }
102 
103  private:
isBackwardGraphtorch::jit::AutogradZeroSpecializer104   bool isBackwardGraph() {
105     return std::any_of(
106         graph_->nodes().begin(), graph_->nodes().end(), [](Node* n) {
107           switch (n->kind()) {
108             case prim::AutogradAnyNonZero:
109             case prim::AutogradAdd:
110             case aten::_grad_sum_to_size:
111               return true;
112             default:
113               return false;
114           }
115         });
116   }
117 
replaceBlockInputsWithGraphInputstorch::jit::AutogradZeroSpecializer118   void replaceBlockInputsWithGraphInputs(Block* b) {
119     TORCH_INTERNAL_ASSERT(graph_->inputs().size() == b->inputs().size());
120     size_t num_inputs = graph_->inputs().size();
121     for (const auto i : c10::irange(num_inputs)) {
122       b->inputs().at(i)->replaceAllUsesWith(graph_->inputs().at(i));
123     }
124     for (const auto i : c10::irange(num_inputs)) {
125       b->eraseInput(num_inputs - (1 + i));
126     }
127   }
128 
setStatesOnGraphInputstorch::jit::AutogradZeroSpecializer129   void setStatesOnGraphInputs() {
130     for (Value* input : graph_->inputs()) {
131       const auto& tp = input->type();
132       if (auto tt = tp->cast<TensorType>()) {
133         if (tt->undefined()) {
134           if (*tt->undefined()) {
135             state_[input] = State::Zero;
136           } else {
137             state_[input] = State::Nonzero;
138           }
139         } else {
140           state_[input] = State::Unknown;
141         }
142       } else if (
143           tp->isSubtypeOf(*TensorType::get()) ||
144           tp->isSubtypeOf(*ListType::ofTensors())) {
145         state_[input] = State::Nonzero;
146       } else {
147         state_[input] = State::Unknown;
148       }
149     }
150   }
151 
getUsesWithAttribute_torch::jit::AutogradZeroSpecializer152   static void getUsesWithAttribute_(
153       Value* inp,
154       Symbol attr,
155       std::vector<Node*>& uses) {
156     for (auto use : inp->uses()) {
157       if (use.user->kind() != prim::profile_ivalue) {
158         continue;
159       }
160 
161       if (use.user->hasAttribute(attr)) {
162         uses.push_back(use.user);
163       }
164 
165       getUsesWithAttribute_(use.user->output(), attr, uses);
166     }
167   }
168 
169   // this is to deal with the fact that there could be other passes that
170   // would like to profile this exact same value. this helper walks
171   // chains of `prim::profile_ivalue` to locate the one inserted by/for
172   // `specializeAutogradZero`
getUsesWithAttributetorch::jit::AutogradZeroSpecializer173   static std::vector<Node*> getUsesWithAttribute(Value* inp, Symbol attr) {
174     std::vector<Node*> uses;
175     getUsesWithAttribute_(inp, attr, uses);
176     return uses;
177   }
178 
getUsetorch::jit::AutogradZeroSpecializer179   static Node* getUse(Value* inp, Symbol kind) {
180     for (auto use : inp->uses()) {
181       if (use.user->kind() == kind) {
182         return use.user;
183       }
184     }
185 
186     return nullptr;
187   }
188 
removeProfiledOptionalUsestorch::jit::AutogradZeroSpecializer189   void removeProfiledOptionalUses(const std::vector<Node*>& uses) {
190     TORCH_INTERNAL_ASSERT(!uses.empty());
191     auto inp = uses[0]->input();
192     // this removes `prim::profile_ivalue` from the original and to-specialize
193     // blocks N.B. the false block isn't impacted as it has been already
194     // encapsulated in a fallback function
195     for (auto u : uses) {
196       u->output()->replaceAllUsesWith(inp);
197     }
198   }
199 
guardSpecializationstorch::jit::AutogradZeroSpecializer200   Node* guardSpecializations() {
201     auto versioning_if = graph_->create(prim::If, {}, graph_->outputs().size());
202     auto value_map = [](Value* v) { return v; };
203     auto true_block = versioning_if->addBlock();
204     auto false_block = versioning_if->addBlock();
205 
206     // we will optimize true_block
207     true_block->cloneFrom(graph_->block(), value_map);
208     replaceBlockInputsWithGraphInputs(true_block);
209     false_block->cloneFrom(graph_->block(), value_map);
210     replaceBlockInputsWithGraphInputs(false_block);
211     replaceBlockWithFallbackGraph(false_block, graph_->inputs());
212 
213     WithInsertPoint wip{graph_->block()->param_node()->next()};
214     Value* none_val = graph_->insertConstant(IValue());
215     std::vector<Value*> checks;
216     std::vector<Value*> zero_values;
217     std::vector<Value*> nonzero_values;
218 
219     for (auto inp : graph_->inputs()) {
220       std::vector<Node*> iprofile_counts_nodes =
221           getUsesWithAttribute(inp, countsAttribute);
222       if (!iprofile_counts_nodes.empty()) {
223         // the original `prim::profile_value[num_present=0,...]` on `inp` is
224         // copied into `true_block` and `false_block`.
225         auto profile_ivalue_node = iprofile_counts_nodes[0];
226         TORCH_INTERNAL_ASSERT(
227             profile_ivalue_node->hasAttribute(countsAttribute));
228         const auto& counts_attr =
229             profile_ivalue_node->ival(countsAttribute).toGenericDict();
230         auto num_present = counts_attr.at(IValue{"num_present"}).toInt();
231         auto num_none = counts_attr.at(IValue{"num_none"}).toInt();
232         if (num_present == 0 && num_none != 0) {
233           auto check = graph_->insert(aten::__is__, {inp, none_val})->node();
234           checks.push_back(check->output());
235           profiled_none_.insert(inp);
236         }
237         removeProfiledOptionalUses(iprofile_counts_nodes);
238         continue;
239       }
240 
241       if (inp->uses().empty() || !inp->type()->cast<TensorType>()) {
242         continue;
243       }
244 
245       // TODO: check multiple uses ?
246       auto pout = getUse(inp, prim::profile);
247       if (!pout) {
248         continue;
249       }
250 
251       auto pttp = pout->ty(attr::profiled_type)->expect<TensorType>();
252       if (!pttp->undefined().has_value()) {
253         continue;
254       }
255 
256       state_[inp] = *pttp->undefined() ? State::Zero : State::Nonzero;
257 
258       if (*pttp->undefined()) {
259         zero_values.push_back(inp);
260       } else {
261         nonzero_values.push_back(inp);
262       }
263     }
264     GRAPH_DUMP("After for loop", graph_);
265     // unable to specialize any of the inputs
266     if (nonzero_values.empty() && zero_values.empty()) {
267       GRAPH_DUMP("Unable to add any specialization guards", graph_);
268       versioning_if->destroy();
269       // the checks we inserted will be cleaned up
270       // by any subsequent DCE pass
271       return nullptr;
272     }
273 
274     Node* nonzero_check = graph_->insert(prim::AutogradAllNonZero, {})->node();
275     for (Value* v : nonzero_values) {
276       nonzero_check->addInput(v);
277     }
278     checks.push_back(nonzero_check->output());
279 
280     Node* zero_check = graph_->insert(prim::AutogradAllZero, {})->node();
281     for (Value* v : zero_values) {
282       zero_check->addInput(v);
283     }
284     checks.push_back(zero_check->output());
285 
286     Value* bool_list =
287         graph_->insertNode(graph_->createList(BoolType::get(), checks))
288             ->output();
289     Value* conjunction = graph_->insert(aten::all, {bool_list});
290 
291     versioning_if->addInput(conjunction);
292     graph_->insertNode(versioning_if);
293 
294     auto ret = graph_->return_node();
295     for (const auto i : c10::irange(ret->inputs().size())) {
296       auto ogo = ret->input(i);
297       auto ngo = versioning_if->output(i);
298       ngo->copyMetadata(ogo);
299       ret->replaceInput(i, ngo);
300     }
301 
302     // We've created:
303     // successful_checks = Guards(...)
304     // if (successful_checks)
305     // -> optimized graph
306     // else:
307     // -> fallback graph
308     // original graph
309     //
310     // Remove the dead original graph
311     for (auto it = graph_->block()->nodes().reverse().begin();
312          *it != versioning_if;) {
313       Node* n = *it;
314       it++;
315       n->destroy();
316     }
317 
318     GRAPH_DUMP("After guardSpecializations", graph_);
319     return versioning_if;
320   }
321 
specializeAutogradOpstorch::jit::AutogradZeroSpecializer322   void specializeAutogradOps(Block* block) {
323     for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
324       auto n = *it;
325       switch (n->kind()) {
326         case prim::AutogradAdd: {
327           auto a = n->input(0);
328           auto b = n->input(1);
329           // if one is Autograd zero, we can just drop the add
330           if (state_[a] == State::Zero) {
331             // Zero + b == b
332             n->output()->replaceAllUsesWith(b);
333             it.destroyCurrent();
334           } else if (state_[b] == State::Zero) {
335             // a + Zero == a
336             n->output()->replaceAllUsesWith(a);
337             it.destroyCurrent();
338           } else if (
339               state_[a] == State::Nonzero && state_[b] == State::Nonzero) {
340             // when both are Nonzero, we can use a normal, optimizable add
341             // instruction
342             WithInsertPoint guard(n);
343             auto* cOne = graph_->insertConstant(1);
344             auto* add_node = graph_->insertNode(graph_->create(aten::add, 1));
345             add_node->addInput(a);
346             add_node->addInput(b);
347             add_node->addInput(cOne);
348             auto* add_output = add_node->output();
349             add_output->setType(n->output()->type());
350             state_[add_output] = State::Nonzero;
351             n->output()->replaceAllUsesWith(add_output);
352             it.destroyCurrent();
353           } else {
354             // otherwise we have conditionally-Nonzero things, and we need
355             // to actually run an AutogradAdd which will guard for Zeros
356             // so we leave the op as is
357             state_[n->output()] = State::Unknown;
358           }
359         } break;
360         case prim::AutogradZero: {
361           state_[n->output()] = State::Zero;
362         } break;
363         case prim::profile: {
364           // this a profile node on a tensor use
365           // if we decided to specialize this graph
366           // its input may have undefinedness info
367           // otherwise it should be Unknown
368           if (!n->inputs().empty()) {
369             state_[n->output()] = !state_.count(n->input())
370                 ? State::Unknown
371                 : state_[n->output()] = state_[n->input()];
372           }
373           break;
374         }
375         case prim::BailOut: {
376           if (auto ptt = n->output()->type()->expect<TensorType>()) {
377             state_[n->output()] = ptt->undefined()
378                 ? *ptt->undefined() ? State::Zero : State::Nonzero
379                 : State::Unknown;
380           }
381         } break;
382         // Lowered GradOf block
383         case prim::If: {
384           auto if_input = n->input(0)->node();
385           if (if_input->kind() == prim::AutogradAnyNonZero) {
386             auto all_zeros = std::all_of(
387                 if_input->inputs().begin(),
388                 if_input->inputs().end(),
389                 [&](Value* v) { return state_[v] == State::Zero; });
390 
391             auto all_nonzeros = std::all_of(
392                 if_input->inputs().begin(),
393                 if_input->inputs().end(),
394                 [&](Value* v) { return state_[v] == State::Nonzero; });
395             // Property 1: if all the gradInputs to the GradOf are Zero
396             // then the gradOutputs are also zero and will be represented as
397             // AutogradZero nodes
398             if (all_zeros) {
399               auto zero =
400                   graph_->createAutogradZero()->insertAfter(n)->output();
401               state_[zero] = State::Zero;
402               for (auto o : n->outputs()) {
403                 o->replaceAllUsesWith(zero);
404               }
405               it.destroyCurrent();
406               break;
407             }
408 
409             specializeGradSumToSize(n->blocks().at(0));
410             if (all_nonzeros) {
411               auto body = n->blocks().at(0);
412               // hoist the nodes in the GradOf body to be before the linear
413               // block
414               for (auto it = body->nodes().begin();
415                    it != body->nodes().end();) {
416                 auto block_node = *it++;
417                 block_node->moveBefore(n);
418               }
419 
420               for (size_t i = 0; i < n->outputs().size(); ++i) {
421                 n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
422                 state_[body->outputs().at(i)] = State::Nonzero;
423               }
424               it.destroyCurrent();
425               break;
426             }
427           }
428 
429           for (auto o : n->outputs()) {
430             state_[o] = State::Unknown;
431           }
432           break;
433         }
434         default:
435           for (auto o : n->outputs()) {
436             state_[o] = State::Unknown;
437           }
438           break;
439       }
440     }
441   }
442 
specializeGradSumToSizetorch::jit::AutogradZeroSpecializer443   void specializeGradSumToSize(Block* b) {
444     for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
445       Node* n = *it;
446       if (n->kind() == aten::_grad_sum_to_size) {
447         bool profiled_none_flag = profiled_none_.count(n->input(1));
448         const Node* node = n->input(1)->node();
449         // propagate profiled none through other profile_ivalue nodes;
450         while (!profiled_none_flag && node->kind() == prim::profile_ivalue) {
451           profiled_none_flag =
452               profiled_none_flag || profiled_none_.count(node->input(0));
453           node = node->input(0)->node();
454         }
455         if (n->input(1)->mustBeNone() || profiled_none_flag) {
456           n->output()->replaceAllUsesWith(n->input(0));
457           it.destroyCurrent();
458         }
459       }
460     }
461   }
462 
463   std::shared_ptr<Graph> graph_;
464   std::unordered_set<Value*> profiled_none_;
465   std::unordered_map<Value*, State> state_;
466 };
467 
468 // propagate autograd zero information through a gradient graph and
469 // remove grad_of blocks if present.
470 // Note: this is a very limited pass. It only propagates autograd zeros for
471 // operations generated by the symbolic autodiff code and cleans up
472 // AutogradAdds when possible. Outputs of other nodes are conservatively
473 // marked Unknown and not optimized.
specializeAutogradZero(std::shared_ptr<Graph> g)474 void specializeAutogradZero(std::shared_ptr<Graph> g) {
475   AutogradZeroSpecializer azs(std::move(g));
476   azs.run();
477 }
478 
479 } // namespace torch::jit
480