1 //===----------------------------------------------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is dual licensed under the MIT and the University of Illinois Open
6 // Source Licenses. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 // <numeric>
11 // UNSUPPORTED: c++98, c++03, c++11, c++14
12 
13 // template<class InputIterator, class OutputIterator, class T,
14 //          class BinaryOperation, class UnaryOperation>
15 //   OutputIterator transform_exclusive_scan(InputIterator first, InputIterator last,
16 //                                           OutputIterator result, T init,
17 //                                           BinaryOperation binary_op,
18 //                                           UnaryOperation unary_op);
19 
20 
21 #include <numeric>
22 #include <algorithm>
23 #include <cassert>
24 #include <functional>
25 #include <iostream>
26 #include <iterator>
27 #include <vector>
28 
29 #include "test_iterators.h"
30 
31 struct add_one {
32     template <typename T>
operator ()add_one33     constexpr auto operator()(T x) const noexcept {
34         return static_cast<T>(x + 1);
35     }
36 };
37 
38 template <class Iter1, class BOp, class UOp, class T, class Iter2>
39 void
test(Iter1 first,Iter1 last,BOp bop,UOp uop,T init,Iter2 rFirst,Iter2 rLast)40 test(Iter1 first, Iter1 last, BOp bop, UOp uop, T init, Iter2 rFirst, Iter2 rLast)
41 {
42     std::vector<typename std::iterator_traits<Iter1>::value_type> v;
43 //  Test not in-place
44     std::transform_exclusive_scan(first, last, std::back_inserter(v), init, bop, uop);
45     assert(std::equal(v.begin(), v.end(), rFirst, rLast));
46 
47 //  Test in-place
48     v.clear();
49     v.assign(first, last);
50     std::transform_exclusive_scan(v.begin(), v.end(), v.begin(), init, bop, uop);
51     assert(std::equal(v.begin(), v.end(), rFirst, rLast));
52 }
53 
54 
55 template <class Iter>
56 void
test()57 test()
58 {
59           int ia[]     = { 1,  3,  5,    7,   9 };
60     const int pResI0[] = { 0,  2,  6,   12,  20 };        // with add_one
61     const int mResI0[] = { 0,  0,  0,    0,   0 };
62     const int pResN0[] = { 0, -1, -4,   -9, -16 };        // with negate
63     const int mResN0[] = { 0,  0,  0,    0,   0 };
64     const int pResI2[] = { 2,  4,  8,   14,  22 };        // with add_one
65     const int mResI2[] = { 2,  4, 16,   96, 768 };
66     const int pResN2[] = { 2,  1,  -2,  -7, -14 };        // with negate
67     const int mResN2[] = { 2, -2,   6, -30, 210 };
68     const unsigned sa = sizeof(ia) / sizeof(ia[0]);
69     static_assert(sa == sizeof(pResI0) / sizeof(pResI0[0]));       // just to be sure
70     static_assert(sa == sizeof(mResI0) / sizeof(mResI0[0]));       // just to be sure
71     static_assert(sa == sizeof(pResN0) / sizeof(pResN0[0]));       // just to be sure
72     static_assert(sa == sizeof(mResN0) / sizeof(mResN0[0]));       // just to be sure
73     static_assert(sa == sizeof(pResI2) / sizeof(pResI2[0]));       // just to be sure
74     static_assert(sa == sizeof(mResI2) / sizeof(mResI2[0]));       // just to be sure
75     static_assert(sa == sizeof(pResN2) / sizeof(pResN2[0]));       // just to be sure
76     static_assert(sa == sizeof(mResN2) / sizeof(mResN2[0]));       // just to be sure
77 
78     for (unsigned int i = 0; i < sa; ++i ) {
79         test(Iter(ia), Iter(ia + i), std::plus<>(),       add_one{},       0, pResI0, pResI0 + i);
80         test(Iter(ia), Iter(ia + i), std::multiplies<>(), add_one{},       0, mResI0, mResI0 + i);
81         test(Iter(ia), Iter(ia + i), std::plus<>(),       std::negate<>(), 0, pResN0, pResN0 + i);
82         test(Iter(ia), Iter(ia + i), std::multiplies<>(), std::negate<>(), 0, mResN0, mResN0 + i);
83         test(Iter(ia), Iter(ia + i), std::plus<>(),       add_one{},       2, pResI2, pResI2 + i);
84         test(Iter(ia), Iter(ia + i), std::multiplies<>(), add_one{},       2, mResI2, mResI2 + i);
85         test(Iter(ia), Iter(ia + i), std::plus<>(),       std::negate<>(), 2, pResN2, pResN2 + i);
86         test(Iter(ia), Iter(ia + i), std::multiplies<>(), std::negate<>(), 2, mResN2, mResN2 + i);
87         }
88 }
89 
triangle(size_t n)90 size_t triangle(size_t n) { return n*(n+1)/2; }
91 
92 //  Basic sanity
basic_tests()93 void basic_tests()
94 {
95     {
96     std::vector<size_t> v(10);
97     std::fill(v.begin(), v.end(), 3);
98     std::transform_exclusive_scan(v.begin(), v.end(), v.begin(), size_t{50}, std::plus<>(), add_one{});
99     for (size_t i = 0; i < v.size(); ++i)
100         assert(v[i] == 50 + i * 4);
101     }
102 
103     {
104     std::vector<size_t> v(10);
105     std::iota(v.begin(), v.end(), 0);
106     std::transform_exclusive_scan(v.begin(), v.end(), v.begin(), size_t{30}, std::plus<>(), add_one{});
107     for (size_t i = 0; i < v.size(); ++i)
108         assert(v[i] == 30 + triangle(i - 1) + i);
109     }
110 
111     {
112     std::vector<size_t> v(10);
113     std::iota(v.begin(), v.end(), 1);
114     std::transform_exclusive_scan(v.begin(), v.end(), v.begin(), size_t{40}, std::plus<>(), add_one{});
115     for (size_t i = 0; i < v.size(); ++i)
116         assert(v[i] == 40 + triangle(i) + i);
117     }
118 
119     {
120     std::vector<size_t> v, res;
121     std::transform_exclusive_scan(v.begin(), v.end(), std::back_inserter(res), size_t{40}, std::plus<>(), add_one{});
122     assert(res.empty());
123     }
124 
125 //  Make sure that the calculations are done using the init typedef
126     {
127     std::vector<unsigned char> v(10);
128     std::iota(v.begin(), v.end(), static_cast<unsigned char>(1));
129     std::vector<size_t> res;
130     std::transform_exclusive_scan(v.begin(), v.end(), std::back_inserter(res), size_t{1}, std::multiplies<>(), add_one{});
131 
132     assert(res.size() == 10);
133     size_t j = 1;
134     assert(res[0] == 1);
135     for (size_t i = 1; i < res.size(); ++i)
136     {
137         j *= i + 1;
138         assert(res[i] == j);
139     }
140     }
141 }
142 
main()143 int main()
144 {
145     basic_tests();
146 
147 //  All the iterator categories
148     test<input_iterator        <const int*> >();
149     test<forward_iterator      <const int*> >();
150     test<bidirectional_iterator<const int*> >();
151     test<random_access_iterator<const int*> >();
152     test<const int*>();
153     test<      int*>();
154 }
155