1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
8*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker template <typename R, typename Func>
test_RNN_xor(Func && model_maker,bool cuda=false)11*da0073e9SAndroid Build Coastguard Worker bool test_RNN_xor(Func&& model_maker, bool cuda = false) {
12*da0073e9SAndroid Build Coastguard Worker torch::manual_seed(0);
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker auto nhid = 32;
15*da0073e9SAndroid Build Coastguard Worker auto model = std::make_shared<SimpleContainer>();
16*da0073e9SAndroid Build Coastguard Worker auto l1 = model->add(Linear(1, nhid), "l1");
17*da0073e9SAndroid Build Coastguard Worker auto rnn_model = model_maker(nhid);
18*da0073e9SAndroid Build Coastguard Worker auto rnn = model->add(rnn_model, "rnn");
19*da0073e9SAndroid Build Coastguard Worker auto nout = nhid;
20*da0073e9SAndroid Build Coastguard Worker if (rnn_model.get()->options_base.proj_size() > 0) {
21*da0073e9SAndroid Build Coastguard Worker nout = rnn_model.get()->options_base.proj_size();
22*da0073e9SAndroid Build Coastguard Worker }
23*da0073e9SAndroid Build Coastguard Worker auto lo = model->add(Linear(nout, 1), "lo");
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker torch::optim::Adam optimizer(model->parameters(), 1e-2);
26*da0073e9SAndroid Build Coastguard Worker auto forward_op = [&](torch::Tensor x) {
27*da0073e9SAndroid Build Coastguard Worker auto T = x.size(0);
28*da0073e9SAndroid Build Coastguard Worker auto B = x.size(1);
29*da0073e9SAndroid Build Coastguard Worker x = x.view({T * B, 1});
30*da0073e9SAndroid Build Coastguard Worker x = l1->forward(x).view({T, B, nhid}).tanh_();
31*da0073e9SAndroid Build Coastguard Worker x = std::get<0>(rnn->forward(x))[T - 1];
32*da0073e9SAndroid Build Coastguard Worker x = lo->forward(x);
33*da0073e9SAndroid Build Coastguard Worker return x;
34*da0073e9SAndroid Build Coastguard Worker };
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker if (cuda) {
37*da0073e9SAndroid Build Coastguard Worker model->to(torch::kCUDA);
38*da0073e9SAndroid Build Coastguard Worker }
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker float running_loss = 1;
41*da0073e9SAndroid Build Coastguard Worker int epoch = 0;
42*da0073e9SAndroid Build Coastguard Worker auto max_epoch = 1500;
43*da0073e9SAndroid Build Coastguard Worker while (running_loss > 1e-2) {
44*da0073e9SAndroid Build Coastguard Worker auto bs = 16U;
45*da0073e9SAndroid Build Coastguard Worker auto nlen = 5U;
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker const auto backend = cuda ? torch::kCUDA : torch::kCPU;
48*da0073e9SAndroid Build Coastguard Worker auto inputs =
49*da0073e9SAndroid Build Coastguard Worker torch::rand({nlen, bs, 1}, backend).round().to(torch::kFloat32);
50*da0073e9SAndroid Build Coastguard Worker auto labels = inputs.sum(0).detach();
51*da0073e9SAndroid Build Coastguard Worker inputs.set_requires_grad(true);
52*da0073e9SAndroid Build Coastguard Worker auto outputs = forward_op(inputs);
53*da0073e9SAndroid Build Coastguard Worker torch::Tensor loss = torch::mse_loss(outputs, labels);
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad();
56*da0073e9SAndroid Build Coastguard Worker loss.backward();
57*da0073e9SAndroid Build Coastguard Worker optimizer.step();
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions)
60*da0073e9SAndroid Build Coastguard Worker running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
61*da0073e9SAndroid Build Coastguard Worker if (epoch > max_epoch) {
62*da0073e9SAndroid Build Coastguard Worker return false;
63*da0073e9SAndroid Build Coastguard Worker }
64*da0073e9SAndroid Build Coastguard Worker epoch++;
65*da0073e9SAndroid Build Coastguard Worker }
66*da0073e9SAndroid Build Coastguard Worker return true;
67*da0073e9SAndroid Build Coastguard Worker };
68*da0073e9SAndroid Build Coastguard Worker
check_lstm_sizes(std::tuple<torch::Tensor,std::tuple<torch::Tensor,torch::Tensor>> lstm_output)69*da0073e9SAndroid Build Coastguard Worker void check_lstm_sizes(
70*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>
71*da0073e9SAndroid Build Coastguard Worker lstm_output) {
72*da0073e9SAndroid Build Coastguard Worker // Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
73*da0073e9SAndroid Build Coastguard Worker // 10 and 16 time steps (10 x 16 x n)
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker torch::Tensor output = std::get<0>(lstm_output);
76*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor> state = std::get<1>(lstm_output);
77*da0073e9SAndroid Build Coastguard Worker torch::Tensor hx = std::get<0>(state);
78*da0073e9SAndroid Build Coastguard Worker torch::Tensor cx = std::get<1>(state);
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(output.ndimension(), 3);
81*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(output.size(0), 10);
82*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(output.size(1), 16);
83*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(output.size(2), 64);
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.ndimension(), 3);
86*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.size(0), 3); // layers
87*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.size(1), 16); // Batchsize
88*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.size(2), 64); // 64 hidden dims
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.ndimension(), 3);
91*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.size(0), 3); // layers
92*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.size(1), 16); // Batchsize
93*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.size(2), 64); // 64 hidden dims
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker // Something is in the hiddens
96*da0073e9SAndroid Build Coastguard Worker ASSERT_GT(hx.norm().item<float>(), 0);
97*da0073e9SAndroid Build Coastguard Worker ASSERT_GT(cx.norm().item<float>(), 0);
98*da0073e9SAndroid Build Coastguard Worker }
99*da0073e9SAndroid Build Coastguard Worker
check_lstm_sizes_proj(std::tuple<torch::Tensor,std::tuple<torch::Tensor,torch::Tensor>> lstm_output)100*da0073e9SAndroid Build Coastguard Worker void check_lstm_sizes_proj(
101*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>
102*da0073e9SAndroid Build Coastguard Worker lstm_output) {
103*da0073e9SAndroid Build Coastguard Worker // Expect the LSTM to have 32 outputs and 3 layers, with an input of batch
104*da0073e9SAndroid Build Coastguard Worker // 10 and 16 time steps (10 x 16 x n)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker torch::Tensor output = std::get<0>(lstm_output);
107*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor> state = std::get<1>(lstm_output);
108*da0073e9SAndroid Build Coastguard Worker torch::Tensor hx = std::get<0>(state);
109*da0073e9SAndroid Build Coastguard Worker torch::Tensor cx = std::get<1>(state);
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(output.ndimension(), 3);
112*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(output.size(0), 10);
113*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(output.size(1), 16);
114*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(output.size(2), 32);
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.ndimension(), 3);
117*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.size(0), 3); // layers
118*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.size(1), 16); // Batchsize
119*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.size(2), 32); // 32 hidden dims
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.ndimension(), 3);
122*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.size(0), 3); // layers
123*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.size(1), 16); // Batchsize
124*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.size(2), 64); // 64 cell dims
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker // Something is in the hiddens
127*da0073e9SAndroid Build Coastguard Worker ASSERT_GT(hx.norm().item<float>(), 0);
128*da0073e9SAndroid Build Coastguard Worker ASSERT_GT(cx.norm().item<float>(), 0);
129*da0073e9SAndroid Build Coastguard Worker }
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker struct RNNTest : torch::test::SeedingFixture {};
132*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,CheckOutputSizes)133*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, CheckOutputSizes) {
134*da0073e9SAndroid Build Coastguard Worker LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
135*da0073e9SAndroid Build Coastguard Worker // Input size is: sequence length, batch size, input size
136*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn({10, 16, 128}, torch::requires_grad());
137*da0073e9SAndroid Build Coastguard Worker auto output = model->forward(x);
138*da0073e9SAndroid Build Coastguard Worker auto y = x.mean();
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker y.backward();
141*da0073e9SAndroid Build Coastguard Worker check_lstm_sizes(output);
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker auto next = model->forward(x, std::get<1>(output));
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker check_lstm_sizes(next);
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker auto output_hx = std::get<0>(std::get<1>(output));
148*da0073e9SAndroid Build Coastguard Worker auto output_cx = std::get<1>(std::get<1>(output));
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker auto next_hx = std::get<0>(std::get<1>(next));
151*da0073e9SAndroid Build Coastguard Worker auto next_cx = std::get<1>(std::get<1>(next));
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker torch::Tensor diff =
154*da0073e9SAndroid Build Coastguard Worker torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker // Hiddens changed
157*da0073e9SAndroid Build Coastguard Worker ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
158*da0073e9SAndroid Build Coastguard Worker }
159*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,CheckOutputSizesProj)160*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, CheckOutputSizesProj) {
161*da0073e9SAndroid Build Coastguard Worker LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32));
162*da0073e9SAndroid Build Coastguard Worker // Input size is: sequence length, batch size, input size
163*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn({10, 16, 128}, torch::requires_grad());
164*da0073e9SAndroid Build Coastguard Worker auto output = model->forward(x);
165*da0073e9SAndroid Build Coastguard Worker auto y = x.mean();
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker y.backward();
168*da0073e9SAndroid Build Coastguard Worker check_lstm_sizes_proj(output);
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker auto next = model->forward(x, std::get<1>(output));
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker check_lstm_sizes_proj(next);
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker auto output_hx = std::get<0>(std::get<1>(output));
175*da0073e9SAndroid Build Coastguard Worker auto output_cx = std::get<1>(std::get<1>(output));
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker auto next_hx = std::get<0>(std::get<1>(next));
178*da0073e9SAndroid Build Coastguard Worker auto next_cx = std::get<1>(std::get<1>(next));
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker torch::Tensor diff = next_hx - output_hx;
181*da0073e9SAndroid Build Coastguard Worker // Hiddens changed
182*da0073e9SAndroid Build Coastguard Worker ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
183*da0073e9SAndroid Build Coastguard Worker diff = next_cx - output_cx;
184*da0073e9SAndroid Build Coastguard Worker ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
185*da0073e9SAndroid Build Coastguard Worker }
186*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,CheckOutputValuesMatchPyTorch)187*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) {
188*da0073e9SAndroid Build Coastguard Worker torch::manual_seed(0);
189*da0073e9SAndroid Build Coastguard Worker // Make sure the outputs match pytorch outputs
190*da0073e9SAndroid Build Coastguard Worker LSTM model(2, 2);
191*da0073e9SAndroid Build Coastguard Worker for (auto& v : model->parameters()) {
192*da0073e9SAndroid Build Coastguard Worker float size = v.numel();
193*da0073e9SAndroid Build Coastguard Worker auto p = static_cast<float*>(v.storage().mutable_data());
194*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < size; i++) {
195*da0073e9SAndroid Build Coastguard Worker p[i] = i / size;
196*da0073e9SAndroid Build Coastguard Worker }
197*da0073e9SAndroid Build Coastguard Worker }
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker auto x = torch::empty({3, 4, 2}, torch::requires_grad());
200*da0073e9SAndroid Build Coastguard Worker float size = x.numel();
201*da0073e9SAndroid Build Coastguard Worker auto p = static_cast<float*>(x.storage().mutable_data());
202*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < size; i++) {
203*da0073e9SAndroid Build Coastguard Worker p[i] = (size - i) / size;
204*da0073e9SAndroid Build Coastguard Worker }
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker auto out = model->forward(x);
207*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(std::get<0>(out).ndimension(), 3);
208*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(std::get<0>(out).size(0), 3);
209*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(std::get<0>(out).size(1), 4);
210*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(std::get<0>(out).size(2), 2);
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker auto flat = std::get<0>(out).view(3 * 4 * 2);
213*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
214*da0073e9SAndroid Build Coastguard Worker float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
215*da0073e9SAndroid Build Coastguard Worker 0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968,
216*da0073e9SAndroid Build Coastguard Worker 0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
217*da0073e9SAndroid Build Coastguard Worker 0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666};
218*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < 3 * 4 * 2; i++) {
219*da0073e9SAndroid Build Coastguard Worker ASSERT_LT(std::abs(flat[i].item<float>() - c_out[i]), 1e-3);
220*da0073e9SAndroid Build Coastguard Worker }
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker auto hx = std::get<0>(std::get<1>(out));
223*da0073e9SAndroid Build Coastguard Worker auto cx = std::get<1>(std::get<1>(out));
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.ndimension(), 3); // layers x B x 2
226*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.size(0), 1);
227*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.size(1), 4);
228*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hx.size(2), 2);
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.ndimension(), 3); // layers x B x 2
231*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.size(0), 1);
232*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.size(1), 4);
233*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(cx.size(2), 2);
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker flat = torch::cat({hx, cx}, 0).view(16);
236*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
237*da0073e9SAndroid Build Coastguard Worker float h_out[] = {
238*da0073e9SAndroid Build Coastguard Worker 0.7889,
239*da0073e9SAndroid Build Coastguard Worker 0.9003,
240*da0073e9SAndroid Build Coastguard Worker 0.7769,
241*da0073e9SAndroid Build Coastguard Worker 0.8905,
242*da0073e9SAndroid Build Coastguard Worker 0.7635,
243*da0073e9SAndroid Build Coastguard Worker 0.8794,
244*da0073e9SAndroid Build Coastguard Worker 0.7484,
245*da0073e9SAndroid Build Coastguard Worker 0.8666,
246*da0073e9SAndroid Build Coastguard Worker 1.1647,
247*da0073e9SAndroid Build Coastguard Worker 1.6106,
248*da0073e9SAndroid Build Coastguard Worker 1.1425,
249*da0073e9SAndroid Build Coastguard Worker 1.5726,
250*da0073e9SAndroid Build Coastguard Worker 1.1187,
251*da0073e9SAndroid Build Coastguard Worker 1.5329,
252*da0073e9SAndroid Build Coastguard Worker 1.0931,
253*da0073e9SAndroid Build Coastguard Worker 1.4911};
254*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < 16; i++) {
255*da0073e9SAndroid Build Coastguard Worker ASSERT_LT(std::abs(flat[i].item<float>() - h_out[i]), 1e-3);
256*da0073e9SAndroid Build Coastguard Worker }
257*da0073e9SAndroid Build Coastguard Worker }
258*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,EndToEndLSTM)259*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, EndToEndLSTM) {
260*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_RNN_xor<LSTM>(
261*da0073e9SAndroid Build Coastguard Worker [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }));
262*da0073e9SAndroid Build Coastguard Worker }
263*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,EndToEndLSTMProj)264*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, EndToEndLSTMProj) {
265*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_RNN_xor<LSTM>([](int s) {
266*da0073e9SAndroid Build Coastguard Worker return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2));
267*da0073e9SAndroid Build Coastguard Worker }));
268*da0073e9SAndroid Build Coastguard Worker }
269*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,EndToEndGRU)270*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, EndToEndGRU) {
271*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_RNN_xor<GRU>(
272*da0073e9SAndroid Build Coastguard Worker [](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }));
273*da0073e9SAndroid Build Coastguard Worker }
274*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,EndToEndRNNRelu)275*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, EndToEndRNNRelu) {
276*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_RNN_xor<RNN>([](int s) {
277*da0073e9SAndroid Build Coastguard Worker return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2));
278*da0073e9SAndroid Build Coastguard Worker }));
279*da0073e9SAndroid Build Coastguard Worker }
280*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,EndToEndRNNTanh)281*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, EndToEndRNNTanh) {
282*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_RNN_xor<RNN>([](int s) {
283*da0073e9SAndroid Build Coastguard Worker return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2));
284*da0073e9SAndroid Build Coastguard Worker }));
285*da0073e9SAndroid Build Coastguard Worker }
286*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,Sizes_CUDA)287*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, Sizes_CUDA) {
288*da0073e9SAndroid Build Coastguard Worker torch::manual_seed(0);
289*da0073e9SAndroid Build Coastguard Worker LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
290*da0073e9SAndroid Build Coastguard Worker model->to(torch::kCUDA);
291*da0073e9SAndroid Build Coastguard Worker auto x =
292*da0073e9SAndroid Build Coastguard Worker torch::randn({10, 16, 128}, torch::requires_grad().device(torch::kCUDA));
293*da0073e9SAndroid Build Coastguard Worker auto output = model->forward(x);
294*da0073e9SAndroid Build Coastguard Worker auto y = x.mean();
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker y.backward();
297*da0073e9SAndroid Build Coastguard Worker check_lstm_sizes(output);
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker auto next = model->forward(x, std::get<1>(output));
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker check_lstm_sizes(next);
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker auto output_hx = std::get<0>(std::get<1>(output));
304*da0073e9SAndroid Build Coastguard Worker auto output_cx = std::get<1>(std::get<1>(output));
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker auto next_hx = std::get<0>(std::get<1>(next));
307*da0073e9SAndroid Build Coastguard Worker auto next_cx = std::get<1>(std::get<1>(next));
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker torch::Tensor diff =
310*da0073e9SAndroid Build Coastguard Worker torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker // Hiddens changed
313*da0073e9SAndroid Build Coastguard Worker ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
314*da0073e9SAndroid Build Coastguard Worker }
315*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,SizesProj_CUDA)316*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, SizesProj_CUDA) {
317*da0073e9SAndroid Build Coastguard Worker torch::manual_seed(0);
318*da0073e9SAndroid Build Coastguard Worker LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32));
319*da0073e9SAndroid Build Coastguard Worker model->to(torch::kCUDA);
320*da0073e9SAndroid Build Coastguard Worker auto x =
321*da0073e9SAndroid Build Coastguard Worker torch::randn({10, 16, 128}, torch::requires_grad().device(torch::kCUDA));
322*da0073e9SAndroid Build Coastguard Worker auto output = model->forward(x);
323*da0073e9SAndroid Build Coastguard Worker auto y = x.mean();
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker y.backward();
326*da0073e9SAndroid Build Coastguard Worker check_lstm_sizes_proj(output);
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker auto next = model->forward(x, std::get<1>(output));
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker check_lstm_sizes_proj(next);
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker auto output_hx = std::get<0>(std::get<1>(output));
333*da0073e9SAndroid Build Coastguard Worker auto output_cx = std::get<1>(std::get<1>(output));
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker auto next_hx = std::get<0>(std::get<1>(next));
336*da0073e9SAndroid Build Coastguard Worker auto next_cx = std::get<1>(std::get<1>(next));
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker torch::Tensor diff = next_hx - output_hx;
339*da0073e9SAndroid Build Coastguard Worker // Hiddens changed
340*da0073e9SAndroid Build Coastguard Worker ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
341*da0073e9SAndroid Build Coastguard Worker diff = next_cx - output_cx;
342*da0073e9SAndroid Build Coastguard Worker ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
343*da0073e9SAndroid Build Coastguard Worker }
344*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,EndToEndLSTM_CUDA)345*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, EndToEndLSTM_CUDA) {
346*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_RNN_xor<LSTM>(
347*da0073e9SAndroid Build Coastguard Worker [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }, true));
348*da0073e9SAndroid Build Coastguard Worker }
349*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,EndToEndLSTMProj_CUDA)350*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, EndToEndLSTMProj_CUDA) {
351*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_RNN_xor<LSTM>(
352*da0073e9SAndroid Build Coastguard Worker [](int s) {
353*da0073e9SAndroid Build Coastguard Worker return LSTM(LSTMOptions(s, s).num_layers(2).proj_size(s / 2));
354*da0073e9SAndroid Build Coastguard Worker },
355*da0073e9SAndroid Build Coastguard Worker true));
356*da0073e9SAndroid Build Coastguard Worker }
357*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,EndToEndGRU_CUDA)358*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, EndToEndGRU_CUDA) {
359*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_RNN_xor<GRU>(
360*da0073e9SAndroid Build Coastguard Worker [](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }, true));
361*da0073e9SAndroid Build Coastguard Worker }
362*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,EndToEndRNNRelu_CUDA)363*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, EndToEndRNNRelu_CUDA) {
364*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_RNN_xor<RNN>(
365*da0073e9SAndroid Build Coastguard Worker [](int s) {
366*da0073e9SAndroid Build Coastguard Worker return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2));
367*da0073e9SAndroid Build Coastguard Worker },
368*da0073e9SAndroid Build Coastguard Worker true));
369*da0073e9SAndroid Build Coastguard Worker }
TEST_F(RNNTest,EndToEndRNNTanh_CUDA)370*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, EndToEndRNNTanh_CUDA) {
371*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_RNN_xor<RNN>(
372*da0073e9SAndroid Build Coastguard Worker [](int s) {
373*da0073e9SAndroid Build Coastguard Worker return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2));
374*da0073e9SAndroid Build Coastguard Worker },
375*da0073e9SAndroid Build Coastguard Worker true));
376*da0073e9SAndroid Build Coastguard Worker }
377*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,PrettyPrintRNNs)378*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, PrettyPrintRNNs) {
379*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
380*da0073e9SAndroid Build Coastguard Worker c10::str(LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2))),
381*da0073e9SAndroid Build Coastguard Worker "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
382*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
383*da0073e9SAndroid Build Coastguard Worker c10::str(
384*da0073e9SAndroid Build Coastguard Worker LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2).proj_size(32))),
385*da0073e9SAndroid Build Coastguard Worker "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false, proj_size=32)");
386*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
387*da0073e9SAndroid Build Coastguard Worker c10::str(GRU(GRUOptions(128, 64).num_layers(3).dropout(0.5))),
388*da0073e9SAndroid Build Coastguard Worker "torch::nn::GRU(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.5, bidirectional=false)");
389*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
390*da0073e9SAndroid Build Coastguard Worker c10::str(RNN(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(
391*da0073e9SAndroid Build Coastguard Worker torch::kTanh))),
392*da0073e9SAndroid Build Coastguard Worker "torch::nn::RNN(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
393*da0073e9SAndroid Build Coastguard Worker }
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker // This test assures that flatten_parameters does not crash,
396*da0073e9SAndroid Build Coastguard Worker // when bidirectional is set to true
397*da0073e9SAndroid Build Coastguard Worker // https://github.com/pytorch/pytorch/issues/19545
TEST_F(RNNTest,BidirectionalFlattenParameters)398*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, BidirectionalFlattenParameters) {
399*da0073e9SAndroid Build Coastguard Worker GRU gru(GRUOptions(100, 256).num_layers(2).bidirectional(true));
400*da0073e9SAndroid Build Coastguard Worker gru->flatten_parameters();
401*da0073e9SAndroid Build Coastguard Worker }
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker template <typename Impl>
copyParameters(torch::nn::ModuleHolder<Impl> & target,std::string t_suffix,const torch::nn::ModuleHolder<Impl> & source,std::string s_suffix)404*da0073e9SAndroid Build Coastguard Worker void copyParameters(
405*da0073e9SAndroid Build Coastguard Worker torch::nn::ModuleHolder<Impl>& target,
406*da0073e9SAndroid Build Coastguard Worker std::string t_suffix,
407*da0073e9SAndroid Build Coastguard Worker const torch::nn::ModuleHolder<Impl>& source,
408*da0073e9SAndroid Build Coastguard Worker std::string s_suffix) {
409*da0073e9SAndroid Build Coastguard Worker at::NoGradGuard guard;
410*da0073e9SAndroid Build Coastguard Worker target->named_parameters()["weight_ih_l" + t_suffix].copy_(
411*da0073e9SAndroid Build Coastguard Worker source->named_parameters()["weight_ih_l" + s_suffix]);
412*da0073e9SAndroid Build Coastguard Worker target->named_parameters()["weight_hh_l" + t_suffix].copy_(
413*da0073e9SAndroid Build Coastguard Worker source->named_parameters()["weight_hh_l" + s_suffix]);
414*da0073e9SAndroid Build Coastguard Worker target->named_parameters()["bias_ih_l" + t_suffix].copy_(
415*da0073e9SAndroid Build Coastguard Worker source->named_parameters()["bias_ih_l" + s_suffix]);
416*da0073e9SAndroid Build Coastguard Worker target->named_parameters()["bias_hh_l" + t_suffix].copy_(
417*da0073e9SAndroid Build Coastguard Worker source->named_parameters()["bias_hh_l" + s_suffix]);
418*da0073e9SAndroid Build Coastguard Worker }
419*da0073e9SAndroid Build Coastguard Worker
gru_output_to_device(std::tuple<torch::Tensor,torch::Tensor> gru_output,torch::Device device)420*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor> gru_output_to_device(
421*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor> gru_output,
422*da0073e9SAndroid Build Coastguard Worker torch::Device device) {
423*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(
424*da0073e9SAndroid Build Coastguard Worker std::get<0>(gru_output).to(device), std::get<1>(gru_output).to(device));
425*da0073e9SAndroid Build Coastguard Worker }
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>
lstm_output_to_device(std::tuple<torch::Tensor,std::tuple<torch::Tensor,torch::Tensor>> lstm_output,torch::Device device)428*da0073e9SAndroid Build Coastguard Worker lstm_output_to_device(
429*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>
430*da0073e9SAndroid Build Coastguard Worker lstm_output,
431*da0073e9SAndroid Build Coastguard Worker torch::Device device) {
432*da0073e9SAndroid Build Coastguard Worker auto hidden_states = std::get<1>(lstm_output);
433*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(
434*da0073e9SAndroid Build Coastguard Worker std::get<0>(lstm_output).to(device),
435*da0073e9SAndroid Build Coastguard Worker std::make_tuple(
436*da0073e9SAndroid Build Coastguard Worker std::get<0>(hidden_states).to(device),
437*da0073e9SAndroid Build Coastguard Worker std::get<1>(hidden_states).to(device)));
438*da0073e9SAndroid Build Coastguard Worker }
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker // This test is a port of python code introduced here:
441*da0073e9SAndroid Build Coastguard Worker // https://towardsdatascience.com/understanding-bidirectional-rnn-in-pytorch-5bd25a5dd66
442*da0073e9SAndroid Build Coastguard Worker // Reverse forward of bidirectional GRU should act
443*da0073e9SAndroid Build Coastguard Worker // as regular forward of unidirectional GRU
BidirectionalGRUReverseForward(bool cuda)444*da0073e9SAndroid Build Coastguard Worker void BidirectionalGRUReverseForward(bool cuda) {
445*da0073e9SAndroid Build Coastguard Worker auto opt = torch::TensorOptions()
446*da0073e9SAndroid Build Coastguard Worker .dtype(torch::kFloat32)
447*da0073e9SAndroid Build Coastguard Worker .requires_grad(false)
448*da0073e9SAndroid Build Coastguard Worker .device(cuda ? torch::kCUDA : torch::kCPU);
449*da0073e9SAndroid Build Coastguard Worker auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
450*da0073e9SAndroid Build Coastguard Worker auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker auto gru_options = GRUOptions(1, 1).num_layers(1).batch_first(false);
453*da0073e9SAndroid Build Coastguard Worker GRU bi_grus{gru_options.bidirectional(true)};
454*da0073e9SAndroid Build Coastguard Worker GRU reverse_gru{gru_options.bidirectional(false)};
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker if (cuda) {
457*da0073e9SAndroid Build Coastguard Worker bi_grus->to(torch::kCUDA);
458*da0073e9SAndroid Build Coastguard Worker reverse_gru->to(torch::kCUDA);
459*da0073e9SAndroid Build Coastguard Worker }
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker // Now make sure the weights of the reverse gru layer match
462*da0073e9SAndroid Build Coastguard Worker // ones of the (reversed) bidirectional's:
463*da0073e9SAndroid Build Coastguard Worker copyParameters(reverse_gru, "0", bi_grus, "0_reverse");
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker auto bi_output = bi_grus->forward(input);
466*da0073e9SAndroid Build Coastguard Worker auto reverse_output = reverse_gru->forward(input_reversed);
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker if (cuda) {
469*da0073e9SAndroid Build Coastguard Worker bi_output = gru_output_to_device(bi_output, torch::kCPU);
470*da0073e9SAndroid Build Coastguard Worker reverse_output = gru_output_to_device(reverse_output, torch::kCPU);
471*da0073e9SAndroid Build Coastguard Worker }
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
474*da0073e9SAndroid Build Coastguard Worker std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
475*da0073e9SAndroid Build Coastguard Worker auto size = std::get<0>(bi_output).size(0);
476*da0073e9SAndroid Build Coastguard Worker for (int i = 0; i < size; i++) {
477*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
478*da0073e9SAndroid Build Coastguard Worker std::get<0>(bi_output)[i][0][1].item<float>(),
479*da0073e9SAndroid Build Coastguard Worker std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
480*da0073e9SAndroid Build Coastguard Worker }
481*da0073e9SAndroid Build Coastguard Worker // The hidden states of the reversed GRUs sits
482*da0073e9SAndroid Build Coastguard Worker // in the odd indices in the first dimension.
483*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
484*da0073e9SAndroid Build Coastguard Worker std::get<1>(bi_output)[1][0][0].item<float>(),
485*da0073e9SAndroid Build Coastguard Worker std::get<1>(reverse_output)[0][0][0].item<float>());
486*da0073e9SAndroid Build Coastguard Worker }
487*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,BidirectionalGRUReverseForward)488*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, BidirectionalGRUReverseForward) {
489*da0073e9SAndroid Build Coastguard Worker BidirectionalGRUReverseForward(false);
490*da0073e9SAndroid Build Coastguard Worker }
491*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,BidirectionalGRUReverseForward_CUDA)492*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, BidirectionalGRUReverseForward_CUDA) {
493*da0073e9SAndroid Build Coastguard Worker BidirectionalGRUReverseForward(true);
494*da0073e9SAndroid Build Coastguard Worker }
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker // Reverse forward of bidirectional LSTM should act
497*da0073e9SAndroid Build Coastguard Worker // as regular forward of unidirectional LSTM
BidirectionalLSTMReverseForwardTest(bool cuda)498*da0073e9SAndroid Build Coastguard Worker void BidirectionalLSTMReverseForwardTest(bool cuda) {
499*da0073e9SAndroid Build Coastguard Worker auto opt = torch::TensorOptions()
500*da0073e9SAndroid Build Coastguard Worker .dtype(torch::kFloat32)
501*da0073e9SAndroid Build Coastguard Worker .requires_grad(false)
502*da0073e9SAndroid Build Coastguard Worker .device(cuda ? torch::kCUDA : torch::kCPU);
503*da0073e9SAndroid Build Coastguard Worker auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
504*da0073e9SAndroid Build Coastguard Worker auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker auto lstm_opt = LSTMOptions(1, 1).num_layers(1).batch_first(false);
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker LSTM bi_lstm{lstm_opt.bidirectional(true)};
509*da0073e9SAndroid Build Coastguard Worker LSTM reverse_lstm{lstm_opt.bidirectional(false)};
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker if (cuda) {
512*da0073e9SAndroid Build Coastguard Worker bi_lstm->to(torch::kCUDA);
513*da0073e9SAndroid Build Coastguard Worker reverse_lstm->to(torch::kCUDA);
514*da0073e9SAndroid Build Coastguard Worker }
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker // Now make sure the weights of the reverse lstm layer match
517*da0073e9SAndroid Build Coastguard Worker // ones of the (reversed) bidirectional's:
518*da0073e9SAndroid Build Coastguard Worker copyParameters(reverse_lstm, "0", bi_lstm, "0_reverse");
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker auto bi_output = bi_lstm->forward(input);
521*da0073e9SAndroid Build Coastguard Worker auto reverse_output = reverse_lstm->forward(input_reversed);
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker if (cuda) {
524*da0073e9SAndroid Build Coastguard Worker bi_output = lstm_output_to_device(bi_output, torch::kCPU);
525*da0073e9SAndroid Build Coastguard Worker reverse_output = lstm_output_to_device(reverse_output, torch::kCPU);
526*da0073e9SAndroid Build Coastguard Worker }
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
529*da0073e9SAndroid Build Coastguard Worker std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
530*da0073e9SAndroid Build Coastguard Worker auto size = std::get<0>(bi_output).size(0);
531*da0073e9SAndroid Build Coastguard Worker for (int i = 0; i < size; i++) {
532*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
533*da0073e9SAndroid Build Coastguard Worker std::get<0>(bi_output)[i][0][1].item<float>(),
534*da0073e9SAndroid Build Coastguard Worker std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
535*da0073e9SAndroid Build Coastguard Worker }
536*da0073e9SAndroid Build Coastguard Worker // The hidden states of the reversed LSTM sits
537*da0073e9SAndroid Build Coastguard Worker // in the odd indices in the first dimension.
538*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
539*da0073e9SAndroid Build Coastguard Worker std::get<0>(std::get<1>(bi_output))[1][0][0].item<float>(),
540*da0073e9SAndroid Build Coastguard Worker std::get<0>(std::get<1>(reverse_output))[0][0][0].item<float>());
541*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
542*da0073e9SAndroid Build Coastguard Worker std::get<1>(std::get<1>(bi_output))[1][0][0].item<float>(),
543*da0073e9SAndroid Build Coastguard Worker std::get<1>(std::get<1>(reverse_output))[0][0][0].item<float>());
544*da0073e9SAndroid Build Coastguard Worker }
545*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,BidirectionalLSTMReverseForward)546*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, BidirectionalLSTMReverseForward) {
547*da0073e9SAndroid Build Coastguard Worker BidirectionalLSTMReverseForwardTest(false);
548*da0073e9SAndroid Build Coastguard Worker }
549*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,BidirectionalLSTMReverseForward_CUDA)550*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, BidirectionalLSTMReverseForward_CUDA) {
551*da0073e9SAndroid Build Coastguard Worker BidirectionalLSTMReverseForwardTest(true);
552*da0073e9SAndroid Build Coastguard Worker }
553*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,BidirectionalMultilayerGRU_CPU_vs_CUDA)554*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
555*da0073e9SAndroid Build Coastguard Worker // Create two GRUs with the same options
556*da0073e9SAndroid Build Coastguard Worker auto opt =
557*da0073e9SAndroid Build Coastguard Worker GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
558*da0073e9SAndroid Build Coastguard Worker GRU gru_cpu{opt};
559*da0073e9SAndroid Build Coastguard Worker GRU gru_cuda{opt};
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker // Copy weights and biases from CPU GRU to CUDA GRU
562*da0073e9SAndroid Build Coastguard Worker {
563*da0073e9SAndroid Build Coastguard Worker at::NoGradGuard guard;
564*da0073e9SAndroid Build Coastguard Worker for (const auto& param : gru_cpu->named_parameters(/*recurse=*/false)) {
565*da0073e9SAndroid Build Coastguard Worker gru_cuda->named_parameters()[param.key()].copy_(
566*da0073e9SAndroid Build Coastguard Worker gru_cpu->named_parameters()[param.key()]);
567*da0073e9SAndroid Build Coastguard Worker }
568*da0073e9SAndroid Build Coastguard Worker }
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker gru_cpu->flatten_parameters();
571*da0073e9SAndroid Build Coastguard Worker gru_cuda->flatten_parameters();
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker // Move GRU to CUDA
574*da0073e9SAndroid Build Coastguard Worker gru_cuda->to(torch::kCUDA);
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker // Create the same inputs
577*da0073e9SAndroid Build Coastguard Worker auto input_opt =
578*da0073e9SAndroid Build Coastguard Worker torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false);
579*da0073e9SAndroid Build Coastguard Worker auto input_cpu =
580*da0073e9SAndroid Build Coastguard Worker torch::tensor({1, 2, 3, 4, 5, 6}, input_opt).reshape({3, 1, 2});
581*da0073e9SAndroid Build Coastguard Worker auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, input_opt)
582*da0073e9SAndroid Build Coastguard Worker .reshape({3, 1, 2})
583*da0073e9SAndroid Build Coastguard Worker .to(torch::kCUDA);
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker // Call forward on both GRUs
586*da0073e9SAndroid Build Coastguard Worker auto output_cpu = gru_cpu->forward(input_cpu);
587*da0073e9SAndroid Build Coastguard Worker auto output_cuda = gru_cuda->forward(input_cuda);
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker output_cpu = gru_output_to_device(output_cpu, torch::kCPU);
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker // Assert that the output and state are equal on CPU and CUDA
592*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
593*da0073e9SAndroid Build Coastguard Worker for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
594*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
595*da0073e9SAndroid Build Coastguard Worker std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
596*da0073e9SAndroid Build Coastguard Worker }
597*da0073e9SAndroid Build Coastguard Worker for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
598*da0073e9SAndroid Build Coastguard Worker for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
599*da0073e9SAndroid Build Coastguard Worker for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
600*da0073e9SAndroid Build Coastguard Worker ASSERT_NEAR(
601*da0073e9SAndroid Build Coastguard Worker std::get<0>(output_cpu)[i][j][k].item<float>(),
602*da0073e9SAndroid Build Coastguard Worker std::get<0>(output_cuda)[i][j][k].item<float>(),
603*da0073e9SAndroid Build Coastguard Worker 1e-5);
604*da0073e9SAndroid Build Coastguard Worker }
605*da0073e9SAndroid Build Coastguard Worker }
606*da0073e9SAndroid Build Coastguard Worker }
607*da0073e9SAndroid Build Coastguard Worker }
608*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,BidirectionalMultilayerLSTM_CPU_vs_CUDA)609*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
610*da0073e9SAndroid Build Coastguard Worker // Create two LSTMs with the same options
611*da0073e9SAndroid Build Coastguard Worker auto opt =
612*da0073e9SAndroid Build Coastguard Worker LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
613*da0073e9SAndroid Build Coastguard Worker LSTM lstm_cpu{opt};
614*da0073e9SAndroid Build Coastguard Worker LSTM lstm_cuda{opt};
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker // Copy weights and biases from CPU LSTM to CUDA LSTM
617*da0073e9SAndroid Build Coastguard Worker {
618*da0073e9SAndroid Build Coastguard Worker at::NoGradGuard guard;
619*da0073e9SAndroid Build Coastguard Worker for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) {
620*da0073e9SAndroid Build Coastguard Worker lstm_cuda->named_parameters()[param.key()].copy_(
621*da0073e9SAndroid Build Coastguard Worker lstm_cpu->named_parameters()[param.key()]);
622*da0073e9SAndroid Build Coastguard Worker }
623*da0073e9SAndroid Build Coastguard Worker }
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker lstm_cpu->flatten_parameters();
626*da0073e9SAndroid Build Coastguard Worker lstm_cuda->flatten_parameters();
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker // Move LSTM to CUDA
629*da0073e9SAndroid Build Coastguard Worker lstm_cuda->to(torch::kCUDA);
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker auto options =
632*da0073e9SAndroid Build Coastguard Worker torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false);
633*da0073e9SAndroid Build Coastguard Worker auto input_cpu =
634*da0073e9SAndroid Build Coastguard Worker torch::tensor({1, 2, 3, 4, 5, 6}, options).reshape({3, 1, 2});
635*da0073e9SAndroid Build Coastguard Worker auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, options)
636*da0073e9SAndroid Build Coastguard Worker .reshape({3, 1, 2})
637*da0073e9SAndroid Build Coastguard Worker .to(torch::kCUDA);
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker // Call forward on both LSTMs
640*da0073e9SAndroid Build Coastguard Worker auto output_cpu = lstm_cpu->forward(input_cpu);
641*da0073e9SAndroid Build Coastguard Worker auto output_cuda = lstm_cuda->forward(input_cuda);
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker output_cpu = lstm_output_to_device(output_cpu, torch::kCPU);
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker // Assert that the output and state are equal on CPU and CUDA
646*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
647*da0073e9SAndroid Build Coastguard Worker for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
648*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
649*da0073e9SAndroid Build Coastguard Worker std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
650*da0073e9SAndroid Build Coastguard Worker }
651*da0073e9SAndroid Build Coastguard Worker for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
652*da0073e9SAndroid Build Coastguard Worker for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
653*da0073e9SAndroid Build Coastguard Worker for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
654*da0073e9SAndroid Build Coastguard Worker ASSERT_NEAR(
655*da0073e9SAndroid Build Coastguard Worker std::get<0>(output_cpu)[i][j][k].item<float>(),
656*da0073e9SAndroid Build Coastguard Worker std::get<0>(output_cuda)[i][j][k].item<float>(),
657*da0073e9SAndroid Build Coastguard Worker 1e-5);
658*da0073e9SAndroid Build Coastguard Worker }
659*da0073e9SAndroid Build Coastguard Worker }
660*da0073e9SAndroid Build Coastguard Worker }
661*da0073e9SAndroid Build Coastguard Worker }
662*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,BidirectionalMultilayerLSTMProj_CPU_vs_CUDA)663*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, BidirectionalMultilayerLSTMProj_CPU_vs_CUDA) {
664*da0073e9SAndroid Build Coastguard Worker // Create two LSTMs with the same options
665*da0073e9SAndroid Build Coastguard Worker auto opt = LSTMOptions(2, 4)
666*da0073e9SAndroid Build Coastguard Worker .num_layers(3)
667*da0073e9SAndroid Build Coastguard Worker .batch_first(false)
668*da0073e9SAndroid Build Coastguard Worker .bidirectional(true)
669*da0073e9SAndroid Build Coastguard Worker .proj_size(2);
670*da0073e9SAndroid Build Coastguard Worker LSTM lstm_cpu{opt};
671*da0073e9SAndroid Build Coastguard Worker LSTM lstm_cuda{opt};
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard Worker // Copy weights and biases from CPU LSTM to CUDA LSTM
674*da0073e9SAndroid Build Coastguard Worker {
675*da0073e9SAndroid Build Coastguard Worker at::NoGradGuard guard;
676*da0073e9SAndroid Build Coastguard Worker for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) {
677*da0073e9SAndroid Build Coastguard Worker lstm_cuda->named_parameters()[param.key()].copy_(
678*da0073e9SAndroid Build Coastguard Worker lstm_cpu->named_parameters()[param.key()]);
679*da0073e9SAndroid Build Coastguard Worker }
680*da0073e9SAndroid Build Coastguard Worker }
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker lstm_cpu->flatten_parameters();
683*da0073e9SAndroid Build Coastguard Worker lstm_cuda->flatten_parameters();
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker // Move LSTM to CUDA
686*da0073e9SAndroid Build Coastguard Worker lstm_cuda->to(torch::kCUDA);
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker auto options =
689*da0073e9SAndroid Build Coastguard Worker torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false);
690*da0073e9SAndroid Build Coastguard Worker auto input_cpu =
691*da0073e9SAndroid Build Coastguard Worker torch::tensor({1, 2, 3, 4, 5, 6}, options).reshape({3, 1, 2});
692*da0073e9SAndroid Build Coastguard Worker auto input_cuda = torch::tensor({1, 2, 3, 4, 5, 6}, options)
693*da0073e9SAndroid Build Coastguard Worker .reshape({3, 1, 2})
694*da0073e9SAndroid Build Coastguard Worker .to(torch::kCUDA);
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker // Call forward on both LSTMs
697*da0073e9SAndroid Build Coastguard Worker auto output_cpu = lstm_cpu->forward(input_cpu);
698*da0073e9SAndroid Build Coastguard Worker auto output_cuda = lstm_cuda->forward(input_cuda);
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker output_cpu = lstm_output_to_device(output_cpu, torch::kCPU);
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker // Assert that the output and state are equal on CPU and CUDA
703*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
704*da0073e9SAndroid Build Coastguard Worker for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
705*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
706*da0073e9SAndroid Build Coastguard Worker std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
707*da0073e9SAndroid Build Coastguard Worker }
708*da0073e9SAndroid Build Coastguard Worker for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
709*da0073e9SAndroid Build Coastguard Worker for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
710*da0073e9SAndroid Build Coastguard Worker for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
711*da0073e9SAndroid Build Coastguard Worker ASSERT_NEAR(
712*da0073e9SAndroid Build Coastguard Worker std::get<0>(output_cpu)[i][j][k].item<float>(),
713*da0073e9SAndroid Build Coastguard Worker std::get<0>(output_cuda)[i][j][k].item<float>(),
714*da0073e9SAndroid Build Coastguard Worker 1e-5);
715*da0073e9SAndroid Build Coastguard Worker }
716*da0073e9SAndroid Build Coastguard Worker }
717*da0073e9SAndroid Build Coastguard Worker }
718*da0073e9SAndroid Build Coastguard Worker }
719*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,UsePackedSequenceAsInput)720*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, UsePackedSequenceAsInput) {
721*da0073e9SAndroid Build Coastguard Worker {
722*da0073e9SAndroid Build Coastguard Worker torch::manual_seed(0);
723*da0073e9SAndroid Build Coastguard Worker auto m = RNN(2, 3);
724*da0073e9SAndroid Build Coastguard Worker torch::nn::utils::rnn::PackedSequence packed_input =
725*da0073e9SAndroid Build Coastguard Worker torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
726*da0073e9SAndroid Build Coastguard Worker auto rnn_output = m->forward_with_packed_input(packed_input);
727*da0073e9SAndroid Build Coastguard Worker auto expected_output = torch::tensor(
728*da0073e9SAndroid Build Coastguard Worker {{-0.0645, -0.7274, 0.4531},
729*da0073e9SAndroid Build Coastguard Worker {-0.3970, -0.6950, 0.6009},
730*da0073e9SAndroid Build Coastguard Worker {-0.3877, -0.7310, 0.6806}});
731*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::allclose(
732*da0073e9SAndroid Build Coastguard Worker std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker // Test passing optional argument to `RNN::forward_with_packed_input`
735*da0073e9SAndroid Build Coastguard Worker rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
736*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::allclose(
737*da0073e9SAndroid Build Coastguard Worker std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
738*da0073e9SAndroid Build Coastguard Worker }
739*da0073e9SAndroid Build Coastguard Worker {
740*da0073e9SAndroid Build Coastguard Worker torch::manual_seed(0);
741*da0073e9SAndroid Build Coastguard Worker auto m = LSTM(2, 3);
742*da0073e9SAndroid Build Coastguard Worker torch::nn::utils::rnn::PackedSequence packed_input =
743*da0073e9SAndroid Build Coastguard Worker torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
744*da0073e9SAndroid Build Coastguard Worker auto rnn_output = m->forward_with_packed_input(packed_input);
745*da0073e9SAndroid Build Coastguard Worker auto expected_output = torch::tensor(
746*da0073e9SAndroid Build Coastguard Worker {{-0.2693, -0.1240, 0.0744},
747*da0073e9SAndroid Build Coastguard Worker {-0.3889, -0.1919, 0.1183},
748*da0073e9SAndroid Build Coastguard Worker {-0.4425, -0.2314, 0.1386}});
749*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::allclose(
750*da0073e9SAndroid Build Coastguard Worker std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker // Test passing optional argument to `LSTM::forward_with_packed_input`
753*da0073e9SAndroid Build Coastguard Worker rnn_output = m->forward_with_packed_input(packed_input, torch::nullopt);
754*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::allclose(
755*da0073e9SAndroid Build Coastguard Worker std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
756*da0073e9SAndroid Build Coastguard Worker }
757*da0073e9SAndroid Build Coastguard Worker {
758*da0073e9SAndroid Build Coastguard Worker torch::manual_seed(0);
759*da0073e9SAndroid Build Coastguard Worker auto m = GRU(2, 3);
760*da0073e9SAndroid Build Coastguard Worker torch::nn::utils::rnn::PackedSequence packed_input =
761*da0073e9SAndroid Build Coastguard Worker torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
762*da0073e9SAndroid Build Coastguard Worker auto rnn_output = m->forward_with_packed_input(packed_input);
763*da0073e9SAndroid Build Coastguard Worker auto expected_output = torch::tensor(
764*da0073e9SAndroid Build Coastguard Worker {{-0.1134, 0.0467, 0.2336},
765*da0073e9SAndroid Build Coastguard Worker {-0.1189, 0.0502, 0.2960},
766*da0073e9SAndroid Build Coastguard Worker {-0.1138, 0.0484, 0.3110}});
767*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::allclose(
768*da0073e9SAndroid Build Coastguard Worker std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker // Test passing optional argument to `GRU::forward_with_packed_input`
771*da0073e9SAndroid Build Coastguard Worker rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
772*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::allclose(
773*da0073e9SAndroid Build Coastguard Worker std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
774*da0073e9SAndroid Build Coastguard Worker }
775*da0073e9SAndroid Build Coastguard Worker }
776*da0073e9SAndroid Build Coastguard Worker
TEST_F(RNNTest,CheckErrorInfos)777*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, CheckErrorInfos) {
778*da0073e9SAndroid Build Coastguard Worker {
779*da0073e9SAndroid Build Coastguard Worker auto options = torch::nn::RNNOptions(1, 0).num_layers(1);
780*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(RNN(options), "hidden_size must be greater than zero");
781*da0073e9SAndroid Build Coastguard Worker
782*da0073e9SAndroid Build Coastguard Worker options = torch::nn::RNNOptions(1, 1).num_layers(0);
783*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(RNN(options), "num_layers must be greater than zero");
784*da0073e9SAndroid Build Coastguard Worker }
785*da0073e9SAndroid Build Coastguard Worker {
786*da0073e9SAndroid Build Coastguard Worker auto options = torch::nn::LSTMOptions(1, 0).num_layers(1);
787*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(LSTM(options), "hidden_size must be greater than zero");
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker options = torch::nn::LSTMOptions(1, 1).num_layers(0);
790*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(LSTM(options), "num_layers must be greater than zero");
791*da0073e9SAndroid Build Coastguard Worker }
792*da0073e9SAndroid Build Coastguard Worker {
793*da0073e9SAndroid Build Coastguard Worker auto options = torch::nn::GRUOptions(1, 0).num_layers(1);
794*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(GRU(options), "hidden_size must be greater than zero");
795*da0073e9SAndroid Build Coastguard Worker
796*da0073e9SAndroid Build Coastguard Worker options = torch::nn::GRUOptions(1, 1).num_layers(0);
797*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(GRU(options), "num_layers must be greater than zero");
798*da0073e9SAndroid Build Coastguard Worker }
799*da0073e9SAndroid Build Coastguard Worker }
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker // This test assures that pad_packed_sequence does not crash when packed with
802*da0073e9SAndroid Build Coastguard Worker // cuda tensors, https://github.com/pytorch/pytorch/issues/115027
TEST_F(RNNTest,CheckPadPackedSequenceWithCudaTensors_CUDA)803*da0073e9SAndroid Build Coastguard Worker TEST_F(RNNTest, CheckPadPackedSequenceWithCudaTensors_CUDA) {
804*da0073e9SAndroid Build Coastguard Worker // Create input on the GPU, sample 5x5
805*da0073e9SAndroid Build Coastguard Worker auto input = torch::randn({5, 5}).to(at::ScalarType::Float).cuda();
806*da0073e9SAndroid Build Coastguard Worker auto lengths = torch::full({5}, 5);
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker auto packed =
809*da0073e9SAndroid Build Coastguard Worker torch::nn::utils::rnn::pack_padded_sequence(input, lengths, false, false);
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker auto error = torch::nn::utils::rnn::pad_packed_sequence(packed);
812*da0073e9SAndroid Build Coastguard Worker }
813