1# mypy: allow-untyped-defs 2import contextlib 3import errno 4import hashlib 5import json 6import os 7import re 8import shutil 9import sys 10import tempfile 11import uuid 12import warnings 13import zipfile 14from pathlib import Path 15from typing import Any, Dict, Optional 16from typing_extensions import deprecated 17from urllib.error import HTTPError, URLError 18from urllib.parse import urlparse # noqa: F401 19from urllib.request import Request, urlopen 20 21import torch 22from torch.serialization import MAP_LOCATION 23 24 25class _Faketqdm: # type: ignore[no-redef] 26 def __init__(self, total=None, disable=False, unit=None, *args, **kwargs): 27 self.total = total 28 self.disable = disable 29 self.n = 0 30 # Ignore all extra *args and **kwargs lest you want to reinvent tqdm 31 32 def update(self, n): 33 if self.disable: 34 return 35 36 self.n += n 37 if self.total is None: 38 sys.stderr.write(f"\r{self.n:.1f} bytes") 39 else: 40 sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%") 41 sys.stderr.flush() 42 43 # Don't bother implementing; use real tqdm if you want 44 def set_description(self, *args, **kwargs): 45 pass 46 47 def write(self, s): 48 sys.stderr.write(f"{s}\n") 49 50 def close(self): 51 self.disable = True 52 53 def __enter__(self): 54 return self 55 56 def __exit__(self, exc_type, exc_val, exc_tb): 57 if self.disable: 58 return 59 60 sys.stderr.write("\n") 61 62 63try: 64 from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper 65except ImportError: 66 tqdm = _Faketqdm 67 68__all__ = [ 69 "download_url_to_file", 70 "get_dir", 71 "help", 72 "list", 73 "load", 74 "load_state_dict_from_url", 75 "set_dir", 76] 77 78# matches bfd8deac from resnet18-bfd8deac.pth 79HASH_REGEX = re.compile(r"-([a-f0-9]*)\.") 80 81_TRUSTED_REPO_OWNERS = ( 82 "facebookresearch", 83 "facebookincubator", 84 "pytorch", 85 "fairinternal", 86) 87ENV_GITHUB_TOKEN = "GITHUB_TOKEN" 88ENV_TORCH_HOME = "TORCH_HOME" 89ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" 90DEFAULT_CACHE_DIR = "~/.cache" 91VAR_DEPENDENCY = "dependencies" 92MODULE_HUBCONF = "hubconf.py" 93READ_DATA_CHUNK = 128 * 1024 94_hub_dir: Optional[str] = None 95 96 97@contextlib.contextmanager 98def _add_to_sys_path(path): 99 sys.path.insert(0, path) 100 try: 101 yield 102 finally: 103 sys.path.remove(path) 104 105 106# Copied from tools/shared/module_loader to be included in torch package 107def _import_module(name, path): 108 import importlib.util 109 from importlib.abc import Loader 110 111 spec = importlib.util.spec_from_file_location(name, path) 112 assert spec is not None 113 module = importlib.util.module_from_spec(spec) 114 assert isinstance(spec.loader, Loader) 115 spec.loader.exec_module(module) 116 return module 117 118 119def _remove_if_exists(path): 120 if os.path.exists(path): 121 if os.path.isfile(path): 122 os.remove(path) 123 else: 124 shutil.rmtree(path) 125 126 127def _git_archive_link(repo_owner, repo_name, ref): 128 # See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip 129 return f"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}" 130 131 132def _load_attr_from_module(module, func_name): 133 # Check if callable is defined in the module 134 if func_name not in dir(module): 135 return None 136 return getattr(module, func_name) 137 138 139def _get_torch_home(): 140 torch_home = os.path.expanduser( 141 os.getenv( 142 ENV_TORCH_HOME, 143 os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"), 144 ) 145 ) 146 return torch_home 147 148 149def _parse_repo_info(github): 150 if ":" in github: 151 repo_info, ref = github.split(":") 152 else: 153 repo_info, ref = github, None 154 repo_owner, repo_name = repo_info.split("/") 155 156 if ref is None: 157 # The ref wasn't specified by the user, so we need to figure out the 158 # default branch: main or master. Our assumption is that if main exists 159 # then it's the default branch, otherwise it's master. 160 try: 161 with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"): 162 ref = "main" 163 except HTTPError as e: 164 if e.code == 404: 165 ref = "master" 166 else: 167 raise 168 except URLError as e: 169 # No internet connection, need to check for cache as last resort 170 for possible_ref in ("main", "master"): 171 if os.path.exists( 172 f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}" 173 ): 174 ref = possible_ref 175 break 176 if ref is None: 177 raise RuntimeError( 178 "It looks like there is no internet connection and the " 179 f"repo could not be found in the cache ({get_dir()})" 180 ) from e 181 return repo_owner, repo_name, ref 182 183 184def _read_url(url): 185 with urlopen(url) as r: 186 return r.read().decode(r.headers.get_content_charset("utf-8")) 187 188 189def _validate_not_a_forked_repo(repo_owner, repo_name, ref): 190 # Use urlopen to avoid depending on local git. 191 headers = {"Accept": "application/vnd.github.v3+json"} 192 token = os.environ.get(ENV_GITHUB_TOKEN) 193 if token is not None: 194 headers["Authorization"] = f"token {token}" 195 for url_prefix in ( 196 f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches", 197 f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags", 198 ): 199 page = 0 200 while True: 201 page += 1 202 url = f"{url_prefix}?per_page=100&page={page}" 203 response = json.loads(_read_url(Request(url, headers=headers))) 204 # Empty response means no more data to process 205 if not response: 206 break 207 for br in response: 208 if br["name"] == ref or br["commit"]["sha"].startswith(ref): 209 return 210 211 raise ValueError( 212 f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. " 213 "If it's a commit from a forked repo, please call hub.load() with forked repo directly." 214 ) 215 216 217def _get_cache_or_reload( 218 github, 219 force_reload, 220 trust_repo, 221 calling_fn, 222 verbose=True, 223 skip_validation=False, 224): 225 # Setup hub_dir to save downloaded files 226 hub_dir = get_dir() 227 os.makedirs(hub_dir, exist_ok=True) 228 # Parse github repo information 229 repo_owner, repo_name, ref = _parse_repo_info(github) 230 # Github allows branch name with slash '/', 231 # this causes confusion with path on both Linux and Windows. 232 # Backslash is not allowed in Github branch name so no need to 233 # to worry about it. 234 normalized_br = ref.replace("/", "_") 235 # Github renames folder repo-v1.x.x to repo-1.x.x 236 # We don't know the repo name before downloading the zip file 237 # and inspect name from it. 238 # To check if cached repo exists, we need to normalize folder names. 239 owner_name_branch = "_".join([repo_owner, repo_name, normalized_br]) 240 repo_dir = os.path.join(hub_dir, owner_name_branch) 241 # Check that the repo is in the trusted list 242 _check_repo_is_trusted( 243 repo_owner, 244 repo_name, 245 owner_name_branch, 246 trust_repo=trust_repo, 247 calling_fn=calling_fn, 248 ) 249 250 use_cache = (not force_reload) and os.path.exists(repo_dir) 251 252 if use_cache: 253 if verbose: 254 sys.stderr.write(f"Using cache found in {repo_dir}\n") 255 else: 256 # Validate the tag/branch is from the original repo instead of a forked repo 257 if not skip_validation: 258 _validate_not_a_forked_repo(repo_owner, repo_name, ref) 259 260 cached_file = os.path.join(hub_dir, normalized_br + ".zip") 261 _remove_if_exists(cached_file) 262 263 try: 264 url = _git_archive_link(repo_owner, repo_name, ref) 265 sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') 266 download_url_to_file(url, cached_file, progress=False) 267 except HTTPError as err: 268 if err.code == 300: 269 # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch 270 # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags 271 # See https://git-scm.com/book/en/v2/Git-Internals-Git-References 272 # Here, we do the same as git: we throw a warning, and assume the user wanted the branch 273 warnings.warn( 274 f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? " 275 "Torchhub will now assume that it's a branch. " 276 "You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or " 277 "refs/tags/tag_name as the ref. That might require using skip_validation=True." 278 ) 279 disambiguated_branch_ref = f"refs/heads/{ref}" 280 url = _git_archive_link( 281 repo_owner, repo_name, ref=disambiguated_branch_ref 282 ) 283 download_url_to_file(url, cached_file, progress=False) 284 else: 285 raise 286 287 with zipfile.ZipFile(cached_file) as cached_zipfile: 288 extraced_repo_name = cached_zipfile.infolist()[0].filename 289 extracted_repo = os.path.join(hub_dir, extraced_repo_name) 290 _remove_if_exists(extracted_repo) 291 # Unzip the code and rename the base folder 292 cached_zipfile.extractall(hub_dir) 293 294 _remove_if_exists(cached_file) 295 _remove_if_exists(repo_dir) 296 shutil.move(extracted_repo, repo_dir) # rename the repo 297 298 return repo_dir 299 300 301def _check_repo_is_trusted( 302 repo_owner, 303 repo_name, 304 owner_name_branch, 305 trust_repo, 306 calling_fn="load", 307): 308 hub_dir = get_dir() 309 filepath = os.path.join(hub_dir, "trusted_list") 310 311 if not os.path.exists(filepath): 312 Path(filepath).touch() 313 with open(filepath) as file: 314 trusted_repos = tuple(line.strip() for line in file) 315 316 # To minimize friction of introducing the new trust_repo mechanism, we consider that 317 # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist) 318 trusted_repos_legacy = next(os.walk(hub_dir))[1] 319 320 owner_name = "_".join([repo_owner, repo_name]) 321 is_trusted = ( 322 owner_name in trusted_repos 323 or owner_name_branch in trusted_repos_legacy 324 or repo_owner in _TRUSTED_REPO_OWNERS 325 ) 326 327 # TODO: Remove `None` option in 2.0 and change the default to "check" 328 if trust_repo is None: 329 if not is_trusted: 330 warnings.warn( 331 "You are about to download and run code from an untrusted repository. In a future release, this won't " 332 "be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., " 333 "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, " 334 f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with " 335 f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for " 336 f"confirmation if the repo is not already trusted. This will eventually be the default behaviour" 337 ) 338 return 339 340 if (trust_repo is False) or (trust_repo == "check" and not is_trusted): 341 response = input( 342 f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. " 343 "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?" 344 ) 345 if response.lower() in ("y", "yes"): 346 if is_trusted: 347 print("The repository is already trusted.") 348 elif response.lower() in ("n", "no", ""): 349 raise Exception("Untrusted repository.") # noqa: TRY002 350 else: 351 raise ValueError(f"Unrecognized response {response}.") 352 353 # At this point we're sure that the user trusts the repo (or wants to trust it) 354 if not is_trusted: 355 with open(filepath, "a") as file: 356 file.write(owner_name + "\n") 357 358 359def _check_module_exists(name): 360 import importlib.util 361 362 return importlib.util.find_spec(name) is not None 363 364 365def _check_dependencies(m): 366 dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) 367 368 if dependencies is not None: 369 missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] 370 if len(missing_deps): 371 raise RuntimeError(f"Missing dependencies: {', '.join(missing_deps)}") 372 373 374def _load_entry_from_hubconf(m, model): 375 if not isinstance(model, str): 376 raise ValueError("Invalid input: model should be a string of function name") 377 378 # Note that if a missing dependency is imported at top level of hubconf, it will 379 # throw before this function. It's a chicken and egg situation where we have to 380 # load hubconf to know what're the dependencies, but to import hubconf it requires 381 # a missing package. This is fine, Python will throw proper error message for users. 382 _check_dependencies(m) 383 384 func = _load_attr_from_module(m, model) 385 386 if func is None or not callable(func): 387 raise RuntimeError(f"Cannot find callable {model} in hubconf") 388 389 return func 390 391 392def get_dir(): 393 r""" 394 Get the Torch Hub cache directory used for storing downloaded models & weights. 395 396 If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where 397 environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. 398 ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux 399 filesystem layout, with a default value ``~/.cache`` if the environment 400 variable is not set. 401 """ 402 # Issue warning to move data if old env is set 403 if os.getenv("TORCH_HUB"): 404 warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead") 405 406 if _hub_dir is not None: 407 return _hub_dir 408 return os.path.join(_get_torch_home(), "hub") 409 410 411def set_dir(d): 412 r""" 413 Optionally set the Torch Hub directory used to save downloaded models & weights. 414 415 Args: 416 d (str): path to a local folder to save downloaded models & weights. 417 """ 418 global _hub_dir 419 _hub_dir = os.path.expanduser(d) 420 421 422def list( 423 github, 424 force_reload=False, 425 skip_validation=False, 426 trust_repo=None, 427 verbose=True, 428): 429 r""" 430 List all callable entrypoints available in the repo specified by ``github``. 431 432 Args: 433 github (str): a string with format "repo_owner/repo_name[:ref]" with an optional 434 ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if 435 it exists, and otherwise ``master``. 436 Example: 'pytorch/vision:0.10' 437 force_reload (bool, optional): whether to discard the existing cache and force a fresh download. 438 Default is ``False``. 439 skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit 440 specified by the ``github`` argument properly belongs to the repo owner. This will make 441 requests to the GitHub API; you can specify a non-default GitHub token by setting the 442 ``GITHUB_TOKEN`` environment variable. Default is ``False``. 443 trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. 444 This parameter was introduced in v1.12 and helps ensuring that users 445 only run code from repos that they trust. 446 447 - If ``False``, a prompt will ask the user whether the repo should 448 be trusted. 449 - If ``True``, the repo will be added to the trusted list and loaded 450 without requiring explicit confirmation. 451 - If ``"check"``, the repo will be checked against the list of 452 trusted repos in the cache. If it is not present in that list, the 453 behaviour will fall back onto the ``trust_repo=False`` option. 454 - If ``None``: this will raise a warning, inviting the user to set 455 ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This 456 is only present for backward compatibility and will be removed in 457 v2.0. 458 459 Default is ``None`` and will eventually change to ``"check"`` in v2.0. 460 verbose (bool, optional): If ``False``, mute messages about hitting 461 local caches. Note that the message about first download cannot be 462 muted. Default is ``True``. 463 464 Returns: 465 list: The available callables entrypoint 466 467 Example: 468 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) 469 >>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True) 470 """ 471 repo_dir = _get_cache_or_reload( 472 github, 473 force_reload, 474 trust_repo, 475 "list", 476 verbose=verbose, 477 skip_validation=skip_validation, 478 ) 479 480 with _add_to_sys_path(repo_dir): 481 hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) 482 hub_module = _import_module(MODULE_HUBCONF, hubconf_path) 483 484 # We take functions starts with '_' as internal helper functions 485 entrypoints = [ 486 f 487 for f in dir(hub_module) 488 if callable(getattr(hub_module, f)) and not f.startswith("_") 489 ] 490 491 return entrypoints 492 493 494def help(github, model, force_reload=False, skip_validation=False, trust_repo=None): 495 r""" 496 Show the docstring of entrypoint ``model``. 497 498 Args: 499 github (str): a string with format <repo_owner/repo_name[:ref]> with an optional 500 ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed 501 to be ``main`` if it exists, and otherwise ``master``. 502 Example: 'pytorch/vision:0.10' 503 model (str): a string of entrypoint name defined in repo's ``hubconf.py`` 504 force_reload (bool, optional): whether to discard the existing cache and force a fresh download. 505 Default is ``False``. 506 skip_validation (bool, optional): if ``False``, torchhub will check that the ref 507 specified by the ``github`` argument properly belongs to the repo owner. This will make 508 requests to the GitHub API; you can specify a non-default GitHub token by setting the 509 ``GITHUB_TOKEN`` environment variable. Default is ``False``. 510 trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. 511 This parameter was introduced in v1.12 and helps ensuring that users 512 only run code from repos that they trust. 513 514 - If ``False``, a prompt will ask the user whether the repo should 515 be trusted. 516 - If ``True``, the repo will be added to the trusted list and loaded 517 without requiring explicit confirmation. 518 - If ``"check"``, the repo will be checked against the list of 519 trusted repos in the cache. If it is not present in that list, the 520 behaviour will fall back onto the ``trust_repo=False`` option. 521 - If ``None``: this will raise a warning, inviting the user to set 522 ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This 523 is only present for backward compatibility and will be removed in 524 v2.0. 525 526 Default is ``None`` and will eventually change to ``"check"`` in v2.0. 527 Example: 528 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) 529 >>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True)) 530 """ 531 repo_dir = _get_cache_or_reload( 532 github, 533 force_reload, 534 trust_repo, 535 "help", 536 verbose=True, 537 skip_validation=skip_validation, 538 ) 539 540 with _add_to_sys_path(repo_dir): 541 hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) 542 hub_module = _import_module(MODULE_HUBCONF, hubconf_path) 543 544 entry = _load_entry_from_hubconf(hub_module, model) 545 546 return entry.__doc__ 547 548 549def load( 550 repo_or_dir, 551 model, 552 *args, 553 source="github", 554 trust_repo=None, 555 force_reload=False, 556 verbose=True, 557 skip_validation=False, 558 **kwargs, 559): 560 r""" 561 Load a model from a github repo or a local directory. 562 563 Note: Loading a model is the typical use case, but this can also be used to 564 for loading other objects such as tokenizers, loss functions, etc. 565 566 If ``source`` is 'github', ``repo_or_dir`` is expected to be 567 of the form ``repo_owner/repo_name[:ref]`` with an optional 568 ref (a tag or a branch). 569 570 If ``source`` is 'local', ``repo_or_dir`` is expected to be a 571 path to a local directory. 572 573 Args: 574 repo_or_dir (str): If ``source`` is 'github', 575 this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with 576 an optional ref (tag or branch), for example 'pytorch/vision:0.10'. If ``ref`` is not specified, 577 the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. 578 If ``source`` is 'local' then it should be a path to a local directory. 579 model (str): the name of a callable (entrypoint) defined in the 580 repo/dir's ``hubconf.py``. 581 *args (optional): the corresponding args for callable ``model``. 582 source (str, optional): 'github' or 'local'. Specifies how 583 ``repo_or_dir`` is to be interpreted. Default is 'github'. 584 trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. 585 This parameter was introduced in v1.12 and helps ensuring that users 586 only run code from repos that they trust. 587 588 - If ``False``, a prompt will ask the user whether the repo should 589 be trusted. 590 - If ``True``, the repo will be added to the trusted list and loaded 591 without requiring explicit confirmation. 592 - If ``"check"``, the repo will be checked against the list of 593 trusted repos in the cache. If it is not present in that list, the 594 behaviour will fall back onto the ``trust_repo=False`` option. 595 - If ``None``: this will raise a warning, inviting the user to set 596 ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This 597 is only present for backward compatibility and will be removed in 598 v2.0. 599 600 Default is ``None`` and will eventually change to ``"check"`` in v2.0. 601 force_reload (bool, optional): whether to force a fresh download of 602 the github repo unconditionally. Does not have any effect if 603 ``source = 'local'``. Default is ``False``. 604 verbose (bool, optional): If ``False``, mute messages about hitting 605 local caches. Note that the message about first download cannot be 606 muted. Does not have any effect if ``source = 'local'``. 607 Default is ``True``. 608 skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit 609 specified by the ``github`` argument properly belongs to the repo owner. This will make 610 requests to the GitHub API; you can specify a non-default GitHub token by setting the 611 ``GITHUB_TOKEN`` environment variable. Default is ``False``. 612 **kwargs (optional): the corresponding kwargs for callable ``model``. 613 614 Returns: 615 The output of the ``model`` callable when called with the given 616 ``*args`` and ``**kwargs``. 617 618 Example: 619 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) 620 >>> # from a github repo 621 >>> repo = "pytorch/vision" 622 >>> model = torch.hub.load( 623 ... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1" 624 ... ) 625 >>> # from a local directory 626 >>> path = "/some/local/path/pytorch/vision" 627 >>> # xdoctest: +SKIP 628 >>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT") 629 """ 630 source = source.lower() 631 632 if source not in ("github", "local"): 633 raise ValueError( 634 f'Unknown source: "{source}". Allowed values: "github" | "local".' 635 ) 636 637 if source == "github": 638 repo_or_dir = _get_cache_or_reload( 639 repo_or_dir, 640 force_reload, 641 trust_repo, 642 "load", 643 verbose=verbose, 644 skip_validation=skip_validation, 645 ) 646 647 model = _load_local(repo_or_dir, model, *args, **kwargs) 648 return model 649 650 651def _load_local(hubconf_dir, model, *args, **kwargs): 652 r""" 653 Load a model from a local directory with a ``hubconf.py``. 654 655 Args: 656 hubconf_dir (str): path to a local directory that contains a 657 ``hubconf.py``. 658 model (str): name of an entrypoint defined in the directory's 659 ``hubconf.py``. 660 *args (optional): the corresponding args for callable ``model``. 661 **kwargs (optional): the corresponding kwargs for callable ``model``. 662 663 Returns: 664 a single model with corresponding pretrained weights. 665 666 Example: 667 >>> # xdoctest: +SKIP("stub local path") 668 >>> path = "/some/local/path/pytorch/vision" 669 >>> model = _load_local(path, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1") 670 """ 671 with _add_to_sys_path(hubconf_dir): 672 hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) 673 hub_module = _import_module(MODULE_HUBCONF, hubconf_path) 674 675 entry = _load_entry_from_hubconf(hub_module, model) 676 model = entry(*args, **kwargs) 677 678 return model 679 680 681def download_url_to_file( 682 url: str, 683 dst: str, 684 hash_prefix: Optional[str] = None, 685 progress: bool = True, 686) -> None: 687 r"""Download object at the given URL to a local path. 688 689 Args: 690 url (str): URL of the object to download 691 dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file`` 692 hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. 693 Default: None 694 progress (bool, optional): whether or not to display a progress bar to stderr 695 Default: True 696 697 Example: 698 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) 699 >>> # xdoctest: +REQUIRES(POSIX) 700 >>> torch.hub.download_url_to_file( 701 ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth", 702 ... "/tmp/temporary_file", 703 ... ) 704 705 """ 706 file_size = None 707 req = Request(url, headers={"User-Agent": "torch.hub"}) 708 u = urlopen(req) 709 meta = u.info() 710 if hasattr(meta, "getheaders"): 711 content_length = meta.getheaders("Content-Length") 712 else: 713 content_length = meta.get_all("Content-Length") 714 if content_length is not None and len(content_length) > 0: 715 file_size = int(content_length[0]) 716 717 # We deliberately save it in a temp file and move it after 718 # download is complete. This prevents a local working checkpoint 719 # being overridden by a broken download. 720 # We deliberately do not use NamedTemporaryFile to avoid restrictive 721 # file permissions being applied to the downloaded file. 722 dst = os.path.expanduser(dst) 723 for seq in range(tempfile.TMP_MAX): 724 tmp_dst = dst + "." + uuid.uuid4().hex + ".partial" 725 try: 726 f = open(tmp_dst, "w+b") 727 except FileExistsError: 728 continue 729 break 730 else: 731 raise FileExistsError(errno.EEXIST, "No usable temporary file name found") 732 733 try: 734 if hash_prefix is not None: 735 sha256 = hashlib.sha256() 736 with tqdm( 737 total=file_size, 738 disable=not progress, 739 unit="B", 740 unit_scale=True, 741 unit_divisor=1024, 742 ) as pbar: 743 while True: 744 buffer = u.read(READ_DATA_CHUNK) 745 if len(buffer) == 0: 746 break 747 f.write(buffer) # type: ignore[possibly-undefined] 748 if hash_prefix is not None: 749 sha256.update(buffer) # type: ignore[possibly-undefined] 750 pbar.update(len(buffer)) 751 752 f.close() 753 if hash_prefix is not None: 754 digest = sha256.hexdigest() # type: ignore[possibly-undefined] 755 if digest[: len(hash_prefix)] != hash_prefix: 756 raise RuntimeError( 757 f'invalid hash value (expected "{hash_prefix}", got "{digest}")' 758 ) 759 shutil.move(f.name, dst) 760 finally: 761 f.close() 762 if os.path.exists(f.name): 763 os.remove(f.name) 764 765 766# Hub used to support automatically extracts from zipfile manually compressed by users. 767# The legacy zip format expects only one file from torch.save() < 1.6 in the zip. 768# We should remove this support since zipfile is now default zipfile format for torch.save(). 769def _is_legacy_zip_format(filename: str) -> bool: 770 if zipfile.is_zipfile(filename): 771 infolist = zipfile.ZipFile(filename).infolist() 772 return len(infolist) == 1 and not infolist[0].is_dir() 773 return False 774 775 776@deprecated( 777 "Falling back to the old format < 1.6. This support will be " 778 "deprecated in favor of default zipfile format introduced in 1.6. " 779 "Please redo torch.save() to save it in the new zipfile format.", 780 category=FutureWarning, 781) 782def _legacy_zip_load( 783 filename: str, 784 model_dir: str, 785 map_location: MAP_LOCATION, 786 weights_only: bool, 787) -> Dict[str, Any]: 788 # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. 789 # We deliberately don't handle tarfile here since our legacy serialization format was in tar. 790 # E.g. resnet18-5c106cde.pth which is widely used. 791 with zipfile.ZipFile(filename) as f: 792 members = f.infolist() 793 if len(members) != 1: 794 raise RuntimeError("Only one file(not dir) is allowed in the zipfile") 795 f.extractall(model_dir) 796 extraced_name = members[0].filename 797 extracted_file = os.path.join(model_dir, extraced_name) 798 return torch.load( 799 extracted_file, map_location=map_location, weights_only=weights_only 800 ) 801 802 803def load_state_dict_from_url( 804 url: str, 805 model_dir: Optional[str] = None, 806 map_location: MAP_LOCATION = None, 807 progress: bool = True, 808 check_hash: bool = False, 809 file_name: Optional[str] = None, 810 weights_only: bool = False, 811) -> Dict[str, Any]: 812 r"""Loads the Torch serialized object at the given URL. 813 814 If downloaded file is a zip file, it will be automatically 815 decompressed. 816 817 If the object is already present in `model_dir`, it's deserialized and 818 returned. 819 The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where 820 ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. 821 822 Args: 823 url (str): URL of the object to download 824 model_dir (str, optional): directory in which to save the object 825 map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) 826 progress (bool, optional): whether or not to display a progress bar to stderr. 827 Default: True 828 check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention 829 ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more 830 digits of the SHA256 hash of the contents of the file. The hash is used to 831 ensure unique names and to verify the contents of the file. 832 Default: False 833 file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set. 834 weights_only(bool, optional): If True, only weights will be loaded and no complex pickled objects. 835 Recommended for untrusted sources. See :func:`~torch.load` for more details. 836 837 Example: 838 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) 839 >>> state_dict = torch.hub.load_state_dict_from_url( 840 ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth" 841 ... ) 842 843 """ 844 # Issue warning to move data if old env is set 845 if os.getenv("TORCH_MODEL_ZOO"): 846 warnings.warn( 847 "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead" 848 ) 849 850 if model_dir is None: 851 hub_dir = get_dir() 852 model_dir = os.path.join(hub_dir, "checkpoints") 853 854 os.makedirs(model_dir, exist_ok=True) 855 856 parts = urlparse(url) 857 filename = os.path.basename(parts.path) 858 if file_name is not None: 859 filename = file_name 860 cached_file = os.path.join(model_dir, filename) 861 if not os.path.exists(cached_file): 862 sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') 863 hash_prefix = None 864 if check_hash: 865 r = HASH_REGEX.search(filename) # r is Optional[Match[str]] 866 hash_prefix = r.group(1) if r else None 867 download_url_to_file(url, cached_file, hash_prefix, progress=progress) 868 869 if _is_legacy_zip_format(cached_file): 870 return _legacy_zip_load(cached_file, model_dir, map_location, weights_only) 871 return torch.load(cached_file, map_location=map_location, weights_only=weights_only) 872