xref: /aosp_15_r20/external/executorch/extension/training/examples/XOR/train.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/data_loader/file_data_loader.h>
10 #include <executorch/extension/tensor/tensor.h>
11 #include <executorch/extension/training/module/training_module.h>
12 #include <executorch/extension/training/optimizer/sgd.h>
13 #include <gflags/gflags.h>
14 #include <random>
15 
16 #pragma clang diagnostic ignored \
17     "-Wbraced-scalar-init" // {0} below upsets clang.
18 
19 using executorch::extension::FileDataLoader;
20 using executorch::extension::training::optimizer::SGD;
21 using executorch::extension::training::optimizer::SGDOptions;
22 using executorch::runtime::Error;
23 using executorch::runtime::Result;
24 DEFINE_string(model_path, "xor.pte", "Model serialized in flatbuffer format.");
25 
main(int argc,char ** argv)26 int main(int argc, char** argv) {
27   gflags::ParseCommandLineFlags(&argc, &argv, true);
28   if (argc != 1) {
29     std::string msg = "Extra commandline args: ";
30     for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) {
31       msg += argv[i];
32     }
33     ET_LOG(Error, "%s", msg.c_str());
34     return 1;
35   }
36 
37   // Load the model file.
38   executorch::runtime::Result<executorch::extension::FileDataLoader>
39       loader_res =
40           executorch::extension::FileDataLoader::from(FLAGS_model_path.c_str());
41   if (loader_res.error() != Error::Ok) {
42     ET_LOG(Error, "Failed to open model file: %s", FLAGS_model_path.c_str());
43     return 1;
44   }
45   auto loader = std::make_unique<executorch::extension::FileDataLoader>(
46       std::move(loader_res.get()));
47 
48   auto mod = executorch::extension::training::TrainingModule(std::move(loader));
49 
50   // Create full data set of input and labels.
51   std::vector<std::pair<
52       executorch::extension::TensorPtr,
53       executorch::extension::TensorPtr>>
54       data_set;
55   data_set.push_back( // XOR(1, 1) = 0
56       {executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 1}),
57        executorch::extension::make_tensor_ptr<int64_t>({1}, {0})});
58   data_set.push_back( // XOR(0, 0) = 0
59       {executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 0}),
60        executorch::extension::make_tensor_ptr<int64_t>({1}, {0})});
61   data_set.push_back( // XOR(1, 0) = 1
62       {executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 0}),
63        executorch::extension::make_tensor_ptr<int64_t>({1}, {1})});
64   data_set.push_back( // XOR(0, 1) = 1
65       {executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 1}),
66        executorch::extension::make_tensor_ptr<int64_t>({1}, {1})});
67 
68   // Create optimizer.
69   // Get the params and names
70   auto param_res = mod.named_parameters("forward");
71   if (param_res.error() != Error::Ok) {
72     ET_LOG(Error, "Failed to get named parameters");
73     return 1;
74   }
75 
76   SGDOptions options{0.1};
77   SGD optimizer(param_res.get(), options);
78 
79   // Randomness to sample the data set.
80   std::default_random_engine URBG{std::random_device{}()};
81   std::uniform_int_distribution<int> dist{
82       0, static_cast<int>(data_set.size()) - 1};
83 
84   // Train the model.
85   size_t num_epochs = 5000;
86   for (int i = 0; i < num_epochs; i++) {
87     int index = dist(URBG);
88     auto& data = data_set[index];
89     const auto& results =
90         mod.execute_forward_backward("forward", {*data.first, *data.second});
91     if (results.error() != Error::Ok) {
92       ET_LOG(Error, "Failed to execute forward_backward");
93       return 1;
94     }
95     if (i % 500 == 0 || i == num_epochs - 1) {
96       ET_LOG(
97           Info,
98           "Step %d, Loss %f, Input [%.0f, %.0f], Prediction %ld, Label %ld",
99           i,
100           results.get()[0].toTensor().const_data_ptr<float>()[0],
101           data.first->const_data_ptr<float>()[0],
102           data.first->const_data_ptr<float>()[1],
103           results.get()[1].toTensor().const_data_ptr<int64_t>()[0],
104           data.second->const_data_ptr<int64_t>()[0]);
105     }
106     optimizer.step(mod.named_gradients("forward").get());
107   }
108 }
109