xref: /aosp_15_r20/external/pytorch/tools/flight_recorder/components/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import math
9from typing import Any, Dict, List, Set, Tuple
10
11from tools.flight_recorder.components.types import (
12    Group,
13    MatchState,
14    Membership,
15    Op,
16    P2P,
17)
18
19
20try:
21    from tabulate import tabulate
22except ModuleNotFoundError:
23    print("tabulate is not installed. Proceeding without it.")
24
25
26def format_frame(frame: Dict[str, str]) -> str:
27    name = frame["name"]
28    filename = frame["filename"]
29    line = frame["line"]
30    return f"{name} at {filename}:{line}"
31
32
33def format_frames(frames: List[Dict[str, str]]) -> str:
34    formatted_frames = []
35    for frame in frames:
36        formatted_frames.append(format_frame(frame))
37    return "\n".join(formatted_frames)
38
39
40def match_one_event(
41    event_a: Dict[Any, Any],
42    event_b: Dict[Any, Any],
43    memberships: Dict[str, Set[Any]],
44    pg_name: str,
45) -> MatchState:
46    op_a = Op(event_a, memberships, pg_name)
47    op_b = Op(event_b, memberships, pg_name)
48    return op_a.match(op_b)
49
50
51def match_coalesced_groups(
52    all_rank_events: Dict[Any, Any],
53    group_size: int,
54    groups: Dict[str, Group],
55    memberships: Dict[str, Set[Any]],
56    _pg_guids: Dict[Tuple[str, int], str],
57) -> bool:
58    """
59    all_rank_events: {
60        rank: [
61            (idx, event_dict)
62        ]
63    }
64
65    Note: it is possible for event dicts in a coalesced group to be asymmetric.
66        e.g. the following events lists form a valid coalescing group
67             events0 [send:1]
68             events1 [recv:0, send:2]
69             events2 [recv:1]
70
71    Rule 1: all ops should find a match
72    Rule 2: relative ordering of sends and recvs in one event list can be arbitrary
73        e.g.
74        events1 [recv:0, send:2]  —> okay
75        events1 [send:2, recv:0] —> also okay
76    Rule 3: sends to the same dest or recvs from the src should be in a consistent order
77        e.g.
78        rank0 [send:1 (100B), send:1 (1000B)]
79        rank1 [recv:0 (1000B), recv:0 (100B)]   —> not okay
80    """
81    all_ops = {
82        rank: [
83            Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
84            for i, e in all_rank_events[rank]
85        ]
86        for rank in all_rank_events
87    }
88
89    def visualize_ops(
90        match: bool,
91        _pg_guids: Dict[Tuple[str, int], str],
92    ) -> None:
93        all_ops = {
94            rank: [
95                Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
96                for i, e in all_rank_events[rank]
97            ]
98            for rank in all_rank_events
99        }
100
101        i = 0
102        row = []
103        progress = True
104        table = []
105        while progress:
106            progress = False
107            for r in all_ops:
108                if len(all_ops[r]) > i:
109                    rank, event = all_rank_events[r][i]
110                    row.append(
111                        Op(
112                            event,
113                            memberships,
114                            _pg_guids[(event["process_group"][0], rank)],
115                        )
116                    )
117                    progress = True
118                else:
119                    row.append(None)  # type: ignore[arg-type]
120            table.append(row)
121            row = []
122            i += 1
123        title = "Match" if match else "MISMATCH"
124        print(f"{title}\n", tabulate(table))  # type: ignore[operator]
125
126    # TODO can't verify seq_id bc there might have been valid seq deltas between ranks even within a pg.
127    for op_list in all_ops.values():
128        if not op_list:
129            # print("TODO- not sure if its valid for only some ranks in a PG to participate in a coalesced op?")
130            return False
131        assert op_list[-1].type == "coalesced"
132        op_list.pop(-1)
133
134    while all_ops:
135        first_rank = next(iter(all_ops))
136        my_ops = all_ops[first_rank]
137
138        if len(all_ops[first_rank]) == 0:
139            all_ops.pop(first_rank)
140            continue
141
142        # lets match the first collective! we need to know which ranks are involved, and ensure that this same
143        # collective is also the first one on those ranks within that group
144        op = my_ops[0]
145        match_idx = -1
146        if op.type in P2P:
147            dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
148            peer_ops = all_ops[dst_global_rank]
149            for i, other in enumerate(peer_ops):
150                if op.match(other) == MatchState.FULLY_MATCHED:
151                    match_idx = i
152                    break
153                elif op.dst == other.src:
154                    # Rule 3
155                    break
156                else:
157                    # Rule 1
158                    continue
159        else:
160            raise NotImplementedError("coalesced collective ops")
161        if match_idx >= 0:
162            my_ops.pop(0)
163            peer_ops.pop(match_idx)
164        else:
165            visualize_ops(False, _pg_guids)
166            return False
167
168    visualize_ops(True, _pg_guids)
169    return True
170
171
172def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int, int]:
173    input_numel = 0
174    output_numel = 0
175    for e in alltoall_cases:
176        input_numel += math.prod(e["input_sizes"][0])
177        output_numel += math.prod(e["output_sizes"][0])
178    return input_numel == output_numel, input_numel, output_numel
179
180
181def find_coalesced_group(
182    pg_name: str,
183    entries: List[Dict[str, Any]],
184    _pg_guids: Dict[Tuple[str, int], str],
185    rank: int,
186) -> List[Tuple[int, Dict[str, Any]]]:
187    """Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
188    build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
189    """
190    found = []
191    collective_seq_id = None
192    for i, e in enumerate(entries):
193        if _pg_guids[(e["process_group"][0], rank)] != pg_name:
194            continue
195        elif collective_seq_id is None:
196            collective_seq_id = (
197                e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"]
198            )
199            found.append((i, e))
200        elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id:
201            found.append((i, e))
202        elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id:
203            found.append((i, e))
204        else:
205            break
206
207    if len(found) > 1:
208        assert found[-1][1]["profiling_name"] == "nccl:coalesced"
209        return found
210    return []
211
212
213def just_print_entries(
214    all_entries: Dict[int, List[Dict[str, Any]]],
215    _groups: Dict[str, Group],
216    _memberships: Dict[str, Set[Any]],
217    _pg_guids: Dict[Tuple[str, int], str],
218    args: argparse.Namespace,
219) -> None:
220    rows = []
221    ranks = sorted(all_entries.keys())
222    headers = [
223        f"Rank {rank}"
224        for rank in ranks
225        if args.selected_ranks is None or rank in args.selected_ranks
226    ]
227    progress = True
228    while progress:
229        progress = False
230        row = []
231        for rank in ranks:
232            if args.selected_ranks is not None and rank not in args.selected_ranks:
233                continue
234            if len(all_entries[rank]) == 0:
235                row.append("")
236            else:
237                entry = all_entries[rank].pop(0)
238                pg_name = _pg_guids[(entry["process_group"][0], rank)]
239                if (
240                    args.pg_filters is None
241                    or entry["process_group"][1] in args.pg_filters
242                ):
243                    row.append(str(Op(entry, _memberships, pg_name)))
244                else:
245                    row.append("")
246                progress = True
247        if progress:
248            rows.append(row)
249
250    print(tabulate(rows, headers=headers))
251
252
253def check_no_missing_dump_files(
254    entries: Dict[int, Any], memberships: List[Membership]
255) -> None:
256    all_ranks = set()
257    for membership in memberships:
258        all_ranks.add(int(membership.global_rank))
259    dumps_ranks = {int(key) for key in entries.keys()}
260    assert (
261        dumps_ranks == all_ranks
262    ), f"Missing dump files from ranks {all_ranks - dumps_ranks}"
263
264
265def check_version(version_by_ranks: Dict[str, str], version: str) -> None:
266    for rank, v in version_by_ranks.items():
267        assert (
268            v == version
269        ), f"Rank {rank} has different version {v} from the given version {version}"
270
271
272def get_version_detail(version: str) -> Tuple[int, int]:
273    version = version.split(".")
274    assert len(version) == 2, f"Invalid version {version}"
275    major, minor = map(int, version)
276    return major, minor
277
278
279def align_trace_from_beginning(
280    entries: Dict[int, List[Dict[str, Any]]]
281) -> Dict[int, List[Dict[str, Any]]]:
282    """
283    Align the trace entries by record ID for entries.
284    This function takes a dictionary of rank names to lists of trace entries as input.
285    Each trace entry is a dictionary containing information about a collective operation,
286    including its unique identifier (`record_id` is monotonically increasing as we write into the ring buffer).
287    The function finds the largest starting point across all ranks by taking the maximum
288    `record_id` value of the first entry in each rank. Finally, it filters out any
289    entries with `record_id` values less than the maximum starting point.
290    The function returns the updated dictionary of sorted and filtered trace entries.
291
292    Args:
293        entries (Dict[str, List[Dict[str, Any]]]): A dictionary of rank names to lists of trace entries.
294
295    Returns:
296        entries (Dict[str, List[Dict[str, Any]]]): Entries sorted by record ID and filtered by the maximum starting point.
297    """
298
299    maximum_starting_record_id = 0
300    for rank in entries:
301        # Although this is a ring buffer, we already sort the entries by `record_id` when dumping, we just
302        # need to find the largest starting point. For example, if the buffer has the following entries:
303        # Rank 0: [0, 1, 2, 3, 4, 5, 6]
304        # Rank 1: [1, 2, 3, 4, 5, 6, 7]
305        # Rank 2: [2, 3, 4, 5, 6, 7, 8]
306        # Rank 3: [0, 1, 2, 3, 4, 5, None]
307        # Then we should start from collective 2 not 0 because any collective before,
308        # we don't have complete records from all ranks so we need to ignore them.
309        first_record_id = entries[rank][0]["record_id"]
310        maximum_starting_record_id = max(maximum_starting_record_id, first_record_id)
311
312    for rank in entries:
313        entries[rank] = [
314            entry
315            for entry in entries[rank]
316            if entry["record_id"] >= maximum_starting_record_id
317        ]
318
319    return entries
320