xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/subgraph_matcher.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/ir/subgraph_matcher.h>
3 #include <torch/csrc/jit/jit_log.h>
4 
5 #include <regex>
6 #include <stack>
7 
8 namespace torch::jit {
9 namespace {
10 
11 /**
12  * \brief A class implementing an API for comparing subgraphs.
13  */
14 class SubgraphMatcher {
15  public:
SubgraphMatcher(const Graph & pattern)16   explicit SubgraphMatcher(const Graph& pattern) : pattern_(pattern) {}
17 
18   /**
19    * \brief Compare matchGraph with the part of the graph denoted by a node \p
20    * ANCHOR.
21    *
22    * The anchor node would be compared against the deepest node in the
23    * match-graph. A node is considered matching if its number of inputs/outputs
24    * is the same as in the corresponding matchGraph node, its type is the same,
25    * and all nodes producing input-values also match.
26    */
27   bool matchesSubgraphFromAnchorNode(Node* anchor);
28 
29   /** \brief Return match map for nodes. */
nodes_map() const30   std::unordered_map<const Node*, Node*> nodes_map() const {
31     return nodes_map_;
32   }
33 
34   /** \brief Return match map for values. */
values_map() const35   std::unordered_map<const Value*, Value*> values_map() const {
36     return values_map_;
37   }
38 
39  private:
40   bool matchValues(const Value* v1, Value* v2);
41   bool matchNodes(const Node* n1, Node* n2);
42   bool matchAttributes(const Node* n1, Node* n2);
43 
44   static bool isInput(const Value* v);
45   static bool isOutput(const Value* v);
46 
47   std::unordered_map<const Node*, Node*> nodes_map_;
48   std::unordered_map<const Value*, Value*> values_map_;
49 
50   const Graph& pattern_;
51   const Node* anchor_ = nullptr;
52 };
53 
54 /**
55  * \brief A function to verify that \p PATTERN is valid. Concrete requirements
56  * for validity can be found in subgraph_matcher.h.
57  */
patternGraphIsValid(const Graph & pattern)58 bool patternGraphIsValid(const Graph& pattern) {
59   // Verify that pattern graph has a single block.
60   for (const Node* n : pattern.nodes()) {
61     if (!n->blocks().empty()) {
62       return false;
63     }
64   }
65 
66   // TODO: Verify that nodes in the pattern don't alias.
67   return true;
68 }
69 
isInput(const Value * v)70 bool SubgraphMatcher::isInput(const Value* v) {
71   return v->node()->kind() == prim::Param;
72 }
73 
isOutput(const Value * v)74 bool SubgraphMatcher::isOutput(const Value* v) {
75   for (const Value* output : v->owningGraph()->outputs()) {
76     if (v == output) {
77       return true;
78     }
79   }
80   return false;
81 }
82 
83 /**
84  * Compare two Values. V1 is from pattern, V2 is from the actual graph.
85  *
86  * The values are considered matching if:
87  * 1) the nodes defining them match
88  * 2) they have the same number of uses, except they are entry or exit nodes.
89  */
matchValues(const Value * v1,Value * v2)90 bool SubgraphMatcher::matchValues(const Value* v1, Value* v2) {
91   // Check if we've already visited these values.
92   if (values_map_.count(v1)) {
93     if (values_map_.at(v1) != v2) {
94       GRAPH_DEBUG(
95           "Values %",
96           v1->debugName(),
97           " and %",
98           v2->debugName(),
99           " did not match because %",
100           v1->debugName(),
101           " has already been matched with %",
102           values_map_.at(v1)->debugName(),
103           ".\n");
104       return false;
105     }
106     return true;
107   }
108 
109   // When V2 is ANCHOR, we're comparing exiting values, and when V1->node is
110   // PARAM, we're comparing entering values - in these two cases the number of
111   // uses don't need to be the same.
112   if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) {
113     GRAPH_DEBUG(
114         "Values %",
115         v1->debugName(),
116         " and %",
117         v2->debugName(),
118         " did not match because number of their uses is different.\n");
119     return false;
120   }
121 
122   // Add the values to the map before calling matchNodes to avoid infinite
123   // recursion.
124   GRAPH_DEBUG(
125       "Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n");
126   values_map_[v1] = v2;
127   return matchNodes(v1->node(), v2->node());
128 }
129 
matchAttributes(const Node * n1,Node * n2)130 bool SubgraphMatcher::matchAttributes(const Node* n1, Node* n2) {
131   if (n1->numAttributes() != n2->numAttributes()) {
132     GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2);
133     return false;
134   }
135   for (const Symbol& attr_name : n1->attributeNames()) {
136     if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) {
137       GRAPH_DEBUG(
138           "Nodes did not match because type of attribute '",
139           attr_name.toQualString(),
140           "' did not match:\n",
141           *n1,
142           *n2);
143       return false;
144     }
145     switch (n1->kindOf(attr_name)) {
146       case AttributeKind::s:
147         if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) {
148           GRAPH_DEBUG(
149               "Nodes did not match because attribute '",
150               attr_name.toQualString(),
151               "' did not match: ",
152               n1->s(attr_name),
153               " != ",
154               n2->s(attr_name),
155               " \n",
156               *n1,
157               *n2);
158           return false;
159         }
160         break;
161       case AttributeKind::c:
162         if (n1->c(attr_name) != n2->c(attr_name)) {
163           GRAPH_DEBUG(
164               "Nodes did not match because attribute '",
165               attr_name.toQualString(),
166               "' did not match:",
167               n1->c(attr_name),
168               " != ",
169               n2->c(attr_name),
170               " \n",
171               *n1,
172               *n2);
173           return false;
174         }
175         break;
176       case AttributeKind::f:
177         if (n1->f(attr_name) != n2->f(attr_name)) {
178           GRAPH_DEBUG(
179               "Nodes did not match because attribute '",
180               attr_name.toQualString(),
181               "' did not match:",
182               n1->f(attr_name),
183               " != ",
184               n2->f(attr_name),
185               " \n",
186               *n1,
187               *n2);
188           return false;
189         }
190         break;
191       case AttributeKind::i:
192         if (n1->i(attr_name) != n2->i(attr_name)) {
193           GRAPH_DEBUG(
194               "Nodes did not match because attribute '",
195               attr_name.toQualString(),
196               "' did not match:",
197               n1->i(attr_name),
198               " != ",
199               n2->i(attr_name),
200               " \n",
201               *n1,
202               *n2);
203           return false;
204         }
205         break;
206       default: {
207         // Other attributes types not supported yet
208         GRAPH_DEBUG(
209             "Nodes did not match because type of attribute '",
210             attr_name.toQualString(),
211             "' is not supported.\n",
212             *n1,
213             *n2);
214         return false;
215       }
216     }
217   }
218   return true;
219 }
220 
endsWith(const std::string & str,const std::string & suffix)221 static bool endsWith(const std::string& str, const std::string& suffix) {
222   return str.size() >= suffix.size() &&
223       0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
224 }
225 
226 /**
227  * Compare two Nodes. N1 is from pattern, N2 is from the actual graph.
228  *
229  * The nodes are considered matching if:
230  * 1) N1 and N2 are of the same kind.
231  * 2) Number of inputs and outputs is the same.
232  * 3) All input and output values match.
233  *
234  * A special case is when N1 is PARAM - this is considered outside the pattern,
235  * so it matches everything.
236  */
matchNodes(const Node * n1,Node * n2)237 bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) {
238   // Check if we've already visited these nodes.
239   if (nodes_map_.count(n1)) {
240     return nodes_map_.at(n1) == n2;
241   }
242 
243   // Param node in pattern graph matches everything.
244   if (n1->kind() == prim::Param) {
245     GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
246     return true;
247   }
248 
249   // We don't allow matches to span across blocks, so check if N2 is in the same
250   // block as the first (anchor) node.
251   if (n2->owningBlock() != anchor_->owningBlock()) {
252     GRAPH_DEBUG(
253         "Nodes did not match because it is in the different block:\n",
254         *n1,
255         *n2);
256     return false;
257   }
258 
259   // Special handling for matching modules
260   if (n1->kind() == Symbol::fromQualString("match::module")) {
261     if (n2->kind() == prim::GetAttr) {
262       if (!n1->hasAttributeS("name")) {
263         GRAPH_DEBUG(
264             "Nodes did not match because special node match::module does not have 'name' attribute:\n",
265             *n1,
266             *n2);
267         return false;
268       }
269       auto t = n2->output()->type()->expect<ClassType>();
270       auto real_typename = t->name()->qualifiedName();
271       auto pattern_typename = n1->s(attr::name);
272       if (!endsWith(real_typename, pattern_typename)) {
273         GRAPH_DEBUG(
274             "Nodes did not match because expected module type is different:\n");
275         GRAPH_DEBUG("  actualtype:    ", real_typename, "\n");
276         GRAPH_DEBUG("  expected type: ", pattern_typename, "\n");
277         GRAPH_DEBUG("Nodes:", *n1, *n2);
278         return false;
279       }
280     }
281   } else {
282     if (n1->kind() != n2->kind() ||
283         n1->outputs().size() != n2->outputs().size() ||
284         n1->inputs().size() != n2->inputs().size()) {
285       GRAPH_DEBUG(
286           "Nodes did not match in their kind or number of inputs/outputs:\n",
287           *n1,
288           *n2);
289       return false;
290     }
291     if (!matchAttributes(n1, n2)) {
292       return false;
293     }
294   }
295 
296   // Add nodes to the map before calling matchValues to avoid infinite
297   // recursion.
298   nodes_map_[n1] = n2;
299   for (const auto i : c10::irange(n1->outputs().size())) {
300     if (!matchValues(n1->outputs()[i], n2->outputs()[i])) {
301       return false;
302     }
303   }
304   for (const auto i : c10::irange(n1->inputs().size())) {
305     if (!matchValues(n1->inputs()[i], n2->inputs()[i])) {
306       return false;
307     }
308   }
309 
310   GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
311   return true;
312 }
313 
314 /**
315  * Recursively try to match pattern with the actual graph starting from the
316  * exiting node in the pattern and anchor node in the actual graph.
317  */
matchesSubgraphFromAnchorNode(Node * anchor)318 bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
319   GRAPH_UPDATE("Starting match from a new anchor: ", *anchor);
320   nodes_map_.clear();
321   values_map_.clear();
322   anchor_ = anchor;
323 
324   const Node* bottom_node = *(pattern_.nodes().end());
325   bottom_node = bottom_node->input(0)->node();
326 
327   if (!matchNodes(bottom_node, anchor)) {
328     return false;
329   }
330 
331   for (const Value* output : pattern_.outputs()) {
332     AT_ASSERT(values_map_.count(output));
333   }
334 
335   GRAPH_UPDATE("Pattern matched!\n");
336   return true;
337 }
338 
339 } // unnamed namespace
340 
341 // Main entry point for the subgraph matching.
findPatternMatches(const Graph & pattern,Graph & graph)342 std::vector<Match> findPatternMatches(const Graph& pattern, Graph& graph) {
343   AT_ASSERT(patternGraphIsValid(pattern));
344   GRAPH_DUMP("Pattern graph: ", &pattern);
345   GRAPH_DUMP("Target graph: ", &graph);
346 
347   SubgraphMatcher m(pattern);
348   std::vector<Match> matches;
349   std::stack<Block*> blocks_to_visit;
350 
351   // Iterate over all nodes in the graph (including nodes in subblocks) trying
352   // to match the pattern each node.
353   blocks_to_visit.push(graph.block());
354   while (!blocks_to_visit.empty()) {
355     Block* block = blocks_to_visit.top();
356     blocks_to_visit.pop();
357     for (Node* n : block->nodes()) {
358       if (m.matchesSubgraphFromAnchorNode(n)) {
359         matches.push_back({n, m.nodes_map(), m.values_map()});
360       }
361       for (Block* subblock : n->blocks()) {
362         blocks_to_visit.push(subblock);
363       }
364     }
365   }
366   return matches;
367 }
368 
369 } // namespace torch::jit
370