xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Operators.h>
2 #include <ATen/native/CPUFallback.h>
3 #include <torch/csrc/lazy/ts_backend/ts_autograd_functions.h>
4 #include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h>
5 
6 namespace torch {
7 namespace lazy {
8 
forward(torch::autograd::AutogradContext * ctx,at::Tensor self,at::IntArrayRef kernel_size,at::IntArrayRef stride,at::IntArrayRef padding,at::IntArrayRef dilation,bool ceil_mode)9 at::Tensor MaxPool3dAutogradFunctionTS::forward(
10     torch::autograd::AutogradContext* ctx,
11     at::Tensor self,
12     at::IntArrayRef kernel_size,
13     at::IntArrayRef stride,
14     at::IntArrayRef padding,
15     at::IntArrayRef dilation,
16     bool ceil_mode) {
17   ctx->saved_data["kernel_size"] = kernel_size;
18   ctx->saved_data["stride"] = stride;
19   ctx->saved_data["padding"] = padding;
20   ctx->saved_data["dilation"] = dilation;
21   ctx->saved_data["ceil_mode"] = ceil_mode;
22   auto results = at::native::
23       call_fallback_fn<&ltc_eager_fallback, ATEN_OP(max_pool3d_with_indices)>::
24           call(self, kernel_size, stride, padding, dilation, ceil_mode);
25   ctx->save_for_backward({self, std::get<1>(results)});
26   return std::get<0>(results);
27 }
28 
backward(torch::autograd::AutogradContext * ctx,torch::autograd::variable_list grad_output)29 torch::autograd::variable_list MaxPool3dAutogradFunctionTS::backward(
30     torch::autograd::AutogradContext* ctx,
31     torch::autograd::variable_list grad_output) {
32   auto kernel_size = ctx->saved_data["kernel_size"].toIntList().vec();
33   auto stride = ctx->saved_data["stride"].toIntList().vec();
34   auto padding = ctx->saved_data["padding"].toIntList().vec();
35   auto dilation = ctx->saved_data["dilation"].toIntList().vec();
36   auto ceil_mode = ctx->saved_data["ceil_mode"].toBool();
37   auto saved = ctx->get_saved_variables();
38   auto self = saved[0];
39   at::Tensor grad;
40   auto indices = saved[1];
41   grad = at::native::call_fallback_fn<
42       &ltc_eager_fallback,
43       ATEN_OP(max_pool3d_with_indices_backward)>::
44       call(
45           grad_output[0],
46           self,
47           kernel_size,
48           stride,
49           padding,
50           dilation,
51           ceil_mode,
52           indices);
53 
54   at::Tensor undef;
55   torch::autograd::variable_list grad_inputs = {
56       grad, undef, undef, undef, undef, undef};
57   return grad_inputs;
58 }
59 
60 } // namespace lazy
61 } // namespace torch
62