xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/partitioner_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from enum import Enum
3from typing import NamedTuple, Dict, List, Set
4
5from torch.fx.node import Node, map_arg
6
7
8class Partition:
9    """Partition class contains all the information about an individual partition.
10    It also provides necessary methods for manipulation the partition.
11    """
12
13    def __init__(self, partition_id: int) -> None:
14        self.nodes: Set[Node] = set()
15        self.partition_id = partition_id
16        self.parents: Set[Partition] = set()
17        self.children: Set[Partition] = set()
18        self.bfs_level: int = -1
19        self.used_mem_bytes: int = 0
20        self.logical_device_ids: List[int] = []
21
22    def __str__(self):
23        return str(self.partition_id)
24
25    def recalculate_mem_size(self):
26        self.used_mem_bytes = 0
27        for node in self.nodes:
28            self.used_mem_bytes += get_extra_size_of(node, self.nodes)
29
30    def add_node(self, node):
31        input_nodes: Dict[Node, None] = {}
32        map_arg(node.args, input_nodes.setdefault)
33        map_arg(node.kwargs, input_nodes.setdefault)
34        # Add current node's input nodes if they are placeholder or constants
35        for n in input_nodes:
36            if n.op in {"placeholder", "get_attr"}:
37                self.nodes.add(n)
38        self.nodes.add(node)
39        self.recalculate_mem_size()
40
41    def remove_node(self, node):
42        # Remove a node only if the node is in the partition
43        if node in self.nodes:
44            self.nodes.remove(node)
45            # Collect the node's input nodes
46            input_nodes: Dict[Node, None] = {}
47            map_arg(node.args, input_nodes.setdefault)
48            map_arg(node.kwargs, input_nodes.setdefault)
49            # Check if an input node is a placeholder or get_attr,
50            # and this input node is not used by some other nodes in this partition,
51            # the remove this input node
52            for input_node in input_nodes:
53                if all(
54                    n not in self.nodes for n in input_node.users
55                ) and input_node.op in {"placeholder", "get_attr"}:
56                    self.nodes.remove(input_node)
57            self.recalculate_mem_size()
58
59
60class Device(NamedTuple):
61    name: str
62    available_mem_bytes: int
63    logical_id: int
64
65
66class NodeLatency(NamedTuple):
67    # Latency due to the memory bandwidth
68    mem_latency_sec: float
69    # Latency due to the computation
70    computer_latency_sec: float
71
72
73class PartitionLatency(NamedTuple):
74    # Sum of all nodes' memory latency on the critical path
75    mem_latency_sec: float
76    # Sum of all nodes' compute latency on the critical path
77    computer_latency_sec: float
78    # Latency of the critical path
79    overall_latency_sec: float
80
81
82class PartitionMode(Enum):
83    size_based = 0
84    sparse_nn = 1
85    cost_aware = 2
86    kl_based = 3
87    aot_based = 4
88
89
90class PartitionerConfig(NamedTuple):
91    devices: List[Device]
92    mode: PartitionMode = PartitionMode.size_based
93    transfer_rate_bytes_per_sec: float = 0.0
94    node_to_latency_mapping: Dict[Node, NodeLatency] = {}
95    node_to_partition_mapping: Dict[Node, int] = {}
96    partition_to_logical_device_mapping: Dict[int, List[int]] = {}
97    # Saturate host by replicating partitions to the remaining idle devices.
98    saturate_host: bool = False
99
100
101def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
102    """Given a node and a set of nodes,
103    this function return the extra size that needed
104    if this node is included in this set.
105    """
106    # Find all its input nodes
107    input_nodes: Dict[Node, None] = {}
108    map_arg(node.args, input_nodes.setdefault)
109    map_arg(node.kwargs, input_nodes.setdefault)
110    # Calculate total size of related nodes
111    total_size_of_input_nodes = 0
112    for n in input_nodes:
113        # Make sure this node hasn't been in this set yet
114        if n not in nodes:
115            size_bytes = getattr(n, "size_bytes", None)
116            if size_bytes:
117                total_size_of_input_nodes += size_bytes.output_size
118            else:
119                raise RuntimeError("node has no size_bytes attr")
120    # Don't forget the op node itself
121    size_bytes = getattr(node, "size_bytes", None)
122    if size_bytes:
123        total_size_of_input_nodes += size_bytes.total_size
124    else:
125        raise RuntimeError("node has no size_bytes attr")
126    return total_size_of_input_nodes
127
128
129def get_latency_of_one_partition(
130    partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency]
131) -> PartitionLatency:
132    """Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
133
134    def get_top_nodes(partition: Partition) -> List[Node]:
135        """Given a partition, return a list of nodes on the top bfs level"""
136        top_nodes: List[Node] = []
137        for node in partition.nodes:
138            # Skip placeholder and get_attr nodes
139            if node.op in {"placeholder", "get_attr"}:
140                continue
141            input_nodes: Dict[Node, None] = {}
142            map_arg(node.args, input_nodes.setdefault)
143            map_arg(node.kwargs, input_nodes.setdefault)
144            # If a node has no input nodes in this partition,
145            # or its input nodes in this partition are placeholders and get_attrs
146            # this node is on the top bfs level in this partition
147            if not any(
148                n in partition.nodes and n.op not in {"placeholder", "get_attr"}
149                    for n in input_nodes
150            ):
151                top_nodes.append(node)
152        return top_nodes
153
154    def dfs_helper(node: Node, partition_latency) -> PartitionLatency:
155        """Given a top node of a partition, this function returns
156        the latency of the critical path in the partition
157        """
158        node_latency = node_to_latency_mapping[node]
159        # Calculate the current overall latency of the partition
160        overall_latency_sec = partition_latency.overall_latency_sec + max(
161            node_latency.computer_latency_sec, node_latency.mem_latency_sec
162        )
163        # Update the mem latency of this path
164        mem_latency_sec = (
165            partition_latency.mem_latency_sec + node_latency.mem_latency_sec
166        )
167        # Update the compute latency of this path
168        computer_latency_sec = (
169            partition_latency.computer_latency_sec + node_latency.computer_latency_sec
170        )
171        # Get all users of this node that are in this partition
172        users = set(node.users).intersection(partition.nodes)
173        if users:
174            max_latency = PartitionLatency(
175                mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
176            )
177            for n in users:
178                # Get new partition latency recursively
179                new_partition_latency = dfs_helper(
180                    n,
181                    PartitionLatency(
182                        mem_latency_sec, computer_latency_sec, overall_latency_sec
183                    ),
184                )
185                if (
186                    new_partition_latency.overall_latency_sec
187                    > max_latency.overall_latency_sec
188                ):
189                    max_latency = new_partition_latency
190            return max_latency
191        # If there is no user, the node is at bottom of the partition
192        return PartitionLatency(
193            mem_latency_sec, computer_latency_sec, overall_latency_sec
194        )
195
196    # Main part starts
197    # Get all top level nodes of this partition
198    top_nodes = get_top_nodes(partition)
199    critical_path_latency = PartitionLatency(
200        mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
201    )
202    # Go through all top nodes and find the largest latency (critical pass latency)
203    for node in top_nodes:
204        partition_latency = dfs_helper(
205            node,
206            PartitionLatency(
207                mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
208            ),
209        )
210        if (
211            partition_latency.overall_latency_sec
212            > critical_path_latency.overall_latency_sec
213        ):
214            critical_path_latency = partition_latency
215    return critical_path_latency
216
217
218def get_partition_to_latency_mapping(
219    partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency]
220) -> Dict[Partition, PartitionLatency]:
221    """Given all the partitions and node_to_latency_mapping dictionary,
222    return a mapping dictionary of each partition to its overall latency
223    """
224    partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {}
225    # Go through each partition and get its latency
226    for partition in partitions:
227        partition_latency = get_latency_of_one_partition(
228            partition, node_to_latency_mapping
229        )
230        partition_to_latency_mapping[partition] = partition_latency
231    return partition_to_latency_mapping
232
233
234def get_comm_latency_between(
235    parent_partition: Partition,
236    child_partition: Partition,
237    transfer_rate_bytes_per_sec: float,
238):
239    """Given two partitions (parent and child),
240    calculate the communication latency between the two.
241    """
242    # If two partitions are on the same device, the comm latency is 0.
243    if (
244        parent_partition.logical_device_ids != []
245        and child_partition.logical_device_ids != []
246        and parent_partition.logical_device_ids == child_partition.logical_device_ids
247    ):
248        return 0.0
249    # Keep tracking the communication size between parent and child
250    comm_size = 0
251    # Keep tracking all the counted node
252    visited_nodes = set()
253    # Go through all nodes in the child partition
254    # If a node has input nodes from the parent partition,
255    # the output size of those input nodes will be counted
256    # and added to comm_size
257    for node in child_partition.nodes:
258        input_nodes: Dict[Node, None] = {}
259        map_arg(node.args, input_nodes.setdefault)
260        map_arg(node.kwargs, input_nodes.setdefault)
261        for n in input_nodes:
262            if n in parent_partition.nodes and n not in visited_nodes:
263                size_bytes = getattr(n, "size_bytes", None)
264                if size_bytes is not None:
265                    comm_size += size_bytes.output_size
266                visited_nodes.add(n)
267    return comm_size / transfer_rate_bytes_per_sec
268
269
270def get_latency_of_partitioned_graph(
271    partitions: List[Partition],
272    partition_to_latency_mapping: Dict[Partition, PartitionLatency],
273    transfer_rate_bytes_per_sec: float,
274):
275    """Given all partitions in a graph, find the critical path among all partitions
276    and return its latency as the latency of the whole graph
277    """
278
279    def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
280        """This function helps to recursively get the latency of a path of partitions"""
281        # Update latency by adding current partition's latency
282        latency_so_far_sec += partition_to_latency_mapping[
283            partition
284        ].overall_latency_sec
285        children = partition.children
286        if partition.children:
287            max_latency_sec = 0.0
288            for child in partition.children:
289                # Calculate latency between
290                comm_latency_sec = get_comm_latency_between(
291                    partition, child, transfer_rate_bytes_per_sec
292                )
293                new_latency_sec = dfs_helper(
294                    child, latency_so_far_sec + comm_latency_sec
295                )
296                if new_latency_sec > max_latency_sec:
297                    max_latency_sec = new_latency_sec
298            return max_latency_sec
299        return latency_so_far_sec
300
301    def get_top_partitions(partitions: List[Partition]) -> List[Partition]:
302        """This function is to return all the partitions without parents
303        as the starting points of all the paths
304        """
305        top_partitions = []
306        for partition in partitions:
307            # If a partition has no parents, then it is a top partition
308            if len(partition.parents) == 0:
309                top_partitions.append(partition)
310        return top_partitions
311
312    top_partitions = get_top_partitions(partitions)
313    critical_path_latency_sec = 0.0
314    for partition in top_partitions:
315        latency_sec = dfs_helper(partition, 0.0)
316        if latency_sec > critical_path_latency_sec:
317            critical_path_latency_sec = latency_sec
318    return critical_path_latency_sec
319