1 // Copyright 2004-present Facebook. All Rights Reserved.
2
3 #include <c10/util/accumulate.h>
4
5 #include <gtest/gtest.h>
6
7 #include <list>
8 #include <vector>
9
10 using namespace ::testing;
11
TEST(accumulateTest,vector_test)12 TEST(accumulateTest, vector_test) {
13 std::vector<int> ints = {1, 2, 3, 4, 5};
14
15 EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5);
16 EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5);
17
18 EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5);
19 EXPECT_EQ(
20 c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5);
21
22 EXPECT_EQ(c10::sum_integers(ints.begin() + 1, ints.end() - 1), 2 + 3 + 4);
23 EXPECT_EQ(
24 c10::multiply_integers(ints.begin() + 1, ints.end() - 1), 2 * 3 * 4);
25
26 EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5);
27 EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3);
28 EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4);
29 EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4);
30 }
31
TEST(accumulateTest,list_test)32 TEST(accumulateTest, list_test) {
33 std::list<int> ints = {1, 2, 3, 4, 5};
34
35 EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5);
36 EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5);
37
38 EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5);
39 EXPECT_EQ(
40 c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5);
41
42 EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5);
43 EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3);
44 EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4);
45 EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4);
46 }
47
TEST(accumulateTest,base_cases)48 TEST(accumulateTest, base_cases) {
49 std::vector<int> ints = {};
50
51 EXPECT_EQ(c10::sum_integers(ints), 0);
52 EXPECT_EQ(c10::multiply_integers(ints), 1);
53 }
54
TEST(accumulateTest,errors)55 TEST(accumulateTest, errors) {
56 std::vector<int> ints = {1, 2, 3, 4, 5};
57
58 #ifndef NDEBUG
59 EXPECT_THROW(c10::numelements_from_dim(-1, ints), c10::Error);
60 #endif
61
62 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
63 EXPECT_THROW(c10::numelements_to_dim(-1, ints), c10::Error);
64 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
65 EXPECT_THROW(c10::numelements_between_dim(-1, 10, ints), c10::Error);
66 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
67 EXPECT_THROW(c10::numelements_between_dim(10, -1, ints), c10::Error);
68
69 EXPECT_EQ(c10::numelements_from_dim(10, ints), 1);
70 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
71 EXPECT_THROW(c10::numelements_to_dim(10, ints), c10::Error);
72 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
73 EXPECT_THROW(c10::numelements_between_dim(10, 4, ints), c10::Error);
74 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
75 EXPECT_THROW(c10::numelements_between_dim(4, 10, ints), c10::Error);
76 }
77