xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/utils/data_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-import-not-at-top
16"""Utilities for file download and caching."""
17
18from abc import abstractmethod
19from contextlib import closing
20import functools
21import hashlib
22import multiprocessing
23import multiprocessing.dummy
24import os
25import queue
26import random
27import shutil
28import sys  # pylint: disable=unused-import
29import tarfile
30import threading
31import time
32import typing
33import urllib
34import weakref
35import zipfile
36
37import numpy as np
38
39from tensorflow.python.framework import ops
40from six.moves.urllib.request import urlopen
41from tensorflow.python.keras.utils import tf_inspect
42from tensorflow.python.keras.utils.generic_utils import Progbar
43from tensorflow.python.keras.utils.io_utils import path_to_string
44from tensorflow.python.util.tf_export import keras_export
45
46# Required to support google internal urlretrieve
47if sys.version_info[0] == 2:
48
49  def urlretrieve(url, filename, reporthook=None, data=None):
50    """Replacement for `urlretrieve` for Python 2.
51
52    Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy
53    `urllib` module, known to have issues with proxy management.
54
55    Args:
56        url: url to retrieve.
57        filename: where to store the retrieved data locally.
58        reporthook: a hook function that will be called once on establishment of
59          the network connection and once after each block read thereafter. The
60          hook will be passed three arguments; a count of blocks transferred so
61          far, a block size in bytes, and the total size of the file.
62        data: `data` argument passed to `urlopen`.
63    """
64
65    def chunk_read(response, chunk_size=8192, reporthook=None):
66      content_type = response.info().get('Content-Length')
67      total_size = -1
68      if content_type is not None:
69        total_size = int(content_type.strip())
70      count = 0
71      while True:
72        chunk = response.read(chunk_size)
73        count += 1
74        if reporthook is not None:
75          reporthook(count, chunk_size, total_size)
76        if chunk:
77          yield chunk
78        else:
79          break
80
81    response = urlopen(url, data)
82    with open(filename, 'wb') as fd:
83      for chunk in chunk_read(response, reporthook=reporthook):
84        fd.write(chunk)
85else:
86  from urllib.request import urlretrieve  # pylint: disable=g-importing-member
87
88
89def is_generator_or_sequence(x):
90  """Check if `x` is a Keras generator type."""
91  builtin_iterators = (str, list, tuple, dict, set, frozenset)
92  if isinstance(x, (ops.Tensor, np.ndarray) + builtin_iterators):
93    return False
94  return (tf_inspect.isgenerator(x) or
95          isinstance(x, Sequence) or
96          isinstance(x, typing.Iterator))
97
98
99def _extract_archive(file_path, path='.', archive_format='auto'):
100  """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
101
102  Args:
103      file_path: path to the archive file
104      path: path to extract the archive file
105      archive_format: Archive format to try for extracting the file.
106          Options are 'auto', 'tar', 'zip', and None.
107          'tar' includes tar, tar.gz, and tar.bz files.
108          The default 'auto' is ['tar', 'zip'].
109          None or an empty list will return no matches found.
110
111  Returns:
112      True if a match was found and an archive extraction was completed,
113      False otherwise.
114  """
115  if archive_format is None:
116    return False
117  if archive_format == 'auto':
118    archive_format = ['tar', 'zip']
119  if isinstance(archive_format, str):
120    archive_format = [archive_format]
121
122  file_path = path_to_string(file_path)
123  path = path_to_string(path)
124
125  for archive_type in archive_format:
126    if archive_type == 'tar':
127      open_fn = tarfile.open
128      is_match_fn = tarfile.is_tarfile
129    if archive_type == 'zip':
130      open_fn = zipfile.ZipFile
131      is_match_fn = zipfile.is_zipfile
132
133    if is_match_fn(file_path):
134      with open_fn(file_path) as archive:
135        try:
136          archive.extractall(path)
137        except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
138          if os.path.exists(path):
139            if os.path.isfile(path):
140              os.remove(path)
141            else:
142              shutil.rmtree(path)
143          raise
144      return True
145  return False
146
147
148@keras_export('keras.utils.get_file')
149def get_file(fname,
150             origin,
151             untar=False,
152             md5_hash=None,
153             file_hash=None,
154             cache_subdir='datasets',
155             hash_algorithm='auto',
156             extract=False,
157             archive_format='auto',
158             cache_dir=None):
159  """Downloads a file from a URL if it not already in the cache.
160
161  By default the file at the url `origin` is downloaded to the
162  cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
163  and given the filename `fname`. The final location of a file
164  `example.txt` would therefore be `~/.keras/datasets/example.txt`.
165
166  Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
167  Passing a hash will verify the file after download. The command line
168  programs `shasum` and `sha256sum` can compute the hash.
169
170  Example:
171
172  ```python
173  path_to_downloaded_file = tf.keras.utils.get_file(
174      "flower_photos",
175      "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
176      untar=True)
177  ```
178
179  Args:
180      fname: Name of the file. If an absolute path `/path/to/file.txt` is
181          specified the file will be saved at that location.
182      origin: Original URL of the file.
183      untar: Deprecated in favor of `extract` argument.
184          boolean, whether the file should be decompressed
185      md5_hash: Deprecated in favor of `file_hash` argument.
186          md5 hash of the file for verification
187      file_hash: The expected hash string of the file after download.
188          The sha256 and md5 hash algorithms are both supported.
189      cache_subdir: Subdirectory under the Keras cache dir where the file is
190          saved. If an absolute path `/path/to/folder` is
191          specified the file will be saved at that location.
192      hash_algorithm: Select the hash algorithm to verify the file.
193          options are `'md5'`, `'sha256'`, and `'auto'`.
194          The default 'auto' detects the hash algorithm in use.
195      extract: True tries extracting the file as an Archive, like tar or zip.
196      archive_format: Archive format to try for extracting the file.
197          Options are `'auto'`, `'tar'`, `'zip'`, and `None`.
198          `'tar'` includes tar, tar.gz, and tar.bz files.
199          The default `'auto'` corresponds to `['tar', 'zip']`.
200          None or an empty list will return no matches found.
201      cache_dir: Location to store cached files, when None it
202          defaults to the default directory `~/.keras/`.
203
204  Returns:
205      Path to the downloaded file
206  """
207  if cache_dir is None:
208    cache_dir = os.path.join(os.path.expanduser('~'), '.keras')
209  if md5_hash is not None and file_hash is None:
210    file_hash = md5_hash
211    hash_algorithm = 'md5'
212  datadir_base = os.path.expanduser(cache_dir)
213  if not os.access(datadir_base, os.W_OK):
214    datadir_base = os.path.join('/tmp', '.keras')
215  datadir = os.path.join(datadir_base, cache_subdir)
216  _makedirs_exist_ok(datadir)
217
218  fname = path_to_string(fname)
219
220  if untar:
221    untar_fpath = os.path.join(datadir, fname)
222    fpath = untar_fpath + '.tar.gz'
223  else:
224    fpath = os.path.join(datadir, fname)
225
226  download = False
227  if os.path.exists(fpath):
228    # File found; verify integrity if a hash was provided.
229    if file_hash is not None:
230      if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
231        print('A local file was found, but it seems to be '
232              'incomplete or outdated because the ' + hash_algorithm +
233              ' file hash does not match the original value of ' + file_hash +
234              ' so we will re-download the data.')
235        download = True
236  else:
237    download = True
238
239  if download:
240    print('Downloading data from', origin)
241
242    class ProgressTracker(object):
243      # Maintain progbar for the lifetime of download.
244      # This design was chosen for Python 2.7 compatibility.
245      progbar = None
246
247    def dl_progress(count, block_size, total_size):
248      if ProgressTracker.progbar is None:
249        if total_size == -1:
250          total_size = None
251        ProgressTracker.progbar = Progbar(total_size)
252      else:
253        ProgressTracker.progbar.update(count * block_size)
254
255    error_msg = 'URL fetch failure on {}: {} -- {}'
256    try:
257      try:
258        urlretrieve(origin, fpath, dl_progress)
259      except urllib.error.HTTPError as e:
260        raise Exception(error_msg.format(origin, e.code, e.msg))
261      except urllib.error.URLError as e:
262        raise Exception(error_msg.format(origin, e.errno, e.reason))
263    except (Exception, KeyboardInterrupt) as e:
264      if os.path.exists(fpath):
265        os.remove(fpath)
266      raise
267    ProgressTracker.progbar = None
268
269  if untar:
270    if not os.path.exists(untar_fpath):
271      _extract_archive(fpath, datadir, archive_format='tar')
272    return untar_fpath
273
274  if extract:
275    _extract_archive(fpath, datadir, archive_format)
276
277  return fpath
278
279
280def _makedirs_exist_ok(datadir):
281  os.makedirs(datadir, exist_ok=True)  # pylint: disable=unexpected-keyword-arg
282
283
284def _resolve_hasher(algorithm, file_hash=None):
285  """Returns hash algorithm as hashlib function."""
286  if algorithm == 'sha256':
287    return hashlib.sha256()
288
289  if algorithm == 'auto' and file_hash is not None and len(file_hash) == 64:
290    return hashlib.sha256()
291
292  # This is used only for legacy purposes.
293  return hashlib.md5()
294
295
296def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
297  """Calculates a file sha256 or md5 hash.
298
299  Example:
300
301  ```python
302  _hash_file('/path/to/file.zip')
303  'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
304  ```
305
306  Args:
307      fpath: path to the file being validated
308      algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`.
309          The default `'auto'` detects the hash algorithm in use.
310      chunk_size: Bytes to read at a time, important for large files.
311
312  Returns:
313      The file hash
314  """
315  if isinstance(algorithm, str):
316    hasher = _resolve_hasher(algorithm)
317  else:
318    hasher = algorithm
319
320  with open(fpath, 'rb') as fpath_file:
321    for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
322      hasher.update(chunk)
323
324  return hasher.hexdigest()
325
326
327def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
328  """Validates a file against a sha256 or md5 hash.
329
330  Args:
331      fpath: path to the file being validated
332      file_hash:  The expected hash string of the file.
333          The sha256 and md5 hash algorithms are both supported.
334      algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
335          The default 'auto' detects the hash algorithm in use.
336      chunk_size: Bytes to read at a time, important for large files.
337
338  Returns:
339      Whether the file is valid
340  """
341  hasher = _resolve_hasher(algorithm, file_hash)
342
343  if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
344    return True
345  else:
346    return False
347
348
349class ThreadsafeIter(object):
350  """Wrap an iterator with a lock and propagate exceptions to all threads."""
351
352  def __init__(self, it):
353    self.it = it
354    self.lock = threading.Lock()
355
356    # After a generator throws an exception all subsequent next() calls raise a
357    # StopIteration Exception. This, however, presents an issue when mixing
358    # generators and threading because it means the order of retrieval need not
359    # match the order in which the generator was called. This can make it appear
360    # that a generator exited normally when in fact the terminating exception is
361    # just in a different thread. In order to provide thread safety, once
362    # self.it has thrown an exception we continue to throw the same exception.
363    self._exception = None
364
365  def __iter__(self):
366    return self
367
368  def next(self):
369    return self.__next__()
370
371  def __next__(self):
372    with self.lock:
373      if self._exception:
374        raise self._exception  # pylint: disable=raising-bad-type
375
376      try:
377        return next(self.it)
378      except Exception as e:
379        self._exception = e
380        raise
381
382
383def threadsafe_generator(f):
384
385  @functools.wraps(f)
386  def g(*a, **kw):
387    return ThreadsafeIter(f(*a, **kw))
388
389  return g
390
391
392@keras_export('keras.utils.Sequence')
393class Sequence(object):
394  """Base object for fitting to a sequence of data, such as a dataset.
395
396  Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
397  If you want to modify your dataset between epochs you may implement
398  `on_epoch_end`.
399  The method `__getitem__` should return a complete batch.
400
401  Notes:
402
403  `Sequence` are a safer way to do multiprocessing. This structure guarantees
404  that the network will only train once
405   on each sample per epoch which is not the case with generators.
406
407  Examples:
408
409  ```python
410  from skimage.io import imread
411  from skimage.transform import resize
412  import numpy as np
413  import math
414
415  # Here, `x_set` is list of path to the images
416  # and `y_set` are the associated classes.
417
418  class CIFAR10Sequence(Sequence):
419
420      def __init__(self, x_set, y_set, batch_size):
421          self.x, self.y = x_set, y_set
422          self.batch_size = batch_size
423
424      def __len__(self):
425          return math.ceil(len(self.x) / self.batch_size)
426
427      def __getitem__(self, idx):
428          batch_x = self.x[idx * self.batch_size:(idx + 1) *
429          self.batch_size]
430          batch_y = self.y[idx * self.batch_size:(idx + 1) *
431          self.batch_size]
432
433          return np.array([
434              resize(imread(file_name), (200, 200))
435                 for file_name in batch_x]), np.array(batch_y)
436  ```
437  """
438
439  @abstractmethod
440  def __getitem__(self, index):
441    """Gets batch at position `index`.
442
443    Args:
444        index: position of the batch in the Sequence.
445
446    Returns:
447        A batch
448    """
449    raise NotImplementedError
450
451  @abstractmethod
452  def __len__(self):
453    """Number of batch in the Sequence.
454
455    Returns:
456        The number of batches in the Sequence.
457    """
458    raise NotImplementedError
459
460  def on_epoch_end(self):
461    """Method called at the end of every epoch.
462    """
463    pass
464
465  def __iter__(self):
466    """Create a generator that iterate over the Sequence."""
467    for item in (self[i] for i in range(len(self))):
468      yield item
469
470
471def iter_sequence_infinite(seq):
472  """Iterates indefinitely over a Sequence.
473
474  Args:
475    seq: `Sequence` instance.
476
477  Yields:
478    Batches of data from the `Sequence`.
479  """
480  while True:
481    for item in seq:
482      yield item
483
484
485# Global variables to be shared across processes
486_SHARED_SEQUENCES = {}
487# We use a Value to provide unique id to different processes.
488_SEQUENCE_COUNTER = None
489
490
491# Because multiprocessing pools are inherently unsafe, starting from a clean
492# state can be essential to avoiding deadlocks. In order to accomplish this, we
493# need to be able to check on the status of Pools that we create.
494_DATA_POOLS = weakref.WeakSet()
495_WORKER_ID_QUEUE = None  # Only created if needed.
496_WORKER_IDS = set()
497_FORCE_THREADPOOL = False
498_FORCE_THREADPOOL_LOCK = threading.RLock()
499
500
501def dont_use_multiprocessing_pool(f):
502  @functools.wraps(f)
503  def wrapped(*args, **kwargs):
504    with _FORCE_THREADPOOL_LOCK:
505      global _FORCE_THREADPOOL
506      old_force_threadpool, _FORCE_THREADPOOL = _FORCE_THREADPOOL, True
507      out = f(*args, **kwargs)
508      _FORCE_THREADPOOL = old_force_threadpool
509      return out
510  return wrapped
511
512
513def get_pool_class(use_multiprocessing):
514  global _FORCE_THREADPOOL
515  if not use_multiprocessing or _FORCE_THREADPOOL:
516    return multiprocessing.dummy.Pool  # ThreadPool
517  return multiprocessing.Pool
518
519
520def get_worker_id_queue():
521  """Lazily create the queue to track worker ids."""
522  global _WORKER_ID_QUEUE
523  if _WORKER_ID_QUEUE is None:
524    _WORKER_ID_QUEUE = multiprocessing.Queue()
525  return _WORKER_ID_QUEUE
526
527
528def init_pool(seqs):
529  global _SHARED_SEQUENCES
530  _SHARED_SEQUENCES = seqs
531
532
533def get_index(uid, i):
534  """Get the value from the Sequence `uid` at index `i`.
535
536  To allow multiple Sequences to be used at the same time, we use `uid` to
537  get a specific one. A single Sequence would cause the validation to
538  overwrite the training Sequence.
539
540  Args:
541      uid: int, Sequence identifier
542      i: index
543
544  Returns:
545      The value at index `i`.
546  """
547  return _SHARED_SEQUENCES[uid][i]
548
549
550@keras_export('keras.utils.SequenceEnqueuer')
551class SequenceEnqueuer(object):
552  """Base class to enqueue inputs.
553
554  The task of an Enqueuer is to use parallelism to speed up preprocessing.
555  This is done with processes or threads.
556
557  Example:
558
559  ```python
560      enqueuer = SequenceEnqueuer(...)
561      enqueuer.start()
562      datas = enqueuer.get()
563      for data in datas:
564          # Use the inputs; training, evaluating, predicting.
565          # ... stop sometime.
566      enqueuer.stop()
567  ```
568
569  The `enqueuer.get()` should be an infinite stream of datas.
570  """
571
572  def __init__(self, sequence,
573               use_multiprocessing=False):
574    self.sequence = sequence
575    self.use_multiprocessing = use_multiprocessing
576
577    global _SEQUENCE_COUNTER
578    if _SEQUENCE_COUNTER is None:
579      try:
580        _SEQUENCE_COUNTER = multiprocessing.Value('i', 0)
581      except OSError:
582        # In this case the OS does not allow us to use
583        # multiprocessing. We resort to an int
584        # for enqueuer indexing.
585        _SEQUENCE_COUNTER = 0
586
587    if isinstance(_SEQUENCE_COUNTER, int):
588      self.uid = _SEQUENCE_COUNTER
589      _SEQUENCE_COUNTER += 1
590    else:
591      # Doing Multiprocessing.Value += x is not process-safe.
592      with _SEQUENCE_COUNTER.get_lock():
593        self.uid = _SEQUENCE_COUNTER.value
594        _SEQUENCE_COUNTER.value += 1
595
596    self.workers = 0
597    self.executor_fn = None
598    self.queue = None
599    self.run_thread = None
600    self.stop_signal = None
601
602  def is_running(self):
603    return self.stop_signal is not None and not self.stop_signal.is_set()
604
605  def start(self, workers=1, max_queue_size=10):
606    """Starts the handler's workers.
607
608    Args:
609        workers: Number of workers.
610        max_queue_size: queue size
611            (when full, workers could block on `put()`)
612    """
613    if self.use_multiprocessing:
614      self.executor_fn = self._get_executor_init(workers)
615    else:
616      # We do not need the init since it's threads.
617      self.executor_fn = lambda _: get_pool_class(False)(workers)
618    self.workers = workers
619    self.queue = queue.Queue(max_queue_size)
620    self.stop_signal = threading.Event()
621    self.run_thread = threading.Thread(target=self._run)
622    self.run_thread.daemon = True
623    self.run_thread.start()
624
625  def _send_sequence(self):
626    """Sends current Iterable to all workers."""
627    # For new processes that may spawn
628    _SHARED_SEQUENCES[self.uid] = self.sequence
629
630  def stop(self, timeout=None):
631    """Stops running threads and wait for them to exit, if necessary.
632
633    Should be called by the same thread which called `start()`.
634
635    Args:
636        timeout: maximum time to wait on `thread.join()`
637    """
638    self.stop_signal.set()
639    with self.queue.mutex:
640      self.queue.queue.clear()
641      self.queue.unfinished_tasks = 0
642      self.queue.not_full.notify()
643    self.run_thread.join(timeout)
644    _SHARED_SEQUENCES[self.uid] = None
645
646  def __del__(self):
647    if self.is_running():
648      self.stop()
649
650  @abstractmethod
651  def _run(self):
652    """Submits request to the executor and queue the `Future` objects."""
653    raise NotImplementedError
654
655  @abstractmethod
656  def _get_executor_init(self, workers):
657    """Gets the Pool initializer for multiprocessing.
658
659    Args:
660        workers: Number of workers.
661
662    Returns:
663        Function, a Function to initialize the pool
664    """
665    raise NotImplementedError
666
667  @abstractmethod
668  def get(self):
669    """Creates a generator to extract data from the queue.
670
671    Skip the data if it is `None`.
672    # Returns
673        Generator yielding tuples `(inputs, targets)`
674            or `(inputs, targets, sample_weights)`.
675    """
676    raise NotImplementedError
677
678
679@keras_export('keras.utils.OrderedEnqueuer')
680class OrderedEnqueuer(SequenceEnqueuer):
681  """Builds a Enqueuer from a Sequence.
682
683  Args:
684      sequence: A `tf.keras.utils.data_utils.Sequence` object.
685      use_multiprocessing: use multiprocessing if True, otherwise threading
686      shuffle: whether to shuffle the data at the beginning of each epoch
687  """
688
689  def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
690    super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing)
691    self.shuffle = shuffle
692
693  def _get_executor_init(self, workers):
694    """Gets the Pool initializer for multiprocessing.
695
696    Args:
697        workers: Number of workers.
698
699    Returns:
700        Function, a Function to initialize the pool
701    """
702    def pool_fn(seqs):
703      pool = get_pool_class(True)(
704          workers, initializer=init_pool_generator,
705          initargs=(seqs, None, get_worker_id_queue()))
706      _DATA_POOLS.add(pool)
707      return pool
708
709    return pool_fn
710
711  def _wait_queue(self):
712    """Wait for the queue to be empty."""
713    while True:
714      time.sleep(0.1)
715      if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set():
716        return
717
718  def _run(self):
719    """Submits request to the executor and queue the `Future` objects."""
720    sequence = list(range(len(self.sequence)))
721    self._send_sequence()  # Share the initial sequence
722    while True:
723      if self.shuffle:
724        random.shuffle(sequence)
725
726      with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
727        for i in sequence:
728          if self.stop_signal.is_set():
729            return
730
731          self.queue.put(
732              executor.apply_async(get_index, (self.uid, i)), block=True)
733
734        # Done with the current epoch, waiting for the final batches
735        self._wait_queue()
736
737        if self.stop_signal.is_set():
738          # We're done
739          return
740
741      # Call the internal on epoch end.
742      self.sequence.on_epoch_end()
743      self._send_sequence()  # Update the pool
744
745  def get(self):
746    """Creates a generator to extract data from the queue.
747
748    Skip the data if it is `None`.
749
750    Yields:
751        The next element in the queue, i.e. a tuple
752        `(inputs, targets)` or
753        `(inputs, targets, sample_weights)`.
754    """
755    while self.is_running():
756      try:
757        inputs = self.queue.get(block=True, timeout=5).get()
758        if self.is_running():
759          self.queue.task_done()
760        if inputs is not None:
761          yield inputs
762      except queue.Empty:
763        pass
764      except Exception as e:  # pylint: disable=broad-except
765        self.stop()
766        raise e
767
768
769def init_pool_generator(gens, random_seed=None, id_queue=None):
770  """Initializer function for pool workers.
771
772  Args:
773    gens: State which should be made available to worker processes.
774    random_seed: An optional value with which to seed child processes.
775    id_queue: A multiprocessing Queue of worker ids. This is used to indicate
776      that a worker process was created by Keras and can be terminated using
777      the cleanup_all_keras_forkpools utility.
778  """
779  global _SHARED_SEQUENCES
780  _SHARED_SEQUENCES = gens
781
782  worker_proc = multiprocessing.current_process()
783
784  # name isn't used for anything, but setting a more descriptive name is helpful
785  # when diagnosing orphaned processes.
786  worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name)
787
788  if random_seed is not None:
789    np.random.seed(random_seed + worker_proc.ident)
790
791  if id_queue is not None:
792    # If a worker dies during init, the pool will just create a replacement.
793    id_queue.put(worker_proc.ident, block=True, timeout=0.1)
794
795
796def next_sample(uid):
797  """Gets the next value from the generator `uid`.
798
799  To allow multiple generators to be used at the same time, we use `uid` to
800  get a specific one. A single generator would cause the validation to
801  overwrite the training generator.
802
803  Args:
804      uid: int, generator identifier
805
806  Returns:
807      The next value of generator `uid`.
808  """
809  return next(_SHARED_SEQUENCES[uid])
810
811
812@keras_export('keras.utils.GeneratorEnqueuer')
813class GeneratorEnqueuer(SequenceEnqueuer):
814  """Builds a queue out of a data generator.
815
816  The provided generator can be finite in which case the class will throw
817  a `StopIteration` exception.
818
819  Args:
820      generator: a generator function which yields data
821      use_multiprocessing: use multiprocessing if True, otherwise threading
822      random_seed: Initial seed for workers,
823          will be incremented by one for each worker.
824  """
825
826  def __init__(self, generator,
827               use_multiprocessing=False,
828               random_seed=None):
829    super(GeneratorEnqueuer, self).__init__(generator, use_multiprocessing)
830    self.random_seed = random_seed
831
832  def _get_executor_init(self, workers):
833    """Gets the Pool initializer for multiprocessing.
834
835    Args:
836      workers: Number of works.
837
838    Returns:
839        A Function to initialize the pool
840    """
841    def pool_fn(seqs):
842      pool = get_pool_class(True)(
843          workers, initializer=init_pool_generator,
844          initargs=(seqs, self.random_seed, get_worker_id_queue()))
845      _DATA_POOLS.add(pool)
846      return pool
847    return pool_fn
848
849  def _run(self):
850    """Submits request to the executor and queue the `Future` objects."""
851    self._send_sequence()  # Share the initial generator
852    with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
853      while True:
854        if self.stop_signal.is_set():
855          return
856
857        self.queue.put(
858            executor.apply_async(next_sample, (self.uid,)), block=True)
859
860  def get(self):
861    """Creates a generator to extract data from the queue.
862
863    Skip the data if it is `None`.
864
865    Yields:
866        The next element in the queue, i.e. a tuple
867        `(inputs, targets)` or
868        `(inputs, targets, sample_weights)`.
869    """
870    try:
871      while self.is_running():
872        inputs = self.queue.get(block=True).get()
873        self.queue.task_done()
874        if inputs is not None:
875          yield inputs
876    except StopIteration:
877      # Special case for finite generators
878      last_ones = []
879      while self.queue.qsize() > 0:
880        last_ones.append(self.queue.get(block=True))
881      # Wait for them to complete
882      for f in last_ones:
883        f.wait()
884      # Keep the good ones
885      last_ones = [future.get() for future in last_ones if future.successful()]
886      for inputs in last_ones:
887        if inputs is not None:
888          yield inputs
889    except Exception as e:  # pylint: disable=broad-except
890      self.stop()
891      if 'generator already executing' in str(e):
892        raise RuntimeError(
893            'Your generator is NOT thread-safe. '
894            'Keras requires a thread-safe generator when '
895            '`use_multiprocessing=False, workers > 1`. ')
896      raise e
897