xref: /aosp_15_r20/external/pytorch/torch/hub.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import errno
4import hashlib
5import json
6import os
7import re
8import shutil
9import sys
10import tempfile
11import uuid
12import warnings
13import zipfile
14from pathlib import Path
15from typing import Any, Dict, Optional
16from typing_extensions import deprecated
17from urllib.error import HTTPError, URLError
18from urllib.parse import urlparse  # noqa: F401
19from urllib.request import Request, urlopen
20
21import torch
22from torch.serialization import MAP_LOCATION
23
24
25class _Faketqdm:  # type: ignore[no-redef]
26    def __init__(self, total=None, disable=False, unit=None, *args, **kwargs):
27        self.total = total
28        self.disable = disable
29        self.n = 0
30        # Ignore all extra *args and **kwargs lest you want to reinvent tqdm
31
32    def update(self, n):
33        if self.disable:
34            return
35
36        self.n += n
37        if self.total is None:
38            sys.stderr.write(f"\r{self.n:.1f} bytes")
39        else:
40            sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%")
41        sys.stderr.flush()
42
43    # Don't bother implementing; use real tqdm if you want
44    def set_description(self, *args, **kwargs):
45        pass
46
47    def write(self, s):
48        sys.stderr.write(f"{s}\n")
49
50    def close(self):
51        self.disable = True
52
53    def __enter__(self):
54        return self
55
56    def __exit__(self, exc_type, exc_val, exc_tb):
57        if self.disable:
58            return
59
60        sys.stderr.write("\n")
61
62
63try:
64    from tqdm import tqdm  # If tqdm is installed use it, otherwise use the fake wrapper
65except ImportError:
66    tqdm = _Faketqdm
67
68__all__ = [
69    "download_url_to_file",
70    "get_dir",
71    "help",
72    "list",
73    "load",
74    "load_state_dict_from_url",
75    "set_dir",
76]
77
78# matches bfd8deac from resnet18-bfd8deac.pth
79HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
80
81_TRUSTED_REPO_OWNERS = (
82    "facebookresearch",
83    "facebookincubator",
84    "pytorch",
85    "fairinternal",
86)
87ENV_GITHUB_TOKEN = "GITHUB_TOKEN"
88ENV_TORCH_HOME = "TORCH_HOME"
89ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
90DEFAULT_CACHE_DIR = "~/.cache"
91VAR_DEPENDENCY = "dependencies"
92MODULE_HUBCONF = "hubconf.py"
93READ_DATA_CHUNK = 128 * 1024
94_hub_dir: Optional[str] = None
95
96
97@contextlib.contextmanager
98def _add_to_sys_path(path):
99    sys.path.insert(0, path)
100    try:
101        yield
102    finally:
103        sys.path.remove(path)
104
105
106# Copied from tools/shared/module_loader to be included in torch package
107def _import_module(name, path):
108    import importlib.util
109    from importlib.abc import Loader
110
111    spec = importlib.util.spec_from_file_location(name, path)
112    assert spec is not None
113    module = importlib.util.module_from_spec(spec)
114    assert isinstance(spec.loader, Loader)
115    spec.loader.exec_module(module)
116    return module
117
118
119def _remove_if_exists(path):
120    if os.path.exists(path):
121        if os.path.isfile(path):
122            os.remove(path)
123        else:
124            shutil.rmtree(path)
125
126
127def _git_archive_link(repo_owner, repo_name, ref):
128    # See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip
129    return f"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}"
130
131
132def _load_attr_from_module(module, func_name):
133    # Check if callable is defined in the module
134    if func_name not in dir(module):
135        return None
136    return getattr(module, func_name)
137
138
139def _get_torch_home():
140    torch_home = os.path.expanduser(
141        os.getenv(
142            ENV_TORCH_HOME,
143            os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
144        )
145    )
146    return torch_home
147
148
149def _parse_repo_info(github):
150    if ":" in github:
151        repo_info, ref = github.split(":")
152    else:
153        repo_info, ref = github, None
154    repo_owner, repo_name = repo_info.split("/")
155
156    if ref is None:
157        # The ref wasn't specified by the user, so we need to figure out the
158        # default branch: main or master. Our assumption is that if main exists
159        # then it's the default branch, otherwise it's master.
160        try:
161            with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
162                ref = "main"
163        except HTTPError as e:
164            if e.code == 404:
165                ref = "master"
166            else:
167                raise
168        except URLError as e:
169            # No internet connection, need to check for cache as last resort
170            for possible_ref in ("main", "master"):
171                if os.path.exists(
172                    f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"
173                ):
174                    ref = possible_ref
175                    break
176            if ref is None:
177                raise RuntimeError(
178                    "It looks like there is no internet connection and the "
179                    f"repo could not be found in the cache ({get_dir()})"
180                ) from e
181    return repo_owner, repo_name, ref
182
183
184def _read_url(url):
185    with urlopen(url) as r:
186        return r.read().decode(r.headers.get_content_charset("utf-8"))
187
188
189def _validate_not_a_forked_repo(repo_owner, repo_name, ref):
190    # Use urlopen to avoid depending on local git.
191    headers = {"Accept": "application/vnd.github.v3+json"}
192    token = os.environ.get(ENV_GITHUB_TOKEN)
193    if token is not None:
194        headers["Authorization"] = f"token {token}"
195    for url_prefix in (
196        f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches",
197        f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags",
198    ):
199        page = 0
200        while True:
201            page += 1
202            url = f"{url_prefix}?per_page=100&page={page}"
203            response = json.loads(_read_url(Request(url, headers=headers)))
204            # Empty response means no more data to process
205            if not response:
206                break
207            for br in response:
208                if br["name"] == ref or br["commit"]["sha"].startswith(ref):
209                    return
210
211    raise ValueError(
212        f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. "
213        "If it's a commit from a forked repo, please call hub.load() with forked repo directly."
214    )
215
216
217def _get_cache_or_reload(
218    github,
219    force_reload,
220    trust_repo,
221    calling_fn,
222    verbose=True,
223    skip_validation=False,
224):
225    # Setup hub_dir to save downloaded files
226    hub_dir = get_dir()
227    os.makedirs(hub_dir, exist_ok=True)
228    # Parse github repo information
229    repo_owner, repo_name, ref = _parse_repo_info(github)
230    # Github allows branch name with slash '/',
231    # this causes confusion with path on both Linux and Windows.
232    # Backslash is not allowed in Github branch name so no need to
233    # to worry about it.
234    normalized_br = ref.replace("/", "_")
235    # Github renames folder repo-v1.x.x to repo-1.x.x
236    # We don't know the repo name before downloading the zip file
237    # and inspect name from it.
238    # To check if cached repo exists, we need to normalize folder names.
239    owner_name_branch = "_".join([repo_owner, repo_name, normalized_br])
240    repo_dir = os.path.join(hub_dir, owner_name_branch)
241    # Check that the repo is in the trusted list
242    _check_repo_is_trusted(
243        repo_owner,
244        repo_name,
245        owner_name_branch,
246        trust_repo=trust_repo,
247        calling_fn=calling_fn,
248    )
249
250    use_cache = (not force_reload) and os.path.exists(repo_dir)
251
252    if use_cache:
253        if verbose:
254            sys.stderr.write(f"Using cache found in {repo_dir}\n")
255    else:
256        # Validate the tag/branch is from the original repo instead of a forked repo
257        if not skip_validation:
258            _validate_not_a_forked_repo(repo_owner, repo_name, ref)
259
260        cached_file = os.path.join(hub_dir, normalized_br + ".zip")
261        _remove_if_exists(cached_file)
262
263        try:
264            url = _git_archive_link(repo_owner, repo_name, ref)
265            sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n')
266            download_url_to_file(url, cached_file, progress=False)
267        except HTTPError as err:
268            if err.code == 300:
269                # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch
270                # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags
271                # See https://git-scm.com/book/en/v2/Git-Internals-Git-References
272                # Here, we do the same as git: we throw a warning, and assume the user wanted the branch
273                warnings.warn(
274                    f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? "
275                    "Torchhub will now assume that it's a branch. "
276                    "You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or "
277                    "refs/tags/tag_name as the ref. That might require using skip_validation=True."
278                )
279                disambiguated_branch_ref = f"refs/heads/{ref}"
280                url = _git_archive_link(
281                    repo_owner, repo_name, ref=disambiguated_branch_ref
282                )
283                download_url_to_file(url, cached_file, progress=False)
284            else:
285                raise
286
287        with zipfile.ZipFile(cached_file) as cached_zipfile:
288            extraced_repo_name = cached_zipfile.infolist()[0].filename
289            extracted_repo = os.path.join(hub_dir, extraced_repo_name)
290            _remove_if_exists(extracted_repo)
291            # Unzip the code and rename the base folder
292            cached_zipfile.extractall(hub_dir)
293
294        _remove_if_exists(cached_file)
295        _remove_if_exists(repo_dir)
296        shutil.move(extracted_repo, repo_dir)  # rename the repo
297
298    return repo_dir
299
300
301def _check_repo_is_trusted(
302    repo_owner,
303    repo_name,
304    owner_name_branch,
305    trust_repo,
306    calling_fn="load",
307):
308    hub_dir = get_dir()
309    filepath = os.path.join(hub_dir, "trusted_list")
310
311    if not os.path.exists(filepath):
312        Path(filepath).touch()
313    with open(filepath) as file:
314        trusted_repos = tuple(line.strip() for line in file)
315
316    # To minimize friction of introducing the new trust_repo mechanism, we consider that
317    # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist)
318    trusted_repos_legacy = next(os.walk(hub_dir))[1]
319
320    owner_name = "_".join([repo_owner, repo_name])
321    is_trusted = (
322        owner_name in trusted_repos
323        or owner_name_branch in trusted_repos_legacy
324        or repo_owner in _TRUSTED_REPO_OWNERS
325    )
326
327    # TODO: Remove `None` option in 2.0 and change the default to "check"
328    if trust_repo is None:
329        if not is_trusted:
330            warnings.warn(
331                "You are about to download and run code from an untrusted repository. In a future release, this won't "
332                "be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., "
333                "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
334                f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
335                f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
336                f"confirmation if the repo is not already trusted. This will eventually be the default behaviour"
337            )
338        return
339
340    if (trust_repo is False) or (trust_repo == "check" and not is_trusted):
341        response = input(
342            f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. "
343            "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?"
344        )
345        if response.lower() in ("y", "yes"):
346            if is_trusted:
347                print("The repository is already trusted.")
348        elif response.lower() in ("n", "no", ""):
349            raise Exception("Untrusted repository.")  # noqa: TRY002
350        else:
351            raise ValueError(f"Unrecognized response {response}.")
352
353    # At this point we're sure that the user trusts the repo (or wants to trust it)
354    if not is_trusted:
355        with open(filepath, "a") as file:
356            file.write(owner_name + "\n")
357
358
359def _check_module_exists(name):
360    import importlib.util
361
362    return importlib.util.find_spec(name) is not None
363
364
365def _check_dependencies(m):
366    dependencies = _load_attr_from_module(m, VAR_DEPENDENCY)
367
368    if dependencies is not None:
369        missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
370        if len(missing_deps):
371            raise RuntimeError(f"Missing dependencies: {', '.join(missing_deps)}")
372
373
374def _load_entry_from_hubconf(m, model):
375    if not isinstance(model, str):
376        raise ValueError("Invalid input: model should be a string of function name")
377
378    # Note that if a missing dependency is imported at top level of hubconf, it will
379    # throw before this function. It's a chicken and egg situation where we have to
380    # load hubconf to know what're the dependencies, but to import hubconf it requires
381    # a missing package. This is fine, Python will throw proper error message for users.
382    _check_dependencies(m)
383
384    func = _load_attr_from_module(m, model)
385
386    if func is None or not callable(func):
387        raise RuntimeError(f"Cannot find callable {model} in hubconf")
388
389    return func
390
391
392def get_dir():
393    r"""
394    Get the Torch Hub cache directory used for storing downloaded models & weights.
395
396    If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
397    environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
398    ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
399    filesystem layout, with a default value ``~/.cache`` if the environment
400    variable is not set.
401    """
402    # Issue warning to move data if old env is set
403    if os.getenv("TORCH_HUB"):
404        warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead")
405
406    if _hub_dir is not None:
407        return _hub_dir
408    return os.path.join(_get_torch_home(), "hub")
409
410
411def set_dir(d):
412    r"""
413    Optionally set the Torch Hub directory used to save downloaded models & weights.
414
415    Args:
416        d (str): path to a local folder to save downloaded models & weights.
417    """
418    global _hub_dir
419    _hub_dir = os.path.expanduser(d)
420
421
422def list(
423    github,
424    force_reload=False,
425    skip_validation=False,
426    trust_repo=None,
427    verbose=True,
428):
429    r"""
430    List all callable entrypoints available in the repo specified by ``github``.
431
432    Args:
433        github (str): a string with format "repo_owner/repo_name[:ref]" with an optional
434            ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if
435            it exists, and otherwise ``master``.
436            Example: 'pytorch/vision:0.10'
437        force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
438            Default is ``False``.
439        skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit
440            specified by the ``github`` argument properly belongs to the repo owner. This will make
441            requests to the GitHub API; you can specify a non-default GitHub token by setting the
442            ``GITHUB_TOKEN`` environment variable. Default is ``False``.
443        trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
444            This parameter was introduced in v1.12 and helps ensuring that users
445            only run code from repos that they trust.
446
447            - If ``False``, a prompt will ask the user whether the repo should
448              be trusted.
449            - If ``True``, the repo will be added to the trusted list and loaded
450              without requiring explicit confirmation.
451            - If ``"check"``, the repo will be checked against the list of
452              trusted repos in the cache. If it is not present in that list, the
453              behaviour will fall back onto the ``trust_repo=False`` option.
454            - If ``None``: this will raise a warning, inviting the user to set
455              ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
456              is only present for backward compatibility and will be removed in
457              v2.0.
458
459            Default is ``None`` and will eventually change to ``"check"`` in v2.0.
460        verbose (bool, optional): If ``False``, mute messages about hitting
461            local caches. Note that the message about first download cannot be
462            muted. Default is ``True``.
463
464    Returns:
465        list: The available callables entrypoint
466
467    Example:
468        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
469        >>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
470    """
471    repo_dir = _get_cache_or_reload(
472        github,
473        force_reload,
474        trust_repo,
475        "list",
476        verbose=verbose,
477        skip_validation=skip_validation,
478    )
479
480    with _add_to_sys_path(repo_dir):
481        hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
482        hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
483
484    # We take functions starts with '_' as internal helper functions
485    entrypoints = [
486        f
487        for f in dir(hub_module)
488        if callable(getattr(hub_module, f)) and not f.startswith("_")
489    ]
490
491    return entrypoints
492
493
494def help(github, model, force_reload=False, skip_validation=False, trust_repo=None):
495    r"""
496    Show the docstring of entrypoint ``model``.
497
498    Args:
499        github (str): a string with format <repo_owner/repo_name[:ref]> with an optional
500            ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed
501            to be ``main`` if it exists, and otherwise ``master``.
502            Example: 'pytorch/vision:0.10'
503        model (str): a string of entrypoint name defined in repo's ``hubconf.py``
504        force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
505            Default is ``False``.
506        skip_validation (bool, optional): if ``False``, torchhub will check that the ref
507            specified by the ``github`` argument properly belongs to the repo owner. This will make
508            requests to the GitHub API; you can specify a non-default GitHub token by setting the
509            ``GITHUB_TOKEN`` environment variable. Default is ``False``.
510        trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
511            This parameter was introduced in v1.12 and helps ensuring that users
512            only run code from repos that they trust.
513
514            - If ``False``, a prompt will ask the user whether the repo should
515              be trusted.
516            - If ``True``, the repo will be added to the trusted list and loaded
517              without requiring explicit confirmation.
518            - If ``"check"``, the repo will be checked against the list of
519              trusted repos in the cache. If it is not present in that list, the
520              behaviour will fall back onto the ``trust_repo=False`` option.
521            - If ``None``: this will raise a warning, inviting the user to set
522              ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
523              is only present for backward compatibility and will be removed in
524              v2.0.
525
526            Default is ``None`` and will eventually change to ``"check"`` in v2.0.
527    Example:
528        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
529        >>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
530    """
531    repo_dir = _get_cache_or_reload(
532        github,
533        force_reload,
534        trust_repo,
535        "help",
536        verbose=True,
537        skip_validation=skip_validation,
538    )
539
540    with _add_to_sys_path(repo_dir):
541        hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
542        hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
543
544    entry = _load_entry_from_hubconf(hub_module, model)
545
546    return entry.__doc__
547
548
549def load(
550    repo_or_dir,
551    model,
552    *args,
553    source="github",
554    trust_repo=None,
555    force_reload=False,
556    verbose=True,
557    skip_validation=False,
558    **kwargs,
559):
560    r"""
561    Load a model from a github repo or a local directory.
562
563    Note: Loading a model is the typical use case, but this can also be used to
564    for loading other objects such as tokenizers, loss functions, etc.
565
566    If ``source`` is 'github', ``repo_or_dir`` is expected to be
567    of the form ``repo_owner/repo_name[:ref]`` with an optional
568    ref (a tag or a branch).
569
570    If ``source`` is 'local', ``repo_or_dir`` is expected to be a
571    path to a local directory.
572
573    Args:
574        repo_or_dir (str): If ``source`` is 'github',
575            this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with
576            an optional ref (tag or branch), for example 'pytorch/vision:0.10'. If ``ref`` is not specified,
577            the default branch is assumed to be ``main`` if it exists, and otherwise ``master``.
578            If ``source`` is 'local'  then it should be a path to a local directory.
579        model (str): the name of a callable (entrypoint) defined in the
580            repo/dir's ``hubconf.py``.
581        *args (optional): the corresponding args for callable ``model``.
582        source (str, optional): 'github' or 'local'. Specifies how
583            ``repo_or_dir`` is to be interpreted. Default is 'github'.
584        trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``.
585            This parameter was introduced in v1.12 and helps ensuring that users
586            only run code from repos that they trust.
587
588            - If ``False``, a prompt will ask the user whether the repo should
589              be trusted.
590            - If ``True``, the repo will be added to the trusted list and loaded
591              without requiring explicit confirmation.
592            - If ``"check"``, the repo will be checked against the list of
593              trusted repos in the cache. If it is not present in that list, the
594              behaviour will fall back onto the ``trust_repo=False`` option.
595            - If ``None``: this will raise a warning, inviting the user to set
596              ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This
597              is only present for backward compatibility and will be removed in
598              v2.0.
599
600            Default is ``None`` and will eventually change to ``"check"`` in v2.0.
601        force_reload (bool, optional): whether to force a fresh download of
602            the github repo unconditionally. Does not have any effect if
603            ``source = 'local'``. Default is ``False``.
604        verbose (bool, optional): If ``False``, mute messages about hitting
605            local caches. Note that the message about first download cannot be
606            muted. Does not have any effect if ``source = 'local'``.
607            Default is ``True``.
608        skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit
609            specified by the ``github`` argument properly belongs to the repo owner. This will make
610            requests to the GitHub API; you can specify a non-default GitHub token by setting the
611            ``GITHUB_TOKEN`` environment variable. Default is ``False``.
612        **kwargs (optional): the corresponding kwargs for callable ``model``.
613
614    Returns:
615        The output of the ``model`` callable when called with the given
616        ``*args`` and ``**kwargs``.
617
618    Example:
619        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
620        >>> # from a github repo
621        >>> repo = "pytorch/vision"
622        >>> model = torch.hub.load(
623        ...     repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1"
624        ... )
625        >>> # from a local directory
626        >>> path = "/some/local/path/pytorch/vision"
627        >>> # xdoctest: +SKIP
628        >>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
629    """
630    source = source.lower()
631
632    if source not in ("github", "local"):
633        raise ValueError(
634            f'Unknown source: "{source}". Allowed values: "github" | "local".'
635        )
636
637    if source == "github":
638        repo_or_dir = _get_cache_or_reload(
639            repo_or_dir,
640            force_reload,
641            trust_repo,
642            "load",
643            verbose=verbose,
644            skip_validation=skip_validation,
645        )
646
647    model = _load_local(repo_or_dir, model, *args, **kwargs)
648    return model
649
650
651def _load_local(hubconf_dir, model, *args, **kwargs):
652    r"""
653    Load a model from a local directory with a ``hubconf.py``.
654
655    Args:
656        hubconf_dir (str): path to a local directory that contains a
657            ``hubconf.py``.
658        model (str): name of an entrypoint defined in the directory's
659            ``hubconf.py``.
660        *args (optional): the corresponding args for callable ``model``.
661        **kwargs (optional): the corresponding kwargs for callable ``model``.
662
663    Returns:
664        a single model with corresponding pretrained weights.
665
666    Example:
667        >>> # xdoctest: +SKIP("stub local path")
668        >>> path = "/some/local/path/pytorch/vision"
669        >>> model = _load_local(path, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1")
670    """
671    with _add_to_sys_path(hubconf_dir):
672        hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
673        hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
674
675        entry = _load_entry_from_hubconf(hub_module, model)
676        model = entry(*args, **kwargs)
677
678    return model
679
680
681def download_url_to_file(
682    url: str,
683    dst: str,
684    hash_prefix: Optional[str] = None,
685    progress: bool = True,
686) -> None:
687    r"""Download object at the given URL to a local path.
688
689    Args:
690        url (str): URL of the object to download
691        dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file``
692        hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
693            Default: None
694        progress (bool, optional): whether or not to display a progress bar to stderr
695            Default: True
696
697    Example:
698        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
699        >>> # xdoctest: +REQUIRES(POSIX)
700        >>> torch.hub.download_url_to_file(
701        ...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth",
702        ...     "/tmp/temporary_file",
703        ... )
704
705    """
706    file_size = None
707    req = Request(url, headers={"User-Agent": "torch.hub"})
708    u = urlopen(req)
709    meta = u.info()
710    if hasattr(meta, "getheaders"):
711        content_length = meta.getheaders("Content-Length")
712    else:
713        content_length = meta.get_all("Content-Length")
714    if content_length is not None and len(content_length) > 0:
715        file_size = int(content_length[0])
716
717    # We deliberately save it in a temp file and move it after
718    # download is complete. This prevents a local working checkpoint
719    # being overridden by a broken download.
720    # We deliberately do not use NamedTemporaryFile to avoid restrictive
721    # file permissions being applied to the downloaded file.
722    dst = os.path.expanduser(dst)
723    for seq in range(tempfile.TMP_MAX):
724        tmp_dst = dst + "." + uuid.uuid4().hex + ".partial"
725        try:
726            f = open(tmp_dst, "w+b")
727        except FileExistsError:
728            continue
729        break
730    else:
731        raise FileExistsError(errno.EEXIST, "No usable temporary file name found")
732
733    try:
734        if hash_prefix is not None:
735            sha256 = hashlib.sha256()
736        with tqdm(
737            total=file_size,
738            disable=not progress,
739            unit="B",
740            unit_scale=True,
741            unit_divisor=1024,
742        ) as pbar:
743            while True:
744                buffer = u.read(READ_DATA_CHUNK)
745                if len(buffer) == 0:
746                    break
747                f.write(buffer)  # type: ignore[possibly-undefined]
748                if hash_prefix is not None:
749                    sha256.update(buffer)  # type: ignore[possibly-undefined]
750                pbar.update(len(buffer))
751
752        f.close()
753        if hash_prefix is not None:
754            digest = sha256.hexdigest()  # type: ignore[possibly-undefined]
755            if digest[: len(hash_prefix)] != hash_prefix:
756                raise RuntimeError(
757                    f'invalid hash value (expected "{hash_prefix}", got "{digest}")'
758                )
759        shutil.move(f.name, dst)
760    finally:
761        f.close()
762        if os.path.exists(f.name):
763            os.remove(f.name)
764
765
766# Hub used to support automatically extracts from zipfile manually compressed by users.
767# The legacy zip format expects only one file from torch.save() < 1.6 in the zip.
768# We should remove this support since zipfile is now default zipfile format for torch.save().
769def _is_legacy_zip_format(filename: str) -> bool:
770    if zipfile.is_zipfile(filename):
771        infolist = zipfile.ZipFile(filename).infolist()
772        return len(infolist) == 1 and not infolist[0].is_dir()
773    return False
774
775
776@deprecated(
777    "Falling back to the old format < 1.6. This support will be "
778    "deprecated in favor of default zipfile format introduced in 1.6. "
779    "Please redo torch.save() to save it in the new zipfile format.",
780    category=FutureWarning,
781)
782def _legacy_zip_load(
783    filename: str,
784    model_dir: str,
785    map_location: MAP_LOCATION,
786    weights_only: bool,
787) -> Dict[str, Any]:
788    # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
789    #       We deliberately don't handle tarfile here since our legacy serialization format was in tar.
790    #       E.g. resnet18-5c106cde.pth which is widely used.
791    with zipfile.ZipFile(filename) as f:
792        members = f.infolist()
793        if len(members) != 1:
794            raise RuntimeError("Only one file(not dir) is allowed in the zipfile")
795        f.extractall(model_dir)
796        extraced_name = members[0].filename
797        extracted_file = os.path.join(model_dir, extraced_name)
798    return torch.load(
799        extracted_file, map_location=map_location, weights_only=weights_only
800    )
801
802
803def load_state_dict_from_url(
804    url: str,
805    model_dir: Optional[str] = None,
806    map_location: MAP_LOCATION = None,
807    progress: bool = True,
808    check_hash: bool = False,
809    file_name: Optional[str] = None,
810    weights_only: bool = False,
811) -> Dict[str, Any]:
812    r"""Loads the Torch serialized object at the given URL.
813
814    If downloaded file is a zip file, it will be automatically
815    decompressed.
816
817    If the object is already present in `model_dir`, it's deserialized and
818    returned.
819    The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
820    ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
821
822    Args:
823        url (str): URL of the object to download
824        model_dir (str, optional): directory in which to save the object
825        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
826        progress (bool, optional): whether or not to display a progress bar to stderr.
827            Default: True
828        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
829            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
830            digits of the SHA256 hash of the contents of the file. The hash is used to
831            ensure unique names and to verify the contents of the file.
832            Default: False
833        file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set.
834        weights_only(bool, optional): If True, only weights will be loaded and no complex pickled objects.
835            Recommended for untrusted sources. See :func:`~torch.load` for more details.
836
837    Example:
838        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
839        >>> state_dict = torch.hub.load_state_dict_from_url(
840        ...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth"
841        ... )
842
843    """
844    # Issue warning to move data if old env is set
845    if os.getenv("TORCH_MODEL_ZOO"):
846        warnings.warn(
847            "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead"
848        )
849
850    if model_dir is None:
851        hub_dir = get_dir()
852        model_dir = os.path.join(hub_dir, "checkpoints")
853
854    os.makedirs(model_dir, exist_ok=True)
855
856    parts = urlparse(url)
857    filename = os.path.basename(parts.path)
858    if file_name is not None:
859        filename = file_name
860    cached_file = os.path.join(model_dir, filename)
861    if not os.path.exists(cached_file):
862        sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n')
863        hash_prefix = None
864        if check_hash:
865            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
866            hash_prefix = r.group(1) if r else None
867        download_url_to_file(url, cached_file, hash_prefix, progress=progress)
868
869    if _is_legacy_zip_format(cached_file):
870        return _legacy_zip_load(cached_file, model_dir, map_location, weights_only)
871    return torch.load(cached_file, map_location=map_location, weights_only=weights_only)
872