xref: /aosp_15_r20/external/pytorch/torch/_higher_order_ops/effects.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from enum import Enum
4from typing import Any, Dict, Optional, Tuple, Union
5from weakref import WeakKeyDictionary
6
7import torch
8import torch.utils._pytree as pytree
9from torch._C import DispatchKey
10from torch._higher_order_ops.torchbind import call_torchbind
11from torch._ops import HigherOrderOperator
12from torch._subclasses.fake_tensor import FakeTensorMode
13from torch.fx.experimental.proxy_tensor import (
14    disable_proxy_modes_tracing,
15    ProxyTorchDispatchMode,
16    track_tensor_tree,
17)
18
19
20class _EffectType(Enum):
21    ORDERED = "Ordered"
22
23
24OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
25
26
27SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary(
28    {
29        torch.ops.aten._print.default: _EffectType.ORDERED,
30        call_torchbind: _EffectType.ORDERED,
31    }
32)
33
34
35def _register_effectful_op(op: OpType, effect: _EffectType):
36    assert isinstance(
37        op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
38    ) and not has_aliasing(op)
39    if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect:
40        raise RuntimeError(
41            f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, "
42            f"trying to register a different effect type {effect}."
43        )
44    SIDE_EFFECTS[op] = effect
45
46
47def _deregister_effectful_op(op: OpType):
48    if op not in SIDE_EFFECTS:
49        raise RuntimeError(f"Op {op} is not registered as effectful")
50
51    del SIDE_EFFECTS[op]
52
53
54class WithEffects(HigherOrderOperator):
55    """
56    with_effects(token, op, args, kwargs) -> (new_token, op_results)
57
58    This HOP helps ensure ordering between side effectful ops like prints or ops
59    using torchbind objects. This is needed to ensure a traced graph from
60    AOTAutograd is functional so that future optimization passes do not reorder
61    these operators. This is done through threading "effect tokens" through the
62    graph to enforce data dependence between side effectful ops.
63
64    The tokens are basically dummy values (torch.tensor([])). We create a token
65    per "effect type", which are enumerated in the _EffectType enum.
66    """
67
68    def __init__(self) -> None:
69        super().__init__("with_effects")
70
71    def __call__(
72        self,
73        token,
74        op: OpType,
75        *args: Tuple[Any, ...],
76        **kwargs: Dict[str, Any],
77    ) -> Tuple[Any, ...]:
78        assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
79        assert not has_aliasing(op), "Ops with aliasing is not supported"
80        assert has_effects(op, args, kwargs)
81        assert isinstance(kwargs, dict)
82        return super().__call__(token, op, *args, **kwargs)
83
84
85with_effects = WithEffects()
86
87
88def has_aliasing(op: OpType):
89    # NOT FOR PUBLIC USE
90    if isinstance(op, torch._ops.HigherOrderOperator):
91        return op not in SIDE_EFFECTS
92
93    for arg in op._schema.arguments:
94        if arg.alias_info is not None:
95            return True
96    for arg in op._schema.returns:
97        if arg.alias_info is not None:
98            return True
99    return False
100
101
102def has_effects(op, args, kwargs) -> bool:
103    # Skip over the profiler's RecordFunction as they should not show up in the graph
104    _skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction}
105    if op in _skip_ops:
106        return False
107
108    return (
109        isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
110        and not has_aliasing(op)
111        and get_effect_key(op, args, kwargs) is not None
112    )
113
114
115def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
116    if op in SIDE_EFFECTS:
117        return SIDE_EFFECTS[op]
118
119    for arg in args:
120        if isinstance(arg, torch.ScriptObject):
121            # Add it to the table so that next time we see the same op we don't
122            # have to parse through the args again
123            SIDE_EFFECTS[op] = _EffectType.ORDERED
124            return _EffectType.ORDERED
125
126    return None
127
128
129def new_token_tensor() -> torch.Tensor:
130    # Use dtype bool to not affect Inductor dtype promotions
131    return torch.tensor([], dtype=torch.bool)
132
133
134@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd)
135def with_effects_dense(
136    token: torch.Tensor,
137    op: torch._ops.OpOverload,
138    *args: Tuple[Any, ...],
139    **kwargs: Dict[str, Any],
140) -> Tuple[torch.Tensor, ...]:
141    out = op(*args, **kwargs)
142    new_token = new_token_tensor()
143    if isinstance(out, tuple):
144        return (new_token, *out)
145    return (new_token, out)
146
147
148@with_effects.py_impl(FakeTensorMode)
149def with_effects_fake(
150    mode,
151    token: torch.Tensor,
152    op: torch._ops.OpOverload,
153    *args: Tuple[Any, ...],
154    **kwargs: Dict[str, Any],
155) -> Tuple[torch.Tensor, ...]:
156    with mode:
157        result = with_effects_dense(token, op, *args, **kwargs)
158        return result
159
160
161@with_effects.py_impl(ProxyTorchDispatchMode)
162def with_effects_proxy(
163    mode,
164    token: torch.Tensor,
165    op: torch._ops.OpOverload,
166    *args: Tuple[Any, ...],
167    **kwargs: Dict[str, Any],
168) -> Tuple[torch.Tensor, ...]:
169    with disable_proxy_modes_tracing():
170        out = with_effects(token, op, *args, **kwargs)
171
172    proxy_token = mode.tracer.unwrap_proxy(token)
173    proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
174    proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
175
176    from torch.fx.node import has_side_effect
177
178    # To avoid the being DCEed by graph.eliminate_dead_code if they.
179    # don't have output or their outputs are not used.
180    has_side_effect(op)
181
182    out_proxy = mode.tracer.create_proxy(
183        "call_function",
184        with_effects,
185        (proxy_token, op, *proxy_args),
186        proxy_kwargs,
187    )
188    result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
189    return result
190
191
192with_effects.fallthrough(DispatchKey.AutogradCPU)
193with_effects.fallthrough(DispatchKey.AutogradCUDA)
194
195
196def _get_schema(op, args) -> torch.FunctionSchema:
197    if isinstance(op, torch._ops.OpOverload):
198        return op._schema
199    elif op == call_torchbind:
200        return getattr(args[0], args[1]).schema
201    else:
202        raise RuntimeError(f"Unable to get schema for op {op}")
203
204
205def handle_effects(
206    allow_token_discovery: bool,
207    tokens: Dict[_EffectType, torch.Tensor],
208    op: OpType,
209    args: Tuple[Any, ...],
210    kwargs: Dict[str, Any],
211) -> Any:
212    """
213    Args:
214        allow_token_discovery: Whether or not we are discovering tokens. If this
215        is true, we will create a token for every side effect type seen that
216        does not have a token assigned yet.  If this is false, the tokens
217        should've all been created ahead of time, so we will error if there is
218        no token mapping to every effect type.
219
220        tokens: Map of effect type to tokens. This is to chain operators of the
221        same effects together so that they do not get reordered in later
222        optimization passes.
223    """
224
225    # Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
226    # this will create an empty tensor during proxy mode tracing if the token
227    # doesn't exist. But the tokens should always exist during proxy mode tracing.
228    key = get_effect_key(op, args, kwargs)
229    assert key is not None
230    if key not in tokens:
231        assert (
232            allow_token_discovery
233        ), f"Could not find a token for effect {key} which came from the function {op}"
234        proxy_tensor_mode = torch._C._get_dispatch_mode(
235            torch._C._TorchDispatchModeKey.PROXY
236        )
237        if proxy_tensor_mode is not None:
238            # If we discovered a new token during tracing, we are in backward.
239            # Then we patch the graph, adding additional tangents_token as input to the joint graph.
240            tracer = proxy_tensor_mode.tracer
241
242            from torch.fx.experimental.proxy_tensor import (
243                disable_proxy_modes_tracing,
244                track_tensor_tree,
245            )
246
247            with disable_proxy_modes_tracing():
248                token_tensor = new_token_tensor()
249
250            token_proxy = proxy_tensor_mode.tracer.create_proxy(
251                "placeholder", "tangents_token", (), {}, name="tangents_token"
252            )
253            track_tensor_tree(token_tensor, token_proxy, constant=None, tracer=tracer)
254
255            tokens[key] = token_tensor
256        else:
257            tokens[key] = new_token_tensor()
258
259    token = tokens[key]
260
261    from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
262
263    ctx = PythonFunctionalizeAPI()
264
265    unwrapped_token = ctx.unwrap_tensors([token])[0]  # type: ignore[arg-type]
266    unwrapped_args = ctx.unwrap_tensors(args)  # type: ignore[arg-type]
267    unwrapped_kwargs = ctx.unwrap_tensors(kwargs)  # type: ignore[arg-type]
268    with ctx.redispatch_to_next():
269        (new_token, *unwrapped_outs) = with_effects(
270            unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs  # type: ignore[arg-type]
271        )
272
273    schema = _get_schema(op, unwrapped_args)
274    if len(schema.returns) == 0:
275        assert unwrapped_outs[0] is None
276        unwrapped_outs = None  # type: ignore[assignment]
277    elif len(schema.returns) == 1:
278        assert len(unwrapped_outs) == 1
279        unwrapped_outs = unwrapped_outs[0]
280    else:
281        assert len(unwrapped_outs) == len(schema.returns)
282
283    # Add the newly created token into the tokens map for a following call to
284    # use this token.
285    wrapped_token = ctx.wrap_tensors(new_token)
286    assert isinstance(wrapped_token, torch.Tensor)
287    tokens[key] = wrapped_token
288
289    return ctx.wrap_tensors(unwrapped_outs)  # type: ignore[arg-type]
290