1 #include <torch/extension.h> 2 #include <torch/torch.h> 3 4 using namespace torch::autograd; 5 6 class Identity : public Function<Identity> { 7 public: forward(AutogradContext * ctx,torch::Tensor input)8 static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input) { 9 return input; 10 } 11 backward(AutogradContext * ctx,tensor_list grad_outputs)12 static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) { 13 return {grad_outputs[0]}; 14 } 15 }; 16 identity(torch::Tensor input)17torch::Tensor identity(torch::Tensor input) { 18 return Identity::apply(input); 19 } 20 PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)21PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 m.def("identity", &identity, "identity"); 23 } 24