1# mypy: allow-untyped-defs 2import inspect 3import warnings 4from typing import Any, List, Optional, Set 5from typing_extensions import deprecated 6 7import torch 8from torch.utils.data.datapipes.iter.sharding import ( 9 _ShardingIterDataPipe, 10 SHARDING_PRIORITIES, 11) 12from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps 13 14 15__all__ = [ 16 "apply_random_seed", 17 "apply_sharding", 18 "apply_shuffle_seed", 19 "apply_shuffle_settings", 20 "get_all_graph_pipes", 21] 22 23 24def get_all_graph_pipes(graph: DataPipeGraph) -> List[DataPipe]: 25 return _get_all_graph_pipes_helper(graph, set()) 26 27 28def _get_all_graph_pipes_helper( 29 graph: DataPipeGraph, id_cache: Set[int] 30) -> List[DataPipe]: 31 results: List[DataPipe] = [] 32 for dp_id, (datapipe, sub_graph) in graph.items(): 33 if dp_id in id_cache: 34 continue 35 id_cache.add(dp_id) 36 results.append(datapipe) 37 results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache)) 38 return results 39 40 41def _is_sharding_datapipe(datapipe: DataPipe) -> bool: 42 return isinstance(datapipe, _ShardingIterDataPipe) or ( 43 hasattr(datapipe, "apply_sharding") 44 and inspect.ismethod(datapipe.apply_sharding) 45 ) 46 47 48def apply_sharding( 49 datapipe: DataPipe, 50 num_of_instances: int, 51 instance_id: int, 52 sharding_group=SHARDING_PRIORITIES.DEFAULT, 53) -> DataPipe: 54 r""" 55 Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``. 56 57 RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch. 58 """ 59 graph = traverse_dps(datapipe) 60 61 def _helper(graph, prev_applied=None): 62 for dp, sub_graph in graph.values(): 63 applied = None 64 if _is_sharding_datapipe(dp): 65 if prev_applied is not None: 66 raise RuntimeError( 67 "Sharding twice on a single pipeline is likely unintended and will cause data loss. " 68 f"Sharding already applied to {prev_applied} while trying to apply to {dp}" 69 ) 70 # For BC, only provide sharding_group if accepted 71 sig = inspect.signature(dp.apply_sharding) 72 if len(sig.parameters) < 3: 73 dp.apply_sharding(num_of_instances, instance_id) 74 else: 75 dp.apply_sharding( 76 num_of_instances, instance_id, sharding_group=sharding_group 77 ) 78 applied = dp 79 if applied is None: 80 applied = prev_applied 81 _helper(sub_graph, applied) 82 83 _helper(graph) 84 85 return datapipe 86 87 88def _is_shuffle_datapipe(datapipe: DataPipe) -> bool: 89 return ( 90 hasattr(datapipe, "set_shuffle") 91 and hasattr(datapipe, "set_seed") 92 and inspect.ismethod(datapipe.set_shuffle) 93 and inspect.ismethod(datapipe.set_seed) 94 ) 95 96 97def apply_shuffle_settings( 98 datapipe: DataPipe, shuffle: Optional[bool] = None 99) -> DataPipe: 100 r""" 101 Traverse the graph of ``DataPipes`` to find and set shuffle attribute. 102 103 Apply the method to each `DataPipe` that has APIs of ``set_shuffle`` 104 and ``set_seed``. 105 106 Args: 107 datapipe: DataPipe that needs to set shuffle attribute 108 shuffle: Shuffle option (default: ``None`` and no-op to the graph) 109 """ 110 if shuffle is None: 111 return datapipe 112 113 graph = traverse_dps(datapipe) 114 all_pipes = get_all_graph_pipes(graph) 115 shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)] 116 if not shufflers and shuffle: 117 warnings.warn( 118 "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. " 119 "Be aware that the default buffer size might not be sufficient for your task." 120 ) 121 datapipe = datapipe.shuffle() 122 shufflers = [ 123 datapipe, 124 ] # type: ignore[list-item] 125 126 for shuffler in shufflers: 127 shuffler.set_shuffle(shuffle) 128 129 return datapipe 130 131 132@deprecated( 133 "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases. " 134 "Please use `apply_random_seed` instead.", 135 category=FutureWarning, 136) 137def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe: 138 return apply_random_seed(datapipe, rng) 139 140 141def _is_random_datapipe(datapipe: DataPipe) -> bool: 142 return hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed) 143 144 145def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe: 146 r""" 147 Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of ``set_seed``. 148 149 Then set the random seed based on the provided RNG to those ``DataPipe``. 150 151 Args: 152 datapipe: DataPipe that needs to set randomness 153 rng: Random number generator to generate random seeds 154 """ 155 graph = traverse_dps(datapipe) 156 all_pipes = get_all_graph_pipes(graph) 157 # Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once. 158 # And, `id` is used in case of unhashable DataPipe 159 cache = set() 160 random_datapipes = [] 161 for pipe in all_pipes: 162 if id(pipe) in cache: 163 continue 164 if _is_random_datapipe(pipe): 165 random_datapipes.append(pipe) 166 cache.add(id(pipe)) 167 168 for pipe in random_datapipes: 169 random_seed = int( 170 torch.empty((), dtype=torch.int64).random_(generator=rng).item() 171 ) 172 pipe.set_seed(random_seed) 173 174 return datapipe 175