1import argparse 2 3import torch 4import torch.nn as nn 5 6from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator 7from .runner import get_nn_runners 8 9 10def barf(): 11 import pdb 12 13 pdb.set_trace() 14 15 16def assertEqual(tensor, expected, threshold=0.001): 17 if isinstance(tensor, (list, tuple)): 18 for t, e in zip(tensor, expected): 19 assertEqual(t, e) 20 else: 21 if (tensor - expected).abs().max() > threshold: 22 barf() 23 24 25def filter_requires_grad(tensors): 26 return [t for t in tensors if t.requires_grad] 27 28 29def test_rnns( 30 experim_creator, 31 control_creator, 32 check_grad=True, 33 verbose=False, 34 seqLength=100, 35 numLayers=1, 36 inputSize=512, 37 hiddenSize=512, 38 miniBatch=64, 39 device="cuda", 40 seed=17, 41): 42 creator_args = dict( 43 seqLength=seqLength, 44 numLayers=numLayers, 45 inputSize=inputSize, 46 hiddenSize=hiddenSize, 47 miniBatch=miniBatch, 48 device=device, 49 seed=seed, 50 ) 51 52 print("Setting up...") 53 control = control_creator(**creator_args) 54 experim = experim_creator(**creator_args) 55 56 # Precondition 57 assertEqual(experim.inputs, control.inputs) 58 assertEqual(experim.params, control.params) 59 60 print("Checking outputs...") 61 control_outputs = control.forward(*control.inputs) 62 experim_outputs = experim.forward(*experim.inputs) 63 assertEqual(experim_outputs, control_outputs) 64 65 print("Checking grads...") 66 assert control.backward_setup is not None 67 assert experim.backward_setup is not None 68 assert control.backward is not None 69 assert experim.backward is not None 70 control_backward_inputs = control.backward_setup(control_outputs, seed) 71 experim_backward_inputs = experim.backward_setup(experim_outputs, seed) 72 73 control.backward(*control_backward_inputs) 74 experim.backward(*experim_backward_inputs) 75 76 control_grads = [p.grad for p in control.params] 77 experim_grads = [p.grad for p in experim.params] 78 assertEqual(experim_grads, control_grads) 79 80 if verbose: 81 print(experim.forward.graph_for(*experim.inputs)) 82 print() 83 84 85def test_vl_py(**test_args): 86 # XXX: This compares vl_py with vl_lstm. 87 # It's done this way because those two don't give the same outputs so 88 # the result isn't an apples-to-apples comparison right now. 89 control_creator = varlen_pytorch_lstm_creator 90 name, experim_creator, context = get_nn_runners("vl_py")[0] 91 with context(): 92 print(f"testing {name}...") 93 creator_keys = [ 94 "seqLength", 95 "numLayers", 96 "inputSize", 97 "hiddenSize", 98 "miniBatch", 99 "device", 100 "seed", 101 ] 102 creator_args = {key: test_args[key] for key in creator_keys} 103 104 print("Setting up...") 105 control = control_creator(**creator_args) 106 experim = experim_creator(**creator_args) 107 108 # Precondition 109 assertEqual(experim.inputs, control.inputs[:2]) 110 assertEqual(experim.params, control.params) 111 112 print("Checking outputs...") 113 control_out, control_hiddens = control.forward(*control.inputs) 114 control_hx, control_cx = control_hiddens 115 experim_out, experim_hiddens = experim.forward(*experim.inputs) 116 experim_hx, experim_cx = experim_hiddens 117 118 experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2) 119 assertEqual(experim_padded, control_out) 120 assertEqual(torch.cat(experim_hx, dim=1), control_hx) 121 assertEqual(torch.cat(experim_cx, dim=1), control_cx) 122 123 print("Checking grads...") 124 assert control.backward_setup is not None 125 assert experim.backward_setup is not None 126 assert control.backward is not None 127 assert experim.backward is not None 128 control_backward_inputs = control.backward_setup( 129 (control_out, control_hiddens), test_args["seed"] 130 ) 131 experim_backward_inputs = experim.backward_setup( 132 (experim_out, experim_hiddens), test_args["seed"] 133 ) 134 135 control.backward(*control_backward_inputs) 136 experim.backward(*experim_backward_inputs) 137 138 control_grads = [p.grad for p in control.params] 139 experim_grads = [p.grad for p in experim.params] 140 assertEqual(experim_grads, control_grads) 141 142 if test_args["verbose"]: 143 print(experim.forward.graph_for(*experim.inputs)) 144 print() 145 146 147if __name__ == "__main__": 148 parser = argparse.ArgumentParser(description="Test lstm correctness") 149 150 parser.add_argument("--seqLength", default="100", type=int) 151 parser.add_argument("--numLayers", default="1", type=int) 152 parser.add_argument("--inputSize", default="512", type=int) 153 parser.add_argument("--hiddenSize", default="512", type=int) 154 parser.add_argument("--miniBatch", default="64", type=int) 155 parser.add_argument("--device", default="cuda", type=str) 156 parser.add_argument("--check-grad", "--check_grad", default="True", type=bool) 157 parser.add_argument("--variable-lstms", "--variable_lstms", action="store_true") 158 parser.add_argument("--seed", default="17", type=int) 159 parser.add_argument("--verbose", action="store_true") 160 parser.add_argument("--rnns", nargs="*", help="What to run. jit_premul, jit, etc") 161 args = parser.parse_args() 162 if args.rnns is None: 163 args.rnns = ["jit_premul", "jit"] 164 print(args) 165 166 if "cuda" in args.device: 167 assert torch.cuda.is_available() 168 169 rnn_runners = get_nn_runners(*args.rnns) 170 171 should_test_varlen_lstms = args.variable_lstms 172 test_args = vars(args) 173 del test_args["rnns"] 174 del test_args["variable_lstms"] 175 176 if should_test_varlen_lstms: 177 test_vl_py(**test_args) 178 179 for name, creator, context in rnn_runners: 180 with context(): 181 print(f"testing {name}...") 182 test_rnns(creator, pytorch_lstm_creator, **test_args) 183