1# mypy: allow-untyped-defs 2r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. 3 4These **needs** to be in global scope since Py2 doesn't support serializing 5static methods. 6""" 7 8import os 9import queue 10import random 11from dataclasses import dataclass 12from typing import Optional, TYPE_CHECKING, Union 13 14import torch 15from torch._utils import ExceptionWrapper 16 17from . import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling 18 19 20if TYPE_CHECKING: 21 from torch.utils.data import Dataset 22 23if IS_WINDOWS: 24 import ctypes 25 from ctypes.wintypes import BOOL, DWORD, HANDLE 26 27 # On Windows, the parent ID of the worker process remains unchanged when the manager process 28 # is gone, and the only way to check it through OS is to let the worker have a process handle 29 # of the manager and ask if the process status has changed. 30 class ManagerWatchdog: 31 def __init__(self) -> None: 32 self.manager_pid = os.getppid() 33 34 # mypy cannot detect this code is windows only 35 self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined] 36 self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) 37 self.kernel32.OpenProcess.restype = HANDLE 38 self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) 39 self.kernel32.WaitForSingleObject.restype = DWORD 40 41 # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx 42 SYNCHRONIZE = 0x00100000 43 self.manager_handle = self.kernel32.OpenProcess( 44 SYNCHRONIZE, 0, self.manager_pid 45 ) 46 47 if not self.manager_handle: 48 raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined] 49 50 self.manager_dead = False 51 52 def is_alive(self): 53 if not self.manager_dead: 54 # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx 55 self.manager_dead = ( 56 self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 57 ) 58 return not self.manager_dead 59 60else: 61 62 class ManagerWatchdog: # type: ignore[no-redef] 63 def __init__(self) -> None: 64 self.manager_pid = os.getppid() 65 self.manager_dead = False 66 67 def is_alive(self): 68 if not self.manager_dead: 69 self.manager_dead = os.getppid() != self.manager_pid 70 return not self.manager_dead 71 72 73_worker_info: Optional["WorkerInfo"] = None 74 75 76class WorkerInfo: 77 id: int 78 num_workers: int 79 seed: int 80 dataset: "Dataset" 81 __initialized = False 82 83 def __init__(self, **kwargs): 84 for k, v in kwargs.items(): 85 setattr(self, k, v) 86 self.__keys = tuple(kwargs.keys()) 87 self.__initialized = True 88 89 def __setattr__(self, key, val): 90 if self.__initialized: 91 raise RuntimeError( 92 f"Cannot assign attributes to {self.__class__.__name__} objects" 93 ) 94 return super().__setattr__(key, val) 95 96 def __repr__(self): 97 items = [] 98 for k in self.__keys: 99 items.append(f"{k}={getattr(self, k)}") 100 return f"{self.__class__.__name__}({', '.join(items)})" 101 102 103def get_worker_info() -> Optional[WorkerInfo]: 104 r"""Returns the information about the current 105 :class:`~torch.utils.data.DataLoader` iterator worker process. 106 107 When called in a worker, this returns an object guaranteed to have the 108 following attributes: 109 110 * :attr:`id`: the current worker id. 111 * :attr:`num_workers`: the total number of workers. 112 * :attr:`seed`: the random seed set for the current worker. This value is 113 determined by main process RNG and the worker id. See 114 :class:`~torch.utils.data.DataLoader`'s documentation for more details. 115 * :attr:`dataset`: the copy of the dataset object in **this** process. Note 116 that this will be a different object in a different process than the one 117 in the main process. 118 119 When called in the main process, this returns ``None``. 120 121 .. note:: 122 When used in a :attr:`worker_init_fn` passed over to 123 :class:`~torch.utils.data.DataLoader`, this method can be useful to 124 set up each worker process differently, for instance, using ``worker_id`` 125 to configure the ``dataset`` object to only read a specific fraction of a 126 sharded dataset, or use ``seed`` to seed other libraries used in dataset 127 code. 128 """ 129 return _worker_info 130 131 132r"""Dummy class used to signal the end of an IterableDataset""" 133 134 135@dataclass(frozen=True) 136class _IterableDatasetStopIteration: 137 worker_id: int 138 139 140r"""Dummy class used to resume the fetching when worker reuse is enabled""" 141 142 143@dataclass(frozen=True) 144class _ResumeIteration: 145 seed: Optional[int] = None 146 147 148# The function `_generate_state` is adapted from `numpy.random.SeedSequence` 149# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx 150# It's MIT licensed, here is the copyright: 151 152# Copyright (c) 2015 Melissa E. O'Neill 153# Copyright (c) 2019 NumPy Developers 154# 155# Permission is hereby granted, free of charge, to any person obtaining a copy 156# of this software and associated documentation files (the "Software"), to deal 157# in the Software without restriction, including without limitation the rights 158# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 159# copies of the Software, and to permit persons to whom the Software is 160# furnished to do so, subject to the following conditions: 161# 162# The above copyright notice and this permission notice shall be included in 163# all copies or substantial portions of the Software. 164# 165# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 166# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 167# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 168# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 169# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 170# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 171# SOFTWARE. 172 173 174# This function generates an array of int32 as the seed for 175# `numpy.random`, in order to prevent state collision due to same 176# seed and algorithm for `numpy.random` and `random` modules. 177# TODO: Implement `SeedSequence` like object for `torch.random` 178def _generate_state(base_seed, worker_id): 179 INIT_A = 0x43B0D7E5 180 MULT_A = 0x931E8875 181 INIT_B = 0x8B51F9DD 182 MULT_B = 0x58F38DED 183 MIX_MULT_L = 0xCA01F9DD 184 MIX_MULT_R = 0x4973F715 185 XSHIFT = 4 * 8 // 2 186 MASK32 = 0xFFFFFFFF 187 188 entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0] 189 pool = [0] * 4 190 191 hash_const_A = INIT_A 192 193 def hash(value): 194 nonlocal hash_const_A 195 value = (value ^ hash_const_A) & MASK32 196 hash_const_A = (hash_const_A * MULT_A) & MASK32 197 value = (value * hash_const_A) & MASK32 198 value = (value ^ (value >> XSHIFT)) & MASK32 199 return value 200 201 def mix(x, y): 202 result_x = (MIX_MULT_L * x) & MASK32 203 result_y = (MIX_MULT_R * y) & MASK32 204 result = (result_x - result_y) & MASK32 205 result = (result ^ (result >> XSHIFT)) & MASK32 206 return result 207 208 # Add in the entropy to the pool. 209 for i in range(len(pool)): 210 pool[i] = hash(entropy[i]) 211 212 # Mix all bits together so late bits can affect earlier bits. 213 for i_src in range(len(pool)): 214 for i_dst in range(len(pool)): 215 if i_src != i_dst: 216 pool[i_dst] = mix(pool[i_dst], hash(pool[i_src])) 217 218 hash_const_B = INIT_B 219 state = [] 220 for i_dst in range(4): 221 data_val = pool[i_dst] 222 data_val = (data_val ^ hash_const_B) & MASK32 223 hash_const_B = (hash_const_B * MULT_B) & MASK32 224 data_val = (data_val * hash_const_B) & MASK32 225 data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32 226 state.append(data_val) 227 return state 228 229 230def _worker_loop( 231 dataset_kind, 232 dataset, 233 index_queue, 234 data_queue, 235 done_event, 236 auto_collation, 237 collate_fn, 238 drop_last, 239 base_seed, 240 init_fn, 241 worker_id, 242 num_workers, 243 persistent_workers, 244 shared_seed, 245): 246 # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the 247 # logic of this function. 248 249 try: 250 # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal 251 # module's handlers are executed after Python returns from C low-level 252 # handlers, likely when the same fatal signal had already happened 253 # again. 254 # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers 255 signal_handling._set_worker_signal_handlers() 256 257 torch.multiprocessing._set_thread_name("pt_data_worker") 258 259 torch.set_num_threads(1) 260 seed = base_seed + worker_id 261 random.seed(seed) 262 torch.manual_seed(seed) 263 if HAS_NUMPY: 264 np_seed = _generate_state(base_seed, worker_id) 265 import numpy as np 266 267 np.random.seed(np_seed) 268 269 from torch.utils.data import IterDataPipe 270 from torch.utils.data.graph_settings import apply_random_seed 271 272 shared_rng = torch.Generator() 273 if isinstance(dataset, IterDataPipe): 274 assert shared_seed is not None 275 shared_rng.manual_seed(shared_seed) 276 dataset = apply_random_seed(dataset, shared_rng) 277 278 global _worker_info 279 _worker_info = WorkerInfo( 280 id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset 281 ) 282 283 from torch.utils.data import _DatasetKind 284 285 init_exception = None 286 287 try: 288 if init_fn is not None: 289 init_fn(worker_id) 290 291 fetcher = _DatasetKind.create_fetcher( 292 dataset_kind, dataset, auto_collation, collate_fn, drop_last 293 ) 294 except Exception: 295 init_exception = ExceptionWrapper( 296 where=f"in DataLoader worker process {worker_id}" 297 ) 298 299 # When using Iterable mode, some worker can exit earlier than others due 300 # to the IterableDataset behaving differently for different workers. 301 # When such things happen, an `_IterableDatasetStopIteration` object is 302 # sent over to the main process with the ID of this worker, so that the 303 # main process won't send more tasks to this worker, and will send 304 # `None` to this worker to properly exit it. 305 # 306 # Note that we cannot set `done_event` from a worker as it is shared 307 # among all processes. Instead, we set the `iteration_end` flag to 308 # signify that the iterator is exhausted. When either `done_event` or 309 # `iteration_end` is set, we skip all processing step and just wait for 310 # `None`. 311 iteration_end = False 312 313 watchdog = ManagerWatchdog() 314 315 while watchdog.is_alive(): 316 try: 317 r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) 318 except queue.Empty: 319 continue 320 if isinstance(r, _ResumeIteration): 321 # Acknowledge the main process 322 data_queue.put((r, None)) 323 iteration_end = False 324 325 if isinstance(dataset, IterDataPipe): 326 assert r.seed is not None 327 shared_rng.manual_seed(r.seed) 328 dataset = apply_random_seed(dataset, shared_rng) 329 330 # Recreate the fetcher for worker-reuse policy 331 fetcher = _DatasetKind.create_fetcher( 332 dataset_kind, dataset, auto_collation, collate_fn, drop_last 333 ) 334 continue 335 elif r is None: 336 # Received the final signal 337 assert done_event.is_set() or iteration_end 338 break 339 elif done_event.is_set() or iteration_end: 340 # `done_event` is set. But I haven't received the final signal 341 # (None) yet. I will keep continuing until get it, and skip the 342 # processing steps. 343 continue 344 idx, index = r 345 data: Union[_IterableDatasetStopIteration, ExceptionWrapper] 346 if init_exception is not None: 347 data = init_exception 348 init_exception = None 349 else: 350 try: 351 data = fetcher.fetch(index) # type: ignore[possibly-undefined] 352 except Exception as e: 353 if ( 354 isinstance(e, StopIteration) 355 and dataset_kind == _DatasetKind.Iterable 356 ): 357 data = _IterableDatasetStopIteration(worker_id) 358 # Set `iteration_end` 359 # (1) to save future `next(...)` calls, and 360 # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. 361 iteration_end = True 362 else: 363 # It is important that we don't store exc_info in a variable. 364 # `ExceptionWrapper` does the correct thing. 365 # See NOTE [ Python Traceback Reference Cycle Problem ] 366 data = ExceptionWrapper( 367 where=f"in DataLoader worker process {worker_id}" 368 ) 369 data_queue.put((idx, data)) 370 del data, idx, index, r # save memory 371 except KeyboardInterrupt: 372 # Main process will raise KeyboardInterrupt anyways. 373 pass 374 if done_event.is_set(): 375 data_queue.cancel_join_thread() 376 data_queue.close() 377