xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/identity.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/extension.h>
2 #include <torch/torch.h>
3 
4 using namespace torch::autograd;
5 
6 class Identity : public Function<Identity> {
7  public:
forward(AutogradContext * ctx,torch::Tensor input)8   static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input) {
9     return input;
10   }
11 
backward(AutogradContext * ctx,tensor_list grad_outputs)12   static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
13     return {grad_outputs[0]};
14   }
15 };
16 
identity(torch::Tensor input)17 torch::Tensor identity(torch::Tensor input) {
18   return Identity::apply(input);
19 }
20 
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)21 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22   m.def("identity", &identity, "identity");
23 }
24