1from typing import Tuple 2 3import torch 4from torch import Tensor 5 6 7def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias): 8 Wx = x.mm(w_ih.t()) 9 Uz = hx.mm(w_hh.t()) 10 11 # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf 12 gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias 13 14 # Same as LSTMCell after this point 15 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 16 17 ingate = ingate.sigmoid() 18 forgetgate = forgetgate.sigmoid() 19 cellgate = cellgate.tanh() 20 outgate = outgate.sigmoid() 21 22 cy = (forgetgate * cx) + (ingate * cellgate) 23 hy = outgate * cy.tanh() 24 25 return hy, cy 26 27 28def lstm_cell( 29 input: Tensor, 30 hidden: Tuple[Tensor, Tensor], 31 w_ih: Tensor, 32 w_hh: Tensor, 33 b_ih: Tensor, 34 b_hh: Tensor, 35) -> Tuple[Tensor, Tensor]: 36 hx, cx = hidden 37 gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh 38 39 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 40 41 ingate = torch.sigmoid(ingate) 42 forgetgate = torch.sigmoid(forgetgate) 43 cellgate = torch.tanh(cellgate) 44 outgate = torch.sigmoid(outgate) 45 46 cy = (forgetgate * cx) + (ingate * cellgate) 47 hy = outgate * torch.tanh(cy) 48 49 return hy, cy 50 51 52def flat_lstm_cell( 53 input: Tensor, 54 hx: Tensor, 55 cx: Tensor, 56 w_ih: Tensor, 57 w_hh: Tensor, 58 b_ih: Tensor, 59 b_hh: Tensor, 60) -> Tuple[Tensor, Tensor]: 61 gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh 62 63 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 64 65 ingate = torch.sigmoid(ingate) 66 forgetgate = torch.sigmoid(forgetgate) 67 cellgate = torch.tanh(cellgate) 68 outgate = torch.sigmoid(outgate) 69 70 cy = (forgetgate * cx) + (ingate * cellgate) 71 hy = outgate * torch.tanh(cy) 72 73 return hy, cy 74 75 76def premul_lstm_cell( 77 igates: Tensor, 78 hidden: Tuple[Tensor, Tensor], 79 w_hh: Tensor, 80 b_ih: Tensor, 81 b_hh: Tensor, 82) -> Tuple[Tensor, Tensor]: 83 hx, cx = hidden 84 gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh 85 86 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 87 88 ingate = torch.sigmoid(ingate) 89 forgetgate = torch.sigmoid(forgetgate) 90 cellgate = torch.tanh(cellgate) 91 outgate = torch.sigmoid(outgate) 92 93 cy = (forgetgate * cx) + (ingate * cellgate) 94 hy = outgate * torch.tanh(cy) 95 96 return hy, cy 97 98 99def premul_lstm_cell_no_bias( 100 igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor 101) -> Tuple[Tensor, Tensor]: 102 hx, cx = hidden 103 gates = igates + torch.mm(hx, w_hh.t()) + b_hh 104 105 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 106 107 ingate = torch.sigmoid(ingate) 108 forgetgate = torch.sigmoid(forgetgate) 109 cellgate = torch.tanh(cellgate) 110 outgate = torch.sigmoid(outgate) 111 112 cy = (forgetgate * cx) + (ingate * cellgate) 113 hy = outgate * torch.tanh(cy) 114 115 return hy, cy 116 117 118def gru_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): 119 gi = torch.mm(input, w_ih.t()) + b_ih 120 gh = torch.mm(hidden, w_hh.t()) + b_hh 121 i_r, i_i, i_n = gi.chunk(3, 1) 122 h_r, h_i, h_n = gh.chunk(3, 1) 123 124 resetgate = torch.sigmoid(i_r + h_r) 125 inputgate = torch.sigmoid(i_i + h_i) 126 newgate = torch.tanh(i_n + resetgate * h_n) 127 hy = newgate + inputgate * (hidden - newgate) 128 129 return hy 130 131 132def rnn_relu_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): 133 igates = torch.mm(input, w_ih.t()) + b_ih 134 hgates = torch.mm(hidden, w_hh.t()) + b_hh 135 return torch.relu(igates + hgates) 136 137 138def rnn_tanh_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): 139 igates = torch.mm(input, w_ih.t()) + b_ih 140 hgates = torch.mm(hidden, w_hh.t()) + b_hh 141 return torch.tanh(igates + hgates) 142