1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorIterator.h>
4 #include <ATen/TensorOperators.h>
5 #include <ATen/native/Activation.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/cat.h>
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/glu_backward_native.h>
14 #include <ATen/ops/glu_backward_jvp_native.h>
15 #include <ATen/ops/glu_jvp_native.h>
16 #include <ATen/ops/glu_native.h>
17 #include <ATen/ops/sigmoid.h>
18 #endif
19
20 namespace at::meta {
21
TORCH_META_FUNC(glu)22 TORCH_META_FUNC(glu) (
23 const Tensor& self, int64_t dim
24 ) {
25 // this can't pass anyway because a 0-dimensional tensor has "size" 1, which
26 // can't be evenly halved, but give a nicer error message here.
27 TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors");
28 auto wrap_dim = maybe_wrap_dim(dim, self.dim());
29 const int64_t nIn = self.size(wrap_dim);
30 TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ",
31 wrap_dim, " is size ", nIn);
32
33 // size output to half of input
34 const int64_t selfSize = nIn / 2;
35 Tensor firstHalf = self.narrow(wrap_dim, 0, selfSize);
36 Tensor secondHalf = self.narrow(wrap_dim, selfSize, selfSize);
37 build_borrowing_binary_op(maybe_get_output(), firstHalf, secondHalf);
38 }
39 } // namespace at::meta
40
41 namespace at::native {
42
43 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
44 DEFINE_DISPATCH(glu_stub);
45 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
46 DEFINE_DISPATCH(glu_backward_stub);
47 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
48 DEFINE_DISPATCH(glu_jvp_stub);
49
TORCH_IMPL_FUNC(glu_out)50 TORCH_IMPL_FUNC(glu_out) (const Tensor& self, int64_t dim, const Tensor& out) {
51 glu_stub(device_type(), *this);
52 }
53
glu_backward_cpu_out(const Tensor & grad_output,const Tensor & input,int64_t dim,Tensor & grad_input)54 Tensor& glu_backward_cpu_out(const Tensor& grad_output, const Tensor& input,
55 int64_t dim, Tensor& grad_input) {
56 TORCH_CHECK(input.dim() > 0, "glu does not support 0-dimensional tensors");
57 auto wrap_dim = maybe_wrap_dim(dim, input.dim());
58 const int64_t nIn = input.size(wrap_dim);
59 TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ",
60 wrap_dim, " is size ", nIn);
61
62 grad_input.resize_as_(input);
63 const int64_t inputSize = nIn / 2;
64 // half tensor
65 Tensor firstHalf = input.narrow(wrap_dim, 0, inputSize);
66 Tensor secondHalf = input.narrow(wrap_dim, inputSize, inputSize);
67 Tensor gradInputfirstHalf = grad_input.narrow(wrap_dim, 0, inputSize);
68 Tensor gradInputsecondHalf = grad_input.narrow(wrap_dim, inputSize, inputSize);
69
70 at::sigmoid_out(gradInputfirstHalf, secondHalf);
71 // for second gradinput half, can get a better performance by fusion
72 auto iter = at::TensorIteratorConfig()
73 .add_output(gradInputsecondHalf)
74 .add_const_input(gradInputfirstHalf)
75 .add_const_input(firstHalf)
76 .add_const_input(grad_output)
77 .build();
78 glu_backward_stub(iter.device_type(), iter);
79 gradInputfirstHalf.mul_(grad_output);
80 return grad_input;
81 }
82
glu_backward_cpu(const Tensor & grad_output,const Tensor & input,int64_t dim)83 Tensor glu_backward_cpu(const Tensor& grad_output, const Tensor& input, int64_t dim) {
84 auto grad_input = at::empty({0}, input.options());
85 return glu_backward_cpu_out(grad_output, input, dim, grad_input);
86 }
87
glu_jvp(const Tensor & glu,const Tensor & x,const Tensor & dx,int64_t dim)88 Tensor glu_jvp(
89 const Tensor& glu,
90 const Tensor& x,
91 const Tensor& dx,
92 int64_t dim
93 ) {
94 dim = maybe_wrap_dim(dim, x.dim());
95 const auto glu_size = glu.size(dim);
96 const auto b = x.narrow(dim, glu_size, glu_size);
97 const auto da = dx.narrow(dim, 0, glu_size);
98 const auto db = dx.narrow(dim, glu_size, glu_size);
99 auto dglu = at::empty_like(glu);
100 auto iter = at::TensorIteratorConfig()
101 .add_output(dglu)
102 .add_const_input(glu)
103 .add_const_input(b)
104 .add_const_input(da)
105 .add_const_input(db)
106 .build();
107 glu_jvp_stub(iter.device_type(), iter);
108 return dglu;
109 }
110
glu_backward_jvp(const Tensor & grad_x,const Tensor & grad_glu,const Tensor & x,const Tensor & dgrad_glu,const Tensor & dx,int64_t dim)111 Tensor glu_backward_jvp(
112 const Tensor& grad_x,
113 const Tensor& grad_glu,
114 const Tensor& x,
115 const Tensor& dgrad_glu,
116 const Tensor& dx,
117 int64_t dim
118 ) {
119 dim = maybe_wrap_dim(dim, x.dim());
120 const auto glu_size = grad_glu.size(dim);
121 const auto a = x.narrow(dim, 0, glu_size);
122 const auto b = x.narrow(dim, glu_size, glu_size);
123 const auto da = dx.narrow(dim, 0, glu_size);
124 const auto db = dx.narrow(dim, glu_size, glu_size);
125 // grad_x_a = grad_glu * sigmoid(b)
126 const auto grad_x_a = grad_x.narrow(dim, 0, glu_size);
127 // grad_x_b = grad_x_a * a * (1 - sigmoid(b))
128 const auto grad_x_b = grad_x.narrow(dim, glu_size, glu_size);
129
130 const auto sig_b = at::sigmoid(b);
131 // TODO: use glu from forward.
132 // TODO: fuse kernels.
133 const auto glu = a * sig_b;
134 const auto db_neg_sig_b = db - db * sig_b;
135
136 // dgrad_x_a = d(grad_glu * sigmoid(b))
137 // = dgrad_glu * sigmoid(b) + grad_glu * sigmoid(b) * (1 - sigmoid(b)) * db
138 // = dgrad_glu * sig_b + grad_x_a * (db - db * sig_b)
139 // = dgrad_glu * sig_b + grad_x_a * db_neg_sig_b
140 const auto dgrad_x_a = dgrad_glu * sig_b + grad_x_a * db_neg_sig_b;
141
142 // dgrad_x_b = d(grad_glu * sigmoid(b) * a * (1 - sigmoid(b))
143 // = d(grad_glu * sigmoid(b)) * a * (1 - sigmoid(b))
144 // + grad_glu * sigmoid(b) * da * (1 - sigmoid(b))
145 // - grad_glu * sigmoid(b) * a * sigmoid(b) * (1 - sigmoid(b)) * db
146 // = dgrad_x_a * a * (1 - sigmoid(b))
147 // + (grad_glu * sigmoid(b)) * (da * (1 - sigmoid(b)) - a * sigmoid(b) * (1 - sigmoid(b)) * db)
148 // = dgrad_x_a * (a - glu) + grad_x_a * (da - da * sig_b - glu * db_neg_sig_b
149 const auto dgrad_x_b = dgrad_x_a * (a - glu) + grad_x_a * (da - da * sig_b - glu * db_neg_sig_b);
150
151 return at::cat({dgrad_x_a, dgrad_x_b}, dim);
152 }
153
154
155 } // namespace at::native
156