1# mypy: allow-untyped-defs 2from typing import Callable, Iterator, Tuple, TypeVar 3 4from torch.utils.data.datapipes._decorator import functional_datapipe 5from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper 6from torch.utils.data.datapipes.datapipe import IterDataPipe 7from torch.utils.data.datapipes.utils.common import ( 8 _check_unpickable_fn, 9 StreamWrapper, 10 validate_input_col, 11) 12 13 14__all__ = ["FilterIterDataPipe"] 15 16 17_T = TypeVar("_T") 18_T_co = TypeVar("_T_co", covariant=True) 19 20 21@functional_datapipe("filter") 22class FilterIterDataPipe(IterDataPipe[_T_co]): 23 r""" 24 Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``). 25 26 Args: 27 datapipe: Iterable DataPipe being filtered 28 filter_fn: Customized function mapping an element to a boolean. 29 input_col: Index or indices of data which ``filter_fn`` is applied, such as: 30 31 - ``None`` as default to apply ``filter_fn`` to the data directly. 32 - Integer(s) is used for list/tuple. 33 - Key(s) is used for dict. 34 35 Example: 36 >>> # xdoctest: +SKIP 37 >>> from torchdata.datapipes.iter import IterableWrapper 38 >>> def is_even(n): 39 ... return n % 2 == 0 40 >>> dp = IterableWrapper(range(5)) 41 >>> filter_dp = dp.filter(filter_fn=is_even) 42 >>> list(filter_dp) 43 [0, 2, 4] 44 """ 45 46 datapipe: IterDataPipe[_T_co] 47 filter_fn: Callable 48 49 def __init__( 50 self, 51 datapipe: IterDataPipe[_T_co], 52 filter_fn: Callable, 53 input_col=None, 54 ) -> None: 55 super().__init__() 56 self.datapipe = datapipe 57 58 _check_unpickable_fn(filter_fn) 59 self.filter_fn = filter_fn # type: ignore[assignment] 60 61 self.input_col = input_col 62 validate_input_col(filter_fn, input_col) 63 64 def _apply_filter_fn(self, data) -> bool: 65 if self.input_col is None: 66 return self.filter_fn(data) 67 elif isinstance(self.input_col, (list, tuple)): 68 args = tuple(data[col] for col in self.input_col) 69 return self.filter_fn(*args) 70 else: 71 return self.filter_fn(data[self.input_col]) 72 73 def __iter__(self) -> Iterator[_T_co]: 74 for data in self.datapipe: 75 condition, filtered = self._returnIfTrue(data) 76 if condition: 77 yield filtered 78 else: 79 StreamWrapper.close_streams(data) 80 81 def _returnIfTrue(self, data: _T) -> Tuple[bool, _T]: 82 condition = self._apply_filter_fn(data) 83 84 if df_wrapper.is_column(condition): 85 # We are operating on DataFrames filter here 86 result = [] 87 for idx, mask in enumerate(df_wrapper.iterate(condition)): 88 if mask: 89 result.append(df_wrapper.get_item(data, idx)) 90 if len(result): 91 return True, df_wrapper.concat(result) 92 else: 93 return False, None # type: ignore[return-value] 94 95 if not isinstance(condition, bool): 96 raise ValueError( 97 "Boolean output is required for `filter_fn` of FilterIterDataPipe, got", 98 type(condition), 99 ) 100 101 return condition, data 102