xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/utils/memory_dag.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/jit_type.h>
4 #include <c10/util/ArrayRef.h>
5 #include <c10/util/flat_hash_map.h>
6 #include <c10/util/sparse_bitset.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 #include <torch/csrc/jit/ir/type_hashing.h>
9 #include <memory>
10 #include <optional>
11 #include <unordered_map>
12 #include <unordered_set>
13 #include <vector>
14 
15 #include <torch/csrc/Export.h>
16 
17 // Uses a compressed index representation for faster comparisons
18 typedef c10::SparseBitVector<256> MemoryLocations;
19 namespace torch {
20 namespace jit {
21 
22 struct Value;
23 
24 using AliasTypeSet = std::vector<TypePtr>;
25 
26 // `Element` represents a vertex in the points-to graph. It represents
27 // anything that could have an aliasing relationship--mostly IR
28 // `Value`s, but also wildcards or the type inside a container (e.g. `T`
29 // in `List[T]`)
30 struct Element {
31   Element(const Value* value_, unsigned index_);
32   // wildcard constructor
33   explicit Element(unsigned index_);
34 
35   // Index into the owning DAG's bit vector that represents this element.
36   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
37   unsigned index;
38 
39   // All elements that this element *may* point to. It's possible to have
40   // multiple elements that you might point to due to control flow/complex ops
41   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
42   MemoryLocations pointsTo;
43   // Backreference for points-to.
44   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
45   MemoryLocations pointedFrom;
46 
47   // Elements can contain other elements (e.g. List[Tensor])
48   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
49   MemoryLocations containedElements;
50 
51   // The values that this element corresponds to. May be empty if this element
52   // doesn't represent a first-class value.
53   // This is for debug information only.
54   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
55   std::unordered_set<const Value*> values;
56 
57  private:
58   // Make `from` point at `to`.
59   void makePointerTo(Element* from, Element* to);
60 
61   friend class MemoryDAG;
62   // We memoize the results of `getMemoryLocations` to speed up queries.
63   // A nullopt means that this cache is not yet populated. Since `MemoryDAG` is
64   // immutable, this cache should never need to be invalidated.
65   mutable std::optional<MemoryLocations> cachedMemoryLocations_;
66 
67   mutable std::optional<MemoryLocations> cachedAllContainedMemoryLocations_;
68 };
69 
70 // class MemoryDAG
71 //
72 // This class tracks the "A points to B" graph for all values. It is used by
73 // AliasDb to provide a higher-level API.
74 //
75 // We maintain a DAG where:
76 //   - Vertices (called "Elements") represent Values and
77 //     other aliasing entities (e.g. the stuff inside a list)
78 //   - Edges represent a "points-to" relationship.
79 //
80 // Leaves in this DAG are entities that don't point to anything, and thus
81 // correspond to unique "memory locations".
82 //
83 // So, by traversing the "points-to" graph to the leaves, you can determine
84 // which memory locations an element may point to.
85 class TORCH_API MemoryDAG {
86  public:
MemoryDAG(std::vector<std::unique_ptr<Element>> indexToElementMap)87   explicit MemoryDAG(std::vector<std::unique_ptr<Element>> indexToElementMap)
88       : indexToElementMap_(std::move(indexToElementMap)) {}
89   // explicitly delete copy constructor because otherwise windows build is
90   // confused for an exported class see
91   // https://stackoverflow.com/a/51033485/105137
92   MemoryDAG(const MemoryDAG&) = delete;
93   MemoryDAG& operator=(const MemoryDAG&) = delete;
94 
95   // Return the unique memory locations that `Element` might represent.
96   const MemoryLocations& getMemoryLocations(const Element* e) const;
97 
98   // Do `a` and `b` potentially share a memory location?
99   bool mayAlias(const Element* a, const Element* b) const;
100 
101   // Does `a` hold reference to any memory that is stored in `b`, or vice versa?
102   bool mayContainAlias(const Element* a, const Element* b) const;
103 
104   bool mayContainAlias(const Element* a, const at::ArrayRef<Element*> b) const;
105 
106   bool mayContainAlias(
107       const at::ArrayRef<Element*> a,
108       const at::ArrayRef<Element*> b) const;
109 
110   // Converts from the compressed index representation
111   const Element* fromIndex(unsigned x) const;
112   Element* fromIndex(unsigned x);
113   void collectAllContainedMemoryLocations(
114       const Element* elem,
115       MemoryLocations& cont) const;
116 
117   /**
118    * The following methods are special cases where we need to mutate the
119    * internals of MemoryDAG for efficiency reasons. Don't call them unless you
120    * know what you're doing! In particular, don't add new mutating methods
121    * without ensuring that you are maintaining cache consistency for memory
122    * locations.
123    */
124 
125   // Adding wildcards can trigger extremely expensive cache invalidations. This
126   // method adds them in a more efficient cache-aware way.
127   void setWildcards(
128       const std::unordered_set<const Value*>& wildcards,
129       const ska::flat_hash_map<const Value*, Element*>& elementMap,
130       const std::function<Element*(const Value*)>& getWildcardElement);
131   Element* unsafeMakeFreshValue(const Value* v);
132 
133  private:
134   const MemoryLocations& getAllContainedMemoryLocations(
135       const Element* elem) const;
136   void collectAllContainedMemoryLocationsImpl(
137       const Element* elem,
138       MemoryLocations& cont) const;
139   std::vector<std::unique_ptr<Element>> indexToElementMap_;
140 };
141 
142 /**
143  * Helper to build up the points-to graph.
144  *
145  * We separate the "building" into a different class because it allows us to
146  * cache internally to MemoryDAG without worrying about how the DAG structure
147  * is mutated.
148  */
149 class TORCH_API MemoryDAGBuilder {
150  public:
151   MemoryDAGBuilder() = default;
152   MemoryDAGBuilder(const MemoryDAGBuilder&) = delete;
153   MemoryDAGBuilder& operator=(const MemoryDAGBuilder&) = delete;
154 
155   // Make `from` point at `to`.
156   void makePointerTo(Element* from, Element* to);
157 
158   void addToContainedElements(Element* contained, Element* container);
159 
createMemoryDAG()160   std::unique_ptr<MemoryDAG> createMemoryDAG() && {
161     return std::make_unique<MemoryDAG>(std::move(indexToElementMap_));
162   }
163 
164   // Make a fresh Element (i.e. an Element that doesn't point to anything) and
165   // return it.
166   Element* makeFreshValue(const Value* v);
167 
168   friend MemoryDAG;
169 
170  private:
171   // `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses
172   // the map to construct the `MemoryDAG`
173   std::vector<std::unique_ptr<Element>> indexToElementMap_;
174 };
175 } // namespace jit
176 } // namespace torch
177