1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/CUDAApplyUtils.cuh>
8 #include <c10/macros/Macros.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/_thnn_fused_lstm_cell_native.h>
17 #include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_native.h>
18 #include <ATen/ops/_thnn_fused_gru_cell_native.h>
19 #include <ATen/ops/_thnn_fused_gru_cell_backward_native.h>
20 #endif
21
22 namespace at::native {
23
24 namespace {
25
26 using at::cuda::detail::TensorInfo;
27 using at::cuda::detail::getTensorInfo;
28 using at::cuda::detail::IndexToOffset;
29 using at::cuda::detail::canUse32BitIndexMath;
30
31 // Factor will be 3 for GRU and 4 for LSTM
checkSizes(CheckedFrom c,const TensorArg & input_gates,const TensorArg & hidden_gates,const TensorArg & input_bias,const TensorArg & hidden_bias,int64_t factor,const TensorArg & prev_hidden)32 void checkSizes(CheckedFrom c,
33 const TensorArg& input_gates, const TensorArg& hidden_gates,
34 const TensorArg& input_bias, const TensorArg& hidden_bias,
35 int64_t factor, const TensorArg& prev_hidden) {
36 checkDim(c, input_gates, 2);
37 checkSameSize(c, input_gates, hidden_gates);
38 int64_t gates_size = input_gates->size(1);
39
40 if (input_bias->defined()) {
41 checkDim(c, input_bias, 1);
42 checkNumel(c, input_bias, gates_size);
43 checkSameSize(c, input_bias, hidden_bias);
44 }
45
46 checkDim(c, prev_hidden, 2);
47 checkNumel(c, prev_hidden, input_gates->size(0) * gates_size / factor);
48
49 checkAllSameGPU(c, {input_gates, hidden_gates, input_bias, hidden_bias, prev_hidden});
50 }
51
allContiguous(at::TensorList tensors)52 bool allContiguous(at::TensorList tensors) {
53 return std::all_of(tensors.begin(), tensors.end(),
54 [](const at::Tensor& t) { return !t.defined() || t.is_contiguous(); });
55 }
56
getLaunchConfig(dim3 * block,dim3 * grid,int64_t numel)57 void getLaunchConfig(dim3* block, dim3* grid, int64_t numel) {
58 c10::DeviceIndex curDevice = -1;
59 c10::cuda::GetDevice(&curDevice);
60 *block = cuda::getApplyBlock();
61 TORCH_INTERNAL_ASSERT(cuda::getApplyGrid(numel, *grid, curDevice),
62 "Could not get grid size for pointwise apply.");
63 }
64
65 template<typename T, typename T2>
tryGetTensorInfo(const at::Tensor & t)66 TensorInfo<T, T2> tryGetTensorInfo(const at::Tensor& t) {
67 return t.defined() ? getTensorInfo<T, T2>(t) : TensorInfo<T, T2>{};
68 }
69
collapseDims()70 void collapseDims() {};
71 template<typename T, typename T2, typename... Args>
collapseDims(TensorInfo<T,T2> & info,Args &...infos)72 void collapseDims(TensorInfo<T, T2>& info, Args&... infos) {
73 info.collapseDims();
74 collapseDims(infos...);
75 }
76
77 #define DEVICE_LINEAR_GET(D_TENSOR, INDEX) \
78 D_TENSOR.data[IndexToOffset<scalar_t, index_type, indexing_kind>::get(INDEX, D_TENSOR)]
79
80 // Biases are always 1D
81 #define DEVICE_BIAS_GET(D_TENSOR, INDEX) \
82 D_TENSOR.data[IndexToOffset<scalar_t, index_type, 1>::get(INDEX, D_TENSOR)]
83
84 #define H2F(input) static_cast<accscalar_t>(input)
85 #define F2H(input) static_cast<scalar_t>(input)
86
87 template<typename T>
88 __device__ __forceinline__
sigmoid(T in)89 T sigmoid(T in) {
90 T one = static_cast<T>(1.0);
91 return one / (one + ::exp(-in));
92 }
93
94 namespace kernel {
95
96 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
97 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
98 C10_LAUNCH_BOUNDS_2(512, 4)
99 #endif
lstm_cell_forward(TensorInfo<scalar_t,index_type> input,TensorInfo<scalar_t,index_type> hidden,TensorInfo<scalar_t,index_type> bias1,TensorInfo<scalar_t,index_type> bias2,TensorInfo<scalar_t,index_type> _cx,TensorInfo<scalar_t,index_type> _hy,TensorInfo<scalar_t,index_type> _cy,TensorInfo<scalar_t,index_type> workspace,index_type hsz,index_type totalElements)100 __global__ void lstm_cell_forward(
101 TensorInfo<scalar_t, index_type> input,
102 TensorInfo<scalar_t, index_type> hidden,
103 TensorInfo<scalar_t, index_type> bias1,
104 TensorInfo<scalar_t, index_type> bias2,
105 TensorInfo<scalar_t, index_type> _cx,
106 TensorInfo<scalar_t, index_type> _hy,
107 TensorInfo<scalar_t, index_type> _cy,
108 TensorInfo<scalar_t, index_type> workspace,
109 index_type hsz,
110 index_type totalElements) {
111 bool has_bias = bias1.data != nullptr;
112 for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
113 linearIndex < totalElements;
114 linearIndex += gridDim.x * blockDim.x) {
115 index_type offset = (linearIndex/hsz)*4*hsz+linearIndex%hsz;
116
117 scalar_t iig = DEVICE_LINEAR_GET(input, offset+0*hsz);
118 scalar_t ifg = DEVICE_LINEAR_GET(input, offset+1*hsz);
119 scalar_t icg = DEVICE_LINEAR_GET(input, offset+2*hsz);
120 scalar_t iog = DEVICE_LINEAR_GET(input, offset+3*hsz);
121
122 scalar_t hig = DEVICE_LINEAR_GET(hidden, offset+0*hsz);
123 scalar_t hfg = DEVICE_LINEAR_GET(hidden, offset+1*hsz);
124 scalar_t hcg = DEVICE_LINEAR_GET(hidden, offset+2*hsz);
125 scalar_t hog = DEVICE_LINEAR_GET(hidden, offset+3*hsz);
126
127 scalar_t* wig = &DEVICE_LINEAR_GET(workspace, offset+0*hsz);
128 scalar_t* wfg = &DEVICE_LINEAR_GET(workspace, offset+1*hsz);
129 scalar_t* wcg = &DEVICE_LINEAR_GET(workspace, offset+2*hsz);
130 scalar_t* wog = &DEVICE_LINEAR_GET(workspace, offset+3*hsz);
131
132 scalar_t cx = DEVICE_LINEAR_GET(_cx, linearIndex);
133
134 scalar_t* hy = &DEVICE_LINEAR_GET(_hy, linearIndex);
135 scalar_t* cy = &DEVICE_LINEAR_GET(_cy, linearIndex);
136
137 scalar_t b1i, b1f, b1c, b1o;
138 scalar_t b2i, b2f, b2c, b2o;
139
140 if (has_bias) {
141 b1i = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 0 * hsz);
142 b1f = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 1 * hsz);
143 b1c = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 2 * hsz);
144 b1o = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 3 * hsz);
145
146 b2i = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 0 * hsz);
147 b2f = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 1 * hsz);
148 b2c = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 2 * hsz);
149 b2o = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 3 * hsz);
150 } else {
151 #ifndef THC_REAL_IS_HALF
152 b1i = 0.0; b1f = 0.0; b1c = 0.0; b1o = 0.0;
153 b2i = 0.0; b2f = 0.0; b2c = 0.0; b2o = 0.0;
154 #else
155 b1i = F2H(0.0); b1f = F2H(0.0); b1c = F2H(0.0); b1o = F2H(0.0);
156 b2i = F2H(0.0); b2f = F2H(0.0); b2c = F2H(0.0); b2o = F2H(0.0);
157 #endif
158 }
159
160 accscalar_t ig, fg, cg, og;
161 accscalar_t f_hy, f_cy;
162
163 ig = sigmoid(H2F(iig) + H2F(hig) + H2F(b1i) + H2F(b2i));
164 fg = sigmoid(H2F(ifg) + H2F(hfg) + H2F(b1f) + H2F(b2f));
165 cg = ::tanh(H2F(icg) + H2F(hcg) + H2F(b1c) + H2F(b2c));
166 og = sigmoid(H2F(iog) + H2F(hog) + H2F(b1o) + H2F(b2o));
167
168 f_cy = (fg * H2F(cx)) + (ig * cg);
169 f_hy = og * ::tanh(f_cy);
170
171 *hy = F2H(f_hy);
172 *cy = F2H(f_cy);
173
174 //SAVE FOR BACKWARDS
175 //Also need cy and cx but can be saved easily in python
176 *wig = F2H(ig);
177 *wfg = F2H(fg);
178 *wcg = F2H(cg);
179 *wog = F2H(og);
180 }
181 }
182
183 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
184 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
185 C10_LAUNCH_BOUNDS_2(512, 4)
186 #endif
lstm_cell_backward(TensorInfo<scalar_t,index_type> storage,TensorInfo<scalar_t,index_type> gradInGates,TensorInfo<scalar_t,index_type> _cx,TensorInfo<scalar_t,index_type> _cy,TensorInfo<scalar_t,index_type> gradoutput,TensorInfo<scalar_t,index_type> gradoutputcell,TensorInfo<scalar_t,index_type> gradInputCx,index_type hsz,index_type totalElements)187 __global__ void lstm_cell_backward(
188 TensorInfo<scalar_t, index_type> storage,
189 TensorInfo<scalar_t, index_type> gradInGates,
190 TensorInfo<scalar_t, index_type> _cx,
191 TensorInfo<scalar_t, index_type> _cy,
192 TensorInfo<scalar_t, index_type> gradoutput,
193 TensorInfo<scalar_t, index_type> gradoutputcell,
194 TensorInfo<scalar_t, index_type> gradInputCx,
195 index_type hsz,
196 index_type totalElements) {
197 bool has_gradoutput = gradoutput.data != nullptr;
198 bool has_gradoutputcell = gradoutputcell.data != nullptr;
199 for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
200 linearIndex < totalElements;
201 linearIndex += gridDim.x * blockDim.x) {
202 index_type offset = (linearIndex/hsz)*4*hsz+linearIndex%hsz;
203
204 scalar_t ig = DEVICE_LINEAR_GET(storage, offset+0*hsz);
205 scalar_t fg = DEVICE_LINEAR_GET(storage, offset+1*hsz);
206 scalar_t cg = DEVICE_LINEAR_GET(storage, offset+2*hsz);
207 scalar_t og = DEVICE_LINEAR_GET(storage, offset+3*hsz);
208
209 scalar_t* ih = &DEVICE_LINEAR_GET(gradInGates, offset+0*hsz);
210 scalar_t* fh = &DEVICE_LINEAR_GET(gradInGates, offset+1*hsz);
211 scalar_t* ch = &DEVICE_LINEAR_GET(gradInGates, offset+2*hsz);
212 scalar_t* oh = &DEVICE_LINEAR_GET(gradInGates, offset+3*hsz);
213
214 //will return hidden grads here
215 scalar_t cx = DEVICE_LINEAR_GET(_cx, linearIndex);
216 scalar_t cy = DEVICE_LINEAR_GET(_cy, linearIndex);
217
218 scalar_t* gi = &DEVICE_LINEAR_GET(gradInputCx, linearIndex);
219
220 accscalar_t go = has_gradoutput ? H2F(DEVICE_LINEAR_GET(gradoutput, linearIndex)) : 0.f;
221 accscalar_t goc = has_gradoutputcell ? H2F(DEVICE_LINEAR_GET(gradoutputcell, linearIndex)) : 0.f;
222
223 accscalar_t gcx = ::tanh(H2F(cy));
224
225 accscalar_t gog = go * gcx;
226 gcx = go * H2F(og) * (1 - gcx*gcx) + goc;
227
228 accscalar_t gig = gcx * H2F(cg);
229 accscalar_t gfg = gcx * H2F(cx);
230 accscalar_t gcg = gcx * H2F(ig);
231
232 gcx = gcx * H2F(fg);
233
234 gig = gig * (1-H2F(ig)) * H2F(ig);
235 gfg = gfg * (1-H2F(fg)) * H2F(fg);
236 gcg = gcg * (1-H2F(cg)*H2F(cg));
237 gog = gog * (1-H2F(og)) * H2F(og);
238
239 *ih = F2H(gig);
240 *fh = F2H(gfg);
241 *ch = F2H(gcg);
242 *oh = F2H(gog);
243
244 *gi = F2H(gcx);
245 }
246 }
247
248 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
249 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
250 C10_LAUNCH_BOUNDS_2(512, 4)
251 #endif
gru_cell_forward(TensorInfo<scalar_t,index_type> Input,TensorInfo<scalar_t,index_type> Hidden,TensorInfo<scalar_t,index_type> Bias1,TensorInfo<scalar_t,index_type> Bias2,TensorInfo<scalar_t,index_type> _hx,TensorInfo<scalar_t,index_type> _hy,TensorInfo<scalar_t,index_type> storage,index_type hsz,index_type totalElements)252 __global__ void gru_cell_forward(
253 TensorInfo<scalar_t, index_type> Input,
254 TensorInfo<scalar_t, index_type> Hidden,
255 TensorInfo<scalar_t, index_type> Bias1,
256 TensorInfo<scalar_t, index_type> Bias2,
257 TensorInfo<scalar_t, index_type> _hx,
258 TensorInfo<scalar_t, index_type> _hy,
259 TensorInfo<scalar_t, index_type> storage,
260 index_type hsz,
261 index_type totalElements) {
262 bool has_bias = Bias1.data != nullptr;
263 for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
264 linearIndex < totalElements;
265 linearIndex += gridDim.x * blockDim.x) {
266 index_type offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz;
267
268 scalar_t ir = DEVICE_LINEAR_GET(Input, offset+0*hsz);
269 scalar_t ii = DEVICE_LINEAR_GET(Input, offset+1*hsz);
270 scalar_t in = DEVICE_LINEAR_GET(Input, offset+2*hsz);
271 scalar_t hr = DEVICE_LINEAR_GET(Hidden,offset+0*hsz);
272 scalar_t hi = DEVICE_LINEAR_GET(Hidden,offset+1*hsz);
273 scalar_t hn = DEVICE_LINEAR_GET(Hidden, offset+2*hsz);
274
275 scalar_t hx = DEVICE_LINEAR_GET(_hx, linearIndex);
276 scalar_t* hy = &DEVICE_LINEAR_GET(_hy, linearIndex);
277
278 scalar_t b1r, b1i, b1n, b2r, b2i, b2n;
279
280 if (has_bias) {
281 b1r = DEVICE_BIAS_GET(Bias1, linearIndex%hsz+0*hsz);
282 b1i = DEVICE_BIAS_GET(Bias1, linearIndex%hsz+1*hsz);
283 b1n = DEVICE_BIAS_GET(Bias1, linearIndex%hsz+2*hsz);
284
285 b2r = DEVICE_BIAS_GET(Bias2, linearIndex%hsz+0*hsz);
286 b2i = DEVICE_BIAS_GET(Bias2, linearIndex%hsz+1*hsz);
287 b2n = DEVICE_BIAS_GET(Bias2, linearIndex%hsz+2*hsz);
288 } else {
289 #ifndef THC_REAL_IS_HALF
290 b1r = 0.0; b1i = 0.0; b1n = 0.0;
291 b2r = 0.0; b2i = 0.0; b2n = 0.0;
292 #else
293 b1r = F2H(0.0); b1i = F2H(0.0); b1n = F2H(0.0);
294 b2r = F2H(0.0); b2i = F2H(0.0); b2n = F2H(0.0);
295 #endif
296 }
297
298 offset = (linearIndex/hsz)*5*hsz+linearIndex%hsz;
299
300 accscalar_t rg, ig, ng;
301
302 rg = sigmoid(H2F(ir) + H2F(hr) + H2F(b1r) + H2F(b2r));
303 ig = sigmoid(H2F(ii) + H2F(hi) + H2F(b1i) + H2F(b2i));
304
305 ng = H2F(in) + H2F(b1n) + rg*( H2F(hn)+H2F(b2n) );
306 ng = ::tanh(ng);
307 *hy = F2H( ng + ig * ( H2F(hx)-ng ) );
308
309 //SAVE FOR BACKWARDS
310 DEVICE_LINEAR_GET(storage, offset+0*hsz) = F2H(rg);
311 DEVICE_LINEAR_GET(storage, offset+1*hsz) = F2H(ig);
312 DEVICE_LINEAR_GET(storage, offset+2*hsz) = F2H(ng);
313 DEVICE_LINEAR_GET(storage, offset+3*hsz) = hx;
314 DEVICE_LINEAR_GET(storage, offset+4*hsz) = F2H(H2F(hn) + H2F(b2n));
315 }
316 }
317
318 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
319 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
320 C10_LAUNCH_BOUNDS_2(512, 4)
321 #endif
gru_cell_backward(TensorInfo<scalar_t,index_type> gradInInput,TensorInfo<scalar_t,index_type> gradInHidden,TensorInfo<scalar_t,index_type> gradOutput,TensorInfo<scalar_t,index_type> gradInputHx,TensorInfo<scalar_t,index_type> storage,index_type hsz,index_type totalElements)322 __global__ void gru_cell_backward(
323 TensorInfo<scalar_t, index_type> gradInInput,
324 TensorInfo<scalar_t, index_type> gradInHidden,
325 TensorInfo<scalar_t, index_type> gradOutput,
326 TensorInfo<scalar_t, index_type> gradInputHx,
327 TensorInfo<scalar_t, index_type> storage,
328 index_type hsz,
329 index_type totalElements) {
330 for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
331 linearIndex < totalElements;
332 linearIndex += gridDim.x * blockDim.x) {
333 index_type offset = (linearIndex/hsz)*5*hsz+linearIndex%hsz;
334
335 scalar_t rg = DEVICE_LINEAR_GET(storage, offset+0*hsz);
336 scalar_t ig = DEVICE_LINEAR_GET(storage, offset+1*hsz);
337 scalar_t ng = DEVICE_LINEAR_GET(storage, offset+2*hsz);
338 scalar_t hx = DEVICE_LINEAR_GET(storage, offset+3*hsz);
339 scalar_t hn = DEVICE_LINEAR_GET(storage, offset+4*hsz);
340
341 scalar_t go = DEVICE_LINEAR_GET(gradOutput, linearIndex);
342
343 offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz;
344
345 accscalar_t gig = H2F(go)*( H2F(hx)-H2F(ng) )*( 1-H2F(ig) )*H2F(ig);
346 accscalar_t ghx = H2F(go)*H2F(ig);
347 accscalar_t gin = H2F(go)*( 1-H2F(ig) )*( 1-H2F(ng)*H2F(ng) );
348 accscalar_t ghn = gin * H2F(rg);
349 accscalar_t grg = gin *H2F(hn)*( 1-H2F(rg) )*H2F(rg);
350
351 DEVICE_LINEAR_GET(gradInInput, offset+0*hsz) = F2H(grg);
352 DEVICE_LINEAR_GET(gradInInput, offset+1*hsz) = F2H(gig);
353 DEVICE_LINEAR_GET(gradInInput, offset+2*hsz) = F2H(gin);
354
355 DEVICE_LINEAR_GET(gradInHidden, offset+0*hsz) = F2H(grg);
356 DEVICE_LINEAR_GET(gradInHidden, offset+1*hsz) = F2H(gig);
357 DEVICE_LINEAR_GET(gradInHidden, offset+2*hsz) = F2H(ghn);
358 DEVICE_LINEAR_GET(gradInputHx, linearIndex) = F2H(ghx);
359 }
360 }
361
362 #undef DEVICE_LINEAR_GET
363 #undef DEVICE_BIAS_GET
364 #undef H2F
365 #undef F2H
366
367 } // namespace kernel
368
369 template<typename scalar_t, typename index_type>
lstm_forward_impl(const Tensor & input_gates,const Tensor & hidden_gates,const Tensor & input_bias,const Tensor & hidden_bias,const Tensor & cx,const Tensor & hy,const Tensor & cy,const Tensor & workspace)370 void lstm_forward_impl(const Tensor& input_gates, const Tensor& hidden_gates,
371 const Tensor& input_bias, const Tensor& hidden_bias,
372 const Tensor& cx,
373 const Tensor& hy, const Tensor& cy, const Tensor& workspace) {
374 using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>;
375
376 dim3 block, grid;
377 int64_t numel = cx.numel();
378 if (numel == 0) return;
379 getLaunchConfig(&block, &grid, numel);
380
381 auto input_gatesI = getTensorInfo<scalar_t, index_type>(input_gates);
382 auto hidden_gatesI = getTensorInfo<scalar_t, index_type>(hidden_gates);
383 auto input_biasI = tryGetTensorInfo<scalar_t, index_type>(input_bias);
384 auto hidden_biasI = tryGetTensorInfo<scalar_t, index_type>(hidden_bias);
385 auto cxI = getTensorInfo<scalar_t, index_type>(cx);
386 auto hyI = getTensorInfo<scalar_t, index_type>(hy);
387 auto cyI = getTensorInfo<scalar_t, index_type>(cy);
388 auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace);
389 index_type hidden_size = cxI.sizes[cxI.dims-1];
390
391 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
392 if (allContiguous({input_gates, hidden_gates, input_bias, hidden_bias, cx, hy, cy, workspace})) {
393 collapseDims(input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, cxI, hyI, cyI, workspaceI);
394 kernel::lstm_cell_forward<scalar_t, accscalar_t, index_type, 1>
395 <<<grid, block, 0, stream>>>
396 (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, cxI, hyI, cyI, workspaceI, hidden_size, numel);
397 C10_CUDA_KERNEL_LAUNCH_CHECK();
398 } else {
399 kernel::lstm_cell_forward<scalar_t, accscalar_t, index_type, 2>
400 <<<grid, block, 0, stream>>>
401 (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, cxI, hyI, cyI, workspaceI, hidden_size, numel);
402 C10_CUDA_KERNEL_LAUNCH_CHECK();
403 }
404 }
405
406 template<typename scalar_t, typename index_type>
lstm_backward_impl(const Tensor & grad_hy,const Tensor & grad_cy,const Tensor & cx,const Tensor & cy,const Tensor & workspace,const Tensor & grad_gates,const Tensor & grad_cx)407 void lstm_backward_impl(const Tensor& grad_hy, const Tensor& grad_cy,
408 const Tensor& cx, const Tensor& cy,
409 const Tensor& workspace,
410 const Tensor& grad_gates, const Tensor& grad_cx) {
411 using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>;
412
413 dim3 block, grid;
414 int64_t numel = cx.numel();
415 getLaunchConfig(&block, &grid, numel);
416 if (numel == 0) return;
417
418 auto grad_hyI = tryGetTensorInfo<scalar_t, index_type>(grad_hy);
419 auto grad_cyI = tryGetTensorInfo<scalar_t, index_type>(grad_cy);
420 auto cxI = getTensorInfo<scalar_t, index_type>(cx);
421 auto cyI = getTensorInfo<scalar_t, index_type>(cy);
422 auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace);
423 auto grad_gatesI = getTensorInfo<scalar_t, index_type>(grad_gates);
424 auto grad_cxI = getTensorInfo<scalar_t, index_type>(grad_cx);
425 index_type hidden_size = cxI.sizes[cxI.dims-1];
426
427 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
428 if (allContiguous({grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx})) {
429 collapseDims(grad_hyI, grad_cyI, cxI, cyI, workspaceI, grad_gatesI, grad_cxI);
430 kernel::lstm_cell_backward<scalar_t, accscalar_t, index_type, 1>
431 <<<grid, block, 0, stream>>>
432 (workspaceI, grad_gatesI, cxI, cyI, grad_hyI, grad_cyI, grad_cxI, hidden_size, numel);
433 C10_CUDA_KERNEL_LAUNCH_CHECK();
434 } else {
435 kernel::lstm_cell_backward<scalar_t, accscalar_t, index_type, 2>
436 <<<grid, block, 0, stream>>>
437 (workspaceI, grad_gatesI, cxI, cyI, grad_hyI, grad_cyI, grad_cxI, hidden_size, numel);
438 C10_CUDA_KERNEL_LAUNCH_CHECK();
439 }
440 }
441
442 template<typename scalar_t, typename index_type>
gru_forward_impl(const Tensor & input_gates,const Tensor & hidden_gates,const Tensor & input_bias,const Tensor & hidden_bias,const Tensor & hx,const Tensor & hy,const Tensor & workspace)443 void gru_forward_impl(const Tensor& input_gates, const Tensor& hidden_gates,
444 const Tensor& input_bias, const Tensor& hidden_bias,
445 const Tensor& hx,
446 const Tensor& hy, const Tensor& workspace) {
447 using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>;
448
449 dim3 block, grid;
450 int64_t numel = hx.numel();
451 if (numel == 0) return;
452 getLaunchConfig(&block, &grid, numel);
453
454 auto input_gatesI = getTensorInfo<scalar_t, index_type>(input_gates);
455 auto hidden_gatesI = getTensorInfo<scalar_t, index_type>(hidden_gates);
456 auto input_biasI = tryGetTensorInfo<scalar_t, index_type>(input_bias);
457 auto hidden_biasI = tryGetTensorInfo<scalar_t, index_type>(hidden_bias);
458 auto hxI = getTensorInfo<scalar_t, index_type>(hx);
459 auto hyI = getTensorInfo<scalar_t, index_type>(hy);
460 auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace);
461 index_type hidden_size = hxI.sizes[hxI.dims-1];
462
463 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
464 if (allContiguous({input_gates, hidden_gates, input_bias, hidden_bias, hx, hy, workspace})) {
465 collapseDims(input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, hxI, hyI, workspaceI);
466 kernel::gru_cell_forward<scalar_t, accscalar_t, index_type, 1>
467 <<<grid, block, 0, stream>>>
468 (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, hxI, hyI, workspaceI, hidden_size, numel);
469 C10_CUDA_KERNEL_LAUNCH_CHECK();
470 } else {
471 kernel::gru_cell_forward<scalar_t, accscalar_t, index_type, 2>
472 <<<grid, block, 0, stream>>>
473 (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, hxI, hyI, workspaceI, hidden_size, numel);
474 C10_CUDA_KERNEL_LAUNCH_CHECK();
475 }
476 }
477
478 template<typename scalar_t, typename index_type>
gru_backward_impl(const Tensor & grad_hy,const Tensor & workspace,const Tensor & grad_input_gates,const Tensor & grad_hidden_gates,const Tensor & grad_hx)479 void gru_backward_impl(const Tensor& grad_hy, const Tensor& workspace,
480 const Tensor& grad_input_gates, const Tensor& grad_hidden_gates, const Tensor& grad_hx) {
481 using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>;
482
483 dim3 block, grid;
484 int64_t numel = grad_hy.numel();
485 if (numel == 0) return;
486 getLaunchConfig(&block, &grid, numel);
487
488 auto grad_hyI = getTensorInfo<scalar_t, index_type>(grad_hy);
489 auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace);
490 auto grad_input_gatesI = getTensorInfo<scalar_t, index_type>(grad_input_gates);
491 auto grad_hidden_gatesI = getTensorInfo<scalar_t, index_type>(grad_hidden_gates);
492 auto grad_hxI = getTensorInfo<scalar_t, index_type>(grad_hx);
493 index_type hidden_size = grad_hyI.sizes[grad_hyI.dims-1];
494
495 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
496 if (allContiguous({grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx})) {
497 collapseDims(grad_hyI, workspaceI, grad_input_gatesI, grad_hidden_gatesI, grad_hxI);
498 kernel::gru_cell_backward<scalar_t, accscalar_t, index_type, 1>
499 <<<grid, block, 0, stream>>>
500 (grad_input_gatesI, grad_hidden_gatesI, grad_hyI, grad_hxI, workspaceI, hidden_size, numel);
501 C10_CUDA_KERNEL_LAUNCH_CHECK();
502 } else {
503 kernel::gru_cell_backward<scalar_t, accscalar_t, index_type, 2>
504 <<<grid, block, 0, stream>>>
505 (grad_input_gatesI, grad_hidden_gatesI, grad_hyI, grad_hxI, workspaceI, hidden_size, numel);
506 C10_CUDA_KERNEL_LAUNCH_CHECK();
507 }
508 }
509
510 } // anonymous namespace
511
512 // Note [64-bit index math check elision]
513 // It's enough to perform the check for 64-bit math on the largest tensor only.
514 // If 32-bit is enough for it, it will suffice for all other tensors too, and we
515 // can save some work using this trick.
516
_thnn_fused_lstm_cell_cuda(const Tensor & input_gates,const Tensor & hidden_gates,const Tensor & cx,const std::optional<Tensor> & input_bias_opt,const std::optional<Tensor> & hidden_bias_opt)517 std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_cuda(
518 const Tensor& input_gates, const Tensor& hidden_gates,
519 const Tensor& cx, const std::optional<Tensor>& input_bias_opt, const std::optional<Tensor>& hidden_bias_opt) {
520 // See [Note: hacky wrapper removal for optional tensor]
521 c10::MaybeOwned<Tensor> input_bias_maybe_owned = at::borrow_from_optional_tensor(input_bias_opt);
522 const Tensor& input_bias = *input_bias_maybe_owned;
523 const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();});
524
525 checkSizes("_thnn_fused_lstm_cell_cuda",
526 {input_gates, "input_gates", 1}, {hidden_gates, "hidden_gates", 2},
527 {input_bias, "input_bias", 3}, {hidden_bias, "hidden_bias", 4},
528 /*factor=*/4, {cx, "prev_hidden", 5});
529
530 auto workspace = at::empty_like(input_gates, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
531 auto hy = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
532 auto cy = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
533 AT_DISPATCH_FLOATING_TYPES_AND2(
534 at::ScalarType::Half,
535 at::ScalarType::BFloat16,
536 input_gates.scalar_type(),
537 "_thnn_fused_lstm_cell_cuda",
538 [&] {
539 if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
540 lstm_forward_impl<scalar_t, int32_t>(input_gates, hidden_gates, input_bias, hidden_bias, cx, hy, cy, workspace);
541 } else {
542 lstm_forward_impl<scalar_t, int64_t>(input_gates, hidden_gates, input_bias, hidden_bias, cx, hy, cy, workspace);
543 }
544 });
545 return std::make_tuple(std::move(hy), std::move(cy), std::move(workspace));
546 }
547
checkLSTMBackwardSizes(const TensorArg & grad_hy,const TensorArg & grad_cy,const TensorArg & cx,const TensorArg & cy,const TensorArg & workspace)548 void checkLSTMBackwardSizes(const TensorArg& grad_hy, const TensorArg& grad_cy,
549 const TensorArg& cx, const TensorArg& cy,
550 const TensorArg& workspace) {
551 CheckedFrom c = "fused_lstm_cell_backward";
552 const TensorArg& defined_grad = grad_hy->defined() ? grad_hy : grad_cy;
553 checkDim(c, defined_grad, 2);
554 auto exp_size = defined_grad->sizes();
555 if (grad_hy->defined()) {
556 checkSize(c, grad_hy, exp_size);
557 }
558 if (grad_cy->defined()) {
559 checkSize(c, grad_cy, exp_size);
560 }
561 checkSize(c, cx, exp_size);
562 checkSize(c, cy, exp_size);
563 checkDim(c, workspace, 2);
564 checkNumel(c, workspace, exp_size[0] * exp_size[1] * 4);
565 }
566
_thnn_fused_lstm_cell_backward_impl_cuda(const std::optional<Tensor> & grad_hy_opt,const std::optional<Tensor> & grad_cy_opt,const Tensor & cx,const Tensor & cy,const Tensor & workspace,bool has_bias)567 std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backward_impl_cuda( const std::optional<Tensor>& grad_hy_opt, const std::optional<Tensor>& grad_cy_opt,
568 const Tensor& cx, const Tensor& cy,
569 const Tensor& workspace, bool has_bias) {
570 // See [Note: hacky wrapper removal for optional tensor]
571 c10::MaybeOwned<Tensor> grad_hy_maybe_owned = at::borrow_from_optional_tensor(grad_hy_opt);
572 const Tensor& grad_hy = *grad_hy_maybe_owned;
573 const Tensor& grad_cy = c10::value_or_else(grad_cy_opt, [] {return Tensor();});
574
575 if (!grad_hy.defined() && !grad_cy.defined()) {
576 return std::tuple<Tensor, Tensor, Tensor>();
577 }
578 checkLSTMBackwardSizes({grad_hy, "grad_hy", 1}, {grad_cy, "grad_cy", 2},
579 {cx, "cx", 3}, {cy, "cy", 4},
580 {workspace, "workspace", 5});
581
582 auto grad_gates = at::empty_like(workspace, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
583 auto grad_cx = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
584 AT_DISPATCH_FLOATING_TYPES_AND2(
585 at::ScalarType::Half,
586 at::ScalarType::BFloat16,
587 workspace.scalar_type(),
588 "_thnn_fused_lstm_cell_cuda_backward",
589 [&] {
590 if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
591 lstm_backward_impl<scalar_t, int32_t>(grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx);
592 } else {
593 lstm_backward_impl<scalar_t, int64_t>(grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx);
594 }
595 });
596
597 auto grad_bias = has_bias ? grad_gates.sum(0, /*keepdim=*/false) : at::Tensor{};
598 return std::make_tuple(std::move(grad_gates), std::move(grad_cx), std::move(grad_bias));
599 }
600
601 static constexpr int64_t GRU_WORKSPACE_MULTIPLIER = 5;
602
_thnn_fused_gru_cell_cuda(const Tensor & input_gates,const Tensor & hidden_gates,const Tensor & hx,const std::optional<Tensor> & input_bias_opt,const std::optional<Tensor> & hidden_bias_opt)603 std::tuple<Tensor, Tensor> _thnn_fused_gru_cell_cuda(
604 const Tensor& input_gates, const Tensor& hidden_gates,
605 const Tensor& hx, const std::optional<Tensor>& input_bias_opt, const std::optional<Tensor>& hidden_bias_opt) {
606 // See [Note: hacky wrapper removal for optional tensor]
607 c10::MaybeOwned<Tensor> input_bias_maybe_owned = at::borrow_from_optional_tensor(input_bias_opt);
608 const Tensor& input_bias = *input_bias_maybe_owned;
609 const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();});
610
611 checkSizes("_thnn_fused_gru_cell_cuda",
612 {input_gates, "input_gates", 1}, {hidden_gates, "hidden_gates", 2},
613 {input_bias, "input_bias", 3}, {hidden_bias, "hidden_bias", 4},
614 /*factor=*/3, {hx, "prev_hidden", 5});
615
616 auto workspace = at::empty({hx.size(0), hx.size(1) * GRU_WORKSPACE_MULTIPLIER}, hx.options());
617 auto hy = at::empty_like(hx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
618 AT_DISPATCH_FLOATING_TYPES_AND2(
619 at::ScalarType::Half,
620 at::ScalarType::BFloat16,
621 input_gates.scalar_type(),
622 "_thnn_fused_gru_cell_cuda",
623 [&] {
624 if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
625 gru_forward_impl<scalar_t, int32_t>(input_gates, hidden_gates, input_bias, hidden_bias, hx, hy, workspace);
626 } else {
627 gru_forward_impl<scalar_t, int64_t>(input_gates, hidden_gates, input_bias, hidden_bias, hx, hy, workspace);
628 }
629 });
630 return std::make_tuple(std::move(hy), std::move(workspace));
631 }
632
checkGRUBackwardSizes(const TensorArg & grad_hy,const TensorArg & workspace)633 void checkGRUBackwardSizes(const TensorArg& grad_hy, const TensorArg& workspace) {
634 CheckedFrom c = "fused_gru_cell_backward";
635 checkDim(c, grad_hy, 2);
636 checkSize(c, workspace, {grad_hy->size(0), grad_hy->size(1) * GRU_WORKSPACE_MULTIPLIER});
637 }
638
_thnn_fused_gru_cell_backward_cuda(const Tensor & grad_hy,const Tensor & workspace,bool has_bias)639 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_gru_cell_backward_cuda(
640 const Tensor& grad_hy, const Tensor& workspace, bool has_bias) {
641 checkGRUBackwardSizes({grad_hy, "grad_hy", 1}, {workspace, "workspace", 2});
642
643 int64_t hidden_size = workspace.size(1) / GRU_WORKSPACE_MULTIPLIER;
644 auto grad_input_gates = at::empty({workspace.size(0), hidden_size * 3}, workspace.options());
645 auto grad_hidden_gates = at::empty({workspace.size(0), hidden_size * 3}, workspace.options());
646 auto grad_hx = at::empty_like(grad_hy, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
647 AT_DISPATCH_FLOATING_TYPES_AND2(
648 at::ScalarType::Half,
649 at::ScalarType::BFloat16,
650 grad_hy.scalar_type(),
651 "_thnn_fused_gru_cell_cuda_backward",
652 [&] {
653 if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision]
654 gru_backward_impl<scalar_t, int32_t>(grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx);
655 } else {
656 gru_backward_impl<scalar_t, int64_t>(grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx);
657 }
658 });
659
660 at::Tensor grad_input_bias, grad_hidden_bias;
661 if (has_bias) {
662 grad_input_bias = grad_input_gates.sum(0, /*keepdim=*/false);
663 grad_hidden_bias = grad_hidden_gates.sum(0, /*keepdim=*/false);
664 }
665
666 return std::make_tuple(
667 std::move(grad_input_gates),
668 std::move(grad_hidden_gates),
669 std::move(grad_hx),
670 std::move(grad_input_bias),
671 std::move(grad_hidden_bias)
672 );
673 }
674
675 } // namespace at::native
676