1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/UnfoldBackward.h>
3
4 #ifndef AT_PER_OPERATOR_HEADERS
5 #include <ATen/Functions.h>
6 #include <ATen/NativeFunctions.h>
7 #else
8 #include <ATen/ops/empty.h>
9 #include <ATen/ops/unfold_backward_native.h>
10 #include <ATen/ops/zeros.h>
11 #endif
12
13 namespace at::native {
14
15 DEFINE_DISPATCH(unfold_backward_stub);
16
unfold_backward(const Tensor & grad,IntArrayRef input_sizes,int64_t dim,int64_t size,int64_t step)17 Tensor unfold_backward(
18 const Tensor& grad,
19 IntArrayRef input_sizes,
20 int64_t dim,
21 int64_t size,
22 int64_t step
23 ) {
24 auto grad_input = at::zeros(input_sizes, grad.options());
25 if (step >= size) {
26 auto gI_unfolded = grad_input.unfold(dim, size, step);
27 gI_unfolded.copy_(grad);
28 return grad_input;
29 }
30
31 unfold_backward_stub(
32 grad.device().type(),
33 grad_input,
34 grad,
35 dim, size, step
36 );
37
38 return grad_input;
39 }
40
41 } // namespace at::native
42