xref: /aosp_15_r20/external/pytorch/test/test_hub.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: hub"]
2
3import os
4import tempfile
5import unittest
6import warnings
7from unittest.mock import patch
8
9import torch
10import torch.hub as hub
11from torch.testing._internal.common_utils import IS_SANDCASTLE, retry, TestCase
12
13
14def sum_of_state_dict(state_dict):
15    s = 0
16    for v in state_dict.values():
17        s += v.sum()
18    return s
19
20
21SUM_OF_HUB_EXAMPLE = 431080
22TORCHHUB_EXAMPLE_RELEASE_URL = (
23    "https://github.com/ailzhang/torchhub_example/releases/download/0.1/mnist_init_ones"
24)
25
26
27@unittest.skipIf(IS_SANDCASTLE, "Sandcastle cannot ping external")
28class TestHub(TestCase):
29    def setUp(self):
30        super().setUp()
31        self.previous_hub_dir = torch.hub.get_dir()
32        self.tmpdir = tempfile.TemporaryDirectory("hub_dir")
33        torch.hub.set_dir(self.tmpdir.name)
34        self.trusted_list_path = os.path.join(torch.hub.get_dir(), "trusted_list")
35
36    def tearDown(self):
37        super().tearDown()
38        torch.hub.set_dir(self.previous_hub_dir)  # probably not needed, but can't hurt
39        self.tmpdir.cleanup()
40
41    def _assert_trusted_list_is_empty(self):
42        with open(self.trusted_list_path) as f:
43            assert not f.readlines()
44
45    def _assert_in_trusted_list(self, line):
46        with open(self.trusted_list_path) as f:
47            assert line in (l.strip() for l in f)
48
49    @retry(Exception, tries=3)
50    def test_load_from_github(self):
51        hub_model = hub.load(
52            "ailzhang/torchhub_example",
53            "mnist",
54            source="github",
55            pretrained=True,
56            verbose=False,
57        )
58        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
59
60    @retry(Exception, tries=3)
61    def test_load_from_local_dir(self):
62        local_dir = hub._get_cache_or_reload(
63            "ailzhang/torchhub_example",
64            force_reload=False,
65            trust_repo=True,
66            calling_fn=None,
67        )
68        hub_model = hub.load(
69            local_dir, "mnist", source="local", pretrained=True, verbose=False
70        )
71        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
72
73    @retry(Exception, tries=3)
74    def test_load_from_branch(self):
75        hub_model = hub.load(
76            "ailzhang/torchhub_example:ci/test_slash",
77            "mnist",
78            pretrained=True,
79            verbose=False,
80        )
81        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
82
83    @retry(Exception, tries=3)
84    def test_get_set_dir(self):
85        previous_hub_dir = torch.hub.get_dir()
86        with tempfile.TemporaryDirectory("hub_dir") as tmpdir:
87            torch.hub.set_dir(tmpdir)
88            self.assertEqual(torch.hub.get_dir(), tmpdir)
89            self.assertNotEqual(previous_hub_dir, tmpdir)
90
91            hub_model = hub.load(
92                "ailzhang/torchhub_example", "mnist", pretrained=True, verbose=False
93            )
94            self.assertEqual(
95                sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE
96            )
97            assert os.path.exists(
98                os.path.join(tmpdir, "ailzhang_torchhub_example_master")
99            )
100
101        # Test that set_dir properly calls expanduser()
102        # non-regression test for https://github.com/pytorch/pytorch/issues/69761
103        new_dir = os.path.join("~", "hub")
104        torch.hub.set_dir(new_dir)
105        self.assertEqual(torch.hub.get_dir(), os.path.expanduser(new_dir))
106
107    @retry(Exception, tries=3)
108    def test_list_entrypoints(self):
109        entry_lists = hub.list("ailzhang/torchhub_example", trust_repo=True)
110        self.assertObjectIn("mnist", entry_lists)
111
112    @retry(Exception, tries=3)
113    def test_download_url_to_file(self):
114        with tempfile.TemporaryDirectory() as tmpdir:
115            f = os.path.join(tmpdir, "temp")
116            hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL, f, progress=False)
117            loaded_state = torch.load(f)
118            self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
119            # Check that the downloaded file has default file permissions
120            f_ref = os.path.join(tmpdir, "reference")
121            open(f_ref, "w").close()
122            expected_permissions = oct(os.stat(f_ref).st_mode & 0o777)
123            actual_permissions = oct(os.stat(f).st_mode & 0o777)
124            assert actual_permissions == expected_permissions
125
126    @retry(Exception, tries=3)
127    def test_load_state_dict_from_url(self):
128        loaded_state = hub.load_state_dict_from_url(TORCHHUB_EXAMPLE_RELEASE_URL)
129        self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
130
131        # with name
132        file_name = "the_file_name"
133        loaded_state = hub.load_state_dict_from_url(
134            TORCHHUB_EXAMPLE_RELEASE_URL, file_name=file_name
135        )
136        expected_file_path = os.path.join(torch.hub.get_dir(), "checkpoints", file_name)
137        self.assertTrue(os.path.exists(expected_file_path))
138        self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
139
140        # with safe weight_only
141        loaded_state = hub.load_state_dict_from_url(
142            TORCHHUB_EXAMPLE_RELEASE_URL, weights_only=True
143        )
144        self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
145
146    @retry(Exception, tries=3)
147    def test_load_legacy_zip_checkpoint(self):
148        with warnings.catch_warnings(record=True) as ws:
149            warnings.simplefilter("always")
150            hub_model = hub.load(
151                "ailzhang/torchhub_example", "mnist_zip", pretrained=True, verbose=False
152            )
153            self.assertEqual(
154                sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE
155            )
156            assert any(
157                "will be deprecated in favor of default zipfile" in str(w) for w in ws
158            )
159
160    # Test the default zipfile serialization format produced by >=1.6 release.
161    @retry(Exception, tries=3)
162    def test_load_zip_1_6_checkpoint(self):
163        hub_model = hub.load(
164            "ailzhang/torchhub_example",
165            "mnist_zip_1_6",
166            pretrained=True,
167            verbose=False,
168            trust_repo=True,
169        )
170        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
171
172    @retry(Exception, tries=3)
173    def test_hub_parse_repo_info(self):
174        # If the branch is specified we just parse the input and return
175        self.assertEqual(torch.hub._parse_repo_info("a/b:c"), ("a", "b", "c"))
176        # For torchvision, the default branch is main
177        self.assertEqual(
178            torch.hub._parse_repo_info("pytorch/vision"), ("pytorch", "vision", "main")
179        )
180        # For the torchhub_example repo, the default branch is still master
181        self.assertEqual(
182            torch.hub._parse_repo_info("ailzhang/torchhub_example"),
183            ("ailzhang", "torchhub_example", "master"),
184        )
185
186    @retry(Exception, tries=3)
187    def test_load_commit_from_forked_repo(self):
188        with self.assertRaisesRegex(ValueError, "If it's a commit from a forked repo"):
189            torch.hub.load("pytorch/vision:4e2c216", "resnet18")
190
191    @retry(Exception, tries=3)
192    @patch("builtins.input", return_value="")
193    def test_trust_repo_false_emptystring(self, patched_input):
194        with self.assertRaisesRegex(Exception, "Untrusted repository."):
195            torch.hub.load(
196                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
197            )
198        self._assert_trusted_list_is_empty()
199        patched_input.assert_called_once()
200
201        patched_input.reset_mock()
202        with self.assertRaisesRegex(Exception, "Untrusted repository."):
203            torch.hub.load(
204                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
205            )
206        self._assert_trusted_list_is_empty()
207        patched_input.assert_called_once()
208
209    @retry(Exception, tries=3)
210    @patch("builtins.input", return_value="no")
211    def test_trust_repo_false_no(self, patched_input):
212        with self.assertRaisesRegex(Exception, "Untrusted repository."):
213            torch.hub.load(
214                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
215            )
216        self._assert_trusted_list_is_empty()
217        patched_input.assert_called_once()
218
219        patched_input.reset_mock()
220        with self.assertRaisesRegex(Exception, "Untrusted repository."):
221            torch.hub.load(
222                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
223            )
224        self._assert_trusted_list_is_empty()
225        patched_input.assert_called_once()
226
227    @retry(Exception, tries=3)
228    @patch("builtins.input", return_value="y")
229    def test_trusted_repo_false_yes(self, patched_input):
230        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False)
231        self._assert_in_trusted_list("ailzhang_torchhub_example")
232        patched_input.assert_called_once()
233
234        # Loading a second time with "check", we don't ask for user input
235        patched_input.reset_mock()
236        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
237        patched_input.assert_not_called()
238
239        # Loading again with False, we still ask for user input
240        patched_input.reset_mock()
241        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False)
242        patched_input.assert_called_once()
243
244    @retry(Exception, tries=3)
245    @patch("builtins.input", return_value="no")
246    def test_trust_repo_check_no(self, patched_input):
247        with self.assertRaisesRegex(Exception, "Untrusted repository."):
248            torch.hub.load(
249                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check"
250            )
251        self._assert_trusted_list_is_empty()
252        patched_input.assert_called_once()
253
254        patched_input.reset_mock()
255        with self.assertRaisesRegex(Exception, "Untrusted repository."):
256            torch.hub.load(
257                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check"
258            )
259        patched_input.assert_called_once()
260
261    @retry(Exception, tries=3)
262    @patch("builtins.input", return_value="y")
263    def test_trust_repo_check_yes(self, patched_input):
264        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
265        self._assert_in_trusted_list("ailzhang_torchhub_example")
266        patched_input.assert_called_once()
267
268        # Loading a second time with "check", we don't ask for user input
269        patched_input.reset_mock()
270        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
271        patched_input.assert_not_called()
272
273    @retry(Exception, tries=3)
274    def test_trust_repo_true(self):
275        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True)
276        self._assert_in_trusted_list("ailzhang_torchhub_example")
277
278    @retry(Exception, tries=3)
279    def test_trust_repo_builtin_trusted_owners(self):
280        torch.hub.load("pytorch/vision", "resnet18", trust_repo="check")
281        self._assert_trusted_list_is_empty()
282
283    @retry(Exception, tries=3)
284    def test_trust_repo_none(self):
285        with warnings.catch_warnings(record=True) as w:
286            warnings.simplefilter("always")
287            torch.hub.load(
288                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=None
289            )
290            assert len(w) == 1
291            assert issubclass(w[-1].category, UserWarning)
292            assert (
293                "You are about to download and run code from an untrusted repository"
294                in str(w[-1].message)
295            )
296
297        self._assert_trusted_list_is_empty()
298
299    @retry(Exception, tries=3)
300    def test_trust_repo_legacy(self):
301        # We first download a repo and then delete the allowlist file
302        # Then we check that the repo is indeed trusted without a prompt,
303        # because it was already downloaded in the past.
304        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True)
305        os.remove(self.trusted_list_path)
306
307        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
308
309        self._assert_trusted_list_is_empty()
310