xref: /aosp_15_r20/external/executorch/extension/training/module/training_module.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/training/module/training_module.h>
10 
11 namespace executorch {
12 namespace extension {
13 namespace training {
14 
15 namespace {
16 std::string gradients_method_prefix = "__et_training_gradients_index_";
17 std::string parameters_method_prefix = "__et_training_parameters_index_";
18 std::string fqn_method_prefix = "__et_training_fqn_";
19 } // namespace
20 
21 runtime::Result<std::vector<runtime::EValue>>
execute_forward_backward(const std::string & method_name,const std::vector<runtime::EValue> & input)22 TrainingModule::execute_forward_backward(
23     const std::string& method_name,
24     const std::vector<runtime::EValue>& input) {
25   // Find where the user outputs end.
26   const std::string gradients_method_name =
27       gradients_method_prefix + method_name;
28   auto res = executorch::extension::Module::execute(gradients_method_name);
29   if (!res.ok()) {
30     return res.error();
31   }
32   uint64_t grad_start = res.get()[0].toInt();
33 
34   const std::string parameters_method_name =
35       parameters_method_prefix + method_name;
36   // get params start.
37   auto param_res =
38       executorch::extension::Module::execute(parameters_method_name);
39   if (!param_res.ok()) {
40     return param_res.error();
41   }
42 
43   uint64_t param_start = param_res.get()[0].toInt();
44 
45   // Execute the forward and backward pass.
46 
47   auto outputs = torch::executor::Module::execute(method_name, input);
48   if (!outputs.ok()) {
49     return outputs.error();
50   }
51 
52   // Extract the user outputs.
53   std::vector<runtime::EValue> user_outputs;
54   user_outputs.reserve(grad_start);
55   for (size_t i = 0; i < grad_start; ++i) {
56     user_outputs.push_back(outputs.get().at(i));
57   }
58 
59   // Extract and store the gradients.
60   if (method_named_gradients_.find(method_name) ==
61       method_named_gradients_.end()) {
62     method_named_gradients_.insert({method_name, {}});
63 
64     auto& gradients_map = method_named_gradients_.at(method_name);
65     // Get names.
66     const std::string fqn_method_name = fqn_method_prefix + method_name;
67     auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
68     if (!fqn_res.ok()) {
69       return fqn_res.error();
70     }
71     const auto& fqn_list = fqn_res.get();
72 
73     // Only have to initialize the dict once because the tensors in the dict and
74     // the tensors in the method alias the same TensorImpl, so updating one will
75     // update the other.
76     size_t name_index = 0;
77     for (size_t grad_index = grad_start; grad_index < param_start;
78          ++grad_index, ++name_index) {
79       exec_aten::string_view fqn = fqn_list.at(name_index).toString();
80       gradients_map.insert({fqn, outputs.get().at(grad_index).toTensor()});
81     }
82   }
83 
84   return user_outputs;
85 }
86 
87 runtime::Result<const std::map<exec_aten::string_view, exec_aten::Tensor>>
named_parameters(const std::string & method_name)88 TrainingModule::named_parameters(const std::string& method_name) {
89   std::map<exec_aten::string_view, exec_aten::Tensor> named_parameters;
90   const std::string fqn_method_name = fqn_method_prefix + method_name;
91   const std::string parameters_method_name =
92       parameters_method_prefix + method_name;
93 
94   // get names.
95   auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
96   if (!fqn_res.ok()) {
97     return fqn_res.error();
98   }
99   const auto& fqn_list = fqn_res.get();
100 
101   // get params start.
102   auto param_res =
103       executorch::extension::Module::execute(parameters_method_name);
104   if (!param_res.ok()) {
105     return param_res.error();
106   }
107 
108   uint64_t param_start = param_res.get()[0].toInt();
109 
110   auto e = executorch::extension::Module::load_method(method_name);
111   if (e != runtime::Error::Ok) {
112     return e;
113   }
114   auto& method = methods_.at(method_name).method;
115 
116   // create dict
117   size_t name_index = 0;
118   for (size_t param_index = param_start; param_index < method->outputs_size();
119        ++param_index, ++name_index) {
120     exec_aten::string_view fqn = fqn_list.at(name_index).toString();
121     exec_aten::Tensor param = method->get_output(param_index).toTensor();
122     named_parameters.insert({fqn, param});
123   }
124   return named_parameters;
125 }
126 
127 runtime::Result<const std::map<exec_aten::string_view, exec_aten::Tensor>>
named_gradients(const std::string & method_name)128 TrainingModule::named_gradients(const std::string& method_name) {
129   if (method_named_gradients_.find(method_name) ==
130       method_named_gradients_.end()) {
131     ET_LOG(Error, "No gradients found for method %s", method_name.c_str());
132     return executorch::runtime::Error::InvalidArgument;
133   }
134   return method_named_gradients_.at(method_name);
135 }
136 
137 } // namespace training
138 } // namespace extension
139 } // namespace executorch
140