1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from enum import Enum 4from typing import Any, Dict, Optional, Tuple, Union 5from weakref import WeakKeyDictionary 6 7import torch 8import torch.utils._pytree as pytree 9from torch._C import DispatchKey 10from torch._higher_order_ops.torchbind import call_torchbind 11from torch._ops import HigherOrderOperator 12from torch._subclasses.fake_tensor import FakeTensorMode 13from torch.fx.experimental.proxy_tensor import ( 14 disable_proxy_modes_tracing, 15 ProxyTorchDispatchMode, 16 track_tensor_tree, 17) 18 19 20class _EffectType(Enum): 21 ORDERED = "Ordered" 22 23 24OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload] 25 26 27SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary( 28 { 29 torch.ops.aten._print.default: _EffectType.ORDERED, 30 call_torchbind: _EffectType.ORDERED, 31 } 32) 33 34 35def _register_effectful_op(op: OpType, effect: _EffectType): 36 assert isinstance( 37 op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) 38 ) and not has_aliasing(op) 39 if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect: 40 raise RuntimeError( 41 f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, " 42 f"trying to register a different effect type {effect}." 43 ) 44 SIDE_EFFECTS[op] = effect 45 46 47def _deregister_effectful_op(op: OpType): 48 if op not in SIDE_EFFECTS: 49 raise RuntimeError(f"Op {op} is not registered as effectful") 50 51 del SIDE_EFFECTS[op] 52 53 54class WithEffects(HigherOrderOperator): 55 """ 56 with_effects(token, op, args, kwargs) -> (new_token, op_results) 57 58 This HOP helps ensure ordering between side effectful ops like prints or ops 59 using torchbind objects. This is needed to ensure a traced graph from 60 AOTAutograd is functional so that future optimization passes do not reorder 61 these operators. This is done through threading "effect tokens" through the 62 graph to enforce data dependence between side effectful ops. 63 64 The tokens are basically dummy values (torch.tensor([])). We create a token 65 per "effect type", which are enumerated in the _EffectType enum. 66 """ 67 68 def __init__(self) -> None: 69 super().__init__("with_effects") 70 71 def __call__( 72 self, 73 token, 74 op: OpType, 75 *args: Tuple[Any, ...], 76 **kwargs: Dict[str, Any], 77 ) -> Tuple[Any, ...]: 78 assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) 79 assert not has_aliasing(op), "Ops with aliasing is not supported" 80 assert has_effects(op, args, kwargs) 81 assert isinstance(kwargs, dict) 82 return super().__call__(token, op, *args, **kwargs) 83 84 85with_effects = WithEffects() 86 87 88def has_aliasing(op: OpType): 89 # NOT FOR PUBLIC USE 90 if isinstance(op, torch._ops.HigherOrderOperator): 91 return op not in SIDE_EFFECTS 92 93 for arg in op._schema.arguments: 94 if arg.alias_info is not None: 95 return True 96 for arg in op._schema.returns: 97 if arg.alias_info is not None: 98 return True 99 return False 100 101 102def has_effects(op, args, kwargs) -> bool: 103 # Skip over the profiler's RecordFunction as they should not show up in the graph 104 _skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction} 105 if op in _skip_ops: 106 return False 107 108 return ( 109 isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) 110 and not has_aliasing(op) 111 and get_effect_key(op, args, kwargs) is not None 112 ) 113 114 115def get_effect_key(op, args, kwargs) -> Optional[_EffectType]: 116 if op in SIDE_EFFECTS: 117 return SIDE_EFFECTS[op] 118 119 for arg in args: 120 if isinstance(arg, torch.ScriptObject): 121 # Add it to the table so that next time we see the same op we don't 122 # have to parse through the args again 123 SIDE_EFFECTS[op] = _EffectType.ORDERED 124 return _EffectType.ORDERED 125 126 return None 127 128 129def new_token_tensor() -> torch.Tensor: 130 # Use dtype bool to not affect Inductor dtype promotions 131 return torch.tensor([], dtype=torch.bool) 132 133 134@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd) 135def with_effects_dense( 136 token: torch.Tensor, 137 op: torch._ops.OpOverload, 138 *args: Tuple[Any, ...], 139 **kwargs: Dict[str, Any], 140) -> Tuple[torch.Tensor, ...]: 141 out = op(*args, **kwargs) 142 new_token = new_token_tensor() 143 if isinstance(out, tuple): 144 return (new_token, *out) 145 return (new_token, out) 146 147 148@with_effects.py_impl(FakeTensorMode) 149def with_effects_fake( 150 mode, 151 token: torch.Tensor, 152 op: torch._ops.OpOverload, 153 *args: Tuple[Any, ...], 154 **kwargs: Dict[str, Any], 155) -> Tuple[torch.Tensor, ...]: 156 with mode: 157 result = with_effects_dense(token, op, *args, **kwargs) 158 return result 159 160 161@with_effects.py_impl(ProxyTorchDispatchMode) 162def with_effects_proxy( 163 mode, 164 token: torch.Tensor, 165 op: torch._ops.OpOverload, 166 *args: Tuple[Any, ...], 167 **kwargs: Dict[str, Any], 168) -> Tuple[torch.Tensor, ...]: 169 with disable_proxy_modes_tracing(): 170 out = with_effects(token, op, *args, **kwargs) 171 172 proxy_token = mode.tracer.unwrap_proxy(token) 173 proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) 174 proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) 175 176 from torch.fx.node import has_side_effect 177 178 # To avoid the being DCEed by graph.eliminate_dead_code if they. 179 # don't have output or their outputs are not used. 180 has_side_effect(op) 181 182 out_proxy = mode.tracer.create_proxy( 183 "call_function", 184 with_effects, 185 (proxy_token, op, *proxy_args), 186 proxy_kwargs, 187 ) 188 result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) 189 return result 190 191 192with_effects.fallthrough(DispatchKey.AutogradCPU) 193with_effects.fallthrough(DispatchKey.AutogradCUDA) 194 195 196def _get_schema(op, args) -> torch.FunctionSchema: 197 if isinstance(op, torch._ops.OpOverload): 198 return op._schema 199 elif op == call_torchbind: 200 return getattr(args[0], args[1]).schema 201 else: 202 raise RuntimeError(f"Unable to get schema for op {op}") 203 204 205def handle_effects( 206 allow_token_discovery: bool, 207 tokens: Dict[_EffectType, torch.Tensor], 208 op: OpType, 209 args: Tuple[Any, ...], 210 kwargs: Dict[str, Any], 211) -> Any: 212 """ 213 Args: 214 allow_token_discovery: Whether or not we are discovering tokens. If this 215 is true, we will create a token for every side effect type seen that 216 does not have a token assigned yet. If this is false, the tokens 217 should've all been created ahead of time, so we will error if there is 218 no token mapping to every effect type. 219 220 tokens: Map of effect type to tokens. This is to chain operators of the 221 same effects together so that they do not get reordered in later 222 optimization passes. 223 """ 224 225 # Get a token. We can't do `tokens.get(op, torch.tensor([]))` because 226 # this will create an empty tensor during proxy mode tracing if the token 227 # doesn't exist. But the tokens should always exist during proxy mode tracing. 228 key = get_effect_key(op, args, kwargs) 229 assert key is not None 230 if key not in tokens: 231 assert ( 232 allow_token_discovery 233 ), f"Could not find a token for effect {key} which came from the function {op}" 234 proxy_tensor_mode = torch._C._get_dispatch_mode( 235 torch._C._TorchDispatchModeKey.PROXY 236 ) 237 if proxy_tensor_mode is not None: 238 # If we discovered a new token during tracing, we are in backward. 239 # Then we patch the graph, adding additional tangents_token as input to the joint graph. 240 tracer = proxy_tensor_mode.tracer 241 242 from torch.fx.experimental.proxy_tensor import ( 243 disable_proxy_modes_tracing, 244 track_tensor_tree, 245 ) 246 247 with disable_proxy_modes_tracing(): 248 token_tensor = new_token_tensor() 249 250 token_proxy = proxy_tensor_mode.tracer.create_proxy( 251 "placeholder", "tangents_token", (), {}, name="tangents_token" 252 ) 253 track_tensor_tree(token_tensor, token_proxy, constant=None, tracer=tracer) 254 255 tokens[key] = token_tensor 256 else: 257 tokens[key] = new_token_tensor() 258 259 token = tokens[key] 260 261 from torch._subclasses.functional_tensor import PythonFunctionalizeAPI 262 263 ctx = PythonFunctionalizeAPI() 264 265 unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type] 266 unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type] 267 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type] 268 with ctx.redispatch_to_next(): 269 (new_token, *unwrapped_outs) = with_effects( 270 unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs # type: ignore[arg-type] 271 ) 272 273 schema = _get_schema(op, unwrapped_args) 274 if len(schema.returns) == 0: 275 assert unwrapped_outs[0] is None 276 unwrapped_outs = None # type: ignore[assignment] 277 elif len(schema.returns) == 1: 278 assert len(unwrapped_outs) == 1 279 unwrapped_outs = unwrapped_outs[0] 280 else: 281 assert len(unwrapped_outs) == len(schema.returns) 282 283 # Add the newly created token into the tokens map for a following call to 284 # use this token. 285 wrapped_token = ctx.wrap_tensors(new_token) 286 assert isinstance(wrapped_token, torch.Tensor) 287 tokens[key] = wrapped_token 288 289 return ctx.wrap_tensors(unwrapped_outs) # type: ignore[arg-type] 290