xref: /aosp_15_r20/external/pytorch/c10/test/util/accumulate_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 
3 #include <c10/util/accumulate.h>
4 
5 #include <gtest/gtest.h>
6 
7 #include <list>
8 #include <vector>
9 
10 using namespace ::testing;
11 
TEST(accumulateTest,vector_test)12 TEST(accumulateTest, vector_test) {
13   std::vector<int> ints = {1, 2, 3, 4, 5};
14 
15   EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5);
16   EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5);
17 
18   EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5);
19   EXPECT_EQ(
20       c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5);
21 
22   EXPECT_EQ(c10::sum_integers(ints.begin() + 1, ints.end() - 1), 2 + 3 + 4);
23   EXPECT_EQ(
24       c10::multiply_integers(ints.begin() + 1, ints.end() - 1), 2 * 3 * 4);
25 
26   EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5);
27   EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3);
28   EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4);
29   EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4);
30 }
31 
TEST(accumulateTest,list_test)32 TEST(accumulateTest, list_test) {
33   std::list<int> ints = {1, 2, 3, 4, 5};
34 
35   EXPECT_EQ(c10::sum_integers(ints), 1 + 2 + 3 + 4 + 5);
36   EXPECT_EQ(c10::multiply_integers(ints), 1 * 2 * 3 * 4 * 5);
37 
38   EXPECT_EQ(c10::sum_integers(ints.begin(), ints.end()), 1 + 2 + 3 + 4 + 5);
39   EXPECT_EQ(
40       c10::multiply_integers(ints.begin(), ints.end()), 1 * 2 * 3 * 4 * 5);
41 
42   EXPECT_EQ(c10::numelements_from_dim(2, ints), 3 * 4 * 5);
43   EXPECT_EQ(c10::numelements_to_dim(3, ints), 1 * 2 * 3);
44   EXPECT_EQ(c10::numelements_between_dim(2, 4, ints), 3 * 4);
45   EXPECT_EQ(c10::numelements_between_dim(4, 2, ints), 3 * 4);
46 }
47 
TEST(accumulateTest,base_cases)48 TEST(accumulateTest, base_cases) {
49   std::vector<int> ints = {};
50 
51   EXPECT_EQ(c10::sum_integers(ints), 0);
52   EXPECT_EQ(c10::multiply_integers(ints), 1);
53 }
54 
TEST(accumulateTest,errors)55 TEST(accumulateTest, errors) {
56   std::vector<int> ints = {1, 2, 3, 4, 5};
57 
58 #ifndef NDEBUG
59   EXPECT_THROW(c10::numelements_from_dim(-1, ints), c10::Error);
60 #endif
61 
62   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
63   EXPECT_THROW(c10::numelements_to_dim(-1, ints), c10::Error);
64   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
65   EXPECT_THROW(c10::numelements_between_dim(-1, 10, ints), c10::Error);
66   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
67   EXPECT_THROW(c10::numelements_between_dim(10, -1, ints), c10::Error);
68 
69   EXPECT_EQ(c10::numelements_from_dim(10, ints), 1);
70   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
71   EXPECT_THROW(c10::numelements_to_dim(10, ints), c10::Error);
72   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
73   EXPECT_THROW(c10::numelements_between_dim(10, 4, ints), c10::Error);
74   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
75   EXPECT_THROW(c10::numelements_between_dim(4, 10, ints), c10::Error);
76 }
77