1 #include <torch/csrc/jit/passes/eliminate_no_ops.h>
2
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/runtime/graph_iterator.h>
6
7 namespace torch::jit {
8
9 namespace {
10
allInputsAreTensors(Node * node)11 bool allInputsAreTensors(Node* node) {
12 for (const auto* value : node->inputs()) {
13 const auto& type = value->type();
14 if (!type->castRaw<TensorType>()) {
15 return false;
16 }
17 }
18 return true;
19 }
20
cannotOptimize(Node * node)21 bool cannotOptimize(Node* node) {
22 const auto kind = node->kind();
23 if (kind == aten::__is__ || kind == aten::__isnot__) {
24 return allInputsAreTensors(node);
25 }
26 return false;
27 }
28
29 // Certain ops can make this optimization unsound. For example,
30 // consider the following graph:
31 // %y : Tensor = aten::detach(%x)
32 // %b : bool = aten::__is__(%y, %x) (= False)
33 // After remove detach, we would get
34 // %b : bool = aten::__is__(%x, %x) (= True!)
containsInvalidOp(std::shared_ptr<Graph> & graph)35 bool containsInvalidOp(std::shared_ptr<Graph>& graph) {
36 for (auto* node : graph->nodes()) {
37 if (cannotOptimize(node)) {
38 return true;
39 }
40 }
41 return false;
42 }
43
44 } // namespace
45
EliminateNoOps(std::shared_ptr<Graph> & graph,std::unordered_set<c10::Symbol> custom_ops)46 bool EliminateNoOps(
47 std::shared_ptr<Graph>& graph,
48 std::unordered_set<c10::Symbol> custom_ops) {
49 GRAPH_DUMP("Before EliminateNoOps: ", graph);
50 if (containsInvalidOp(graph)) {
51 return false;
52 }
53 // Ops here should be of the form x = f(x, ...)
54 std::unordered_set<c10::Symbol> no_ops{aten::detach};
55 no_ops.insert(custom_ops.begin(), custom_ops.end());
56
57 bool changed = false;
58
59 auto graph_it = DepthFirstGraphNodeIterator(graph);
60 for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
61 auto it = no_ops.find(node->kind());
62 if (it == no_ops.end()) {
63 continue;
64 }
65
66 changed = true;
67 node->output()->replaceAllUsesWith(node->input(0));
68 }
69
70 if (changed) {
71 EliminateDeadCode(graph);
72 }
73
74 GRAPH_DUMP("After EliminateNoOps: ", graph);
75 return changed;
76 }
77
78 } // namespace torch::jit
79