xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2
3import torch
4import torch.nn as nn
5
6from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator
7from .runner import get_nn_runners
8
9
10def barf():
11    import pdb
12
13    pdb.set_trace()
14
15
16def assertEqual(tensor, expected, threshold=0.001):
17    if isinstance(tensor, (list, tuple)):
18        for t, e in zip(tensor, expected):
19            assertEqual(t, e)
20    else:
21        if (tensor - expected).abs().max() > threshold:
22            barf()
23
24
25def filter_requires_grad(tensors):
26    return [t for t in tensors if t.requires_grad]
27
28
29def test_rnns(
30    experim_creator,
31    control_creator,
32    check_grad=True,
33    verbose=False,
34    seqLength=100,
35    numLayers=1,
36    inputSize=512,
37    hiddenSize=512,
38    miniBatch=64,
39    device="cuda",
40    seed=17,
41):
42    creator_args = dict(
43        seqLength=seqLength,
44        numLayers=numLayers,
45        inputSize=inputSize,
46        hiddenSize=hiddenSize,
47        miniBatch=miniBatch,
48        device=device,
49        seed=seed,
50    )
51
52    print("Setting up...")
53    control = control_creator(**creator_args)
54    experim = experim_creator(**creator_args)
55
56    # Precondition
57    assertEqual(experim.inputs, control.inputs)
58    assertEqual(experim.params, control.params)
59
60    print("Checking outputs...")
61    control_outputs = control.forward(*control.inputs)
62    experim_outputs = experim.forward(*experim.inputs)
63    assertEqual(experim_outputs, control_outputs)
64
65    print("Checking grads...")
66    assert control.backward_setup is not None
67    assert experim.backward_setup is not None
68    assert control.backward is not None
69    assert experim.backward is not None
70    control_backward_inputs = control.backward_setup(control_outputs, seed)
71    experim_backward_inputs = experim.backward_setup(experim_outputs, seed)
72
73    control.backward(*control_backward_inputs)
74    experim.backward(*experim_backward_inputs)
75
76    control_grads = [p.grad for p in control.params]
77    experim_grads = [p.grad for p in experim.params]
78    assertEqual(experim_grads, control_grads)
79
80    if verbose:
81        print(experim.forward.graph_for(*experim.inputs))
82    print()
83
84
85def test_vl_py(**test_args):
86    # XXX: This compares vl_py with vl_lstm.
87    # It's done this way because those two don't give the same outputs so
88    # the result isn't an apples-to-apples comparison right now.
89    control_creator = varlen_pytorch_lstm_creator
90    name, experim_creator, context = get_nn_runners("vl_py")[0]
91    with context():
92        print(f"testing {name}...")
93        creator_keys = [
94            "seqLength",
95            "numLayers",
96            "inputSize",
97            "hiddenSize",
98            "miniBatch",
99            "device",
100            "seed",
101        ]
102        creator_args = {key: test_args[key] for key in creator_keys}
103
104        print("Setting up...")
105        control = control_creator(**creator_args)
106        experim = experim_creator(**creator_args)
107
108        # Precondition
109        assertEqual(experim.inputs, control.inputs[:2])
110        assertEqual(experim.params, control.params)
111
112        print("Checking outputs...")
113        control_out, control_hiddens = control.forward(*control.inputs)
114        control_hx, control_cx = control_hiddens
115        experim_out, experim_hiddens = experim.forward(*experim.inputs)
116        experim_hx, experim_cx = experim_hiddens
117
118        experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2)
119        assertEqual(experim_padded, control_out)
120        assertEqual(torch.cat(experim_hx, dim=1), control_hx)
121        assertEqual(torch.cat(experim_cx, dim=1), control_cx)
122
123        print("Checking grads...")
124        assert control.backward_setup is not None
125        assert experim.backward_setup is not None
126        assert control.backward is not None
127        assert experim.backward is not None
128        control_backward_inputs = control.backward_setup(
129            (control_out, control_hiddens), test_args["seed"]
130        )
131        experim_backward_inputs = experim.backward_setup(
132            (experim_out, experim_hiddens), test_args["seed"]
133        )
134
135        control.backward(*control_backward_inputs)
136        experim.backward(*experim_backward_inputs)
137
138        control_grads = [p.grad for p in control.params]
139        experim_grads = [p.grad for p in experim.params]
140        assertEqual(experim_grads, control_grads)
141
142        if test_args["verbose"]:
143            print(experim.forward.graph_for(*experim.inputs))
144        print()
145
146
147if __name__ == "__main__":
148    parser = argparse.ArgumentParser(description="Test lstm correctness")
149
150    parser.add_argument("--seqLength", default="100", type=int)
151    parser.add_argument("--numLayers", default="1", type=int)
152    parser.add_argument("--inputSize", default="512", type=int)
153    parser.add_argument("--hiddenSize", default="512", type=int)
154    parser.add_argument("--miniBatch", default="64", type=int)
155    parser.add_argument("--device", default="cuda", type=str)
156    parser.add_argument("--check-grad", "--check_grad", default="True", type=bool)
157    parser.add_argument("--variable-lstms", "--variable_lstms", action="store_true")
158    parser.add_argument("--seed", default="17", type=int)
159    parser.add_argument("--verbose", action="store_true")
160    parser.add_argument("--rnns", nargs="*", help="What to run. jit_premul, jit, etc")
161    args = parser.parse_args()
162    if args.rnns is None:
163        args.rnns = ["jit_premul", "jit"]
164    print(args)
165
166    if "cuda" in args.device:
167        assert torch.cuda.is_available()
168
169    rnn_runners = get_nn_runners(*args.rnns)
170
171    should_test_varlen_lstms = args.variable_lstms
172    test_args = vars(args)
173    del test_args["rnns"]
174    del test_args["variable_lstms"]
175
176    if should_test_varlen_lstms:
177        test_vl_py(**test_args)
178
179    for name, creator, context in rnn_runners:
180        with context():
181            print(f"testing {name}...")
182            test_rnns(creator, pytorch_lstm_creator, **test_args)
183