1 #include <torch/csrc/jit/passes/canonicalize.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/ir/ir_views.h>
5
6 namespace torch::jit {
7
8 // Canonicalize a graph, renumbering it so that all structurally equivalent
9 // graphs have same numbers.
10 // keep_unique_names: If false, canonicalizes unique names by removing them
11 // and replacing them with normal value names.
12 // Otherwise, ignores values with unique names.
Canonicalize(const std::shared_ptr<Graph> & graph,bool keep_unique_names)13 std::shared_ptr<Graph> Canonicalize(
14 const std::shared_ptr<Graph>& graph,
15 bool keep_unique_names) {
16 auto r = std::make_shared<Graph>(graph->current_scope());
17 std::unordered_map<Value*, Value*> rn_env;
18 auto rn_fn = [&](Value* v) { return rn_env.at(v); };
19 for (auto* input : graph->inputs()) {
20 auto* r_input = r->addInput();
21 r_input->copyMetadata(input);
22 if (!keep_unique_names)
23 r_input->setDebugName("");
24 rn_env[input] = r_input;
25 }
26 for (auto* node : graph->nodes()) {
27 auto* r_node = r->createClone(node, rn_fn);
28 if (!keep_unique_names) {
29 for (auto* output : r_node->outputs()) {
30 output->setDebugName("");
31 }
32 }
33 r->appendNode(r_node);
34 auto outputs = node->outputs();
35 auto r_outputs = r_node->outputs();
36 for (const auto i : c10::irange(outputs.size())) {
37 rn_env[outputs.at(i)] = r_outputs.at(i);
38 }
39 if (node->hasAttribute(attr::Subgraph)) {
40 r_node->g_(
41 attr::Subgraph,
42 Canonicalize(node->g(attr::Subgraph), keep_unique_names));
43 }
44 }
45 for (auto* output : graph->outputs()) {
46 r->registerOutput(rn_fn(output));
47 }
48
49 return r;
50 }
51
52 // Which index in b's owning Node is b
blockIndex(const Block * b)53 static size_t blockIndex(const Block* b) {
54 auto n = b->owningNode();
55 AT_ASSERT(n);
56 for (size_t i = 0; i < n->blocks().size(); ++i) {
57 if (n->blocks()[i] == b) {
58 return i;
59 }
60 }
61 AT_ASSERT(false);
62 }
63
64 /*
65 * This establishes a canonical ordering of nodes.
66 * If n1 and n2 are in the same block, whichever node appears first
67 * is before the other.
68 * If n1 and n2 are contained in different blocks of an if node,
69 * then whichever block is in the true block is ordered before the other.
70 * If n1 contains n2, then n1 is before n2. This has the nice property that
71 * whichever node appears first in a dump of the graph is before the other.
72 * NB: this is not a topological index. Topologically, two nodes in
73 * different blocks of an if node are not topologically < or > each other.
74 */
isBefore(Node * n1,Node * n2)75 static bool isBefore(Node* n1, Node* n2) {
76 // Invalid to call with the same node as both args
77 AT_ASSERT(n1 != n2);
78
79 // Set n1 and n2 to be the number of blocks from the Graph block
80 size_t d_1 = n1->blocksFromGraphBlock();
81 size_t d_2 = n2->blocksFromGraphBlock();
82
83 for (; d_1 > d_2; --d_1) {
84 n1 = n1->owningBlock()->owningNode();
85 // n2 contains n1
86 if (n1 == n2) {
87 return false;
88 }
89 }
90
91 for (; d_2 > d_1; --d_2) {
92 n2 = n2->owningBlock()->owningNode();
93 // n1 contains n2
94 if (n2 == n1) {
95 return true;
96 }
97 }
98
99 // Now they are the same numer of blocks from the graph block,
100 // recurse upwards, checking if they are on the same block
101 while (true) {
102 if (n1->owningBlock() == n2->owningBlock()) {
103 return n1->isBefore(n2);
104 }
105
106 auto new_n1 = n1->owningBlock()->owningNode();
107 auto new_n2 = n2->owningBlock()->owningNode();
108
109 AT_ASSERT(new_n1 != nullptr);
110 AT_ASSERT(new_n2 != nullptr);
111
112 if (new_n1 == new_n2) {
113 // take whichever node is in the earlier block
114 auto index_1 = blockIndex(n1->owningBlock());
115 auto index_2 = blockIndex(n2->owningBlock());
116 return index_1 < index_2;
117 }
118
119 n1 = new_n1;
120 n2 = new_n2;
121 }
122 }
123
isBefore(const Use & a,const Use & b)124 static bool isBefore(const Use& a, const Use& b) {
125 // If two uses are the same node, we order on offset
126 if (a.user == b.user) {
127 return a.offset < b.offset;
128 }
129
130 return isBefore(a.user, b.user);
131 }
132
isAfter(const Use & a,const Use & b)133 static bool isAfter(const Use& a, const Use& b) {
134 if (a.user == b.user && a.offset == b.offset) {
135 return false;
136 }
137 return !isBefore(a, b);
138 }
139
isBeforeOrAfter(const Use & a,const Use & b,bool checking_before)140 bool isBeforeOrAfter(const Use& a, const Use& b, bool checking_before) {
141 return checking_before ? isBefore(a, b) : isAfter(a, b);
142 }
143
firstOrLastUse(Value * v,bool find_first)144 std::optional<const Use> firstOrLastUse(Value* v, bool find_first) {
145 if (v->uses().empty()) {
146 return std::nullopt;
147 }
148 Use extreme_use = v->uses()[0];
149 for (size_t i = 1; i < v->uses().size(); ++i) {
150 auto n_use = v->uses()[i];
151 if (!isBeforeOrAfter(extreme_use, n_use, find_first)) {
152 extreme_use = n_use;
153 }
154 }
155
156 return extreme_use;
157 }
158
gatherFirstUses(at::ArrayRef<Value * > values)159 static std::vector<std::optional<const Use>> gatherFirstUses(
160 at::ArrayRef<Value*> values) {
161 return fmap(values, [&](Value* v) -> std::optional<const Use> {
162 return firstOrLastUse(v, true);
163 });
164 }
165
sort_indexes(at::ArrayRef<Value * > values)166 static std::vector<size_t> sort_indexes(at::ArrayRef<Value*> values) {
167 // initialize original index locations
168 std::vector<size_t> idx(values.size());
169 std::iota(idx.begin(), idx.end(), 0);
170
171 std::vector<std::optional<const Use>> first_uses = gatherFirstUses(values);
172
173 // Sort values based on canonical ordering of their first usage
174 std::sort(idx.begin(), idx.end(), [&first_uses](size_t i1, size_t i2) {
175 // if neither has any uses, use original ordering. Since the
176 // only values that jitter are ones added by the compiler and are guaranteed
177 // to have uses, original ordering is fine.
178 if (first_uses[i1] == std::nullopt && first_uses[i2] == std::nullopt) {
179 return i1 < i2;
180 }
181 if (first_uses[i1] == std::nullopt) {
182 return false;
183 } else if (first_uses[i2] == std::nullopt) {
184 return true;
185 }
186
187 auto fst_v1 = *first_uses[i1];
188 auto fst_v2 = *first_uses[i2];
189
190 return isBefore(fst_v1, fst_v2);
191 });
192
193 return idx;
194 }
195
CanonicalizeLoopOutputs(Node * n)196 static void CanonicalizeLoopOutputs(Node* n) {
197 auto new_indices = sort_indexes(n->outputs());
198 LoopView(n).permuteLoopCarried(new_indices);
199 }
200
CanonicalizeIfOutputs(Node * n)201 static void CanonicalizeIfOutputs(Node* n) {
202 auto new_indices = sort_indexes(n->outputs());
203 IfView(n).permuteOutputs(new_indices);
204 }
205
CanonicalizeOutputs(Block * block)206 static void CanonicalizeOutputs(Block* block) {
207 // We iterate in reverse since ordering of a node's outputs is dependent on
208 // the value use following it in the graph
209 for (Node* n : block->nodes().reverse()) {
210 switch (n->kind()) {
211 case prim::Loop: {
212 CanonicalizeLoopOutputs(n);
213 } break;
214 case prim::If: {
215 CanonicalizeIfOutputs(n);
216 } break;
217 }
218 // Since an a control flow node's outputs are after
219 // the values outputted within its blocks, first canonicalize
220 // the nodes outputs and then recurse on its blocks
221 for (Block* b : n->blocks()) {
222 CanonicalizeOutputs(b);
223 }
224 }
225 }
226
227 // Canonicalize a graph's control flow node outputs. We do this to solve jitter
228 // issues with outputs added to control flow nodes after the first pass of
229 // compilation in ir_emitter.cpp
CanonicalizeOutputs(std::shared_ptr<Graph> & graph)230 void CanonicalizeOutputs(std::shared_ptr<Graph>& graph) {
231 CanonicalizeOutputs(graph->block());
232 }
233 } // namespace torch::jit
234