1 #pragma once 2 #include <unordered_set> 3 #include <vector> 4 #include <ATen/core/symbol.h> 5 #include <c10/util/Exception.h> 6 #include <c10/util/hash.h> 7 8 namespace c10 { 9 /** 10 * class AliasInfo 11 * 12 * Data structure to hold aliasing information for an `Argument`. They can be 13 * nested to represent aliasing information on contained types. 14 * 15 * There is a `beforeSet` which describes the aliasing information before the 16 * operator executes, and an `afterSet` that describes aliasing info 17 * after execution. 18 */ 19 class AliasInfo { 20 public: 21 // Symbol for the set that can alias anything wildcardSet()22 static Symbol wildcardSet() { 23 static const Symbol wc = Symbol::fromQualString("alias::*"); 24 return wc; 25 } 26 setIsWrite(bool isWrite)27 void setIsWrite(bool isWrite) { 28 isWrite_ = isWrite; 29 } 30 isWrite()31 bool isWrite() const { 32 return isWrite_; 33 } 34 addBeforeSet(Symbol aliasSet)35 void addBeforeSet(Symbol aliasSet) { 36 beforeSets_.insert(aliasSet); 37 } 38 addAfterSet(Symbol aliasSet)39 void addAfterSet(Symbol aliasSet) { 40 afterSets_.insert(aliasSet); 41 } 42 beforeSets()43 const std::unordered_set<Symbol>& beforeSets() const { 44 return beforeSets_; 45 } 46 afterSets()47 const std::unordered_set<Symbol>& afterSets() const { 48 return afterSets_; 49 } 50 beforeSet()51 Symbol beforeSet() const { 52 AT_ASSERT(beforeSets_.size() == 1); 53 return *beforeSets_.begin(); 54 } 55 isWildcardBefore()56 bool isWildcardBefore() const { 57 return beforeSets_.count(wildcardSet()) != 0; 58 } 59 isWildcardAfter()60 bool isWildcardAfter() const { 61 return afterSets_.count(wildcardSet()) != 0; 62 } 63 64 // the alias info for the contained types of the type 65 // e.g. if this is an annotation on List[T], `sets` refers to 66 // the alias sets that the list may be in 67 // while containedTypes()[0] refers to the sets that members of the list 68 // may be in addContainedType(AliasInfo aliasInfo)69 void addContainedType(AliasInfo aliasInfo) { 70 containedTypes_.push_back(std::move(aliasInfo)); 71 } containedTypes()72 const std::vector<AliasInfo>& containedTypes() const { 73 return containedTypes_; 74 } 75 76 private: 77 std::unordered_set<Symbol> beforeSets_; 78 std::unordered_set<Symbol> afterSets_; 79 std::vector<AliasInfo> containedTypes_; 80 bool isWrite_ = false; 81 }; 82 83 inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) { 84 return lhs.isWrite() == rhs.isWrite() 85 && lhs.beforeSets() == rhs.beforeSets() 86 && lhs.afterSets() == rhs.afterSets() 87 && lhs.containedTypes() == rhs.containedTypes(); 88 } 89 90 // this does match the way things are represented in the schema 91 inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) { 92 out << "("; 93 bool first = true; 94 for (const auto& set : aliasInfo.beforeSets()) { 95 if (first) { 96 first = false; 97 } else { 98 out << "|"; 99 } 100 out << set.toUnqualString(); 101 } 102 if (aliasInfo.isWrite()) { 103 out << "!"; 104 } 105 if (aliasInfo.beforeSets() != aliasInfo.afterSets()) { 106 out << " -> "; 107 first = true; 108 for (const auto& set : aliasInfo.afterSets()) { 109 if (first) { 110 first = false; 111 } else { 112 out << "|"; 113 } 114 out << set.toUnqualString(); 115 } 116 } 117 out << ")"; 118 return out; 119 } 120 } // namespace c10 121 122 namespace std { 123 template <> 124 struct hash<c10::AliasInfo> { 125 size_t operator()(const c10::AliasInfo& aliasInfo) const { 126 auto hash = std::hash<bool>()(aliasInfo.isWrite()); 127 128 // NOTE: for unordered_set hashes, we couldn't use hash_combine 129 // because hash_combine is order dependent. Instead, we choose to 130 // use XOR as the combining function as XOR is commutative. 131 size_t before_set_hash_seed = 0; 132 for (auto &e: aliasInfo.beforeSets()) { 133 auto symbol_hash = std::hash<c10::Symbol>()(e); 134 before_set_hash_seed = before_set_hash_seed ^ symbol_hash; 135 } 136 size_t after_set_hash_seed = 0; 137 for (auto &e: aliasInfo.afterSets()) { 138 auto symbol_hash = std::hash<c10::Symbol>()(e); 139 after_set_hash_seed = after_set_hash_seed ^ symbol_hash; 140 } 141 142 hash = c10::hash_combine(hash, before_set_hash_seed); 143 hash = c10::hash_combine(hash, after_set_hash_seed); 144 for (auto &e: aliasInfo.containedTypes()) { 145 auto contained_type_hash = std::hash<c10::AliasInfo>()(e); 146 hash = c10::hash_combine(hash, contained_type_hash); 147 } 148 return hash; 149 } 150 }; 151 } 152