xref: /aosp_15_r20/external/stg/unification.h (revision 9e3b08ae94a55201065475453d799e8b1378bea6)
1 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2 // -*- mode: C++ -*-
3 //
4 // Copyright 2022-2023 Google LLC
5 //
6 // Licensed under the Apache License v2.0 with LLVM Exceptions (the
7 // "License"); you may not use this file except in compliance with the
8 // License.  You may obtain a copy of the License at
9 //
10 //     https://llvm.org/LICENSE.txt
11 //
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 //
18 // Author: Giuliano Procida
19 
20 #ifndef STG_UNIFICATION_H_
21 #define STG_UNIFICATION_H_
22 
23 #include <exception>
24 
25 #include "graph.h"
26 #include "runtime.h"
27 #include "substitution.h"
28 
29 namespace stg {
30 
31 // Keep track of which nodes are pending substitution and rewrite the graph on
32 // destruction.
33 class Unification {
34  public:
Unification(Runtime & runtime,Graph & graph,Id start)35   Unification(Runtime& runtime, Graph& graph, Id start)
36       : graph_(graph),
37         start_(start),
38         mapping_(start),
39         runtime_(runtime),
40         find_query_(runtime, "unification.find_query"),
41         find_halved_(runtime, "unification.find_halved"),
42         union_known_(runtime, "unification.union_known"),
43         union_unknown_(runtime, "unification.union_unknown") {}
44 
~Unification()45   ~Unification() {
46     if (std::uncaught_exceptions() > 0) {
47       // abort unification
48       return;
49     }
50     // apply substitutions to the entire graph
51     const Time time(runtime_, "unification.rewrite");
52     Counter removed(runtime_, "unification.removed");
53     Counter retained(runtime_, "unification.retained");
54     const auto remap = [&](Id& id) {
55       Update(id);
56     };
57     const Substitute substitute(graph_, remap);
58     graph_.ForEach(start_, graph_.Limit(), [&](Id id) {
59       if (Find(id) != id) {
60         graph_.Remove(id);
61         ++removed;
62       } else {
63         substitute(id);
64         ++retained;
65       }
66     });
67   }
68 
Reserve(Id limit)69   void Reserve(Id limit) {
70     mapping_.Reserve(limit);
71   }
72 
73   bool Unify(Id id1, Id id2);
74 
Find(Id id)75   Id Find(Id id) {
76     ++find_query_;
77     // path halving - tiny performance gain
78     while (true) {
79       // note: safe to take a reference as mapping cannot grow after this
80       auto& parent = mapping_[id];
81       if (parent == id) {
82         return id;
83       }
84       const auto parent_parent = mapping_[parent];
85       if (parent_parent == parent) {
86         return parent;
87       }
88       id = parent = parent_parent;
89       ++find_halved_;
90     }
91   }
92 
Union(Id id1,Id id2)93   void Union(Id id1, Id id2) {
94     // id2 will always be preferred as a parent node; interpreted as a
95     // substitution, id1 will be replaced by id2
96     const Id fid1 = Find(id1);
97     const Id fid2 = Find(id2);
98     if (fid1 == fid2) {
99       ++union_known_;
100       return;
101     }
102     mapping_[fid1] = fid2;
103     ++union_unknown_;
104   }
105 
106   // update id to representative id
Update(Id & id)107   void Update(Id& id) {
108     const Id fid = Find(id);
109     // avoid silent stores
110     if (fid != id) {
111       id = fid;
112     }
113   }
114 
115  private:
116   Graph& graph_;
117   Id start_;
118   DenseIdMapping mapping_;
119   Runtime& runtime_;
120   Counter find_query_;
121   Counter find_halved_;
122   Counter union_known_;
123   Counter union_unknown_;
124 };
125 
126 }  // namespace stg
127 
128 #endif  // STG_UNIFICATION_H_
129