xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/utils/memory_dag.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/utils/memory_dag.h>
2 
3 #include <c10/util/flat_hash_map.h>
4 #include <algorithm>
5 #include <queue>
6 
7 namespace torch {
8 namespace jit {
9 namespace {
10 
makePointerToImpl(Element * from,Element * to)11 void makePointerToImpl(Element* from, Element* to) {
12   from->pointsTo.set(to->index);
13   to->pointedFrom.set(from->index);
14 }
15 
makeFreshValueImpl(const Value * v,std::vector<std::unique_ptr<Element>> & indexToElementMap_)16 Element* makeFreshValueImpl(
17     const Value* v,
18     std::vector<std::unique_ptr<Element>>& indexToElementMap_) {
19   if (v == nullptr) {
20     // Create a wildcard element, with no corresponding value
21     indexToElementMap_.emplace_back(
22         std::make_unique<Element>(indexToElementMap_.size()));
23     return indexToElementMap_.back().get();
24   }
25   indexToElementMap_.emplace_back(
26       std::make_unique<Element>(v, indexToElementMap_.size()));
27   return indexToElementMap_.back().get();
28 }
29 } // namespace
30 
Element(const Value * value_,unsigned index_)31 Element::Element(const Value* value_, unsigned index_)
32     : index(index_), values({value_}) {}
Element(unsigned index_)33 Element::Element(unsigned index_) : index(index_), values({}) {}
34 
fromIndex(unsigned x) const35 const Element* MemoryDAG::fromIndex(unsigned x) const {
36   TORCH_INTERNAL_ASSERT(x < indexToElementMap_.size());
37   return indexToElementMap_[x].get();
38 }
39 
fromIndex(unsigned x)40 Element* MemoryDAG::fromIndex(unsigned x) {
41   TORCH_INTERNAL_ASSERT(x < indexToElementMap_.size());
42   return indexToElementMap_[x].get();
43 }
44 
mayAlias(const Element * a,const Element * b) const45 bool MemoryDAG::mayAlias(const Element* a, const Element* b) const {
46   const auto& aMemLoc = getMemoryLocations(a);
47   const auto& bMemLoc = getMemoryLocations(b);
48 
49   return aMemLoc.intersects(bMemLoc);
50 }
51 
mayContainAlias(const Element * a,const Element * b) const52 bool MemoryDAG::mayContainAlias(const Element* a, const Element* b) const {
53   return getAllContainedMemoryLocations(a).intersects(
54       getAllContainedMemoryLocations(b));
55 }
56 
getAllContainedMemoryLocations(const Element * elem) const57 const MemoryLocations& MemoryDAG::getAllContainedMemoryLocations(
58     const Element* elem) const {
59   if (C10_UNLIKELY(!elem->cachedAllContainedMemoryLocations_.has_value())) {
60     MemoryLocations cache;
61     elem->cachedAllContainedMemoryLocations_ = MemoryLocations();
62     collectAllContainedMemoryLocationsImpl(
63         elem, *elem->cachedAllContainedMemoryLocations_);
64   }
65   return *elem->cachedAllContainedMemoryLocations_;
66 }
67 
collectAllContainedMemoryLocations(const Element * elem,MemoryLocations & cont) const68 void MemoryDAG::collectAllContainedMemoryLocations(
69     const Element* elem,
70     MemoryLocations& cont) const {
71   // we have already recursed on this element
72   unsigned compIdx = elem->index;
73   if (cont.test(compIdx)) {
74     return;
75   }
76 
77   if (C10_UNLIKELY(!elem->cachedAllContainedMemoryLocations_.has_value())) {
78     MemoryLocations cache;
79     collectAllContainedMemoryLocationsImpl(elem, cache);
80     elem->cachedAllContainedMemoryLocations_ = std::move(cache);
81   }
82   cont |= *elem->cachedAllContainedMemoryLocations_;
83 }
84 
collectAllContainedMemoryLocationsImpl(const Element * elem,MemoryLocations & cont) const85 void MemoryDAG::collectAllContainedMemoryLocationsImpl(
86     const Element* elem,
87     MemoryLocations& cont) const {
88   unsigned compIdx = elem->index;
89   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!cont.test(compIdx));
90   cont.set(compIdx);
91 
92   for (const auto& mem_loc : getMemoryLocations(elem)) {
93     collectAllContainedMemoryLocations(fromIndex(mem_loc), cont);
94   }
95 
96   for (const auto& contained : elem->containedElements) {
97     collectAllContainedMemoryLocations(fromIndex(contained), cont);
98   }
99 }
100 
mayContainAlias(const Element * a,const at::ArrayRef<Element * > b) const101 bool MemoryDAG::mayContainAlias(
102     const Element* a,
103     const at::ArrayRef<Element*> b) const {
104   if (b.empty()) {
105     return false;
106   }
107 
108   const auto& a_contained = getAllContainedMemoryLocations(a);
109   return std::any_of(b.begin(), b.end(), [this, &a_contained](Element* b_elem) {
110     return a_contained.intersects(this->getAllContainedMemoryLocations(b_elem));
111   });
112 }
113 
mayContainAlias(const at::ArrayRef<Element * > a,const at::ArrayRef<Element * > b) const114 bool MemoryDAG::mayContainAlias(
115     const at::ArrayRef<Element*> a,
116     const at::ArrayRef<Element*> b) const {
117   if (a.empty() || b.empty()) {
118     return false;
119   }
120 
121   MemoryLocations all_a_mlocs;
122   for (const auto& elem : a) {
123     collectAllContainedMemoryLocations(elem, all_a_mlocs);
124   }
125 
126   MemoryLocations all_b_mlocs;
127   for (const auto& elem : b) {
128     collectAllContainedMemoryLocations(elem, all_b_mlocs);
129   }
130 
131   return all_a_mlocs.intersects(all_b_mlocs);
132 }
133 
makePointerTo(Element * from,Element * to)134 void MemoryDAGBuilder::makePointerTo(Element* from, Element* to) {
135   makePointerToImpl(from, to);
136 }
137 
addToContainedElements(Element * elem,Element * container)138 void MemoryDAGBuilder::addToContainedElements(
139     Element* elem,
140     Element* container) {
141   TORCH_INTERNAL_ASSERT(
142       elem != container, "Elements cannot contain themselves");
143   container->containedElements.set(elem->index);
144 }
145 
146 // Give `v` a fresh alias (i.e. it does not point to any value)
makeFreshValue(const Value * v)147 Element* MemoryDAGBuilder::makeFreshValue(const Value* v) {
148   return makeFreshValueImpl(v, indexToElementMap_);
149 }
150 
151 // This function builds up a bitset representing the "alias set" for
152 // `e` (`MemoryLocations` is just a typedef'd c10::SparseBitVector).
getMemoryLocations(const Element * e) const153 const MemoryLocations& MemoryDAG::getMemoryLocations(const Element* e) const {
154   // Note on cache invalidation: all mutation should occur through
155   // MemoryDAGBuilder. Thus, once we consume the builder to create an
156   // immutable MemoryDAG, we can cache here without worrying that we
157   // might potentially get invalidated.
158   if (e->cachedMemoryLocations_) {
159     return *e->cachedMemoryLocations_;
160   }
161 
162   MemoryLocations ret;
163   if (e->pointsTo.empty()) {
164     // Base case: if we don't point to anything, this element is a memory
165     // location. Return itself.
166     ret.set(e->index);
167   } else {
168     for (auto el : e->pointsTo) {
169       ret |= getMemoryLocations(fromIndex(el));
170     }
171   }
172 
173   e->cachedMemoryLocations_ = std::move(ret);
174   return *e->cachedMemoryLocations_;
175 }
176 
setWildcards(const std::unordered_set<const Value * > & wildcards,const ska::flat_hash_map<const Value *,Element * > & elementMap,const std::function<Element * (const Value *)> & getWildcardElement)177 void MemoryDAG::setWildcards(
178     const std::unordered_set<const Value*>& wildcards,
179     const ska::flat_hash_map<const Value*, Element*>& elementMap,
180     const std::function<Element*(const Value*)>& getWildcardElement) {
181   std::unordered_map<Element*, MemoryLocations> cacheUpdates;
182   // If an element is set as a wildcard, that means that all its memory
183   // locations must point to the wildcard element.
184   for (const Value* v : wildcards) {
185     auto wildcardElement = getWildcardElement(v);
186     TORCH_INTERNAL_ASSERT(wildcardElement);
187 
188     const MemoryLocations& pointeeSet = getMemoryLocations(elementMap.at(v));
189     for (const auto& pointee : pointeeSet) {
190       auto from = this->fromIndex(pointee);
191       // avoid cycles where the wildcard points to itself
192       if (from != wildcardElement) {
193         makePointerToImpl(from, wildcardElement);
194       }
195     }
196     // Track which memory locations we edited with a new pointer to the wildcard
197     // element.
198     cacheUpdates[wildcardElement] |= pointeeSet;
199   }
200 
201   // Update caches in-place.
202   // We take advantage of the fact that we only edited memory locations.
203   //
204   // Say we added a pointer from `MemoryLocationFoo -> WildcardBar`.
205   // For every element, if the cache contains `MemoryLocationFoo`, then we must
206   // add `WildcardBar` to it.
207   for (const std::unique_ptr<Element>& e : this->indexToElementMap_) {
208     e->cachedAllContainedMemoryLocations_.reset();
209     if (e->values.empty()) {
210       // This element is a wildcard element, we can skip it.
211       continue;
212     }
213 
214     auto wildcardElement = getWildcardElement(*(e->values.begin()));
215     if (!wildcardElement) {
216       // This value is not a wildcard.
217       continue;
218     }
219     auto it = cacheUpdates.find(wildcardElement);
220     if (it == cacheUpdates.end()) {
221       // We didn't rewrite any MemoryLocations to point to this element.
222       continue;
223     }
224     // If this element contains an edited memory location, update the cache to
225     // contain the pointed-to wildcard element as well.
226     if (getMemoryLocations(e.get()).intersects(it->second)) {
227       e->cachedMemoryLocations_->set(wildcardElement->index);
228     }
229   }
230 }
231 
unsafeMakeFreshValue(const Value * v)232 Element* MemoryDAG::unsafeMakeFreshValue(const Value* v) {
233   return makeFreshValueImpl(v, indexToElementMap_);
234 }
235 } // namespace jit
236 } // namespace torch
237