xref: /aosp_15_r20/external/pytorch/torch/_inductor/cudagraph_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import dataclasses
5from enum import Enum
6from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
7
8import torch
9from torch._dynamo.utils import counters
10from torch._inductor.utils import InputType
11
12
13perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
14static_inputs_log = torch._logging.getArtifactLogger(
15    __name__, "cudagraph_static_inputs"
16)
17
18
19OutputType = List[Optional[Union[int, torch.Tensor]]]
20ModelType = Callable[[List[InputType]], OutputType]
21
22
23@dataclasses.dataclass(frozen=True)
24class FunctionID:
25    "Unique counter of a function wrapped in cudagraphify_impl"
26    id: int
27
28
29@dataclasses.dataclass(frozen=True)
30class PlaceholderInfo:
31    """
32    A serializable version of torch.fx.Node that contains information
33    pertinent to placeholder stack traces. We use these in logging and error messages
34    related to cudagraphs, and will cache these results.
35    """
36
37    name: str
38    stack_trace: Optional[str]
39    # This field is recursive, but never cyclic (since a node never uses itself)
40    users: List[PlaceholderInfo]
41    mutating_use_stack_trace: Optional[str]
42
43
44@dataclasses.dataclass(frozen=True)
45class WrappedFunction:
46    """
47    Represents a function that you want to record for CUDA graph replay,
48    with a little more metadata so we can identify if we have an applicable
49    CUDA graph in our CUDA graph tree for it.
50    """
51
52    model: Callable[..., Any]
53    static_input_idxs: Sequence[int]
54    id: FunctionID
55    constants: Tuple[torch.Tensor, ...]
56    placeholders: Sequence[PlaceholderInfo]
57    mutated_input_idxs: Sequence[int]
58
59
60def get_mutating_use_stack_trace_from_node(
61    placeholder_node: torch.fx.Node,
62) -> Optional[str]:
63    # reinplaced uses might have a single, non-copy_ use
64    if len(placeholder_node.users) == 1:
65        return next(iter(placeholder_node.users)).meta.get("stack_trace", None)
66
67    for use in placeholder_node.users:
68        if use.target == torch.ops.aten.copy_.default:
69            if stack_trace := use.meta.get("stack_trace", None):
70                return stack_trace
71
72    return None
73
74
75def get_mutating_use_stack_trace(placeholder_info: PlaceholderInfo) -> Optional[str]:
76    return placeholder_info.mutating_use_stack_trace
77
78
79def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo:
80    name = placeholder_node.name
81    stack_trace = placeholder_node.meta.get("stack_trace", None)
82    users = []
83    mutating_use_stack_trace = None
84    # Only recurse to users once, since we only care about user's stack traces
85    if placeholder_node.op == "placeholder":
86        users = [to_placeholder_info(i) for i in placeholder_node.users]
87        mutating_use_stack_trace = get_mutating_use_stack_trace_from_node(
88            placeholder_node
89        )
90
91    return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace)
92
93
94def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]:
95    return [
96        to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder"
97    ]
98
99
100def format_default_skip_message(reason: str) -> str:
101    return f"skipping cudagraphs due to {reason}"
102
103
104def get_mutation_stack_trace(
105    placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int]
106) -> str:
107    stack_trace: Optional[str] = ""
108
109    for idx in mutation_indices:
110        placeholder = placeholders[idx]
111        if stack_trace := get_mutating_use_stack_trace(placeholder):
112            break
113
114    msg = format_default_skip_message(
115        f"mutated inputs ({len(mutation_indices)} instances)"
116    )
117    if stack_trace:
118        return f"{msg}. Found from : \n {stack_trace}"
119
120    return msg
121
122
123def check_for_mutation(
124    func: WrappedFunction,
125    inputs: List[InputType],
126    is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool],
127) -> Optional[str]:
128    # doesnt work for non-trees because the warmup run would apply mutation twice
129    if torch._inductor.config.triton.cudagraph_trees:
130        # checking if mutation is only on parameters/static inputs
131        mutation_indices: Sequence[int] = [
132            idx
133            for idx in func.mutated_input_idxs
134            if not (
135                idx in func.static_input_idxs
136                or is_cuda_graph_recorded_tensor(inputs[idx])  # type: ignore[arg-type]
137            )
138        ]
139    else:
140        mutation_indices = func.mutated_input_idxs
141
142    static_inputs_log.debug(
143        "check mutation static input indices: %s", func.static_input_idxs
144    )
145    static_inputs_log.debug("check mutation mutation indices: %s", mutation_indices)
146
147    return (
148        get_mutation_stack_trace(func.placeholders, mutation_indices)
149        if mutation_indices
150        else None
151    )
152
153
154def _get_use_stack_trace(node) -> Optional[str]:
155    for use in node.users:
156        if stack_trace := use.meta.get("stack_trace", None):
157            return stack_trace
158    return None
159
160
161def check_multiple_devices_or_any_cpu_nodes(
162    device_node_mapping: Dict[torch.device, torch.fx.Node]
163) -> Optional[str]:
164    if cpu_node := device_node_mapping.get(torch.device("cpu")):
165        msg = f"cpu device ({cpu_node.name})"
166        if stack_trace := _get_use_stack_trace(cpu_node):
167            return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}")
168
169        return format_default_skip_message(msg)
170
171    if (
172        len(device_node_mapping) == 1
173        and next(iter(device_node_mapping.keys())).type == "cuda"
174    ):
175        return None
176
177    keys_repr = (repr(key) for key in device_node_mapping.keys())
178    return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}")
179
180
181def check_lowering_disable_cudagraph(
182    device_node_mapping: Dict[torch.device, torch.fx.Node]
183):
184    return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
185
186
187def log_cudagraph_skip_and_bump_counter(msg):
188    perf_hint_log.warning(msg)
189    counters["inductor"]["cudagraph_skips"] += 1
190
191
192@dataclasses.dataclass
193class BoxedDeviceIndex:
194    value: Optional[int]
195
196    def set(self, device_idx: Optional[int]):
197        assert device_idx is None or isinstance(device_idx, int)
198        self.value = device_idx
199
200
201def check_for_mutation_ignore_cuda_graph_managed_tensor(
202    gm: torch.fx.GraphModule, compiled_graph, static_input_idxs: Sequence[int]
203) -> Optional[str]:
204    default_msg = format_default_skip_message("mutated inputs")
205
206    # doesnt work for non-trees because the warmup run would apply mutation twice
207    if torch._inductor.config.triton.cudagraph_trees:
208        unique_idxs = set(static_input_idxs)
209        # checking if mutation is only on parameters/static inputs
210        mutation_indices = [
211            idx for idx in compiled_graph.mutated_input_idxs if idx not in unique_idxs
212        ]
213        has_mutation = len(mutation_indices) != 0
214        if not has_mutation:
215            return None
216        placeholders = get_placeholder_info(gm.graph)
217        return get_mutation_stack_trace(placeholders, mutation_indices)
218
219    else:
220        has_mutation = len(compiled_graph.mutated_inputs) != 0
221        return None if not has_mutation else default_msg
222
223
224def get_placeholder_stack_trace(placeholder: PlaceholderInfo) -> Optional[str]:
225    """
226    Gets the first non-empty stack trace of a placeholder or its users.
227    """
228    if placeholder.stack_trace:
229        return placeholder.stack_trace
230
231    for user in placeholder.users:
232        if user.stack_trace:
233            return user.stack_trace
234
235    return None
236
237
238class CheckInvariantStatus(Enum):
239    # Check invariant succeeded
240    SUCCESS = 1
241
242    # Previously managed data pointers are not stable
243    CudagraphManagedIdxMismatch = 2
244
245    # Static tensor input addresses are not stable
246    StaticInputIdxMismatch = 3
247
248    # Expected dead indices before graph are live
249    ExpectedDeadIndicesBeforeGraphMismatch = 4
250
251    def __str__(self) -> str:
252        if self.name == "CudagraphManagedIdxMismatch":
253            return "cudagraph managed tensor data pointer changed"
254        elif self.name == "StaticInputIdxMismatch":
255            return "static input data pointer changed"
256        elif self.name == "ExpectedDeadIndicesBeforeGraphMismatch":
257            return "expected dead indices before graph are live"
258        else:
259            return f"{self.name}: {self.value}"
260
261
262def log_data_ptr_mismatch(
263    placeholders: Sequence[PlaceholderInfo],
264    inputs: List[InputType],
265    recorded_data_ptr: Sequence[Optional[int]],
266    target_idxs: Sequence[int],
267    mismatch: CheckInvariantStatus,
268) -> str:
269    """
270    Logs the mismatch between input data pointers and recorded data pointers.
271    This checks only idxs in target_idxs.
272    """
273    assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(
274        placeholders
275    ), "length mismatch between inputs, recorded_data_ptr, and placeholders"
276
277    t_tensors = [inputs[i] for i in target_idxs]
278    t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs]
279    error_msg = f"{mismatch}.\n"
280    for i, (tensor, data_ptr) in enumerate(zip(t_tensors, t_data_ptrs)):
281        assert isinstance(tensor, torch.Tensor)
282        index = target_idxs[i]
283        if tensor.data_ptr() != data_ptr:
284            placeholder = placeholders[index]
285            error_msg = (
286                f"{error_msg}input name: {placeholder.name}. "
287                f"data pointer changed from {data_ptr} to {tensor.data_ptr()}. "
288                f"input stack trace: {get_placeholder_stack_trace(placeholder)}\n"
289            )
290    return error_msg
291
292
293def maybe_warning_due_to_dynamic_shape(
294    fn_cache: Dict[Tuple[int, ...], Callable[..., Any]],
295    new_int_key: Any,
296) -> bool:
297    num_cudagraphs = len(fn_cache.keys()) + 1
298
299    def warn_msg():
300        return (
301            "CUDAGraph supports dynamic shapes by recording a new graph for each "
302            "distinct input size. Recording too many CUDAGraphs may lead to "
303            f"extra overhead. We have observed {num_cudagraphs} distinct sizes. "
304            "Please consider the following options for better performance: "
305            "a) padding inputs to a few fixed number of shapes; or b) set "
306            "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
307            "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
308            "to silence this warning."
309        )
310
311    if (
312        torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit
313        and num_cudagraphs
314        > torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit
315    ):
316        perf_hint_log.warning(warn_msg())
317        return True
318
319    return False
320
321
322@dataclasses.dataclass(frozen=True)
323class CudagraphCachedInfo:
324    """
325    Info needed to realign inputs
326    """
327
328    placeholders: Sequence[PlaceholderInfo]
329    stack_traces: List[Optional[str]]
330    cudagraph_fail_reasons: List[str]
331