xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/recording.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import inspect
4import itertools
5import logging
6from dataclasses import dataclass
7from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
9import torch
10import torch.utils._pytree as pytree
11
12
13log = logging.getLogger(__name__)
14trace_shape_events_log = torch._logging.getArtifactLogger(
15    __name__, "trace_shape_events"
16)
17
18
19__all__ = [
20    "ShapeEnvEvent",
21    "record_shapeenv_event",
22    "replay_shape_env_events",
23    "FakeTensorMeta",
24    "shape_env_check_state_equal",
25    "NotEqualError",
26]
27
28# [Note: Recording ShapeEnv Events]
29# =================================
30#
31# What is a ShapeEnv event?
32# -------------------------
33# We consider a ShapeEnv event every function call (ShapeEnv method or
34# independent function) that modifies the state of the ShapeEnv instance.
35# Such calls are recorded alongside their positional and keyword arguments,
36# so that it may be replayed over a different ShapeEnv instance.
37#
38# See [Note: ShapeEnv State Equality] for what is considered the state
39# of a ShapeEnv instance.
40#
41# What is it for?
42# ---------------
43# ShapeEnv events recording is used for reconstructing the ShapeEnv in an
44# arbitrary state in time.
45#
46# Being able to arbitrarily replay events like so is useful, mainly for
47# translation validation bisection. i.e. if a ValidationException has been
48# raised, find the earliest point in time where the translation validation
49# fails.
50#
51# Besides that, it also allows us to inspect the given instance and,
52# for example, check the guards that would actually be issued at that point.
53#
54# What kind of arguments can be stored in an event?
55# -------------------------------------------------
56# There's no specific rule for what cannot be used as an argument.
57# That said, pay special attention to the following cases:
58#
59#   1. Tensor inputs: there are some tests that check whether the inputs
60#      were garbage collected after execution. These will fail if there's
61#      an event that is holding a reference to those inputs.
62#
63#   2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that
64#      will be automatically replaced by the new given ShapeEnv instance.
65#
66#   3. SymTypes arguments: they also hold references to ShapeEnv. So,
67#      whenever we see them, we create a new instance, replacing the
68#      ShapeEnv reference.
69#
70#   4. FX nodes: specifically, FX nodes from the FX graph for symbolic
71#      shapes. That argument must be replaced when replaying the event at
72#      ShapeEnvEvent.run, since it has to reference a node from the given
73#      instance, and not from the recorded instance.
74
75
76# Event class for reconstructing ShapeEnv at arbitrary time.
77#
78# Represents a method call that mutates ShapeEnv in a way that affects the
79# issued guards, when ShapeEnv.produce_guards is called.
80@dataclass
81class ShapeEnvEvent:
82    # ShapeEnv method.
83    f: Callable
84
85    # Arguments and keyword arguments called with.
86    args: Optional[List[Any]] = None
87    kwargs: Optional[Dict[str, Any]] = None
88
89    # List of tracked_fakes at the time the method was called.
90    tracked_fakes: Optional[List[Any]] = None
91
92    # Name of the captured event.
93    # Used for special handling of particular methods.
94    name: Optional[str] = None
95
96    # Replay itself, but using shape_env as self.
97    def run(self, shape_env=None) -> Any:
98        from torch.fx.experimental.symbolic_shapes import (
99            is_symbolic,
100            ShapeEnv,
101            SymTypes,
102        )
103
104        # Special handling for the constructor event.
105        if self.f is ShapeEnv:
106            assert shape_env is None and self.args is None and self.kwargs is not None
107            return ShapeEnv(**self.kwargs)
108
109        assert shape_env is not None
110        args = list(self.args or [])
111        kwargs = dict(self.kwargs or {})
112
113        # Replace any argument of type ShapeEnv by the given one.
114        args, kwargs = pytree.tree_map_only(
115            ShapeEnv, lambda _: shape_env, (args, kwargs)
116        )
117
118        # Replace any argument of type SymTypes by a new instance,
119        # replacing its ShapeEnv reference.
120        args, kwargs = pytree.tree_map_only(
121            lambda x: isinstance(x, SymTypes) and is_symbolic(x),
122            lambda a: type(a)(a.node.with_shape_env(shape_env)),
123            (args, kwargs),
124        )
125
126        # Converts FX nodes using the mapping argument.
127        def maybe_convert_node(x: Any) -> Any:
128            if not isinstance(x, torch.fx.Node):
129                # Don't do anything to x if it's not an FX node.
130                return x
131
132            # If, at some point, we created an FX node, it means that translation validation is on.
133            # It also means we are building an FX graph for symbolic shapes at shape_env.graph, and
134            # we are tracking node names at shape_env.name_to_node.
135            assert hasattr(shape_env, "name_to_node")
136            name_to_node = shape_env.name_to_node  # type: ignore[attr-defined]
137            assert x.name in name_to_node
138            return name_to_node[x.name]
139
140        # Replaces the value of an specific argument by the result of fn.
141        def replacearg(index: int, key: str, fn: Callable):
142            if index < len(args):
143                args[index] = fn(args[index])
144            if key in kwargs:
145                kwargs[key] = fn(kwargs[key])
146
147        if self.is_create_fx_call_function():
148            # ShapeEnv.create_fx_call_function:
149            # "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv.
150            # They must be replaced, since a "call_function" FX node with this tuple as argument
151            # will be added to the FX graph of the new shape_env.
152            replacearg(
153                index=2,
154                key="args",
155                fn=lambda args: tuple(maybe_convert_node(a) for a in args),
156            )
157        if self.is_evaluate_expr() or self.is_defer_runtime_assert():
158            # ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert:
159            # "fx_node" parameter is an (optional) FX node that represents the evaluate expression.
160            # They must be replaced, since it will be part of a "call_function" FX node for
161            # torch._assert, which will be added to the FX graph of the new shape_env.
162            replacearg(index=3, key="fx_node", fn=maybe_convert_node)
163
164        # Actually call the method with the converted arguments.
165        return self.f(*args, **kwargs)
166
167    def __str__(self) -> str:
168        name = self.name if self.name is not None else self.f.__name__
169        return f"event: {name} ({self.args}, {self.kwargs})"
170
171    def is_create_fx_call_function(self) -> bool:
172        return self.name == "_create_fx_call_function"
173
174    def is_evaluate_expr(self) -> bool:
175        return self.name == "evaluate_expr"
176
177    def is_defer_runtime_assert(self) -> bool:
178        return self.name == "defer_runtime_assert"
179
180
181NEST = 0
182
183
184# Extracts a ShapeEnv instance inside args and kwargs.
185# Specifically, it looks for:
186#   1. ShapeEnv arguments
187#   2. SymInt, SymFloat, or SymBool arguments
188# If we find more than one object of any of the above types, we
189# also check that the ShapeEnv instance is the same for all of them.
190def _extract_shape_env_and_assert_equal(args, kwargs):
191    from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes
192
193    def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
194        if old is not None:
195            assert old is new, "call with different ShapeEnv"
196        return new
197
198    shape_env = None
199    for val in itertools.chain(args, kwargs.values()):
200        if isinstance(val, ShapeEnv):
201            shape_env = assert_equal(shape_env, val)
202        if isinstance(val, SymTypes) and is_symbolic(val):
203            shape_env = assert_equal(shape_env, val.node.shape_env)
204
205    return shape_env
206
207
208# Decorator for recording the given function as a replayable event.
209#
210# This decorator should be used at every function that mutates the state of
211# ShapeEnv in some way that affects the resulting issued guards (i.e. when
212# ShapeEnv.produce_guards is called).
213#
214# save_tracked_fakes: saves a snapshot of the TrackedFake list.
215# This is used when calling ShapeEnv.produce_guards at arbitrary points in time.
216#
217# When to save the list of TrackedFake?
218# =====================================
219# We should save the list of TrackedFake whenever the translation validation
220# bisection may actually stop and call the produce_guards method at the moment
221# right after the recorded function was played. In other words, since the
222# bisection bisects through torch._assert calls, we should save in all methods
223# that adds a torch._assert call to the symbolic shapes FX graph.
224#
225# At the moment, there are 2 methods that save the list:
226#   - ShapeEnv.evaluate_expr
227#   - ShapeEnv.defer_runtime_assert
228def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
229    def decorator(fn: Callable) -> Callable:
230        assert callable(fn)
231        args = inspect.getfullargspec(fn).args
232        assert args and args[0] == "self", (
233            "record_shapeenv_event should only wrap methods on ShapeEnv; refactor your "
234            "code so that it calls into a method on ShapeEnv"
235        )
236        name = fn.__name__
237
238        @functools.wraps(fn)
239        def wrapper(*args, **kwargs):
240            from torch.fx.experimental.symbolic_shapes import ShapeEnv
241
242            assert isinstance(args[0], ShapeEnv)
243
244            global NEST
245
246            trace_shape_events_log.debug(
247                "%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs
248            )
249            NEST += 1
250
251            def retlog(r):
252                trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r)
253                return r
254
255            try:
256                if args[0].is_recording:  # type: ignore[has-type]
257                    # If ShapeEnv is already recording an event, call the wrapped
258                    # function directly.
259                    #
260                    # NB: here, we skip the check of whether all ShapeEnv instances
261                    # are equal, in favor of a faster dispatch.
262                    return retlog(fn(*args, **kwargs))
263
264                # Retrieve an instance of ShapeEnv.
265                # Assumption: the collection of args and kwargs may not reference
266                # different ShapeEnv instances.
267                self = _extract_shape_env_and_assert_equal(args, kwargs)
268
269                # If we are calling this function without any ShapeEnv instance
270                # alive in its arguments, we don't record and call the original.
271                if self is None:
272                    return retlog(fn(*args, **kwargs))
273
274                # Otherwise, start recording and call the function.
275                with self._recording():
276                    # Take a snapshot of the current tracked_fakes.
277                    tracked_fakes = (
278                        self._snapshot_tracked_fakes() if save_tracked_fakes else None
279                    )
280                    # Record the event for 'fn'.
281                    event = ShapeEnvEvent(
282                        fn, list(args), kwargs, tracked_fakes, name=fn.__name__
283                    )
284                    # Play the event on this ShapeEnv.
285                    # NB: It's important to put the event first, because running
286                    # the event can trigger internal events that must be ordered
287                    # after this event.  However, if an exception happens, we do
288                    # NOT want to have the event in the list, so pop it off from
289                    # the record if an error happened
290                    self.events.append(event)
291                    try:
292                        return retlog(event.run(self))
293                    except Exception:
294                        self.events.pop()
295                        raise
296
297            except Exception:
298                log.error(  # noqa: G201
299                    "failed while running %s(*%s, **%s)",
300                    name,
301                    args[1:],
302                    kwargs,
303                    exc_info=log.isEnabledFor(logging.INFO),
304                )
305                raise
306
307            finally:
308                NEST -= 1
309
310        return wrapper
311
312    return decorator
313
314
315# Replays the ShapeEnvEvents list.
316# It assumes the first event is the constructor call.
317#
318# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv.
319def replay_shape_env_events(events):
320    from torch.fx.experimental.symbolic_shapes import ShapeEnv
321
322    constructor_event = events[0]
323    assert constructor_event.f == ShapeEnv
324
325    # Constructs the new ShapeEnv.
326    shape_env = constructor_event.run()
327
328    for event in events[1:]:
329        try:
330            # Actually replays each event.
331            # We need to call create_mapping_fn every time, since the node list might
332            # change after each event is replayed.
333            event.run(shape_env)
334        except Exception as e:
335            log.error("failed when running event: %s", event)
336            raise
337
338    return shape_env
339
340
341# FakeTensor metadata.
342# This is to be used in place of FakeTensor placeholders when calling
343# ShapeEnv.produce_guards.
344@dataclass
345class FakeTensorMeta:
346    tensor_size: Tuple[Union[int, torch.SymInt], ...]
347    tensor_stride: Tuple[Union[int, torch.SymInt], ...]
348    tensor_storage_offset: Union[int, torch.SymInt]
349    is_nested: bool
350
351    def size(self) -> Tuple[Union[int, torch.SymInt], ...]:
352        return self.tensor_size
353
354    def stride(self) -> Tuple[Union[int, torch.SymInt], ...]:
355        return self.tensor_stride
356
357    def storage_offset(self) -> Union[int, torch.SymInt]:
358        return self.tensor_storage_offset
359
360    def dim(self) -> int:
361        return len(self.tensor_size)
362
363    @staticmethod
364    def from_fake(fake) -> "FakeTensorMeta":
365        return FakeTensorMeta(
366            fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested
367        )
368
369
370# [Note: ShapeEnv State Equality]
371# ===============================
372#
373# What is considered ShapeEnv state?
374# ----------------------------------
375# We consider to be the state of a ShapeEnv instance everything that
376# is not in the inline tuple inside remove_nonstate_variables function.
377# That is: the fields within ShapeEnv that modify the flow of execution
378# of the program.
379#
380# So, for example: the replacements field might influence on how an
381# expression is simplified. That, in turn, may result in a guard being
382# statically known (i.e. not added).
383#
384# On the other hand, var_to_stack serves only changes what is printed
385# in the screen, i.e. used only for debugging purposes. Therefore, we
386# should not consider it when comparing states.
387#
388# What to do on NotEqualError?
389# ----------------------------
390# Here are a few possible causes for getting a NotEqualError raised:
391#
392#   1. New field that does not belong in the ShapeEnv state.
393#      For example: log field of type ShapeEnvLoggerAdapter. Different
394#      ShapeEnv instances will always have different ShapeEnvLoggerAdapter
395#      instances, i.e. equality comparison would fail.
396#      Solution: add it to the inlined tuple inside remove_nonstate_variables
397#      function inside check_equal method.
398#
399#   2. New field that is not directly comparable across instances.
400#      For example: guards field of type List[ShapeGuard]. More specifically,
401#      the ShapeGuard type holds an expression and a stack information
402#      for debugging purposes. When replaying the even on a new ShapeEnv
403#      instance, the stack would be different, which would trigger this error.
404#      Solution: add a special case to the map_value function inside
405#      check_equal function.
406#
407#   3. Mutation of ShapeEnv on some not recorded function.
408#      If a mutation of the state of ShapeEnv happens inside a function
409#      that is not recorded (or that no caller in the stack is recorded),
410#      then, the replayed ShapeEnv won't catch that.
411#      Solution: decorate the function with record_shape_env_event.
412
413
414# Checks whether the state of two ShapeEnv are equal w.r.t. the guards
415# returned by ShapeEnv.produce_guards.
416def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value):
417    # Collect and remove variables that don't necessarily represent the state
418    # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the
419    # instance itself.
420    env1_vars = vars(env1).copy()
421    env2_vars = vars(env2).copy()
422
423    for v in non_state_variable_names:
424        if v in env1_vars:
425            env1_vars.pop(v)
426        if v in env2_vars:
427            env2_vars.pop(v)
428
429    # Function for transforming the mismatched values into string.
430    # Needed, since dict and set entries order might not be the same every time.
431    def value_to_str(value: Any) -> str:
432        if isinstance(value, dict):
433            return (
434                "{"
435                + ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str))
436                + "}"
437            )
438        if isinstance(value, set):
439            return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}"
440        return str(value)
441
442    # Compares env1_vars with env2_vars.
443    # Here, we allow the value of each field to be mapped, so that we appropriately
444    # compare the two values.
445    def compare_vars(
446        map_value: Callable[[str, Any], Any]
447    ) -> List[Tuple[str, str, str]]:
448        env1_set, env2_set = set(env1_vars), set(env2_vars)
449
450        # First, compare the set of keys in each vars dictionary.
451        if env1_set != env2_set:
452            raise NotEqualError(
453                "field set mismatch:",
454                [
455                    (
456                        "found unique fields:",
457                        str(sorted(env1_set - env2_set)),
458                        str(sorted(env2_set - env1_set)),
459                    ),
460                ],
461            )
462
463        # Then, sort the keys, and compare the mapped values of each key.
464        sorted_keys = list(env1_set)
465        sorted_keys.sort()
466
467        mapped_dict = [
468            (k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k]))
469            for k in sorted_keys
470        ]
471
472        # Return a list of tuples representing the fields that did not match
473        # alongside their respective mapped values.
474        return [
475            (f"{k}: values don't match.", value_to_str(val1), value_to_str(val2))
476            for k, val1, val2 in mapped_dict
477            if val1 != val2
478        ]
479
480    # Accumulate the mismatching fields.
481    errors = compare_vars(map_value)
482
483    if len(errors) > 0:
484        raise NotEqualError("field values don't match:", errors)
485
486
487class NotEqualError(Exception):
488    def __init__(
489        self,
490        msg: str,
491        mismatched: List[Tuple[str, str, str]],
492    ) -> None:
493        details = "\n".join(
494            [
495                "\n".join(
496                    [
497                        f"==> {inner_msg}",
498                        f"  >  Left: {str1}",
499                        f"  > Right: {str2}",
500                    ]
501                )
502                for inner_msg, str1, str2 in mismatched
503            ]
504        )
505
506        super().__init__(
507            f"""\
508ShapeEnv not equal: {msg}
509
510{details}
511"""
512        )
513