xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/constant_propagation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/constant_propagation.h>
2 
3 #include <ATen/core/functional.h>
4 #include <ATen/core/ivalue.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/autograd/variable.h>
8 #include <torch/csrc/jit/ir/alias_analysis.h>
9 #include <torch/csrc/jit/ir/constants.h>
10 #include <torch/csrc/jit/ir/ir.h>
11 #include <torch/csrc/jit/ir/node_hashing.h>
12 #include <torch/csrc/jit/jit_log.h>
13 #include <torch/csrc/jit/passes/dead_code_elimination.h>
14 #include <torch/csrc/jit/runtime/operator.h>
15 #include <torch/csrc/jit/runtime/vararg_functions.h>
16 
17 #include <utility>
18 
19 namespace torch::jit {
20 
runNodeIfInputsAreConstant(const Node * n,bool ignore_custom_classes,AliasDb * db)21 std::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
22     const Node* n,
23     bool ignore_custom_classes,
24     AliasDb* db) {
25   Stack stack;
26   for (auto input : n->inputs()) {
27     if (auto ival = toIValue(input)) {
28       stack.push_back(*ival);
29     } else {
30       return std::nullopt;
31     }
32   }
33 
34   switch (n->kind()) {
35     case prim::ListUnpack: {
36       if (stack.back().toList().size() != n->outputs().size()) {
37         return std::nullopt;
38       }
39       listUnpack(stack, n->outputs().size());
40     } break;
41     case prim::TupleConstruct: {
42       auto tt = n->output()->type()->expect<TupleType>();
43       if (tt->name()) {
44         namedTupleConstruct(stack, std::move(tt), n->inputs().size());
45       } else {
46         tupleConstruct(stack, n->inputs().size());
47       }
48     } break;
49     case prim::ListConstruct: {
50       listConstruct(
51           stack,
52           n->output()->type()->expectRef<ListType>(),
53           n->inputs().size());
54     } break;
55     case prim::DictConstruct: {
56       dictConstruct(
57           stack,
58           n->output()->type()->expectRef<DictType>(),
59           n->inputs().size());
60     } break;
61     case prim::CreateObject: {
62       createObject(
63           stack,
64           n->output()->type()->expect<ClassType>(),
65           /*use_weak_ref*/ true);
66     } break;
67     case prim::GetAttr: {
68       auto attr = pop(stack).toObject()->getAttr(n->s(attr::name));
69       push(stack, attr);
70     } break;
71     case prim::isinstance: {
72       isinstance(stack, n->tys(attr::types));
73     } break;
74     default: {
75       const auto maybe_schema = n->maybeSchema();
76       if (maybe_schema && maybe_schema->is_vararg()) {
77         // vararg schemas require the number of inputs at the top of the stack
78         // but this is broken in other places in constant prop, so disable it
79         // for now
80         return std::nullopt;
81       }
82 
83       try {
84         auto op = n->getOperation();
85         op(stack);
86       } catch (...) {
87         return std::nullopt;
88       }
89     } break;
90   }
91 
92   for (IValue& v : stack) {
93     if (v.isTensor()) {
94       const at::Tensor& t = v.toTensor();
95       if (t.defined() && t.requires_grad()) {
96         // requires grad tensors cannot be constants
97         return std::nullopt;
98       }
99     }
100     // Weak form of const propagation
101     if (ignore_custom_classes) {
102       if (v.isCustomClass()) {
103         return std::nullopt;
104       }
105     }
106     // see [Constant Object Weak CompilationUnit Reference]
107     if (v.isCustomClass()) {
108       if (v.toObject()->is_weak_compilation_ref()) {
109         continue;
110       }
111       if (!db) {
112         continue;
113       }
114       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
115       Node* n_non_const = const_cast<Node*>(n);
116       if (db->mayContainAlias(
117               n_non_const->inputs(), {n_non_const->outputs()})) {
118         continue;
119       }
120       auto obj = v.toObject();
121       obj->unsafe_make_weak_compilation_ref();
122     }
123     if (v.isObject()) {
124       if (!v.toObject()->is_weak_compilation_ref()) {
125         return std::nullopt;
126       }
127     }
128   }
129   return stack;
130 }
131 
132 namespace {
133 
134 std::unordered_set<Symbol> skip_list = {
135     prim::If,
136     prim::Loop,
137     prim::Closure,
138     prim::Constant,
139     prim::AutogradZero,
140     prim::Uninitialized,
141     prim::Guard,
142     prim::profile,
143     prim::profile_ivalue,
144     prim::unchecked_unwrap_optional, // TODO remove
145     prim::awaitable,
146     aten::dequantize,
147     // TODO (zach): we should consider skipping tensor factories in the cases
148     // where the constant tensor would be large but cheap to create.
149 };
150 
151 struct ConstantPropagator {
152   // Runs constant propagation with an aliasing db and checks if inputs or
153   // outputs might be mutated in the graph
WithAliasDbtorch::jit::__anon5956705a0111::ConstantPropagator154   static ConstantPropagator WithAliasDb(
155       std::shared_ptr<Graph> graph,
156       bool ignore_custom_classes) {
157     return ConstantPropagator(std::move(graph), true, ignore_custom_classes);
158   }
159 
160   // Runs constant propagation only on ops that clearly do not have aliased
161   // inputs or outputs without computing aliasing information
NoAliasDbtorch::jit::__anon5956705a0111::ConstantPropagator162   static ConstantPropagator NoAliasDb(std::shared_ptr<Graph> graph) {
163     return ConstantPropagator(std::move(graph), false, false);
164   }
165 
runtorch::jit::__anon5956705a0111::ConstantPropagator166   bool run() {
167     ConstantPropagation(graph_->block());
168     return made_change_;
169   }
170 
171  private:
ConstantPropagatortorch::jit::__anon5956705a0111::ConstantPropagator172   ConstantPropagator(
173       std::shared_ptr<Graph> graph,
174       bool aliasing_types,
175       bool ignore_custom_classes)
176       : graph_(std::move(graph)),
177         aliasing_types_(aliasing_types),
178         ignore_custom_classes_(ignore_custom_classes) {}
179 
propagateNodetorch::jit::__anon5956705a0111::ConstantPropagator180   void propagateNode(Node* n) {
181     std::vector<IValue> outputs;
182     if (auto outputs_opt =
183             runNodeIfInputsAreConstant(n, ignore_custom_classes_)) {
184       outputs = std::move(outputs_opt.value());
185     } else {
186       // The op failed to run, so we cannot continue constant-prop for it.
187       return;
188     }
189     auto graph = n->owningGraph();
190     WithInsertPoint guard(n);
191     for (const auto i : c10::irange(outputs.size())) {
192       auto new_output = tryInsertConstant(*graph, outputs[i]);
193       if (new_output) {
194         made_change_ = true;
195         GRAPH_UPDATE(
196             "Folding %",
197             n->outputs()[i]->debugName(),
198             " with ",
199             getHeader((*new_output)->node()));
200         if (outputs[i].isNone()) {
201           (*new_output)->setType(n->outputs()[i]->type());
202         }
203         n->outputs()[i]->replaceAllUsesWith(*new_output);
204       }
205       // If we cannot insert the IValue as a constant, give up replacing the
206       // node and let DCE remove it
207     }
208   }
209 
removeLoopNodetorch::jit::__anon5956705a0111::ConstantPropagator210   void removeLoopNode(Node* n) {
211     auto loop_input_offset = 2; // offset of loop carried deps in input list
212     for (size_t i = 0; i < n->outputs().size(); ++i) {
213       n->outputs().at(i)->replaceAllUsesWith(
214           n->inputs().at(i + loop_input_offset));
215     }
216     made_change_ = true;
217     n->destroy();
218   }
219 
loopWillNotRuntorch::jit::__anon5956705a0111::ConstantPropagator220   bool loopWillNotRun(Node* node) {
221     Value* trip_count = node->inputs().at(0);
222     int64_t iter_len = constant_as<int64_t>(trip_count).value_or(1);
223 
224     Value* start_cond = node->inputs().at(1);
225     bool cond_val = constant_as<bool>(start_cond).value_or(true);
226 
227     bool loop_might_run = cond_val && iter_len > 0;
228     if (!loop_might_run) {
229       GRAPH_UPDATE(
230           "Removing unexecuted loop: ",
231           *node,
232           "\ntripcount: ",
233           trip_count,
234           " and start_cond: ",
235           getHeader(start_cond->node()));
236     }
237     return !loop_might_run;
238   }
239 
inlineIfBodytorch::jit::__anon5956705a0111::ConstantPropagator240   void inlineIfBody(Block* body) {
241     Node* n = body->owningNode();
242     for (auto it = body->nodes().begin(); it != body->nodes().end();) {
243       Node* body_node = *it;
244       // advance iterator because after body_node is moved its next pointer will
245       // be to n
246       it++;
247       body_node->moveBefore(n);
248     }
249     for (size_t i = 0; i < n->outputs().size(); ++i) {
250       n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
251     }
252     // NB: destroy the node here, because it might contain side effects, like
253     // print
254     n->destroy();
255   }
256 
inlineIftorch::jit::__anon5956705a0111::ConstantPropagator257   void inlineIf(Node* n) {
258     auto input_bool = constant_as<bool>(n->input());
259     AT_ASSERT(input_bool);
260     GRAPH_UPDATE(
261         "Folding if ",
262         getHeader(n->input()->node()),
263         " where condition = ",
264         *input_bool);
265     size_t block_index = *input_bool ? 0 : 1;
266     ConstantPropagation(n->blocks().at(block_index));
267     inlineIfBody(n->blocks().at(block_index));
268     made_change_ = true;
269   }
270 
replaceAndRemoveIfOutputtorch::jit::__anon5956705a0111::ConstantPropagator271   void replaceAndRemoveIfOutput(Node* n, size_t i, Value* replacement) {
272     n->outputs().at(i)->replaceAllUsesWith(replacement);
273     n->eraseOutput(i);
274     n->blocks().at(0)->eraseOutput(i);
275     n->blocks().at(1)->eraseOutput(i);
276   }
277 
278   // remove extra outputs from the node
removeExtraIfOutputstorch::jit::__anon5956705a0111::ConstantPropagator279   void removeExtraIfOutputs(Node* n) {
280     TORCH_CHECK(n->kind() == prim::If, "Only supported for If nodes");
281     auto true_block = n->blocks()[0];
282     auto false_block = n->blocks()[1];
283     auto graph = n->owningGraph();
284     auto initial_outputs = true_block->outputs().size();
285     WithInsertPoint guard(n);
286     for (size_t i = 0; i < true_block->outputs().size();) {
287       auto t_out = true_block->outputs().at(i);
288       auto f_out = false_block->outputs().at(i);
289 
290       // neither block changes the output value
291       if (true_block->outputs()[i] == false_block->outputs()[i]) {
292         replaceAndRemoveIfOutput(n, i, true_block->outputs()[i]);
293         continue;
294       }
295 
296       // true block output is constant and constant matches false block output
297       auto maybe_const = toIValue(t_out);
298       auto eq = EqualNode();
299       if (maybe_const && eq(t_out->node(), f_out->node())) {
300         auto new_const = graph->insertConstant(*maybe_const);
301         replaceAndRemoveIfOutput(n, i, new_const);
302         continue;
303       }
304 
305       i++; // increment bc we didn't remove current index
306     }
307     made_change_ |= initial_outputs != true_block->outputs().size();
308   }
309 
310   // remove extra outputs from the node
removeExtraLoopOutputstorch::jit::__anon5956705a0111::ConstantPropagator311   void removeExtraLoopOutputs(Node* node) {
312     auto initial_outputs = node->outputs().size();
313     auto loop_body = node->blocks().at(0);
314     auto loop_input_offset = 2; // offset of loop carried deps in input list
315     auto loop_body_offset =
316         1; // offset to the loop carried dependencies in block inputs/outputs
317     for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
318       size_t i = i_1 - 1;
319       // if the value is no longer changed remove output
320       if (loop_body->inputs().at(loop_body_offset + i) ==
321           loop_body->outputs().at(loop_body_offset + i)) {
322         auto node_input = node->inputs().at(loop_input_offset + i);
323         node->outputs().at(i)->replaceAllUsesWith(node_input);
324         loop_body->inputs()
325             .at(loop_body_offset + i)
326             ->replaceAllUsesWith(node_input);
327         node->eraseOutput(i);
328         node->removeInput(loop_input_offset + i);
329         loop_body->eraseInput(loop_body_offset + i);
330         loop_body->eraseOutput(loop_body_offset + i);
331       }
332     }
333     made_change_ |= initial_outputs != node->outputs().size();
334   }
335 
noMutableValuestorch::jit::__anon5956705a0111::ConstantPropagator336   bool noMutableValues(at::ArrayRef<Value*> values) {
337     return std::none_of(values.begin(), values.end(), [](Value* v) {
338       return AliasDb::isMutableType(v);
339     });
340   }
341 
getOrCreateAliasDbtorch::jit::__anon5956705a0111::ConstantPropagator342   AliasDb* getOrCreateAliasDb() {
343     if (!aliasDb_) {
344       aliasDb_ = std::make_unique<AliasDb>(graph_);
345     }
346     return aliasDb_.get();
347   }
348 
supportedNodetorch::jit::__anon5956705a0111::ConstantPropagator349   bool supportedNode(Node* n) {
350     bool no_mutation = false;
351     if (aliasing_types_) {
352       no_mutation = !getOrCreateAliasDb()->hasWriters(n);
353     } else {
354       no_mutation =
355           noMutableValues(n->inputs()) && noMutableValues(n->outputs());
356     }
357     return no_mutation && !n->kind().is_onnx() &&
358         skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
359         !n->hasSideEffects() && n->blocks().empty();
360   }
361 
ConstantPropagationtorch::jit::__anon5956705a0111::ConstantPropagator362   void ConstantPropagation(at::ArrayRef<Block*> blocks) {
363     for (Block* block : blocks) {
364       ConstantPropagation(block);
365     }
366   }
367 
ConstantPropagationtorch::jit::__anon5956705a0111::ConstantPropagator368   void ConstantPropagation(Node* n) {
369     bool constant_inputs =
370         std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
371           return v->node()->kind() == prim::Constant;
372         });
373     if (n->kind() == prim::If) {
374       // inline node if we can, otherwise check for simplified outputs
375       if (constant_inputs) {
376         inlineIf(n);
377       } else {
378         ConstantPropagation(n->blocks());
379         removeExtraIfOutputs(n);
380       }
381     } else if (n->kind() == prim::Loop) {
382       if (loopWillNotRun(n)) {
383         removeLoopNode(n);
384       } else {
385         ConstantPropagation(n->blocks());
386         removeExtraLoopOutputs(n);
387       }
388     } else if (constant_inputs && supportedNode(n)) {
389       propagateNode(n);
390     } else {
391       ConstantPropagation(n->blocks());
392     }
393   }
394 
ConstantPropagationtorch::jit::__anon5956705a0111::ConstantPropagator395   void ConstantPropagation(Block* block) {
396     for (auto it = block->nodes().begin(); it != block->nodes().end();) {
397       Node* n = *it;
398       it++; // advance iterator bc the current node may be destroyed
399       ConstantPropagation(n);
400     }
401   }
402 
403   std::shared_ptr<Graph> graph_;
404   // lazily initialized if using aliasing_types, otherwise not initialized
405   std::unique_ptr<AliasDb> aliasDb_ = nullptr;
406   bool aliasing_types_;
407   bool made_change_ = false;
408   bool ignore_custom_classes_;
409 };
410 } // anonymous namespace
411 
ConstantPropagation(std::shared_ptr<Graph> & graph,bool ignore_custom_classes)412 bool ConstantPropagation(
413     std::shared_ptr<Graph>& graph,
414     bool ignore_custom_classes) {
415   ConstantPropagator cp =
416       ConstantPropagator::WithAliasDb(graph, ignore_custom_classes);
417   bool made_change = cp.run();
418   if (made_change) {
419     EliminateDeadCode(graph);
420   }
421   GRAPH_DUMP("After ConstantPropagation: ", graph);
422   return made_change;
423 }
424 
ConstantPropagationImmutableTypes(std::shared_ptr<Graph> & graph)425 bool ConstantPropagationImmutableTypes(std::shared_ptr<Graph>& graph) {
426   ConstantPropagator cp = ConstantPropagator::NoAliasDb(graph);
427   bool made_change = cp.run();
428   if (made_change) {
429     EliminateDeadCode(graph);
430   }
431   GRAPH_DUMP("After ConstantPropagationImmutableTypes: ", graph);
432   return made_change;
433 }
434 
435 } // namespace torch::jit
436