1*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple 2*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom .cells import flat_lstm_cell, lstm_cell, premul_lstm_cell, premul_lstm_cell_no_bias 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker# list[list[T]] -> list[T] 11*da0073e9SAndroid Build Coastguard Workerdef flatten_list(lst): 12*da0073e9SAndroid Build Coastguard Worker result = [] 13*da0073e9SAndroid Build Coastguard Worker for inner in lst: 14*da0073e9SAndroid Build Coastguard Worker result.extend(inner) 15*da0073e9SAndroid Build Coastguard Worker return result 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker""" 19*da0073e9SAndroid Build Coastguard WorkerDefine a creator as a function: 20*da0073e9SAndroid Build Coastguard Worker(options) -> (inputs, params, forward, backward_setup, backward) 21*da0073e9SAndroid Build Coastguard Workerinputs: the inputs to the returned 'forward'. One can call 22*da0073e9SAndroid Build Coastguard Worker forward(*inputs) directly. 23*da0073e9SAndroid Build Coastguard Workerparams: List[Tensor] all requires_grad=True parameters. 24*da0073e9SAndroid Build Coastguard Workerforward: function / graph executor / module 25*da0073e9SAndroid Build Coastguard Worker One can call rnn(rnn_inputs) using the outputs of the creator. 26*da0073e9SAndroid Build Coastguard Workerbackward_setup: backward_inputs = backward_setup(*outputs) 27*da0073e9SAndroid Build Coastguard Worker Then, we pass backward_inputs to backward. If None, then it is assumed to 28*da0073e9SAndroid Build Coastguard Worker be the identity function. 29*da0073e9SAndroid Build Coastguard Workerbackward: Given `output = backward_setup(*forward(*inputs))`, performs 30*da0073e9SAndroid Build Coastguard Worker backpropagation. If None, then nothing happens. 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerfastrnns.bench times the forward and backward invocations. 33*da0073e9SAndroid Build Coastguard Worker""" 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard WorkerModelDef = namedtuple( 37*da0073e9SAndroid Build Coastguard Worker "ModelDef", ["inputs", "params", "forward", "backward_setup", "backward"] 38*da0073e9SAndroid Build Coastguard Worker) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerdef lstm_backward_setup(lstm_outputs, seed=None): 42*da0073e9SAndroid Build Coastguard Worker hx, _ = lstm_outputs 43*da0073e9SAndroid Build Coastguard Worker return simple_backward_setup(hx, seed) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerdef simple_backward_setup(output, seed=None): 47*da0073e9SAndroid Build Coastguard Worker assert isinstance(output, torch.Tensor) 48*da0073e9SAndroid Build Coastguard Worker if seed: 49*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 50*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn_like(output) 51*da0073e9SAndroid Build Coastguard Worker return output, grad_output 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Workerdef simple_backward(output, grad_output, **kwargs): 55*da0073e9SAndroid Build Coastguard Worker return output.backward(grad_output, **kwargs) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Workerdef pytorch_lstm_creator(**kwargs): 59*da0073e9SAndroid Build Coastguard Worker input, hidden, _, module = lstm_inputs(return_module=True, **kwargs) 60*da0073e9SAndroid Build Coastguard Worker return ModelDef( 61*da0073e9SAndroid Build Coastguard Worker inputs=[input, hidden], 62*da0073e9SAndroid Build Coastguard Worker params=flatten_list(module.all_weights), 63*da0073e9SAndroid Build Coastguard Worker forward=module, 64*da0073e9SAndroid Build Coastguard Worker backward_setup=lstm_backward_setup, 65*da0073e9SAndroid Build Coastguard Worker backward=simple_backward, 66*da0073e9SAndroid Build Coastguard Worker ) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerdef lstm_creator(script=True, **kwargs): 70*da0073e9SAndroid Build Coastguard Worker input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) 71*da0073e9SAndroid Build Coastguard Worker inputs = [input, hidden] + params[0] 72*da0073e9SAndroid Build Coastguard Worker return ModelDef( 73*da0073e9SAndroid Build Coastguard Worker inputs=inputs, 74*da0073e9SAndroid Build Coastguard Worker params=flatten_list(params), 75*da0073e9SAndroid Build Coastguard Worker forward=lstm_factory(lstm_cell, script), 76*da0073e9SAndroid Build Coastguard Worker backward_setup=lstm_backward_setup, 77*da0073e9SAndroid Build Coastguard Worker backward=simple_backward, 78*da0073e9SAndroid Build Coastguard Worker ) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Workerdef lnlstm_creator(script=True, decompose_layernorm=False, **kwargs): 82*da0073e9SAndroid Build Coastguard Worker assert script is True 83*da0073e9SAndroid Build Coastguard Worker from .custom_lstms import script_lnlstm 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker input_size = kwargs["inputSize"] 86*da0073e9SAndroid Build Coastguard Worker hidden_size = kwargs["hiddenSize"] 87*da0073e9SAndroid Build Coastguard Worker seq_len = kwargs["seqLength"] 88*da0073e9SAndroid Build Coastguard Worker batch_size = kwargs["miniBatch"] 89*da0073e9SAndroid Build Coastguard Worker ge = script_lnlstm( 90*da0073e9SAndroid Build Coastguard Worker input_size, hidden_size, 1, decompose_layernorm=decompose_layernorm 91*da0073e9SAndroid Build Coastguard Worker ).cuda() 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker input = torch.randn(seq_len, batch_size, input_size, device="cuda") 94*da0073e9SAndroid Build Coastguard Worker states = [ 95*da0073e9SAndroid Build Coastguard Worker ( 96*da0073e9SAndroid Build Coastguard Worker torch.randn(batch_size, hidden_size, device="cuda"), 97*da0073e9SAndroid Build Coastguard Worker torch.randn(batch_size, hidden_size, device="cuda"), 98*da0073e9SAndroid Build Coastguard Worker ) 99*da0073e9SAndroid Build Coastguard Worker ] 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker return ModelDef( 102*da0073e9SAndroid Build Coastguard Worker inputs=[input, states], 103*da0073e9SAndroid Build Coastguard Worker params=ge.parameters(), 104*da0073e9SAndroid Build Coastguard Worker forward=ge, 105*da0073e9SAndroid Build Coastguard Worker backward_setup=lstm_backward_setup, 106*da0073e9SAndroid Build Coastguard Worker backward=simple_backward, 107*da0073e9SAndroid Build Coastguard Worker ) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Workerdef dropoutlstm_creator(script=True, **kwargs): 111*da0073e9SAndroid Build Coastguard Worker assert script is True 112*da0073e9SAndroid Build Coastguard Worker from .custom_lstms import LSTMState, script_lstm 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker input_size = kwargs["inputSize"] 115*da0073e9SAndroid Build Coastguard Worker hidden_size = kwargs["hiddenSize"] 116*da0073e9SAndroid Build Coastguard Worker seq_len = kwargs["seqLength"] 117*da0073e9SAndroid Build Coastguard Worker batch_size = kwargs["miniBatch"] 118*da0073e9SAndroid Build Coastguard Worker num_layers = kwargs["numLayers"] 119*da0073e9SAndroid Build Coastguard Worker ge = script_lstm(input_size, hidden_size, num_layers, dropout=True).cuda() 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker input = torch.randn(seq_len, batch_size, input_size, device="cuda") 122*da0073e9SAndroid Build Coastguard Worker states = [ 123*da0073e9SAndroid Build Coastguard Worker LSTMState( 124*da0073e9SAndroid Build Coastguard Worker torch.randn(batch_size, hidden_size, device="cuda"), 125*da0073e9SAndroid Build Coastguard Worker torch.randn(batch_size, hidden_size, device="cuda"), 126*da0073e9SAndroid Build Coastguard Worker ) 127*da0073e9SAndroid Build Coastguard Worker for _ in range(num_layers) 128*da0073e9SAndroid Build Coastguard Worker ] 129*da0073e9SAndroid Build Coastguard Worker return ModelDef( 130*da0073e9SAndroid Build Coastguard Worker inputs=[input, states], 131*da0073e9SAndroid Build Coastguard Worker params=ge.parameters(), 132*da0073e9SAndroid Build Coastguard Worker forward=ge, 133*da0073e9SAndroid Build Coastguard Worker backward_setup=lstm_backward_setup, 134*da0073e9SAndroid Build Coastguard Worker backward=simple_backward, 135*da0073e9SAndroid Build Coastguard Worker ) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Workerdef lstm_premul_creator(script=True, **kwargs): 139*da0073e9SAndroid Build Coastguard Worker input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) 140*da0073e9SAndroid Build Coastguard Worker inputs = [input, hidden] + params[0] 141*da0073e9SAndroid Build Coastguard Worker return ModelDef( 142*da0073e9SAndroid Build Coastguard Worker inputs=inputs, 143*da0073e9SAndroid Build Coastguard Worker params=flatten_list(params), 144*da0073e9SAndroid Build Coastguard Worker forward=lstm_factory_premul(premul_lstm_cell, script), 145*da0073e9SAndroid Build Coastguard Worker backward_setup=lstm_backward_setup, 146*da0073e9SAndroid Build Coastguard Worker backward=simple_backward, 147*da0073e9SAndroid Build Coastguard Worker ) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Workerdef lstm_premul_bias_creator(script=True, **kwargs): 151*da0073e9SAndroid Build Coastguard Worker input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) 152*da0073e9SAndroid Build Coastguard Worker inputs = [input, hidden] + params[0] 153*da0073e9SAndroid Build Coastguard Worker return ModelDef( 154*da0073e9SAndroid Build Coastguard Worker inputs=inputs, 155*da0073e9SAndroid Build Coastguard Worker params=flatten_list(params), 156*da0073e9SAndroid Build Coastguard Worker forward=lstm_factory_premul_bias(premul_lstm_cell_no_bias, script), 157*da0073e9SAndroid Build Coastguard Worker backward_setup=lstm_backward_setup, 158*da0073e9SAndroid Build Coastguard Worker backward=simple_backward, 159*da0073e9SAndroid Build Coastguard Worker ) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Workerdef lstm_simple_creator(script=True, **kwargs): 163*da0073e9SAndroid Build Coastguard Worker input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) 164*da0073e9SAndroid Build Coastguard Worker inputs = [input] + [h[0] for h in hidden] + params[0] 165*da0073e9SAndroid Build Coastguard Worker return ModelDef( 166*da0073e9SAndroid Build Coastguard Worker inputs=inputs, 167*da0073e9SAndroid Build Coastguard Worker params=flatten_list(params), 168*da0073e9SAndroid Build Coastguard Worker forward=lstm_factory_simple(flat_lstm_cell, script), 169*da0073e9SAndroid Build Coastguard Worker backward_setup=lstm_backward_setup, 170*da0073e9SAndroid Build Coastguard Worker backward=simple_backward, 171*da0073e9SAndroid Build Coastguard Worker ) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Workerdef lstm_multilayer_creator(script=True, **kwargs): 175*da0073e9SAndroid Build Coastguard Worker input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) 176*da0073e9SAndroid Build Coastguard Worker inputs = [input, hidden, flatten_list(params)] 177*da0073e9SAndroid Build Coastguard Worker return ModelDef( 178*da0073e9SAndroid Build Coastguard Worker inputs=inputs, 179*da0073e9SAndroid Build Coastguard Worker params=flatten_list(params), 180*da0073e9SAndroid Build Coastguard Worker forward=lstm_factory_multilayer(lstm_cell, script), 181*da0073e9SAndroid Build Coastguard Worker backward_setup=lstm_backward_setup, 182*da0073e9SAndroid Build Coastguard Worker backward=simple_backward, 183*da0073e9SAndroid Build Coastguard Worker ) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Workerdef imagenet_cnn_creator(arch, jit=True): 187*da0073e9SAndroid Build Coastguard Worker def creator(device="cuda", **kwargs): 188*da0073e9SAndroid Build Coastguard Worker model = arch().to(device) 189*da0073e9SAndroid Build Coastguard Worker x = torch.randn(32, 3, 224, 224, device=device) 190*da0073e9SAndroid Build Coastguard Worker if jit: 191*da0073e9SAndroid Build Coastguard Worker model = torch.jit.trace(model, x) 192*da0073e9SAndroid Build Coastguard Worker return ModelDef( 193*da0073e9SAndroid Build Coastguard Worker inputs=(x,), 194*da0073e9SAndroid Build Coastguard Worker params=list(model.parameters()), 195*da0073e9SAndroid Build Coastguard Worker forward=model, 196*da0073e9SAndroid Build Coastguard Worker backward_setup=simple_backward_setup, 197*da0073e9SAndroid Build Coastguard Worker backward=simple_backward, 198*da0073e9SAndroid Build Coastguard Worker ) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker return creator 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Workerdef varlen_lstm_inputs( 204*da0073e9SAndroid Build Coastguard Worker minlen=30, 205*da0073e9SAndroid Build Coastguard Worker maxlen=100, 206*da0073e9SAndroid Build Coastguard Worker numLayers=1, 207*da0073e9SAndroid Build Coastguard Worker inputSize=512, 208*da0073e9SAndroid Build Coastguard Worker hiddenSize=512, 209*da0073e9SAndroid Build Coastguard Worker miniBatch=64, 210*da0073e9SAndroid Build Coastguard Worker return_module=False, 211*da0073e9SAndroid Build Coastguard Worker device="cuda", 212*da0073e9SAndroid Build Coastguard Worker seed=None, 213*da0073e9SAndroid Build Coastguard Worker **kwargs, 214*da0073e9SAndroid Build Coastguard Worker): 215*da0073e9SAndroid Build Coastguard Worker if seed is not None: 216*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 217*da0073e9SAndroid Build Coastguard Worker lengths = torch.randint( 218*da0073e9SAndroid Build Coastguard Worker low=minlen, high=maxlen, size=[miniBatch], dtype=torch.long, device=device 219*da0073e9SAndroid Build Coastguard Worker ) 220*da0073e9SAndroid Build Coastguard Worker x = [torch.randn(length, inputSize, device=device) for length in lengths] 221*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) 222*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) 223*da0073e9SAndroid Build Coastguard Worker lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers).to(device) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker if return_module: 226*da0073e9SAndroid Build Coastguard Worker return x, lengths, (hx, cx), lstm.all_weights, lstm 227*da0073e9SAndroid Build Coastguard Worker else: 228*da0073e9SAndroid Build Coastguard Worker # NB: lstm.all_weights format: 229*da0073e9SAndroid Build Coastguard Worker # wih, whh, bih, bhh = lstm.all_weights[layer] 230*da0073e9SAndroid Build Coastguard Worker return x, lengths, (hx, cx), lstm.all_weights, None 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Workerdef varlen_lstm_backward_setup(forward_output, seed=None): 234*da0073e9SAndroid Build Coastguard Worker if seed: 235*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 236*da0073e9SAndroid Build Coastguard Worker rnn_utils = torch.nn.utils.rnn 237*da0073e9SAndroid Build Coastguard Worker sequences = forward_output[0] 238*da0073e9SAndroid Build Coastguard Worker padded = rnn_utils.pad_sequence(sequences) 239*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(padded) 240*da0073e9SAndroid Build Coastguard Worker return padded, grad 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Workerdef varlen_pytorch_lstm_creator(**kwargs): 244*da0073e9SAndroid Build Coastguard Worker rnn_utils = torch.nn.utils.rnn 245*da0073e9SAndroid Build Coastguard Worker sequences, _, hidden, _, module = varlen_lstm_inputs(return_module=True, **kwargs) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker def forward(sequences, hidden): 248*da0073e9SAndroid Build Coastguard Worker packed = rnn_utils.pack_sequence(sequences, enforce_sorted=False) 249*da0073e9SAndroid Build Coastguard Worker out, new_hidden = module(packed, hidden) 250*da0073e9SAndroid Build Coastguard Worker padded, lengths = rnn_utils.pad_packed_sequence(out) 251*da0073e9SAndroid Build Coastguard Worker # XXX: It's more efficient to store the output in its padded form, 252*da0073e9SAndroid Build Coastguard Worker # but that might not be conducive to loss computation. 253*da0073e9SAndroid Build Coastguard Worker # Un-padding the output also makes the backward pass 2x slower... 254*da0073e9SAndroid Build Coastguard Worker # return [padded[:lengths[i], i, :] for i in range(lengths.size(0))] 255*da0073e9SAndroid Build Coastguard Worker return padded, new_hidden 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker return ModelDef( 258*da0073e9SAndroid Build Coastguard Worker inputs=[sequences, hidden], 259*da0073e9SAndroid Build Coastguard Worker params=flatten_list(module.all_weights), 260*da0073e9SAndroid Build Coastguard Worker forward=forward, 261*da0073e9SAndroid Build Coastguard Worker backward_setup=lstm_backward_setup, 262*da0073e9SAndroid Build Coastguard Worker backward=simple_backward, 263*da0073e9SAndroid Build Coastguard Worker ) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Workerdef varlen_lstm_factory(cell, script): 267*da0073e9SAndroid Build Coastguard Worker def dynamic_rnn( 268*da0073e9SAndroid Build Coastguard Worker sequences: List[Tensor], 269*da0073e9SAndroid Build Coastguard Worker hiddens: Tuple[Tensor, Tensor], 270*da0073e9SAndroid Build Coastguard Worker wih: Tensor, 271*da0073e9SAndroid Build Coastguard Worker whh: Tensor, 272*da0073e9SAndroid Build Coastguard Worker bih: Tensor, 273*da0073e9SAndroid Build Coastguard Worker bhh: Tensor, 274*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]: 275*da0073e9SAndroid Build Coastguard Worker hx, cx = hiddens 276*da0073e9SAndroid Build Coastguard Worker hxs = hx.unbind(1) 277*da0073e9SAndroid Build Coastguard Worker cxs = cx.unbind(1) 278*da0073e9SAndroid Build Coastguard Worker # List of: (output, hx, cx) 279*da0073e9SAndroid Build Coastguard Worker outputs = [] 280*da0073e9SAndroid Build Coastguard Worker hx_outs = [] 281*da0073e9SAndroid Build Coastguard Worker cx_outs = [] 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker for batch in range(len(sequences)): 284*da0073e9SAndroid Build Coastguard Worker output = [] 285*da0073e9SAndroid Build Coastguard Worker hy, cy = hxs[batch], cxs[batch] 286*da0073e9SAndroid Build Coastguard Worker inputs = sequences[batch].unbind(0) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker for seq_idx in range(len(inputs)): 289*da0073e9SAndroid Build Coastguard Worker hy, cy = cell( 290*da0073e9SAndroid Build Coastguard Worker inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh 291*da0073e9SAndroid Build Coastguard Worker ) 292*da0073e9SAndroid Build Coastguard Worker output += [hy] 293*da0073e9SAndroid Build Coastguard Worker outputs += [torch.stack(output)] 294*da0073e9SAndroid Build Coastguard Worker hx_outs += [hy.unsqueeze(0)] 295*da0073e9SAndroid Build Coastguard Worker cx_outs += [cy.unsqueeze(0)] 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker return outputs, (hx_outs, cx_outs) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker if script: 300*da0073e9SAndroid Build Coastguard Worker cell = torch.jit.script(cell) 301*da0073e9SAndroid Build Coastguard Worker dynamic_rnn = torch.jit.script(dynamic_rnn) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker return dynamic_rnn 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Workerdef varlen_lstm_creator(script=False, **kwargs): 307*da0073e9SAndroid Build Coastguard Worker sequences, _, hidden, params, _ = varlen_lstm_inputs(return_module=False, **kwargs) 308*da0073e9SAndroid Build Coastguard Worker inputs = [sequences, hidden] + params[0] 309*da0073e9SAndroid Build Coastguard Worker return ModelDef( 310*da0073e9SAndroid Build Coastguard Worker inputs=inputs, 311*da0073e9SAndroid Build Coastguard Worker params=flatten_list(params), 312*da0073e9SAndroid Build Coastguard Worker forward=varlen_lstm_factory(lstm_cell, script), 313*da0073e9SAndroid Build Coastguard Worker backward_setup=varlen_lstm_backward_setup, 314*da0073e9SAndroid Build Coastguard Worker backward=simple_backward, 315*da0073e9SAndroid Build Coastguard Worker ) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker# cudnn_layernorm_lstm: since cudnn does not have Layernorm LSTM, we cannot benchmark 319*da0073e9SAndroid Build Coastguard Worker# the lowerbound directly. Instead, we only benchmark the forward pass by mimicing the 320*da0073e9SAndroid Build Coastguard Worker# computation of a cudnn lstm + seq_len * 3 layernorm computation. This should serve 321*da0073e9SAndroid Build Coastguard Worker# as a perf lowerbound for the Layernorm LSTM forward pass(given that Layernorm itself 322*da0073e9SAndroid Build Coastguard Worker# is invariant), the lowerbound of backward pass is hard to get since we lose the 323*da0073e9SAndroid Build Coastguard Worker# intermediate results, we can still optimize the layernorm implementation to make 324*da0073e9SAndroid Build Coastguard Worker# a faster forward lowerbound though. 325*da0073e9SAndroid Build Coastguard Workerdef layernorm_pytorch_lstm_creator(**kwargs): 326*da0073e9SAndroid Build Coastguard Worker input, hidden, _, module = lstm_inputs(return_module=True, **kwargs) 327*da0073e9SAndroid Build Coastguard Worker batch_size = kwargs["miniBatch"] 328*da0073e9SAndroid Build Coastguard Worker hidden_size = kwargs["hiddenSize"] 329*da0073e9SAndroid Build Coastguard Worker ln_i = torch.nn.LayerNorm(4 * hidden_size).cuda() 330*da0073e9SAndroid Build Coastguard Worker ln_h = torch.nn.LayerNorm(4 * hidden_size).cuda() 331*da0073e9SAndroid Build Coastguard Worker ln_c = torch.nn.LayerNorm(hidden_size).cuda() 332*da0073e9SAndroid Build Coastguard Worker ln_input1 = torch.randn(batch_size, 4 * hidden_size, device="cuda") 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker def forward(input, hidden): 335*da0073e9SAndroid Build Coastguard Worker out, new_hidden = module(input, hidden) 336*da0073e9SAndroid Build Coastguard Worker # plus (seq_len * three laynorm cell computation) to mimic the lower bound of 337*da0073e9SAndroid Build Coastguard Worker # Layernorm cudnn LSTM in the forward pass 338*da0073e9SAndroid Build Coastguard Worker seq_len = len(input.unbind(0)) 339*da0073e9SAndroid Build Coastguard Worker hy, cy = new_hidden 340*da0073e9SAndroid Build Coastguard Worker for i in range(seq_len): 341*da0073e9SAndroid Build Coastguard Worker ln_i_output = ln_i(ln_input1) 342*da0073e9SAndroid Build Coastguard Worker ln_h_output = ln_h(ln_input1) 343*da0073e9SAndroid Build Coastguard Worker cy = ln_c(cy) 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker return out, (hy, cy) 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker return ModelDef( 348*da0073e9SAndroid Build Coastguard Worker inputs=[input, hidden], 349*da0073e9SAndroid Build Coastguard Worker params=flatten_list(module.all_weights), 350*da0073e9SAndroid Build Coastguard Worker forward=forward, 351*da0073e9SAndroid Build Coastguard Worker backward_setup=lstm_backward_setup, 352*da0073e9SAndroid Build Coastguard Worker backward=None, 353*da0073e9SAndroid Build Coastguard Worker ) 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker# input: lstm.all_weights format (wih, whh, bih, bhh = lstm.all_weights[layer]) 357*da0073e9SAndroid Build Coastguard Worker# output: packed_weights with format 358*da0073e9SAndroid Build Coastguard Worker# packed_weights[0] is wih with size (layer, 4*hiddenSize, inputSize) 359*da0073e9SAndroid Build Coastguard Worker# packed_weights[1] is whh with size (layer, 4*hiddenSize, hiddenSize) 360*da0073e9SAndroid Build Coastguard Worker# packed_weights[2] is bih with size (layer, 4*hiddenSize) 361*da0073e9SAndroid Build Coastguard Worker# packed_weights[3] is bhh with size (layer, 4*hiddenSize) 362*da0073e9SAndroid Build Coastguard Workerdef stack_weights(weights): 363*da0073e9SAndroid Build Coastguard Worker def unzip_columns(mat): 364*da0073e9SAndroid Build Coastguard Worker assert isinstance(mat, list) 365*da0073e9SAndroid Build Coastguard Worker assert isinstance(mat[0], list) 366*da0073e9SAndroid Build Coastguard Worker layers = len(mat) 367*da0073e9SAndroid Build Coastguard Worker columns = len(mat[0]) 368*da0073e9SAndroid Build Coastguard Worker return [[mat[layer][col] for layer in range(layers)] for col in range(columns)] 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker # XXX: script fns have problems indexing multidim lists, so we try to 371*da0073e9SAndroid Build Coastguard Worker # avoid them by stacking tensors 372*da0073e9SAndroid Build Coastguard Worker all_weights = weights 373*da0073e9SAndroid Build Coastguard Worker packed_weights = [torch.stack(param) for param in unzip_columns(all_weights)] 374*da0073e9SAndroid Build Coastguard Worker return packed_weights 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker# returns: x, (hx, cx), all_weights, lstm module with all_weights as params 378*da0073e9SAndroid Build Coastguard Workerdef lstm_inputs( 379*da0073e9SAndroid Build Coastguard Worker seqLength=100, 380*da0073e9SAndroid Build Coastguard Worker numLayers=1, 381*da0073e9SAndroid Build Coastguard Worker inputSize=512, 382*da0073e9SAndroid Build Coastguard Worker hiddenSize=512, 383*da0073e9SAndroid Build Coastguard Worker miniBatch=64, 384*da0073e9SAndroid Build Coastguard Worker dropout=0.0, 385*da0073e9SAndroid Build Coastguard Worker return_module=False, 386*da0073e9SAndroid Build Coastguard Worker device="cuda", 387*da0073e9SAndroid Build Coastguard Worker seed=None, 388*da0073e9SAndroid Build Coastguard Worker): 389*da0073e9SAndroid Build Coastguard Worker if seed is not None: 390*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 391*da0073e9SAndroid Build Coastguard Worker x = torch.randn(seqLength, miniBatch, inputSize, device=device) 392*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) 393*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) 394*da0073e9SAndroid Build Coastguard Worker lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers, dropout=dropout) 395*da0073e9SAndroid Build Coastguard Worker if "cuda" in device: 396*da0073e9SAndroid Build Coastguard Worker lstm = lstm.cuda() 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker if return_module: 399*da0073e9SAndroid Build Coastguard Worker return x, (hx, cx), lstm.all_weights, lstm 400*da0073e9SAndroid Build Coastguard Worker else: 401*da0073e9SAndroid Build Coastguard Worker # NB: lstm.all_weights format: 402*da0073e9SAndroid Build Coastguard Worker # wih, whh, bih, bhh = lstm.all_weights[layer] 403*da0073e9SAndroid Build Coastguard Worker return x, (hx, cx), lstm.all_weights, None 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Workerdef lstm_factory(cell, script): 407*da0073e9SAndroid Build Coastguard Worker def dynamic_rnn( 408*da0073e9SAndroid Build Coastguard Worker input: Tensor, 409*da0073e9SAndroid Build Coastguard Worker hidden: Tuple[Tensor, Tensor], 410*da0073e9SAndroid Build Coastguard Worker wih: Tensor, 411*da0073e9SAndroid Build Coastguard Worker whh: Tensor, 412*da0073e9SAndroid Build Coastguard Worker bih: Tensor, 413*da0073e9SAndroid Build Coastguard Worker bhh: Tensor, 414*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 415*da0073e9SAndroid Build Coastguard Worker hx, cx = hidden 416*da0073e9SAndroid Build Coastguard Worker outputs = [] 417*da0073e9SAndroid Build Coastguard Worker inputs = input.unbind(0) 418*da0073e9SAndroid Build Coastguard Worker hy, cy = hx[0], cx[0] 419*da0073e9SAndroid Build Coastguard Worker for seq_idx in range(len(inputs)): 420*da0073e9SAndroid Build Coastguard Worker hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) 421*da0073e9SAndroid Build Coastguard Worker outputs += [hy] 422*da0073e9SAndroid Build Coastguard Worker return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker if script: 425*da0073e9SAndroid Build Coastguard Worker cell = torch.jit.script(cell) 426*da0073e9SAndroid Build Coastguard Worker dynamic_rnn = torch.jit.script(dynamic_rnn) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker return dynamic_rnn 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker# premul: we're going to premultiply the inputs & weights 432*da0073e9SAndroid Build Coastguard Workerdef lstm_factory_premul(premul_cell, script): 433*da0073e9SAndroid Build Coastguard Worker def dynamic_rnn( 434*da0073e9SAndroid Build Coastguard Worker input: Tensor, 435*da0073e9SAndroid Build Coastguard Worker hidden: Tuple[Tensor, Tensor], 436*da0073e9SAndroid Build Coastguard Worker wih: Tensor, 437*da0073e9SAndroid Build Coastguard Worker whh: Tensor, 438*da0073e9SAndroid Build Coastguard Worker bih: Tensor, 439*da0073e9SAndroid Build Coastguard Worker bhh: Tensor, 440*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 441*da0073e9SAndroid Build Coastguard Worker hx, cx = hidden 442*da0073e9SAndroid Build Coastguard Worker outputs = [] 443*da0073e9SAndroid Build Coastguard Worker inputs = torch.matmul(input, wih.t()).unbind(0) 444*da0073e9SAndroid Build Coastguard Worker hy, cy = hx[0], cx[0] 445*da0073e9SAndroid Build Coastguard Worker for seq_idx in range(len(inputs)): 446*da0073e9SAndroid Build Coastguard Worker hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bih, bhh) 447*da0073e9SAndroid Build Coastguard Worker outputs += [hy] 448*da0073e9SAndroid Build Coastguard Worker return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker if script: 451*da0073e9SAndroid Build Coastguard Worker premul_cell = torch.jit.script(premul_cell) 452*da0073e9SAndroid Build Coastguard Worker dynamic_rnn = torch.jit.script(dynamic_rnn) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker return dynamic_rnn 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker# premul: we're going to premultiply the inputs & weights, and add bias 458*da0073e9SAndroid Build Coastguard Workerdef lstm_factory_premul_bias(premul_cell, script): 459*da0073e9SAndroid Build Coastguard Worker def dynamic_rnn( 460*da0073e9SAndroid Build Coastguard Worker input: Tensor, 461*da0073e9SAndroid Build Coastguard Worker hidden: Tuple[Tensor, Tensor], 462*da0073e9SAndroid Build Coastguard Worker wih: Tensor, 463*da0073e9SAndroid Build Coastguard Worker whh: Tensor, 464*da0073e9SAndroid Build Coastguard Worker bih: Tensor, 465*da0073e9SAndroid Build Coastguard Worker bhh: Tensor, 466*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 467*da0073e9SAndroid Build Coastguard Worker hx, cx = hidden 468*da0073e9SAndroid Build Coastguard Worker outputs = [] 469*da0073e9SAndroid Build Coastguard Worker inpSize = input.size() 470*da0073e9SAndroid Build Coastguard Worker # add bias for all timesteps instead of going step-by-step, results in a single reduction kernel in the backward 471*da0073e9SAndroid Build Coastguard Worker # FIXME matmul(x,y) + bias currently goes through jit AD, and backward formula in AD is not optimized for this 472*da0073e9SAndroid Build Coastguard Worker # case. Workaround with mm and views. 473*da0073e9SAndroid Build Coastguard Worker inpSize = input.size() 474*da0073e9SAndroid Build Coastguard Worker inputs = torch.mm(input.view(-1, inpSize[2]), wih.t()) + bih 475*da0073e9SAndroid Build Coastguard Worker inputs = inputs.view(inpSize[0], inpSize[1], -1).unbind(0) 476*da0073e9SAndroid Build Coastguard Worker hy, cy = hx[0], cx[0] 477*da0073e9SAndroid Build Coastguard Worker for seq_idx in range(len(inputs)): 478*da0073e9SAndroid Build Coastguard Worker hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bhh) 479*da0073e9SAndroid Build Coastguard Worker outputs += [hy] 480*da0073e9SAndroid Build Coastguard Worker return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker if script: 483*da0073e9SAndroid Build Coastguard Worker premul_cell = torch.jit.script(premul_cell) 484*da0073e9SAndroid Build Coastguard Worker dynamic_rnn = torch.jit.script(dynamic_rnn) 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker return dynamic_rnn 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker# simple: flat inputs (no tuples), no list to accumulate outputs 490*da0073e9SAndroid Build Coastguard Worker# useful mostly for benchmarking older JIT versions 491*da0073e9SAndroid Build Coastguard Workerdef lstm_factory_simple(cell, script): 492*da0073e9SAndroid Build Coastguard Worker def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh): 493*da0073e9SAndroid Build Coastguard Worker hy = hx # for scoping 494*da0073e9SAndroid Build Coastguard Worker cy = cx # for scoping 495*da0073e9SAndroid Build Coastguard Worker inputs = input.unbind(0) 496*da0073e9SAndroid Build Coastguard Worker for seq_idx in range(len(inputs)): 497*da0073e9SAndroid Build Coastguard Worker hy, cy = cell(inputs[seq_idx], hy, cy, wih, whh, bih, bhh) 498*da0073e9SAndroid Build Coastguard Worker return hy, cy 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker if script: 501*da0073e9SAndroid Build Coastguard Worker cell = torch.jit.script(cell) 502*da0073e9SAndroid Build Coastguard Worker dynamic_rnn = torch.jit.script(dynamic_rnn) 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker return dynamic_rnn 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Workerdef lstm_factory_multilayer(cell, script): 508*da0073e9SAndroid Build Coastguard Worker def dynamic_rnn( 509*da0073e9SAndroid Build Coastguard Worker input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor] 510*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 511*da0073e9SAndroid Build Coastguard Worker params_stride = 4 # NB: this assumes that biases are there 512*da0073e9SAndroid Build Coastguard Worker hx, cx = hidden 513*da0073e9SAndroid Build Coastguard Worker hy, cy = hidden # for scoping... 514*da0073e9SAndroid Build Coastguard Worker inputs, outputs = input.unbind(0), [] 515*da0073e9SAndroid Build Coastguard Worker for layer in range(hx.size(0)): 516*da0073e9SAndroid Build Coastguard Worker hy = hx[layer] 517*da0073e9SAndroid Build Coastguard Worker cy = cx[layer] 518*da0073e9SAndroid Build Coastguard Worker base_idx = layer * params_stride 519*da0073e9SAndroid Build Coastguard Worker wih = params[base_idx] 520*da0073e9SAndroid Build Coastguard Worker whh = params[base_idx + 1] 521*da0073e9SAndroid Build Coastguard Worker bih = params[base_idx + 2] 522*da0073e9SAndroid Build Coastguard Worker bhh = params[base_idx + 3] 523*da0073e9SAndroid Build Coastguard Worker for seq_idx in range(len(inputs)): 524*da0073e9SAndroid Build Coastguard Worker hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) 525*da0073e9SAndroid Build Coastguard Worker outputs += [hy] 526*da0073e9SAndroid Build Coastguard Worker inputs, outputs = outputs, [] 527*da0073e9SAndroid Build Coastguard Worker return torch.stack(inputs), (hy.unsqueeze(0), cy.unsqueeze(0)) 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker if script: 530*da0073e9SAndroid Build Coastguard Worker cell = torch.jit.script(cell) 531*da0073e9SAndroid Build Coastguard Worker dynamic_rnn = torch.jit.script(dynamic_rnn) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker return dynamic_rnn 534