1 #include <torch/extension.h>
2
3 #include <cstddef>
4 #include <string>
5
6 struct Net : torch::nn::Cloneable<Net> {
NetNet7 Net(int64_t in, int64_t out) : in_(in), out_(out) {
8 reset();
9 }
10
resetNet11 void reset() override {
12 fc = register_module("fc", torch::nn::Linear(in_, out_));
13 buffer = register_buffer("buf", torch::eye(5));
14 }
15
forwardNet16 torch::Tensor forward(torch::Tensor x) {
17 return fc->forward(x);
18 }
19
set_biasNet20 void set_bias(torch::Tensor bias) {
21 torch::NoGradGuard guard;
22 fc->bias.set_(bias);
23 }
24
get_biasNet25 torch::Tensor get_bias() const {
26 return fc->bias;
27 }
28
add_new_parameterNet29 void add_new_parameter(const std::string& name, torch::Tensor tensor) {
30 register_parameter(name, tensor);
31 }
32
add_new_bufferNet33 void add_new_buffer(const std::string& name, torch::Tensor tensor) {
34 register_buffer(name, tensor);
35 }
36
add_new_submoduleNet37 void add_new_submodule(const std::string& name) {
38 register_module(name, torch::nn::Linear(fc->options));
39 }
40
41 int64_t in_, out_;
42 torch::nn::Linear fc{nullptr};
43 torch::Tensor buffer;
44 };
45
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)46 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
47 torch::python::bind_module<Net>(m, "Net")
48 .def(py::init<int64_t, int64_t>())
49 .def("set_bias", &Net::set_bias)
50 .def("get_bias", &Net::get_bias)
51 .def("add_new_parameter", &Net::add_new_parameter)
52 .def("add_new_buffer", &Net::add_new_buffer)
53 .def("add_new_submodule", &Net::add_new_submodule);
54 }
55