xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/alias_info.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <unordered_set>
3 #include <vector>
4 #include <ATen/core/symbol.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/hash.h>
7 
8 namespace c10 {
9 /**
10  * class AliasInfo
11  *
12  * Data structure to hold aliasing information for an `Argument`. They can be
13  * nested to represent aliasing information on contained types.
14  *
15  * There is a `beforeSet` which describes the aliasing information before the
16  * operator executes, and an `afterSet` that describes aliasing info
17  * after execution.
18  */
19 class AliasInfo {
20  public:
21   // Symbol for the set that can alias anything
wildcardSet()22   static Symbol wildcardSet() {
23     static const Symbol wc = Symbol::fromQualString("alias::*");
24     return wc;
25   }
26 
setIsWrite(bool isWrite)27   void setIsWrite(bool isWrite) {
28     isWrite_ = isWrite;
29   }
30 
isWrite()31   bool isWrite() const {
32     return isWrite_;
33   }
34 
addBeforeSet(Symbol aliasSet)35   void addBeforeSet(Symbol aliasSet) {
36     beforeSets_.insert(aliasSet);
37   }
38 
addAfterSet(Symbol aliasSet)39   void addAfterSet(Symbol aliasSet) {
40     afterSets_.insert(aliasSet);
41   }
42 
beforeSets()43   const std::unordered_set<Symbol>& beforeSets() const {
44     return beforeSets_;
45   }
46 
afterSets()47   const std::unordered_set<Symbol>& afterSets() const {
48     return afterSets_;
49   }
50 
beforeSet()51   Symbol beforeSet() const {
52     AT_ASSERT(beforeSets_.size() == 1);
53     return *beforeSets_.begin();
54   }
55 
isWildcardBefore()56   bool isWildcardBefore() const {
57     return beforeSets_.count(wildcardSet()) != 0;
58   }
59 
isWildcardAfter()60   bool isWildcardAfter() const {
61     return afterSets_.count(wildcardSet()) != 0;
62   }
63 
64   // the alias info for the contained types of the type
65   // e.g. if this is an annotation on List[T], `sets` refers to
66   // the alias sets that the list may be in
67   // while containedTypes()[0] refers to the sets that members of the list
68   // may be in
addContainedType(AliasInfo aliasInfo)69   void addContainedType(AliasInfo aliasInfo) {
70     containedTypes_.push_back(std::move(aliasInfo));
71   }
containedTypes()72   const std::vector<AliasInfo>& containedTypes() const {
73     return containedTypes_;
74   }
75 
76  private:
77   std::unordered_set<Symbol> beforeSets_;
78   std::unordered_set<Symbol> afterSets_;
79   std::vector<AliasInfo> containedTypes_;
80   bool isWrite_ = false;
81 };
82 
83 inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
84   return lhs.isWrite() == rhs.isWrite()
85       && lhs.beforeSets() == rhs.beforeSets()
86       && lhs.afterSets() == rhs.afterSets()
87       && lhs.containedTypes() == rhs.containedTypes();
88 }
89 
90 // this does match the way things are represented in the schema
91 inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
92   out << "(";
93   bool first = true;
94   for (const auto& set : aliasInfo.beforeSets()) {
95     if (first) {
96       first = false;
97     } else {
98       out << "|";
99     }
100     out << set.toUnqualString();
101   }
102   if (aliasInfo.isWrite()) {
103     out << "!";
104   }
105   if (aliasInfo.beforeSets() != aliasInfo.afterSets()) {
106     out << " -> ";
107     first = true;
108     for (const auto& set : aliasInfo.afterSets()) {
109       if (first) {
110         first = false;
111       } else {
112         out << "|";
113       }
114       out << set.toUnqualString();
115     }
116   }
117   out << ")";
118   return out;
119 }
120 } // namespace c10
121 
122 namespace std {
123 template <>
124   struct hash<c10::AliasInfo> {
125     size_t operator()(const c10::AliasInfo& aliasInfo) const {
126       auto hash = std::hash<bool>()(aliasInfo.isWrite());
127 
128       // NOTE: for unordered_set hashes, we couldn't use hash_combine
129       // because hash_combine is order dependent. Instead, we choose to
130       // use XOR as the combining function as XOR is commutative.
131       size_t before_set_hash_seed = 0;
132       for (auto &e: aliasInfo.beforeSets()) {
133         auto symbol_hash = std::hash<c10::Symbol>()(e);
134         before_set_hash_seed = before_set_hash_seed ^ symbol_hash;
135       }
136       size_t after_set_hash_seed = 0;
137       for (auto &e: aliasInfo.afterSets()) {
138         auto symbol_hash = std::hash<c10::Symbol>()(e);
139         after_set_hash_seed = after_set_hash_seed ^ symbol_hash;
140       }
141 
142       hash = c10::hash_combine(hash, before_set_hash_seed);
143       hash = c10::hash_combine(hash, after_set_hash_seed);
144       for (auto &e: aliasInfo.containedTypes()) {
145         auto contained_type_hash = std::hash<c10::AliasInfo>()(e);
146         hash = c10::hash_combine(hash, contained_type_hash);
147       }
148       return hash;
149     }
150   };
151 }
152