1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <pybind11/pybind11.h>
10 #include <pybind11/stl.h>
11 #include <memory>
12
13 #include <ATen/Tensor.h>
14 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15 #include <torch/csrc/utils/pybind.h>
16 #include "executorch/extension/tensor/tensor.h"
17 #include "executorch/extension/training/optimizer/sgd.h"
18 #ifndef USE_ATEN_LIB
19 #include <executorch/extension/aten_util/aten_bridge.h>
20 #endif
21
22 namespace py = pybind11;
23
24 namespace executorch {
25 namespace extension {
26 namespace training {
27
28 namespace {
29
30 struct PySGD final {
PySGDexecutorch::extension::training::__anon6d04ad4f0111::PySGD31 explicit PySGD(
32 const py::dict& named_params,
33 double lr,
34 double momentum,
35 double dampening,
36 double weight_decay,
37 bool nesterov)
38 : sgd_(nullptr),
39 fqns_()
40 #ifndef USE_ATEN_LIB
41 ,
42 params_()
43 #endif
44 {
45 std::map<exec_aten::string_view, exec_aten::Tensor> cpp_inputs;
46 auto py_named_params =
47 py::cast<std::unordered_map<std::string, at::Tensor>>(named_params);
48 const auto params_size = py::len(named_params);
49 fqns_ = std::vector<std::string>();
50 fqns_.reserve(params_size);
51
52 for (auto pair : py_named_params) {
53 fqns_.push_back(pair.first);
54 exec_aten::string_view v{fqns_.back().c_str(), pair.first.size()};
55 #ifndef USE_ATEN_LIB
56 // convert at::Tensor to torch::executor::Tensor
57 params_.emplace_back(alias_tensor_ptr_to_attensor(pair.second));
58 cpp_inputs.insert({v, *params_.back()});
59 #else
60 cpp_inputs.insert({v, pair.second});
61 #endif
62 }
63 sgd_ = std::make_unique<optimizer::SGD>(
64 cpp_inputs,
65 extension::training::optimizer::SGDOptions(
66 lr, momentum, dampening, weight_decay, nesterov));
67 }
68
69 // Not needed for now, so just delete.
70 PySGD(const PySGD&) = delete;
71 PySGD& operator=(const PySGD&) = delete;
72 PySGD(PySGD&&) = delete;
73 PySGD& operator=(PySGD&&) = delete;
74
stepexecutorch::extension::training::__anon6d04ad4f0111::PySGD75 void step(const py::dict& py_dict) {
76 auto py_named_gradients =
77 py::cast<std::unordered_map<std::string, at::Tensor>>(py_dict);
78 std::map<exec_aten::string_view, exec_aten::Tensor> cpp_inputs;
79
80 std::vector<std::string> fqn;
81 #ifndef USE_ATEN_LIB
82 std::vector<TensorPtr> et_tensors;
83 #endif
84
85 // Convert python objects into cpp.
86 for (const auto& pair : py_named_gradients) {
87 fqn.push_back(pair.first);
88 auto at_tensor = pair.second;
89 // alias_etensor_to_attensor will assert on this later, so to better
90 // propogate up to python we check early and throw an exception.
91 if (!at_tensor.is_contiguous()) {
92 auto error_msg = "Gradient is not contiguous.";
93 throw std::runtime_error(error_msg);
94 }
95 #ifndef USE_ATEN_LIB
96 // convert at::Tensor to torch::executor::Tensor
97 auto temp = alias_tensor_ptr_to_attensor(at_tensor);
98 et_tensors.push_back(temp);
99 cpp_inputs.insert({pair.first.c_str(), *et_tensors.back()});
100 #else
101 cpp_inputs.insert({pair.first.c_str(), at_tensor});
102 #endif
103 }
104
105 auto err = sgd_->step(cpp_inputs);
106 if (err != runtime::Error::Ok) {
107 throw std::runtime_error("SGD step failed");
108 }
109 }
110
111 private:
112 // TODO(jakeszwe): Write an optimizer interface and use it here instead of SGD
113 // specifically.
114 std::unique_ptr<optimizer::SGD> sgd_ = nullptr;
115 std::vector<std::string> fqns_;
116
117 #ifndef USE_ATEN_LIB // Portable mode
118 std::vector<TensorPtr> params_;
119 #endif
120 ;
121 };
122
get_sgd_optimizer(const py::dict & named_params,double lr,double momentum=0,double dampening=0,double weight_decay=0,bool nesterov=false)123 static std::unique_ptr<PySGD> get_sgd_optimizer(
124 const py::dict& named_params,
125 double lr,
126 double momentum = 0,
127 double dampening = 0,
128 double weight_decay = 0,
129 bool nesterov = false) {
130 return std::make_unique<PySGD>(
131 named_params, lr, momentum, dampening, weight_decay, nesterov);
132 }
133
134 } // namespace
135
PYBIND11_MODULE(_training_lib,m)136 PYBIND11_MODULE(_training_lib, m) {
137 m.def(
138 "get_sgd_optimizer",
139 &get_sgd_optimizer,
140 py::arg("named_params"),
141 py::arg("lr") = 0.1,
142 py::arg("momentum") = 0.0,
143 py::arg("dampening") = 0.0,
144 py::arg("weight_decay") = 0.0,
145 py::arg("nesterov") = false);
146 py::class_<PySGD>(m, "ExecuTorchSGD").def("step", &PySGD::step);
147 }
148
149 } // namespace training
150 } // namespace extension
151 } // namespace executorch
152