import argparse import gzip import os import sys from urllib.error import URLError from urllib.request import urlretrieve MIRRORS = [ "http://yann.lecun.com/exdb/mnist/", "https://ossci-datasets.s3.amazonaws.com/mnist/", ] RESOURCES = [ "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", ] def report_download_progress( chunk_number: int, chunk_size: int, file_size: int, ) -> None: if file_size != -1: percent = min(1, (chunk_number * chunk_size) / file_size) bar = "#" * int(64 * percent) sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%") def download(destination_path: str, resource: str, quiet: bool) -> None: if os.path.exists(destination_path): if not quiet: print(f"{destination_path} already exists, skipping ...") else: for mirror in MIRRORS: url = mirror + resource print(f"Downloading {url} ...") try: hook = None if quiet else report_download_progress urlretrieve(url, destination_path, reporthook=hook) except (URLError, ConnectionError) as e: print(f"Failed to download (trying next):\n{e}") continue finally: if not quiet: # Just a newline. print() break else: raise RuntimeError("Error downloading resource!") def unzip(zipped_path: str, quiet: bool) -> None: unzipped_path = os.path.splitext(zipped_path)[0] if os.path.exists(unzipped_path): if not quiet: print(f"{unzipped_path} already exists, skipping ... ") return with gzip.open(zipped_path, "rb") as zipped_file: with open(unzipped_path, "wb") as unzipped_file: unzipped_file.write(zipped_file.read()) if not quiet: print(f"Unzipped {zipped_path} ...") def main() -> None: parser = argparse.ArgumentParser( description="Download the MNIST dataset from the internet" ) parser.add_argument( "-d", "--destination", default=".", help="Destination directory" ) parser.add_argument( "-q", "--quiet", action="store_true", help="Don't report about progress" ) options = parser.parse_args() if not os.path.exists(options.destination): os.makedirs(options.destination) try: for resource in RESOURCES: path = os.path.join(options.destination, resource) download(path, resource, options.quiet) unzip(path, options.quiet) except KeyboardInterrupt: print("Interrupted") if __name__ == "__main__": main()