xref: /aosp_15_r20/external/pytorch/test/test_hub.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: hub"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport tempfile
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Workerimport warnings
7*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerimport torch.hub as hub
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_SANDCASTLE, retry, TestCase
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerdef sum_of_state_dict(state_dict):
15*da0073e9SAndroid Build Coastguard Worker    s = 0
16*da0073e9SAndroid Build Coastguard Worker    for v in state_dict.values():
17*da0073e9SAndroid Build Coastguard Worker        s += v.sum()
18*da0073e9SAndroid Build Coastguard Worker    return s
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard WorkerSUM_OF_HUB_EXAMPLE = 431080
22*da0073e9SAndroid Build Coastguard WorkerTORCHHUB_EXAMPLE_RELEASE_URL = (
23*da0073e9SAndroid Build Coastguard Worker    "https://github.com/ailzhang/torchhub_example/releases/download/0.1/mnist_init_ones"
24*da0073e9SAndroid Build Coastguard Worker)
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_SANDCASTLE, "Sandcastle cannot ping external")
28*da0073e9SAndroid Build Coastguard Workerclass TestHub(TestCase):
29*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
30*da0073e9SAndroid Build Coastguard Worker        super().setUp()
31*da0073e9SAndroid Build Coastguard Worker        self.previous_hub_dir = torch.hub.get_dir()
32*da0073e9SAndroid Build Coastguard Worker        self.tmpdir = tempfile.TemporaryDirectory("hub_dir")
33*da0073e9SAndroid Build Coastguard Worker        torch.hub.set_dir(self.tmpdir.name)
34*da0073e9SAndroid Build Coastguard Worker        self.trusted_list_path = os.path.join(torch.hub.get_dir(), "trusted_list")
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
37*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
38*da0073e9SAndroid Build Coastguard Worker        torch.hub.set_dir(self.previous_hub_dir)  # probably not needed, but can't hurt
39*da0073e9SAndroid Build Coastguard Worker        self.tmpdir.cleanup()
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def _assert_trusted_list_is_empty(self):
42*da0073e9SAndroid Build Coastguard Worker        with open(self.trusted_list_path) as f:
43*da0073e9SAndroid Build Coastguard Worker            assert not f.readlines()
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    def _assert_in_trusted_list(self, line):
46*da0073e9SAndroid Build Coastguard Worker        with open(self.trusted_list_path) as f:
47*da0073e9SAndroid Build Coastguard Worker            assert line in (l.strip() for l in f)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
50*da0073e9SAndroid Build Coastguard Worker    def test_load_from_github(self):
51*da0073e9SAndroid Build Coastguard Worker        hub_model = hub.load(
52*da0073e9SAndroid Build Coastguard Worker            "ailzhang/torchhub_example",
53*da0073e9SAndroid Build Coastguard Worker            "mnist",
54*da0073e9SAndroid Build Coastguard Worker            source="github",
55*da0073e9SAndroid Build Coastguard Worker            pretrained=True,
56*da0073e9SAndroid Build Coastguard Worker            verbose=False,
57*da0073e9SAndroid Build Coastguard Worker        )
58*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
61*da0073e9SAndroid Build Coastguard Worker    def test_load_from_local_dir(self):
62*da0073e9SAndroid Build Coastguard Worker        local_dir = hub._get_cache_or_reload(
63*da0073e9SAndroid Build Coastguard Worker            "ailzhang/torchhub_example",
64*da0073e9SAndroid Build Coastguard Worker            force_reload=False,
65*da0073e9SAndroid Build Coastguard Worker            trust_repo=True,
66*da0073e9SAndroid Build Coastguard Worker            calling_fn=None,
67*da0073e9SAndroid Build Coastguard Worker        )
68*da0073e9SAndroid Build Coastguard Worker        hub_model = hub.load(
69*da0073e9SAndroid Build Coastguard Worker            local_dir, "mnist", source="local", pretrained=True, verbose=False
70*da0073e9SAndroid Build Coastguard Worker        )
71*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
74*da0073e9SAndroid Build Coastguard Worker    def test_load_from_branch(self):
75*da0073e9SAndroid Build Coastguard Worker        hub_model = hub.load(
76*da0073e9SAndroid Build Coastguard Worker            "ailzhang/torchhub_example:ci/test_slash",
77*da0073e9SAndroid Build Coastguard Worker            "mnist",
78*da0073e9SAndroid Build Coastguard Worker            pretrained=True,
79*da0073e9SAndroid Build Coastguard Worker            verbose=False,
80*da0073e9SAndroid Build Coastguard Worker        )
81*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
84*da0073e9SAndroid Build Coastguard Worker    def test_get_set_dir(self):
85*da0073e9SAndroid Build Coastguard Worker        previous_hub_dir = torch.hub.get_dir()
86*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory("hub_dir") as tmpdir:
87*da0073e9SAndroid Build Coastguard Worker            torch.hub.set_dir(tmpdir)
88*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.hub.get_dir(), tmpdir)
89*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(previous_hub_dir, tmpdir)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker            hub_model = hub.load(
92*da0073e9SAndroid Build Coastguard Worker                "ailzhang/torchhub_example", "mnist", pretrained=True, verbose=False
93*da0073e9SAndroid Build Coastguard Worker            )
94*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
95*da0073e9SAndroid Build Coastguard Worker                sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE
96*da0073e9SAndroid Build Coastguard Worker            )
97*da0073e9SAndroid Build Coastguard Worker            assert os.path.exists(
98*da0073e9SAndroid Build Coastguard Worker                os.path.join(tmpdir, "ailzhang_torchhub_example_master")
99*da0073e9SAndroid Build Coastguard Worker            )
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker        # Test that set_dir properly calls expanduser()
102*da0073e9SAndroid Build Coastguard Worker        # non-regression test for https://github.com/pytorch/pytorch/issues/69761
103*da0073e9SAndroid Build Coastguard Worker        new_dir = os.path.join("~", "hub")
104*da0073e9SAndroid Build Coastguard Worker        torch.hub.set_dir(new_dir)
105*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.hub.get_dir(), os.path.expanduser(new_dir))
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
108*da0073e9SAndroid Build Coastguard Worker    def test_list_entrypoints(self):
109*da0073e9SAndroid Build Coastguard Worker        entry_lists = hub.list("ailzhang/torchhub_example", trust_repo=True)
110*da0073e9SAndroid Build Coastguard Worker        self.assertObjectIn("mnist", entry_lists)
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
113*da0073e9SAndroid Build Coastguard Worker    def test_download_url_to_file(self):
114*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as tmpdir:
115*da0073e9SAndroid Build Coastguard Worker            f = os.path.join(tmpdir, "temp")
116*da0073e9SAndroid Build Coastguard Worker            hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL, f, progress=False)
117*da0073e9SAndroid Build Coastguard Worker            loaded_state = torch.load(f)
118*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
119*da0073e9SAndroid Build Coastguard Worker            # Check that the downloaded file has default file permissions
120*da0073e9SAndroid Build Coastguard Worker            f_ref = os.path.join(tmpdir, "reference")
121*da0073e9SAndroid Build Coastguard Worker            open(f_ref, "w").close()
122*da0073e9SAndroid Build Coastguard Worker            expected_permissions = oct(os.stat(f_ref).st_mode & 0o777)
123*da0073e9SAndroid Build Coastguard Worker            actual_permissions = oct(os.stat(f).st_mode & 0o777)
124*da0073e9SAndroid Build Coastguard Worker            assert actual_permissions == expected_permissions
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
127*da0073e9SAndroid Build Coastguard Worker    def test_load_state_dict_from_url(self):
128*da0073e9SAndroid Build Coastguard Worker        loaded_state = hub.load_state_dict_from_url(TORCHHUB_EXAMPLE_RELEASE_URL)
129*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        # with name
132*da0073e9SAndroid Build Coastguard Worker        file_name = "the_file_name"
133*da0073e9SAndroid Build Coastguard Worker        loaded_state = hub.load_state_dict_from_url(
134*da0073e9SAndroid Build Coastguard Worker            TORCHHUB_EXAMPLE_RELEASE_URL, file_name=file_name
135*da0073e9SAndroid Build Coastguard Worker        )
136*da0073e9SAndroid Build Coastguard Worker        expected_file_path = os.path.join(torch.hub.get_dir(), "checkpoints", file_name)
137*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(os.path.exists(expected_file_path))
138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        # with safe weight_only
141*da0073e9SAndroid Build Coastguard Worker        loaded_state = hub.load_state_dict_from_url(
142*da0073e9SAndroid Build Coastguard Worker            TORCHHUB_EXAMPLE_RELEASE_URL, weights_only=True
143*da0073e9SAndroid Build Coastguard Worker        )
144*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
147*da0073e9SAndroid Build Coastguard Worker    def test_load_legacy_zip_checkpoint(self):
148*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
149*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
150*da0073e9SAndroid Build Coastguard Worker            hub_model = hub.load(
151*da0073e9SAndroid Build Coastguard Worker                "ailzhang/torchhub_example", "mnist_zip", pretrained=True, verbose=False
152*da0073e9SAndroid Build Coastguard Worker            )
153*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
154*da0073e9SAndroid Build Coastguard Worker                sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE
155*da0073e9SAndroid Build Coastguard Worker            )
156*da0073e9SAndroid Build Coastguard Worker            assert any(
157*da0073e9SAndroid Build Coastguard Worker                "will be deprecated in favor of default zipfile" in str(w) for w in ws
158*da0073e9SAndroid Build Coastguard Worker            )
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    # Test the default zipfile serialization format produced by >=1.6 release.
161*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
162*da0073e9SAndroid Build Coastguard Worker    def test_load_zip_1_6_checkpoint(self):
163*da0073e9SAndroid Build Coastguard Worker        hub_model = hub.load(
164*da0073e9SAndroid Build Coastguard Worker            "ailzhang/torchhub_example",
165*da0073e9SAndroid Build Coastguard Worker            "mnist_zip_1_6",
166*da0073e9SAndroid Build Coastguard Worker            pretrained=True,
167*da0073e9SAndroid Build Coastguard Worker            verbose=False,
168*da0073e9SAndroid Build Coastguard Worker            trust_repo=True,
169*da0073e9SAndroid Build Coastguard Worker        )
170*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
173*da0073e9SAndroid Build Coastguard Worker    def test_hub_parse_repo_info(self):
174*da0073e9SAndroid Build Coastguard Worker        # If the branch is specified we just parse the input and return
175*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.hub._parse_repo_info("a/b:c"), ("a", "b", "c"))
176*da0073e9SAndroid Build Coastguard Worker        # For torchvision, the default branch is main
177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
178*da0073e9SAndroid Build Coastguard Worker            torch.hub._parse_repo_info("pytorch/vision"), ("pytorch", "vision", "main")
179*da0073e9SAndroid Build Coastguard Worker        )
180*da0073e9SAndroid Build Coastguard Worker        # For the torchhub_example repo, the default branch is still master
181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
182*da0073e9SAndroid Build Coastguard Worker            torch.hub._parse_repo_info("ailzhang/torchhub_example"),
183*da0073e9SAndroid Build Coastguard Worker            ("ailzhang", "torchhub_example", "master"),
184*da0073e9SAndroid Build Coastguard Worker        )
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
187*da0073e9SAndroid Build Coastguard Worker    def test_load_commit_from_forked_repo(self):
188*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "If it's a commit from a forked repo"):
189*da0073e9SAndroid Build Coastguard Worker            torch.hub.load("pytorch/vision:4e2c216", "resnet18")
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
192*da0073e9SAndroid Build Coastguard Worker    @patch("builtins.input", return_value="")
193*da0073e9SAndroid Build Coastguard Worker    def test_trust_repo_false_emptystring(self, patched_input):
194*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Untrusted repository."):
195*da0073e9SAndroid Build Coastguard Worker            torch.hub.load(
196*da0073e9SAndroid Build Coastguard Worker                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
197*da0073e9SAndroid Build Coastguard Worker            )
198*da0073e9SAndroid Build Coastguard Worker        self._assert_trusted_list_is_empty()
199*da0073e9SAndroid Build Coastguard Worker        patched_input.assert_called_once()
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker        patched_input.reset_mock()
202*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Untrusted repository."):
203*da0073e9SAndroid Build Coastguard Worker            torch.hub.load(
204*da0073e9SAndroid Build Coastguard Worker                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
205*da0073e9SAndroid Build Coastguard Worker            )
206*da0073e9SAndroid Build Coastguard Worker        self._assert_trusted_list_is_empty()
207*da0073e9SAndroid Build Coastguard Worker        patched_input.assert_called_once()
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
210*da0073e9SAndroid Build Coastguard Worker    @patch("builtins.input", return_value="no")
211*da0073e9SAndroid Build Coastguard Worker    def test_trust_repo_false_no(self, patched_input):
212*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Untrusted repository."):
213*da0073e9SAndroid Build Coastguard Worker            torch.hub.load(
214*da0073e9SAndroid Build Coastguard Worker                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
215*da0073e9SAndroid Build Coastguard Worker            )
216*da0073e9SAndroid Build Coastguard Worker        self._assert_trusted_list_is_empty()
217*da0073e9SAndroid Build Coastguard Worker        patched_input.assert_called_once()
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        patched_input.reset_mock()
220*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Untrusted repository."):
221*da0073e9SAndroid Build Coastguard Worker            torch.hub.load(
222*da0073e9SAndroid Build Coastguard Worker                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False
223*da0073e9SAndroid Build Coastguard Worker            )
224*da0073e9SAndroid Build Coastguard Worker        self._assert_trusted_list_is_empty()
225*da0073e9SAndroid Build Coastguard Worker        patched_input.assert_called_once()
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
228*da0073e9SAndroid Build Coastguard Worker    @patch("builtins.input", return_value="y")
229*da0073e9SAndroid Build Coastguard Worker    def test_trusted_repo_false_yes(self, patched_input):
230*da0073e9SAndroid Build Coastguard Worker        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False)
231*da0073e9SAndroid Build Coastguard Worker        self._assert_in_trusted_list("ailzhang_torchhub_example")
232*da0073e9SAndroid Build Coastguard Worker        patched_input.assert_called_once()
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        # Loading a second time with "check", we don't ask for user input
235*da0073e9SAndroid Build Coastguard Worker        patched_input.reset_mock()
236*da0073e9SAndroid Build Coastguard Worker        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
237*da0073e9SAndroid Build Coastguard Worker        patched_input.assert_not_called()
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker        # Loading again with False, we still ask for user input
240*da0073e9SAndroid Build Coastguard Worker        patched_input.reset_mock()
241*da0073e9SAndroid Build Coastguard Worker        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=False)
242*da0073e9SAndroid Build Coastguard Worker        patched_input.assert_called_once()
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
245*da0073e9SAndroid Build Coastguard Worker    @patch("builtins.input", return_value="no")
246*da0073e9SAndroid Build Coastguard Worker    def test_trust_repo_check_no(self, patched_input):
247*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Untrusted repository."):
248*da0073e9SAndroid Build Coastguard Worker            torch.hub.load(
249*da0073e9SAndroid Build Coastguard Worker                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check"
250*da0073e9SAndroid Build Coastguard Worker            )
251*da0073e9SAndroid Build Coastguard Worker        self._assert_trusted_list_is_empty()
252*da0073e9SAndroid Build Coastguard Worker        patched_input.assert_called_once()
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        patched_input.reset_mock()
255*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Untrusted repository."):
256*da0073e9SAndroid Build Coastguard Worker            torch.hub.load(
257*da0073e9SAndroid Build Coastguard Worker                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check"
258*da0073e9SAndroid Build Coastguard Worker            )
259*da0073e9SAndroid Build Coastguard Worker        patched_input.assert_called_once()
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
262*da0073e9SAndroid Build Coastguard Worker    @patch("builtins.input", return_value="y")
263*da0073e9SAndroid Build Coastguard Worker    def test_trust_repo_check_yes(self, patched_input):
264*da0073e9SAndroid Build Coastguard Worker        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
265*da0073e9SAndroid Build Coastguard Worker        self._assert_in_trusted_list("ailzhang_torchhub_example")
266*da0073e9SAndroid Build Coastguard Worker        patched_input.assert_called_once()
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        # Loading a second time with "check", we don't ask for user input
269*da0073e9SAndroid Build Coastguard Worker        patched_input.reset_mock()
270*da0073e9SAndroid Build Coastguard Worker        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
271*da0073e9SAndroid Build Coastguard Worker        patched_input.assert_not_called()
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
274*da0073e9SAndroid Build Coastguard Worker    def test_trust_repo_true(self):
275*da0073e9SAndroid Build Coastguard Worker        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True)
276*da0073e9SAndroid Build Coastguard Worker        self._assert_in_trusted_list("ailzhang_torchhub_example")
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
279*da0073e9SAndroid Build Coastguard Worker    def test_trust_repo_builtin_trusted_owners(self):
280*da0073e9SAndroid Build Coastguard Worker        torch.hub.load("pytorch/vision", "resnet18", trust_repo="check")
281*da0073e9SAndroid Build Coastguard Worker        self._assert_trusted_list_is_empty()
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
284*da0073e9SAndroid Build Coastguard Worker    def test_trust_repo_none(self):
285*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
286*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
287*da0073e9SAndroid Build Coastguard Worker            torch.hub.load(
288*da0073e9SAndroid Build Coastguard Worker                "ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=None
289*da0073e9SAndroid Build Coastguard Worker            )
290*da0073e9SAndroid Build Coastguard Worker            assert len(w) == 1
291*da0073e9SAndroid Build Coastguard Worker            assert issubclass(w[-1].category, UserWarning)
292*da0073e9SAndroid Build Coastguard Worker            assert (
293*da0073e9SAndroid Build Coastguard Worker                "You are about to download and run code from an untrusted repository"
294*da0073e9SAndroid Build Coastguard Worker                in str(w[-1].message)
295*da0073e9SAndroid Build Coastguard Worker            )
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        self._assert_trusted_list_is_empty()
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker    @retry(Exception, tries=3)
300*da0073e9SAndroid Build Coastguard Worker    def test_trust_repo_legacy(self):
301*da0073e9SAndroid Build Coastguard Worker        # We first download a repo and then delete the allowlist file
302*da0073e9SAndroid Build Coastguard Worker        # Then we check that the repo is indeed trusted without a prompt,
303*da0073e9SAndroid Build Coastguard Worker        # because it was already downloaded in the past.
304*da0073e9SAndroid Build Coastguard Worker        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo=True)
305*da0073e9SAndroid Build Coastguard Worker        os.remove(self.trusted_list_path)
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        torch.hub.load("ailzhang/torchhub_example", "mnist_zip_1_6", trust_repo="check")
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        self._assert_trusted_list_is_empty()
310