1# Owner(s): ["module: hub"] 2 3import os 4import tempfile 5import unittest 6import warnings 7from unittest.mock import patch 8 9import torch 10import torch.hub as hub 11from torch.testing._internal.common_utils import IS_SANDCASTLE, retry, TestCase 12 13 14def sum_of_state_dict(state_dict): 15 s = 0 16 for v in state_dict.values(): 17 s += v.sum() 18 return s 19 20 21SUM_OF_HUB_EXAMPLE = 431080 22TORCHHUB_EXAMPLE_RELEASE_URL = ( 23 "https://github.com/ailzhang/torchhub_example/releases/download/0.1/mnist_init_ones" 24) 25 26 27@unittest.skipIf(IS_SANDCASTLE, "Sandcastle cannot ping external") 28class TestHub(TestCase): 29 def setUp(self): 30 super().setUp() 31 self.previous_hub_dir = torch.hub.get_dir() 32 self.tmpdir = tempfile.TemporaryDirectory("hub_dir") 33 torch.hub.set_dir(self.tmpdir.name) 34 self.trusted_list_path = os.path.join(torch.hub.get_dir(), "trusted_list") 35 36 def tearDown(self): 37 super().tearDown() 38 torch.hub.set_dir(self.previous_hub_dir) # probably not needed, but can't hurt 39 self.tmpdir.cleanup() 40 41 def _assert_trusted_list_is_empty(self): 42 with open(self.trusted_list_path) as f: 43 assert not f.readlines() 44 45 def _assert_in_trusted_list(self, line): 46 with open(self.trusted_list_path) as f: 47 assert line in (l.strip() for l in f) 48 49 @retry(Exception, tries=3) 50 def test_load_from_github(self): 51 hub_model = hub.load( 52 "ailzhang/torchhub_example", 53 "mnist", 54 source="github", 55 pretrained=True, 56 verbose=False, 57 ) 58 self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE) 59 60 @retry(Exception, tries=3) 61 def test_load_from_local_dir(self): 62 local_dir = hub._get_cache_or_reload( 63 "ailzhang/torchhub_example", 64 force_reload=False, 65 trust_repo=True, 66 calling_fn=None, 67 ) 68 hub_model = hub.load( 69 local_dir, "mnist", source="local", pretrained=True, verbose=False 70 ) 71 self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE) 72 73 @retry(Exception, tries=3) 74 def test_load_from_branch(self): 75 hub_model = hub.load( 76 "ailzhang/torchhub_example:ci/test_slash", 77 "mnist", 78 pretrained=True, 79 verbose=False, 80 ) 81 self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE) 82 83 @retry(Exception, tries=3) 84 def test_get_set_dir(self): 85 previous_hub_dir = torch.hub.get_dir() 86 with tempfile.TemporaryDirectory("hub_dir") as tmpdir: 87 torch.hub.set_dir(tmpdir) 88 self.assertEqual(torch.hub.get_dir(), tmpdir) 89 self.assertNotEqual(previous_hub_dir, tmpdir) 90 91 hub_model = hub.load( 92 "ailzhang/torchhub_example", "mnist", pretrained=True, verbose=False 93 ) 94 self.assertEqual( 95 sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE 96 ) 97 assert os.path.exists( 98 os.path.join(tmpdir, "ailzhang_torchhub_example_master") 99 ) 100 101 # Test that set_dir properly calls expanduser() 102 # non-regression test for https://github.com/pytorch/pytorch/issues/69761 103 new_dir = os.path.join("~", "hub") 104 torch.hub.set_dir(new_dir) 105 self.assertEqual(torch.hub.get_dir(), os.path.expanduser(new_dir)) 106 107 @retry(Exception, tries=3) 108 def test_list_entrypoints(self): 109 entry_lists = hub.list("ailzhang/torchhub_example", trust_repo=True) 110 self.assertObjectIn("mnist", entry_lists) 111 112 @retry(Exception, tries=3) 113 def test_download_url_to_file(self): 114 with tempfile.TemporaryDirectory() as tmpdir: 115 f = os.path.join(tmpdir, "temp") 116 hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL, f, progress=False) 117 loaded_state = torch.load(f) 118 self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE) 119 # Check that the downloaded file has default file permissions 120 f_ref = os.path.join(tmpdir, "reference") 121 open(f_ref, "w").close() 122 expected_permissions = oct(os.stat(f_ref).st_mode & 0o777) 123 actual_permissions = oct(os.stat(f).st_mode & 0o777) 124 assert actual_permissions == expected_permissions 125 126 @retry(Exception, tries=3) 127 def test_load_state_dict_from_url(self): 128 loaded_state = hub.load_state_dict_from_url(TORCHHUB_EXAMPLE_RELEASE_URL) 129 self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE) 130 131 # with name 132 file_name = "the_file_name" 133 loaded_state = hub.load_state_dict_from_url( 134 TORCHHUB_EXAMPLE_RELEASE_URL, file_name=file_name 135 ) 136 expected_file_path = os.path.join(torch.hub.get_dir(), "checkpoints", file_name) 137 self.assertTrue(os.path.exists(expected_file_path)) 138 self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE) 139 140 # with safe weight_only 141 loaded_state = hub.load_state_dict_from_url( 142 TORCHHUB_EXAMPLE_RELEASE_URL, weights_only=True 143 ) 144 self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE) 145 146 @retry(Exception, tries=3) 147 def test_load_legacy_zip_checkpoint(self): 148 with warnings.catch_warnings(record=True) as ws: 149 warnings.simplefilter("always") 150 hub_model = hub.load( 151 "ailzhang/torchhub_example", "mnist_zip", pretrained=True, verbose=False 152 ) 153 self.assertEqual( 154 sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE 155 ) 156 assert any( 157 "will be deprecated in favor of default zipfile" in str(w) for w in ws 158 ) 159 160 # Test the default zipfile serialization format produced by >=1.6 release. 161 @retry(Exception, tries=3) 162 def test_load_zip_1_6_checkpoint(self): 163 hub_model = hub.load( 164 "ailzhang/torchhub_example", 165 "mnist_zip_1_6", 166 pretrained=True, 167 verbose=False, 168 trust_repo=True, 169 ) 170 self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE) 171 172 @retry(Exception, tries=3) 173 def test_hub_parse_repo_info(self): 174 # If the branch is specified we just parse the input and return 175 self.assertEqual(torch.hub._parse_repo_info("a/b:c"), ("a", "b", "c")) 176 # For torchvision, the default branch is main 177 self.assertEqual( 178 torch.hub._parse_repo_info("pytorch/vision"), ("pytorch", "vision", "main") 179 ) 180 # For the torchhub_example repo, the default branch is still master 181 self.assertEqual( 182 torch.hub._parse_repo_info("ailzhang/torchhub_example"), 183 ("ailzhang", "torchhub_example", "master"), 184 ) 185 186 @retry(Exception, tries=3) 187 def test_load_commit_from_forked_repo(self): 188 with self.assertRaisesRegex(ValueError, "If it's a commit from a forked repo"): 189 torch.hub.load("pytorch/vision:4e2c216", "resnet18") 190 191 @retry(Exception, tries=3) 192 @patch("builtins.input", return_value="") 193 def test_trust_repo_false_emptystring(self, patched_input): 194 with self.assertRaisesRegex(Exception, "Untrusted repository."): 195 torch.hub.load( 196 "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False 197 ) 198 self._assert_trusted_list_is_empty() 199 patched_input.assert_called_once() 200 201 patched_input.reset_mock() 202 with self.assertRaisesRegex(Exception, "Untrusted repository."): 203 torch.hub.load( 204 "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False 205 ) 206 self._assert_trusted_list_is_empty() 207 patched_input.assert_called_once() 208 209 @retry(Exception, tries=3) 210 @patch("builtins.input", return_value="no") 211 def test_trust_repo_false_no(self, patched_input): 212 with self.assertRaisesRegex(Exception, "Untrusted repository."): 213 torch.hub.load( 214 "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False 215 ) 216 self._assert_trusted_list_is_empty() 217 patched_input.assert_called_once() 218 219 patched_input.reset_mock() 220 with self.assertRaisesRegex(Exception, "Untrusted repository."): 221 torch.hub.load( 222 "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False 223 ) 224 self._assert_trusted_list_is_empty() 225 patched_input.assert_called_once() 226 227 @retry(Exception, tries=3) 228 @patch("builtins.input", return_value="y") 229 def test_trusted_repo_false_yes(self, patched_input): 230 torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False) 231 self._assert_in_trusted_list("ailzhang_torchhub_example") 232 patched_input.assert_called_once() 233 234 # Loading a second time with "check", we don't ask for user input 235 patched_input.reset_mock() 236 torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check") 237 patched_input.assert_not_called() 238 239 # Loading again with False, we still ask for user input 240 patched_input.reset_mock() 241 torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False) 242 patched_input.assert_called_once() 243 244 @retry(Exception, tries=3) 245 @patch("builtins.input", return_value="no") 246 def test_trust_repo_check_no(self, patched_input): 247 with self.assertRaisesRegex(Exception, "Untrusted repository."): 248 torch.hub.load( 249 "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check" 250 ) 251 self._assert_trusted_list_is_empty() 252 patched_input.assert_called_once() 253 254 patched_input.reset_mock() 255 with self.assertRaisesRegex(Exception, "Untrusted repository."): 256 torch.hub.load( 257 "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check" 258 ) 259 patched_input.assert_called_once() 260 261 @retry(Exception, tries=3) 262 @patch("builtins.input", return_value="y") 263 def test_trust_repo_check_yes(self, patched_input): 264 torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check") 265 self._assert_in_trusted_list("ailzhang_torchhub_example") 266 patched_input.assert_called_once() 267 268 # Loading a second time with "check", we don't ask for user input 269 patched_input.reset_mock() 270 torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check") 271 patched_input.assert_not_called() 272 273 @retry(Exception, tries=3) 274 def test_trust_repo_true(self): 275 torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True) 276 self._assert_in_trusted_list("ailzhang_torchhub_example") 277 278 @retry(Exception, tries=3) 279 def test_trust_repo_builtin_trusted_owners(self): 280 torch.hub.load("pytorch/vision", "resnet18", trust_repo="check") 281 self._assert_trusted_list_is_empty() 282 283 @retry(Exception, tries=3) 284 def test_trust_repo_none(self): 285 with warnings.catch_warnings(record=True) as w: 286 warnings.simplefilter("always") 287 torch.hub.load( 288 "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=None 289 ) 290 assert len(w) == 1 291 assert issubclass(w[-1].category, UserWarning) 292 assert ( 293 "You are about to download and run code from an untrusted repository" 294 in str(w[-1].message) 295 ) 296 297 self._assert_trusted_list_is_empty() 298 299 @retry(Exception, tries=3) 300 def test_trust_repo_legacy(self): 301 # We first download a repo and then delete the allowlist file 302 # Then we check that the repo is indeed trusted without a prompt, 303 # because it was already downloaded in the past. 304 torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True) 305 os.remove(self.trusted_list_path) 306 307 torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check") 308 309 self._assert_trusted_list_is_empty() 310