1from collections import namedtuple 2from typing import List, Tuple 3 4import torch 5from torch import Tensor 6 7from .cells import flat_lstm_cell, lstm_cell, premul_lstm_cell, premul_lstm_cell_no_bias 8 9 10# list[list[T]] -> list[T] 11def flatten_list(lst): 12 result = [] 13 for inner in lst: 14 result.extend(inner) 15 return result 16 17 18""" 19Define a creator as a function: 20(options) -> (inputs, params, forward, backward_setup, backward) 21inputs: the inputs to the returned 'forward'. One can call 22 forward(*inputs) directly. 23params: List[Tensor] all requires_grad=True parameters. 24forward: function / graph executor / module 25 One can call rnn(rnn_inputs) using the outputs of the creator. 26backward_setup: backward_inputs = backward_setup(*outputs) 27 Then, we pass backward_inputs to backward. If None, then it is assumed to 28 be the identity function. 29backward: Given `output = backward_setup(*forward(*inputs))`, performs 30 backpropagation. If None, then nothing happens. 31 32fastrnns.bench times the forward and backward invocations. 33""" 34 35 36ModelDef = namedtuple( 37 "ModelDef", ["inputs", "params", "forward", "backward_setup", "backward"] 38) 39 40 41def lstm_backward_setup(lstm_outputs, seed=None): 42 hx, _ = lstm_outputs 43 return simple_backward_setup(hx, seed) 44 45 46def simple_backward_setup(output, seed=None): 47 assert isinstance(output, torch.Tensor) 48 if seed: 49 torch.manual_seed(seed) 50 grad_output = torch.randn_like(output) 51 return output, grad_output 52 53 54def simple_backward(output, grad_output, **kwargs): 55 return output.backward(grad_output, **kwargs) 56 57 58def pytorch_lstm_creator(**kwargs): 59 input, hidden, _, module = lstm_inputs(return_module=True, **kwargs) 60 return ModelDef( 61 inputs=[input, hidden], 62 params=flatten_list(module.all_weights), 63 forward=module, 64 backward_setup=lstm_backward_setup, 65 backward=simple_backward, 66 ) 67 68 69def lstm_creator(script=True, **kwargs): 70 input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) 71 inputs = [input, hidden] + params[0] 72 return ModelDef( 73 inputs=inputs, 74 params=flatten_list(params), 75 forward=lstm_factory(lstm_cell, script), 76 backward_setup=lstm_backward_setup, 77 backward=simple_backward, 78 ) 79 80 81def lnlstm_creator(script=True, decompose_layernorm=False, **kwargs): 82 assert script is True 83 from .custom_lstms import script_lnlstm 84 85 input_size = kwargs["inputSize"] 86 hidden_size = kwargs["hiddenSize"] 87 seq_len = kwargs["seqLength"] 88 batch_size = kwargs["miniBatch"] 89 ge = script_lnlstm( 90 input_size, hidden_size, 1, decompose_layernorm=decompose_layernorm 91 ).cuda() 92 93 input = torch.randn(seq_len, batch_size, input_size, device="cuda") 94 states = [ 95 ( 96 torch.randn(batch_size, hidden_size, device="cuda"), 97 torch.randn(batch_size, hidden_size, device="cuda"), 98 ) 99 ] 100 101 return ModelDef( 102 inputs=[input, states], 103 params=ge.parameters(), 104 forward=ge, 105 backward_setup=lstm_backward_setup, 106 backward=simple_backward, 107 ) 108 109 110def dropoutlstm_creator(script=True, **kwargs): 111 assert script is True 112 from .custom_lstms import LSTMState, script_lstm 113 114 input_size = kwargs["inputSize"] 115 hidden_size = kwargs["hiddenSize"] 116 seq_len = kwargs["seqLength"] 117 batch_size = kwargs["miniBatch"] 118 num_layers = kwargs["numLayers"] 119 ge = script_lstm(input_size, hidden_size, num_layers, dropout=True).cuda() 120 121 input = torch.randn(seq_len, batch_size, input_size, device="cuda") 122 states = [ 123 LSTMState( 124 torch.randn(batch_size, hidden_size, device="cuda"), 125 torch.randn(batch_size, hidden_size, device="cuda"), 126 ) 127 for _ in range(num_layers) 128 ] 129 return ModelDef( 130 inputs=[input, states], 131 params=ge.parameters(), 132 forward=ge, 133 backward_setup=lstm_backward_setup, 134 backward=simple_backward, 135 ) 136 137 138def lstm_premul_creator(script=True, **kwargs): 139 input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) 140 inputs = [input, hidden] + params[0] 141 return ModelDef( 142 inputs=inputs, 143 params=flatten_list(params), 144 forward=lstm_factory_premul(premul_lstm_cell, script), 145 backward_setup=lstm_backward_setup, 146 backward=simple_backward, 147 ) 148 149 150def lstm_premul_bias_creator(script=True, **kwargs): 151 input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) 152 inputs = [input, hidden] + params[0] 153 return ModelDef( 154 inputs=inputs, 155 params=flatten_list(params), 156 forward=lstm_factory_premul_bias(premul_lstm_cell_no_bias, script), 157 backward_setup=lstm_backward_setup, 158 backward=simple_backward, 159 ) 160 161 162def lstm_simple_creator(script=True, **kwargs): 163 input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) 164 inputs = [input] + [h[0] for h in hidden] + params[0] 165 return ModelDef( 166 inputs=inputs, 167 params=flatten_list(params), 168 forward=lstm_factory_simple(flat_lstm_cell, script), 169 backward_setup=lstm_backward_setup, 170 backward=simple_backward, 171 ) 172 173 174def lstm_multilayer_creator(script=True, **kwargs): 175 input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) 176 inputs = [input, hidden, flatten_list(params)] 177 return ModelDef( 178 inputs=inputs, 179 params=flatten_list(params), 180 forward=lstm_factory_multilayer(lstm_cell, script), 181 backward_setup=lstm_backward_setup, 182 backward=simple_backward, 183 ) 184 185 186def imagenet_cnn_creator(arch, jit=True): 187 def creator(device="cuda", **kwargs): 188 model = arch().to(device) 189 x = torch.randn(32, 3, 224, 224, device=device) 190 if jit: 191 model = torch.jit.trace(model, x) 192 return ModelDef( 193 inputs=(x,), 194 params=list(model.parameters()), 195 forward=model, 196 backward_setup=simple_backward_setup, 197 backward=simple_backward, 198 ) 199 200 return creator 201 202 203def varlen_lstm_inputs( 204 minlen=30, 205 maxlen=100, 206 numLayers=1, 207 inputSize=512, 208 hiddenSize=512, 209 miniBatch=64, 210 return_module=False, 211 device="cuda", 212 seed=None, 213 **kwargs, 214): 215 if seed is not None: 216 torch.manual_seed(seed) 217 lengths = torch.randint( 218 low=minlen, high=maxlen, size=[miniBatch], dtype=torch.long, device=device 219 ) 220 x = [torch.randn(length, inputSize, device=device) for length in lengths] 221 hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) 222 cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) 223 lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers).to(device) 224 225 if return_module: 226 return x, lengths, (hx, cx), lstm.all_weights, lstm 227 else: 228 # NB: lstm.all_weights format: 229 # wih, whh, bih, bhh = lstm.all_weights[layer] 230 return x, lengths, (hx, cx), lstm.all_weights, None 231 232 233def varlen_lstm_backward_setup(forward_output, seed=None): 234 if seed: 235 torch.manual_seed(seed) 236 rnn_utils = torch.nn.utils.rnn 237 sequences = forward_output[0] 238 padded = rnn_utils.pad_sequence(sequences) 239 grad = torch.randn_like(padded) 240 return padded, grad 241 242 243def varlen_pytorch_lstm_creator(**kwargs): 244 rnn_utils = torch.nn.utils.rnn 245 sequences, _, hidden, _, module = varlen_lstm_inputs(return_module=True, **kwargs) 246 247 def forward(sequences, hidden): 248 packed = rnn_utils.pack_sequence(sequences, enforce_sorted=False) 249 out, new_hidden = module(packed, hidden) 250 padded, lengths = rnn_utils.pad_packed_sequence(out) 251 # XXX: It's more efficient to store the output in its padded form, 252 # but that might not be conducive to loss computation. 253 # Un-padding the output also makes the backward pass 2x slower... 254 # return [padded[:lengths[i], i, :] for i in range(lengths.size(0))] 255 return padded, new_hidden 256 257 return ModelDef( 258 inputs=[sequences, hidden], 259 params=flatten_list(module.all_weights), 260 forward=forward, 261 backward_setup=lstm_backward_setup, 262 backward=simple_backward, 263 ) 264 265 266def varlen_lstm_factory(cell, script): 267 def dynamic_rnn( 268 sequences: List[Tensor], 269 hiddens: Tuple[Tensor, Tensor], 270 wih: Tensor, 271 whh: Tensor, 272 bih: Tensor, 273 bhh: Tensor, 274 ) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]: 275 hx, cx = hiddens 276 hxs = hx.unbind(1) 277 cxs = cx.unbind(1) 278 # List of: (output, hx, cx) 279 outputs = [] 280 hx_outs = [] 281 cx_outs = [] 282 283 for batch in range(len(sequences)): 284 output = [] 285 hy, cy = hxs[batch], cxs[batch] 286 inputs = sequences[batch].unbind(0) 287 288 for seq_idx in range(len(inputs)): 289 hy, cy = cell( 290 inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh 291 ) 292 output += [hy] 293 outputs += [torch.stack(output)] 294 hx_outs += [hy.unsqueeze(0)] 295 cx_outs += [cy.unsqueeze(0)] 296 297 return outputs, (hx_outs, cx_outs) 298 299 if script: 300 cell = torch.jit.script(cell) 301 dynamic_rnn = torch.jit.script(dynamic_rnn) 302 303 return dynamic_rnn 304 305 306def varlen_lstm_creator(script=False, **kwargs): 307 sequences, _, hidden, params, _ = varlen_lstm_inputs(return_module=False, **kwargs) 308 inputs = [sequences, hidden] + params[0] 309 return ModelDef( 310 inputs=inputs, 311 params=flatten_list(params), 312 forward=varlen_lstm_factory(lstm_cell, script), 313 backward_setup=varlen_lstm_backward_setup, 314 backward=simple_backward, 315 ) 316 317 318# cudnn_layernorm_lstm: since cudnn does not have Layernorm LSTM, we cannot benchmark 319# the lowerbound directly. Instead, we only benchmark the forward pass by mimicing the 320# computation of a cudnn lstm + seq_len * 3 layernorm computation. This should serve 321# as a perf lowerbound for the Layernorm LSTM forward pass(given that Layernorm itself 322# is invariant), the lowerbound of backward pass is hard to get since we lose the 323# intermediate results, we can still optimize the layernorm implementation to make 324# a faster forward lowerbound though. 325def layernorm_pytorch_lstm_creator(**kwargs): 326 input, hidden, _, module = lstm_inputs(return_module=True, **kwargs) 327 batch_size = kwargs["miniBatch"] 328 hidden_size = kwargs["hiddenSize"] 329 ln_i = torch.nn.LayerNorm(4 * hidden_size).cuda() 330 ln_h = torch.nn.LayerNorm(4 * hidden_size).cuda() 331 ln_c = torch.nn.LayerNorm(hidden_size).cuda() 332 ln_input1 = torch.randn(batch_size, 4 * hidden_size, device="cuda") 333 334 def forward(input, hidden): 335 out, new_hidden = module(input, hidden) 336 # plus (seq_len * three laynorm cell computation) to mimic the lower bound of 337 # Layernorm cudnn LSTM in the forward pass 338 seq_len = len(input.unbind(0)) 339 hy, cy = new_hidden 340 for i in range(seq_len): 341 ln_i_output = ln_i(ln_input1) 342 ln_h_output = ln_h(ln_input1) 343 cy = ln_c(cy) 344 345 return out, (hy, cy) 346 347 return ModelDef( 348 inputs=[input, hidden], 349 params=flatten_list(module.all_weights), 350 forward=forward, 351 backward_setup=lstm_backward_setup, 352 backward=None, 353 ) 354 355 356# input: lstm.all_weights format (wih, whh, bih, bhh = lstm.all_weights[layer]) 357# output: packed_weights with format 358# packed_weights[0] is wih with size (layer, 4*hiddenSize, inputSize) 359# packed_weights[1] is whh with size (layer, 4*hiddenSize, hiddenSize) 360# packed_weights[2] is bih with size (layer, 4*hiddenSize) 361# packed_weights[3] is bhh with size (layer, 4*hiddenSize) 362def stack_weights(weights): 363 def unzip_columns(mat): 364 assert isinstance(mat, list) 365 assert isinstance(mat[0], list) 366 layers = len(mat) 367 columns = len(mat[0]) 368 return [[mat[layer][col] for layer in range(layers)] for col in range(columns)] 369 370 # XXX: script fns have problems indexing multidim lists, so we try to 371 # avoid them by stacking tensors 372 all_weights = weights 373 packed_weights = [torch.stack(param) for param in unzip_columns(all_weights)] 374 return packed_weights 375 376 377# returns: x, (hx, cx), all_weights, lstm module with all_weights as params 378def lstm_inputs( 379 seqLength=100, 380 numLayers=1, 381 inputSize=512, 382 hiddenSize=512, 383 miniBatch=64, 384 dropout=0.0, 385 return_module=False, 386 device="cuda", 387 seed=None, 388): 389 if seed is not None: 390 torch.manual_seed(seed) 391 x = torch.randn(seqLength, miniBatch, inputSize, device=device) 392 hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) 393 cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) 394 lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers, dropout=dropout) 395 if "cuda" in device: 396 lstm = lstm.cuda() 397 398 if return_module: 399 return x, (hx, cx), lstm.all_weights, lstm 400 else: 401 # NB: lstm.all_weights format: 402 # wih, whh, bih, bhh = lstm.all_weights[layer] 403 return x, (hx, cx), lstm.all_weights, None 404 405 406def lstm_factory(cell, script): 407 def dynamic_rnn( 408 input: Tensor, 409 hidden: Tuple[Tensor, Tensor], 410 wih: Tensor, 411 whh: Tensor, 412 bih: Tensor, 413 bhh: Tensor, 414 ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 415 hx, cx = hidden 416 outputs = [] 417 inputs = input.unbind(0) 418 hy, cy = hx[0], cx[0] 419 for seq_idx in range(len(inputs)): 420 hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) 421 outputs += [hy] 422 return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) 423 424 if script: 425 cell = torch.jit.script(cell) 426 dynamic_rnn = torch.jit.script(dynamic_rnn) 427 428 return dynamic_rnn 429 430 431# premul: we're going to premultiply the inputs & weights 432def lstm_factory_premul(premul_cell, script): 433 def dynamic_rnn( 434 input: Tensor, 435 hidden: Tuple[Tensor, Tensor], 436 wih: Tensor, 437 whh: Tensor, 438 bih: Tensor, 439 bhh: Tensor, 440 ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 441 hx, cx = hidden 442 outputs = [] 443 inputs = torch.matmul(input, wih.t()).unbind(0) 444 hy, cy = hx[0], cx[0] 445 for seq_idx in range(len(inputs)): 446 hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bih, bhh) 447 outputs += [hy] 448 return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) 449 450 if script: 451 premul_cell = torch.jit.script(premul_cell) 452 dynamic_rnn = torch.jit.script(dynamic_rnn) 453 454 return dynamic_rnn 455 456 457# premul: we're going to premultiply the inputs & weights, and add bias 458def lstm_factory_premul_bias(premul_cell, script): 459 def dynamic_rnn( 460 input: Tensor, 461 hidden: Tuple[Tensor, Tensor], 462 wih: Tensor, 463 whh: Tensor, 464 bih: Tensor, 465 bhh: Tensor, 466 ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 467 hx, cx = hidden 468 outputs = [] 469 inpSize = input.size() 470 # add bias for all timesteps instead of going step-by-step, results in a single reduction kernel in the backward 471 # FIXME matmul(x,y) + bias currently goes through jit AD, and backward formula in AD is not optimized for this 472 # case. Workaround with mm and views. 473 inpSize = input.size() 474 inputs = torch.mm(input.view(-1, inpSize[2]), wih.t()) + bih 475 inputs = inputs.view(inpSize[0], inpSize[1], -1).unbind(0) 476 hy, cy = hx[0], cx[0] 477 for seq_idx in range(len(inputs)): 478 hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bhh) 479 outputs += [hy] 480 return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) 481 482 if script: 483 premul_cell = torch.jit.script(premul_cell) 484 dynamic_rnn = torch.jit.script(dynamic_rnn) 485 486 return dynamic_rnn 487 488 489# simple: flat inputs (no tuples), no list to accumulate outputs 490# useful mostly for benchmarking older JIT versions 491def lstm_factory_simple(cell, script): 492 def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh): 493 hy = hx # for scoping 494 cy = cx # for scoping 495 inputs = input.unbind(0) 496 for seq_idx in range(len(inputs)): 497 hy, cy = cell(inputs[seq_idx], hy, cy, wih, whh, bih, bhh) 498 return hy, cy 499 500 if script: 501 cell = torch.jit.script(cell) 502 dynamic_rnn = torch.jit.script(dynamic_rnn) 503 504 return dynamic_rnn 505 506 507def lstm_factory_multilayer(cell, script): 508 def dynamic_rnn( 509 input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor] 510 ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 511 params_stride = 4 # NB: this assumes that biases are there 512 hx, cx = hidden 513 hy, cy = hidden # for scoping... 514 inputs, outputs = input.unbind(0), [] 515 for layer in range(hx.size(0)): 516 hy = hx[layer] 517 cy = cx[layer] 518 base_idx = layer * params_stride 519 wih = params[base_idx] 520 whh = params[base_idx + 1] 521 bih = params[base_idx + 2] 522 bhh = params[base_idx + 3] 523 for seq_idx in range(len(inputs)): 524 hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) 525 outputs += [hy] 526 inputs, outputs = outputs, [] 527 return torch.stack(inputs), (hy.unsqueeze(0), cy.unsqueeze(0)) 528 529 if script: 530 cell = torch.jit.script(cell) 531 dynamic_rnn = torch.jit.script(dynamic_rnn) 532 533 return dynamic_rnn 534