xref: /aosp_15_r20/external/pytorch/torch/utils/data/graph_settings.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3import warnings
4from typing import Any, List, Optional, Set
5from typing_extensions import deprecated
6
7import torch
8from torch.utils.data.datapipes.iter.sharding import (
9    _ShardingIterDataPipe,
10    SHARDING_PRIORITIES,
11)
12from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
13
14
15__all__ = [
16    "apply_random_seed",
17    "apply_sharding",
18    "apply_shuffle_seed",
19    "apply_shuffle_settings",
20    "get_all_graph_pipes",
21]
22
23
24def get_all_graph_pipes(graph: DataPipeGraph) -> List[DataPipe]:
25    return _get_all_graph_pipes_helper(graph, set())
26
27
28def _get_all_graph_pipes_helper(
29    graph: DataPipeGraph, id_cache: Set[int]
30) -> List[DataPipe]:
31    results: List[DataPipe] = []
32    for dp_id, (datapipe, sub_graph) in graph.items():
33        if dp_id in id_cache:
34            continue
35        id_cache.add(dp_id)
36        results.append(datapipe)
37        results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache))
38    return results
39
40
41def _is_sharding_datapipe(datapipe: DataPipe) -> bool:
42    return isinstance(datapipe, _ShardingIterDataPipe) or (
43        hasattr(datapipe, "apply_sharding")
44        and inspect.ismethod(datapipe.apply_sharding)
45    )
46
47
48def apply_sharding(
49    datapipe: DataPipe,
50    num_of_instances: int,
51    instance_id: int,
52    sharding_group=SHARDING_PRIORITIES.DEFAULT,
53) -> DataPipe:
54    r"""
55    Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``.
56
57    RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch.
58    """
59    graph = traverse_dps(datapipe)
60
61    def _helper(graph, prev_applied=None):
62        for dp, sub_graph in graph.values():
63            applied = None
64            if _is_sharding_datapipe(dp):
65                if prev_applied is not None:
66                    raise RuntimeError(
67                        "Sharding twice on a single pipeline is likely unintended and will cause data loss. "
68                        f"Sharding already applied to {prev_applied} while trying to apply to {dp}"
69                    )
70                # For BC, only provide sharding_group if accepted
71                sig = inspect.signature(dp.apply_sharding)
72                if len(sig.parameters) < 3:
73                    dp.apply_sharding(num_of_instances, instance_id)
74                else:
75                    dp.apply_sharding(
76                        num_of_instances, instance_id, sharding_group=sharding_group
77                    )
78                applied = dp
79            if applied is None:
80                applied = prev_applied
81            _helper(sub_graph, applied)
82
83    _helper(graph)
84
85    return datapipe
86
87
88def _is_shuffle_datapipe(datapipe: DataPipe) -> bool:
89    return (
90        hasattr(datapipe, "set_shuffle")
91        and hasattr(datapipe, "set_seed")
92        and inspect.ismethod(datapipe.set_shuffle)
93        and inspect.ismethod(datapipe.set_seed)
94    )
95
96
97def apply_shuffle_settings(
98    datapipe: DataPipe, shuffle: Optional[bool] = None
99) -> DataPipe:
100    r"""
101    Traverse the graph of ``DataPipes`` to find and set shuffle attribute.
102
103    Apply the method to each `DataPipe` that has APIs of ``set_shuffle``
104    and ``set_seed``.
105
106    Args:
107        datapipe: DataPipe that needs to set shuffle attribute
108        shuffle: Shuffle option (default: ``None`` and no-op to the graph)
109    """
110    if shuffle is None:
111        return datapipe
112
113    graph = traverse_dps(datapipe)
114    all_pipes = get_all_graph_pipes(graph)
115    shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)]
116    if not shufflers and shuffle:
117        warnings.warn(
118            "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. "
119            "Be aware that the default buffer size might not be sufficient for your task."
120        )
121        datapipe = datapipe.shuffle()
122        shufflers = [
123            datapipe,
124        ]  # type: ignore[list-item]
125
126    for shuffler in shufflers:
127        shuffler.set_shuffle(shuffle)
128
129    return datapipe
130
131
132@deprecated(
133    "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases. "
134    "Please use `apply_random_seed` instead.",
135    category=FutureWarning,
136)
137def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe:
138    return apply_random_seed(datapipe, rng)
139
140
141def _is_random_datapipe(datapipe: DataPipe) -> bool:
142    return hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed)
143
144
145def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
146    r"""
147    Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of ``set_seed``.
148
149    Then set the random seed based on the provided RNG to those ``DataPipe``.
150
151    Args:
152        datapipe: DataPipe that needs to set randomness
153        rng: Random number generator to generate random seeds
154    """
155    graph = traverse_dps(datapipe)
156    all_pipes = get_all_graph_pipes(graph)
157    # Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once.
158    # And, `id` is used in case of unhashable DataPipe
159    cache = set()
160    random_datapipes = []
161    for pipe in all_pipes:
162        if id(pipe) in cache:
163            continue
164        if _is_random_datapipe(pipe):
165            random_datapipes.append(pipe)
166            cache.add(id(pipe))
167
168    for pipe in random_datapipes:
169        random_seed = int(
170            torch.empty((), dtype=torch.int64).random_(generator=rng).item()
171        )
172        pipe.set_seed(random_seed)
173
174    return datapipe
175