1 #include <torch/csrc/jit/passes/specialize_autogradzero.h>
2
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/runtime/graph_executor.h>
7 #include <torch/csrc/jit/runtime/profiling_record.h>
8
9 #include <ATen/core/symbol.h>
10 #include <c10/util/irange.h>
11
12 namespace torch::jit {
13
14 static const auto countsAttribute = Symbol::attr("none_counts");
15
hasGradSumToSizeUses(Value * v)16 static bool hasGradSumToSizeUses(Value* v) {
17 return std::any_of(v->uses().begin(), v->uses().end(), [](const Use& use) {
18 return use.user->kind() == aten::_grad_sum_to_size;
19 });
20 }
21
insertProfileNodesForSpecializeAutogradZero(Block * block,ProfilingRecord * pr)22 static void insertProfileNodesForSpecializeAutogradZero(
23 Block* block,
24 ProfilingRecord* pr) {
25 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
26 auto n = *it;
27 for (const auto offset : c10::irange(n->inputs().size())) {
28 auto i = n->input(offset);
29 if (i->type()->cast<OptionalType>() && hasGradSumToSizeUses(i)) {
30 // here we are profile the definition instead of the use,
31 // because we are only optimizing in the case of a None value which is
32 // immutable
33 auto opt_pn = pr->createProfileIValueNode(i);
34
35 c10::Dict<std::string, int64_t> noneCountsDict;
36 noneCountsDict.insert("num_none", 0);
37 noneCountsDict.insert("num_present", 0);
38 IValue init_val(noneCountsDict);
39
40 opt_pn->ival_(countsAttribute, init_val);
41
42 std::function<void(Stack&)> optional_profiler = [pr,
43 opt_pn](Stack& stack) {
44 std::lock_guard<std::mutex> lock(pr->mutex_);
45
46 TORCH_INTERNAL_ASSERT(opt_pn->hasAttribute(countsAttribute));
47 // frame_id is unused
48 int64_t frame_id = 0;
49 pop(stack, frame_id);
50
51 const auto& counts_attr = opt_pn->ival(countsAttribute);
52 auto noneCounts = c10::impl::toTypedDict<std::string, int64_t>(
53 counts_attr.toGenericDict());
54 IValue value;
55 pop(stack, value);
56 if (value.isNone()) {
57 noneCounts.insert_or_assign(
58 "num_none", noneCounts.at("num_none") + 1);
59 } else {
60 noneCounts.insert_or_assign(
61 "num_present", noneCounts.at("num_present") + 1);
62 }
63 push(stack, value);
64 };
65 opt_pn->setCallback(optional_profiler);
66 opt_pn->insertAfter(i->node());
67 i->replaceAllUsesAfterNodeWith(opt_pn, opt_pn->output());
68 }
69 }
70
71 for (auto ib : n->blocks()) {
72 insertProfileNodesForSpecializeAutogradZero(ib, pr);
73 }
74 }
75 }
76
InsertProfileNodesForSpecializeAutogradZero(ProfilingRecord * pr)77 void InsertProfileNodesForSpecializeAutogradZero(ProfilingRecord* pr) {
78 insertProfileNodesForSpecializeAutogradZero(pr->profiled_graph_->block(), pr);
79 }
80
81 struct AutogradZeroSpecializer {
82 enum class State { Nonzero, Zero, Unknown };
83
AutogradZeroSpecializertorch::jit::AutogradZeroSpecializer84 AutogradZeroSpecializer(std::shared_ptr<Graph> graph)
85 : graph_(std::move(graph)) {}
86
runtorch::jit::AutogradZeroSpecializer87 void run() {
88 if (!isBackwardGraph()) {
89 return;
90 }
91 if (getExecutorMode()) {
92 if (auto versioning_if = guardSpecializations()) {
93 specializeAutogradOps(versioning_if->blocks()[0]);
94 GRAPH_DUMP("After versioning graph", graph_);
95 }
96 } else {
97 setStatesOnGraphInputs();
98 specializeAutogradOps(graph_->block());
99 }
100 GRAPH_DUMP("After specializeAutogradOps graph", graph_);
101 }
102
103 private:
isBackwardGraphtorch::jit::AutogradZeroSpecializer104 bool isBackwardGraph() {
105 return std::any_of(
106 graph_->nodes().begin(), graph_->nodes().end(), [](Node* n) {
107 switch (n->kind()) {
108 case prim::AutogradAnyNonZero:
109 case prim::AutogradAdd:
110 case aten::_grad_sum_to_size:
111 return true;
112 default:
113 return false;
114 }
115 });
116 }
117
replaceBlockInputsWithGraphInputstorch::jit::AutogradZeroSpecializer118 void replaceBlockInputsWithGraphInputs(Block* b) {
119 TORCH_INTERNAL_ASSERT(graph_->inputs().size() == b->inputs().size());
120 size_t num_inputs = graph_->inputs().size();
121 for (const auto i : c10::irange(num_inputs)) {
122 b->inputs().at(i)->replaceAllUsesWith(graph_->inputs().at(i));
123 }
124 for (const auto i : c10::irange(num_inputs)) {
125 b->eraseInput(num_inputs - (1 + i));
126 }
127 }
128
setStatesOnGraphInputstorch::jit::AutogradZeroSpecializer129 void setStatesOnGraphInputs() {
130 for (Value* input : graph_->inputs()) {
131 const auto& tp = input->type();
132 if (auto tt = tp->cast<TensorType>()) {
133 if (tt->undefined()) {
134 if (*tt->undefined()) {
135 state_[input] = State::Zero;
136 } else {
137 state_[input] = State::Nonzero;
138 }
139 } else {
140 state_[input] = State::Unknown;
141 }
142 } else if (
143 tp->isSubtypeOf(*TensorType::get()) ||
144 tp->isSubtypeOf(*ListType::ofTensors())) {
145 state_[input] = State::Nonzero;
146 } else {
147 state_[input] = State::Unknown;
148 }
149 }
150 }
151
getUsesWithAttribute_torch::jit::AutogradZeroSpecializer152 static void getUsesWithAttribute_(
153 Value* inp,
154 Symbol attr,
155 std::vector<Node*>& uses) {
156 for (auto use : inp->uses()) {
157 if (use.user->kind() != prim::profile_ivalue) {
158 continue;
159 }
160
161 if (use.user->hasAttribute(attr)) {
162 uses.push_back(use.user);
163 }
164
165 getUsesWithAttribute_(use.user->output(), attr, uses);
166 }
167 }
168
169 // this is to deal with the fact that there could be other passes that
170 // would like to profile this exact same value. this helper walks
171 // chains of `prim::profile_ivalue` to locate the one inserted by/for
172 // `specializeAutogradZero`
getUsesWithAttributetorch::jit::AutogradZeroSpecializer173 static std::vector<Node*> getUsesWithAttribute(Value* inp, Symbol attr) {
174 std::vector<Node*> uses;
175 getUsesWithAttribute_(inp, attr, uses);
176 return uses;
177 }
178
getUsetorch::jit::AutogradZeroSpecializer179 static Node* getUse(Value* inp, Symbol kind) {
180 for (auto use : inp->uses()) {
181 if (use.user->kind() == kind) {
182 return use.user;
183 }
184 }
185
186 return nullptr;
187 }
188
removeProfiledOptionalUsestorch::jit::AutogradZeroSpecializer189 void removeProfiledOptionalUses(const std::vector<Node*>& uses) {
190 TORCH_INTERNAL_ASSERT(!uses.empty());
191 auto inp = uses[0]->input();
192 // this removes `prim::profile_ivalue` from the original and to-specialize
193 // blocks N.B. the false block isn't impacted as it has been already
194 // encapsulated in a fallback function
195 for (auto u : uses) {
196 u->output()->replaceAllUsesWith(inp);
197 }
198 }
199
guardSpecializationstorch::jit::AutogradZeroSpecializer200 Node* guardSpecializations() {
201 auto versioning_if = graph_->create(prim::If, {}, graph_->outputs().size());
202 auto value_map = [](Value* v) { return v; };
203 auto true_block = versioning_if->addBlock();
204 auto false_block = versioning_if->addBlock();
205
206 // we will optimize true_block
207 true_block->cloneFrom(graph_->block(), value_map);
208 replaceBlockInputsWithGraphInputs(true_block);
209 false_block->cloneFrom(graph_->block(), value_map);
210 replaceBlockInputsWithGraphInputs(false_block);
211 replaceBlockWithFallbackGraph(false_block, graph_->inputs());
212
213 WithInsertPoint wip{graph_->block()->param_node()->next()};
214 Value* none_val = graph_->insertConstant(IValue());
215 std::vector<Value*> checks;
216 std::vector<Value*> zero_values;
217 std::vector<Value*> nonzero_values;
218
219 for (auto inp : graph_->inputs()) {
220 std::vector<Node*> iprofile_counts_nodes =
221 getUsesWithAttribute(inp, countsAttribute);
222 if (!iprofile_counts_nodes.empty()) {
223 // the original `prim::profile_value[num_present=0,...]` on `inp` is
224 // copied into `true_block` and `false_block`.
225 auto profile_ivalue_node = iprofile_counts_nodes[0];
226 TORCH_INTERNAL_ASSERT(
227 profile_ivalue_node->hasAttribute(countsAttribute));
228 const auto& counts_attr =
229 profile_ivalue_node->ival(countsAttribute).toGenericDict();
230 auto num_present = counts_attr.at(IValue{"num_present"}).toInt();
231 auto num_none = counts_attr.at(IValue{"num_none"}).toInt();
232 if (num_present == 0 && num_none != 0) {
233 auto check = graph_->insert(aten::__is__, {inp, none_val})->node();
234 checks.push_back(check->output());
235 profiled_none_.insert(inp);
236 }
237 removeProfiledOptionalUses(iprofile_counts_nodes);
238 continue;
239 }
240
241 if (inp->uses().empty() || !inp->type()->cast<TensorType>()) {
242 continue;
243 }
244
245 // TODO: check multiple uses ?
246 auto pout = getUse(inp, prim::profile);
247 if (!pout) {
248 continue;
249 }
250
251 auto pttp = pout->ty(attr::profiled_type)->expect<TensorType>();
252 if (!pttp->undefined().has_value()) {
253 continue;
254 }
255
256 state_[inp] = *pttp->undefined() ? State::Zero : State::Nonzero;
257
258 if (*pttp->undefined()) {
259 zero_values.push_back(inp);
260 } else {
261 nonzero_values.push_back(inp);
262 }
263 }
264 GRAPH_DUMP("After for loop", graph_);
265 // unable to specialize any of the inputs
266 if (nonzero_values.empty() && zero_values.empty()) {
267 GRAPH_DUMP("Unable to add any specialization guards", graph_);
268 versioning_if->destroy();
269 // the checks we inserted will be cleaned up
270 // by any subsequent DCE pass
271 return nullptr;
272 }
273
274 Node* nonzero_check = graph_->insert(prim::AutogradAllNonZero, {})->node();
275 for (Value* v : nonzero_values) {
276 nonzero_check->addInput(v);
277 }
278 checks.push_back(nonzero_check->output());
279
280 Node* zero_check = graph_->insert(prim::AutogradAllZero, {})->node();
281 for (Value* v : zero_values) {
282 zero_check->addInput(v);
283 }
284 checks.push_back(zero_check->output());
285
286 Value* bool_list =
287 graph_->insertNode(graph_->createList(BoolType::get(), checks))
288 ->output();
289 Value* conjunction = graph_->insert(aten::all, {bool_list});
290
291 versioning_if->addInput(conjunction);
292 graph_->insertNode(versioning_if);
293
294 auto ret = graph_->return_node();
295 for (const auto i : c10::irange(ret->inputs().size())) {
296 auto ogo = ret->input(i);
297 auto ngo = versioning_if->output(i);
298 ngo->copyMetadata(ogo);
299 ret->replaceInput(i, ngo);
300 }
301
302 // We've created:
303 // successful_checks = Guards(...)
304 // if (successful_checks)
305 // -> optimized graph
306 // else:
307 // -> fallback graph
308 // original graph
309 //
310 // Remove the dead original graph
311 for (auto it = graph_->block()->nodes().reverse().begin();
312 *it != versioning_if;) {
313 Node* n = *it;
314 it++;
315 n->destroy();
316 }
317
318 GRAPH_DUMP("After guardSpecializations", graph_);
319 return versioning_if;
320 }
321
specializeAutogradOpstorch::jit::AutogradZeroSpecializer322 void specializeAutogradOps(Block* block) {
323 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
324 auto n = *it;
325 switch (n->kind()) {
326 case prim::AutogradAdd: {
327 auto a = n->input(0);
328 auto b = n->input(1);
329 // if one is Autograd zero, we can just drop the add
330 if (state_[a] == State::Zero) {
331 // Zero + b == b
332 n->output()->replaceAllUsesWith(b);
333 it.destroyCurrent();
334 } else if (state_[b] == State::Zero) {
335 // a + Zero == a
336 n->output()->replaceAllUsesWith(a);
337 it.destroyCurrent();
338 } else if (
339 state_[a] == State::Nonzero && state_[b] == State::Nonzero) {
340 // when both are Nonzero, we can use a normal, optimizable add
341 // instruction
342 WithInsertPoint guard(n);
343 auto* cOne = graph_->insertConstant(1);
344 auto* add_node = graph_->insertNode(graph_->create(aten::add, 1));
345 add_node->addInput(a);
346 add_node->addInput(b);
347 add_node->addInput(cOne);
348 auto* add_output = add_node->output();
349 add_output->setType(n->output()->type());
350 state_[add_output] = State::Nonzero;
351 n->output()->replaceAllUsesWith(add_output);
352 it.destroyCurrent();
353 } else {
354 // otherwise we have conditionally-Nonzero things, and we need
355 // to actually run an AutogradAdd which will guard for Zeros
356 // so we leave the op as is
357 state_[n->output()] = State::Unknown;
358 }
359 } break;
360 case prim::AutogradZero: {
361 state_[n->output()] = State::Zero;
362 } break;
363 case prim::profile: {
364 // this a profile node on a tensor use
365 // if we decided to specialize this graph
366 // its input may have undefinedness info
367 // otherwise it should be Unknown
368 if (!n->inputs().empty()) {
369 state_[n->output()] = !state_.count(n->input())
370 ? State::Unknown
371 : state_[n->output()] = state_[n->input()];
372 }
373 break;
374 }
375 case prim::BailOut: {
376 if (auto ptt = n->output()->type()->expect<TensorType>()) {
377 state_[n->output()] = ptt->undefined()
378 ? *ptt->undefined() ? State::Zero : State::Nonzero
379 : State::Unknown;
380 }
381 } break;
382 // Lowered GradOf block
383 case prim::If: {
384 auto if_input = n->input(0)->node();
385 if (if_input->kind() == prim::AutogradAnyNonZero) {
386 auto all_zeros = std::all_of(
387 if_input->inputs().begin(),
388 if_input->inputs().end(),
389 [&](Value* v) { return state_[v] == State::Zero; });
390
391 auto all_nonzeros = std::all_of(
392 if_input->inputs().begin(),
393 if_input->inputs().end(),
394 [&](Value* v) { return state_[v] == State::Nonzero; });
395 // Property 1: if all the gradInputs to the GradOf are Zero
396 // then the gradOutputs are also zero and will be represented as
397 // AutogradZero nodes
398 if (all_zeros) {
399 auto zero =
400 graph_->createAutogradZero()->insertAfter(n)->output();
401 state_[zero] = State::Zero;
402 for (auto o : n->outputs()) {
403 o->replaceAllUsesWith(zero);
404 }
405 it.destroyCurrent();
406 break;
407 }
408
409 specializeGradSumToSize(n->blocks().at(0));
410 if (all_nonzeros) {
411 auto body = n->blocks().at(0);
412 // hoist the nodes in the GradOf body to be before the linear
413 // block
414 for (auto it = body->nodes().begin();
415 it != body->nodes().end();) {
416 auto block_node = *it++;
417 block_node->moveBefore(n);
418 }
419
420 for (size_t i = 0; i < n->outputs().size(); ++i) {
421 n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
422 state_[body->outputs().at(i)] = State::Nonzero;
423 }
424 it.destroyCurrent();
425 break;
426 }
427 }
428
429 for (auto o : n->outputs()) {
430 state_[o] = State::Unknown;
431 }
432 break;
433 }
434 default:
435 for (auto o : n->outputs()) {
436 state_[o] = State::Unknown;
437 }
438 break;
439 }
440 }
441 }
442
specializeGradSumToSizetorch::jit::AutogradZeroSpecializer443 void specializeGradSumToSize(Block* b) {
444 for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
445 Node* n = *it;
446 if (n->kind() == aten::_grad_sum_to_size) {
447 bool profiled_none_flag = profiled_none_.count(n->input(1));
448 const Node* node = n->input(1)->node();
449 // propagate profiled none through other profile_ivalue nodes;
450 while (!profiled_none_flag && node->kind() == prim::profile_ivalue) {
451 profiled_none_flag =
452 profiled_none_flag || profiled_none_.count(node->input(0));
453 node = node->input(0)->node();
454 }
455 if (n->input(1)->mustBeNone() || profiled_none_flag) {
456 n->output()->replaceAllUsesWith(n->input(0));
457 it.destroyCurrent();
458 }
459 }
460 }
461 }
462
463 std::shared_ptr<Graph> graph_;
464 std::unordered_set<Value*> profiled_none_;
465 std::unordered_map<Value*, State> state_;
466 };
467
468 // propagate autograd zero information through a gradient graph and
469 // remove grad_of blocks if present.
470 // Note: this is a very limited pass. It only propagates autograd zeros for
471 // operations generated by the symbolic autodiff code and cleans up
472 // AutogradAdds when possible. Outputs of other nodes are conservatively
473 // marked Unknown and not optimized.
specializeAutogradZero(std::shared_ptr<Graph> g)474 void specializeAutogradZero(std::shared_ptr<Graph> g) {
475 AutogradZeroSpecializer azs(std::move(g));
476 azs.run();
477 }
478
479 } // namespace torch::jit
480