xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/UnfoldBackward.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/UnfoldBackward.h>
3 
4 #ifndef AT_PER_OPERATOR_HEADERS
5 #include <ATen/Functions.h>
6 #include <ATen/NativeFunctions.h>
7 #else
8 #include <ATen/ops/empty.h>
9 #include <ATen/ops/unfold_backward_native.h>
10 #include <ATen/ops/zeros.h>
11 #endif
12 
13 namespace at::native {
14 
15 DEFINE_DISPATCH(unfold_backward_stub);
16 
unfold_backward(const Tensor & grad,IntArrayRef input_sizes,int64_t dim,int64_t size,int64_t step)17 Tensor unfold_backward(
18   const Tensor& grad,
19   IntArrayRef input_sizes,
20   int64_t dim,
21   int64_t size,
22   int64_t step
23 ) {
24   auto grad_input = at::zeros(input_sizes, grad.options());
25   if (step >= size) {
26     auto gI_unfolded = grad_input.unfold(dim, size, step);
27     gI_unfolded.copy_(grad);
28     return grad_input;
29   }
30 
31   unfold_backward_stub(
32     grad.device().type(),
33     grad_input,
34     grad,
35     dim, size, step
36   );
37 
38   return grad_input;
39 }
40 
41 } // namespace at::native
42