xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/cpp_frontend_extension.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/extension.h>
2 
3 #include <cstddef>
4 #include <string>
5 
6 struct Net : torch::nn::Cloneable<Net> {
NetNet7   Net(int64_t in, int64_t out) : in_(in), out_(out) {
8     reset();
9   }
10 
resetNet11   void reset() override {
12     fc = register_module("fc", torch::nn::Linear(in_, out_));
13     buffer = register_buffer("buf", torch::eye(5));
14   }
15 
forwardNet16   torch::Tensor forward(torch::Tensor x) {
17     return fc->forward(x);
18   }
19 
set_biasNet20   void set_bias(torch::Tensor bias) {
21     torch::NoGradGuard guard;
22     fc->bias.set_(bias);
23   }
24 
get_biasNet25   torch::Tensor get_bias() const {
26     return fc->bias;
27   }
28 
add_new_parameterNet29   void add_new_parameter(const std::string& name, torch::Tensor tensor) {
30     register_parameter(name, tensor);
31   }
32 
add_new_bufferNet33   void add_new_buffer(const std::string& name, torch::Tensor tensor) {
34     register_buffer(name, tensor);
35   }
36 
add_new_submoduleNet37   void add_new_submodule(const std::string& name) {
38     register_module(name, torch::nn::Linear(fc->options));
39   }
40 
41   int64_t in_, out_;
42   torch::nn::Linear fc{nullptr};
43   torch::Tensor buffer;
44 };
45 
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)46 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
47   torch::python::bind_module<Net>(m, "Net")
48       .def(py::init<int64_t, int64_t>())
49       .def("set_bias", &Net::set_bias)
50       .def("get_bias", &Net::get_bias)
51       .def("add_new_parameter", &Net::add_new_parameter)
52       .def("add_new_buffer", &Net::add_new_buffer)
53       .def("add_new_submodule", &Net::add_new_submodule);
54 }
55