xref: /aosp_15_r20/external/pytorch/test/cpp/api/parameterlist.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <algorithm>
7 #include <memory>
8 #include <vector>
9 
10 #include <test/cpp/api/support.h>
11 
12 using namespace torch::nn;
13 using namespace torch::test;
14 
15 struct ParameterListTest : torch::test::SeedingFixture {};
16 
TEST_F(ParameterListTest,ConstructsFromSharedPointer)17 TEST_F(ParameterListTest, ConstructsFromSharedPointer) {
18   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
19   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
20   torch::Tensor tc = torch::randn({1, 2});
21   ASSERT_TRUE(ta.requires_grad());
22   ASSERT_FALSE(tb.requires_grad());
23   ParameterList list(ta, tb, tc);
24   ASSERT_EQ(list->size(), 3);
25 }
26 
TEST_F(ParameterListTest,isEmpty)27 TEST_F(ParameterListTest, isEmpty) {
28   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
29   ParameterList list;
30   ASSERT_TRUE(list->is_empty());
31   list->append(ta);
32   ASSERT_FALSE(list->is_empty());
33   ASSERT_EQ(list->size(), 1);
34 }
35 
TEST_F(ParameterListTest,PushBackAddsAnElement)36 TEST_F(ParameterListTest, PushBackAddsAnElement) {
37   ParameterList list;
38   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
39   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
40   torch::Tensor tc = torch::randn({1, 2});
41   torch::Tensor td = torch::randn({1, 2, 3});
42   ASSERT_EQ(list->size(), 0);
43   ASSERT_TRUE(list->is_empty());
44   list->append(ta);
45   ASSERT_EQ(list->size(), 1);
46   list->append(tb);
47   ASSERT_EQ(list->size(), 2);
48   list->append(tc);
49   ASSERT_EQ(list->size(), 3);
50   list->append(td);
51   ASSERT_EQ(list->size(), 4);
52 }
TEST_F(ParameterListTest,ForEachLoop)53 TEST_F(ParameterListTest, ForEachLoop) {
54   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
55   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
56   torch::Tensor tc = torch::randn({1, 2});
57   torch::Tensor td = torch::randn({1, 2, 3});
58   ParameterList list(ta, tb, tc, td);
59   std::vector<torch::Tensor> params = {ta, tb, tc, td};
60   ASSERT_EQ(list->size(), 4);
61   int idx = 0;
62   for (const auto& pair : *list) {
63     ASSERT_TRUE(
64         torch::all(torch::eq(pair.value(), params[idx++])).item<bool>());
65   }
66 }
67 
TEST_F(ParameterListTest,AccessWithAt)68 TEST_F(ParameterListTest, AccessWithAt) {
69   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
70   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
71   torch::Tensor tc = torch::randn({1, 2});
72   torch::Tensor td = torch::randn({1, 2, 3});
73   std::vector<torch::Tensor> params = {ta, tb, tc, td};
74 
75   ParameterList list;
76   for (auto& param : params) {
77     list->append(param);
78   }
79   ASSERT_EQ(list->size(), 4);
80 
81   // returns the correct module for a given index
82   for (const auto i : c10::irange(params.size())) {
83     ASSERT_TRUE(torch::all(torch::eq(list->at(i), params[i])).item<bool>());
84   }
85 
86   for (const auto i : c10::irange(params.size())) {
87     ASSERT_TRUE(torch::all(torch::eq(list[i], params[i])).item<bool>());
88   }
89 
90   // throws for a bad index
91   ASSERT_THROWS_WITH(list->at(params.size() + 100), "Index out of range");
92   ASSERT_THROWS_WITH(list->at(params.size() + 1), "Index out of range");
93   ASSERT_THROWS_WITH(list[params.size() + 1], "Index out of range");
94 }
95 
TEST_F(ParameterListTest,ExtendPushesParametersFromOtherParameterList)96 TEST_F(ParameterListTest, ExtendPushesParametersFromOtherParameterList) {
97   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
98   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
99   torch::Tensor tc = torch::randn({1, 2});
100   torch::Tensor td = torch::randn({1, 2, 3});
101   torch::Tensor te = torch::randn({1, 2});
102   torch::Tensor tf = torch::randn({1, 2, 3});
103   ParameterList a(ta, tb);
104   ParameterList b(tc, td);
105   a->extend(*b);
106 
107   ASSERT_EQ(a->size(), 4);
108   ASSERT_TRUE(torch::all(torch::eq(a[0], ta)).item<bool>());
109   ASSERT_TRUE(torch::all(torch::eq(a[1], tb)).item<bool>());
110   ASSERT_TRUE(torch::all(torch::eq(a[2], tc)).item<bool>());
111   ASSERT_TRUE(torch::all(torch::eq(a[3], td)).item<bool>());
112 
113   ASSERT_EQ(b->size(), 2);
114   ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>());
115   ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>());
116 
117   std::vector<torch::Tensor> c = {te, tf};
118   b->extend(c);
119 
120   ASSERT_EQ(b->size(), 4);
121   ASSERT_TRUE(torch::all(torch::eq(b[0], tc)).item<bool>());
122   ASSERT_TRUE(torch::all(torch::eq(b[1], td)).item<bool>());
123   ASSERT_TRUE(torch::all(torch::eq(b[2], te)).item<bool>());
124   ASSERT_TRUE(torch::all(torch::eq(b[3], tf)).item<bool>());
125 }
126 
TEST_F(ParameterListTest,PrettyPrintParameterList)127 TEST_F(ParameterListTest, PrettyPrintParameterList) {
128   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
129   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
130   torch::Tensor tc = torch::randn({1, 2});
131   ParameterList list(ta, tb, tc);
132   ASSERT_EQ(
133       c10::str(list),
134       "torch::nn::ParameterList(\n"
135       "(0): Parameter containing: [Float of size [1, 2]]\n"
136       "(1): Parameter containing: [Float of size [1, 2]]\n"
137       "(2): Parameter containing: [Float of size [1, 2]]\n"
138       ")");
139 }
140 
TEST_F(ParameterListTest,IncrementAdd)141 TEST_F(ParameterListTest, IncrementAdd) {
142   torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
143   torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
144   torch::Tensor tc = torch::randn({1, 2});
145   torch::Tensor td = torch::randn({1, 2, 3});
146   torch::Tensor te = torch::randn({1, 2});
147   torch::Tensor tf = torch::randn({1, 2, 3});
148   ParameterList listA(ta, tb, tc);
149   ParameterList listB(td, te, tf);
150   std::vector<torch::Tensor> tensors{ta, tb, tc, td, te, tf};
151   int idx = 0;
152   *listA += *listB;
153   ASSERT_TRUE(torch::all(torch::eq(listA[0], ta)).item<bool>());
154   ASSERT_TRUE(torch::all(torch::eq(listA[1], tb)).item<bool>());
155   ASSERT_TRUE(torch::all(torch::eq(listA[2], tc)).item<bool>());
156   ASSERT_TRUE(torch::all(torch::eq(listA[3], td)).item<bool>());
157   ASSERT_TRUE(torch::all(torch::eq(listA[4], te)).item<bool>());
158   ASSERT_TRUE(torch::all(torch::eq(listA[5], tf)).item<bool>());
159   for (const auto& P : listA->named_parameters(false))
160     ASSERT_TRUE(torch::all(torch::eq(P.value(), tensors[idx++])).item<bool>());
161 
162   ASSERT_EQ(idx, 6);
163 }
164