xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/factory.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
2*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom .cells import flat_lstm_cell, lstm_cell, premul_lstm_cell, premul_lstm_cell_no_bias
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker# list[list[T]] -> list[T]
11*da0073e9SAndroid Build Coastguard Workerdef flatten_list(lst):
12*da0073e9SAndroid Build Coastguard Worker    result = []
13*da0073e9SAndroid Build Coastguard Worker    for inner in lst:
14*da0073e9SAndroid Build Coastguard Worker        result.extend(inner)
15*da0073e9SAndroid Build Coastguard Worker    return result
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker"""
19*da0073e9SAndroid Build Coastguard WorkerDefine a creator as a function:
20*da0073e9SAndroid Build Coastguard Worker(options) -> (inputs, params, forward, backward_setup, backward)
21*da0073e9SAndroid Build Coastguard Workerinputs: the inputs to the returned 'forward'. One can call
22*da0073e9SAndroid Build Coastguard Worker    forward(*inputs) directly.
23*da0073e9SAndroid Build Coastguard Workerparams: List[Tensor] all requires_grad=True parameters.
24*da0073e9SAndroid Build Coastguard Workerforward: function / graph executor / module
25*da0073e9SAndroid Build Coastguard Worker    One can call rnn(rnn_inputs) using the outputs of the creator.
26*da0073e9SAndroid Build Coastguard Workerbackward_setup: backward_inputs = backward_setup(*outputs)
27*da0073e9SAndroid Build Coastguard Worker    Then, we pass backward_inputs to backward. If None, then it is assumed to
28*da0073e9SAndroid Build Coastguard Worker    be the identity function.
29*da0073e9SAndroid Build Coastguard Workerbackward: Given `output = backward_setup(*forward(*inputs))`, performs
30*da0073e9SAndroid Build Coastguard Worker    backpropagation. If None, then nothing happens.
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerfastrnns.bench times the forward and backward invocations.
33*da0073e9SAndroid Build Coastguard Worker"""
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard WorkerModelDef = namedtuple(
37*da0073e9SAndroid Build Coastguard Worker    "ModelDef", ["inputs", "params", "forward", "backward_setup", "backward"]
38*da0073e9SAndroid Build Coastguard Worker)
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerdef lstm_backward_setup(lstm_outputs, seed=None):
42*da0073e9SAndroid Build Coastguard Worker    hx, _ = lstm_outputs
43*da0073e9SAndroid Build Coastguard Worker    return simple_backward_setup(hx, seed)
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerdef simple_backward_setup(output, seed=None):
47*da0073e9SAndroid Build Coastguard Worker    assert isinstance(output, torch.Tensor)
48*da0073e9SAndroid Build Coastguard Worker    if seed:
49*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(seed)
50*da0073e9SAndroid Build Coastguard Worker    grad_output = torch.randn_like(output)
51*da0073e9SAndroid Build Coastguard Worker    return output, grad_output
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Workerdef simple_backward(output, grad_output, **kwargs):
55*da0073e9SAndroid Build Coastguard Worker    return output.backward(grad_output, **kwargs)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Workerdef pytorch_lstm_creator(**kwargs):
59*da0073e9SAndroid Build Coastguard Worker    input, hidden, _, module = lstm_inputs(return_module=True, **kwargs)
60*da0073e9SAndroid Build Coastguard Worker    return ModelDef(
61*da0073e9SAndroid Build Coastguard Worker        inputs=[input, hidden],
62*da0073e9SAndroid Build Coastguard Worker        params=flatten_list(module.all_weights),
63*da0073e9SAndroid Build Coastguard Worker        forward=module,
64*da0073e9SAndroid Build Coastguard Worker        backward_setup=lstm_backward_setup,
65*da0073e9SAndroid Build Coastguard Worker        backward=simple_backward,
66*da0073e9SAndroid Build Coastguard Worker    )
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerdef lstm_creator(script=True, **kwargs):
70*da0073e9SAndroid Build Coastguard Worker    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
71*da0073e9SAndroid Build Coastguard Worker    inputs = [input, hidden] + params[0]
72*da0073e9SAndroid Build Coastguard Worker    return ModelDef(
73*da0073e9SAndroid Build Coastguard Worker        inputs=inputs,
74*da0073e9SAndroid Build Coastguard Worker        params=flatten_list(params),
75*da0073e9SAndroid Build Coastguard Worker        forward=lstm_factory(lstm_cell, script),
76*da0073e9SAndroid Build Coastguard Worker        backward_setup=lstm_backward_setup,
77*da0073e9SAndroid Build Coastguard Worker        backward=simple_backward,
78*da0073e9SAndroid Build Coastguard Worker    )
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Workerdef lnlstm_creator(script=True, decompose_layernorm=False, **kwargs):
82*da0073e9SAndroid Build Coastguard Worker    assert script is True
83*da0073e9SAndroid Build Coastguard Worker    from .custom_lstms import script_lnlstm
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    input_size = kwargs["inputSize"]
86*da0073e9SAndroid Build Coastguard Worker    hidden_size = kwargs["hiddenSize"]
87*da0073e9SAndroid Build Coastguard Worker    seq_len = kwargs["seqLength"]
88*da0073e9SAndroid Build Coastguard Worker    batch_size = kwargs["miniBatch"]
89*da0073e9SAndroid Build Coastguard Worker    ge = script_lnlstm(
90*da0073e9SAndroid Build Coastguard Worker        input_size, hidden_size, 1, decompose_layernorm=decompose_layernorm
91*da0073e9SAndroid Build Coastguard Worker    ).cuda()
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    input = torch.randn(seq_len, batch_size, input_size, device="cuda")
94*da0073e9SAndroid Build Coastguard Worker    states = [
95*da0073e9SAndroid Build Coastguard Worker        (
96*da0073e9SAndroid Build Coastguard Worker            torch.randn(batch_size, hidden_size, device="cuda"),
97*da0073e9SAndroid Build Coastguard Worker            torch.randn(batch_size, hidden_size, device="cuda"),
98*da0073e9SAndroid Build Coastguard Worker        )
99*da0073e9SAndroid Build Coastguard Worker    ]
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    return ModelDef(
102*da0073e9SAndroid Build Coastguard Worker        inputs=[input, states],
103*da0073e9SAndroid Build Coastguard Worker        params=ge.parameters(),
104*da0073e9SAndroid Build Coastguard Worker        forward=ge,
105*da0073e9SAndroid Build Coastguard Worker        backward_setup=lstm_backward_setup,
106*da0073e9SAndroid Build Coastguard Worker        backward=simple_backward,
107*da0073e9SAndroid Build Coastguard Worker    )
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Workerdef dropoutlstm_creator(script=True, **kwargs):
111*da0073e9SAndroid Build Coastguard Worker    assert script is True
112*da0073e9SAndroid Build Coastguard Worker    from .custom_lstms import LSTMState, script_lstm
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    input_size = kwargs["inputSize"]
115*da0073e9SAndroid Build Coastguard Worker    hidden_size = kwargs["hiddenSize"]
116*da0073e9SAndroid Build Coastguard Worker    seq_len = kwargs["seqLength"]
117*da0073e9SAndroid Build Coastguard Worker    batch_size = kwargs["miniBatch"]
118*da0073e9SAndroid Build Coastguard Worker    num_layers = kwargs["numLayers"]
119*da0073e9SAndroid Build Coastguard Worker    ge = script_lstm(input_size, hidden_size, num_layers, dropout=True).cuda()
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    input = torch.randn(seq_len, batch_size, input_size, device="cuda")
122*da0073e9SAndroid Build Coastguard Worker    states = [
123*da0073e9SAndroid Build Coastguard Worker        LSTMState(
124*da0073e9SAndroid Build Coastguard Worker            torch.randn(batch_size, hidden_size, device="cuda"),
125*da0073e9SAndroid Build Coastguard Worker            torch.randn(batch_size, hidden_size, device="cuda"),
126*da0073e9SAndroid Build Coastguard Worker        )
127*da0073e9SAndroid Build Coastguard Worker        for _ in range(num_layers)
128*da0073e9SAndroid Build Coastguard Worker    ]
129*da0073e9SAndroid Build Coastguard Worker    return ModelDef(
130*da0073e9SAndroid Build Coastguard Worker        inputs=[input, states],
131*da0073e9SAndroid Build Coastguard Worker        params=ge.parameters(),
132*da0073e9SAndroid Build Coastguard Worker        forward=ge,
133*da0073e9SAndroid Build Coastguard Worker        backward_setup=lstm_backward_setup,
134*da0073e9SAndroid Build Coastguard Worker        backward=simple_backward,
135*da0073e9SAndroid Build Coastguard Worker    )
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Workerdef lstm_premul_creator(script=True, **kwargs):
139*da0073e9SAndroid Build Coastguard Worker    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
140*da0073e9SAndroid Build Coastguard Worker    inputs = [input, hidden] + params[0]
141*da0073e9SAndroid Build Coastguard Worker    return ModelDef(
142*da0073e9SAndroid Build Coastguard Worker        inputs=inputs,
143*da0073e9SAndroid Build Coastguard Worker        params=flatten_list(params),
144*da0073e9SAndroid Build Coastguard Worker        forward=lstm_factory_premul(premul_lstm_cell, script),
145*da0073e9SAndroid Build Coastguard Worker        backward_setup=lstm_backward_setup,
146*da0073e9SAndroid Build Coastguard Worker        backward=simple_backward,
147*da0073e9SAndroid Build Coastguard Worker    )
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Workerdef lstm_premul_bias_creator(script=True, **kwargs):
151*da0073e9SAndroid Build Coastguard Worker    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
152*da0073e9SAndroid Build Coastguard Worker    inputs = [input, hidden] + params[0]
153*da0073e9SAndroid Build Coastguard Worker    return ModelDef(
154*da0073e9SAndroid Build Coastguard Worker        inputs=inputs,
155*da0073e9SAndroid Build Coastguard Worker        params=flatten_list(params),
156*da0073e9SAndroid Build Coastguard Worker        forward=lstm_factory_premul_bias(premul_lstm_cell_no_bias, script),
157*da0073e9SAndroid Build Coastguard Worker        backward_setup=lstm_backward_setup,
158*da0073e9SAndroid Build Coastguard Worker        backward=simple_backward,
159*da0073e9SAndroid Build Coastguard Worker    )
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Workerdef lstm_simple_creator(script=True, **kwargs):
163*da0073e9SAndroid Build Coastguard Worker    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
164*da0073e9SAndroid Build Coastguard Worker    inputs = [input] + [h[0] for h in hidden] + params[0]
165*da0073e9SAndroid Build Coastguard Worker    return ModelDef(
166*da0073e9SAndroid Build Coastguard Worker        inputs=inputs,
167*da0073e9SAndroid Build Coastguard Worker        params=flatten_list(params),
168*da0073e9SAndroid Build Coastguard Worker        forward=lstm_factory_simple(flat_lstm_cell, script),
169*da0073e9SAndroid Build Coastguard Worker        backward_setup=lstm_backward_setup,
170*da0073e9SAndroid Build Coastguard Worker        backward=simple_backward,
171*da0073e9SAndroid Build Coastguard Worker    )
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Workerdef lstm_multilayer_creator(script=True, **kwargs):
175*da0073e9SAndroid Build Coastguard Worker    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
176*da0073e9SAndroid Build Coastguard Worker    inputs = [input, hidden, flatten_list(params)]
177*da0073e9SAndroid Build Coastguard Worker    return ModelDef(
178*da0073e9SAndroid Build Coastguard Worker        inputs=inputs,
179*da0073e9SAndroid Build Coastguard Worker        params=flatten_list(params),
180*da0073e9SAndroid Build Coastguard Worker        forward=lstm_factory_multilayer(lstm_cell, script),
181*da0073e9SAndroid Build Coastguard Worker        backward_setup=lstm_backward_setup,
182*da0073e9SAndroid Build Coastguard Worker        backward=simple_backward,
183*da0073e9SAndroid Build Coastguard Worker    )
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Workerdef imagenet_cnn_creator(arch, jit=True):
187*da0073e9SAndroid Build Coastguard Worker    def creator(device="cuda", **kwargs):
188*da0073e9SAndroid Build Coastguard Worker        model = arch().to(device)
189*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(32, 3, 224, 224, device=device)
190*da0073e9SAndroid Build Coastguard Worker        if jit:
191*da0073e9SAndroid Build Coastguard Worker            model = torch.jit.trace(model, x)
192*da0073e9SAndroid Build Coastguard Worker        return ModelDef(
193*da0073e9SAndroid Build Coastguard Worker            inputs=(x,),
194*da0073e9SAndroid Build Coastguard Worker            params=list(model.parameters()),
195*da0073e9SAndroid Build Coastguard Worker            forward=model,
196*da0073e9SAndroid Build Coastguard Worker            backward_setup=simple_backward_setup,
197*da0073e9SAndroid Build Coastguard Worker            backward=simple_backward,
198*da0073e9SAndroid Build Coastguard Worker        )
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker    return creator
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Workerdef varlen_lstm_inputs(
204*da0073e9SAndroid Build Coastguard Worker    minlen=30,
205*da0073e9SAndroid Build Coastguard Worker    maxlen=100,
206*da0073e9SAndroid Build Coastguard Worker    numLayers=1,
207*da0073e9SAndroid Build Coastguard Worker    inputSize=512,
208*da0073e9SAndroid Build Coastguard Worker    hiddenSize=512,
209*da0073e9SAndroid Build Coastguard Worker    miniBatch=64,
210*da0073e9SAndroid Build Coastguard Worker    return_module=False,
211*da0073e9SAndroid Build Coastguard Worker    device="cuda",
212*da0073e9SAndroid Build Coastguard Worker    seed=None,
213*da0073e9SAndroid Build Coastguard Worker    **kwargs,
214*da0073e9SAndroid Build Coastguard Worker):
215*da0073e9SAndroid Build Coastguard Worker    if seed is not None:
216*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(seed)
217*da0073e9SAndroid Build Coastguard Worker    lengths = torch.randint(
218*da0073e9SAndroid Build Coastguard Worker        low=minlen, high=maxlen, size=[miniBatch], dtype=torch.long, device=device
219*da0073e9SAndroid Build Coastguard Worker    )
220*da0073e9SAndroid Build Coastguard Worker    x = [torch.randn(length, inputSize, device=device) for length in lengths]
221*da0073e9SAndroid Build Coastguard Worker    hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
222*da0073e9SAndroid Build Coastguard Worker    cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
223*da0073e9SAndroid Build Coastguard Worker    lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers).to(device)
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    if return_module:
226*da0073e9SAndroid Build Coastguard Worker        return x, lengths, (hx, cx), lstm.all_weights, lstm
227*da0073e9SAndroid Build Coastguard Worker    else:
228*da0073e9SAndroid Build Coastguard Worker        # NB: lstm.all_weights format:
229*da0073e9SAndroid Build Coastguard Worker        # wih, whh, bih, bhh = lstm.all_weights[layer]
230*da0073e9SAndroid Build Coastguard Worker        return x, lengths, (hx, cx), lstm.all_weights, None
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Workerdef varlen_lstm_backward_setup(forward_output, seed=None):
234*da0073e9SAndroid Build Coastguard Worker    if seed:
235*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(seed)
236*da0073e9SAndroid Build Coastguard Worker    rnn_utils = torch.nn.utils.rnn
237*da0073e9SAndroid Build Coastguard Worker    sequences = forward_output[0]
238*da0073e9SAndroid Build Coastguard Worker    padded = rnn_utils.pad_sequence(sequences)
239*da0073e9SAndroid Build Coastguard Worker    grad = torch.randn_like(padded)
240*da0073e9SAndroid Build Coastguard Worker    return padded, grad
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Workerdef varlen_pytorch_lstm_creator(**kwargs):
244*da0073e9SAndroid Build Coastguard Worker    rnn_utils = torch.nn.utils.rnn
245*da0073e9SAndroid Build Coastguard Worker    sequences, _, hidden, _, module = varlen_lstm_inputs(return_module=True, **kwargs)
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker    def forward(sequences, hidden):
248*da0073e9SAndroid Build Coastguard Worker        packed = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
249*da0073e9SAndroid Build Coastguard Worker        out, new_hidden = module(packed, hidden)
250*da0073e9SAndroid Build Coastguard Worker        padded, lengths = rnn_utils.pad_packed_sequence(out)
251*da0073e9SAndroid Build Coastguard Worker        # XXX: It's more efficient to store the output in its padded form,
252*da0073e9SAndroid Build Coastguard Worker        # but that might not be conducive to loss computation.
253*da0073e9SAndroid Build Coastguard Worker        # Un-padding the output also makes the backward pass 2x slower...
254*da0073e9SAndroid Build Coastguard Worker        # return [padded[:lengths[i], i, :] for i in range(lengths.size(0))]
255*da0073e9SAndroid Build Coastguard Worker        return padded, new_hidden
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker    return ModelDef(
258*da0073e9SAndroid Build Coastguard Worker        inputs=[sequences, hidden],
259*da0073e9SAndroid Build Coastguard Worker        params=flatten_list(module.all_weights),
260*da0073e9SAndroid Build Coastguard Worker        forward=forward,
261*da0073e9SAndroid Build Coastguard Worker        backward_setup=lstm_backward_setup,
262*da0073e9SAndroid Build Coastguard Worker        backward=simple_backward,
263*da0073e9SAndroid Build Coastguard Worker    )
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Workerdef varlen_lstm_factory(cell, script):
267*da0073e9SAndroid Build Coastguard Worker    def dynamic_rnn(
268*da0073e9SAndroid Build Coastguard Worker        sequences: List[Tensor],
269*da0073e9SAndroid Build Coastguard Worker        hiddens: Tuple[Tensor, Tensor],
270*da0073e9SAndroid Build Coastguard Worker        wih: Tensor,
271*da0073e9SAndroid Build Coastguard Worker        whh: Tensor,
272*da0073e9SAndroid Build Coastguard Worker        bih: Tensor,
273*da0073e9SAndroid Build Coastguard Worker        bhh: Tensor,
274*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]:
275*da0073e9SAndroid Build Coastguard Worker        hx, cx = hiddens
276*da0073e9SAndroid Build Coastguard Worker        hxs = hx.unbind(1)
277*da0073e9SAndroid Build Coastguard Worker        cxs = cx.unbind(1)
278*da0073e9SAndroid Build Coastguard Worker        # List of: (output, hx, cx)
279*da0073e9SAndroid Build Coastguard Worker        outputs = []
280*da0073e9SAndroid Build Coastguard Worker        hx_outs = []
281*da0073e9SAndroid Build Coastguard Worker        cx_outs = []
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker        for batch in range(len(sequences)):
284*da0073e9SAndroid Build Coastguard Worker            output = []
285*da0073e9SAndroid Build Coastguard Worker            hy, cy = hxs[batch], cxs[batch]
286*da0073e9SAndroid Build Coastguard Worker            inputs = sequences[batch].unbind(0)
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker            for seq_idx in range(len(inputs)):
289*da0073e9SAndroid Build Coastguard Worker                hy, cy = cell(
290*da0073e9SAndroid Build Coastguard Worker                    inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh
291*da0073e9SAndroid Build Coastguard Worker                )
292*da0073e9SAndroid Build Coastguard Worker                output += [hy]
293*da0073e9SAndroid Build Coastguard Worker            outputs += [torch.stack(output)]
294*da0073e9SAndroid Build Coastguard Worker            hx_outs += [hy.unsqueeze(0)]
295*da0073e9SAndroid Build Coastguard Worker            cx_outs += [cy.unsqueeze(0)]
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        return outputs, (hx_outs, cx_outs)
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker    if script:
300*da0073e9SAndroid Build Coastguard Worker        cell = torch.jit.script(cell)
301*da0073e9SAndroid Build Coastguard Worker        dynamic_rnn = torch.jit.script(dynamic_rnn)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker    return dynamic_rnn
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Workerdef varlen_lstm_creator(script=False, **kwargs):
307*da0073e9SAndroid Build Coastguard Worker    sequences, _, hidden, params, _ = varlen_lstm_inputs(return_module=False, **kwargs)
308*da0073e9SAndroid Build Coastguard Worker    inputs = [sequences, hidden] + params[0]
309*da0073e9SAndroid Build Coastguard Worker    return ModelDef(
310*da0073e9SAndroid Build Coastguard Worker        inputs=inputs,
311*da0073e9SAndroid Build Coastguard Worker        params=flatten_list(params),
312*da0073e9SAndroid Build Coastguard Worker        forward=varlen_lstm_factory(lstm_cell, script),
313*da0073e9SAndroid Build Coastguard Worker        backward_setup=varlen_lstm_backward_setup,
314*da0073e9SAndroid Build Coastguard Worker        backward=simple_backward,
315*da0073e9SAndroid Build Coastguard Worker    )
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker# cudnn_layernorm_lstm: since cudnn does not have Layernorm LSTM, we cannot benchmark
319*da0073e9SAndroid Build Coastguard Worker# the lowerbound directly. Instead, we only benchmark the forward pass by mimicing the
320*da0073e9SAndroid Build Coastguard Worker# computation of a cudnn lstm + seq_len * 3 layernorm computation. This should serve
321*da0073e9SAndroid Build Coastguard Worker# as a perf lowerbound for the Layernorm LSTM forward pass(given that Layernorm itself
322*da0073e9SAndroid Build Coastguard Worker# is invariant), the lowerbound of backward pass is hard to get since we lose the
323*da0073e9SAndroid Build Coastguard Worker# intermediate results, we can still optimize the layernorm implementation to make
324*da0073e9SAndroid Build Coastguard Worker# a faster forward lowerbound though.
325*da0073e9SAndroid Build Coastguard Workerdef layernorm_pytorch_lstm_creator(**kwargs):
326*da0073e9SAndroid Build Coastguard Worker    input, hidden, _, module = lstm_inputs(return_module=True, **kwargs)
327*da0073e9SAndroid Build Coastguard Worker    batch_size = kwargs["miniBatch"]
328*da0073e9SAndroid Build Coastguard Worker    hidden_size = kwargs["hiddenSize"]
329*da0073e9SAndroid Build Coastguard Worker    ln_i = torch.nn.LayerNorm(4 * hidden_size).cuda()
330*da0073e9SAndroid Build Coastguard Worker    ln_h = torch.nn.LayerNorm(4 * hidden_size).cuda()
331*da0073e9SAndroid Build Coastguard Worker    ln_c = torch.nn.LayerNorm(hidden_size).cuda()
332*da0073e9SAndroid Build Coastguard Worker    ln_input1 = torch.randn(batch_size, 4 * hidden_size, device="cuda")
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker    def forward(input, hidden):
335*da0073e9SAndroid Build Coastguard Worker        out, new_hidden = module(input, hidden)
336*da0073e9SAndroid Build Coastguard Worker        # plus (seq_len * three laynorm cell computation) to mimic the lower bound of
337*da0073e9SAndroid Build Coastguard Worker        # Layernorm cudnn LSTM in the forward pass
338*da0073e9SAndroid Build Coastguard Worker        seq_len = len(input.unbind(0))
339*da0073e9SAndroid Build Coastguard Worker        hy, cy = new_hidden
340*da0073e9SAndroid Build Coastguard Worker        for i in range(seq_len):
341*da0073e9SAndroid Build Coastguard Worker            ln_i_output = ln_i(ln_input1)
342*da0073e9SAndroid Build Coastguard Worker            ln_h_output = ln_h(ln_input1)
343*da0073e9SAndroid Build Coastguard Worker            cy = ln_c(cy)
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker        return out, (hy, cy)
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker    return ModelDef(
348*da0073e9SAndroid Build Coastguard Worker        inputs=[input, hidden],
349*da0073e9SAndroid Build Coastguard Worker        params=flatten_list(module.all_weights),
350*da0073e9SAndroid Build Coastguard Worker        forward=forward,
351*da0073e9SAndroid Build Coastguard Worker        backward_setup=lstm_backward_setup,
352*da0073e9SAndroid Build Coastguard Worker        backward=None,
353*da0073e9SAndroid Build Coastguard Worker    )
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker# input: lstm.all_weights format (wih, whh, bih, bhh = lstm.all_weights[layer])
357*da0073e9SAndroid Build Coastguard Worker# output: packed_weights with format
358*da0073e9SAndroid Build Coastguard Worker# packed_weights[0] is wih with size (layer, 4*hiddenSize, inputSize)
359*da0073e9SAndroid Build Coastguard Worker# packed_weights[1] is whh with size (layer, 4*hiddenSize, hiddenSize)
360*da0073e9SAndroid Build Coastguard Worker# packed_weights[2] is bih with size (layer, 4*hiddenSize)
361*da0073e9SAndroid Build Coastguard Worker# packed_weights[3] is bhh with size (layer, 4*hiddenSize)
362*da0073e9SAndroid Build Coastguard Workerdef stack_weights(weights):
363*da0073e9SAndroid Build Coastguard Worker    def unzip_columns(mat):
364*da0073e9SAndroid Build Coastguard Worker        assert isinstance(mat, list)
365*da0073e9SAndroid Build Coastguard Worker        assert isinstance(mat[0], list)
366*da0073e9SAndroid Build Coastguard Worker        layers = len(mat)
367*da0073e9SAndroid Build Coastguard Worker        columns = len(mat[0])
368*da0073e9SAndroid Build Coastguard Worker        return [[mat[layer][col] for layer in range(layers)] for col in range(columns)]
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker    # XXX: script fns have problems indexing multidim lists, so we try to
371*da0073e9SAndroid Build Coastguard Worker    # avoid them by stacking tensors
372*da0073e9SAndroid Build Coastguard Worker    all_weights = weights
373*da0073e9SAndroid Build Coastguard Worker    packed_weights = [torch.stack(param) for param in unzip_columns(all_weights)]
374*da0073e9SAndroid Build Coastguard Worker    return packed_weights
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker# returns: x, (hx, cx), all_weights, lstm module with all_weights as params
378*da0073e9SAndroid Build Coastguard Workerdef lstm_inputs(
379*da0073e9SAndroid Build Coastguard Worker    seqLength=100,
380*da0073e9SAndroid Build Coastguard Worker    numLayers=1,
381*da0073e9SAndroid Build Coastguard Worker    inputSize=512,
382*da0073e9SAndroid Build Coastguard Worker    hiddenSize=512,
383*da0073e9SAndroid Build Coastguard Worker    miniBatch=64,
384*da0073e9SAndroid Build Coastguard Worker    dropout=0.0,
385*da0073e9SAndroid Build Coastguard Worker    return_module=False,
386*da0073e9SAndroid Build Coastguard Worker    device="cuda",
387*da0073e9SAndroid Build Coastguard Worker    seed=None,
388*da0073e9SAndroid Build Coastguard Worker):
389*da0073e9SAndroid Build Coastguard Worker    if seed is not None:
390*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(seed)
391*da0073e9SAndroid Build Coastguard Worker    x = torch.randn(seqLength, miniBatch, inputSize, device=device)
392*da0073e9SAndroid Build Coastguard Worker    hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
393*da0073e9SAndroid Build Coastguard Worker    cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
394*da0073e9SAndroid Build Coastguard Worker    lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers, dropout=dropout)
395*da0073e9SAndroid Build Coastguard Worker    if "cuda" in device:
396*da0073e9SAndroid Build Coastguard Worker        lstm = lstm.cuda()
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker    if return_module:
399*da0073e9SAndroid Build Coastguard Worker        return x, (hx, cx), lstm.all_weights, lstm
400*da0073e9SAndroid Build Coastguard Worker    else:
401*da0073e9SAndroid Build Coastguard Worker        # NB: lstm.all_weights format:
402*da0073e9SAndroid Build Coastguard Worker        # wih, whh, bih, bhh = lstm.all_weights[layer]
403*da0073e9SAndroid Build Coastguard Worker        return x, (hx, cx), lstm.all_weights, None
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Workerdef lstm_factory(cell, script):
407*da0073e9SAndroid Build Coastguard Worker    def dynamic_rnn(
408*da0073e9SAndroid Build Coastguard Worker        input: Tensor,
409*da0073e9SAndroid Build Coastguard Worker        hidden: Tuple[Tensor, Tensor],
410*da0073e9SAndroid Build Coastguard Worker        wih: Tensor,
411*da0073e9SAndroid Build Coastguard Worker        whh: Tensor,
412*da0073e9SAndroid Build Coastguard Worker        bih: Tensor,
413*da0073e9SAndroid Build Coastguard Worker        bhh: Tensor,
414*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
415*da0073e9SAndroid Build Coastguard Worker        hx, cx = hidden
416*da0073e9SAndroid Build Coastguard Worker        outputs = []
417*da0073e9SAndroid Build Coastguard Worker        inputs = input.unbind(0)
418*da0073e9SAndroid Build Coastguard Worker        hy, cy = hx[0], cx[0]
419*da0073e9SAndroid Build Coastguard Worker        for seq_idx in range(len(inputs)):
420*da0073e9SAndroid Build Coastguard Worker            hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh)
421*da0073e9SAndroid Build Coastguard Worker            outputs += [hy]
422*da0073e9SAndroid Build Coastguard Worker        return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker    if script:
425*da0073e9SAndroid Build Coastguard Worker        cell = torch.jit.script(cell)
426*da0073e9SAndroid Build Coastguard Worker        dynamic_rnn = torch.jit.script(dynamic_rnn)
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker    return dynamic_rnn
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker# premul: we're going to premultiply the inputs & weights
432*da0073e9SAndroid Build Coastguard Workerdef lstm_factory_premul(premul_cell, script):
433*da0073e9SAndroid Build Coastguard Worker    def dynamic_rnn(
434*da0073e9SAndroid Build Coastguard Worker        input: Tensor,
435*da0073e9SAndroid Build Coastguard Worker        hidden: Tuple[Tensor, Tensor],
436*da0073e9SAndroid Build Coastguard Worker        wih: Tensor,
437*da0073e9SAndroid Build Coastguard Worker        whh: Tensor,
438*da0073e9SAndroid Build Coastguard Worker        bih: Tensor,
439*da0073e9SAndroid Build Coastguard Worker        bhh: Tensor,
440*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
441*da0073e9SAndroid Build Coastguard Worker        hx, cx = hidden
442*da0073e9SAndroid Build Coastguard Worker        outputs = []
443*da0073e9SAndroid Build Coastguard Worker        inputs = torch.matmul(input, wih.t()).unbind(0)
444*da0073e9SAndroid Build Coastguard Worker        hy, cy = hx[0], cx[0]
445*da0073e9SAndroid Build Coastguard Worker        for seq_idx in range(len(inputs)):
446*da0073e9SAndroid Build Coastguard Worker            hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bih, bhh)
447*da0073e9SAndroid Build Coastguard Worker            outputs += [hy]
448*da0073e9SAndroid Build Coastguard Worker        return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker    if script:
451*da0073e9SAndroid Build Coastguard Worker        premul_cell = torch.jit.script(premul_cell)
452*da0073e9SAndroid Build Coastguard Worker        dynamic_rnn = torch.jit.script(dynamic_rnn)
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker    return dynamic_rnn
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker# premul: we're going to premultiply the inputs & weights, and add bias
458*da0073e9SAndroid Build Coastguard Workerdef lstm_factory_premul_bias(premul_cell, script):
459*da0073e9SAndroid Build Coastguard Worker    def dynamic_rnn(
460*da0073e9SAndroid Build Coastguard Worker        input: Tensor,
461*da0073e9SAndroid Build Coastguard Worker        hidden: Tuple[Tensor, Tensor],
462*da0073e9SAndroid Build Coastguard Worker        wih: Tensor,
463*da0073e9SAndroid Build Coastguard Worker        whh: Tensor,
464*da0073e9SAndroid Build Coastguard Worker        bih: Tensor,
465*da0073e9SAndroid Build Coastguard Worker        bhh: Tensor,
466*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
467*da0073e9SAndroid Build Coastguard Worker        hx, cx = hidden
468*da0073e9SAndroid Build Coastguard Worker        outputs = []
469*da0073e9SAndroid Build Coastguard Worker        inpSize = input.size()
470*da0073e9SAndroid Build Coastguard Worker        # add bias for all timesteps instead of going step-by-step, results in a single reduction kernel in the backward
471*da0073e9SAndroid Build Coastguard Worker        # FIXME matmul(x,y) + bias currently goes through jit AD, and backward formula in AD is not optimized for this
472*da0073e9SAndroid Build Coastguard Worker        # case. Workaround with mm and views.
473*da0073e9SAndroid Build Coastguard Worker        inpSize = input.size()
474*da0073e9SAndroid Build Coastguard Worker        inputs = torch.mm(input.view(-1, inpSize[2]), wih.t()) + bih
475*da0073e9SAndroid Build Coastguard Worker        inputs = inputs.view(inpSize[0], inpSize[1], -1).unbind(0)
476*da0073e9SAndroid Build Coastguard Worker        hy, cy = hx[0], cx[0]
477*da0073e9SAndroid Build Coastguard Worker        for seq_idx in range(len(inputs)):
478*da0073e9SAndroid Build Coastguard Worker            hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bhh)
479*da0073e9SAndroid Build Coastguard Worker            outputs += [hy]
480*da0073e9SAndroid Build Coastguard Worker        return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker    if script:
483*da0073e9SAndroid Build Coastguard Worker        premul_cell = torch.jit.script(premul_cell)
484*da0073e9SAndroid Build Coastguard Worker        dynamic_rnn = torch.jit.script(dynamic_rnn)
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker    return dynamic_rnn
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker# simple: flat inputs (no tuples), no list to accumulate outputs
490*da0073e9SAndroid Build Coastguard Worker#         useful mostly for benchmarking older JIT versions
491*da0073e9SAndroid Build Coastguard Workerdef lstm_factory_simple(cell, script):
492*da0073e9SAndroid Build Coastguard Worker    def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh):
493*da0073e9SAndroid Build Coastguard Worker        hy = hx  # for scoping
494*da0073e9SAndroid Build Coastguard Worker        cy = cx  # for scoping
495*da0073e9SAndroid Build Coastguard Worker        inputs = input.unbind(0)
496*da0073e9SAndroid Build Coastguard Worker        for seq_idx in range(len(inputs)):
497*da0073e9SAndroid Build Coastguard Worker            hy, cy = cell(inputs[seq_idx], hy, cy, wih, whh, bih, bhh)
498*da0073e9SAndroid Build Coastguard Worker        return hy, cy
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker    if script:
501*da0073e9SAndroid Build Coastguard Worker        cell = torch.jit.script(cell)
502*da0073e9SAndroid Build Coastguard Worker        dynamic_rnn = torch.jit.script(dynamic_rnn)
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker    return dynamic_rnn
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Workerdef lstm_factory_multilayer(cell, script):
508*da0073e9SAndroid Build Coastguard Worker    def dynamic_rnn(
509*da0073e9SAndroid Build Coastguard Worker        input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor]
510*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
511*da0073e9SAndroid Build Coastguard Worker        params_stride = 4  # NB: this assumes that biases are there
512*da0073e9SAndroid Build Coastguard Worker        hx, cx = hidden
513*da0073e9SAndroid Build Coastguard Worker        hy, cy = hidden  # for scoping...
514*da0073e9SAndroid Build Coastguard Worker        inputs, outputs = input.unbind(0), []
515*da0073e9SAndroid Build Coastguard Worker        for layer in range(hx.size(0)):
516*da0073e9SAndroid Build Coastguard Worker            hy = hx[layer]
517*da0073e9SAndroid Build Coastguard Worker            cy = cx[layer]
518*da0073e9SAndroid Build Coastguard Worker            base_idx = layer * params_stride
519*da0073e9SAndroid Build Coastguard Worker            wih = params[base_idx]
520*da0073e9SAndroid Build Coastguard Worker            whh = params[base_idx + 1]
521*da0073e9SAndroid Build Coastguard Worker            bih = params[base_idx + 2]
522*da0073e9SAndroid Build Coastguard Worker            bhh = params[base_idx + 3]
523*da0073e9SAndroid Build Coastguard Worker            for seq_idx in range(len(inputs)):
524*da0073e9SAndroid Build Coastguard Worker                hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh)
525*da0073e9SAndroid Build Coastguard Worker                outputs += [hy]
526*da0073e9SAndroid Build Coastguard Worker            inputs, outputs = outputs, []
527*da0073e9SAndroid Build Coastguard Worker        return torch.stack(inputs), (hy.unsqueeze(0), cy.unsqueeze(0))
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker    if script:
530*da0073e9SAndroid Build Coastguard Worker        cell = torch.jit.script(cell)
531*da0073e9SAndroid Build Coastguard Worker        dynamic_rnn = torch.jit.script(dynamic_rnn)
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker    return dynamic_rnn
534