xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/GatedLinearUnit.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorIterator.h>
4 #include <ATen/TensorOperators.h>
5 #include <ATen/native/Activation.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/cat.h>
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/glu_backward_native.h>
14 #include <ATen/ops/glu_backward_jvp_native.h>
15 #include <ATen/ops/glu_jvp_native.h>
16 #include <ATen/ops/glu_native.h>
17 #include <ATen/ops/sigmoid.h>
18 #endif
19 
20 namespace at::meta {
21 
TORCH_META_FUNC(glu)22 TORCH_META_FUNC(glu) (
23     const Tensor& self, int64_t dim
24 ) {
25   // this can't pass anyway because a 0-dimensional tensor has "size" 1, which
26   // can't be evenly halved, but give a nicer error message here.
27   TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors");
28   auto wrap_dim = maybe_wrap_dim(dim, self.dim());
29   const int64_t nIn = self.size(wrap_dim);
30   TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ",
31               wrap_dim, " is size ", nIn);
32 
33   // size output to half of input
34   const int64_t selfSize = nIn / 2;
35   Tensor firstHalf = self.narrow(wrap_dim, 0, selfSize);
36   Tensor secondHalf = self.narrow(wrap_dim, selfSize, selfSize);
37   build_borrowing_binary_op(maybe_get_output(), firstHalf, secondHalf);
38 }
39 } // namespace at::meta
40 
41 namespace at::native {
42 
43 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
44 DEFINE_DISPATCH(glu_stub);
45 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
46 DEFINE_DISPATCH(glu_backward_stub);
47 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
48 DEFINE_DISPATCH(glu_jvp_stub);
49 
TORCH_IMPL_FUNC(glu_out)50 TORCH_IMPL_FUNC(glu_out) (const Tensor& self, int64_t dim, const Tensor& out) {
51   glu_stub(device_type(), *this);
52 }
53 
glu_backward_cpu_out(const Tensor & grad_output,const Tensor & input,int64_t dim,Tensor & grad_input)54 Tensor& glu_backward_cpu_out(const Tensor& grad_output, const Tensor& input,
55                              int64_t dim, Tensor& grad_input) {
56   TORCH_CHECK(input.dim() > 0, "glu does not support 0-dimensional tensors");
57   auto wrap_dim = maybe_wrap_dim(dim, input.dim());
58   const int64_t nIn = input.size(wrap_dim);
59   TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ",
60               wrap_dim, " is size ", nIn);
61 
62   grad_input.resize_as_(input);
63   const int64_t inputSize = nIn / 2;
64   // half tensor
65   Tensor firstHalf = input.narrow(wrap_dim, 0, inputSize);
66   Tensor secondHalf = input.narrow(wrap_dim, inputSize, inputSize);
67   Tensor gradInputfirstHalf = grad_input.narrow(wrap_dim, 0, inputSize);
68   Tensor gradInputsecondHalf = grad_input.narrow(wrap_dim, inputSize, inputSize);
69 
70   at::sigmoid_out(gradInputfirstHalf, secondHalf);
71   // for second gradinput half, can get a better performance by fusion
72   auto iter = at::TensorIteratorConfig()
73     .add_output(gradInputsecondHalf)
74     .add_const_input(gradInputfirstHalf)
75     .add_const_input(firstHalf)
76     .add_const_input(grad_output)
77     .build();
78   glu_backward_stub(iter.device_type(), iter);
79   gradInputfirstHalf.mul_(grad_output);
80   return grad_input;
81 }
82 
glu_backward_cpu(const Tensor & grad_output,const Tensor & input,int64_t dim)83 Tensor glu_backward_cpu(const Tensor& grad_output, const Tensor& input, int64_t dim) {
84   auto grad_input = at::empty({0}, input.options());
85   return glu_backward_cpu_out(grad_output, input, dim, grad_input);
86 }
87 
glu_jvp(const Tensor & glu,const Tensor & x,const Tensor & dx,int64_t dim)88 Tensor glu_jvp(
89     const Tensor& glu,
90     const Tensor& x,
91     const Tensor& dx,
92     int64_t dim
93 ) {
94   dim = maybe_wrap_dim(dim, x.dim());
95   const auto glu_size = glu.size(dim);
96   const auto b = x.narrow(dim, glu_size, glu_size);
97   const auto da = dx.narrow(dim, 0, glu_size);
98   const auto db = dx.narrow(dim, glu_size, glu_size);
99   auto dglu = at::empty_like(glu);
100   auto iter = at::TensorIteratorConfig()
101     .add_output(dglu)
102     .add_const_input(glu)
103     .add_const_input(b)
104     .add_const_input(da)
105     .add_const_input(db)
106     .build();
107   glu_jvp_stub(iter.device_type(), iter);
108   return dglu;
109 }
110 
glu_backward_jvp(const Tensor & grad_x,const Tensor & grad_glu,const Tensor & x,const Tensor & dgrad_glu,const Tensor & dx,int64_t dim)111 Tensor glu_backward_jvp(
112     const Tensor& grad_x,
113     const Tensor& grad_glu,
114     const Tensor& x,
115     const Tensor& dgrad_glu,
116     const Tensor& dx,
117     int64_t dim
118 ) {
119   dim = maybe_wrap_dim(dim, x.dim());
120   const auto glu_size = grad_glu.size(dim);
121   const auto a = x.narrow(dim, 0, glu_size);
122   const auto b = x.narrow(dim, glu_size, glu_size);
123   const auto da = dx.narrow(dim, 0, glu_size);
124   const auto db = dx.narrow(dim, glu_size, glu_size);
125   // grad_x_a = grad_glu * sigmoid(b)
126   const auto grad_x_a = grad_x.narrow(dim, 0, glu_size);
127   // grad_x_b = grad_x_a * a * (1 - sigmoid(b))
128   const auto grad_x_b = grad_x.narrow(dim, glu_size, glu_size);
129 
130   const auto sig_b = at::sigmoid(b);
131   // TODO: use glu from forward.
132   // TODO: fuse kernels.
133   const auto glu = a * sig_b;
134   const auto db_neg_sig_b = db - db * sig_b;
135 
136   // dgrad_x_a = d(grad_glu * sigmoid(b))
137   //           = dgrad_glu * sigmoid(b) + grad_glu * sigmoid(b) * (1 - sigmoid(b)) * db
138   //           = dgrad_glu * sig_b + grad_x_a * (db - db * sig_b)
139   //           = dgrad_glu * sig_b + grad_x_a * db_neg_sig_b
140   const auto dgrad_x_a = dgrad_glu * sig_b + grad_x_a * db_neg_sig_b;
141 
142   // dgrad_x_b = d(grad_glu * sigmoid(b) * a * (1 - sigmoid(b))
143   //           =  d(grad_glu * sigmoid(b)) * a * (1 - sigmoid(b))
144   //            + grad_glu * sigmoid(b) * da * (1 - sigmoid(b))
145   //            - grad_glu * sigmoid(b) * a * sigmoid(b) * (1 - sigmoid(b)) * db
146   //          =   dgrad_x_a * a * (1 - sigmoid(b))
147   //           + (grad_glu * sigmoid(b)) * (da * (1 - sigmoid(b)) - a * sigmoid(b) * (1 - sigmoid(b)) * db)
148   //          = dgrad_x_a * (a - glu) + grad_x_a * (da - da * sig_b - glu * db_neg_sig_b
149   const auto dgrad_x_b = dgrad_x_a * (a - glu) + grad_x_a * (da - da * sig_b - glu * db_neg_sig_b);
150 
151   return at::cat({dgrad_x_a, dgrad_x_b}, dim);
152 }
153 
154 
155 } // namespace at::native
156