xref: /aosp_15_r20/external/pytorch/scripts/release_notes/test_release_notes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import tempfile
2import unittest
3
4from commitlist import CommitList
5
6
7class TestCommitList(unittest.TestCase):
8    def test_create_new(self):
9        with tempfile.TemporaryDirectory() as tempdir:
10            commit_list_path = f"{tempdir}/commitlist.csv"
11            commit_list = CommitList.create_new(
12                commit_list_path, "v1.5.0", "6000dca5df"
13            )
14            self.assertEqual(len(commit_list.commits), 33)
15            self.assertEqual(commit_list.commits[0].commit_hash, "7335f079abb")
16            self.assertTrue(
17                commit_list.commits[0].title.startswith("[pt][quant] qmul and qadd")
18            )
19            self.assertEqual(commit_list.commits[-1].commit_hash, "6000dca5df6")
20            self.assertTrue(
21                commit_list.commits[-1].title.startswith(
22                    "[nomnigraph] Copy device option when customize "
23                )
24            )
25
26    def test_read_write(self):
27        with tempfile.TemporaryDirectory() as tempdir:
28            commit_list_path = f"{tempdir}/commitlist.csv"
29            initial = CommitList.create_new(commit_list_path, "v1.5.0", "7543e7e558")
30            initial.write_to_disk()
31
32            expected = CommitList.from_existing(commit_list_path)
33            expected.commits[-2].category = "foobar"
34            expected.write_to_disk()
35
36            commit_list = CommitList.from_existing(commit_list_path)
37            for commit, expected_commit in zip(commit_list.commits, expected.commits):
38                self.assertEqual(commit, expected_commit)
39
40    def test_update_to(self):
41        with tempfile.TemporaryDirectory() as tempdir:
42            commit_list_path = f"{tempdir}/commitlist.csv"
43            initial = CommitList.create_new(commit_list_path, "v1.5.0", "7543e7e558")
44            initial.commits[-2].category = "foobar"
45            self.assertEqual(len(initial.commits), 2143)
46            initial.write_to_disk()
47
48            commit_list = CommitList.from_existing(commit_list_path)
49            commit_list.update_to("5702a28b26")
50            self.assertEqual(len(commit_list.commits), 2143 + 4)
51            self.assertEqual(commit_list.commits[-5], initial.commits[-1])
52
53
54if __name__ == "__main__":
55    unittest.main()
56