xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/eliminate_no_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/eliminate_no_ops.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/runtime/graph_iterator.h>
6 
7 namespace torch::jit {
8 
9 namespace {
10 
allInputsAreTensors(Node * node)11 bool allInputsAreTensors(Node* node) {
12   for (const auto* value : node->inputs()) {
13     const auto& type = value->type();
14     if (!type->castRaw<TensorType>()) {
15       return false;
16     }
17   }
18   return true;
19 }
20 
cannotOptimize(Node * node)21 bool cannotOptimize(Node* node) {
22   const auto kind = node->kind();
23   if (kind == aten::__is__ || kind == aten::__isnot__) {
24     return allInputsAreTensors(node);
25   }
26   return false;
27 }
28 
29 // Certain ops can make this optimization unsound. For example,
30 // consider the following graph:
31 //   %y : Tensor = aten::detach(%x)
32 //   %b : bool = aten::__is__(%y, %x) (= False)
33 // After remove detach, we would get
34 //   %b : bool = aten::__is__(%x, %x) (= True!)
containsInvalidOp(std::shared_ptr<Graph> & graph)35 bool containsInvalidOp(std::shared_ptr<Graph>& graph) {
36   for (auto* node : graph->nodes()) {
37     if (cannotOptimize(node)) {
38       return true;
39     }
40   }
41   return false;
42 }
43 
44 } // namespace
45 
EliminateNoOps(std::shared_ptr<Graph> & graph,std::unordered_set<c10::Symbol> custom_ops)46 bool EliminateNoOps(
47     std::shared_ptr<Graph>& graph,
48     std::unordered_set<c10::Symbol> custom_ops) {
49   GRAPH_DUMP("Before EliminateNoOps: ", graph);
50   if (containsInvalidOp(graph)) {
51     return false;
52   }
53   // Ops here should be of the form x = f(x, ...)
54   std::unordered_set<c10::Symbol> no_ops{aten::detach};
55   no_ops.insert(custom_ops.begin(), custom_ops.end());
56 
57   bool changed = false;
58 
59   auto graph_it = DepthFirstGraphNodeIterator(graph);
60   for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
61     auto it = no_ops.find(node->kind());
62     if (it == no_ops.end()) {
63       continue;
64     }
65 
66     changed = true;
67     node->output()->replaceAllUsesWith(node->input(0));
68   }
69 
70   if (changed) {
71     EliminateDeadCode(graph);
72   }
73 
74   GRAPH_DUMP("After EliminateNoOps: ", graph);
75   return changed;
76 }
77 
78 } // namespace torch::jit
79