1 // Copyright (c) 2023 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include <algorithm> 16 #include <cassert> 17 #include <cstdint> 18 #include <functional> 19 #include <initializer_list> 20 #include <limits> 21 #include <type_traits> 22 #include <vector> 23 24 #ifndef SOURCE_ENUM_SET_H_ 25 #define SOURCE_ENUM_SET_H_ 26 27 #include "source/latest_version_spirv_header.h" 28 29 namespace spvtools { 30 31 // This container is optimized to store and retrieve unsigned enum values. 32 // The base model for this implementation is an open-addressing hashtable with 33 // linear probing. For small enums (max index < 64), all operations are O(1). 34 // 35 // - Enums are stored in buckets (64 contiguous values max per bucket) 36 // - Buckets ranges don't overlap, but don't have to be contiguous. 37 // - Enums are packed into 64-bits buckets, using 1 bit per enum value. 38 // 39 // Example: 40 // - MyEnum { A = 0, B = 1, C = 64, D = 65 } 41 // - 2 buckets are required: 42 // - bucket 0, storing values in the range [ 0; 64[ 43 // - bucket 1, storing values in the range [64; 128[ 44 // 45 // - Buckets are stored in a sorted vector (sorted by bucket range). 46 // - Retrieval is done by computing the theoretical bucket index using the enum 47 // value, and 48 // doing a linear scan from this position. 49 // - Insertion is done by retrieving the bucket and either: 50 // - inserting a new bucket in the sorted vector when no buckets has a 51 // compatible range. 52 // - setting the corresponding bit in the bucket. 53 // This means insertion in the middle/beginning can cause a memmove when no 54 // bucket is available. In our case, this happens at most 23 times for the 55 // largest enum we have (Opcodes). 56 template <typename T> 57 class EnumSet { 58 private: 59 using BucketType = uint64_t; 60 using ElementType = std::underlying_type_t<T>; 61 static_assert(std::is_enum_v<T>, "EnumSets only works with enums."); 62 static_assert(std::is_signed_v<ElementType> == false, 63 "EnumSet doesn't supports signed enums."); 64 65 // Each bucket can hold up to `kBucketSize` distinct, contiguous enum values. 66 // The first value a bucket can hold must be aligned on `kBucketSize`. 67 struct Bucket { 68 // bit mask to store `kBucketSize` enums. 69 BucketType data; 70 // 1st enum this bucket can represent. 71 T start; 72 73 friend bool operator==(const Bucket& lhs, const Bucket& rhs) { 74 return lhs.start == rhs.start && lhs.data == rhs.data; 75 } 76 }; 77 78 // How many distinct values can a bucket hold? 1 bit per value. 79 static constexpr size_t kBucketSize = sizeof(BucketType) * 8ULL; 80 81 public: 82 class Iterator { 83 public: 84 typedef Iterator self_type; 85 typedef T value_type; 86 typedef T& reference; 87 typedef T* pointer; 88 typedef std::forward_iterator_tag iterator_category; 89 typedef size_t difference_type; 90 Iterator(const Iterator & other)91 Iterator(const Iterator& other) 92 : set_(other.set_), 93 bucketIndex_(other.bucketIndex_), 94 bucketOffset_(other.bucketOffset_) {} 95 96 Iterator& operator++() { 97 do { 98 if (bucketIndex_ >= set_->buckets_.size()) { 99 bucketIndex_ = set_->buckets_.size(); 100 bucketOffset_ = 0; 101 break; 102 } 103 104 if (bucketOffset_ + 1 == kBucketSize) { 105 bucketOffset_ = 0; 106 ++bucketIndex_; 107 } else { 108 ++bucketOffset_; 109 } 110 111 } while (bucketIndex_ < set_->buckets_.size() && 112 !set_->HasEnumAt(bucketIndex_, bucketOffset_)); 113 return *this; 114 } 115 116 Iterator operator++(int) { 117 Iterator old = *this; 118 operator++(); 119 return old; 120 } 121 122 T operator*() const { 123 assert(set_->HasEnumAt(bucketIndex_, bucketOffset_) && 124 "operator*() called on an invalid iterator."); 125 return GetValueFromBucket(set_->buckets_[bucketIndex_], bucketOffset_); 126 } 127 128 bool operator!=(const Iterator& other) const { 129 return set_ != other.set_ || bucketOffset_ != other.bucketOffset_ || 130 bucketIndex_ != other.bucketIndex_; 131 } 132 133 bool operator==(const Iterator& other) const { 134 return !(operator!=(other)); 135 } 136 137 Iterator& operator=(const Iterator& other) { 138 set_ = other.set_; 139 bucketIndex_ = other.bucketIndex_; 140 bucketOffset_ = other.bucketOffset_; 141 return *this; 142 } 143 144 private: Iterator(const EnumSet * set,size_t bucketIndex,ElementType bucketOffset)145 Iterator(const EnumSet* set, size_t bucketIndex, ElementType bucketOffset) 146 : set_(set), bucketIndex_(bucketIndex), bucketOffset_(bucketOffset) {} 147 148 private: 149 const EnumSet* set_ = nullptr; 150 // Index of the bucket in the vector. 151 size_t bucketIndex_ = 0; 152 // Offset in bits in the current bucket. 153 ElementType bucketOffset_ = 0; 154 155 friend class EnumSet; 156 }; 157 158 // Required to allow the use of std::inserter. 159 using value_type = T; 160 using const_iterator = Iterator; 161 using iterator = Iterator; 162 163 public: cbegin()164 iterator cbegin() const noexcept { 165 auto it = iterator(this, /* bucketIndex= */ 0, /* bucketOffset= */ 0); 166 if (buckets_.size() == 0) { 167 return it; 168 } 169 170 // The iterator has the logic to find the next valid bit. If the value 0 171 // is not stored, use it to find the next valid bit. 172 if (!HasEnumAt(it.bucketIndex_, it.bucketOffset_)) { 173 ++it; 174 } 175 176 return it; 177 } 178 begin()179 iterator begin() const noexcept { return cbegin(); } 180 cend()181 iterator cend() const noexcept { 182 return iterator(this, buckets_.size(), /* bucketOffset= */ 0); 183 } 184 end()185 iterator end() const noexcept { return cend(); } 186 187 // Creates an empty set. EnumSet()188 EnumSet() : buckets_(0), size_(0) {} 189 190 // Creates a set and store `value` in it. EnumSet(T value)191 EnumSet(T value) : EnumSet() { insert(value); } 192 193 // Creates a set and stores each `values` in it. EnumSet(std::initializer_list<T> values)194 EnumSet(std::initializer_list<T> values) : EnumSet() { 195 for (auto item : values) { 196 insert(item); 197 } 198 } 199 200 // Creates a set, and insert `count` enum values pointed by `array` in it. EnumSet(ElementType count,const T * array)201 EnumSet(ElementType count, const T* array) : EnumSet() { 202 for (ElementType i = 0; i < count; i++) { 203 insert(array[i]); 204 } 205 } 206 207 // Creates a set initialized with the content of the range [begin; end[. 208 template <class InputIt> EnumSet(InputIt begin,InputIt end)209 EnumSet(InputIt begin, InputIt end) : EnumSet() { 210 for (; begin != end; ++begin) { 211 insert(*begin); 212 } 213 } 214 215 // Copies the EnumSet `other` into a new EnumSet. EnumSet(const EnumSet & other)216 EnumSet(const EnumSet& other) 217 : buckets_(other.buckets_), size_(other.size_) {} 218 219 // Moves the EnumSet `other` into a new EnumSet. EnumSet(EnumSet && other)220 EnumSet(EnumSet&& other) 221 : buckets_(std::move(other.buckets_)), size_(other.size_) {} 222 223 // Deep-copies the EnumSet `other` into this EnumSet. 224 EnumSet& operator=(const EnumSet& other) { 225 buckets_ = other.buckets_; 226 size_ = other.size_; 227 return *this; 228 } 229 230 // Matches std::unordered_set::insert behavior. insert(const T & value)231 std::pair<iterator, bool> insert(const T& value) { 232 const size_t index = FindBucketForValue(value); 233 const ElementType offset = ComputeBucketOffset(value); 234 235 if (index >= buckets_.size() || 236 buckets_[index].start != ComputeBucketStart(value)) { 237 size_ += 1; 238 InsertBucketFor(index, value); 239 return std::make_pair(Iterator(this, index, offset), true); 240 } 241 242 auto& bucket = buckets_[index]; 243 const auto mask = ComputeMaskForValue(value); 244 if (bucket.data & mask) { 245 return std::make_pair(Iterator(this, index, offset), false); 246 } 247 248 size_ += 1; 249 bucket.data |= ComputeMaskForValue(value); 250 return std::make_pair(Iterator(this, index, offset), true); 251 } 252 253 // Inserts `value` in the set if possible. 254 // Similar to `std::unordered_set::insert`, except the hint is ignored. 255 // Returns an iterator to the inserted element, or the element preventing 256 // insertion. insert(const_iterator,const T & value)257 iterator insert(const_iterator, const T& value) { 258 return insert(value).first; 259 } 260 261 // Inserts `value` in the set if possible. 262 // Similar to `std::unordered_set::insert`, except the hint is ignored. 263 // Returns an iterator to the inserted element, or the element preventing 264 // insertion. insert(const_iterator,T && value)265 iterator insert(const_iterator, T&& value) { return insert(value).first; } 266 267 // Inserts all the values in the range [`first`; `last[. 268 // Similar to `std::unordered_set::insert`. 269 template <class InputIt> insert(InputIt first,InputIt last)270 void insert(InputIt first, InputIt last) { 271 for (auto it = first; it != last; ++it) { 272 insert(*it); 273 } 274 } 275 276 // Removes the value `value` into the set. 277 // Similar to `std::unordered_set::erase`. 278 // Returns the number of erased elements. erase(const T & value)279 size_t erase(const T& value) { 280 const size_t index = FindBucketForValue(value); 281 if (index >= buckets_.size() || 282 buckets_[index].start != ComputeBucketStart(value)) { 283 return 0; 284 } 285 286 auto& bucket = buckets_[index]; 287 const auto mask = ComputeMaskForValue(value); 288 if (!(bucket.data & mask)) { 289 return 0; 290 } 291 292 size_ -= 1; 293 bucket.data &= ~mask; 294 if (bucket.data == 0) { 295 buckets_.erase(buckets_.cbegin() + index); 296 } 297 return 1; 298 } 299 300 // Returns true if `value` is present in the set. contains(T value)301 bool contains(T value) const { 302 const size_t index = FindBucketForValue(value); 303 if (index >= buckets_.size() || 304 buckets_[index].start != ComputeBucketStart(value)) { 305 return false; 306 } 307 auto& bucket = buckets_[index]; 308 return bucket.data & ComputeMaskForValue(value); 309 } 310 311 // Returns the 1 if `value` is present in the set, `0` otherwise. count(T value)312 inline size_t count(T value) const { return contains(value) ? 1 : 0; } 313 314 // Returns true if the set is holds no values. empty()315 inline bool empty() const { return size_ == 0; } 316 317 // Returns the number of enums stored in this set. size()318 size_t size() const { return size_; } 319 320 // Returns true if this set contains at least one value contained in `in_set`. 321 // Note: If `in_set` is empty, this function returns true. HasAnyOf(const EnumSet<T> & in_set)322 bool HasAnyOf(const EnumSet<T>& in_set) const { 323 if (in_set.empty()) { 324 return true; 325 } 326 327 auto lhs = buckets_.cbegin(); 328 auto rhs = in_set.buckets_.cbegin(); 329 330 while (lhs != buckets_.cend() && rhs != in_set.buckets_.cend()) { 331 if (lhs->start == rhs->start) { 332 if (lhs->data & rhs->data) { 333 // At least 1 bit is shared. Early return. 334 return true; 335 } 336 337 lhs++; 338 rhs++; 339 continue; 340 } 341 342 // LHS bucket is smaller than the current RHS bucket. Catching up on RHS. 343 if (lhs->start < rhs->start) { 344 lhs++; 345 continue; 346 } 347 348 // Otherwise, RHS needs to catch up on LHS. 349 rhs++; 350 } 351 352 return false; 353 } 354 355 private: 356 // Returns the index of the last bucket in which `value` could be stored. ComputeLargestPossibleBucketIndexFor(T value)357 static constexpr inline size_t ComputeLargestPossibleBucketIndexFor(T value) { 358 return static_cast<size_t>(value) / kBucketSize; 359 } 360 361 // Returns the smallest enum value that could be contained in the same bucket 362 // as `value`. ComputeBucketStart(T value)363 static constexpr inline T ComputeBucketStart(T value) { 364 return static_cast<T>(kBucketSize * 365 ComputeLargestPossibleBucketIndexFor(value)); 366 } 367 368 // Returns the index of the bit that corresponds to `value` in the bucket. ComputeBucketOffset(T value)369 static constexpr inline ElementType ComputeBucketOffset(T value) { 370 return static_cast<ElementType>(value) % kBucketSize; 371 } 372 373 // Returns the bitmask used to represent the enum `value` in its bucket. ComputeMaskForValue(T value)374 static constexpr inline BucketType ComputeMaskForValue(T value) { 375 return 1ULL << ComputeBucketOffset(value); 376 } 377 378 // Returns the `enum` stored in `bucket` at `offset`. 379 // `offset` is the bit-offset in the bucket storage. GetValueFromBucket(const Bucket & bucket,BucketType offset)380 static constexpr inline T GetValueFromBucket(const Bucket& bucket, 381 BucketType offset) { 382 return static_cast<T>(static_cast<ElementType>(bucket.start) + offset); 383 } 384 385 // For a given enum `value`, finds the bucket index that could contain this 386 // value. If no such bucket is found, the index at which the new bucket should 387 // be inserted is returned. FindBucketForValue(T value)388 size_t FindBucketForValue(T value) const { 389 // Set is empty, insert at 0. 390 if (buckets_.size() == 0) { 391 return 0; 392 } 393 394 const T wanted_start = ComputeBucketStart(value); 395 assert(buckets_.size() > 0 && 396 "Size must not be 0 here. Has the code above changed?"); 397 size_t index = std::min(buckets_.size() - 1, 398 ComputeLargestPossibleBucketIndexFor(value)); 399 400 // This loops behaves like std::upper_bound with a reverse iterator. 401 // Buckets are sorted. 3 main cases: 402 // - The bucket matches 403 // => returns the bucket index. 404 // - The found bucket is larger 405 // => scans left until it finds the correct bucket, or insertion point. 406 // - The found bucket is smaller 407 // => We are at the end, so we return past-end index for insertion. 408 for (; buckets_[index].start >= wanted_start; index--) { 409 if (index == 0) { 410 return 0; 411 } 412 } 413 414 return index + 1; 415 } 416 417 // Creates a new bucket to store `value` and inserts it at `index`. 418 // If the `index` is past the end, the bucket is inserted at the end of the 419 // vector. InsertBucketFor(size_t index,T value)420 void InsertBucketFor(size_t index, T value) { 421 const T bucket_start = ComputeBucketStart(value); 422 Bucket bucket = {1ULL << ComputeBucketOffset(value), bucket_start}; 423 auto it = buckets_.emplace(buckets_.begin() + index, std::move(bucket)); 424 #if defined(NDEBUG) 425 (void)it; // Silencing unused variable warning. 426 #else 427 assert(std::next(it) == buckets_.end() || 428 std::next(it)->start > bucket_start); 429 assert(it == buckets_.begin() || std::prev(it)->start < bucket_start); 430 #endif 431 } 432 433 // Returns true if the bucket at `bucketIndex/ stores the enum at 434 // `bucketOffset`, false otherwise. HasEnumAt(size_t bucketIndex,BucketType bucketOffset)435 bool HasEnumAt(size_t bucketIndex, BucketType bucketOffset) const { 436 assert(bucketIndex < buckets_.size()); 437 assert(bucketOffset < kBucketSize); 438 return buckets_[bucketIndex].data & (1ULL << bucketOffset); 439 } 440 441 // Returns true if `lhs` and `rhs` hold the exact same values. 442 friend bool operator==(const EnumSet& lhs, const EnumSet& rhs) { 443 if (lhs.size_ != rhs.size_) { 444 return false; 445 } 446 447 if (lhs.buckets_.size() != rhs.buckets_.size()) { 448 return false; 449 } 450 return lhs.buckets_ == rhs.buckets_; 451 } 452 453 // Returns true if `lhs` and `rhs` hold at least 1 different value. 454 friend bool operator!=(const EnumSet& lhs, const EnumSet& rhs) { 455 return !(lhs == rhs); 456 } 457 458 // Storage for the buckets. 459 std::vector<Bucket> buckets_; 460 // How many enums is this set storing. 461 size_t size_ = 0; 462 }; 463 464 // A set of spv::Capability. 465 using CapabilitySet = EnumSet<spv::Capability>; 466 467 } // namespace spvtools 468 469 #endif // SOURCE_ENUM_SET_H_ 470