xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/custom_lstms.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport numbers
2*da0073e9SAndroid Build Coastguard Workerimport warnings
3*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
4*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport torch.jit as jit
8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
9*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
10*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Parameter
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker"""
14*da0073e9SAndroid Build Coastguard WorkerSome helper classes for writing custom TorchScript LSTMs.
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard WorkerGoals:
17*da0073e9SAndroid Build Coastguard Worker- Classes are easy to read, use, and extend
18*da0073e9SAndroid Build Coastguard Worker- Performance of custom LSTMs approach fused-kernel-levels of speed.
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard WorkerA few notes about features we could add to clean up the below code:
21*da0073e9SAndroid Build Coastguard Worker- Support enumerate with nn.ModuleList:
22*da0073e9SAndroid Build Coastguard Worker  https://github.com/pytorch/pytorch/issues/14471
23*da0073e9SAndroid Build Coastguard Worker- Support enumerate/zip with lists:
24*da0073e9SAndroid Build Coastguard Worker  https://github.com/pytorch/pytorch/issues/15952
25*da0073e9SAndroid Build Coastguard Worker- Support overriding of class methods:
26*da0073e9SAndroid Build Coastguard Worker  https://github.com/pytorch/pytorch/issues/10733
27*da0073e9SAndroid Build Coastguard Worker- Support passing around user-defined namedtuple types for readability
28*da0073e9SAndroid Build Coastguard Worker- Support slicing w/ range. It enables reversing lists easily.
29*da0073e9SAndroid Build Coastguard Worker  https://github.com/pytorch/pytorch/issues/10774
30*da0073e9SAndroid Build Coastguard Worker- Multiline type annotations. List[List[Tuple[Tensor,Tensor]]] is verbose
31*da0073e9SAndroid Build Coastguard Worker  https://github.com/pytorch/pytorch/pull/14922
32*da0073e9SAndroid Build Coastguard Worker"""
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerdef script_lstm(
36*da0073e9SAndroid Build Coastguard Worker    input_size,
37*da0073e9SAndroid Build Coastguard Worker    hidden_size,
38*da0073e9SAndroid Build Coastguard Worker    num_layers,
39*da0073e9SAndroid Build Coastguard Worker    bias=True,
40*da0073e9SAndroid Build Coastguard Worker    batch_first=False,
41*da0073e9SAndroid Build Coastguard Worker    dropout=False,
42*da0073e9SAndroid Build Coastguard Worker    bidirectional=False,
43*da0073e9SAndroid Build Coastguard Worker):
44*da0073e9SAndroid Build Coastguard Worker    """Returns a ScriptModule that mimics a PyTorch native LSTM."""
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    # The following are not implemented.
47*da0073e9SAndroid Build Coastguard Worker    assert bias
48*da0073e9SAndroid Build Coastguard Worker    assert not batch_first
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    if bidirectional:
51*da0073e9SAndroid Build Coastguard Worker        stack_type = StackedLSTM2
52*da0073e9SAndroid Build Coastguard Worker        layer_type = BidirLSTMLayer
53*da0073e9SAndroid Build Coastguard Worker        dirs = 2
54*da0073e9SAndroid Build Coastguard Worker    elif dropout:
55*da0073e9SAndroid Build Coastguard Worker        stack_type = StackedLSTMWithDropout
56*da0073e9SAndroid Build Coastguard Worker        layer_type = LSTMLayer
57*da0073e9SAndroid Build Coastguard Worker        dirs = 1
58*da0073e9SAndroid Build Coastguard Worker    else:
59*da0073e9SAndroid Build Coastguard Worker        stack_type = StackedLSTM
60*da0073e9SAndroid Build Coastguard Worker        layer_type = LSTMLayer
61*da0073e9SAndroid Build Coastguard Worker        dirs = 1
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    return stack_type(
64*da0073e9SAndroid Build Coastguard Worker        num_layers,
65*da0073e9SAndroid Build Coastguard Worker        layer_type,
66*da0073e9SAndroid Build Coastguard Worker        first_layer_args=[LSTMCell, input_size, hidden_size],
67*da0073e9SAndroid Build Coastguard Worker        other_layer_args=[LSTMCell, hidden_size * dirs, hidden_size],
68*da0073e9SAndroid Build Coastguard Worker    )
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerdef script_lnlstm(
72*da0073e9SAndroid Build Coastguard Worker    input_size,
73*da0073e9SAndroid Build Coastguard Worker    hidden_size,
74*da0073e9SAndroid Build Coastguard Worker    num_layers,
75*da0073e9SAndroid Build Coastguard Worker    bias=True,
76*da0073e9SAndroid Build Coastguard Worker    batch_first=False,
77*da0073e9SAndroid Build Coastguard Worker    dropout=False,
78*da0073e9SAndroid Build Coastguard Worker    bidirectional=False,
79*da0073e9SAndroid Build Coastguard Worker    decompose_layernorm=False,
80*da0073e9SAndroid Build Coastguard Worker):
81*da0073e9SAndroid Build Coastguard Worker    """Returns a ScriptModule that mimics a PyTorch native LSTM."""
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    # The following are not implemented.
84*da0073e9SAndroid Build Coastguard Worker    assert bias
85*da0073e9SAndroid Build Coastguard Worker    assert not batch_first
86*da0073e9SAndroid Build Coastguard Worker    assert not dropout
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    if bidirectional:
89*da0073e9SAndroid Build Coastguard Worker        stack_type = StackedLSTM2
90*da0073e9SAndroid Build Coastguard Worker        layer_type = BidirLSTMLayer
91*da0073e9SAndroid Build Coastguard Worker        dirs = 2
92*da0073e9SAndroid Build Coastguard Worker    else:
93*da0073e9SAndroid Build Coastguard Worker        stack_type = StackedLSTM
94*da0073e9SAndroid Build Coastguard Worker        layer_type = LSTMLayer
95*da0073e9SAndroid Build Coastguard Worker        dirs = 1
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    return stack_type(
98*da0073e9SAndroid Build Coastguard Worker        num_layers,
99*da0073e9SAndroid Build Coastguard Worker        layer_type,
100*da0073e9SAndroid Build Coastguard Worker        first_layer_args=[
101*da0073e9SAndroid Build Coastguard Worker            LayerNormLSTMCell,
102*da0073e9SAndroid Build Coastguard Worker            input_size,
103*da0073e9SAndroid Build Coastguard Worker            hidden_size,
104*da0073e9SAndroid Build Coastguard Worker            decompose_layernorm,
105*da0073e9SAndroid Build Coastguard Worker        ],
106*da0073e9SAndroid Build Coastguard Worker        other_layer_args=[
107*da0073e9SAndroid Build Coastguard Worker            LayerNormLSTMCell,
108*da0073e9SAndroid Build Coastguard Worker            hidden_size * dirs,
109*da0073e9SAndroid Build Coastguard Worker            hidden_size,
110*da0073e9SAndroid Build Coastguard Worker            decompose_layernorm,
111*da0073e9SAndroid Build Coastguard Worker        ],
112*da0073e9SAndroid Build Coastguard Worker    )
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard WorkerLSTMState = namedtuple("LSTMState", ["hx", "cx"])
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Workerdef reverse(lst: List[Tensor]) -> List[Tensor]:
119*da0073e9SAndroid Build Coastguard Worker    return lst[::-1]
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Workerclass LSTMCell(jit.ScriptModule):
123*da0073e9SAndroid Build Coastguard Worker    def __init__(self, input_size, hidden_size):
124*da0073e9SAndroid Build Coastguard Worker        super().__init__()
125*da0073e9SAndroid Build Coastguard Worker        self.input_size = input_size
126*da0073e9SAndroid Build Coastguard Worker        self.hidden_size = hidden_size
127*da0073e9SAndroid Build Coastguard Worker        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
128*da0073e9SAndroid Build Coastguard Worker        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
129*da0073e9SAndroid Build Coastguard Worker        self.bias_ih = Parameter(torch.randn(4 * hidden_size))
130*da0073e9SAndroid Build Coastguard Worker        self.bias_hh = Parameter(torch.randn(4 * hidden_size))
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker    @jit.script_method
133*da0073e9SAndroid Build Coastguard Worker    def forward(
134*da0073e9SAndroid Build Coastguard Worker        self, input: Tensor, state: Tuple[Tensor, Tensor]
135*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
136*da0073e9SAndroid Build Coastguard Worker        hx, cx = state
137*da0073e9SAndroid Build Coastguard Worker        gates = (
138*da0073e9SAndroid Build Coastguard Worker            torch.mm(input, self.weight_ih.t())
139*da0073e9SAndroid Build Coastguard Worker            + self.bias_ih
140*da0073e9SAndroid Build Coastguard Worker            + torch.mm(hx, self.weight_hh.t())
141*da0073e9SAndroid Build Coastguard Worker            + self.bias_hh
142*da0073e9SAndroid Build Coastguard Worker        )
143*da0073e9SAndroid Build Coastguard Worker        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker        ingate = torch.sigmoid(ingate)
146*da0073e9SAndroid Build Coastguard Worker        forgetgate = torch.sigmoid(forgetgate)
147*da0073e9SAndroid Build Coastguard Worker        cellgate = torch.tanh(cellgate)
148*da0073e9SAndroid Build Coastguard Worker        outgate = torch.sigmoid(outgate)
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        cy = (forgetgate * cx) + (ingate * cellgate)
151*da0073e9SAndroid Build Coastguard Worker        hy = outgate * torch.tanh(cy)
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker        return hy, (hy, cy)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Workerclass LayerNorm(jit.ScriptModule):
157*da0073e9SAndroid Build Coastguard Worker    def __init__(self, normalized_shape):
158*da0073e9SAndroid Build Coastguard Worker        super().__init__()
159*da0073e9SAndroid Build Coastguard Worker        if isinstance(normalized_shape, numbers.Integral):
160*da0073e9SAndroid Build Coastguard Worker            normalized_shape = (normalized_shape,)
161*da0073e9SAndroid Build Coastguard Worker        normalized_shape = torch.Size(normalized_shape)
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        # XXX: This is true for our LSTM / NLP use case and helps simplify code
164*da0073e9SAndroid Build Coastguard Worker        assert len(normalized_shape) == 1
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        self.weight = Parameter(torch.ones(normalized_shape))
167*da0073e9SAndroid Build Coastguard Worker        self.bias = Parameter(torch.zeros(normalized_shape))
168*da0073e9SAndroid Build Coastguard Worker        self.normalized_shape = normalized_shape
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    @jit.script_method
171*da0073e9SAndroid Build Coastguard Worker    def compute_layernorm_stats(self, input):
172*da0073e9SAndroid Build Coastguard Worker        mu = input.mean(-1, keepdim=True)
173*da0073e9SAndroid Build Coastguard Worker        sigma = input.std(-1, keepdim=True, unbiased=False)
174*da0073e9SAndroid Build Coastguard Worker        return mu, sigma
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    @jit.script_method
177*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
178*da0073e9SAndroid Build Coastguard Worker        mu, sigma = self.compute_layernorm_stats(input)
179*da0073e9SAndroid Build Coastguard Worker        return (input - mu) / sigma * self.weight + self.bias
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Workerclass LayerNormLSTMCell(jit.ScriptModule):
183*da0073e9SAndroid Build Coastguard Worker    def __init__(self, input_size, hidden_size, decompose_layernorm=False):
184*da0073e9SAndroid Build Coastguard Worker        super().__init__()
185*da0073e9SAndroid Build Coastguard Worker        self.input_size = input_size
186*da0073e9SAndroid Build Coastguard Worker        self.hidden_size = hidden_size
187*da0073e9SAndroid Build Coastguard Worker        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
188*da0073e9SAndroid Build Coastguard Worker        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
189*da0073e9SAndroid Build Coastguard Worker        # The layernorms provide learnable biases
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        if decompose_layernorm:
192*da0073e9SAndroid Build Coastguard Worker            ln = LayerNorm
193*da0073e9SAndroid Build Coastguard Worker        else:
194*da0073e9SAndroid Build Coastguard Worker            ln = nn.LayerNorm
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        self.layernorm_i = ln(4 * hidden_size)
197*da0073e9SAndroid Build Coastguard Worker        self.layernorm_h = ln(4 * hidden_size)
198*da0073e9SAndroid Build Coastguard Worker        self.layernorm_c = ln(hidden_size)
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker    @jit.script_method
201*da0073e9SAndroid Build Coastguard Worker    def forward(
202*da0073e9SAndroid Build Coastguard Worker        self, input: Tensor, state: Tuple[Tensor, Tensor]
203*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
204*da0073e9SAndroid Build Coastguard Worker        hx, cx = state
205*da0073e9SAndroid Build Coastguard Worker        igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
206*da0073e9SAndroid Build Coastguard Worker        hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
207*da0073e9SAndroid Build Coastguard Worker        gates = igates + hgates
208*da0073e9SAndroid Build Coastguard Worker        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker        ingate = torch.sigmoid(ingate)
211*da0073e9SAndroid Build Coastguard Worker        forgetgate = torch.sigmoid(forgetgate)
212*da0073e9SAndroid Build Coastguard Worker        cellgate = torch.tanh(cellgate)
213*da0073e9SAndroid Build Coastguard Worker        outgate = torch.sigmoid(outgate)
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker        cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate))
216*da0073e9SAndroid Build Coastguard Worker        hy = outgate * torch.tanh(cy)
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        return hy, (hy, cy)
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Workerclass LSTMLayer(jit.ScriptModule):
222*da0073e9SAndroid Build Coastguard Worker    def __init__(self, cell, *cell_args):
223*da0073e9SAndroid Build Coastguard Worker        super().__init__()
224*da0073e9SAndroid Build Coastguard Worker        self.cell = cell(*cell_args)
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker    @jit.script_method
227*da0073e9SAndroid Build Coastguard Worker    def forward(
228*da0073e9SAndroid Build Coastguard Worker        self, input: Tensor, state: Tuple[Tensor, Tensor]
229*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
230*da0073e9SAndroid Build Coastguard Worker        inputs = input.unbind(0)
231*da0073e9SAndroid Build Coastguard Worker        outputs = torch.jit.annotate(List[Tensor], [])
232*da0073e9SAndroid Build Coastguard Worker        for i in range(len(inputs)):
233*da0073e9SAndroid Build Coastguard Worker            out, state = self.cell(inputs[i], state)
234*da0073e9SAndroid Build Coastguard Worker            outputs += [out]
235*da0073e9SAndroid Build Coastguard Worker        return torch.stack(outputs), state
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Workerclass ReverseLSTMLayer(jit.ScriptModule):
239*da0073e9SAndroid Build Coastguard Worker    def __init__(self, cell, *cell_args):
240*da0073e9SAndroid Build Coastguard Worker        super().__init__()
241*da0073e9SAndroid Build Coastguard Worker        self.cell = cell(*cell_args)
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    @jit.script_method
244*da0073e9SAndroid Build Coastguard Worker    def forward(
245*da0073e9SAndroid Build Coastguard Worker        self, input: Tensor, state: Tuple[Tensor, Tensor]
246*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
247*da0073e9SAndroid Build Coastguard Worker        inputs = reverse(input.unbind(0))
248*da0073e9SAndroid Build Coastguard Worker        outputs = jit.annotate(List[Tensor], [])
249*da0073e9SAndroid Build Coastguard Worker        for i in range(len(inputs)):
250*da0073e9SAndroid Build Coastguard Worker            out, state = self.cell(inputs[i], state)
251*da0073e9SAndroid Build Coastguard Worker            outputs += [out]
252*da0073e9SAndroid Build Coastguard Worker        return torch.stack(reverse(outputs)), state
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Workerclass BidirLSTMLayer(jit.ScriptModule):
256*da0073e9SAndroid Build Coastguard Worker    __constants__ = ["directions"]
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker    def __init__(self, cell, *cell_args):
259*da0073e9SAndroid Build Coastguard Worker        super().__init__()
260*da0073e9SAndroid Build Coastguard Worker        self.directions = nn.ModuleList(
261*da0073e9SAndroid Build Coastguard Worker            [
262*da0073e9SAndroid Build Coastguard Worker                LSTMLayer(cell, *cell_args),
263*da0073e9SAndroid Build Coastguard Worker                ReverseLSTMLayer(cell, *cell_args),
264*da0073e9SAndroid Build Coastguard Worker            ]
265*da0073e9SAndroid Build Coastguard Worker        )
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    @jit.script_method
268*da0073e9SAndroid Build Coastguard Worker    def forward(
269*da0073e9SAndroid Build Coastguard Worker        self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
270*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
271*da0073e9SAndroid Build Coastguard Worker        # List[LSTMState]: [forward LSTMState, backward LSTMState]
272*da0073e9SAndroid Build Coastguard Worker        outputs = jit.annotate(List[Tensor], [])
273*da0073e9SAndroid Build Coastguard Worker        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
274*da0073e9SAndroid Build Coastguard Worker        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
275*da0073e9SAndroid Build Coastguard Worker        i = 0
276*da0073e9SAndroid Build Coastguard Worker        for direction in self.directions:
277*da0073e9SAndroid Build Coastguard Worker            state = states[i]
278*da0073e9SAndroid Build Coastguard Worker            out, out_state = direction(input, state)
279*da0073e9SAndroid Build Coastguard Worker            outputs += [out]
280*da0073e9SAndroid Build Coastguard Worker            output_states += [out_state]
281*da0073e9SAndroid Build Coastguard Worker            i += 1
282*da0073e9SAndroid Build Coastguard Worker        return torch.cat(outputs, -1), output_states
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Workerdef init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):
286*da0073e9SAndroid Build Coastguard Worker    layers = [layer(*first_layer_args)] + [
287*da0073e9SAndroid Build Coastguard Worker        layer(*other_layer_args) for _ in range(num_layers - 1)
288*da0073e9SAndroid Build Coastguard Worker    ]
289*da0073e9SAndroid Build Coastguard Worker    return nn.ModuleList(layers)
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Workerclass StackedLSTM(jit.ScriptModule):
293*da0073e9SAndroid Build Coastguard Worker    __constants__ = ["layers"]  # Necessary for iterating through self.layers
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
296*da0073e9SAndroid Build Coastguard Worker        super().__init__()
297*da0073e9SAndroid Build Coastguard Worker        self.layers = init_stacked_lstm(
298*da0073e9SAndroid Build Coastguard Worker            num_layers, layer, first_layer_args, other_layer_args
299*da0073e9SAndroid Build Coastguard Worker        )
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker    @jit.script_method
302*da0073e9SAndroid Build Coastguard Worker    def forward(
303*da0073e9SAndroid Build Coastguard Worker        self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
304*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
305*da0073e9SAndroid Build Coastguard Worker        # List[LSTMState]: One state per layer
306*da0073e9SAndroid Build Coastguard Worker        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
307*da0073e9SAndroid Build Coastguard Worker        output = input
308*da0073e9SAndroid Build Coastguard Worker        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
309*da0073e9SAndroid Build Coastguard Worker        i = 0
310*da0073e9SAndroid Build Coastguard Worker        for rnn_layer in self.layers:
311*da0073e9SAndroid Build Coastguard Worker            state = states[i]
312*da0073e9SAndroid Build Coastguard Worker            output, out_state = rnn_layer(output, state)
313*da0073e9SAndroid Build Coastguard Worker            output_states += [out_state]
314*da0073e9SAndroid Build Coastguard Worker            i += 1
315*da0073e9SAndroid Build Coastguard Worker        return output, output_states
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker# Differs from StackedLSTM in that its forward method takes
319*da0073e9SAndroid Build Coastguard Worker# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM
320*da0073e9SAndroid Build Coastguard Worker# except we don't support overriding script methods.
321*da0073e9SAndroid Build Coastguard Worker# https://github.com/pytorch/pytorch/issues/10733
322*da0073e9SAndroid Build Coastguard Workerclass StackedLSTM2(jit.ScriptModule):
323*da0073e9SAndroid Build Coastguard Worker    __constants__ = ["layers"]  # Necessary for iterating through self.layers
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
326*da0073e9SAndroid Build Coastguard Worker        super().__init__()
327*da0073e9SAndroid Build Coastguard Worker        self.layers = init_stacked_lstm(
328*da0073e9SAndroid Build Coastguard Worker            num_layers, layer, first_layer_args, other_layer_args
329*da0073e9SAndroid Build Coastguard Worker        )
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker    @jit.script_method
332*da0073e9SAndroid Build Coastguard Worker    def forward(
333*da0073e9SAndroid Build Coastguard Worker        self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]]
334*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]:
335*da0073e9SAndroid Build Coastguard Worker        # List[List[LSTMState]]: The outer list is for layers,
336*da0073e9SAndroid Build Coastguard Worker        #                        inner list is for directions.
337*da0073e9SAndroid Build Coastguard Worker        output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
338*da0073e9SAndroid Build Coastguard Worker        output = input
339*da0073e9SAndroid Build Coastguard Worker        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
340*da0073e9SAndroid Build Coastguard Worker        i = 0
341*da0073e9SAndroid Build Coastguard Worker        for rnn_layer in self.layers:
342*da0073e9SAndroid Build Coastguard Worker            state = states[i]
343*da0073e9SAndroid Build Coastguard Worker            output, out_state = rnn_layer(output, state)
344*da0073e9SAndroid Build Coastguard Worker            output_states += [out_state]
345*da0073e9SAndroid Build Coastguard Worker            i += 1
346*da0073e9SAndroid Build Coastguard Worker        return output, output_states
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Workerclass StackedLSTMWithDropout(jit.ScriptModule):
350*da0073e9SAndroid Build Coastguard Worker    # Necessary for iterating through self.layers and dropout support
351*da0073e9SAndroid Build Coastguard Worker    __constants__ = ["layers", "num_layers"]
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
354*da0073e9SAndroid Build Coastguard Worker        super().__init__()
355*da0073e9SAndroid Build Coastguard Worker        self.layers = init_stacked_lstm(
356*da0073e9SAndroid Build Coastguard Worker            num_layers, layer, first_layer_args, other_layer_args
357*da0073e9SAndroid Build Coastguard Worker        )
358*da0073e9SAndroid Build Coastguard Worker        # Introduces a Dropout layer on the outputs of each LSTM layer except
359*da0073e9SAndroid Build Coastguard Worker        # the last layer, with dropout probability = 0.4.
360*da0073e9SAndroid Build Coastguard Worker        self.num_layers = num_layers
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker        if num_layers == 1:
363*da0073e9SAndroid Build Coastguard Worker            warnings.warn(
364*da0073e9SAndroid Build Coastguard Worker                "dropout lstm adds dropout layers after all but last "
365*da0073e9SAndroid Build Coastguard Worker                "recurrent layer, it expects num_layers greater than "
366*da0073e9SAndroid Build Coastguard Worker                "1, but got num_layers = 1"
367*da0073e9SAndroid Build Coastguard Worker            )
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        self.dropout_layer = nn.Dropout(0.4)
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker    @jit.script_method
372*da0073e9SAndroid Build Coastguard Worker    def forward(
373*da0073e9SAndroid Build Coastguard Worker        self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
374*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
375*da0073e9SAndroid Build Coastguard Worker        # List[LSTMState]: One state per layer
376*da0073e9SAndroid Build Coastguard Worker        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
377*da0073e9SAndroid Build Coastguard Worker        output = input
378*da0073e9SAndroid Build Coastguard Worker        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
379*da0073e9SAndroid Build Coastguard Worker        i = 0
380*da0073e9SAndroid Build Coastguard Worker        for rnn_layer in self.layers:
381*da0073e9SAndroid Build Coastguard Worker            state = states[i]
382*da0073e9SAndroid Build Coastguard Worker            output, out_state = rnn_layer(output, state)
383*da0073e9SAndroid Build Coastguard Worker            # Apply the dropout layer except the last layer
384*da0073e9SAndroid Build Coastguard Worker            if i < self.num_layers - 1:
385*da0073e9SAndroid Build Coastguard Worker                output = self.dropout_layer(output)
386*da0073e9SAndroid Build Coastguard Worker            output_states += [out_state]
387*da0073e9SAndroid Build Coastguard Worker            i += 1
388*da0073e9SAndroid Build Coastguard Worker        return output, output_states
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Workerdef flatten_states(states):
392*da0073e9SAndroid Build Coastguard Worker    states = list(zip(*states))
393*da0073e9SAndroid Build Coastguard Worker    assert len(states) == 2
394*da0073e9SAndroid Build Coastguard Worker    return [torch.stack(state) for state in states]
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Workerdef double_flatten_states(states):
398*da0073e9SAndroid Build Coastguard Worker    # XXX: Can probably write this in a nicer way
399*da0073e9SAndroid Build Coastguard Worker    states = flatten_states([flatten_states(inner) for inner in states])
400*da0073e9SAndroid Build Coastguard Worker    return [hidden.view([-1] + list(hidden.shape[2:])) for hidden in states]
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Workerdef test_script_rnn_layer(seq_len, batch, input_size, hidden_size):
404*da0073e9SAndroid Build Coastguard Worker    inp = torch.randn(seq_len, batch, input_size)
405*da0073e9SAndroid Build Coastguard Worker    state = LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
406*da0073e9SAndroid Build Coastguard Worker    rnn = LSTMLayer(LSTMCell, input_size, hidden_size)
407*da0073e9SAndroid Build Coastguard Worker    out, out_state = rnn(inp, state)
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker    # Control: pytorch native LSTM
410*da0073e9SAndroid Build Coastguard Worker    lstm = nn.LSTM(input_size, hidden_size, 1)
411*da0073e9SAndroid Build Coastguard Worker    lstm_state = LSTMState(state.hx.unsqueeze(0), state.cx.unsqueeze(0))
412*da0073e9SAndroid Build Coastguard Worker    for lstm_param, custom_param in zip(lstm.all_weights[0], rnn.parameters()):
413*da0073e9SAndroid Build Coastguard Worker        assert lstm_param.shape == custom_param.shape
414*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
415*da0073e9SAndroid Build Coastguard Worker            lstm_param.copy_(custom_param)
416*da0073e9SAndroid Build Coastguard Worker    lstm_out, lstm_out_state = lstm(inp, lstm_state)
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker    assert (out - lstm_out).abs().max() < 1e-5
419*da0073e9SAndroid Build Coastguard Worker    assert (out_state[0] - lstm_out_state[0]).abs().max() < 1e-5
420*da0073e9SAndroid Build Coastguard Worker    assert (out_state[1] - lstm_out_state[1]).abs().max() < 1e-5
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Workerdef test_script_stacked_rnn(seq_len, batch, input_size, hidden_size, num_layers):
424*da0073e9SAndroid Build Coastguard Worker    inp = torch.randn(seq_len, batch, input_size)
425*da0073e9SAndroid Build Coastguard Worker    states = [
426*da0073e9SAndroid Build Coastguard Worker        LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
427*da0073e9SAndroid Build Coastguard Worker        for _ in range(num_layers)
428*da0073e9SAndroid Build Coastguard Worker    ]
429*da0073e9SAndroid Build Coastguard Worker    rnn = script_lstm(input_size, hidden_size, num_layers)
430*da0073e9SAndroid Build Coastguard Worker    out, out_state = rnn(inp, states)
431*da0073e9SAndroid Build Coastguard Worker    custom_state = flatten_states(out_state)
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    # Control: pytorch native LSTM
434*da0073e9SAndroid Build Coastguard Worker    lstm = nn.LSTM(input_size, hidden_size, num_layers)
435*da0073e9SAndroid Build Coastguard Worker    lstm_state = flatten_states(states)
436*da0073e9SAndroid Build Coastguard Worker    for layer in range(num_layers):
437*da0073e9SAndroid Build Coastguard Worker        custom_params = list(rnn.parameters())[4 * layer : 4 * (layer + 1)]
438*da0073e9SAndroid Build Coastguard Worker        for lstm_param, custom_param in zip(lstm.all_weights[layer], custom_params):
439*da0073e9SAndroid Build Coastguard Worker            assert lstm_param.shape == custom_param.shape
440*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
441*da0073e9SAndroid Build Coastguard Worker                lstm_param.copy_(custom_param)
442*da0073e9SAndroid Build Coastguard Worker    lstm_out, lstm_out_state = lstm(inp, lstm_state)
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker    assert (out - lstm_out).abs().max() < 1e-5
445*da0073e9SAndroid Build Coastguard Worker    assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5
446*da0073e9SAndroid Build Coastguard Worker    assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Workerdef test_script_stacked_bidir_rnn(seq_len, batch, input_size, hidden_size, num_layers):
450*da0073e9SAndroid Build Coastguard Worker    inp = torch.randn(seq_len, batch, input_size)
451*da0073e9SAndroid Build Coastguard Worker    states = [
452*da0073e9SAndroid Build Coastguard Worker        [
453*da0073e9SAndroid Build Coastguard Worker            LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
454*da0073e9SAndroid Build Coastguard Worker            for _ in range(2)
455*da0073e9SAndroid Build Coastguard Worker        ]
456*da0073e9SAndroid Build Coastguard Worker        for _ in range(num_layers)
457*da0073e9SAndroid Build Coastguard Worker    ]
458*da0073e9SAndroid Build Coastguard Worker    rnn = script_lstm(input_size, hidden_size, num_layers, bidirectional=True)
459*da0073e9SAndroid Build Coastguard Worker    out, out_state = rnn(inp, states)
460*da0073e9SAndroid Build Coastguard Worker    custom_state = double_flatten_states(out_state)
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker    # Control: pytorch native LSTM
463*da0073e9SAndroid Build Coastguard Worker    lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True)
464*da0073e9SAndroid Build Coastguard Worker    lstm_state = double_flatten_states(states)
465*da0073e9SAndroid Build Coastguard Worker    for layer in range(num_layers):
466*da0073e9SAndroid Build Coastguard Worker        for direct in range(2):
467*da0073e9SAndroid Build Coastguard Worker            index = 2 * layer + direct
468*da0073e9SAndroid Build Coastguard Worker            custom_params = list(rnn.parameters())[4 * index : 4 * index + 4]
469*da0073e9SAndroid Build Coastguard Worker            for lstm_param, custom_param in zip(lstm.all_weights[index], custom_params):
470*da0073e9SAndroid Build Coastguard Worker                assert lstm_param.shape == custom_param.shape
471*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
472*da0073e9SAndroid Build Coastguard Worker                    lstm_param.copy_(custom_param)
473*da0073e9SAndroid Build Coastguard Worker    lstm_out, lstm_out_state = lstm(inp, lstm_state)
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker    assert (out - lstm_out).abs().max() < 1e-5
476*da0073e9SAndroid Build Coastguard Worker    assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5
477*da0073e9SAndroid Build Coastguard Worker    assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Workerdef test_script_stacked_lstm_dropout(
481*da0073e9SAndroid Build Coastguard Worker    seq_len, batch, input_size, hidden_size, num_layers
482*da0073e9SAndroid Build Coastguard Worker):
483*da0073e9SAndroid Build Coastguard Worker    inp = torch.randn(seq_len, batch, input_size)
484*da0073e9SAndroid Build Coastguard Worker    states = [
485*da0073e9SAndroid Build Coastguard Worker        LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
486*da0073e9SAndroid Build Coastguard Worker        for _ in range(num_layers)
487*da0073e9SAndroid Build Coastguard Worker    ]
488*da0073e9SAndroid Build Coastguard Worker    rnn = script_lstm(input_size, hidden_size, num_layers, dropout=True)
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker    # just a smoke test
491*da0073e9SAndroid Build Coastguard Worker    out, out_state = rnn(inp, states)
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Workerdef test_script_stacked_lnlstm(seq_len, batch, input_size, hidden_size, num_layers):
495*da0073e9SAndroid Build Coastguard Worker    inp = torch.randn(seq_len, batch, input_size)
496*da0073e9SAndroid Build Coastguard Worker    states = [
497*da0073e9SAndroid Build Coastguard Worker        LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
498*da0073e9SAndroid Build Coastguard Worker        for _ in range(num_layers)
499*da0073e9SAndroid Build Coastguard Worker    ]
500*da0073e9SAndroid Build Coastguard Worker    rnn = script_lnlstm(input_size, hidden_size, num_layers)
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker    # just a smoke test
503*da0073e9SAndroid Build Coastguard Worker    out, out_state = rnn(inp, states)
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Workertest_script_rnn_layer(5, 2, 3, 7)
507*da0073e9SAndroid Build Coastguard Workertest_script_stacked_rnn(5, 2, 3, 7, 4)
508*da0073e9SAndroid Build Coastguard Workertest_script_stacked_bidir_rnn(5, 2, 3, 7, 4)
509*da0073e9SAndroid Build Coastguard Workertest_script_stacked_lstm_dropout(5, 2, 3, 7, 4)
510*da0073e9SAndroid Build Coastguard Workertest_script_stacked_lnlstm(5, 2, 3, 7, 4)
511