xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/cells.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Tuple
2
3import torch
4from torch import Tensor
5
6
7def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
8    Wx = x.mm(w_ih.t())
9    Uz = hx.mm(w_hh.t())
10
11    # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf
12    gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias
13
14    # Same as LSTMCell after this point
15    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
16
17    ingate = ingate.sigmoid()
18    forgetgate = forgetgate.sigmoid()
19    cellgate = cellgate.tanh()
20    outgate = outgate.sigmoid()
21
22    cy = (forgetgate * cx) + (ingate * cellgate)
23    hy = outgate * cy.tanh()
24
25    return hy, cy
26
27
28def lstm_cell(
29    input: Tensor,
30    hidden: Tuple[Tensor, Tensor],
31    w_ih: Tensor,
32    w_hh: Tensor,
33    b_ih: Tensor,
34    b_hh: Tensor,
35) -> Tuple[Tensor, Tensor]:
36    hx, cx = hidden
37    gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
38
39    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
40
41    ingate = torch.sigmoid(ingate)
42    forgetgate = torch.sigmoid(forgetgate)
43    cellgate = torch.tanh(cellgate)
44    outgate = torch.sigmoid(outgate)
45
46    cy = (forgetgate * cx) + (ingate * cellgate)
47    hy = outgate * torch.tanh(cy)
48
49    return hy, cy
50
51
52def flat_lstm_cell(
53    input: Tensor,
54    hx: Tensor,
55    cx: Tensor,
56    w_ih: Tensor,
57    w_hh: Tensor,
58    b_ih: Tensor,
59    b_hh: Tensor,
60) -> Tuple[Tensor, Tensor]:
61    gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
62
63    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
64
65    ingate = torch.sigmoid(ingate)
66    forgetgate = torch.sigmoid(forgetgate)
67    cellgate = torch.tanh(cellgate)
68    outgate = torch.sigmoid(outgate)
69
70    cy = (forgetgate * cx) + (ingate * cellgate)
71    hy = outgate * torch.tanh(cy)
72
73    return hy, cy
74
75
76def premul_lstm_cell(
77    igates: Tensor,
78    hidden: Tuple[Tensor, Tensor],
79    w_hh: Tensor,
80    b_ih: Tensor,
81    b_hh: Tensor,
82) -> Tuple[Tensor, Tensor]:
83    hx, cx = hidden
84    gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh
85
86    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
87
88    ingate = torch.sigmoid(ingate)
89    forgetgate = torch.sigmoid(forgetgate)
90    cellgate = torch.tanh(cellgate)
91    outgate = torch.sigmoid(outgate)
92
93    cy = (forgetgate * cx) + (ingate * cellgate)
94    hy = outgate * torch.tanh(cy)
95
96    return hy, cy
97
98
99def premul_lstm_cell_no_bias(
100    igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor
101) -> Tuple[Tensor, Tensor]:
102    hx, cx = hidden
103    gates = igates + torch.mm(hx, w_hh.t()) + b_hh
104
105    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
106
107    ingate = torch.sigmoid(ingate)
108    forgetgate = torch.sigmoid(forgetgate)
109    cellgate = torch.tanh(cellgate)
110    outgate = torch.sigmoid(outgate)
111
112    cy = (forgetgate * cx) + (ingate * cellgate)
113    hy = outgate * torch.tanh(cy)
114
115    return hy, cy
116
117
118def gru_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
119    gi = torch.mm(input, w_ih.t()) + b_ih
120    gh = torch.mm(hidden, w_hh.t()) + b_hh
121    i_r, i_i, i_n = gi.chunk(3, 1)
122    h_r, h_i, h_n = gh.chunk(3, 1)
123
124    resetgate = torch.sigmoid(i_r + h_r)
125    inputgate = torch.sigmoid(i_i + h_i)
126    newgate = torch.tanh(i_n + resetgate * h_n)
127    hy = newgate + inputgate * (hidden - newgate)
128
129    return hy
130
131
132def rnn_relu_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
133    igates = torch.mm(input, w_ih.t()) + b_ih
134    hgates = torch.mm(hidden, w_hh.t()) + b_hh
135    return torch.relu(igates + hgates)
136
137
138def rnn_tanh_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
139    igates = torch.mm(input, w_ih.t()) + b_ih
140    hgates = torch.mm(hidden, w_hh.t()) + b_hh
141    return torch.tanh(igates + hgates)
142