1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-import-not-at-top 16"""Utilities for file download and caching.""" 17 18from abc import abstractmethod 19from contextlib import closing 20import functools 21import hashlib 22import multiprocessing 23import multiprocessing.dummy 24import os 25import queue 26import random 27import shutil 28import sys # pylint: disable=unused-import 29import tarfile 30import threading 31import time 32import typing 33import urllib 34import weakref 35import zipfile 36 37import numpy as np 38 39from tensorflow.python.framework import ops 40from six.moves.urllib.request import urlopen 41from tensorflow.python.keras.utils import tf_inspect 42from tensorflow.python.keras.utils.generic_utils import Progbar 43from tensorflow.python.keras.utils.io_utils import path_to_string 44from tensorflow.python.util.tf_export import keras_export 45 46# Required to support google internal urlretrieve 47if sys.version_info[0] == 2: 48 49 def urlretrieve(url, filename, reporthook=None, data=None): 50 """Replacement for `urlretrieve` for Python 2. 51 52 Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy 53 `urllib` module, known to have issues with proxy management. 54 55 Args: 56 url: url to retrieve. 57 filename: where to store the retrieved data locally. 58 reporthook: a hook function that will be called once on establishment of 59 the network connection and once after each block read thereafter. The 60 hook will be passed three arguments; a count of blocks transferred so 61 far, a block size in bytes, and the total size of the file. 62 data: `data` argument passed to `urlopen`. 63 """ 64 65 def chunk_read(response, chunk_size=8192, reporthook=None): 66 content_type = response.info().get('Content-Length') 67 total_size = -1 68 if content_type is not None: 69 total_size = int(content_type.strip()) 70 count = 0 71 while True: 72 chunk = response.read(chunk_size) 73 count += 1 74 if reporthook is not None: 75 reporthook(count, chunk_size, total_size) 76 if chunk: 77 yield chunk 78 else: 79 break 80 81 response = urlopen(url, data) 82 with open(filename, 'wb') as fd: 83 for chunk in chunk_read(response, reporthook=reporthook): 84 fd.write(chunk) 85else: 86 from urllib.request import urlretrieve # pylint: disable=g-importing-member 87 88 89def is_generator_or_sequence(x): 90 """Check if `x` is a Keras generator type.""" 91 builtin_iterators = (str, list, tuple, dict, set, frozenset) 92 if isinstance(x, (ops.Tensor, np.ndarray) + builtin_iterators): 93 return False 94 return (tf_inspect.isgenerator(x) or 95 isinstance(x, Sequence) or 96 isinstance(x, typing.Iterator)) 97 98 99def _extract_archive(file_path, path='.', archive_format='auto'): 100 """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. 101 102 Args: 103 file_path: path to the archive file 104 path: path to extract the archive file 105 archive_format: Archive format to try for extracting the file. 106 Options are 'auto', 'tar', 'zip', and None. 107 'tar' includes tar, tar.gz, and tar.bz files. 108 The default 'auto' is ['tar', 'zip']. 109 None or an empty list will return no matches found. 110 111 Returns: 112 True if a match was found and an archive extraction was completed, 113 False otherwise. 114 """ 115 if archive_format is None: 116 return False 117 if archive_format == 'auto': 118 archive_format = ['tar', 'zip'] 119 if isinstance(archive_format, str): 120 archive_format = [archive_format] 121 122 file_path = path_to_string(file_path) 123 path = path_to_string(path) 124 125 for archive_type in archive_format: 126 if archive_type == 'tar': 127 open_fn = tarfile.open 128 is_match_fn = tarfile.is_tarfile 129 if archive_type == 'zip': 130 open_fn = zipfile.ZipFile 131 is_match_fn = zipfile.is_zipfile 132 133 if is_match_fn(file_path): 134 with open_fn(file_path) as archive: 135 try: 136 archive.extractall(path) 137 except (tarfile.TarError, RuntimeError, KeyboardInterrupt): 138 if os.path.exists(path): 139 if os.path.isfile(path): 140 os.remove(path) 141 else: 142 shutil.rmtree(path) 143 raise 144 return True 145 return False 146 147 148@keras_export('keras.utils.get_file') 149def get_file(fname, 150 origin, 151 untar=False, 152 md5_hash=None, 153 file_hash=None, 154 cache_subdir='datasets', 155 hash_algorithm='auto', 156 extract=False, 157 archive_format='auto', 158 cache_dir=None): 159 """Downloads a file from a URL if it not already in the cache. 160 161 By default the file at the url `origin` is downloaded to the 162 cache_dir `~/.keras`, placed in the cache_subdir `datasets`, 163 and given the filename `fname`. The final location of a file 164 `example.txt` would therefore be `~/.keras/datasets/example.txt`. 165 166 Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. 167 Passing a hash will verify the file after download. The command line 168 programs `shasum` and `sha256sum` can compute the hash. 169 170 Example: 171 172 ```python 173 path_to_downloaded_file = tf.keras.utils.get_file( 174 "flower_photos", 175 "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", 176 untar=True) 177 ``` 178 179 Args: 180 fname: Name of the file. If an absolute path `/path/to/file.txt` is 181 specified the file will be saved at that location. 182 origin: Original URL of the file. 183 untar: Deprecated in favor of `extract` argument. 184 boolean, whether the file should be decompressed 185 md5_hash: Deprecated in favor of `file_hash` argument. 186 md5 hash of the file for verification 187 file_hash: The expected hash string of the file after download. 188 The sha256 and md5 hash algorithms are both supported. 189 cache_subdir: Subdirectory under the Keras cache dir where the file is 190 saved. If an absolute path `/path/to/folder` is 191 specified the file will be saved at that location. 192 hash_algorithm: Select the hash algorithm to verify the file. 193 options are `'md5'`, `'sha256'`, and `'auto'`. 194 The default 'auto' detects the hash algorithm in use. 195 extract: True tries extracting the file as an Archive, like tar or zip. 196 archive_format: Archive format to try for extracting the file. 197 Options are `'auto'`, `'tar'`, `'zip'`, and `None`. 198 `'tar'` includes tar, tar.gz, and tar.bz files. 199 The default `'auto'` corresponds to `['tar', 'zip']`. 200 None or an empty list will return no matches found. 201 cache_dir: Location to store cached files, when None it 202 defaults to the default directory `~/.keras/`. 203 204 Returns: 205 Path to the downloaded file 206 """ 207 if cache_dir is None: 208 cache_dir = os.path.join(os.path.expanduser('~'), '.keras') 209 if md5_hash is not None and file_hash is None: 210 file_hash = md5_hash 211 hash_algorithm = 'md5' 212 datadir_base = os.path.expanduser(cache_dir) 213 if not os.access(datadir_base, os.W_OK): 214 datadir_base = os.path.join('/tmp', '.keras') 215 datadir = os.path.join(datadir_base, cache_subdir) 216 _makedirs_exist_ok(datadir) 217 218 fname = path_to_string(fname) 219 220 if untar: 221 untar_fpath = os.path.join(datadir, fname) 222 fpath = untar_fpath + '.tar.gz' 223 else: 224 fpath = os.path.join(datadir, fname) 225 226 download = False 227 if os.path.exists(fpath): 228 # File found; verify integrity if a hash was provided. 229 if file_hash is not None: 230 if not validate_file(fpath, file_hash, algorithm=hash_algorithm): 231 print('A local file was found, but it seems to be ' 232 'incomplete or outdated because the ' + hash_algorithm + 233 ' file hash does not match the original value of ' + file_hash + 234 ' so we will re-download the data.') 235 download = True 236 else: 237 download = True 238 239 if download: 240 print('Downloading data from', origin) 241 242 class ProgressTracker(object): 243 # Maintain progbar for the lifetime of download. 244 # This design was chosen for Python 2.7 compatibility. 245 progbar = None 246 247 def dl_progress(count, block_size, total_size): 248 if ProgressTracker.progbar is None: 249 if total_size == -1: 250 total_size = None 251 ProgressTracker.progbar = Progbar(total_size) 252 else: 253 ProgressTracker.progbar.update(count * block_size) 254 255 error_msg = 'URL fetch failure on {}: {} -- {}' 256 try: 257 try: 258 urlretrieve(origin, fpath, dl_progress) 259 except urllib.error.HTTPError as e: 260 raise Exception(error_msg.format(origin, e.code, e.msg)) 261 except urllib.error.URLError as e: 262 raise Exception(error_msg.format(origin, e.errno, e.reason)) 263 except (Exception, KeyboardInterrupt) as e: 264 if os.path.exists(fpath): 265 os.remove(fpath) 266 raise 267 ProgressTracker.progbar = None 268 269 if untar: 270 if not os.path.exists(untar_fpath): 271 _extract_archive(fpath, datadir, archive_format='tar') 272 return untar_fpath 273 274 if extract: 275 _extract_archive(fpath, datadir, archive_format) 276 277 return fpath 278 279 280def _makedirs_exist_ok(datadir): 281 os.makedirs(datadir, exist_ok=True) # pylint: disable=unexpected-keyword-arg 282 283 284def _resolve_hasher(algorithm, file_hash=None): 285 """Returns hash algorithm as hashlib function.""" 286 if algorithm == 'sha256': 287 return hashlib.sha256() 288 289 if algorithm == 'auto' and file_hash is not None and len(file_hash) == 64: 290 return hashlib.sha256() 291 292 # This is used only for legacy purposes. 293 return hashlib.md5() 294 295 296def _hash_file(fpath, algorithm='sha256', chunk_size=65535): 297 """Calculates a file sha256 or md5 hash. 298 299 Example: 300 301 ```python 302 _hash_file('/path/to/file.zip') 303 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' 304 ``` 305 306 Args: 307 fpath: path to the file being validated 308 algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`. 309 The default `'auto'` detects the hash algorithm in use. 310 chunk_size: Bytes to read at a time, important for large files. 311 312 Returns: 313 The file hash 314 """ 315 if isinstance(algorithm, str): 316 hasher = _resolve_hasher(algorithm) 317 else: 318 hasher = algorithm 319 320 with open(fpath, 'rb') as fpath_file: 321 for chunk in iter(lambda: fpath_file.read(chunk_size), b''): 322 hasher.update(chunk) 323 324 return hasher.hexdigest() 325 326 327def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): 328 """Validates a file against a sha256 or md5 hash. 329 330 Args: 331 fpath: path to the file being validated 332 file_hash: The expected hash string of the file. 333 The sha256 and md5 hash algorithms are both supported. 334 algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'. 335 The default 'auto' detects the hash algorithm in use. 336 chunk_size: Bytes to read at a time, important for large files. 337 338 Returns: 339 Whether the file is valid 340 """ 341 hasher = _resolve_hasher(algorithm, file_hash) 342 343 if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash): 344 return True 345 else: 346 return False 347 348 349class ThreadsafeIter(object): 350 """Wrap an iterator with a lock and propagate exceptions to all threads.""" 351 352 def __init__(self, it): 353 self.it = it 354 self.lock = threading.Lock() 355 356 # After a generator throws an exception all subsequent next() calls raise a 357 # StopIteration Exception. This, however, presents an issue when mixing 358 # generators and threading because it means the order of retrieval need not 359 # match the order in which the generator was called. This can make it appear 360 # that a generator exited normally when in fact the terminating exception is 361 # just in a different thread. In order to provide thread safety, once 362 # self.it has thrown an exception we continue to throw the same exception. 363 self._exception = None 364 365 def __iter__(self): 366 return self 367 368 def next(self): 369 return self.__next__() 370 371 def __next__(self): 372 with self.lock: 373 if self._exception: 374 raise self._exception # pylint: disable=raising-bad-type 375 376 try: 377 return next(self.it) 378 except Exception as e: 379 self._exception = e 380 raise 381 382 383def threadsafe_generator(f): 384 385 @functools.wraps(f) 386 def g(*a, **kw): 387 return ThreadsafeIter(f(*a, **kw)) 388 389 return g 390 391 392@keras_export('keras.utils.Sequence') 393class Sequence(object): 394 """Base object for fitting to a sequence of data, such as a dataset. 395 396 Every `Sequence` must implement the `__getitem__` and the `__len__` methods. 397 If you want to modify your dataset between epochs you may implement 398 `on_epoch_end`. 399 The method `__getitem__` should return a complete batch. 400 401 Notes: 402 403 `Sequence` are a safer way to do multiprocessing. This structure guarantees 404 that the network will only train once 405 on each sample per epoch which is not the case with generators. 406 407 Examples: 408 409 ```python 410 from skimage.io import imread 411 from skimage.transform import resize 412 import numpy as np 413 import math 414 415 # Here, `x_set` is list of path to the images 416 # and `y_set` are the associated classes. 417 418 class CIFAR10Sequence(Sequence): 419 420 def __init__(self, x_set, y_set, batch_size): 421 self.x, self.y = x_set, y_set 422 self.batch_size = batch_size 423 424 def __len__(self): 425 return math.ceil(len(self.x) / self.batch_size) 426 427 def __getitem__(self, idx): 428 batch_x = self.x[idx * self.batch_size:(idx + 1) * 429 self.batch_size] 430 batch_y = self.y[idx * self.batch_size:(idx + 1) * 431 self.batch_size] 432 433 return np.array([ 434 resize(imread(file_name), (200, 200)) 435 for file_name in batch_x]), np.array(batch_y) 436 ``` 437 """ 438 439 @abstractmethod 440 def __getitem__(self, index): 441 """Gets batch at position `index`. 442 443 Args: 444 index: position of the batch in the Sequence. 445 446 Returns: 447 A batch 448 """ 449 raise NotImplementedError 450 451 @abstractmethod 452 def __len__(self): 453 """Number of batch in the Sequence. 454 455 Returns: 456 The number of batches in the Sequence. 457 """ 458 raise NotImplementedError 459 460 def on_epoch_end(self): 461 """Method called at the end of every epoch. 462 """ 463 pass 464 465 def __iter__(self): 466 """Create a generator that iterate over the Sequence.""" 467 for item in (self[i] for i in range(len(self))): 468 yield item 469 470 471def iter_sequence_infinite(seq): 472 """Iterates indefinitely over a Sequence. 473 474 Args: 475 seq: `Sequence` instance. 476 477 Yields: 478 Batches of data from the `Sequence`. 479 """ 480 while True: 481 for item in seq: 482 yield item 483 484 485# Global variables to be shared across processes 486_SHARED_SEQUENCES = {} 487# We use a Value to provide unique id to different processes. 488_SEQUENCE_COUNTER = None 489 490 491# Because multiprocessing pools are inherently unsafe, starting from a clean 492# state can be essential to avoiding deadlocks. In order to accomplish this, we 493# need to be able to check on the status of Pools that we create. 494_DATA_POOLS = weakref.WeakSet() 495_WORKER_ID_QUEUE = None # Only created if needed. 496_WORKER_IDS = set() 497_FORCE_THREADPOOL = False 498_FORCE_THREADPOOL_LOCK = threading.RLock() 499 500 501def dont_use_multiprocessing_pool(f): 502 @functools.wraps(f) 503 def wrapped(*args, **kwargs): 504 with _FORCE_THREADPOOL_LOCK: 505 global _FORCE_THREADPOOL 506 old_force_threadpool, _FORCE_THREADPOOL = _FORCE_THREADPOOL, True 507 out = f(*args, **kwargs) 508 _FORCE_THREADPOOL = old_force_threadpool 509 return out 510 return wrapped 511 512 513def get_pool_class(use_multiprocessing): 514 global _FORCE_THREADPOOL 515 if not use_multiprocessing or _FORCE_THREADPOOL: 516 return multiprocessing.dummy.Pool # ThreadPool 517 return multiprocessing.Pool 518 519 520def get_worker_id_queue(): 521 """Lazily create the queue to track worker ids.""" 522 global _WORKER_ID_QUEUE 523 if _WORKER_ID_QUEUE is None: 524 _WORKER_ID_QUEUE = multiprocessing.Queue() 525 return _WORKER_ID_QUEUE 526 527 528def init_pool(seqs): 529 global _SHARED_SEQUENCES 530 _SHARED_SEQUENCES = seqs 531 532 533def get_index(uid, i): 534 """Get the value from the Sequence `uid` at index `i`. 535 536 To allow multiple Sequences to be used at the same time, we use `uid` to 537 get a specific one. A single Sequence would cause the validation to 538 overwrite the training Sequence. 539 540 Args: 541 uid: int, Sequence identifier 542 i: index 543 544 Returns: 545 The value at index `i`. 546 """ 547 return _SHARED_SEQUENCES[uid][i] 548 549 550@keras_export('keras.utils.SequenceEnqueuer') 551class SequenceEnqueuer(object): 552 """Base class to enqueue inputs. 553 554 The task of an Enqueuer is to use parallelism to speed up preprocessing. 555 This is done with processes or threads. 556 557 Example: 558 559 ```python 560 enqueuer = SequenceEnqueuer(...) 561 enqueuer.start() 562 datas = enqueuer.get() 563 for data in datas: 564 # Use the inputs; training, evaluating, predicting. 565 # ... stop sometime. 566 enqueuer.stop() 567 ``` 568 569 The `enqueuer.get()` should be an infinite stream of datas. 570 """ 571 572 def __init__(self, sequence, 573 use_multiprocessing=False): 574 self.sequence = sequence 575 self.use_multiprocessing = use_multiprocessing 576 577 global _SEQUENCE_COUNTER 578 if _SEQUENCE_COUNTER is None: 579 try: 580 _SEQUENCE_COUNTER = multiprocessing.Value('i', 0) 581 except OSError: 582 # In this case the OS does not allow us to use 583 # multiprocessing. We resort to an int 584 # for enqueuer indexing. 585 _SEQUENCE_COUNTER = 0 586 587 if isinstance(_SEQUENCE_COUNTER, int): 588 self.uid = _SEQUENCE_COUNTER 589 _SEQUENCE_COUNTER += 1 590 else: 591 # Doing Multiprocessing.Value += x is not process-safe. 592 with _SEQUENCE_COUNTER.get_lock(): 593 self.uid = _SEQUENCE_COUNTER.value 594 _SEQUENCE_COUNTER.value += 1 595 596 self.workers = 0 597 self.executor_fn = None 598 self.queue = None 599 self.run_thread = None 600 self.stop_signal = None 601 602 def is_running(self): 603 return self.stop_signal is not None and not self.stop_signal.is_set() 604 605 def start(self, workers=1, max_queue_size=10): 606 """Starts the handler's workers. 607 608 Args: 609 workers: Number of workers. 610 max_queue_size: queue size 611 (when full, workers could block on `put()`) 612 """ 613 if self.use_multiprocessing: 614 self.executor_fn = self._get_executor_init(workers) 615 else: 616 # We do not need the init since it's threads. 617 self.executor_fn = lambda _: get_pool_class(False)(workers) 618 self.workers = workers 619 self.queue = queue.Queue(max_queue_size) 620 self.stop_signal = threading.Event() 621 self.run_thread = threading.Thread(target=self._run) 622 self.run_thread.daemon = True 623 self.run_thread.start() 624 625 def _send_sequence(self): 626 """Sends current Iterable to all workers.""" 627 # For new processes that may spawn 628 _SHARED_SEQUENCES[self.uid] = self.sequence 629 630 def stop(self, timeout=None): 631 """Stops running threads and wait for them to exit, if necessary. 632 633 Should be called by the same thread which called `start()`. 634 635 Args: 636 timeout: maximum time to wait on `thread.join()` 637 """ 638 self.stop_signal.set() 639 with self.queue.mutex: 640 self.queue.queue.clear() 641 self.queue.unfinished_tasks = 0 642 self.queue.not_full.notify() 643 self.run_thread.join(timeout) 644 _SHARED_SEQUENCES[self.uid] = None 645 646 def __del__(self): 647 if self.is_running(): 648 self.stop() 649 650 @abstractmethod 651 def _run(self): 652 """Submits request to the executor and queue the `Future` objects.""" 653 raise NotImplementedError 654 655 @abstractmethod 656 def _get_executor_init(self, workers): 657 """Gets the Pool initializer for multiprocessing. 658 659 Args: 660 workers: Number of workers. 661 662 Returns: 663 Function, a Function to initialize the pool 664 """ 665 raise NotImplementedError 666 667 @abstractmethod 668 def get(self): 669 """Creates a generator to extract data from the queue. 670 671 Skip the data if it is `None`. 672 # Returns 673 Generator yielding tuples `(inputs, targets)` 674 or `(inputs, targets, sample_weights)`. 675 """ 676 raise NotImplementedError 677 678 679@keras_export('keras.utils.OrderedEnqueuer') 680class OrderedEnqueuer(SequenceEnqueuer): 681 """Builds a Enqueuer from a Sequence. 682 683 Args: 684 sequence: A `tf.keras.utils.data_utils.Sequence` object. 685 use_multiprocessing: use multiprocessing if True, otherwise threading 686 shuffle: whether to shuffle the data at the beginning of each epoch 687 """ 688 689 def __init__(self, sequence, use_multiprocessing=False, shuffle=False): 690 super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing) 691 self.shuffle = shuffle 692 693 def _get_executor_init(self, workers): 694 """Gets the Pool initializer for multiprocessing. 695 696 Args: 697 workers: Number of workers. 698 699 Returns: 700 Function, a Function to initialize the pool 701 """ 702 def pool_fn(seqs): 703 pool = get_pool_class(True)( 704 workers, initializer=init_pool_generator, 705 initargs=(seqs, None, get_worker_id_queue())) 706 _DATA_POOLS.add(pool) 707 return pool 708 709 return pool_fn 710 711 def _wait_queue(self): 712 """Wait for the queue to be empty.""" 713 while True: 714 time.sleep(0.1) 715 if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): 716 return 717 718 def _run(self): 719 """Submits request to the executor and queue the `Future` objects.""" 720 sequence = list(range(len(self.sequence))) 721 self._send_sequence() # Share the initial sequence 722 while True: 723 if self.shuffle: 724 random.shuffle(sequence) 725 726 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 727 for i in sequence: 728 if self.stop_signal.is_set(): 729 return 730 731 self.queue.put( 732 executor.apply_async(get_index, (self.uid, i)), block=True) 733 734 # Done with the current epoch, waiting for the final batches 735 self._wait_queue() 736 737 if self.stop_signal.is_set(): 738 # We're done 739 return 740 741 # Call the internal on epoch end. 742 self.sequence.on_epoch_end() 743 self._send_sequence() # Update the pool 744 745 def get(self): 746 """Creates a generator to extract data from the queue. 747 748 Skip the data if it is `None`. 749 750 Yields: 751 The next element in the queue, i.e. a tuple 752 `(inputs, targets)` or 753 `(inputs, targets, sample_weights)`. 754 """ 755 while self.is_running(): 756 try: 757 inputs = self.queue.get(block=True, timeout=5).get() 758 if self.is_running(): 759 self.queue.task_done() 760 if inputs is not None: 761 yield inputs 762 except queue.Empty: 763 pass 764 except Exception as e: # pylint: disable=broad-except 765 self.stop() 766 raise e 767 768 769def init_pool_generator(gens, random_seed=None, id_queue=None): 770 """Initializer function for pool workers. 771 772 Args: 773 gens: State which should be made available to worker processes. 774 random_seed: An optional value with which to seed child processes. 775 id_queue: A multiprocessing Queue of worker ids. This is used to indicate 776 that a worker process was created by Keras and can be terminated using 777 the cleanup_all_keras_forkpools utility. 778 """ 779 global _SHARED_SEQUENCES 780 _SHARED_SEQUENCES = gens 781 782 worker_proc = multiprocessing.current_process() 783 784 # name isn't used for anything, but setting a more descriptive name is helpful 785 # when diagnosing orphaned processes. 786 worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name) 787 788 if random_seed is not None: 789 np.random.seed(random_seed + worker_proc.ident) 790 791 if id_queue is not None: 792 # If a worker dies during init, the pool will just create a replacement. 793 id_queue.put(worker_proc.ident, block=True, timeout=0.1) 794 795 796def next_sample(uid): 797 """Gets the next value from the generator `uid`. 798 799 To allow multiple generators to be used at the same time, we use `uid` to 800 get a specific one. A single generator would cause the validation to 801 overwrite the training generator. 802 803 Args: 804 uid: int, generator identifier 805 806 Returns: 807 The next value of generator `uid`. 808 """ 809 return next(_SHARED_SEQUENCES[uid]) 810 811 812@keras_export('keras.utils.GeneratorEnqueuer') 813class GeneratorEnqueuer(SequenceEnqueuer): 814 """Builds a queue out of a data generator. 815 816 The provided generator can be finite in which case the class will throw 817 a `StopIteration` exception. 818 819 Args: 820 generator: a generator function which yields data 821 use_multiprocessing: use multiprocessing if True, otherwise threading 822 random_seed: Initial seed for workers, 823 will be incremented by one for each worker. 824 """ 825 826 def __init__(self, generator, 827 use_multiprocessing=False, 828 random_seed=None): 829 super(GeneratorEnqueuer, self).__init__(generator, use_multiprocessing) 830 self.random_seed = random_seed 831 832 def _get_executor_init(self, workers): 833 """Gets the Pool initializer for multiprocessing. 834 835 Args: 836 workers: Number of works. 837 838 Returns: 839 A Function to initialize the pool 840 """ 841 def pool_fn(seqs): 842 pool = get_pool_class(True)( 843 workers, initializer=init_pool_generator, 844 initargs=(seqs, self.random_seed, get_worker_id_queue())) 845 _DATA_POOLS.add(pool) 846 return pool 847 return pool_fn 848 849 def _run(self): 850 """Submits request to the executor and queue the `Future` objects.""" 851 self._send_sequence() # Share the initial generator 852 with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: 853 while True: 854 if self.stop_signal.is_set(): 855 return 856 857 self.queue.put( 858 executor.apply_async(next_sample, (self.uid,)), block=True) 859 860 def get(self): 861 """Creates a generator to extract data from the queue. 862 863 Skip the data if it is `None`. 864 865 Yields: 866 The next element in the queue, i.e. a tuple 867 `(inputs, targets)` or 868 `(inputs, targets, sample_weights)`. 869 """ 870 try: 871 while self.is_running(): 872 inputs = self.queue.get(block=True).get() 873 self.queue.task_done() 874 if inputs is not None: 875 yield inputs 876 except StopIteration: 877 # Special case for finite generators 878 last_ones = [] 879 while self.queue.qsize() > 0: 880 last_ones.append(self.queue.get(block=True)) 881 # Wait for them to complete 882 for f in last_ones: 883 f.wait() 884 # Keep the good ones 885 last_ones = [future.get() for future in last_ones if future.successful()] 886 for inputs in last_ones: 887 if inputs is not None: 888 yield inputs 889 except Exception as e: # pylint: disable=broad-except 890 self.stop() 891 if 'generator already executing' in str(e): 892 raise RuntimeError( 893 'Your generator is NOT thread-safe. ' 894 'Keras requires a thread-safe generator when ' 895 '`use_multiprocessing=False, workers > 1`. ') 896 raise e 897