xref: /aosp_15_r20/external/angle/third_party/spirv-tools/src/source/enum_set.h (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 // Copyright (c) 2023 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 #include <cassert>
17 #include <cstdint>
18 #include <functional>
19 #include <initializer_list>
20 #include <limits>
21 #include <type_traits>
22 #include <vector>
23 
24 #ifndef SOURCE_ENUM_SET_H_
25 #define SOURCE_ENUM_SET_H_
26 
27 #include "source/latest_version_spirv_header.h"
28 
29 namespace spvtools {
30 
31 // This container is optimized to store and retrieve unsigned enum values.
32 // The base model for this implementation is an open-addressing hashtable with
33 // linear probing. For small enums (max index < 64), all operations are O(1).
34 //
35 // - Enums are stored in buckets (64 contiguous values max per bucket)
36 // - Buckets ranges don't overlap, but don't have to be contiguous.
37 // - Enums are packed into 64-bits buckets, using 1 bit per enum value.
38 //
39 // Example:
40 //  - MyEnum { A = 0, B = 1, C = 64, D = 65 }
41 //  - 2 buckets are required:
42 //      - bucket 0, storing values in the range [ 0;  64[
43 //      - bucket 1, storing values in the range [64; 128[
44 //
45 // - Buckets are stored in a sorted vector (sorted by bucket range).
46 // - Retrieval is done by computing the theoretical bucket index using the enum
47 // value, and
48 //   doing a linear scan from this position.
49 // - Insertion is done by retrieving the bucket and either:
50 //   - inserting a new bucket in the sorted vector when no buckets has a
51 //   compatible range.
52 //   - setting the corresponding bit in the bucket.
53 //   This means insertion in the middle/beginning can cause a memmove when no
54 //   bucket is available. In our case, this happens at most 23 times for the
55 //   largest enum we have (Opcodes).
56 template <typename T>
57 class EnumSet {
58  private:
59   using BucketType = uint64_t;
60   using ElementType = std::underlying_type_t<T>;
61   static_assert(std::is_enum_v<T>, "EnumSets only works with enums.");
62   static_assert(std::is_signed_v<ElementType> == false,
63                 "EnumSet doesn't supports signed enums.");
64 
65   // Each bucket can hold up to `kBucketSize` distinct, contiguous enum values.
66   // The first value a bucket can hold must be aligned on `kBucketSize`.
67   struct Bucket {
68     // bit mask to store `kBucketSize` enums.
69     BucketType data;
70     // 1st enum this bucket can represent.
71     T start;
72 
73     friend bool operator==(const Bucket& lhs, const Bucket& rhs) {
74       return lhs.start == rhs.start && lhs.data == rhs.data;
75     }
76   };
77 
78   // How many distinct values can a bucket hold? 1 bit per value.
79   static constexpr size_t kBucketSize = sizeof(BucketType) * 8ULL;
80 
81  public:
82   class Iterator {
83    public:
84     typedef Iterator self_type;
85     typedef T value_type;
86     typedef T& reference;
87     typedef T* pointer;
88     typedef std::forward_iterator_tag iterator_category;
89     typedef size_t difference_type;
90 
Iterator(const Iterator & other)91     Iterator(const Iterator& other)
92         : set_(other.set_),
93           bucketIndex_(other.bucketIndex_),
94           bucketOffset_(other.bucketOffset_) {}
95 
96     Iterator& operator++() {
97       do {
98         if (bucketIndex_ >= set_->buckets_.size()) {
99           bucketIndex_ = set_->buckets_.size();
100           bucketOffset_ = 0;
101           break;
102         }
103 
104         if (bucketOffset_ + 1 == kBucketSize) {
105           bucketOffset_ = 0;
106           ++bucketIndex_;
107         } else {
108           ++bucketOffset_;
109         }
110 
111       } while (bucketIndex_ < set_->buckets_.size() &&
112                !set_->HasEnumAt(bucketIndex_, bucketOffset_));
113       return *this;
114     }
115 
116     Iterator operator++(int) {
117       Iterator old = *this;
118       operator++();
119       return old;
120     }
121 
122     T operator*() const {
123       assert(set_->HasEnumAt(bucketIndex_, bucketOffset_) &&
124              "operator*() called on an invalid iterator.");
125       return GetValueFromBucket(set_->buckets_[bucketIndex_], bucketOffset_);
126     }
127 
128     bool operator!=(const Iterator& other) const {
129       return set_ != other.set_ || bucketOffset_ != other.bucketOffset_ ||
130              bucketIndex_ != other.bucketIndex_;
131     }
132 
133     bool operator==(const Iterator& other) const {
134       return !(operator!=(other));
135     }
136 
137     Iterator& operator=(const Iterator& other) {
138       set_ = other.set_;
139       bucketIndex_ = other.bucketIndex_;
140       bucketOffset_ = other.bucketOffset_;
141       return *this;
142     }
143 
144    private:
Iterator(const EnumSet * set,size_t bucketIndex,ElementType bucketOffset)145     Iterator(const EnumSet* set, size_t bucketIndex, ElementType bucketOffset)
146         : set_(set), bucketIndex_(bucketIndex), bucketOffset_(bucketOffset) {}
147 
148    private:
149     const EnumSet* set_ = nullptr;
150     // Index of the bucket in the vector.
151     size_t bucketIndex_ = 0;
152     // Offset in bits in the current bucket.
153     ElementType bucketOffset_ = 0;
154 
155     friend class EnumSet;
156   };
157 
158   // Required to allow the use of std::inserter.
159   using value_type = T;
160   using const_iterator = Iterator;
161   using iterator = Iterator;
162 
163  public:
cbegin()164   iterator cbegin() const noexcept {
165     auto it = iterator(this, /* bucketIndex= */ 0, /* bucketOffset= */ 0);
166     if (buckets_.size() == 0) {
167       return it;
168     }
169 
170     // The iterator has the logic to find the next valid bit. If the value 0
171     // is not stored, use it to find the next valid bit.
172     if (!HasEnumAt(it.bucketIndex_, it.bucketOffset_)) {
173       ++it;
174     }
175 
176     return it;
177   }
178 
begin()179   iterator begin() const noexcept { return cbegin(); }
180 
cend()181   iterator cend() const noexcept {
182     return iterator(this, buckets_.size(), /* bucketOffset= */ 0);
183   }
184 
end()185   iterator end() const noexcept { return cend(); }
186 
187   // Creates an empty set.
EnumSet()188   EnumSet() : buckets_(0), size_(0) {}
189 
190   // Creates a set and store `value` in it.
EnumSet(T value)191   EnumSet(T value) : EnumSet() { insert(value); }
192 
193   // Creates a set and stores each `values` in it.
EnumSet(std::initializer_list<T> values)194   EnumSet(std::initializer_list<T> values) : EnumSet() {
195     for (auto item : values) {
196       insert(item);
197     }
198   }
199 
200   // Creates a set, and insert `count` enum values pointed by `array` in it.
EnumSet(ElementType count,const T * array)201   EnumSet(ElementType count, const T* array) : EnumSet() {
202     for (ElementType i = 0; i < count; i++) {
203       insert(array[i]);
204     }
205   }
206 
207   // Creates a set initialized with the content of the range [begin; end[.
208   template <class InputIt>
EnumSet(InputIt begin,InputIt end)209   EnumSet(InputIt begin, InputIt end) : EnumSet() {
210     for (; begin != end; ++begin) {
211       insert(*begin);
212     }
213   }
214 
215   // Copies the EnumSet `other` into a new EnumSet.
EnumSet(const EnumSet & other)216   EnumSet(const EnumSet& other)
217       : buckets_(other.buckets_), size_(other.size_) {}
218 
219   // Moves the EnumSet `other` into a new EnumSet.
EnumSet(EnumSet && other)220   EnumSet(EnumSet&& other)
221       : buckets_(std::move(other.buckets_)), size_(other.size_) {}
222 
223   // Deep-copies the EnumSet `other` into this EnumSet.
224   EnumSet& operator=(const EnumSet& other) {
225     buckets_ = other.buckets_;
226     size_ = other.size_;
227     return *this;
228   }
229 
230   // Matches std::unordered_set::insert behavior.
insert(const T & value)231   std::pair<iterator, bool> insert(const T& value) {
232     const size_t index = FindBucketForValue(value);
233     const ElementType offset = ComputeBucketOffset(value);
234 
235     if (index >= buckets_.size() ||
236         buckets_[index].start != ComputeBucketStart(value)) {
237       size_ += 1;
238       InsertBucketFor(index, value);
239       return std::make_pair(Iterator(this, index, offset), true);
240     }
241 
242     auto& bucket = buckets_[index];
243     const auto mask = ComputeMaskForValue(value);
244     if (bucket.data & mask) {
245       return std::make_pair(Iterator(this, index, offset), false);
246     }
247 
248     size_ += 1;
249     bucket.data |= ComputeMaskForValue(value);
250     return std::make_pair(Iterator(this, index, offset), true);
251   }
252 
253   // Inserts `value` in the set if possible.
254   // Similar to `std::unordered_set::insert`, except the hint is ignored.
255   // Returns an iterator to the inserted element, or the element preventing
256   // insertion.
insert(const_iterator,const T & value)257   iterator insert(const_iterator, const T& value) {
258     return insert(value).first;
259   }
260 
261   // Inserts `value` in the set if possible.
262   // Similar to `std::unordered_set::insert`, except the hint is ignored.
263   // Returns an iterator to the inserted element, or the element preventing
264   // insertion.
insert(const_iterator,T && value)265   iterator insert(const_iterator, T&& value) { return insert(value).first; }
266 
267   // Inserts all the values in the range [`first`; `last[.
268   // Similar to `std::unordered_set::insert`.
269   template <class InputIt>
insert(InputIt first,InputIt last)270   void insert(InputIt first, InputIt last) {
271     for (auto it = first; it != last; ++it) {
272       insert(*it);
273     }
274   }
275 
276   // Removes the value `value` into the set.
277   // Similar to `std::unordered_set::erase`.
278   // Returns the number of erased elements.
erase(const T & value)279   size_t erase(const T& value) {
280     const size_t index = FindBucketForValue(value);
281     if (index >= buckets_.size() ||
282         buckets_[index].start != ComputeBucketStart(value)) {
283       return 0;
284     }
285 
286     auto& bucket = buckets_[index];
287     const auto mask = ComputeMaskForValue(value);
288     if (!(bucket.data & mask)) {
289       return 0;
290     }
291 
292     size_ -= 1;
293     bucket.data &= ~mask;
294     if (bucket.data == 0) {
295       buckets_.erase(buckets_.cbegin() + index);
296     }
297     return 1;
298   }
299 
300   // Returns true if `value` is present in the set.
contains(T value)301   bool contains(T value) const {
302     const size_t index = FindBucketForValue(value);
303     if (index >= buckets_.size() ||
304         buckets_[index].start != ComputeBucketStart(value)) {
305       return false;
306     }
307     auto& bucket = buckets_[index];
308     return bucket.data & ComputeMaskForValue(value);
309   }
310 
311   // Returns the 1 if `value` is present in the set, `0` otherwise.
count(T value)312   inline size_t count(T value) const { return contains(value) ? 1 : 0; }
313 
314   // Returns true if the set is holds no values.
empty()315   inline bool empty() const { return size_ == 0; }
316 
317   // Returns the number of enums stored in this set.
size()318   size_t size() const { return size_; }
319 
320   // Returns true if this set contains at least one value contained in `in_set`.
321   // Note: If `in_set` is empty, this function returns true.
HasAnyOf(const EnumSet<T> & in_set)322   bool HasAnyOf(const EnumSet<T>& in_set) const {
323     if (in_set.empty()) {
324       return true;
325     }
326 
327     auto lhs = buckets_.cbegin();
328     auto rhs = in_set.buckets_.cbegin();
329 
330     while (lhs != buckets_.cend() && rhs != in_set.buckets_.cend()) {
331       if (lhs->start == rhs->start) {
332         if (lhs->data & rhs->data) {
333           // At least 1 bit is shared. Early return.
334           return true;
335         }
336 
337         lhs++;
338         rhs++;
339         continue;
340       }
341 
342       // LHS bucket is smaller than the current RHS bucket. Catching up on RHS.
343       if (lhs->start < rhs->start) {
344         lhs++;
345         continue;
346       }
347 
348       // Otherwise, RHS needs to catch up on LHS.
349       rhs++;
350     }
351 
352     return false;
353   }
354 
355  private:
356   // Returns the index of the last bucket in which `value` could be stored.
ComputeLargestPossibleBucketIndexFor(T value)357   static constexpr inline size_t ComputeLargestPossibleBucketIndexFor(T value) {
358     return static_cast<size_t>(value) / kBucketSize;
359   }
360 
361   // Returns the smallest enum value that could be contained in the same bucket
362   // as `value`.
ComputeBucketStart(T value)363   static constexpr inline T ComputeBucketStart(T value) {
364     return static_cast<T>(kBucketSize *
365                           ComputeLargestPossibleBucketIndexFor(value));
366   }
367 
368   //  Returns the index of the bit that corresponds to `value` in the bucket.
ComputeBucketOffset(T value)369   static constexpr inline ElementType ComputeBucketOffset(T value) {
370     return static_cast<ElementType>(value) % kBucketSize;
371   }
372 
373   // Returns the bitmask used to represent the enum `value` in its bucket.
ComputeMaskForValue(T value)374   static constexpr inline BucketType ComputeMaskForValue(T value) {
375     return 1ULL << ComputeBucketOffset(value);
376   }
377 
378   // Returns the `enum` stored in `bucket` at `offset`.
379   // `offset` is the bit-offset in the bucket storage.
GetValueFromBucket(const Bucket & bucket,BucketType offset)380   static constexpr inline T GetValueFromBucket(const Bucket& bucket,
381                                                BucketType offset) {
382     return static_cast<T>(static_cast<ElementType>(bucket.start) + offset);
383   }
384 
385   // For a given enum `value`, finds the bucket index that could contain this
386   // value. If no such bucket is found, the index at which the new bucket should
387   // be inserted is returned.
FindBucketForValue(T value)388   size_t FindBucketForValue(T value) const {
389     // Set is empty, insert at 0.
390     if (buckets_.size() == 0) {
391       return 0;
392     }
393 
394     const T wanted_start = ComputeBucketStart(value);
395     assert(buckets_.size() > 0 &&
396            "Size must not be 0 here. Has the code above changed?");
397     size_t index = std::min(buckets_.size() - 1,
398                             ComputeLargestPossibleBucketIndexFor(value));
399 
400     // This loops behaves like std::upper_bound with a reverse iterator.
401     // Buckets are sorted. 3 main cases:
402     //  - The bucket matches
403     //    => returns the bucket index.
404     //  - The found bucket is larger
405     //    => scans left until it finds the correct bucket, or insertion point.
406     //  - The found bucket is smaller
407     //    => We are at the end, so we return past-end index for insertion.
408     for (; buckets_[index].start >= wanted_start; index--) {
409       if (index == 0) {
410         return 0;
411       }
412     }
413 
414     return index + 1;
415   }
416 
417   // Creates a new bucket to store `value` and inserts it at `index`.
418   // If the `index` is past the end, the bucket is inserted at the end of the
419   // vector.
InsertBucketFor(size_t index,T value)420   void InsertBucketFor(size_t index, T value) {
421     const T bucket_start = ComputeBucketStart(value);
422     Bucket bucket = {1ULL << ComputeBucketOffset(value), bucket_start};
423     auto it = buckets_.emplace(buckets_.begin() + index, std::move(bucket));
424 #if defined(NDEBUG)
425     (void)it;  // Silencing unused variable warning.
426 #else
427     assert(std::next(it) == buckets_.end() ||
428            std::next(it)->start > bucket_start);
429     assert(it == buckets_.begin() || std::prev(it)->start < bucket_start);
430 #endif
431   }
432 
433   // Returns true if the bucket at `bucketIndex/ stores the enum at
434   // `bucketOffset`, false otherwise.
HasEnumAt(size_t bucketIndex,BucketType bucketOffset)435   bool HasEnumAt(size_t bucketIndex, BucketType bucketOffset) const {
436     assert(bucketIndex < buckets_.size());
437     assert(bucketOffset < kBucketSize);
438     return buckets_[bucketIndex].data & (1ULL << bucketOffset);
439   }
440 
441   // Returns true if `lhs` and `rhs` hold the exact same values.
442   friend bool operator==(const EnumSet& lhs, const EnumSet& rhs) {
443     if (lhs.size_ != rhs.size_) {
444       return false;
445     }
446 
447     if (lhs.buckets_.size() != rhs.buckets_.size()) {
448       return false;
449     }
450     return lhs.buckets_ == rhs.buckets_;
451   }
452 
453   // Returns true if `lhs` and `rhs` hold at least 1 different value.
454   friend bool operator!=(const EnumSet& lhs, const EnumSet& rhs) {
455     return !(lhs == rhs);
456   }
457 
458   // Storage for the buckets.
459   std::vector<Bucket> buckets_;
460   // How many enums is this set storing.
461   size_t size_ = 0;
462 };
463 
464 // A set of spv::Capability.
465 using CapabilitySet = EnumSet<spv::Capability>;
466 
467 }  // namespace spvtools
468 
469 #endif  // SOURCE_ENUM_SET_H_
470