xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/RNN.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/CUDAApplyUtils.cuh>
8 #include <c10/macros/Macros.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/_thnn_fused_lstm_cell_native.h>
17 #include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_native.h>
18 #include <ATen/ops/_thnn_fused_gru_cell_native.h>
19 #include <ATen/ops/_thnn_fused_gru_cell_backward_native.h>
20 #endif
21 
22 namespace at::native {
23 
24 namespace {
25 
26 using at::cuda::detail::TensorInfo;
27 using at::cuda::detail::getTensorInfo;
28 using at::cuda::detail::IndexToOffset;
29 using at::cuda::detail::canUse32BitIndexMath;
30 
31 // Factor will be 3 for GRU and 4 for LSTM
checkSizes(CheckedFrom c,const TensorArg & input_gates,const TensorArg & hidden_gates,const TensorArg & input_bias,const TensorArg & hidden_bias,int64_t factor,const TensorArg & prev_hidden)32 void checkSizes(CheckedFrom c,
33                 const TensorArg& input_gates, const TensorArg& hidden_gates,
34                 const TensorArg& input_bias, const TensorArg& hidden_bias,
35                 int64_t factor, const TensorArg& prev_hidden) {
36   checkDim(c, input_gates, 2);
37   checkSameSize(c, input_gates, hidden_gates);
38   int64_t gates_size = input_gates->size(1);
39 
40   if (input_bias->defined()) {
41     checkDim(c, input_bias, 1);
42     checkNumel(c, input_bias, gates_size);
43     checkSameSize(c, input_bias, hidden_bias);
44   }
45 
46   checkDim(c, prev_hidden, 2);
47   checkNumel(c, prev_hidden, input_gates->size(0) * gates_size / factor);
48 
49   checkAllSameGPU(c, {input_gates, hidden_gates, input_bias, hidden_bias, prev_hidden});
50 }
51 
allContiguous(at::TensorList tensors)52 bool allContiguous(at::TensorList tensors) {
53   return std::all_of(tensors.begin(), tensors.end(),
54                      [](const at::Tensor& t) { return !t.defined() || t.is_contiguous(); });
55 }
56 
getLaunchConfig(dim3 * block,dim3 * grid,int64_t numel)57 void getLaunchConfig(dim3* block, dim3* grid, int64_t numel) {
58   c10::DeviceIndex curDevice = -1;
59   c10::cuda::GetDevice(&curDevice);
60   *block = cuda::getApplyBlock();
61   TORCH_INTERNAL_ASSERT(cuda::getApplyGrid(numel, *grid, curDevice),
62                         "Could not get grid size for pointwise apply.");
63 }
64 
65 template<typename T, typename T2>
tryGetTensorInfo(const at::Tensor & t)66 TensorInfo<T, T2> tryGetTensorInfo(const at::Tensor& t) {
67   return t.defined() ? getTensorInfo<T, T2>(t) : TensorInfo<T, T2>{};
68 }
69 
collapseDims()70 void collapseDims() {};
71 template<typename T, typename T2, typename... Args>
collapseDims(TensorInfo<T,T2> & info,Args &...infos)72 void collapseDims(TensorInfo<T, T2>& info, Args&... infos) {
73   info.collapseDims();
74   collapseDims(infos...);
75 }
76 
77 #define DEVICE_LINEAR_GET(D_TENSOR, INDEX)                              \
78   D_TENSOR.data[IndexToOffset<scalar_t, index_type, indexing_kind>::get(INDEX, D_TENSOR)]
79 
80 // Biases are always 1D
81 #define DEVICE_BIAS_GET(D_TENSOR, INDEX)                              \
82   D_TENSOR.data[IndexToOffset<scalar_t, index_type, 1>::get(INDEX, D_TENSOR)]
83 
84 #define H2F(input) static_cast<accscalar_t>(input)
85 #define F2H(input) static_cast<scalar_t>(input)
86 
87 template<typename T>
88 __device__ __forceinline__
sigmoid(T in)89 T sigmoid(T in)  {
90   T one = static_cast<T>(1.0);
91   return one / (one + ::exp(-in));
92 }
93 
94 namespace kernel {
95 
96 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
97 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
98 C10_LAUNCH_BOUNDS_2(512, 4)
99 #endif
lstm_cell_forward(TensorInfo<scalar_t,index_type> input,TensorInfo<scalar_t,index_type> hidden,TensorInfo<scalar_t,index_type> bias1,TensorInfo<scalar_t,index_type> bias2,TensorInfo<scalar_t,index_type> _cx,TensorInfo<scalar_t,index_type> _hy,TensorInfo<scalar_t,index_type> _cy,TensorInfo<scalar_t,index_type> workspace,index_type hsz,index_type totalElements)100 __global__ void lstm_cell_forward(
101             TensorInfo<scalar_t, index_type> input,
102             TensorInfo<scalar_t, index_type> hidden,
103             TensorInfo<scalar_t, index_type> bias1,
104             TensorInfo<scalar_t, index_type> bias2,
105             TensorInfo<scalar_t, index_type> _cx,
106             TensorInfo<scalar_t, index_type> _hy,
107             TensorInfo<scalar_t, index_type> _cy,
108             TensorInfo<scalar_t, index_type> workspace,
109             index_type hsz,
110             index_type totalElements) {
111     bool has_bias = bias1.data != nullptr;
112     for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
113        linearIndex < totalElements;
114        linearIndex += gridDim.x * blockDim.x) {
115       index_type offset = (linearIndex/hsz)*4*hsz+linearIndex%hsz;
116 
117       scalar_t iig = DEVICE_LINEAR_GET(input, offset+0*hsz);
118       scalar_t ifg = DEVICE_LINEAR_GET(input, offset+1*hsz);
119       scalar_t icg = DEVICE_LINEAR_GET(input, offset+2*hsz);
120       scalar_t iog = DEVICE_LINEAR_GET(input, offset+3*hsz);
121 
122       scalar_t hig = DEVICE_LINEAR_GET(hidden, offset+0*hsz);
123       scalar_t hfg = DEVICE_LINEAR_GET(hidden, offset+1*hsz);
124       scalar_t hcg = DEVICE_LINEAR_GET(hidden,  offset+2*hsz);
125       scalar_t hog = DEVICE_LINEAR_GET(hidden,  offset+3*hsz);
126 
127       scalar_t* wig = &DEVICE_LINEAR_GET(workspace, offset+0*hsz);
128       scalar_t* wfg = &DEVICE_LINEAR_GET(workspace, offset+1*hsz);
129       scalar_t* wcg = &DEVICE_LINEAR_GET(workspace, offset+2*hsz);
130       scalar_t* wog = &DEVICE_LINEAR_GET(workspace, offset+3*hsz);
131 
132       scalar_t cx = DEVICE_LINEAR_GET(_cx, linearIndex);
133 
134       scalar_t* hy = &DEVICE_LINEAR_GET(_hy, linearIndex);
135       scalar_t* cy = &DEVICE_LINEAR_GET(_cy, linearIndex);
136 
137       scalar_t b1i, b1f, b1c, b1o;
138       scalar_t b2i, b2f, b2c, b2o;
139 
140       if (has_bias) {
141         b1i = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 0 * hsz);
142         b1f = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 1 * hsz);
143         b1c = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 2 * hsz);
144         b1o = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 3 * hsz);
145 
146         b2i = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 0 * hsz);
147         b2f = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 1 * hsz);
148         b2c = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 2 * hsz);
149         b2o = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 3 * hsz);
150       } else {
151 #ifndef THC_REAL_IS_HALF
152         b1i = 0.0; b1f = 0.0; b1c = 0.0; b1o = 0.0;
153         b2i = 0.0; b2f = 0.0; b2c = 0.0; b2o = 0.0;
154 #else
155         b1i = F2H(0.0); b1f = F2H(0.0); b1c = F2H(0.0); b1o = F2H(0.0);
156         b2i = F2H(0.0); b2f = F2H(0.0); b2c = F2H(0.0); b2o = F2H(0.0);
157 #endif
158       }
159 
160       accscalar_t ig, fg, cg, og;
161       accscalar_t f_hy, f_cy;
162 
163       ig = sigmoid(H2F(iig) + H2F(hig) + H2F(b1i) + H2F(b2i));
164       fg = sigmoid(H2F(ifg) + H2F(hfg) + H2F(b1f) + H2F(b2f));
165       cg = ::tanh(H2F(icg) + H2F(hcg) + H2F(b1c) + H2F(b2c));
166       og = sigmoid(H2F(iog) + H2F(hog) + H2F(b1o) + H2F(b2o));
167 
168       f_cy = (fg * H2F(cx)) + (ig * cg);
169       f_hy = og * ::tanh(f_cy);
170 
171       *hy = F2H(f_hy);
172       *cy = F2H(f_cy);
173 
174       //SAVE FOR BACKWARDS
175       //Also need cy and cx but can be saved easily in python
176       *wig = F2H(ig);
177       *wfg = F2H(fg);
178       *wcg = F2H(cg);
179       *wog = F2H(og);
180     }
181 }
182 
183 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
184 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
185 C10_LAUNCH_BOUNDS_2(512, 4)
186 #endif
lstm_cell_backward(TensorInfo<scalar_t,index_type> storage,TensorInfo<scalar_t,index_type> gradInGates,TensorInfo<scalar_t,index_type> _cx,TensorInfo<scalar_t,index_type> _cy,TensorInfo<scalar_t,index_type> gradoutput,TensorInfo<scalar_t,index_type> gradoutputcell,TensorInfo<scalar_t,index_type> gradInputCx,index_type hsz,index_type totalElements)187 __global__ void lstm_cell_backward(
188               TensorInfo<scalar_t, index_type> storage,
189               TensorInfo<scalar_t, index_type> gradInGates,
190               TensorInfo<scalar_t, index_type> _cx,
191               TensorInfo<scalar_t, index_type> _cy,
192               TensorInfo<scalar_t, index_type> gradoutput,
193               TensorInfo<scalar_t, index_type> gradoutputcell,
194               TensorInfo<scalar_t, index_type> gradInputCx,
195               index_type hsz,
196               index_type totalElements) {
197   bool has_gradoutput = gradoutput.data != nullptr;
198   bool has_gradoutputcell = gradoutputcell.data != nullptr;
199   for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
200        linearIndex < totalElements;
201        linearIndex += gridDim.x * blockDim.x) {
202     index_type offset = (linearIndex/hsz)*4*hsz+linearIndex%hsz;
203 
204     scalar_t ig = DEVICE_LINEAR_GET(storage, offset+0*hsz);
205     scalar_t fg = DEVICE_LINEAR_GET(storage, offset+1*hsz);
206     scalar_t cg = DEVICE_LINEAR_GET(storage, offset+2*hsz);
207     scalar_t og = DEVICE_LINEAR_GET(storage, offset+3*hsz);
208 
209     scalar_t* ih = &DEVICE_LINEAR_GET(gradInGates, offset+0*hsz);
210     scalar_t* fh = &DEVICE_LINEAR_GET(gradInGates, offset+1*hsz);
211     scalar_t* ch = &DEVICE_LINEAR_GET(gradInGates, offset+2*hsz);
212     scalar_t* oh = &DEVICE_LINEAR_GET(gradInGates, offset+3*hsz);
213 
214     //will return hidden grads here
215     scalar_t cx = DEVICE_LINEAR_GET(_cx, linearIndex);
216     scalar_t cy = DEVICE_LINEAR_GET(_cy, linearIndex);
217 
218     scalar_t* gi = &DEVICE_LINEAR_GET(gradInputCx, linearIndex);
219 
220     accscalar_t go  = has_gradoutput ? H2F(DEVICE_LINEAR_GET(gradoutput, linearIndex)) : 0.f;
221     accscalar_t goc = has_gradoutputcell ? H2F(DEVICE_LINEAR_GET(gradoutputcell, linearIndex)) : 0.f;
222 
223     accscalar_t gcx = ::tanh(H2F(cy));
224 
225     accscalar_t gog = go * gcx;
226     gcx = go * H2F(og) * (1 - gcx*gcx) + goc;
227 
228     accscalar_t gig = gcx * H2F(cg);
229     accscalar_t gfg = gcx * H2F(cx);
230     accscalar_t gcg = gcx * H2F(ig);
231 
232     gcx = gcx * H2F(fg);
233 
234     gig = gig * (1-H2F(ig)) * H2F(ig);
235     gfg = gfg * (1-H2F(fg)) * H2F(fg);
236     gcg = gcg * (1-H2F(cg)*H2F(cg));
237     gog = gog * (1-H2F(og)) * H2F(og);
238 
239     *ih = F2H(gig);
240     *fh = F2H(gfg);
241     *ch = F2H(gcg);
242     *oh = F2H(gog);
243 
244     *gi = F2H(gcx);
245   }
246 }
247 
248 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
249 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
250 C10_LAUNCH_BOUNDS_2(512, 4)
251 #endif
gru_cell_forward(TensorInfo<scalar_t,index_type> Input,TensorInfo<scalar_t,index_type> Hidden,TensorInfo<scalar_t,index_type> Bias1,TensorInfo<scalar_t,index_type> Bias2,TensorInfo<scalar_t,index_type> _hx,TensorInfo<scalar_t,index_type> _hy,TensorInfo<scalar_t,index_type> storage,index_type hsz,index_type totalElements)252 __global__ void gru_cell_forward(
253             TensorInfo<scalar_t, index_type> Input,
254             TensorInfo<scalar_t, index_type> Hidden,
255             TensorInfo<scalar_t, index_type> Bias1,
256             TensorInfo<scalar_t, index_type> Bias2,
257             TensorInfo<scalar_t, index_type> _hx,
258             TensorInfo<scalar_t, index_type> _hy,
259             TensorInfo<scalar_t, index_type> storage,
260             index_type hsz,
261             index_type totalElements) {
262   bool has_bias = Bias1.data != nullptr;
263   for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
264        linearIndex < totalElements;
265        linearIndex += gridDim.x * blockDim.x) {
266       index_type offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz;
267 
268       scalar_t ir = DEVICE_LINEAR_GET(Input, offset+0*hsz);
269       scalar_t ii = DEVICE_LINEAR_GET(Input, offset+1*hsz);
270       scalar_t in = DEVICE_LINEAR_GET(Input, offset+2*hsz);
271       scalar_t hr = DEVICE_LINEAR_GET(Hidden,offset+0*hsz);
272       scalar_t hi = DEVICE_LINEAR_GET(Hidden,offset+1*hsz);
273       scalar_t hn = DEVICE_LINEAR_GET(Hidden,  offset+2*hsz);
274 
275       scalar_t hx = DEVICE_LINEAR_GET(_hx, linearIndex);
276       scalar_t* hy = &DEVICE_LINEAR_GET(_hy, linearIndex);
277 
278       scalar_t b1r, b1i, b1n, b2r, b2i, b2n;
279 
280       if (has_bias) {
281         b1r = DEVICE_BIAS_GET(Bias1, linearIndex%hsz+0*hsz);
282         b1i = DEVICE_BIAS_GET(Bias1, linearIndex%hsz+1*hsz);
283         b1n = DEVICE_BIAS_GET(Bias1, linearIndex%hsz+2*hsz);
284 
285         b2r = DEVICE_BIAS_GET(Bias2, linearIndex%hsz+0*hsz);
286         b2i = DEVICE_BIAS_GET(Bias2, linearIndex%hsz+1*hsz);
287         b2n = DEVICE_BIAS_GET(Bias2, linearIndex%hsz+2*hsz);
288       } else {
289 #ifndef THC_REAL_IS_HALF
290         b1r = 0.0; b1i = 0.0; b1n = 0.0;
291         b2r = 0.0; b2i = 0.0; b2n = 0.0;
292 #else
293         b1r = F2H(0.0); b1i = F2H(0.0); b1n = F2H(0.0);
294         b2r = F2H(0.0); b2i = F2H(0.0); b2n = F2H(0.0);
295 #endif
296       }
297 
298       offset = (linearIndex/hsz)*5*hsz+linearIndex%hsz;
299 
300       accscalar_t rg, ig, ng;
301 
302       rg = sigmoid(H2F(ir) + H2F(hr) + H2F(b1r) + H2F(b2r));
303       ig = sigmoid(H2F(ii) + H2F(hi) + H2F(b1i) + H2F(b2i));
304 
305       ng = H2F(in) + H2F(b1n) + rg*( H2F(hn)+H2F(b2n) );
306       ng = ::tanh(ng);
307       *hy = F2H( ng + ig * ( H2F(hx)-ng ) );
308 
309       //SAVE FOR BACKWARDS
310       DEVICE_LINEAR_GET(storage, offset+0*hsz) = F2H(rg);
311       DEVICE_LINEAR_GET(storage, offset+1*hsz) = F2H(ig);
312       DEVICE_LINEAR_GET(storage, offset+2*hsz) = F2H(ng);
313       DEVICE_LINEAR_GET(storage, offset+3*hsz) = hx;
314       DEVICE_LINEAR_GET(storage, offset+4*hsz) = F2H(H2F(hn) + H2F(b2n));
315     }
316 }
317 
318 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
319 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
320 C10_LAUNCH_BOUNDS_2(512, 4)
321 #endif
gru_cell_backward(TensorInfo<scalar_t,index_type> gradInInput,TensorInfo<scalar_t,index_type> gradInHidden,TensorInfo<scalar_t,index_type> gradOutput,TensorInfo<scalar_t,index_type> gradInputHx,TensorInfo<scalar_t,index_type> storage,index_type hsz,index_type totalElements)322 __global__ void gru_cell_backward(
323              TensorInfo<scalar_t, index_type> gradInInput,
324              TensorInfo<scalar_t, index_type> gradInHidden,
325              TensorInfo<scalar_t, index_type> gradOutput,
326              TensorInfo<scalar_t, index_type> gradInputHx,
327              TensorInfo<scalar_t, index_type> storage,
328              index_type hsz,
329              index_type totalElements) {
330   for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
331        linearIndex < totalElements;
332        linearIndex += gridDim.x * blockDim.x) {
333     index_type offset = (linearIndex/hsz)*5*hsz+linearIndex%hsz;
334 
335     scalar_t rg = DEVICE_LINEAR_GET(storage, offset+0*hsz);
336     scalar_t ig = DEVICE_LINEAR_GET(storage, offset+1*hsz);
337     scalar_t ng = DEVICE_LINEAR_GET(storage, offset+2*hsz);
338     scalar_t hx = DEVICE_LINEAR_GET(storage, offset+3*hsz);
339     scalar_t hn = DEVICE_LINEAR_GET(storage, offset+4*hsz);
340 
341     scalar_t go = DEVICE_LINEAR_GET(gradOutput, linearIndex);
342 
343     offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz;
344 
345     accscalar_t gig = H2F(go)*( H2F(hx)-H2F(ng) )*( 1-H2F(ig) )*H2F(ig);
346     accscalar_t ghx = H2F(go)*H2F(ig);
347     accscalar_t gin = H2F(go)*( 1-H2F(ig) )*( 1-H2F(ng)*H2F(ng) );
348     accscalar_t ghn = gin * H2F(rg);
349     accscalar_t grg = gin *H2F(hn)*( 1-H2F(rg) )*H2F(rg);
350 
351     DEVICE_LINEAR_GET(gradInInput, offset+0*hsz) = F2H(grg);
352     DEVICE_LINEAR_GET(gradInInput, offset+1*hsz) = F2H(gig);
353     DEVICE_LINEAR_GET(gradInInput, offset+2*hsz) = F2H(gin);
354 
355     DEVICE_LINEAR_GET(gradInHidden, offset+0*hsz) = F2H(grg);
356     DEVICE_LINEAR_GET(gradInHidden, offset+1*hsz) = F2H(gig);
357     DEVICE_LINEAR_GET(gradInHidden, offset+2*hsz) = F2H(ghn);
358     DEVICE_LINEAR_GET(gradInputHx, linearIndex) = F2H(ghx);
359   }
360 }
361 
362 #undef DEVICE_LINEAR_GET
363 #undef DEVICE_BIAS_GET
364 #undef H2F
365 #undef F2H
366 
367 } // namespace kernel
368 
369 template<typename scalar_t, typename index_type>
lstm_forward_impl(const Tensor & input_gates,const Tensor & hidden_gates,const Tensor & input_bias,const Tensor & hidden_bias,const Tensor & cx,const Tensor & hy,const Tensor & cy,const Tensor & workspace)370 void lstm_forward_impl(const Tensor& input_gates, const Tensor& hidden_gates,
371                        const Tensor& input_bias, const Tensor& hidden_bias,
372                        const Tensor& cx,
373                        const Tensor& hy, const Tensor& cy, const Tensor& workspace) {
374   using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>;
375 
376   dim3 block, grid;
377   int64_t numel = cx.numel();
378   if (numel == 0) return;
379   getLaunchConfig(&block, &grid, numel);
380 
381   auto input_gatesI = getTensorInfo<scalar_t, index_type>(input_gates);
382   auto hidden_gatesI = getTensorInfo<scalar_t, index_type>(hidden_gates);
383   auto input_biasI = tryGetTensorInfo<scalar_t, index_type>(input_bias);
384   auto hidden_biasI = tryGetTensorInfo<scalar_t, index_type>(hidden_bias);
385   auto cxI = getTensorInfo<scalar_t, index_type>(cx);
386   auto hyI = getTensorInfo<scalar_t, index_type>(hy);
387   auto cyI = getTensorInfo<scalar_t, index_type>(cy);
388   auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace);
389   index_type hidden_size = cxI.sizes[cxI.dims-1];
390 
391   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
392   if (allContiguous({input_gates, hidden_gates, input_bias, hidden_bias, cx, hy, cy, workspace})) {
393     collapseDims(input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, cxI, hyI, cyI, workspaceI);
394     kernel::lstm_cell_forward<scalar_t, accscalar_t, index_type, 1>
395       <<<grid, block, 0, stream>>>
396         (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, cxI, hyI, cyI, workspaceI, hidden_size, numel);
397     C10_CUDA_KERNEL_LAUNCH_CHECK();
398   } else {
399     kernel::lstm_cell_forward<scalar_t, accscalar_t, index_type, 2>
400       <<<grid, block, 0, stream>>>
401         (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, cxI, hyI, cyI, workspaceI, hidden_size, numel);
402     C10_CUDA_KERNEL_LAUNCH_CHECK();
403   }
404 }
405 
406 template<typename scalar_t, typename index_type>
lstm_backward_impl(const Tensor & grad_hy,const Tensor & grad_cy,const Tensor & cx,const Tensor & cy,const Tensor & workspace,const Tensor & grad_gates,const Tensor & grad_cx)407 void lstm_backward_impl(const Tensor& grad_hy, const Tensor& grad_cy,
408                         const Tensor& cx, const Tensor& cy,
409                         const Tensor& workspace,
410                         const Tensor& grad_gates, const Tensor& grad_cx) {
411   using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>;
412 
413   dim3 block, grid;
414   int64_t numel = cx.numel();
415   getLaunchConfig(&block, &grid, numel);
416   if (numel == 0) return;
417 
418   auto grad_hyI = tryGetTensorInfo<scalar_t, index_type>(grad_hy);
419   auto grad_cyI = tryGetTensorInfo<scalar_t, index_type>(grad_cy);
420   auto cxI = getTensorInfo<scalar_t, index_type>(cx);
421   auto cyI = getTensorInfo<scalar_t, index_type>(cy);
422   auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace);
423   auto grad_gatesI = getTensorInfo<scalar_t, index_type>(grad_gates);
424   auto grad_cxI = getTensorInfo<scalar_t, index_type>(grad_cx);
425   index_type hidden_size = cxI.sizes[cxI.dims-1];
426 
427   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
428   if (allContiguous({grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx})) {
429     collapseDims(grad_hyI, grad_cyI, cxI, cyI, workspaceI, grad_gatesI, grad_cxI);
430     kernel::lstm_cell_backward<scalar_t, accscalar_t, index_type, 1>
431       <<<grid, block, 0, stream>>>
432         (workspaceI, grad_gatesI, cxI, cyI, grad_hyI, grad_cyI, grad_cxI, hidden_size, numel);
433     C10_CUDA_KERNEL_LAUNCH_CHECK();
434   } else {
435     kernel::lstm_cell_backward<scalar_t, accscalar_t, index_type, 2>
436       <<<grid, block, 0, stream>>>
437         (workspaceI, grad_gatesI, cxI, cyI, grad_hyI, grad_cyI, grad_cxI, hidden_size, numel);
438     C10_CUDA_KERNEL_LAUNCH_CHECK();
439   }
440 }
441 
442 template<typename scalar_t, typename index_type>
gru_forward_impl(const Tensor & input_gates,const Tensor & hidden_gates,const Tensor & input_bias,const Tensor & hidden_bias,const Tensor & hx,const Tensor & hy,const Tensor & workspace)443 void gru_forward_impl(const Tensor& input_gates, const Tensor& hidden_gates,
444                       const Tensor& input_bias, const Tensor& hidden_bias,
445                       const Tensor& hx,
446                       const Tensor& hy, const Tensor& workspace) {
447   using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>;
448 
449   dim3 block, grid;
450   int64_t numel = hx.numel();
451   if (numel == 0) return;
452   getLaunchConfig(&block, &grid, numel);
453 
454   auto input_gatesI = getTensorInfo<scalar_t, index_type>(input_gates);
455   auto hidden_gatesI = getTensorInfo<scalar_t, index_type>(hidden_gates);
456   auto input_biasI = tryGetTensorInfo<scalar_t, index_type>(input_bias);
457   auto hidden_biasI = tryGetTensorInfo<scalar_t, index_type>(hidden_bias);
458   auto hxI = getTensorInfo<scalar_t, index_type>(hx);
459   auto hyI = getTensorInfo<scalar_t, index_type>(hy);
460   auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace);
461   index_type hidden_size = hxI.sizes[hxI.dims-1];
462 
463   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
464   if (allContiguous({input_gates, hidden_gates, input_bias, hidden_bias, hx, hy, workspace})) {
465     collapseDims(input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, hxI, hyI, workspaceI);
466     kernel::gru_cell_forward<scalar_t, accscalar_t, index_type, 1>
467       <<<grid, block, 0, stream>>>
468         (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, hxI, hyI, workspaceI, hidden_size, numel);
469     C10_CUDA_KERNEL_LAUNCH_CHECK();
470   } else {
471     kernel::gru_cell_forward<scalar_t, accscalar_t, index_type, 2>
472       <<<grid, block, 0, stream>>>
473         (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, hxI, hyI, workspaceI, hidden_size, numel);
474     C10_CUDA_KERNEL_LAUNCH_CHECK();
475   }
476 }
477 
478 template<typename scalar_t, typename index_type>
gru_backward_impl(const Tensor & grad_hy,const Tensor & workspace,const Tensor & grad_input_gates,const Tensor & grad_hidden_gates,const Tensor & grad_hx)479 void gru_backward_impl(const Tensor& grad_hy, const Tensor& workspace,
480                        const Tensor& grad_input_gates, const Tensor& grad_hidden_gates, const Tensor& grad_hx) {
481   using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>;
482 
483   dim3 block, grid;
484   int64_t numel = grad_hy.numel();
485   if (numel == 0) return;
486   getLaunchConfig(&block, &grid, numel);
487 
488   auto grad_hyI = getTensorInfo<scalar_t, index_type>(grad_hy);
489   auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace);
490   auto grad_input_gatesI = getTensorInfo<scalar_t, index_type>(grad_input_gates);
491   auto grad_hidden_gatesI = getTensorInfo<scalar_t, index_type>(grad_hidden_gates);
492   auto grad_hxI = getTensorInfo<scalar_t, index_type>(grad_hx);
493   index_type hidden_size = grad_hyI.sizes[grad_hyI.dims-1];
494 
495   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
496   if (allContiguous({grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx})) {
497     collapseDims(grad_hyI, workspaceI, grad_input_gatesI, grad_hidden_gatesI, grad_hxI);
498     kernel::gru_cell_backward<scalar_t, accscalar_t, index_type, 1>
499       <<<grid, block, 0, stream>>>
500         (grad_input_gatesI, grad_hidden_gatesI, grad_hyI, grad_hxI, workspaceI, hidden_size, numel);
501     C10_CUDA_KERNEL_LAUNCH_CHECK();
502   } else {
503     kernel::gru_cell_backward<scalar_t, accscalar_t, index_type, 2>
504       <<<grid, block, 0, stream>>>
505         (grad_input_gatesI, grad_hidden_gatesI, grad_hyI, grad_hxI, workspaceI, hidden_size, numel);
506     C10_CUDA_KERNEL_LAUNCH_CHECK();
507   }
508 }
509 
510 } // anonymous namespace
511 
512 // Note [64-bit index math check elision]
513 // It's enough to perform the check for 64-bit math on the largest tensor only.
514 // If 32-bit is enough for it, it will suffice for all other tensors too, and we
515 // can save some work using this trick.
516 
_thnn_fused_lstm_cell_cuda(const Tensor & input_gates,const Tensor & hidden_gates,const Tensor & cx,const std::optional<Tensor> & input_bias_opt,const std::optional<Tensor> & hidden_bias_opt)517 std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_cuda(
518       const Tensor& input_gates, const Tensor& hidden_gates,
519       const Tensor& cx, const std::optional<Tensor>& input_bias_opt, const std::optional<Tensor>& hidden_bias_opt) {
520   // See [Note: hacky wrapper removal for optional tensor]
521   c10::MaybeOwned<Tensor> input_bias_maybe_owned = at::borrow_from_optional_tensor(input_bias_opt);
522   const Tensor& input_bias = *input_bias_maybe_owned;
523   const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();});
524 
525   checkSizes("_thnn_fused_lstm_cell_cuda",
526              {input_gates, "input_gates", 1}, {hidden_gates, "hidden_gates", 2},
527              {input_bias, "input_bias", 3}, {hidden_bias, "hidden_bias", 4},
528              /*factor=*/4, {cx, "prev_hidden", 5});
529 
530   auto workspace = at::empty_like(input_gates, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
531   auto hy = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
532   auto cy = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
533   AT_DISPATCH_FLOATING_TYPES_AND2(
534     at::ScalarType::Half,
535     at::ScalarType::BFloat16,
536     input_gates.scalar_type(),
537     "_thnn_fused_lstm_cell_cuda",
538     [&] {
539       if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
540         lstm_forward_impl<scalar_t, int32_t>(input_gates, hidden_gates, input_bias, hidden_bias, cx, hy, cy, workspace);
541       } else {
542         lstm_forward_impl<scalar_t, int64_t>(input_gates, hidden_gates, input_bias, hidden_bias, cx, hy, cy, workspace);
543       }
544   });
545   return std::make_tuple(std::move(hy), std::move(cy), std::move(workspace));
546 }
547 
checkLSTMBackwardSizes(const TensorArg & grad_hy,const TensorArg & grad_cy,const TensorArg & cx,const TensorArg & cy,const TensorArg & workspace)548 void checkLSTMBackwardSizes(const TensorArg& grad_hy, const TensorArg& grad_cy,
549                             const TensorArg& cx, const TensorArg& cy,
550                             const TensorArg& workspace) {
551   CheckedFrom c = "fused_lstm_cell_backward";
552   const TensorArg& defined_grad = grad_hy->defined() ? grad_hy : grad_cy;
553   checkDim(c, defined_grad, 2);
554   auto exp_size = defined_grad->sizes();
555   if (grad_hy->defined()) {
556     checkSize(c, grad_hy, exp_size);
557   }
558   if (grad_cy->defined()) {
559     checkSize(c, grad_cy, exp_size);
560   }
561   checkSize(c, cx, exp_size);
562   checkSize(c, cy, exp_size);
563   checkDim(c, workspace, 2);
564   checkNumel(c, workspace, exp_size[0] * exp_size[1] * 4);
565 }
566 
_thnn_fused_lstm_cell_backward_impl_cuda(const std::optional<Tensor> & grad_hy_opt,const std::optional<Tensor> & grad_cy_opt,const Tensor & cx,const Tensor & cy,const Tensor & workspace,bool has_bias)567 std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backward_impl_cuda( const std::optional<Tensor>& grad_hy_opt, const std::optional<Tensor>& grad_cy_opt,
568       const Tensor& cx, const Tensor& cy,
569       const Tensor& workspace, bool has_bias) {
570   // See [Note: hacky wrapper removal for optional tensor]
571   c10::MaybeOwned<Tensor> grad_hy_maybe_owned = at::borrow_from_optional_tensor(grad_hy_opt);
572   const Tensor& grad_hy = *grad_hy_maybe_owned;
573   const Tensor& grad_cy = c10::value_or_else(grad_cy_opt, [] {return Tensor();});
574 
575   if (!grad_hy.defined() && !grad_cy.defined()) {
576     return std::tuple<Tensor, Tensor, Tensor>();
577   }
578   checkLSTMBackwardSizes({grad_hy, "grad_hy", 1}, {grad_cy, "grad_cy", 2},
579                          {cx, "cx", 3}, {cy, "cy", 4},
580                          {workspace, "workspace", 5});
581 
582   auto grad_gates = at::empty_like(workspace, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
583   auto grad_cx = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
584   AT_DISPATCH_FLOATING_TYPES_AND2(
585     at::ScalarType::Half,
586     at::ScalarType::BFloat16,
587     workspace.scalar_type(),
588     "_thnn_fused_lstm_cell_cuda_backward",
589     [&] {
590       if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
591         lstm_backward_impl<scalar_t, int32_t>(grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx);
592       } else {
593         lstm_backward_impl<scalar_t, int64_t>(grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx);
594       }
595   });
596 
597   auto grad_bias = has_bias ? grad_gates.sum(0, /*keepdim=*/false) : at::Tensor{};
598   return std::make_tuple(std::move(grad_gates), std::move(grad_cx), std::move(grad_bias));
599 }
600 
601 static constexpr int64_t GRU_WORKSPACE_MULTIPLIER = 5;
602 
_thnn_fused_gru_cell_cuda(const Tensor & input_gates,const Tensor & hidden_gates,const Tensor & hx,const std::optional<Tensor> & input_bias_opt,const std::optional<Tensor> & hidden_bias_opt)603 std::tuple<Tensor, Tensor> _thnn_fused_gru_cell_cuda(
604       const Tensor& input_gates, const Tensor& hidden_gates,
605       const Tensor& hx, const std::optional<Tensor>& input_bias_opt, const std::optional<Tensor>& hidden_bias_opt) {
606   // See [Note: hacky wrapper removal for optional tensor]
607   c10::MaybeOwned<Tensor> input_bias_maybe_owned = at::borrow_from_optional_tensor(input_bias_opt);
608   const Tensor& input_bias = *input_bias_maybe_owned;
609   const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();});
610 
611   checkSizes("_thnn_fused_gru_cell_cuda",
612              {input_gates, "input_gates", 1}, {hidden_gates, "hidden_gates", 2},
613              {input_bias, "input_bias", 3}, {hidden_bias, "hidden_bias", 4},
614              /*factor=*/3, {hx, "prev_hidden", 5});
615 
616   auto workspace = at::empty({hx.size(0), hx.size(1) * GRU_WORKSPACE_MULTIPLIER}, hx.options());
617   auto hy = at::empty_like(hx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
618   AT_DISPATCH_FLOATING_TYPES_AND2(
619     at::ScalarType::Half,
620     at::ScalarType::BFloat16,
621     input_gates.scalar_type(),
622     "_thnn_fused_gru_cell_cuda",
623     [&] {
624       if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
625         gru_forward_impl<scalar_t, int32_t>(input_gates, hidden_gates, input_bias, hidden_bias, hx, hy, workspace);
626       } else {
627         gru_forward_impl<scalar_t, int64_t>(input_gates, hidden_gates, input_bias, hidden_bias, hx, hy, workspace);
628       }
629   });
630   return std::make_tuple(std::move(hy), std::move(workspace));
631 }
632 
checkGRUBackwardSizes(const TensorArg & grad_hy,const TensorArg & workspace)633 void checkGRUBackwardSizes(const TensorArg& grad_hy, const TensorArg& workspace) {
634   CheckedFrom c = "fused_gru_cell_backward";
635   checkDim(c, grad_hy, 2);
636   checkSize(c, workspace, {grad_hy->size(0), grad_hy->size(1) * GRU_WORKSPACE_MULTIPLIER});
637 }
638 
_thnn_fused_gru_cell_backward_cuda(const Tensor & grad_hy,const Tensor & workspace,bool has_bias)639 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_gru_cell_backward_cuda(
640       const Tensor& grad_hy, const Tensor& workspace, bool has_bias) {
641   checkGRUBackwardSizes({grad_hy, "grad_hy", 1}, {workspace, "workspace", 2});
642 
643   int64_t hidden_size = workspace.size(1) / GRU_WORKSPACE_MULTIPLIER;
644   auto grad_input_gates = at::empty({workspace.size(0), hidden_size * 3}, workspace.options());
645   auto grad_hidden_gates = at::empty({workspace.size(0), hidden_size * 3}, workspace.options());
646   auto grad_hx = at::empty_like(grad_hy, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
647   AT_DISPATCH_FLOATING_TYPES_AND2(
648     at::ScalarType::Half,
649     at::ScalarType::BFloat16,
650     grad_hy.scalar_type(),
651     "_thnn_fused_gru_cell_cuda_backward",
652     [&] {
653       if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
654         gru_backward_impl<scalar_t, int32_t>(grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx);
655       } else {
656         gru_backward_impl<scalar_t, int64_t>(grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx);
657       }
658   });
659 
660   at::Tensor grad_input_bias, grad_hidden_bias;
661   if (has_bias) {
662     grad_input_bias = grad_input_gates.sum(0, /*keepdim=*/false);
663     grad_hidden_bias = grad_hidden_gates.sum(0, /*keepdim=*/false);
664   }
665 
666   return std::make_tuple(
667     std::move(grad_input_gates),
668     std::move(grad_hidden_gates),
669     std::move(grad_hx),
670     std::move(grad_input_bias),
671     std::move(grad_hidden_bias)
672   );
673 }
674 
675 } // namespace at::native
676