1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/ir/subgraph_matcher.h>
3 #include <torch/csrc/jit/jit_log.h>
4
5 #include <regex>
6 #include <stack>
7
8 namespace torch::jit {
9 namespace {
10
11 /**
12 * \brief A class implementing an API for comparing subgraphs.
13 */
14 class SubgraphMatcher {
15 public:
SubgraphMatcher(const Graph & pattern)16 explicit SubgraphMatcher(const Graph& pattern) : pattern_(pattern) {}
17
18 /**
19 * \brief Compare matchGraph with the part of the graph denoted by a node \p
20 * ANCHOR.
21 *
22 * The anchor node would be compared against the deepest node in the
23 * match-graph. A node is considered matching if its number of inputs/outputs
24 * is the same as in the corresponding matchGraph node, its type is the same,
25 * and all nodes producing input-values also match.
26 */
27 bool matchesSubgraphFromAnchorNode(Node* anchor);
28
29 /** \brief Return match map for nodes. */
nodes_map() const30 std::unordered_map<const Node*, Node*> nodes_map() const {
31 return nodes_map_;
32 }
33
34 /** \brief Return match map for values. */
values_map() const35 std::unordered_map<const Value*, Value*> values_map() const {
36 return values_map_;
37 }
38
39 private:
40 bool matchValues(const Value* v1, Value* v2);
41 bool matchNodes(const Node* n1, Node* n2);
42 bool matchAttributes(const Node* n1, Node* n2);
43
44 static bool isInput(const Value* v);
45 static bool isOutput(const Value* v);
46
47 std::unordered_map<const Node*, Node*> nodes_map_;
48 std::unordered_map<const Value*, Value*> values_map_;
49
50 const Graph& pattern_;
51 const Node* anchor_ = nullptr;
52 };
53
54 /**
55 * \brief A function to verify that \p PATTERN is valid. Concrete requirements
56 * for validity can be found in subgraph_matcher.h.
57 */
patternGraphIsValid(const Graph & pattern)58 bool patternGraphIsValid(const Graph& pattern) {
59 // Verify that pattern graph has a single block.
60 for (const Node* n : pattern.nodes()) {
61 if (!n->blocks().empty()) {
62 return false;
63 }
64 }
65
66 // TODO: Verify that nodes in the pattern don't alias.
67 return true;
68 }
69
isInput(const Value * v)70 bool SubgraphMatcher::isInput(const Value* v) {
71 return v->node()->kind() == prim::Param;
72 }
73
isOutput(const Value * v)74 bool SubgraphMatcher::isOutput(const Value* v) {
75 for (const Value* output : v->owningGraph()->outputs()) {
76 if (v == output) {
77 return true;
78 }
79 }
80 return false;
81 }
82
83 /**
84 * Compare two Values. V1 is from pattern, V2 is from the actual graph.
85 *
86 * The values are considered matching if:
87 * 1) the nodes defining them match
88 * 2) they have the same number of uses, except they are entry or exit nodes.
89 */
matchValues(const Value * v1,Value * v2)90 bool SubgraphMatcher::matchValues(const Value* v1, Value* v2) {
91 // Check if we've already visited these values.
92 if (values_map_.count(v1)) {
93 if (values_map_.at(v1) != v2) {
94 GRAPH_DEBUG(
95 "Values %",
96 v1->debugName(),
97 " and %",
98 v2->debugName(),
99 " did not match because %",
100 v1->debugName(),
101 " has already been matched with %",
102 values_map_.at(v1)->debugName(),
103 ".\n");
104 return false;
105 }
106 return true;
107 }
108
109 // When V2 is ANCHOR, we're comparing exiting values, and when V1->node is
110 // PARAM, we're comparing entering values - in these two cases the number of
111 // uses don't need to be the same.
112 if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) {
113 GRAPH_DEBUG(
114 "Values %",
115 v1->debugName(),
116 " and %",
117 v2->debugName(),
118 " did not match because number of their uses is different.\n");
119 return false;
120 }
121
122 // Add the values to the map before calling matchNodes to avoid infinite
123 // recursion.
124 GRAPH_DEBUG(
125 "Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n");
126 values_map_[v1] = v2;
127 return matchNodes(v1->node(), v2->node());
128 }
129
matchAttributes(const Node * n1,Node * n2)130 bool SubgraphMatcher::matchAttributes(const Node* n1, Node* n2) {
131 if (n1->numAttributes() != n2->numAttributes()) {
132 GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2);
133 return false;
134 }
135 for (const Symbol& attr_name : n1->attributeNames()) {
136 if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) {
137 GRAPH_DEBUG(
138 "Nodes did not match because type of attribute '",
139 attr_name.toQualString(),
140 "' did not match:\n",
141 *n1,
142 *n2);
143 return false;
144 }
145 switch (n1->kindOf(attr_name)) {
146 case AttributeKind::s:
147 if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) {
148 GRAPH_DEBUG(
149 "Nodes did not match because attribute '",
150 attr_name.toQualString(),
151 "' did not match: ",
152 n1->s(attr_name),
153 " != ",
154 n2->s(attr_name),
155 " \n",
156 *n1,
157 *n2);
158 return false;
159 }
160 break;
161 case AttributeKind::c:
162 if (n1->c(attr_name) != n2->c(attr_name)) {
163 GRAPH_DEBUG(
164 "Nodes did not match because attribute '",
165 attr_name.toQualString(),
166 "' did not match:",
167 n1->c(attr_name),
168 " != ",
169 n2->c(attr_name),
170 " \n",
171 *n1,
172 *n2);
173 return false;
174 }
175 break;
176 case AttributeKind::f:
177 if (n1->f(attr_name) != n2->f(attr_name)) {
178 GRAPH_DEBUG(
179 "Nodes did not match because attribute '",
180 attr_name.toQualString(),
181 "' did not match:",
182 n1->f(attr_name),
183 " != ",
184 n2->f(attr_name),
185 " \n",
186 *n1,
187 *n2);
188 return false;
189 }
190 break;
191 case AttributeKind::i:
192 if (n1->i(attr_name) != n2->i(attr_name)) {
193 GRAPH_DEBUG(
194 "Nodes did not match because attribute '",
195 attr_name.toQualString(),
196 "' did not match:",
197 n1->i(attr_name),
198 " != ",
199 n2->i(attr_name),
200 " \n",
201 *n1,
202 *n2);
203 return false;
204 }
205 break;
206 default: {
207 // Other attributes types not supported yet
208 GRAPH_DEBUG(
209 "Nodes did not match because type of attribute '",
210 attr_name.toQualString(),
211 "' is not supported.\n",
212 *n1,
213 *n2);
214 return false;
215 }
216 }
217 }
218 return true;
219 }
220
endsWith(const std::string & str,const std::string & suffix)221 static bool endsWith(const std::string& str, const std::string& suffix) {
222 return str.size() >= suffix.size() &&
223 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
224 }
225
226 /**
227 * Compare two Nodes. N1 is from pattern, N2 is from the actual graph.
228 *
229 * The nodes are considered matching if:
230 * 1) N1 and N2 are of the same kind.
231 * 2) Number of inputs and outputs is the same.
232 * 3) All input and output values match.
233 *
234 * A special case is when N1 is PARAM - this is considered outside the pattern,
235 * so it matches everything.
236 */
matchNodes(const Node * n1,Node * n2)237 bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) {
238 // Check if we've already visited these nodes.
239 if (nodes_map_.count(n1)) {
240 return nodes_map_.at(n1) == n2;
241 }
242
243 // Param node in pattern graph matches everything.
244 if (n1->kind() == prim::Param) {
245 GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
246 return true;
247 }
248
249 // We don't allow matches to span across blocks, so check if N2 is in the same
250 // block as the first (anchor) node.
251 if (n2->owningBlock() != anchor_->owningBlock()) {
252 GRAPH_DEBUG(
253 "Nodes did not match because it is in the different block:\n",
254 *n1,
255 *n2);
256 return false;
257 }
258
259 // Special handling for matching modules
260 if (n1->kind() == Symbol::fromQualString("match::module")) {
261 if (n2->kind() == prim::GetAttr) {
262 if (!n1->hasAttributeS("name")) {
263 GRAPH_DEBUG(
264 "Nodes did not match because special node match::module does not have 'name' attribute:\n",
265 *n1,
266 *n2);
267 return false;
268 }
269 auto t = n2->output()->type()->expect<ClassType>();
270 auto real_typename = t->name()->qualifiedName();
271 auto pattern_typename = n1->s(attr::name);
272 if (!endsWith(real_typename, pattern_typename)) {
273 GRAPH_DEBUG(
274 "Nodes did not match because expected module type is different:\n");
275 GRAPH_DEBUG(" actualtype: ", real_typename, "\n");
276 GRAPH_DEBUG(" expected type: ", pattern_typename, "\n");
277 GRAPH_DEBUG("Nodes:", *n1, *n2);
278 return false;
279 }
280 }
281 } else {
282 if (n1->kind() != n2->kind() ||
283 n1->outputs().size() != n2->outputs().size() ||
284 n1->inputs().size() != n2->inputs().size()) {
285 GRAPH_DEBUG(
286 "Nodes did not match in their kind or number of inputs/outputs:\n",
287 *n1,
288 *n2);
289 return false;
290 }
291 if (!matchAttributes(n1, n2)) {
292 return false;
293 }
294 }
295
296 // Add nodes to the map before calling matchValues to avoid infinite
297 // recursion.
298 nodes_map_[n1] = n2;
299 for (const auto i : c10::irange(n1->outputs().size())) {
300 if (!matchValues(n1->outputs()[i], n2->outputs()[i])) {
301 return false;
302 }
303 }
304 for (const auto i : c10::irange(n1->inputs().size())) {
305 if (!matchValues(n1->inputs()[i], n2->inputs()[i])) {
306 return false;
307 }
308 }
309
310 GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
311 return true;
312 }
313
314 /**
315 * Recursively try to match pattern with the actual graph starting from the
316 * exiting node in the pattern and anchor node in the actual graph.
317 */
matchesSubgraphFromAnchorNode(Node * anchor)318 bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
319 GRAPH_UPDATE("Starting match from a new anchor: ", *anchor);
320 nodes_map_.clear();
321 values_map_.clear();
322 anchor_ = anchor;
323
324 const Node* bottom_node = *(pattern_.nodes().end());
325 bottom_node = bottom_node->input(0)->node();
326
327 if (!matchNodes(bottom_node, anchor)) {
328 return false;
329 }
330
331 for (const Value* output : pattern_.outputs()) {
332 AT_ASSERT(values_map_.count(output));
333 }
334
335 GRAPH_UPDATE("Pattern matched!\n");
336 return true;
337 }
338
339 } // unnamed namespace
340
341 // Main entry point for the subgraph matching.
findPatternMatches(const Graph & pattern,Graph & graph)342 std::vector<Match> findPatternMatches(const Graph& pattern, Graph& graph) {
343 AT_ASSERT(patternGraphIsValid(pattern));
344 GRAPH_DUMP("Pattern graph: ", &pattern);
345 GRAPH_DUMP("Target graph: ", &graph);
346
347 SubgraphMatcher m(pattern);
348 std::vector<Match> matches;
349 std::stack<Block*> blocks_to_visit;
350
351 // Iterate over all nodes in the graph (including nodes in subblocks) trying
352 // to match the pattern each node.
353 blocks_to_visit.push(graph.block());
354 while (!blocks_to_visit.empty()) {
355 Block* block = blocks_to_visit.top();
356 blocks_to_visit.pop();
357 for (Node* n : block->nodes()) {
358 if (m.matchesSubgraphFromAnchorNode(n)) {
359 matches.push_back({n, m.nodes_map(), m.values_map()});
360 }
361 for (Block* subblock : n->blocks()) {
362 blocks_to_visit.push(subblock);
363 }
364 }
365 }
366 return matches;
367 }
368
369 } // namespace torch::jit
370