xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/canonicalize.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/canonicalize.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/ir/ir_views.h>
5 
6 namespace torch::jit {
7 
8 // Canonicalize a graph, renumbering it so that all structurally equivalent
9 // graphs have same numbers.
10 // keep_unique_names: If false, canonicalizes unique names by removing them
11 //   and replacing them with normal value names.
12 //   Otherwise, ignores values with unique names.
Canonicalize(const std::shared_ptr<Graph> & graph,bool keep_unique_names)13 std::shared_ptr<Graph> Canonicalize(
14     const std::shared_ptr<Graph>& graph,
15     bool keep_unique_names) {
16   auto r = std::make_shared<Graph>(graph->current_scope());
17   std::unordered_map<Value*, Value*> rn_env;
18   auto rn_fn = [&](Value* v) { return rn_env.at(v); };
19   for (auto* input : graph->inputs()) {
20     auto* r_input = r->addInput();
21     r_input->copyMetadata(input);
22     if (!keep_unique_names)
23       r_input->setDebugName("");
24     rn_env[input] = r_input;
25   }
26   for (auto* node : graph->nodes()) {
27     auto* r_node = r->createClone(node, rn_fn);
28     if (!keep_unique_names) {
29       for (auto* output : r_node->outputs()) {
30         output->setDebugName("");
31       }
32     }
33     r->appendNode(r_node);
34     auto outputs = node->outputs();
35     auto r_outputs = r_node->outputs();
36     for (const auto i : c10::irange(outputs.size())) {
37       rn_env[outputs.at(i)] = r_outputs.at(i);
38     }
39     if (node->hasAttribute(attr::Subgraph)) {
40       r_node->g_(
41           attr::Subgraph,
42           Canonicalize(node->g(attr::Subgraph), keep_unique_names));
43     }
44   }
45   for (auto* output : graph->outputs()) {
46     r->registerOutput(rn_fn(output));
47   }
48 
49   return r;
50 }
51 
52 // Which index in b's owning Node is b
blockIndex(const Block * b)53 static size_t blockIndex(const Block* b) {
54   auto n = b->owningNode();
55   AT_ASSERT(n);
56   for (size_t i = 0; i < n->blocks().size(); ++i) {
57     if (n->blocks()[i] == b) {
58       return i;
59     }
60   }
61   AT_ASSERT(false);
62 }
63 
64 /*
65  * This establishes a canonical ordering of nodes.
66  * If n1 and n2 are in the same block, whichever node appears first
67  * is before the other.
68  * If n1 and n2 are contained in different blocks of an if node,
69  * then whichever block is in the true block is ordered before the other.
70  * If n1 contains n2, then n1 is before n2. This has the nice property that
71  * whichever node appears first in a dump of the graph is before the other.
72  * NB: this is not a topological index. Topologically, two nodes in
73  * different blocks of an if node are not topologically < or > each other.
74  */
isBefore(Node * n1,Node * n2)75 static bool isBefore(Node* n1, Node* n2) {
76   // Invalid to call with the same node as both args
77   AT_ASSERT(n1 != n2);
78 
79   // Set n1 and n2 to be the number of blocks from the Graph block
80   size_t d_1 = n1->blocksFromGraphBlock();
81   size_t d_2 = n2->blocksFromGraphBlock();
82 
83   for (; d_1 > d_2; --d_1) {
84     n1 = n1->owningBlock()->owningNode();
85     // n2 contains n1
86     if (n1 == n2) {
87       return false;
88     }
89   }
90 
91   for (; d_2 > d_1; --d_2) {
92     n2 = n2->owningBlock()->owningNode();
93     // n1 contains n2
94     if (n2 == n1) {
95       return true;
96     }
97   }
98 
99   // Now they are the same numer of blocks from the graph block,
100   // recurse upwards, checking if they are on the same block
101   while (true) {
102     if (n1->owningBlock() == n2->owningBlock()) {
103       return n1->isBefore(n2);
104     }
105 
106     auto new_n1 = n1->owningBlock()->owningNode();
107     auto new_n2 = n2->owningBlock()->owningNode();
108 
109     AT_ASSERT(new_n1 != nullptr);
110     AT_ASSERT(new_n2 != nullptr);
111 
112     if (new_n1 == new_n2) {
113       // take whichever node is in the earlier block
114       auto index_1 = blockIndex(n1->owningBlock());
115       auto index_2 = blockIndex(n2->owningBlock());
116       return index_1 < index_2;
117     }
118 
119     n1 = new_n1;
120     n2 = new_n2;
121   }
122 }
123 
isBefore(const Use & a,const Use & b)124 static bool isBefore(const Use& a, const Use& b) {
125   // If two uses are the same node, we order on offset
126   if (a.user == b.user) {
127     return a.offset < b.offset;
128   }
129 
130   return isBefore(a.user, b.user);
131 }
132 
isAfter(const Use & a,const Use & b)133 static bool isAfter(const Use& a, const Use& b) {
134   if (a.user == b.user && a.offset == b.offset) {
135     return false;
136   }
137   return !isBefore(a, b);
138 }
139 
isBeforeOrAfter(const Use & a,const Use & b,bool checking_before)140 bool isBeforeOrAfter(const Use& a, const Use& b, bool checking_before) {
141   return checking_before ? isBefore(a, b) : isAfter(a, b);
142 }
143 
firstOrLastUse(Value * v,bool find_first)144 std::optional<const Use> firstOrLastUse(Value* v, bool find_first) {
145   if (v->uses().empty()) {
146     return std::nullopt;
147   }
148   Use extreme_use = v->uses()[0];
149   for (size_t i = 1; i < v->uses().size(); ++i) {
150     auto n_use = v->uses()[i];
151     if (!isBeforeOrAfter(extreme_use, n_use, find_first)) {
152       extreme_use = n_use;
153     }
154   }
155 
156   return extreme_use;
157 }
158 
gatherFirstUses(at::ArrayRef<Value * > values)159 static std::vector<std::optional<const Use>> gatherFirstUses(
160     at::ArrayRef<Value*> values) {
161   return fmap(values, [&](Value* v) -> std::optional<const Use> {
162     return firstOrLastUse(v, true);
163   });
164 }
165 
sort_indexes(at::ArrayRef<Value * > values)166 static std::vector<size_t> sort_indexes(at::ArrayRef<Value*> values) {
167   // initialize original index locations
168   std::vector<size_t> idx(values.size());
169   std::iota(idx.begin(), idx.end(), 0);
170 
171   std::vector<std::optional<const Use>> first_uses = gatherFirstUses(values);
172 
173   // Sort values based on canonical ordering of their first usage
174   std::sort(idx.begin(), idx.end(), [&first_uses](size_t i1, size_t i2) {
175     // if neither has any uses, use original ordering. Since the
176     // only values that jitter are ones added by the compiler and are guaranteed
177     // to have uses, original ordering is fine.
178     if (first_uses[i1] == std::nullopt && first_uses[i2] == std::nullopt) {
179       return i1 < i2;
180     }
181     if (first_uses[i1] == std::nullopt) {
182       return false;
183     } else if (first_uses[i2] == std::nullopt) {
184       return true;
185     }
186 
187     auto fst_v1 = *first_uses[i1];
188     auto fst_v2 = *first_uses[i2];
189 
190     return isBefore(fst_v1, fst_v2);
191   });
192 
193   return idx;
194 }
195 
CanonicalizeLoopOutputs(Node * n)196 static void CanonicalizeLoopOutputs(Node* n) {
197   auto new_indices = sort_indexes(n->outputs());
198   LoopView(n).permuteLoopCarried(new_indices);
199 }
200 
CanonicalizeIfOutputs(Node * n)201 static void CanonicalizeIfOutputs(Node* n) {
202   auto new_indices = sort_indexes(n->outputs());
203   IfView(n).permuteOutputs(new_indices);
204 }
205 
CanonicalizeOutputs(Block * block)206 static void CanonicalizeOutputs(Block* block) {
207   // We iterate in reverse since ordering of a node's outputs is dependent on
208   // the value use following it in the graph
209   for (Node* n : block->nodes().reverse()) {
210     switch (n->kind()) {
211       case prim::Loop: {
212         CanonicalizeLoopOutputs(n);
213       } break;
214       case prim::If: {
215         CanonicalizeIfOutputs(n);
216       } break;
217     }
218     // Since an a control flow node's outputs are after
219     // the values outputted within its blocks, first canonicalize
220     // the nodes outputs and then recurse on its blocks
221     for (Block* b : n->blocks()) {
222       CanonicalizeOutputs(b);
223     }
224   }
225 }
226 
227 // Canonicalize a graph's control flow node outputs. We do this to solve jitter
228 // issues with outputs added to control flow nodes after the first pass of
229 // compilation in ir_emitter.cpp
CanonicalizeOutputs(std::shared_ptr<Graph> & graph)230 void CanonicalizeOutputs(std::shared_ptr<Graph>& graph) {
231   CanonicalizeOutputs(graph->block());
232 }
233 } // namespace torch::jit
234