xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/custom_lstms.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import numbers
2import warnings
3from collections import namedtuple
4from typing import List, Tuple
5
6import torch
7import torch.jit as jit
8import torch.nn as nn
9from torch import Tensor
10from torch.nn import Parameter
11
12
13"""
14Some helper classes for writing custom TorchScript LSTMs.
15
16Goals:
17- Classes are easy to read, use, and extend
18- Performance of custom LSTMs approach fused-kernel-levels of speed.
19
20A few notes about features we could add to clean up the below code:
21- Support enumerate with nn.ModuleList:
22  https://github.com/pytorch/pytorch/issues/14471
23- Support enumerate/zip with lists:
24  https://github.com/pytorch/pytorch/issues/15952
25- Support overriding of class methods:
26  https://github.com/pytorch/pytorch/issues/10733
27- Support passing around user-defined namedtuple types for readability
28- Support slicing w/ range. It enables reversing lists easily.
29  https://github.com/pytorch/pytorch/issues/10774
30- Multiline type annotations. List[List[Tuple[Tensor,Tensor]]] is verbose
31  https://github.com/pytorch/pytorch/pull/14922
32"""
33
34
35def script_lstm(
36    input_size,
37    hidden_size,
38    num_layers,
39    bias=True,
40    batch_first=False,
41    dropout=False,
42    bidirectional=False,
43):
44    """Returns a ScriptModule that mimics a PyTorch native LSTM."""
45
46    # The following are not implemented.
47    assert bias
48    assert not batch_first
49
50    if bidirectional:
51        stack_type = StackedLSTM2
52        layer_type = BidirLSTMLayer
53        dirs = 2
54    elif dropout:
55        stack_type = StackedLSTMWithDropout
56        layer_type = LSTMLayer
57        dirs = 1
58    else:
59        stack_type = StackedLSTM
60        layer_type = LSTMLayer
61        dirs = 1
62
63    return stack_type(
64        num_layers,
65        layer_type,
66        first_layer_args=[LSTMCell, input_size, hidden_size],
67        other_layer_args=[LSTMCell, hidden_size * dirs, hidden_size],
68    )
69
70
71def script_lnlstm(
72    input_size,
73    hidden_size,
74    num_layers,
75    bias=True,
76    batch_first=False,
77    dropout=False,
78    bidirectional=False,
79    decompose_layernorm=False,
80):
81    """Returns a ScriptModule that mimics a PyTorch native LSTM."""
82
83    # The following are not implemented.
84    assert bias
85    assert not batch_first
86    assert not dropout
87
88    if bidirectional:
89        stack_type = StackedLSTM2
90        layer_type = BidirLSTMLayer
91        dirs = 2
92    else:
93        stack_type = StackedLSTM
94        layer_type = LSTMLayer
95        dirs = 1
96
97    return stack_type(
98        num_layers,
99        layer_type,
100        first_layer_args=[
101            LayerNormLSTMCell,
102            input_size,
103            hidden_size,
104            decompose_layernorm,
105        ],
106        other_layer_args=[
107            LayerNormLSTMCell,
108            hidden_size * dirs,
109            hidden_size,
110            decompose_layernorm,
111        ],
112    )
113
114
115LSTMState = namedtuple("LSTMState", ["hx", "cx"])
116
117
118def reverse(lst: List[Tensor]) -> List[Tensor]:
119    return lst[::-1]
120
121
122class LSTMCell(jit.ScriptModule):
123    def __init__(self, input_size, hidden_size):
124        super().__init__()
125        self.input_size = input_size
126        self.hidden_size = hidden_size
127        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
128        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
129        self.bias_ih = Parameter(torch.randn(4 * hidden_size))
130        self.bias_hh = Parameter(torch.randn(4 * hidden_size))
131
132    @jit.script_method
133    def forward(
134        self, input: Tensor, state: Tuple[Tensor, Tensor]
135    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
136        hx, cx = state
137        gates = (
138            torch.mm(input, self.weight_ih.t())
139            + self.bias_ih
140            + torch.mm(hx, self.weight_hh.t())
141            + self.bias_hh
142        )
143        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
144
145        ingate = torch.sigmoid(ingate)
146        forgetgate = torch.sigmoid(forgetgate)
147        cellgate = torch.tanh(cellgate)
148        outgate = torch.sigmoid(outgate)
149
150        cy = (forgetgate * cx) + (ingate * cellgate)
151        hy = outgate * torch.tanh(cy)
152
153        return hy, (hy, cy)
154
155
156class LayerNorm(jit.ScriptModule):
157    def __init__(self, normalized_shape):
158        super().__init__()
159        if isinstance(normalized_shape, numbers.Integral):
160            normalized_shape = (normalized_shape,)
161        normalized_shape = torch.Size(normalized_shape)
162
163        # XXX: This is true for our LSTM / NLP use case and helps simplify code
164        assert len(normalized_shape) == 1
165
166        self.weight = Parameter(torch.ones(normalized_shape))
167        self.bias = Parameter(torch.zeros(normalized_shape))
168        self.normalized_shape = normalized_shape
169
170    @jit.script_method
171    def compute_layernorm_stats(self, input):
172        mu = input.mean(-1, keepdim=True)
173        sigma = input.std(-1, keepdim=True, unbiased=False)
174        return mu, sigma
175
176    @jit.script_method
177    def forward(self, input):
178        mu, sigma = self.compute_layernorm_stats(input)
179        return (input - mu) / sigma * self.weight + self.bias
180
181
182class LayerNormLSTMCell(jit.ScriptModule):
183    def __init__(self, input_size, hidden_size, decompose_layernorm=False):
184        super().__init__()
185        self.input_size = input_size
186        self.hidden_size = hidden_size
187        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
188        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
189        # The layernorms provide learnable biases
190
191        if decompose_layernorm:
192            ln = LayerNorm
193        else:
194            ln = nn.LayerNorm
195
196        self.layernorm_i = ln(4 * hidden_size)
197        self.layernorm_h = ln(4 * hidden_size)
198        self.layernorm_c = ln(hidden_size)
199
200    @jit.script_method
201    def forward(
202        self, input: Tensor, state: Tuple[Tensor, Tensor]
203    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
204        hx, cx = state
205        igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
206        hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
207        gates = igates + hgates
208        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
209
210        ingate = torch.sigmoid(ingate)
211        forgetgate = torch.sigmoid(forgetgate)
212        cellgate = torch.tanh(cellgate)
213        outgate = torch.sigmoid(outgate)
214
215        cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate))
216        hy = outgate * torch.tanh(cy)
217
218        return hy, (hy, cy)
219
220
221class LSTMLayer(jit.ScriptModule):
222    def __init__(self, cell, *cell_args):
223        super().__init__()
224        self.cell = cell(*cell_args)
225
226    @jit.script_method
227    def forward(
228        self, input: Tensor, state: Tuple[Tensor, Tensor]
229    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
230        inputs = input.unbind(0)
231        outputs = torch.jit.annotate(List[Tensor], [])
232        for i in range(len(inputs)):
233            out, state = self.cell(inputs[i], state)
234            outputs += [out]
235        return torch.stack(outputs), state
236
237
238class ReverseLSTMLayer(jit.ScriptModule):
239    def __init__(self, cell, *cell_args):
240        super().__init__()
241        self.cell = cell(*cell_args)
242
243    @jit.script_method
244    def forward(
245        self, input: Tensor, state: Tuple[Tensor, Tensor]
246    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
247        inputs = reverse(input.unbind(0))
248        outputs = jit.annotate(List[Tensor], [])
249        for i in range(len(inputs)):
250            out, state = self.cell(inputs[i], state)
251            outputs += [out]
252        return torch.stack(reverse(outputs)), state
253
254
255class BidirLSTMLayer(jit.ScriptModule):
256    __constants__ = ["directions"]
257
258    def __init__(self, cell, *cell_args):
259        super().__init__()
260        self.directions = nn.ModuleList(
261            [
262                LSTMLayer(cell, *cell_args),
263                ReverseLSTMLayer(cell, *cell_args),
264            ]
265        )
266
267    @jit.script_method
268    def forward(
269        self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
270    ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
271        # List[LSTMState]: [forward LSTMState, backward LSTMState]
272        outputs = jit.annotate(List[Tensor], [])
273        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
274        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
275        i = 0
276        for direction in self.directions:
277            state = states[i]
278            out, out_state = direction(input, state)
279            outputs += [out]
280            output_states += [out_state]
281            i += 1
282        return torch.cat(outputs, -1), output_states
283
284
285def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):
286    layers = [layer(*first_layer_args)] + [
287        layer(*other_layer_args) for _ in range(num_layers - 1)
288    ]
289    return nn.ModuleList(layers)
290
291
292class StackedLSTM(jit.ScriptModule):
293    __constants__ = ["layers"]  # Necessary for iterating through self.layers
294
295    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
296        super().__init__()
297        self.layers = init_stacked_lstm(
298            num_layers, layer, first_layer_args, other_layer_args
299        )
300
301    @jit.script_method
302    def forward(
303        self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
304    ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
305        # List[LSTMState]: One state per layer
306        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
307        output = input
308        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
309        i = 0
310        for rnn_layer in self.layers:
311            state = states[i]
312            output, out_state = rnn_layer(output, state)
313            output_states += [out_state]
314            i += 1
315        return output, output_states
316
317
318# Differs from StackedLSTM in that its forward method takes
319# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM
320# except we don't support overriding script methods.
321# https://github.com/pytorch/pytorch/issues/10733
322class StackedLSTM2(jit.ScriptModule):
323    __constants__ = ["layers"]  # Necessary for iterating through self.layers
324
325    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
326        super().__init__()
327        self.layers = init_stacked_lstm(
328            num_layers, layer, first_layer_args, other_layer_args
329        )
330
331    @jit.script_method
332    def forward(
333        self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]]
334    ) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]:
335        # List[List[LSTMState]]: The outer list is for layers,
336        #                        inner list is for directions.
337        output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
338        output = input
339        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
340        i = 0
341        for rnn_layer in self.layers:
342            state = states[i]
343            output, out_state = rnn_layer(output, state)
344            output_states += [out_state]
345            i += 1
346        return output, output_states
347
348
349class StackedLSTMWithDropout(jit.ScriptModule):
350    # Necessary for iterating through self.layers and dropout support
351    __constants__ = ["layers", "num_layers"]
352
353    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
354        super().__init__()
355        self.layers = init_stacked_lstm(
356            num_layers, layer, first_layer_args, other_layer_args
357        )
358        # Introduces a Dropout layer on the outputs of each LSTM layer except
359        # the last layer, with dropout probability = 0.4.
360        self.num_layers = num_layers
361
362        if num_layers == 1:
363            warnings.warn(
364                "dropout lstm adds dropout layers after all but last "
365                "recurrent layer, it expects num_layers greater than "
366                "1, but got num_layers = 1"
367            )
368
369        self.dropout_layer = nn.Dropout(0.4)
370
371    @jit.script_method
372    def forward(
373        self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
374    ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
375        # List[LSTMState]: One state per layer
376        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
377        output = input
378        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
379        i = 0
380        for rnn_layer in self.layers:
381            state = states[i]
382            output, out_state = rnn_layer(output, state)
383            # Apply the dropout layer except the last layer
384            if i < self.num_layers - 1:
385                output = self.dropout_layer(output)
386            output_states += [out_state]
387            i += 1
388        return output, output_states
389
390
391def flatten_states(states):
392    states = list(zip(*states))
393    assert len(states) == 2
394    return [torch.stack(state) for state in states]
395
396
397def double_flatten_states(states):
398    # XXX: Can probably write this in a nicer way
399    states = flatten_states([flatten_states(inner) for inner in states])
400    return [hidden.view([-1] + list(hidden.shape[2:])) for hidden in states]
401
402
403def test_script_rnn_layer(seq_len, batch, input_size, hidden_size):
404    inp = torch.randn(seq_len, batch, input_size)
405    state = LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
406    rnn = LSTMLayer(LSTMCell, input_size, hidden_size)
407    out, out_state = rnn(inp, state)
408
409    # Control: pytorch native LSTM
410    lstm = nn.LSTM(input_size, hidden_size, 1)
411    lstm_state = LSTMState(state.hx.unsqueeze(0), state.cx.unsqueeze(0))
412    for lstm_param, custom_param in zip(lstm.all_weights[0], rnn.parameters()):
413        assert lstm_param.shape == custom_param.shape
414        with torch.no_grad():
415            lstm_param.copy_(custom_param)
416    lstm_out, lstm_out_state = lstm(inp, lstm_state)
417
418    assert (out - lstm_out).abs().max() < 1e-5
419    assert (out_state[0] - lstm_out_state[0]).abs().max() < 1e-5
420    assert (out_state[1] - lstm_out_state[1]).abs().max() < 1e-5
421
422
423def test_script_stacked_rnn(seq_len, batch, input_size, hidden_size, num_layers):
424    inp = torch.randn(seq_len, batch, input_size)
425    states = [
426        LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
427        for _ in range(num_layers)
428    ]
429    rnn = script_lstm(input_size, hidden_size, num_layers)
430    out, out_state = rnn(inp, states)
431    custom_state = flatten_states(out_state)
432
433    # Control: pytorch native LSTM
434    lstm = nn.LSTM(input_size, hidden_size, num_layers)
435    lstm_state = flatten_states(states)
436    for layer in range(num_layers):
437        custom_params = list(rnn.parameters())[4 * layer : 4 * (layer + 1)]
438        for lstm_param, custom_param in zip(lstm.all_weights[layer], custom_params):
439            assert lstm_param.shape == custom_param.shape
440            with torch.no_grad():
441                lstm_param.copy_(custom_param)
442    lstm_out, lstm_out_state = lstm(inp, lstm_state)
443
444    assert (out - lstm_out).abs().max() < 1e-5
445    assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5
446    assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5
447
448
449def test_script_stacked_bidir_rnn(seq_len, batch, input_size, hidden_size, num_layers):
450    inp = torch.randn(seq_len, batch, input_size)
451    states = [
452        [
453            LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
454            for _ in range(2)
455        ]
456        for _ in range(num_layers)
457    ]
458    rnn = script_lstm(input_size, hidden_size, num_layers, bidirectional=True)
459    out, out_state = rnn(inp, states)
460    custom_state = double_flatten_states(out_state)
461
462    # Control: pytorch native LSTM
463    lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True)
464    lstm_state = double_flatten_states(states)
465    for layer in range(num_layers):
466        for direct in range(2):
467            index = 2 * layer + direct
468            custom_params = list(rnn.parameters())[4 * index : 4 * index + 4]
469            for lstm_param, custom_param in zip(lstm.all_weights[index], custom_params):
470                assert lstm_param.shape == custom_param.shape
471                with torch.no_grad():
472                    lstm_param.copy_(custom_param)
473    lstm_out, lstm_out_state = lstm(inp, lstm_state)
474
475    assert (out - lstm_out).abs().max() < 1e-5
476    assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5
477    assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5
478
479
480def test_script_stacked_lstm_dropout(
481    seq_len, batch, input_size, hidden_size, num_layers
482):
483    inp = torch.randn(seq_len, batch, input_size)
484    states = [
485        LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
486        for _ in range(num_layers)
487    ]
488    rnn = script_lstm(input_size, hidden_size, num_layers, dropout=True)
489
490    # just a smoke test
491    out, out_state = rnn(inp, states)
492
493
494def test_script_stacked_lnlstm(seq_len, batch, input_size, hidden_size, num_layers):
495    inp = torch.randn(seq_len, batch, input_size)
496    states = [
497        LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
498        for _ in range(num_layers)
499    ]
500    rnn = script_lnlstm(input_size, hidden_size, num_layers)
501
502    # just a smoke test
503    out, out_state = rnn(inp, states)
504
505
506test_script_rnn_layer(5, 2, 3, 7)
507test_script_stacked_rnn(5, 2, 3, 7, 4)
508test_script_stacked_bidir_rnn(5, 2, 3, 7, 4)
509test_script_stacked_lstm_dropout(5, 2, 3, 7, 4)
510test_script_stacked_lnlstm(5, 2, 3, 7, 4)
511