1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/training/module/training_module.h>
10
11 namespace executorch {
12 namespace extension {
13 namespace training {
14
15 namespace {
16 std::string gradients_method_prefix = "__et_training_gradients_index_";
17 std::string parameters_method_prefix = "__et_training_parameters_index_";
18 std::string fqn_method_prefix = "__et_training_fqn_";
19 } // namespace
20
21 runtime::Result<std::vector<runtime::EValue>>
execute_forward_backward(const std::string & method_name,const std::vector<runtime::EValue> & input)22 TrainingModule::execute_forward_backward(
23 const std::string& method_name,
24 const std::vector<runtime::EValue>& input) {
25 // Find where the user outputs end.
26 const std::string gradients_method_name =
27 gradients_method_prefix + method_name;
28 auto res = executorch::extension::Module::execute(gradients_method_name);
29 if (!res.ok()) {
30 return res.error();
31 }
32 uint64_t grad_start = res.get()[0].toInt();
33
34 const std::string parameters_method_name =
35 parameters_method_prefix + method_name;
36 // get params start.
37 auto param_res =
38 executorch::extension::Module::execute(parameters_method_name);
39 if (!param_res.ok()) {
40 return param_res.error();
41 }
42
43 uint64_t param_start = param_res.get()[0].toInt();
44
45 // Execute the forward and backward pass.
46
47 auto outputs = torch::executor::Module::execute(method_name, input);
48 if (!outputs.ok()) {
49 return outputs.error();
50 }
51
52 // Extract the user outputs.
53 std::vector<runtime::EValue> user_outputs;
54 user_outputs.reserve(grad_start);
55 for (size_t i = 0; i < grad_start; ++i) {
56 user_outputs.push_back(outputs.get().at(i));
57 }
58
59 // Extract and store the gradients.
60 if (method_named_gradients_.find(method_name) ==
61 method_named_gradients_.end()) {
62 method_named_gradients_.insert({method_name, {}});
63
64 auto& gradients_map = method_named_gradients_.at(method_name);
65 // Get names.
66 const std::string fqn_method_name = fqn_method_prefix + method_name;
67 auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
68 if (!fqn_res.ok()) {
69 return fqn_res.error();
70 }
71 const auto& fqn_list = fqn_res.get();
72
73 // Only have to initialize the dict once because the tensors in the dict and
74 // the tensors in the method alias the same TensorImpl, so updating one will
75 // update the other.
76 size_t name_index = 0;
77 for (size_t grad_index = grad_start; grad_index < param_start;
78 ++grad_index, ++name_index) {
79 exec_aten::string_view fqn = fqn_list.at(name_index).toString();
80 gradients_map.insert({fqn, outputs.get().at(grad_index).toTensor()});
81 }
82 }
83
84 return user_outputs;
85 }
86
87 runtime::Result<const std::map<exec_aten::string_view, exec_aten::Tensor>>
named_parameters(const std::string & method_name)88 TrainingModule::named_parameters(const std::string& method_name) {
89 std::map<exec_aten::string_view, exec_aten::Tensor> named_parameters;
90 const std::string fqn_method_name = fqn_method_prefix + method_name;
91 const std::string parameters_method_name =
92 parameters_method_prefix + method_name;
93
94 // get names.
95 auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
96 if (!fqn_res.ok()) {
97 return fqn_res.error();
98 }
99 const auto& fqn_list = fqn_res.get();
100
101 // get params start.
102 auto param_res =
103 executorch::extension::Module::execute(parameters_method_name);
104 if (!param_res.ok()) {
105 return param_res.error();
106 }
107
108 uint64_t param_start = param_res.get()[0].toInt();
109
110 auto e = executorch::extension::Module::load_method(method_name);
111 if (e != runtime::Error::Ok) {
112 return e;
113 }
114 auto& method = methods_.at(method_name).method;
115
116 // create dict
117 size_t name_index = 0;
118 for (size_t param_index = param_start; param_index < method->outputs_size();
119 ++param_index, ++name_index) {
120 exec_aten::string_view fqn = fqn_list.at(name_index).toString();
121 exec_aten::Tensor param = method->get_output(param_index).toTensor();
122 named_parameters.insert({fqn, param});
123 }
124 return named_parameters;
125 }
126
127 runtime::Result<const std::map<exec_aten::string_view, exec_aten::Tensor>>
named_gradients(const std::string & method_name)128 TrainingModule::named_gradients(const std::string& method_name) {
129 if (method_named_gradients_.find(method_name) ==
130 method_named_gradients_.end()) {
131 ET_LOG(Error, "No gradients found for method %s", method_name.c_str());
132 return executorch::runtime::Error::InvalidArgument;
133 }
134 return method_named_gradients_.at(method_name);
135 }
136
137 } // namespace training
138 } // namespace extension
139 } // namespace executorch
140