xref: /aosp_15_r20/external/executorch/extension/training/pybindings/_training_lib.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <pybind11/pybind11.h>
10 #include <pybind11/stl.h>
11 #include <memory>
12 
13 #include <ATen/Tensor.h>
14 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15 #include <torch/csrc/utils/pybind.h>
16 #include "executorch/extension/tensor/tensor.h"
17 #include "executorch/extension/training/optimizer/sgd.h"
18 #ifndef USE_ATEN_LIB
19 #include <executorch/extension/aten_util/aten_bridge.h>
20 #endif
21 
22 namespace py = pybind11;
23 
24 namespace executorch {
25 namespace extension {
26 namespace training {
27 
28 namespace {
29 
30 struct PySGD final {
PySGDexecutorch::extension::training::__anon6d04ad4f0111::PySGD31   explicit PySGD(
32       const py::dict& named_params,
33       double lr,
34       double momentum,
35       double dampening,
36       double weight_decay,
37       bool nesterov)
38       : sgd_(nullptr),
39         fqns_()
40 #ifndef USE_ATEN_LIB
41         ,
42         params_()
43 #endif
44   {
45     std::map<exec_aten::string_view, exec_aten::Tensor> cpp_inputs;
46     auto py_named_params =
47         py::cast<std::unordered_map<std::string, at::Tensor>>(named_params);
48     const auto params_size = py::len(named_params);
49     fqns_ = std::vector<std::string>();
50     fqns_.reserve(params_size);
51 
52     for (auto pair : py_named_params) {
53       fqns_.push_back(pair.first);
54       exec_aten::string_view v{fqns_.back().c_str(), pair.first.size()};
55 #ifndef USE_ATEN_LIB
56       // convert at::Tensor to torch::executor::Tensor
57       params_.emplace_back(alias_tensor_ptr_to_attensor(pair.second));
58       cpp_inputs.insert({v, *params_.back()});
59 #else
60       cpp_inputs.insert({v, pair.second});
61 #endif
62     }
63     sgd_ = std::make_unique<optimizer::SGD>(
64         cpp_inputs,
65         extension::training::optimizer::SGDOptions(
66             lr, momentum, dampening, weight_decay, nesterov));
67   }
68 
69   // Not needed for now, so just delete.
70   PySGD(const PySGD&) = delete;
71   PySGD& operator=(const PySGD&) = delete;
72   PySGD(PySGD&&) = delete;
73   PySGD& operator=(PySGD&&) = delete;
74 
stepexecutorch::extension::training::__anon6d04ad4f0111::PySGD75   void step(const py::dict& py_dict) {
76     auto py_named_gradients =
77         py::cast<std::unordered_map<std::string, at::Tensor>>(py_dict);
78     std::map<exec_aten::string_view, exec_aten::Tensor> cpp_inputs;
79 
80     std::vector<std::string> fqn;
81 #ifndef USE_ATEN_LIB
82     std::vector<TensorPtr> et_tensors;
83 #endif
84 
85     // Convert python objects into cpp.
86     for (const auto& pair : py_named_gradients) {
87       fqn.push_back(pair.first);
88       auto at_tensor = pair.second;
89       // alias_etensor_to_attensor will assert on this later, so to better
90       // propogate up to python we check early and throw an exception.
91       if (!at_tensor.is_contiguous()) {
92         auto error_msg = "Gradient is not contiguous.";
93         throw std::runtime_error(error_msg);
94       }
95 #ifndef USE_ATEN_LIB
96       // convert at::Tensor to torch::executor::Tensor
97       auto temp = alias_tensor_ptr_to_attensor(at_tensor);
98       et_tensors.push_back(temp);
99       cpp_inputs.insert({pair.first.c_str(), *et_tensors.back()});
100 #else
101       cpp_inputs.insert({pair.first.c_str(), at_tensor});
102 #endif
103     }
104 
105     auto err = sgd_->step(cpp_inputs);
106     if (err != runtime::Error::Ok) {
107       throw std::runtime_error("SGD step failed");
108     }
109   }
110 
111  private:
112   // TODO(jakeszwe): Write an optimizer interface and use it here instead of SGD
113   // specifically.
114   std::unique_ptr<optimizer::SGD> sgd_ = nullptr;
115   std::vector<std::string> fqns_;
116 
117 #ifndef USE_ATEN_LIB // Portable mode
118   std::vector<TensorPtr> params_;
119 #endif
120   ;
121 };
122 
get_sgd_optimizer(const py::dict & named_params,double lr,double momentum=0,double dampening=0,double weight_decay=0,bool nesterov=false)123 static std::unique_ptr<PySGD> get_sgd_optimizer(
124     const py::dict& named_params,
125     double lr,
126     double momentum = 0,
127     double dampening = 0,
128     double weight_decay = 0,
129     bool nesterov = false) {
130   return std::make_unique<PySGD>(
131       named_params, lr, momentum, dampening, weight_decay, nesterov);
132 }
133 
134 } // namespace
135 
PYBIND11_MODULE(_training_lib,m)136 PYBIND11_MODULE(_training_lib, m) {
137   m.def(
138       "get_sgd_optimizer",
139       &get_sgd_optimizer,
140       py::arg("named_params"),
141       py::arg("lr") = 0.1,
142       py::arg("momentum") = 0.0,
143       py::arg("dampening") = 0.0,
144       py::arg("weight_decay") = 0.0,
145       py::arg("nesterov") = false);
146   py::class_<PySGD>(m, "ExecuTorchSGD").def("step", &PySGD::step);
147 }
148 
149 } // namespace training
150 } // namespace extension
151 } // namespace executorch
152