1 #include <torch/csrc/jit/passes/constant_propagation.h>
2
3 #include <ATen/core/functional.h>
4 #include <ATen/core/ivalue.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/autograd/variable.h>
8 #include <torch/csrc/jit/ir/alias_analysis.h>
9 #include <torch/csrc/jit/ir/constants.h>
10 #include <torch/csrc/jit/ir/ir.h>
11 #include <torch/csrc/jit/ir/node_hashing.h>
12 #include <torch/csrc/jit/jit_log.h>
13 #include <torch/csrc/jit/passes/dead_code_elimination.h>
14 #include <torch/csrc/jit/runtime/operator.h>
15 #include <torch/csrc/jit/runtime/vararg_functions.h>
16
17 #include <utility>
18
19 namespace torch::jit {
20
runNodeIfInputsAreConstant(const Node * n,bool ignore_custom_classes,AliasDb * db)21 std::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
22 const Node* n,
23 bool ignore_custom_classes,
24 AliasDb* db) {
25 Stack stack;
26 for (auto input : n->inputs()) {
27 if (auto ival = toIValue(input)) {
28 stack.push_back(*ival);
29 } else {
30 return std::nullopt;
31 }
32 }
33
34 switch (n->kind()) {
35 case prim::ListUnpack: {
36 if (stack.back().toList().size() != n->outputs().size()) {
37 return std::nullopt;
38 }
39 listUnpack(stack, n->outputs().size());
40 } break;
41 case prim::TupleConstruct: {
42 auto tt = n->output()->type()->expect<TupleType>();
43 if (tt->name()) {
44 namedTupleConstruct(stack, std::move(tt), n->inputs().size());
45 } else {
46 tupleConstruct(stack, n->inputs().size());
47 }
48 } break;
49 case prim::ListConstruct: {
50 listConstruct(
51 stack,
52 n->output()->type()->expectRef<ListType>(),
53 n->inputs().size());
54 } break;
55 case prim::DictConstruct: {
56 dictConstruct(
57 stack,
58 n->output()->type()->expectRef<DictType>(),
59 n->inputs().size());
60 } break;
61 case prim::CreateObject: {
62 createObject(
63 stack,
64 n->output()->type()->expect<ClassType>(),
65 /*use_weak_ref*/ true);
66 } break;
67 case prim::GetAttr: {
68 auto attr = pop(stack).toObject()->getAttr(n->s(attr::name));
69 push(stack, attr);
70 } break;
71 case prim::isinstance: {
72 isinstance(stack, n->tys(attr::types));
73 } break;
74 default: {
75 const auto maybe_schema = n->maybeSchema();
76 if (maybe_schema && maybe_schema->is_vararg()) {
77 // vararg schemas require the number of inputs at the top of the stack
78 // but this is broken in other places in constant prop, so disable it
79 // for now
80 return std::nullopt;
81 }
82
83 try {
84 auto op = n->getOperation();
85 op(stack);
86 } catch (...) {
87 return std::nullopt;
88 }
89 } break;
90 }
91
92 for (IValue& v : stack) {
93 if (v.isTensor()) {
94 const at::Tensor& t = v.toTensor();
95 if (t.defined() && t.requires_grad()) {
96 // requires grad tensors cannot be constants
97 return std::nullopt;
98 }
99 }
100 // Weak form of const propagation
101 if (ignore_custom_classes) {
102 if (v.isCustomClass()) {
103 return std::nullopt;
104 }
105 }
106 // see [Constant Object Weak CompilationUnit Reference]
107 if (v.isCustomClass()) {
108 if (v.toObject()->is_weak_compilation_ref()) {
109 continue;
110 }
111 if (!db) {
112 continue;
113 }
114 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
115 Node* n_non_const = const_cast<Node*>(n);
116 if (db->mayContainAlias(
117 n_non_const->inputs(), {n_non_const->outputs()})) {
118 continue;
119 }
120 auto obj = v.toObject();
121 obj->unsafe_make_weak_compilation_ref();
122 }
123 if (v.isObject()) {
124 if (!v.toObject()->is_weak_compilation_ref()) {
125 return std::nullopt;
126 }
127 }
128 }
129 return stack;
130 }
131
132 namespace {
133
134 std::unordered_set<Symbol> skip_list = {
135 prim::If,
136 prim::Loop,
137 prim::Closure,
138 prim::Constant,
139 prim::AutogradZero,
140 prim::Uninitialized,
141 prim::Guard,
142 prim::profile,
143 prim::profile_ivalue,
144 prim::unchecked_unwrap_optional, // TODO remove
145 prim::awaitable,
146 aten::dequantize,
147 // TODO (zach): we should consider skipping tensor factories in the cases
148 // where the constant tensor would be large but cheap to create.
149 };
150
151 struct ConstantPropagator {
152 // Runs constant propagation with an aliasing db and checks if inputs or
153 // outputs might be mutated in the graph
WithAliasDbtorch::jit::__anon5956705a0111::ConstantPropagator154 static ConstantPropagator WithAliasDb(
155 std::shared_ptr<Graph> graph,
156 bool ignore_custom_classes) {
157 return ConstantPropagator(std::move(graph), true, ignore_custom_classes);
158 }
159
160 // Runs constant propagation only on ops that clearly do not have aliased
161 // inputs or outputs without computing aliasing information
NoAliasDbtorch::jit::__anon5956705a0111::ConstantPropagator162 static ConstantPropagator NoAliasDb(std::shared_ptr<Graph> graph) {
163 return ConstantPropagator(std::move(graph), false, false);
164 }
165
runtorch::jit::__anon5956705a0111::ConstantPropagator166 bool run() {
167 ConstantPropagation(graph_->block());
168 return made_change_;
169 }
170
171 private:
ConstantPropagatortorch::jit::__anon5956705a0111::ConstantPropagator172 ConstantPropagator(
173 std::shared_ptr<Graph> graph,
174 bool aliasing_types,
175 bool ignore_custom_classes)
176 : graph_(std::move(graph)),
177 aliasing_types_(aliasing_types),
178 ignore_custom_classes_(ignore_custom_classes) {}
179
propagateNodetorch::jit::__anon5956705a0111::ConstantPropagator180 void propagateNode(Node* n) {
181 std::vector<IValue> outputs;
182 if (auto outputs_opt =
183 runNodeIfInputsAreConstant(n, ignore_custom_classes_)) {
184 outputs = std::move(outputs_opt.value());
185 } else {
186 // The op failed to run, so we cannot continue constant-prop for it.
187 return;
188 }
189 auto graph = n->owningGraph();
190 WithInsertPoint guard(n);
191 for (const auto i : c10::irange(outputs.size())) {
192 auto new_output = tryInsertConstant(*graph, outputs[i]);
193 if (new_output) {
194 made_change_ = true;
195 GRAPH_UPDATE(
196 "Folding %",
197 n->outputs()[i]->debugName(),
198 " with ",
199 getHeader((*new_output)->node()));
200 if (outputs[i].isNone()) {
201 (*new_output)->setType(n->outputs()[i]->type());
202 }
203 n->outputs()[i]->replaceAllUsesWith(*new_output);
204 }
205 // If we cannot insert the IValue as a constant, give up replacing the
206 // node and let DCE remove it
207 }
208 }
209
removeLoopNodetorch::jit::__anon5956705a0111::ConstantPropagator210 void removeLoopNode(Node* n) {
211 auto loop_input_offset = 2; // offset of loop carried deps in input list
212 for (size_t i = 0; i < n->outputs().size(); ++i) {
213 n->outputs().at(i)->replaceAllUsesWith(
214 n->inputs().at(i + loop_input_offset));
215 }
216 made_change_ = true;
217 n->destroy();
218 }
219
loopWillNotRuntorch::jit::__anon5956705a0111::ConstantPropagator220 bool loopWillNotRun(Node* node) {
221 Value* trip_count = node->inputs().at(0);
222 int64_t iter_len = constant_as<int64_t>(trip_count).value_or(1);
223
224 Value* start_cond = node->inputs().at(1);
225 bool cond_val = constant_as<bool>(start_cond).value_or(true);
226
227 bool loop_might_run = cond_val && iter_len > 0;
228 if (!loop_might_run) {
229 GRAPH_UPDATE(
230 "Removing unexecuted loop: ",
231 *node,
232 "\ntripcount: ",
233 trip_count,
234 " and start_cond: ",
235 getHeader(start_cond->node()));
236 }
237 return !loop_might_run;
238 }
239
inlineIfBodytorch::jit::__anon5956705a0111::ConstantPropagator240 void inlineIfBody(Block* body) {
241 Node* n = body->owningNode();
242 for (auto it = body->nodes().begin(); it != body->nodes().end();) {
243 Node* body_node = *it;
244 // advance iterator because after body_node is moved its next pointer will
245 // be to n
246 it++;
247 body_node->moveBefore(n);
248 }
249 for (size_t i = 0; i < n->outputs().size(); ++i) {
250 n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
251 }
252 // NB: destroy the node here, because it might contain side effects, like
253 // print
254 n->destroy();
255 }
256
inlineIftorch::jit::__anon5956705a0111::ConstantPropagator257 void inlineIf(Node* n) {
258 auto input_bool = constant_as<bool>(n->input());
259 AT_ASSERT(input_bool);
260 GRAPH_UPDATE(
261 "Folding if ",
262 getHeader(n->input()->node()),
263 " where condition = ",
264 *input_bool);
265 size_t block_index = *input_bool ? 0 : 1;
266 ConstantPropagation(n->blocks().at(block_index));
267 inlineIfBody(n->blocks().at(block_index));
268 made_change_ = true;
269 }
270
replaceAndRemoveIfOutputtorch::jit::__anon5956705a0111::ConstantPropagator271 void replaceAndRemoveIfOutput(Node* n, size_t i, Value* replacement) {
272 n->outputs().at(i)->replaceAllUsesWith(replacement);
273 n->eraseOutput(i);
274 n->blocks().at(0)->eraseOutput(i);
275 n->blocks().at(1)->eraseOutput(i);
276 }
277
278 // remove extra outputs from the node
removeExtraIfOutputstorch::jit::__anon5956705a0111::ConstantPropagator279 void removeExtraIfOutputs(Node* n) {
280 TORCH_CHECK(n->kind() == prim::If, "Only supported for If nodes");
281 auto true_block = n->blocks()[0];
282 auto false_block = n->blocks()[1];
283 auto graph = n->owningGraph();
284 auto initial_outputs = true_block->outputs().size();
285 WithInsertPoint guard(n);
286 for (size_t i = 0; i < true_block->outputs().size();) {
287 auto t_out = true_block->outputs().at(i);
288 auto f_out = false_block->outputs().at(i);
289
290 // neither block changes the output value
291 if (true_block->outputs()[i] == false_block->outputs()[i]) {
292 replaceAndRemoveIfOutput(n, i, true_block->outputs()[i]);
293 continue;
294 }
295
296 // true block output is constant and constant matches false block output
297 auto maybe_const = toIValue(t_out);
298 auto eq = EqualNode();
299 if (maybe_const && eq(t_out->node(), f_out->node())) {
300 auto new_const = graph->insertConstant(*maybe_const);
301 replaceAndRemoveIfOutput(n, i, new_const);
302 continue;
303 }
304
305 i++; // increment bc we didn't remove current index
306 }
307 made_change_ |= initial_outputs != true_block->outputs().size();
308 }
309
310 // remove extra outputs from the node
removeExtraLoopOutputstorch::jit::__anon5956705a0111::ConstantPropagator311 void removeExtraLoopOutputs(Node* node) {
312 auto initial_outputs = node->outputs().size();
313 auto loop_body = node->blocks().at(0);
314 auto loop_input_offset = 2; // offset of loop carried deps in input list
315 auto loop_body_offset =
316 1; // offset to the loop carried dependencies in block inputs/outputs
317 for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
318 size_t i = i_1 - 1;
319 // if the value is no longer changed remove output
320 if (loop_body->inputs().at(loop_body_offset + i) ==
321 loop_body->outputs().at(loop_body_offset + i)) {
322 auto node_input = node->inputs().at(loop_input_offset + i);
323 node->outputs().at(i)->replaceAllUsesWith(node_input);
324 loop_body->inputs()
325 .at(loop_body_offset + i)
326 ->replaceAllUsesWith(node_input);
327 node->eraseOutput(i);
328 node->removeInput(loop_input_offset + i);
329 loop_body->eraseInput(loop_body_offset + i);
330 loop_body->eraseOutput(loop_body_offset + i);
331 }
332 }
333 made_change_ |= initial_outputs != node->outputs().size();
334 }
335
noMutableValuestorch::jit::__anon5956705a0111::ConstantPropagator336 bool noMutableValues(at::ArrayRef<Value*> values) {
337 return std::none_of(values.begin(), values.end(), [](Value* v) {
338 return AliasDb::isMutableType(v);
339 });
340 }
341
getOrCreateAliasDbtorch::jit::__anon5956705a0111::ConstantPropagator342 AliasDb* getOrCreateAliasDb() {
343 if (!aliasDb_) {
344 aliasDb_ = std::make_unique<AliasDb>(graph_);
345 }
346 return aliasDb_.get();
347 }
348
supportedNodetorch::jit::__anon5956705a0111::ConstantPropagator349 bool supportedNode(Node* n) {
350 bool no_mutation = false;
351 if (aliasing_types_) {
352 no_mutation = !getOrCreateAliasDb()->hasWriters(n);
353 } else {
354 no_mutation =
355 noMutableValues(n->inputs()) && noMutableValues(n->outputs());
356 }
357 return no_mutation && !n->kind().is_onnx() &&
358 skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
359 !n->hasSideEffects() && n->blocks().empty();
360 }
361
ConstantPropagationtorch::jit::__anon5956705a0111::ConstantPropagator362 void ConstantPropagation(at::ArrayRef<Block*> blocks) {
363 for (Block* block : blocks) {
364 ConstantPropagation(block);
365 }
366 }
367
ConstantPropagationtorch::jit::__anon5956705a0111::ConstantPropagator368 void ConstantPropagation(Node* n) {
369 bool constant_inputs =
370 std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
371 return v->node()->kind() == prim::Constant;
372 });
373 if (n->kind() == prim::If) {
374 // inline node if we can, otherwise check for simplified outputs
375 if (constant_inputs) {
376 inlineIf(n);
377 } else {
378 ConstantPropagation(n->blocks());
379 removeExtraIfOutputs(n);
380 }
381 } else if (n->kind() == prim::Loop) {
382 if (loopWillNotRun(n)) {
383 removeLoopNode(n);
384 } else {
385 ConstantPropagation(n->blocks());
386 removeExtraLoopOutputs(n);
387 }
388 } else if (constant_inputs && supportedNode(n)) {
389 propagateNode(n);
390 } else {
391 ConstantPropagation(n->blocks());
392 }
393 }
394
ConstantPropagationtorch::jit::__anon5956705a0111::ConstantPropagator395 void ConstantPropagation(Block* block) {
396 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
397 Node* n = *it;
398 it++; // advance iterator bc the current node may be destroyed
399 ConstantPropagation(n);
400 }
401 }
402
403 std::shared_ptr<Graph> graph_;
404 // lazily initialized if using aliasing_types, otherwise not initialized
405 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
406 bool aliasing_types_;
407 bool made_change_ = false;
408 bool ignore_custom_classes_;
409 };
410 } // anonymous namespace
411
ConstantPropagation(std::shared_ptr<Graph> & graph,bool ignore_custom_classes)412 bool ConstantPropagation(
413 std::shared_ptr<Graph>& graph,
414 bool ignore_custom_classes) {
415 ConstantPropagator cp =
416 ConstantPropagator::WithAliasDb(graph, ignore_custom_classes);
417 bool made_change = cp.run();
418 if (made_change) {
419 EliminateDeadCode(graph);
420 }
421 GRAPH_DUMP("After ConstantPropagation: ", graph);
422 return made_change;
423 }
424
ConstantPropagationImmutableTypes(std::shared_ptr<Graph> & graph)425 bool ConstantPropagationImmutableTypes(std::shared_ptr<Graph>& graph) {
426 ConstantPropagator cp = ConstantPropagator::NoAliasDb(graph);
427 bool made_change = cp.run();
428 if (made_change) {
429 EliminateDeadCode(graph);
430 }
431 GRAPH_DUMP("After ConstantPropagationImmutableTypes: ", graph);
432 return made_change;
433 }
434
435 } // namespace torch::jit
436