1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/training/optimizer/sgd.h>
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12 #include <executorch/runtime/core/span.h>
13 #include <executorch/runtime/platform/runtime.h>
14
15 #include <gtest/gtest.h>
16
17 // @lint-ignore-every CLANGTIDY facebook-hte-CArray
18
19 using namespace ::testing;
20 using exec_aten::ScalarType;
21 using exec_aten::Tensor;
22 using ::executorch::extension::training::optimizer::SGD;
23 using ::executorch::extension::training::optimizer::SGDOptions;
24 using ::executorch::extension::training::optimizer::SGDParamState;
25 using ::executorch::runtime::Error;
26 using ::executorch::runtime::testing::TensorFactory;
27
28 class SGDOptimizerTest : public ::testing::Test {
29 protected:
SetUp()30 void SetUp() override {
31 torch::executor::runtime_init();
32 }
33 };
34
TEST_F(SGDOptimizerTest,SGDParamStateTest)35 TEST_F(SGDOptimizerTest, SGDParamStateTest) {
36 TensorFactory<ScalarType::Int> tf;
37 Tensor momentum_buffer = tf.make({2, 2}, {1, 2, 3, 4});
38 SGDParamState state(momentum_buffer);
39
40 auto data_p = state.momentum_buffer().const_data_ptr<int32_t>();
41
42 ASSERT_EQ(data_p[0], 1);
43 ASSERT_EQ(data_p[1], 2);
44 ASSERT_EQ(data_p[2], 3);
45 ASSERT_EQ(data_p[3], 4);
46 }
47
TEST_F(SGDOptimizerTest,SGDOptionsNonDefaultValuesTest)48 TEST_F(SGDOptimizerTest, SGDOptionsNonDefaultValuesTest) {
49 SGDOptions options(0.1, 1.0, 2.0, 3.0, true);
50
51 EXPECT_EQ(options.lr(), 0.1);
52 EXPECT_EQ(options.momentum(), 1.0);
53 EXPECT_EQ(options.dampening(), 2.0);
54 EXPECT_EQ(options.weight_decay(), 3.0);
55 EXPECT_TRUE(options.nesterov());
56 }
57
TEST_F(SGDOptimizerTest,SGDOptionsDefaultValuesTest)58 TEST_F(SGDOptimizerTest, SGDOptionsDefaultValuesTest) {
59 SGDOptions options(0.1);
60
61 EXPECT_EQ(options.lr(), 0.1);
62 EXPECT_EQ(options.momentum(), 0);
63 EXPECT_EQ(options.dampening(), 0);
64 EXPECT_EQ(options.weight_decay(), 0);
65 EXPECT_TRUE(!options.nesterov());
66 }
67
TEST_F(SGDOptimizerTest,SGDOptimizerSimple)68 TEST_F(SGDOptimizerTest, SGDOptimizerSimple) {
69 TensorFactory<ScalarType::Float> tf;
70
71 std::map<exec_aten::string_view, exec_aten::Tensor> named_parameters;
72 std::map<exec_aten::string_view, exec_aten::Tensor> named_gradients;
73
74 named_parameters.insert({"param1", tf.make({1, 1}, {1})});
75
76 // dummy gradient of -1 for all epochs
77 named_gradients.insert({"param1", tf.make({1, 1}, {-1})});
78
79 SGD optimizer(named_parameters, SGDOptions{0.1});
80
81 for (int i = 0; i < 10; ++i) {
82 optimizer.step(named_gradients);
83 }
84
85 auto p1 = static_cast<const float*>(
86 named_parameters.at("param1").unsafeGetTensorImpl()->data());
87 EXPECT_NEAR(p1[0], 2.0, 0.1);
88 }
89
TEST_F(SGDOptimizerTest,SGDOptimizerComplex)90 TEST_F(SGDOptimizerTest, SGDOptimizerComplex) {
91 TensorFactory<ScalarType::Float> tf;
92
93 std::map<exec_aten::string_view, exec_aten::Tensor> named_parameters;
94
95 named_parameters.insert({"param1", tf.make({1, 1}, {1.0})});
96 named_parameters.insert({"param2", tf.make({1, 1}, {2.0})});
97
98 SGD optimizer(named_parameters, SGDOptions{0.1, 0.1, 0, 2, true});
99
100 for (int i = 0; i < 10; ++i) {
101 std::map<exec_aten::string_view, exec_aten::Tensor> named_gradients;
102 // dummy gradient of -1 for all epochs
103 named_gradients.insert({"param1", tf.make({1, 1}, {-1})});
104 named_gradients.insert({"param2", tf.make({1, 1}, {-1})});
105 optimizer.step(named_gradients);
106 }
107
108 auto p1 =
109 static_cast<const float*>(named_parameters.at("param1").const_data_ptr());
110 auto p2 =
111 static_cast<const float*>(named_parameters.at("param2").const_data_ptr());
112 EXPECT_NEAR(p1[0], 0.540303, 0.1);
113 EXPECT_NEAR(p2[0], 0.620909, 0.1);
114 }
115