xref: /aosp_15_r20/external/executorch/extension/training/optimizer/test/sgd_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/training/optimizer/sgd.h>
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12 #include <executorch/runtime/core/span.h>
13 #include <executorch/runtime/platform/runtime.h>
14 
15 #include <gtest/gtest.h>
16 
17 // @lint-ignore-every CLANGTIDY facebook-hte-CArray
18 
19 using namespace ::testing;
20 using exec_aten::ScalarType;
21 using exec_aten::Tensor;
22 using ::executorch::extension::training::optimizer::SGD;
23 using ::executorch::extension::training::optimizer::SGDOptions;
24 using ::executorch::extension::training::optimizer::SGDParamState;
25 using ::executorch::runtime::Error;
26 using ::executorch::runtime::testing::TensorFactory;
27 
28 class SGDOptimizerTest : public ::testing::Test {
29  protected:
SetUp()30   void SetUp() override {
31     torch::executor::runtime_init();
32   }
33 };
34 
TEST_F(SGDOptimizerTest,SGDParamStateTest)35 TEST_F(SGDOptimizerTest, SGDParamStateTest) {
36   TensorFactory<ScalarType::Int> tf;
37   Tensor momentum_buffer = tf.make({2, 2}, {1, 2, 3, 4});
38   SGDParamState state(momentum_buffer);
39 
40   auto data_p = state.momentum_buffer().const_data_ptr<int32_t>();
41 
42   ASSERT_EQ(data_p[0], 1);
43   ASSERT_EQ(data_p[1], 2);
44   ASSERT_EQ(data_p[2], 3);
45   ASSERT_EQ(data_p[3], 4);
46 }
47 
TEST_F(SGDOptimizerTest,SGDOptionsNonDefaultValuesTest)48 TEST_F(SGDOptimizerTest, SGDOptionsNonDefaultValuesTest) {
49   SGDOptions options(0.1, 1.0, 2.0, 3.0, true);
50 
51   EXPECT_EQ(options.lr(), 0.1);
52   EXPECT_EQ(options.momentum(), 1.0);
53   EXPECT_EQ(options.dampening(), 2.0);
54   EXPECT_EQ(options.weight_decay(), 3.0);
55   EXPECT_TRUE(options.nesterov());
56 }
57 
TEST_F(SGDOptimizerTest,SGDOptionsDefaultValuesTest)58 TEST_F(SGDOptimizerTest, SGDOptionsDefaultValuesTest) {
59   SGDOptions options(0.1);
60 
61   EXPECT_EQ(options.lr(), 0.1);
62   EXPECT_EQ(options.momentum(), 0);
63   EXPECT_EQ(options.dampening(), 0);
64   EXPECT_EQ(options.weight_decay(), 0);
65   EXPECT_TRUE(!options.nesterov());
66 }
67 
TEST_F(SGDOptimizerTest,SGDOptimizerSimple)68 TEST_F(SGDOptimizerTest, SGDOptimizerSimple) {
69   TensorFactory<ScalarType::Float> tf;
70 
71   std::map<exec_aten::string_view, exec_aten::Tensor> named_parameters;
72   std::map<exec_aten::string_view, exec_aten::Tensor> named_gradients;
73 
74   named_parameters.insert({"param1", tf.make({1, 1}, {1})});
75 
76   // dummy gradient of -1 for all epochs
77   named_gradients.insert({"param1", tf.make({1, 1}, {-1})});
78 
79   SGD optimizer(named_parameters, SGDOptions{0.1});
80 
81   for (int i = 0; i < 10; ++i) {
82     optimizer.step(named_gradients);
83   }
84 
85   auto p1 = static_cast<const float*>(
86       named_parameters.at("param1").unsafeGetTensorImpl()->data());
87   EXPECT_NEAR(p1[0], 2.0, 0.1);
88 }
89 
TEST_F(SGDOptimizerTest,SGDOptimizerComplex)90 TEST_F(SGDOptimizerTest, SGDOptimizerComplex) {
91   TensorFactory<ScalarType::Float> tf;
92 
93   std::map<exec_aten::string_view, exec_aten::Tensor> named_parameters;
94 
95   named_parameters.insert({"param1", tf.make({1, 1}, {1.0})});
96   named_parameters.insert({"param2", tf.make({1, 1}, {2.0})});
97 
98   SGD optimizer(named_parameters, SGDOptions{0.1, 0.1, 0, 2, true});
99 
100   for (int i = 0; i < 10; ++i) {
101     std::map<exec_aten::string_view, exec_aten::Tensor> named_gradients;
102     // dummy gradient of -1 for all epochs
103     named_gradients.insert({"param1", tf.make({1, 1}, {-1})});
104     named_gradients.insert({"param2", tf.make({1, 1}, {-1})});
105     optimizer.step(named_gradients);
106   }
107 
108   auto p1 =
109       static_cast<const float*>(named_parameters.at("param1").const_data_ptr());
110   auto p2 =
111       static_cast<const float*>(named_parameters.at("param2").const_data_ptr());
112   EXPECT_NEAR(p1[0], 0.540303, 0.1);
113   EXPECT_NEAR(p2[0], 0.620909, 0.1);
114 }
115