1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 #include <unordered_map> 6 #include <vector> 7 8 namespace torch::jit { 9 10 /** 11 * \brief A structure describing a match of a pattern in a graph. 12 * 13 * The structure contains an anchor node, from which the match was found, and 14 * match-maps for nodes and values. A match-map specifies the correspondance 15 * between nodes in the pattern graph (match-map keys) with nodes in the actual 16 * graph (match-map values). We keep such maps for both nodes and values. 17 */ 18 struct Match { 19 Node* anchor; 20 std::unordered_map<const Node*, Node*> nodes_map; 21 std::unordered_map<const Value*, Value*> values_map; 22 }; 23 24 /** 25 * \brief Find all matches of a \p PATTERN in a \p GRAPH. 26 * 27 * The function returns a vector of match-descriptors (see description of 28 * `struct Match`). 29 * 30 * Matching rules: 31 * - Pattern graph must contain a single block. 32 * - Matched subgraphs do not span across different blocks. 33 * - No uses outside the match are allowed, except for Param and Return nodes. 34 * Basically, we're matching hammocks, not arbitrary subgraphs. 35 * - The pattern graph must return only one value (i.e. it must have a single 36 * node leading to return). 37 * - Nodes that are not used in computation of the return value in the pattern 38 * graph are ignored during matching (IOW, we're essentially performing DCE on 39 * the pattern). 40 * - Pattern graph nodes cannot alias. TODO: the check not implemented yet. 41 * - Aliasing nodes in the graph cannot consitute a match (i.e. through all 42 * found matches, no nodes in the subgraph alias with each other). TODO: check 43 * not implemented yet. 44 * - The matcher will not mutate either the pattern graph or the matched graph. 45 * The matched graph is taken as non-const so that Match may contain non-const 46 * pointers. This enables clients of this API to use Match to drive mutations. 47 * 48 * Note [Multi-output Patterns] 49 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 50 * Subgraph matcher provides limited support for multi-output patterns. With a 51 * single output pattern, a single scan through the graph is sufficient to 52 * find all the matches: given a starting node (an "anchor"), we can 53 * deterministically check whether a pattern matches a subgraph corresponding to 54 * this anchor node. For a general case of multi-output patterns, we would have 55 * N anchors, which would result in M^N comparisons (M is the size of the 56 * graph). Clearly this is computationally prohibitive. 57 * 58 * To overcome this, we impose some constraints on the multi-output patterns 59 * that we accept. We require that checking whether the pattern matches a 60 * subgraph would still be fully determined by a single node in the graph. To 61 * achieve this, we designate the first output in the pattern as the "main" 62 * output and assume that we can traverse up from this node to match the 63 * entire pattern. 64 * 65 * Corrolary 1: the order of outputs in the pattern matters! 66 * Corollary 2: patterns cannot contain any nodes not participating in the main 67 * output computation. 68 */ 69 std::vector<Match> TORCH_API 70 findPatternMatches(const Graph& pattern, Graph& graph); 71 72 } // namespace torch::jit 73