1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/data_loader/file_data_loader.h>
10 #include <executorch/extension/tensor/tensor.h>
11 #include <executorch/extension/training/module/training_module.h>
12 #include <executorch/extension/training/optimizer/sgd.h>
13 #include <gflags/gflags.h>
14 #include <random>
15
16 #pragma clang diagnostic ignored \
17 "-Wbraced-scalar-init" // {0} below upsets clang.
18
19 using executorch::extension::FileDataLoader;
20 using executorch::extension::training::optimizer::SGD;
21 using executorch::extension::training::optimizer::SGDOptions;
22 using executorch::runtime::Error;
23 using executorch::runtime::Result;
24 DEFINE_string(model_path, "xor.pte", "Model serialized in flatbuffer format.");
25
main(int argc,char ** argv)26 int main(int argc, char** argv) {
27 gflags::ParseCommandLineFlags(&argc, &argv, true);
28 if (argc != 1) {
29 std::string msg = "Extra commandline args: ";
30 for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) {
31 msg += argv[i];
32 }
33 ET_LOG(Error, "%s", msg.c_str());
34 return 1;
35 }
36
37 // Load the model file.
38 executorch::runtime::Result<executorch::extension::FileDataLoader>
39 loader_res =
40 executorch::extension::FileDataLoader::from(FLAGS_model_path.c_str());
41 if (loader_res.error() != Error::Ok) {
42 ET_LOG(Error, "Failed to open model file: %s", FLAGS_model_path.c_str());
43 return 1;
44 }
45 auto loader = std::make_unique<executorch::extension::FileDataLoader>(
46 std::move(loader_res.get()));
47
48 auto mod = executorch::extension::training::TrainingModule(std::move(loader));
49
50 // Create full data set of input and labels.
51 std::vector<std::pair<
52 executorch::extension::TensorPtr,
53 executorch::extension::TensorPtr>>
54 data_set;
55 data_set.push_back( // XOR(1, 1) = 0
56 {executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 1}),
57 executorch::extension::make_tensor_ptr<int64_t>({1}, {0})});
58 data_set.push_back( // XOR(0, 0) = 0
59 {executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 0}),
60 executorch::extension::make_tensor_ptr<int64_t>({1}, {0})});
61 data_set.push_back( // XOR(1, 0) = 1
62 {executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 0}),
63 executorch::extension::make_tensor_ptr<int64_t>({1}, {1})});
64 data_set.push_back( // XOR(0, 1) = 1
65 {executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 1}),
66 executorch::extension::make_tensor_ptr<int64_t>({1}, {1})});
67
68 // Create optimizer.
69 // Get the params and names
70 auto param_res = mod.named_parameters("forward");
71 if (param_res.error() != Error::Ok) {
72 ET_LOG(Error, "Failed to get named parameters");
73 return 1;
74 }
75
76 SGDOptions options{0.1};
77 SGD optimizer(param_res.get(), options);
78
79 // Randomness to sample the data set.
80 std::default_random_engine URBG{std::random_device{}()};
81 std::uniform_int_distribution<int> dist{
82 0, static_cast<int>(data_set.size()) - 1};
83
84 // Train the model.
85 size_t num_epochs = 5000;
86 for (int i = 0; i < num_epochs; i++) {
87 int index = dist(URBG);
88 auto& data = data_set[index];
89 const auto& results =
90 mod.execute_forward_backward("forward", {*data.first, *data.second});
91 if (results.error() != Error::Ok) {
92 ET_LOG(Error, "Failed to execute forward_backward");
93 return 1;
94 }
95 if (i % 500 == 0 || i == num_epochs - 1) {
96 ET_LOG(
97 Info,
98 "Step %d, Loss %f, Input [%.0f, %.0f], Prediction %ld, Label %ld",
99 i,
100 results.get()[0].toTensor().const_data_ptr<float>()[0],
101 data.first->const_data_ptr<float>()[0],
102 data.first->const_data_ptr<float>()[1],
103 results.get()[1].toTensor().const_data_ptr<int64_t>()[0],
104 data.second->const_data_ptr<int64_t>()[0]);
105 }
106 optimizer.step(mod.named_gradients("forward").get());
107 }
108 }
109