1 #include <torch/csrc/jit/passes/utils/memory_dag.h>
2
3 #include <c10/util/flat_hash_map.h>
4 #include <algorithm>
5 #include <queue>
6
7 namespace torch {
8 namespace jit {
9 namespace {
10
makePointerToImpl(Element * from,Element * to)11 void makePointerToImpl(Element* from, Element* to) {
12 from->pointsTo.set(to->index);
13 to->pointedFrom.set(from->index);
14 }
15
makeFreshValueImpl(const Value * v,std::vector<std::unique_ptr<Element>> & indexToElementMap_)16 Element* makeFreshValueImpl(
17 const Value* v,
18 std::vector<std::unique_ptr<Element>>& indexToElementMap_) {
19 if (v == nullptr) {
20 // Create a wildcard element, with no corresponding value
21 indexToElementMap_.emplace_back(
22 std::make_unique<Element>(indexToElementMap_.size()));
23 return indexToElementMap_.back().get();
24 }
25 indexToElementMap_.emplace_back(
26 std::make_unique<Element>(v, indexToElementMap_.size()));
27 return indexToElementMap_.back().get();
28 }
29 } // namespace
30
Element(const Value * value_,unsigned index_)31 Element::Element(const Value* value_, unsigned index_)
32 : index(index_), values({value_}) {}
Element(unsigned index_)33 Element::Element(unsigned index_) : index(index_), values({}) {}
34
fromIndex(unsigned x) const35 const Element* MemoryDAG::fromIndex(unsigned x) const {
36 TORCH_INTERNAL_ASSERT(x < indexToElementMap_.size());
37 return indexToElementMap_[x].get();
38 }
39
fromIndex(unsigned x)40 Element* MemoryDAG::fromIndex(unsigned x) {
41 TORCH_INTERNAL_ASSERT(x < indexToElementMap_.size());
42 return indexToElementMap_[x].get();
43 }
44
mayAlias(const Element * a,const Element * b) const45 bool MemoryDAG::mayAlias(const Element* a, const Element* b) const {
46 const auto& aMemLoc = getMemoryLocations(a);
47 const auto& bMemLoc = getMemoryLocations(b);
48
49 return aMemLoc.intersects(bMemLoc);
50 }
51
mayContainAlias(const Element * a,const Element * b) const52 bool MemoryDAG::mayContainAlias(const Element* a, const Element* b) const {
53 return getAllContainedMemoryLocations(a).intersects(
54 getAllContainedMemoryLocations(b));
55 }
56
getAllContainedMemoryLocations(const Element * elem) const57 const MemoryLocations& MemoryDAG::getAllContainedMemoryLocations(
58 const Element* elem) const {
59 if (C10_UNLIKELY(!elem->cachedAllContainedMemoryLocations_.has_value())) {
60 MemoryLocations cache;
61 elem->cachedAllContainedMemoryLocations_ = MemoryLocations();
62 collectAllContainedMemoryLocationsImpl(
63 elem, *elem->cachedAllContainedMemoryLocations_);
64 }
65 return *elem->cachedAllContainedMemoryLocations_;
66 }
67
collectAllContainedMemoryLocations(const Element * elem,MemoryLocations & cont) const68 void MemoryDAG::collectAllContainedMemoryLocations(
69 const Element* elem,
70 MemoryLocations& cont) const {
71 // we have already recursed on this element
72 unsigned compIdx = elem->index;
73 if (cont.test(compIdx)) {
74 return;
75 }
76
77 if (C10_UNLIKELY(!elem->cachedAllContainedMemoryLocations_.has_value())) {
78 MemoryLocations cache;
79 collectAllContainedMemoryLocationsImpl(elem, cache);
80 elem->cachedAllContainedMemoryLocations_ = std::move(cache);
81 }
82 cont |= *elem->cachedAllContainedMemoryLocations_;
83 }
84
collectAllContainedMemoryLocationsImpl(const Element * elem,MemoryLocations & cont) const85 void MemoryDAG::collectAllContainedMemoryLocationsImpl(
86 const Element* elem,
87 MemoryLocations& cont) const {
88 unsigned compIdx = elem->index;
89 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!cont.test(compIdx));
90 cont.set(compIdx);
91
92 for (const auto& mem_loc : getMemoryLocations(elem)) {
93 collectAllContainedMemoryLocations(fromIndex(mem_loc), cont);
94 }
95
96 for (const auto& contained : elem->containedElements) {
97 collectAllContainedMemoryLocations(fromIndex(contained), cont);
98 }
99 }
100
mayContainAlias(const Element * a,const at::ArrayRef<Element * > b) const101 bool MemoryDAG::mayContainAlias(
102 const Element* a,
103 const at::ArrayRef<Element*> b) const {
104 if (b.empty()) {
105 return false;
106 }
107
108 const auto& a_contained = getAllContainedMemoryLocations(a);
109 return std::any_of(b.begin(), b.end(), [this, &a_contained](Element* b_elem) {
110 return a_contained.intersects(this->getAllContainedMemoryLocations(b_elem));
111 });
112 }
113
mayContainAlias(const at::ArrayRef<Element * > a,const at::ArrayRef<Element * > b) const114 bool MemoryDAG::mayContainAlias(
115 const at::ArrayRef<Element*> a,
116 const at::ArrayRef<Element*> b) const {
117 if (a.empty() || b.empty()) {
118 return false;
119 }
120
121 MemoryLocations all_a_mlocs;
122 for (const auto& elem : a) {
123 collectAllContainedMemoryLocations(elem, all_a_mlocs);
124 }
125
126 MemoryLocations all_b_mlocs;
127 for (const auto& elem : b) {
128 collectAllContainedMemoryLocations(elem, all_b_mlocs);
129 }
130
131 return all_a_mlocs.intersects(all_b_mlocs);
132 }
133
makePointerTo(Element * from,Element * to)134 void MemoryDAGBuilder::makePointerTo(Element* from, Element* to) {
135 makePointerToImpl(from, to);
136 }
137
addToContainedElements(Element * elem,Element * container)138 void MemoryDAGBuilder::addToContainedElements(
139 Element* elem,
140 Element* container) {
141 TORCH_INTERNAL_ASSERT(
142 elem != container, "Elements cannot contain themselves");
143 container->containedElements.set(elem->index);
144 }
145
146 // Give `v` a fresh alias (i.e. it does not point to any value)
makeFreshValue(const Value * v)147 Element* MemoryDAGBuilder::makeFreshValue(const Value* v) {
148 return makeFreshValueImpl(v, indexToElementMap_);
149 }
150
151 // This function builds up a bitset representing the "alias set" for
152 // `e` (`MemoryLocations` is just a typedef'd c10::SparseBitVector).
getMemoryLocations(const Element * e) const153 const MemoryLocations& MemoryDAG::getMemoryLocations(const Element* e) const {
154 // Note on cache invalidation: all mutation should occur through
155 // MemoryDAGBuilder. Thus, once we consume the builder to create an
156 // immutable MemoryDAG, we can cache here without worrying that we
157 // might potentially get invalidated.
158 if (e->cachedMemoryLocations_) {
159 return *e->cachedMemoryLocations_;
160 }
161
162 MemoryLocations ret;
163 if (e->pointsTo.empty()) {
164 // Base case: if we don't point to anything, this element is a memory
165 // location. Return itself.
166 ret.set(e->index);
167 } else {
168 for (auto el : e->pointsTo) {
169 ret |= getMemoryLocations(fromIndex(el));
170 }
171 }
172
173 e->cachedMemoryLocations_ = std::move(ret);
174 return *e->cachedMemoryLocations_;
175 }
176
setWildcards(const std::unordered_set<const Value * > & wildcards,const ska::flat_hash_map<const Value *,Element * > & elementMap,const std::function<Element * (const Value *)> & getWildcardElement)177 void MemoryDAG::setWildcards(
178 const std::unordered_set<const Value*>& wildcards,
179 const ska::flat_hash_map<const Value*, Element*>& elementMap,
180 const std::function<Element*(const Value*)>& getWildcardElement) {
181 std::unordered_map<Element*, MemoryLocations> cacheUpdates;
182 // If an element is set as a wildcard, that means that all its memory
183 // locations must point to the wildcard element.
184 for (const Value* v : wildcards) {
185 auto wildcardElement = getWildcardElement(v);
186 TORCH_INTERNAL_ASSERT(wildcardElement);
187
188 const MemoryLocations& pointeeSet = getMemoryLocations(elementMap.at(v));
189 for (const auto& pointee : pointeeSet) {
190 auto from = this->fromIndex(pointee);
191 // avoid cycles where the wildcard points to itself
192 if (from != wildcardElement) {
193 makePointerToImpl(from, wildcardElement);
194 }
195 }
196 // Track which memory locations we edited with a new pointer to the wildcard
197 // element.
198 cacheUpdates[wildcardElement] |= pointeeSet;
199 }
200
201 // Update caches in-place.
202 // We take advantage of the fact that we only edited memory locations.
203 //
204 // Say we added a pointer from `MemoryLocationFoo -> WildcardBar`.
205 // For every element, if the cache contains `MemoryLocationFoo`, then we must
206 // add `WildcardBar` to it.
207 for (const std::unique_ptr<Element>& e : this->indexToElementMap_) {
208 e->cachedAllContainedMemoryLocations_.reset();
209 if (e->values.empty()) {
210 // This element is a wildcard element, we can skip it.
211 continue;
212 }
213
214 auto wildcardElement = getWildcardElement(*(e->values.begin()));
215 if (!wildcardElement) {
216 // This value is not a wildcard.
217 continue;
218 }
219 auto it = cacheUpdates.find(wildcardElement);
220 if (it == cacheUpdates.end()) {
221 // We didn't rewrite any MemoryLocations to point to this element.
222 continue;
223 }
224 // If this element contains an edited memory location, update the cache to
225 // contain the pointed-to wildcard element as well.
226 if (getMemoryLocations(e.get()).intersects(it->second)) {
227 e->cachedMemoryLocations_->set(wildcardElement->index);
228 }
229 }
230 }
231
unsafeMakeFreshValue(const Value * v)232 Element* MemoryDAG::unsafeMakeFreshValue(const Value* v) {
233 return makeFreshValueImpl(v, indexToElementMap_);
234 }
235 } // namespace jit
236 } // namespace torch
237