1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: hub"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport tempfile 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Workerimport warnings 7*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerimport torch.hub as hub 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_SANDCASTLE, retry, TestCase 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerdef sum_of_state_dict(state_dict): 15*da0073e9SAndroid Build Coastguard Worker s = 0 16*da0073e9SAndroid Build Coastguard Worker for v in state_dict.values(): 17*da0073e9SAndroid Build Coastguard Worker s += v.sum() 18*da0073e9SAndroid Build Coastguard Worker return s 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard WorkerSUM_OF_HUB_EXAMPLE = 431080 22*da0073e9SAndroid Build Coastguard WorkerTORCHHUB_EXAMPLE_RELEASE_URL = ( 23*da0073e9SAndroid Build Coastguard Worker "https://github.com/ailzhang/torchhub_example/releases/download/0.1/mnist_init_ones" 24*da0073e9SAndroid Build Coastguard Worker) 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_SANDCASTLE, "Sandcastle cannot ping external") 28*da0073e9SAndroid Build Coastguard Workerclass TestHub(TestCase): 29*da0073e9SAndroid Build Coastguard Worker def setUp(self): 30*da0073e9SAndroid Build Coastguard Worker super().setUp() 31*da0073e9SAndroid Build Coastguard Worker self.previous_hub_dir = torch.hub.get_dir() 32*da0073e9SAndroid Build Coastguard Worker self.tmpdir = tempfile.TemporaryDirectory("hub_dir") 33*da0073e9SAndroid Build Coastguard Worker torch.hub.set_dir(self.tmpdir.name) 34*da0073e9SAndroid Build Coastguard Worker self.trusted_list_path = os.path.join(torch.hub.get_dir(), "trusted_list") 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 37*da0073e9SAndroid Build Coastguard Worker super().tearDown() 38*da0073e9SAndroid Build Coastguard Worker torch.hub.set_dir(self.previous_hub_dir) # probably not needed, but can't hurt 39*da0073e9SAndroid Build Coastguard Worker self.tmpdir.cleanup() 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def _assert_trusted_list_is_empty(self): 42*da0073e9SAndroid Build Coastguard Worker with open(self.trusted_list_path) as f: 43*da0073e9SAndroid Build Coastguard Worker assert not f.readlines() 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker def _assert_in_trusted_list(self, line): 46*da0073e9SAndroid Build Coastguard Worker with open(self.trusted_list_path) as f: 47*da0073e9SAndroid Build Coastguard Worker assert line in (l.strip() for l in f) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 50*da0073e9SAndroid Build Coastguard Worker def test_load_from_github(self): 51*da0073e9SAndroid Build Coastguard Worker hub_model = hub.load( 52*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", 53*da0073e9SAndroid Build Coastguard Worker "mnist", 54*da0073e9SAndroid Build Coastguard Worker source="github", 55*da0073e9SAndroid Build Coastguard Worker pretrained=True, 56*da0073e9SAndroid Build Coastguard Worker verbose=False, 57*da0073e9SAndroid Build Coastguard Worker ) 58*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 61*da0073e9SAndroid Build Coastguard Worker def test_load_from_local_dir(self): 62*da0073e9SAndroid Build Coastguard Worker local_dir = hub._get_cache_or_reload( 63*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", 64*da0073e9SAndroid Build Coastguard Worker force_reload=False, 65*da0073e9SAndroid Build Coastguard Worker trust_repo=True, 66*da0073e9SAndroid Build Coastguard Worker calling_fn=None, 67*da0073e9SAndroid Build Coastguard Worker ) 68*da0073e9SAndroid Build Coastguard Worker hub_model = hub.load( 69*da0073e9SAndroid Build Coastguard Worker local_dir, "mnist", source="local", pretrained=True, verbose=False 70*da0073e9SAndroid Build Coastguard Worker ) 71*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE) 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 74*da0073e9SAndroid Build Coastguard Worker def test_load_from_branch(self): 75*da0073e9SAndroid Build Coastguard Worker hub_model = hub.load( 76*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example:ci/test_slash", 77*da0073e9SAndroid Build Coastguard Worker "mnist", 78*da0073e9SAndroid Build Coastguard Worker pretrained=True, 79*da0073e9SAndroid Build Coastguard Worker verbose=False, 80*da0073e9SAndroid Build Coastguard Worker ) 81*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE) 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 84*da0073e9SAndroid Build Coastguard Worker def test_get_set_dir(self): 85*da0073e9SAndroid Build Coastguard Worker previous_hub_dir = torch.hub.get_dir() 86*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory("hub_dir") as tmpdir: 87*da0073e9SAndroid Build Coastguard Worker torch.hub.set_dir(tmpdir) 88*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.hub.get_dir(), tmpdir) 89*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(previous_hub_dir, tmpdir) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker hub_model = hub.load( 92*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", "mnist", pretrained=True, verbose=False 93*da0073e9SAndroid Build Coastguard Worker ) 94*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 95*da0073e9SAndroid Build Coastguard Worker sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE 96*da0073e9SAndroid Build Coastguard Worker ) 97*da0073e9SAndroid Build Coastguard Worker assert os.path.exists( 98*da0073e9SAndroid Build Coastguard Worker os.path.join(tmpdir, "ailzhang_torchhub_example_master") 99*da0073e9SAndroid Build Coastguard Worker ) 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker # Test that set_dir properly calls expanduser() 102*da0073e9SAndroid Build Coastguard Worker # non-regression test for https://github.com/pytorch/pytorch/issues/69761 103*da0073e9SAndroid Build Coastguard Worker new_dir = os.path.join("~", "hub") 104*da0073e9SAndroid Build Coastguard Worker torch.hub.set_dir(new_dir) 105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.hub.get_dir(), os.path.expanduser(new_dir)) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 108*da0073e9SAndroid Build Coastguard Worker def test_list_entrypoints(self): 109*da0073e9SAndroid Build Coastguard Worker entry_lists = hub.list("ailzhang/torchhub_example", trust_repo=True) 110*da0073e9SAndroid Build Coastguard Worker self.assertObjectIn("mnist", entry_lists) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 113*da0073e9SAndroid Build Coastguard Worker def test_download_url_to_file(self): 114*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmpdir: 115*da0073e9SAndroid Build Coastguard Worker f = os.path.join(tmpdir, "temp") 116*da0073e9SAndroid Build Coastguard Worker hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL, f, progress=False) 117*da0073e9SAndroid Build Coastguard Worker loaded_state = torch.load(f) 118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE) 119*da0073e9SAndroid Build Coastguard Worker # Check that the downloaded file has default file permissions 120*da0073e9SAndroid Build Coastguard Worker f_ref = os.path.join(tmpdir, "reference") 121*da0073e9SAndroid Build Coastguard Worker open(f_ref, "w").close() 122*da0073e9SAndroid Build Coastguard Worker expected_permissions = oct(os.stat(f_ref).st_mode & 0o777) 123*da0073e9SAndroid Build Coastguard Worker actual_permissions = oct(os.stat(f).st_mode & 0o777) 124*da0073e9SAndroid Build Coastguard Worker assert actual_permissions == expected_permissions 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 127*da0073e9SAndroid Build Coastguard Worker def test_load_state_dict_from_url(self): 128*da0073e9SAndroid Build Coastguard Worker loaded_state = hub.load_state_dict_from_url(TORCHHUB_EXAMPLE_RELEASE_URL) 129*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker # with name 132*da0073e9SAndroid Build Coastguard Worker file_name = "the_file_name" 133*da0073e9SAndroid Build Coastguard Worker loaded_state = hub.load_state_dict_from_url( 134*da0073e9SAndroid Build Coastguard Worker TORCHHUB_EXAMPLE_RELEASE_URL, file_name=file_name 135*da0073e9SAndroid Build Coastguard Worker ) 136*da0073e9SAndroid Build Coastguard Worker expected_file_path = os.path.join(torch.hub.get_dir(), "checkpoints", file_name) 137*da0073e9SAndroid Build Coastguard Worker self.assertTrue(os.path.exists(expected_file_path)) 138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker # with safe weight_only 141*da0073e9SAndroid Build Coastguard Worker loaded_state = hub.load_state_dict_from_url( 142*da0073e9SAndroid Build Coastguard Worker TORCHHUB_EXAMPLE_RELEASE_URL, weights_only=True 143*da0073e9SAndroid Build Coastguard Worker ) 144*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 147*da0073e9SAndroid Build Coastguard Worker def test_load_legacy_zip_checkpoint(self): 148*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 149*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 150*da0073e9SAndroid Build Coastguard Worker hub_model = hub.load( 151*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", "mnist_zip", pretrained=True, verbose=False 152*da0073e9SAndroid Build Coastguard Worker ) 153*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 154*da0073e9SAndroid Build Coastguard Worker sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE 155*da0073e9SAndroid Build Coastguard Worker ) 156*da0073e9SAndroid Build Coastguard Worker assert any( 157*da0073e9SAndroid Build Coastguard Worker "will be deprecated in favor of default zipfile" in str(w) for w in ws 158*da0073e9SAndroid Build Coastguard Worker ) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker # Test the default zipfile serialization format produced by >=1.6 release. 161*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 162*da0073e9SAndroid Build Coastguard Worker def test_load_zip_1_6_checkpoint(self): 163*da0073e9SAndroid Build Coastguard Worker hub_model = hub.load( 164*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", 165*da0073e9SAndroid Build Coastguard Worker "mnist_zip_1_6", 166*da0073e9SAndroid Build Coastguard Worker pretrained=True, 167*da0073e9SAndroid Build Coastguard Worker verbose=False, 168*da0073e9SAndroid Build Coastguard Worker trust_repo=True, 169*da0073e9SAndroid Build Coastguard Worker ) 170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 173*da0073e9SAndroid Build Coastguard Worker def test_hub_parse_repo_info(self): 174*da0073e9SAndroid Build Coastguard Worker # If the branch is specified we just parse the input and return 175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.hub._parse_repo_info("a/b:c"), ("a", "b", "c")) 176*da0073e9SAndroid Build Coastguard Worker # For torchvision, the default branch is main 177*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 178*da0073e9SAndroid Build Coastguard Worker torch.hub._parse_repo_info("pytorch/vision"), ("pytorch", "vision", "main") 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker # For the torchhub_example repo, the default branch is still master 181*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 182*da0073e9SAndroid Build Coastguard Worker torch.hub._parse_repo_info("ailzhang/torchhub_example"), 183*da0073e9SAndroid Build Coastguard Worker ("ailzhang", "torchhub_example", "master"), 184*da0073e9SAndroid Build Coastguard Worker ) 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 187*da0073e9SAndroid Build Coastguard Worker def test_load_commit_from_forked_repo(self): 188*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "If it's a commit from a forked repo"): 189*da0073e9SAndroid Build Coastguard Worker torch.hub.load("pytorch/vision:4e2c216", "resnet18") 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 192*da0073e9SAndroid Build Coastguard Worker @patch("builtins.input", return_value="") 193*da0073e9SAndroid Build Coastguard Worker def test_trust_repo_false_emptystring(self, patched_input): 194*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Untrusted repository."): 195*da0073e9SAndroid Build Coastguard Worker torch.hub.load( 196*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False 197*da0073e9SAndroid Build Coastguard Worker ) 198*da0073e9SAndroid Build Coastguard Worker self._assert_trusted_list_is_empty() 199*da0073e9SAndroid Build Coastguard Worker patched_input.assert_called_once() 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker patched_input.reset_mock() 202*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Untrusted repository."): 203*da0073e9SAndroid Build Coastguard Worker torch.hub.load( 204*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False 205*da0073e9SAndroid Build Coastguard Worker ) 206*da0073e9SAndroid Build Coastguard Worker self._assert_trusted_list_is_empty() 207*da0073e9SAndroid Build Coastguard Worker patched_input.assert_called_once() 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 210*da0073e9SAndroid Build Coastguard Worker @patch("builtins.input", return_value="no") 211*da0073e9SAndroid Build Coastguard Worker def test_trust_repo_false_no(self, patched_input): 212*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Untrusted repository."): 213*da0073e9SAndroid Build Coastguard Worker torch.hub.load( 214*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False 215*da0073e9SAndroid Build Coastguard Worker ) 216*da0073e9SAndroid Build Coastguard Worker self._assert_trusted_list_is_empty() 217*da0073e9SAndroid Build Coastguard Worker patched_input.assert_called_once() 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker patched_input.reset_mock() 220*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Untrusted repository."): 221*da0073e9SAndroid Build Coastguard Worker torch.hub.load( 222*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False 223*da0073e9SAndroid Build Coastguard Worker ) 224*da0073e9SAndroid Build Coastguard Worker self._assert_trusted_list_is_empty() 225*da0073e9SAndroid Build Coastguard Worker patched_input.assert_called_once() 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 228*da0073e9SAndroid Build Coastguard Worker @patch("builtins.input", return_value="y") 229*da0073e9SAndroid Build Coastguard Worker def test_trusted_repo_false_yes(self, patched_input): 230*da0073e9SAndroid Build Coastguard Worker torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False) 231*da0073e9SAndroid Build Coastguard Worker self._assert_in_trusted_list("ailzhang_torchhub_example") 232*da0073e9SAndroid Build Coastguard Worker patched_input.assert_called_once() 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker # Loading a second time with "check", we don't ask for user input 235*da0073e9SAndroid Build Coastguard Worker patched_input.reset_mock() 236*da0073e9SAndroid Build Coastguard Worker torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check") 237*da0073e9SAndroid Build Coastguard Worker patched_input.assert_not_called() 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker # Loading again with False, we still ask for user input 240*da0073e9SAndroid Build Coastguard Worker patched_input.reset_mock() 241*da0073e9SAndroid Build Coastguard Worker torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False) 242*da0073e9SAndroid Build Coastguard Worker patched_input.assert_called_once() 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 245*da0073e9SAndroid Build Coastguard Worker @patch("builtins.input", return_value="no") 246*da0073e9SAndroid Build Coastguard Worker def test_trust_repo_check_no(self, patched_input): 247*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Untrusted repository."): 248*da0073e9SAndroid Build Coastguard Worker torch.hub.load( 249*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check" 250*da0073e9SAndroid Build Coastguard Worker ) 251*da0073e9SAndroid Build Coastguard Worker self._assert_trusted_list_is_empty() 252*da0073e9SAndroid Build Coastguard Worker patched_input.assert_called_once() 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker patched_input.reset_mock() 255*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Untrusted repository."): 256*da0073e9SAndroid Build Coastguard Worker torch.hub.load( 257*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check" 258*da0073e9SAndroid Build Coastguard Worker ) 259*da0073e9SAndroid Build Coastguard Worker patched_input.assert_called_once() 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 262*da0073e9SAndroid Build Coastguard Worker @patch("builtins.input", return_value="y") 263*da0073e9SAndroid Build Coastguard Worker def test_trust_repo_check_yes(self, patched_input): 264*da0073e9SAndroid Build Coastguard Worker torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check") 265*da0073e9SAndroid Build Coastguard Worker self._assert_in_trusted_list("ailzhang_torchhub_example") 266*da0073e9SAndroid Build Coastguard Worker patched_input.assert_called_once() 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker # Loading a second time with "check", we don't ask for user input 269*da0073e9SAndroid Build Coastguard Worker patched_input.reset_mock() 270*da0073e9SAndroid Build Coastguard Worker torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check") 271*da0073e9SAndroid Build Coastguard Worker patched_input.assert_not_called() 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 274*da0073e9SAndroid Build Coastguard Worker def test_trust_repo_true(self): 275*da0073e9SAndroid Build Coastguard Worker torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True) 276*da0073e9SAndroid Build Coastguard Worker self._assert_in_trusted_list("ailzhang_torchhub_example") 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 279*da0073e9SAndroid Build Coastguard Worker def test_trust_repo_builtin_trusted_owners(self): 280*da0073e9SAndroid Build Coastguard Worker torch.hub.load("pytorch/vision", "resnet18", trust_repo="check") 281*da0073e9SAndroid Build Coastguard Worker self._assert_trusted_list_is_empty() 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 284*da0073e9SAndroid Build Coastguard Worker def test_trust_repo_none(self): 285*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 286*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 287*da0073e9SAndroid Build Coastguard Worker torch.hub.load( 288*da0073e9SAndroid Build Coastguard Worker "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=None 289*da0073e9SAndroid Build Coastguard Worker ) 290*da0073e9SAndroid Build Coastguard Worker assert len(w) == 1 291*da0073e9SAndroid Build Coastguard Worker assert issubclass(w[-1].category, UserWarning) 292*da0073e9SAndroid Build Coastguard Worker assert ( 293*da0073e9SAndroid Build Coastguard Worker "You are about to download and run code from an untrusted repository" 294*da0073e9SAndroid Build Coastguard Worker in str(w[-1].message) 295*da0073e9SAndroid Build Coastguard Worker ) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker self._assert_trusted_list_is_empty() 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker @retry(Exception, tries=3) 300*da0073e9SAndroid Build Coastguard Worker def test_trust_repo_legacy(self): 301*da0073e9SAndroid Build Coastguard Worker # We first download a repo and then delete the allowlist file 302*da0073e9SAndroid Build Coastguard Worker # Then we check that the repo is indeed trusted without a prompt, 303*da0073e9SAndroid Build Coastguard Worker # because it was already downloaded in the past. 304*da0073e9SAndroid Build Coastguard Worker torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True) 305*da0073e9SAndroid Build Coastguard Worker os.remove(self.trusted_list_path) 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check") 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker self._assert_trusted_list_is_empty() 310