1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef TEST_LIBCXX_ALGORITHMS_ALG_SORTING_ASSERT_SORT_INVALID_COMPARATOR_INVALID_COMPARATOR_UTILITIES_H
10 #define TEST_LIBCXX_ALGORITHMS_ALG_SORTING_ASSERT_SORT_INVALID_COMPARATOR_INVALID_COMPARATOR_UTILITIES_H
11 
12 #include <algorithm>
13 #include <cassert>
14 #include <cstddef>
15 #include <map>
16 #include <ranges>
17 #include <set>
18 #include <string>
19 #include <string_view>
20 #include <vector>
21 
22 class ComparisonResults {
23 public:
ComparisonResults(std::string_view data)24   explicit ComparisonResults(std::string_view data) {
25     for (auto line :
26          std::views::split(data, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
27       auto values                     = std::views::split(line, ' ');
28       auto it                         = values.begin();
29       std::size_t left                = std::stol(std::string((*it).data(), (*it).size()));
30       it                              = std::next(it);
31       std::size_t right               = std::stol(std::string((*it).data(), (*it).size()));
32       it                              = std::next(it);
33       bool result                     = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
34       comparison_results[left][right] = result;
35     }
36   }
37 
compare(size_t * left,size_t * right)38   bool compare(size_t* left, size_t* right) const {
39     assert(left != nullptr && right != nullptr && "something is wrong with the test");
40     assert(comparison_results.contains(*left) && comparison_results.at(*left).contains(*right) &&
41            "malformed input data?");
42     return comparison_results.at(*left).at(*right);
43   }
44 
size()45   size_t size() const { return comparison_results.size(); }
46 
47 private:
48   std::map<std::size_t, std::map<std::size_t, bool>>
49       comparison_results; // terrible for performance, but really convenient
50 };
51 
52 class SortingFixture {
53 public:
SortingFixture(std::string_view data)54   explicit SortingFixture(std::string_view data) : comparison_results_(data) {
55     for (std::size_t i = 0; i != comparison_results_.size(); ++i) {
56       elements_.push_back(std::make_unique<std::size_t>(i));
57       valid_ptrs_.insert(elements_.back().get());
58     }
59   }
60 
create_elements()61   std::vector<std::size_t*> create_elements() {
62     std::vector<std::size_t*> copy;
63     for (auto const& e : elements_)
64       copy.push_back(e.get());
65     return copy;
66   }
67 
checked_predicate()68   auto checked_predicate() {
69     return [this](size_t* left, size_t* right) {
70       // If the pointers passed to the comparator are not in the set of pointers we
71       // set up above, then we're being passed garbage values from the algorithm
72       // because we're reading OOB.
73       assert(valid_ptrs_.contains(left));
74       assert(valid_ptrs_.contains(right));
75       return comparison_results_.compare(left, right);
76     };
77   }
78 
79 private:
80   ComparisonResults comparison_results_;
81   std::vector<std::unique_ptr<std::size_t>> elements_;
82   std::set<std::size_t*> valid_ptrs_;
83 };
84 
85 #endif // TEST_LIBCXX_ALGORITHMS_ALG_SORTING_ASSERT_SORT_INVALID_COMPARATOR_INVALID_COMPARATOR_UTILITIES_H
86