1# mypy: allow-untyped-defs 2from enum import IntEnum 3from typing import Dict, Sized, Tuple 4 5from torch.utils.data.datapipes._decorator import functional_datapipe 6from torch.utils.data.datapipes.datapipe import IterDataPipe 7 8 9__all__ = [ 10 "SHARDING_PRIORITIES", 11 "ShardingFilterIterDataPipe", 12] 13 14 15class SHARDING_PRIORITIES(IntEnum): 16 DEFAULT = 1 17 DISTRIBUTED = 2 18 MULTIPROCESSING = 3 19 20 21class _ShardingIterDataPipe(IterDataPipe): 22 def apply_sharding( 23 self, 24 num_of_instances: int, 25 instance_id: int, 26 sharding_group: SHARDING_PRIORITIES, 27 ): 28 raise NotImplementedError 29 30 31@functional_datapipe("sharding_filter") 32class ShardingFilterIterDataPipe(_ShardingIterDataPipe): 33 r""" 34 Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``). 35 36 After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the 37 original DataPipe, where `n` equals to the number of instances. 38 39 Args: 40 source_datapipe: Iterable DataPipe that will be sharded 41 """ 42 43 def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None): 44 self.source_datapipe = source_datapipe 45 self.sharding_group_filter = sharding_group_filter 46 self.groups: Dict[int, Tuple[int, int]] = {} 47 self.num_of_instances = 1 48 self.instance_id = 0 49 self._update_num_of_instances() 50 51 def apply_sharding( 52 self, num_of_instances, instance_id, sharding_group=SHARDING_PRIORITIES.DEFAULT 53 ): 54 if instance_id >= num_of_instances: 55 raise ValueError( 56 f"instance_id({instance_id}) should be smaller than num_of_instances({num_of_instances})" 57 ) 58 if sharding_group == SHARDING_PRIORITIES.DEFAULT: 59 if len(self.groups) and SHARDING_PRIORITIES.DEFAULT not in self.groups: 60 raise RuntimeError( 61 "ShardingFilter cannot mix DEFAULT and non DEFAULT groups" 62 ) 63 else: 64 if SHARDING_PRIORITIES.DEFAULT in self.groups: 65 raise RuntimeError( 66 "ShardingFilter cannot mix DEFAULT and non DEFAULT groups" 67 ) 68 self.groups[sharding_group] = (num_of_instances, instance_id) 69 self._update_num_of_instances() 70 71 def _update_num_of_instances(self): 72 sorted_sharding_groups = [] 73 for key in sorted(self.groups.keys()): 74 if self.sharding_group_filter is None or key == self.sharding_group_filter: 75 sorted_sharding_groups.append(self.groups[key]) 76 77 sorted_sharding_groups.reverse() 78 79 self.num_of_instances = 1 80 self.instance_id = 0 81 82 for group_num_of_instances, group_instance_id in sorted_sharding_groups: 83 self.instance_id += self.num_of_instances * group_instance_id 84 self.num_of_instances *= group_num_of_instances 85 86 def __iter__(self): 87 for i, item in enumerate(self.source_datapipe): 88 if i % self.num_of_instances == self.instance_id: 89 yield item 90 91 def __len__(self): 92 if isinstance(self.source_datapipe, Sized): 93 return len(self.source_datapipe) // self.num_of_instances + ( 94 1 95 if ( 96 self.instance_id < len(self.source_datapipe) % self.num_of_instances 97 ) 98 else 0 99 ) 100 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 101