1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import dataclasses 5from enum import Enum 6from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union 7 8import torch 9from torch._dynamo.utils import counters 10from torch._inductor.utils import InputType 11 12 13perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") 14static_inputs_log = torch._logging.getArtifactLogger( 15 __name__, "cudagraph_static_inputs" 16) 17 18 19OutputType = List[Optional[Union[int, torch.Tensor]]] 20ModelType = Callable[[List[InputType]], OutputType] 21 22 23@dataclasses.dataclass(frozen=True) 24class FunctionID: 25 "Unique counter of a function wrapped in cudagraphify_impl" 26 id: int 27 28 29@dataclasses.dataclass(frozen=True) 30class PlaceholderInfo: 31 """ 32 A serializable version of torch.fx.Node that contains information 33 pertinent to placeholder stack traces. We use these in logging and error messages 34 related to cudagraphs, and will cache these results. 35 """ 36 37 name: str 38 stack_trace: Optional[str] 39 # This field is recursive, but never cyclic (since a node never uses itself) 40 users: List[PlaceholderInfo] 41 mutating_use_stack_trace: Optional[str] 42 43 44@dataclasses.dataclass(frozen=True) 45class WrappedFunction: 46 """ 47 Represents a function that you want to record for CUDA graph replay, 48 with a little more metadata so we can identify if we have an applicable 49 CUDA graph in our CUDA graph tree for it. 50 """ 51 52 model: Callable[..., Any] 53 static_input_idxs: Sequence[int] 54 id: FunctionID 55 constants: Tuple[torch.Tensor, ...] 56 placeholders: Sequence[PlaceholderInfo] 57 mutated_input_idxs: Sequence[int] 58 59 60def get_mutating_use_stack_trace_from_node( 61 placeholder_node: torch.fx.Node, 62) -> Optional[str]: 63 # reinplaced uses might have a single, non-copy_ use 64 if len(placeholder_node.users) == 1: 65 return next(iter(placeholder_node.users)).meta.get("stack_trace", None) 66 67 for use in placeholder_node.users: 68 if use.target == torch.ops.aten.copy_.default: 69 if stack_trace := use.meta.get("stack_trace", None): 70 return stack_trace 71 72 return None 73 74 75def get_mutating_use_stack_trace(placeholder_info: PlaceholderInfo) -> Optional[str]: 76 return placeholder_info.mutating_use_stack_trace 77 78 79def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo: 80 name = placeholder_node.name 81 stack_trace = placeholder_node.meta.get("stack_trace", None) 82 users = [] 83 mutating_use_stack_trace = None 84 # Only recurse to users once, since we only care about user's stack traces 85 if placeholder_node.op == "placeholder": 86 users = [to_placeholder_info(i) for i in placeholder_node.users] 87 mutating_use_stack_trace = get_mutating_use_stack_trace_from_node( 88 placeholder_node 89 ) 90 91 return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace) 92 93 94def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]: 95 return [ 96 to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder" 97 ] 98 99 100def format_default_skip_message(reason: str) -> str: 101 return f"skipping cudagraphs due to {reason}" 102 103 104def get_mutation_stack_trace( 105 placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int] 106) -> str: 107 stack_trace: Optional[str] = "" 108 109 for idx in mutation_indices: 110 placeholder = placeholders[idx] 111 if stack_trace := get_mutating_use_stack_trace(placeholder): 112 break 113 114 msg = format_default_skip_message( 115 f"mutated inputs ({len(mutation_indices)} instances)" 116 ) 117 if stack_trace: 118 return f"{msg}. Found from : \n {stack_trace}" 119 120 return msg 121 122 123def check_for_mutation( 124 func: WrappedFunction, 125 inputs: List[InputType], 126 is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool], 127) -> Optional[str]: 128 # doesnt work for non-trees because the warmup run would apply mutation twice 129 if torch._inductor.config.triton.cudagraph_trees: 130 # checking if mutation is only on parameters/static inputs 131 mutation_indices: Sequence[int] = [ 132 idx 133 for idx in func.mutated_input_idxs 134 if not ( 135 idx in func.static_input_idxs 136 or is_cuda_graph_recorded_tensor(inputs[idx]) # type: ignore[arg-type] 137 ) 138 ] 139 else: 140 mutation_indices = func.mutated_input_idxs 141 142 static_inputs_log.debug( 143 "check mutation static input indices: %s", func.static_input_idxs 144 ) 145 static_inputs_log.debug("check mutation mutation indices: %s", mutation_indices) 146 147 return ( 148 get_mutation_stack_trace(func.placeholders, mutation_indices) 149 if mutation_indices 150 else None 151 ) 152 153 154def _get_use_stack_trace(node) -> Optional[str]: 155 for use in node.users: 156 if stack_trace := use.meta.get("stack_trace", None): 157 return stack_trace 158 return None 159 160 161def check_multiple_devices_or_any_cpu_nodes( 162 device_node_mapping: Dict[torch.device, torch.fx.Node] 163) -> Optional[str]: 164 if cpu_node := device_node_mapping.get(torch.device("cpu")): 165 msg = f"cpu device ({cpu_node.name})" 166 if stack_trace := _get_use_stack_trace(cpu_node): 167 return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}") 168 169 return format_default_skip_message(msg) 170 171 if ( 172 len(device_node_mapping) == 1 173 and next(iter(device_node_mapping.keys())).type == "cuda" 174 ): 175 return None 176 177 keys_repr = (repr(key) for key in device_node_mapping.keys()) 178 return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}") 179 180 181def check_lowering_disable_cudagraph( 182 device_node_mapping: Dict[torch.device, torch.fx.Node] 183): 184 return check_multiple_devices_or_any_cpu_nodes(device_node_mapping) 185 186 187def log_cudagraph_skip_and_bump_counter(msg): 188 perf_hint_log.warning(msg) 189 counters["inductor"]["cudagraph_skips"] += 1 190 191 192@dataclasses.dataclass 193class BoxedDeviceIndex: 194 value: Optional[int] 195 196 def set(self, device_idx: Optional[int]): 197 assert device_idx is None or isinstance(device_idx, int) 198 self.value = device_idx 199 200 201def check_for_mutation_ignore_cuda_graph_managed_tensor( 202 gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: Sequence[int] 203) -> Optional[str]: 204 default_msg = format_default_skip_message("mutated inputs") 205 206 # doesnt work for non-trees because the warmup run would apply mutation twice 207 if torch._inductor.config.triton.cudagraph_trees: 208 unique_idxs = set(static_input_idxs) 209 # checking if mutation is only on parameters/static inputs 210 mutation_indices = [ 211 idx for idx in compiled_graph.mutated_input_idxs if idx not in unique_idxs 212 ] 213 has_mutation = len(mutation_indices) != 0 214 if not has_mutation: 215 return None 216 placeholders = get_placeholder_info(gm.graph) 217 return get_mutation_stack_trace(placeholders, mutation_indices) 218 219 else: 220 has_mutation = len(compiled_graph.mutated_inputs) != 0 221 return None if not has_mutation else default_msg 222 223 224def get_placeholder_stack_trace(placeholder: PlaceholderInfo) -> Optional[str]: 225 """ 226 Gets the first non-empty stack trace of a placeholder or its users. 227 """ 228 if placeholder.stack_trace: 229 return placeholder.stack_trace 230 231 for user in placeholder.users: 232 if user.stack_trace: 233 return user.stack_trace 234 235 return None 236 237 238class CheckInvariantStatus(Enum): 239 # Check invariant succeeded 240 SUCCESS = 1 241 242 # Previously managed data pointers are not stable 243 CudagraphManagedIdxMismatch = 2 244 245 # Static tensor input addresses are not stable 246 StaticInputIdxMismatch = 3 247 248 # Expected dead indices before graph are live 249 ExpectedDeadIndicesBeforeGraphMismatch = 4 250 251 def __str__(self) -> str: 252 if self.name == "CudagraphManagedIdxMismatch": 253 return "cudagraph managed tensor data pointer changed" 254 elif self.name == "StaticInputIdxMismatch": 255 return "static input data pointer changed" 256 elif self.name == "ExpectedDeadIndicesBeforeGraphMismatch": 257 return "expected dead indices before graph are live" 258 else: 259 return f"{self.name}: {self.value}" 260 261 262def log_data_ptr_mismatch( 263 placeholders: Sequence[PlaceholderInfo], 264 inputs: List[InputType], 265 recorded_data_ptr: Sequence[Optional[int]], 266 target_idxs: Sequence[int], 267 mismatch: CheckInvariantStatus, 268) -> str: 269 """ 270 Logs the mismatch between input data pointers and recorded data pointers. 271 This checks only idxs in target_idxs. 272 """ 273 assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len( 274 placeholders 275 ), "length mismatch between inputs, recorded_data_ptr, and placeholders" 276 277 t_tensors = [inputs[i] for i in target_idxs] 278 t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs] 279 error_msg = f"{mismatch}.\n" 280 for i, (tensor, data_ptr) in enumerate(zip(t_tensors, t_data_ptrs)): 281 assert isinstance(tensor, torch.Tensor) 282 index = target_idxs[i] 283 if tensor.data_ptr() != data_ptr: 284 placeholder = placeholders[index] 285 error_msg = ( 286 f"{error_msg}input name: {placeholder.name}. " 287 f"data pointer changed from {data_ptr} to {tensor.data_ptr()}. " 288 f"input stack trace: {get_placeholder_stack_trace(placeholder)}\n" 289 ) 290 return error_msg 291 292 293def maybe_warning_due_to_dynamic_shape( 294 fn_cache: Dict[Tuple[int, ...], Callable[..., Any]], 295 new_int_key: Any, 296) -> bool: 297 num_cudagraphs = len(fn_cache.keys()) + 1 298 299 def warn_msg(): 300 return ( 301 "CUDAGraph supports dynamic shapes by recording a new graph for each " 302 "distinct input size. Recording too many CUDAGraphs may lead to " 303 f"extra overhead. We have observed {num_cudagraphs} distinct sizes. " 304 "Please consider the following options for better performance: " 305 "a) padding inputs to a few fixed number of shapes; or b) set " 306 "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. " 307 "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None " 308 "to silence this warning." 309 ) 310 311 if ( 312 torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit 313 and num_cudagraphs 314 > torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit 315 ): 316 perf_hint_log.warning(warn_msg()) 317 return True 318 319 return False 320 321 322@dataclasses.dataclass(frozen=True) 323class CudagraphCachedInfo: 324 """ 325 Info needed to realign inputs 326 """ 327 328 placeholders: Sequence[PlaceholderInfo] 329 stack_traces: List[Optional[str]] 330 cudagraph_fail_reasons: List[str] 331