1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import math 9from typing import Any, Dict, List, Set, Tuple 10 11from tools.flight_recorder.components.types import ( 12 Group, 13 MatchState, 14 Membership, 15 Op, 16 P2P, 17) 18 19 20try: 21 from tabulate import tabulate 22except ModuleNotFoundError: 23 print("tabulate is not installed. Proceeding without it.") 24 25 26def format_frame(frame: Dict[str, str]) -> str: 27 name = frame["name"] 28 filename = frame["filename"] 29 line = frame["line"] 30 return f"{name} at {filename}:{line}" 31 32 33def format_frames(frames: List[Dict[str, str]]) -> str: 34 formatted_frames = [] 35 for frame in frames: 36 formatted_frames.append(format_frame(frame)) 37 return "\n".join(formatted_frames) 38 39 40def match_one_event( 41 event_a: Dict[Any, Any], 42 event_b: Dict[Any, Any], 43 memberships: Dict[str, Set[Any]], 44 pg_name: str, 45) -> MatchState: 46 op_a = Op(event_a, memberships, pg_name) 47 op_b = Op(event_b, memberships, pg_name) 48 return op_a.match(op_b) 49 50 51def match_coalesced_groups( 52 all_rank_events: Dict[Any, Any], 53 group_size: int, 54 groups: Dict[str, Group], 55 memberships: Dict[str, Set[Any]], 56 _pg_guids: Dict[Tuple[str, int], str], 57) -> bool: 58 """ 59 all_rank_events: { 60 rank: [ 61 (idx, event_dict) 62 ] 63 } 64 65 Note: it is possible for event dicts in a coalesced group to be asymmetric. 66 e.g. the following events lists form a valid coalescing group 67 events0 [send:1] 68 events1 [recv:0, send:2] 69 events2 [recv:1] 70 71 Rule 1: all ops should find a match 72 Rule 2: relative ordering of sends and recvs in one event list can be arbitrary 73 e.g. 74 events1 [recv:0, send:2] —> okay 75 events1 [send:2, recv:0] —> also okay 76 Rule 3: sends to the same dest or recvs from the src should be in a consistent order 77 e.g. 78 rank0 [send:1 (100B), send:1 (1000B)] 79 rank1 [recv:0 (1000B), recv:0 (100B)] —> not okay 80 """ 81 all_ops = { 82 rank: [ 83 Op(e, memberships, _pg_guids[(e["process_group"][0], rank)]) 84 for i, e in all_rank_events[rank] 85 ] 86 for rank in all_rank_events 87 } 88 89 def visualize_ops( 90 match: bool, 91 _pg_guids: Dict[Tuple[str, int], str], 92 ) -> None: 93 all_ops = { 94 rank: [ 95 Op(e, memberships, _pg_guids[(e["process_group"][0], rank)]) 96 for i, e in all_rank_events[rank] 97 ] 98 for rank in all_rank_events 99 } 100 101 i = 0 102 row = [] 103 progress = True 104 table = [] 105 while progress: 106 progress = False 107 for r in all_ops: 108 if len(all_ops[r]) > i: 109 rank, event = all_rank_events[r][i] 110 row.append( 111 Op( 112 event, 113 memberships, 114 _pg_guids[(event["process_group"][0], rank)], 115 ) 116 ) 117 progress = True 118 else: 119 row.append(None) # type: ignore[arg-type] 120 table.append(row) 121 row = [] 122 i += 1 123 title = "Match" if match else "MISMATCH" 124 print(f"{title}\n", tabulate(table)) # type: ignore[operator] 125 126 # TODO can't verify seq_id bc there might have been valid seq deltas between ranks even within a pg. 127 for op_list in all_ops.values(): 128 if not op_list: 129 # print("TODO- not sure if its valid for only some ranks in a PG to participate in a coalesced op?") 130 return False 131 assert op_list[-1].type == "coalesced" 132 op_list.pop(-1) 133 134 while all_ops: 135 first_rank = next(iter(all_ops)) 136 my_ops = all_ops[first_rank] 137 138 if len(all_ops[first_rank]) == 0: 139 all_ops.pop(first_rank) 140 continue 141 142 # lets match the first collective! we need to know which ranks are involved, and ensure that this same 143 # collective is also the first one on those ranks within that group 144 op = my_ops[0] 145 match_idx = -1 146 if op.type in P2P: 147 dst_global_rank = sorted(memberships[op.pg_name])[op.dst] 148 peer_ops = all_ops[dst_global_rank] 149 for i, other in enumerate(peer_ops): 150 if op.match(other) == MatchState.FULLY_MATCHED: 151 match_idx = i 152 break 153 elif op.dst == other.src: 154 # Rule 3 155 break 156 else: 157 # Rule 1 158 continue 159 else: 160 raise NotImplementedError("coalesced collective ops") 161 if match_idx >= 0: 162 my_ops.pop(0) 163 peer_ops.pop(match_idx) 164 else: 165 visualize_ops(False, _pg_guids) 166 return False 167 168 visualize_ops(True, _pg_guids) 169 return True 170 171 172def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int, int]: 173 input_numel = 0 174 output_numel = 0 175 for e in alltoall_cases: 176 input_numel += math.prod(e["input_sizes"][0]) 177 output_numel += math.prod(e["output_sizes"][0]) 178 return input_numel == output_numel, input_numel, output_numel 179 180 181def find_coalesced_group( 182 pg_name: str, 183 entries: List[Dict[str, Any]], 184 _pg_guids: Dict[Tuple[str, int], str], 185 rank: int, 186) -> List[Tuple[int, Dict[str, Any]]]: 187 """Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones, 188 build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id 189 """ 190 found = [] 191 collective_seq_id = None 192 for i, e in enumerate(entries): 193 if _pg_guids[(e["process_group"][0], rank)] != pg_name: 194 continue 195 elif collective_seq_id is None: 196 collective_seq_id = ( 197 e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"] 198 ) 199 found.append((i, e)) 200 elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id: 201 found.append((i, e)) 202 elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id: 203 found.append((i, e)) 204 else: 205 break 206 207 if len(found) > 1: 208 assert found[-1][1]["profiling_name"] == "nccl:coalesced" 209 return found 210 return [] 211 212 213def just_print_entries( 214 all_entries: Dict[int, List[Dict[str, Any]]], 215 _groups: Dict[str, Group], 216 _memberships: Dict[str, Set[Any]], 217 _pg_guids: Dict[Tuple[str, int], str], 218 args: argparse.Namespace, 219) -> None: 220 rows = [] 221 ranks = sorted(all_entries.keys()) 222 headers = [ 223 f"Rank {rank}" 224 for rank in ranks 225 if args.selected_ranks is None or rank in args.selected_ranks 226 ] 227 progress = True 228 while progress: 229 progress = False 230 row = [] 231 for rank in ranks: 232 if args.selected_ranks is not None and rank not in args.selected_ranks: 233 continue 234 if len(all_entries[rank]) == 0: 235 row.append("") 236 else: 237 entry = all_entries[rank].pop(0) 238 pg_name = _pg_guids[(entry["process_group"][0], rank)] 239 if ( 240 args.pg_filters is None 241 or entry["process_group"][1] in args.pg_filters 242 ): 243 row.append(str(Op(entry, _memberships, pg_name))) 244 else: 245 row.append("") 246 progress = True 247 if progress: 248 rows.append(row) 249 250 print(tabulate(rows, headers=headers)) 251 252 253def check_no_missing_dump_files( 254 entries: Dict[int, Any], memberships: List[Membership] 255) -> None: 256 all_ranks = set() 257 for membership in memberships: 258 all_ranks.add(int(membership.global_rank)) 259 dumps_ranks = {int(key) for key in entries.keys()} 260 assert ( 261 dumps_ranks == all_ranks 262 ), f"Missing dump files from ranks {all_ranks - dumps_ranks}" 263 264 265def check_version(version_by_ranks: Dict[str, str], version: str) -> None: 266 for rank, v in version_by_ranks.items(): 267 assert ( 268 v == version 269 ), f"Rank {rank} has different version {v} from the given version {version}" 270 271 272def get_version_detail(version: str) -> Tuple[int, int]: 273 version = version.split(".") 274 assert len(version) == 2, f"Invalid version {version}" 275 major, minor = map(int, version) 276 return major, minor 277 278 279def align_trace_from_beginning( 280 entries: Dict[int, List[Dict[str, Any]]] 281) -> Dict[int, List[Dict[str, Any]]]: 282 """ 283 Align the trace entries by record ID for entries. 284 This function takes a dictionary of rank names to lists of trace entries as input. 285 Each trace entry is a dictionary containing information about a collective operation, 286 including its unique identifier (`record_id` is monotonically increasing as we write into the ring buffer). 287 The function finds the largest starting point across all ranks by taking the maximum 288 `record_id` value of the first entry in each rank. Finally, it filters out any 289 entries with `record_id` values less than the maximum starting point. 290 The function returns the updated dictionary of sorted and filtered trace entries. 291 292 Args: 293 entries (Dict[str, List[Dict[str, Any]]]): A dictionary of rank names to lists of trace entries. 294 295 Returns: 296 entries (Dict[str, List[Dict[str, Any]]]): Entries sorted by record ID and filtered by the maximum starting point. 297 """ 298 299 maximum_starting_record_id = 0 300 for rank in entries: 301 # Although this is a ring buffer, we already sort the entries by `record_id` when dumping, we just 302 # need to find the largest starting point. For example, if the buffer has the following entries: 303 # Rank 0: [0, 1, 2, 3, 4, 5, 6] 304 # Rank 1: [1, 2, 3, 4, 5, 6, 7] 305 # Rank 2: [2, 3, 4, 5, 6, 7, 8] 306 # Rank 3: [0, 1, 2, 3, 4, 5, None] 307 # Then we should start from collective 2 not 0 because any collective before, 308 # we don't have complete records from all ranks so we need to ignore them. 309 first_record_id = entries[rank][0]["record_id"] 310 maximum_starting_record_id = max(maximum_starting_record_id, first_record_id) 311 312 for rank in entries: 313 entries[rank] = [ 314 entry 315 for entry in entries[rank] 316 if entry["record_id"] >= maximum_starting_record_id 317 ] 318 319 return entries 320