1# mypy: allow-untyped-defs 2import functools 3from contextlib import contextmanager 4from dataclasses import dataclass 5from typing import Any, Callable 6 7import torch 8import torch.fx.traceback as fx_traceback 9import torch.utils._pytree as pytree 10from torch._ops import OperatorBase 11from torch.fx.experimental.proxy_tensor import make_fx 12from torch.multiprocessing.reductions import StorageWeakRef 13 14 15@dataclass 16class UnsupportedAliasMutationException(RuntimeError): 17 reason: str 18 19 20def autograd_not_implemented_inner( 21 operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any 22) -> Any: 23 """If autograd is enabled and any of the arguments require grad this will either 24 raise an error or return a DelayedError depending on the value of delayed. 25 26 Args: 27 operator: The Operator to call with the *args and **kwargs with 28 op_name: The name of the Operator 29 delayed_error: If True, return a DelayedError instead of raising an error 30 args: The flattened operands to the Operator 31 kwargs: The keyword arguments to the Operator 32 33 Raises: 34 RuntimeError: If autograd is enabled and any of the arguments to the Operator 35 """ 36 with torch._C._AutoDispatchBelowAutograd(): 37 result = operator(*args, **kwargs) 38 flat_operands = pytree.arg_tree_leaves(*args) 39 if torch.is_grad_enabled() and any( 40 f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) 41 ): 42 if delayed_error: 43 err_fn = torch._C._functions.DelayedError( 44 f"Autograd not implemented for {str(operator)}", 45 1, 46 ) 47 48 def fake_requires_grad(tensor): 49 if torch.is_floating_point(tensor) or torch.is_complex(tensor): 50 tensor = tensor.detach() 51 tensor.requires_grad = True 52 return tensor 53 54 return pytree.tree_map_only( 55 torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result 56 ) 57 else: 58 raise RuntimeError(f"Autograd not implemented for {str(operator)}") 59 return result 60 61 62def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable: 63 def inner(*args, **kwargs): 64 return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs) 65 66 return inner 67 68 69def _maybe_run_with_interpreter(fn): 70 maybe_interpreted_fn = fn 71 if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta(): 72 # Running graph with interpreter is needed for propagating the stack_trace 73 def graph_with_interpreter(*args): 74 with fx_traceback.preserve_node_meta(): 75 return torch.fx.Interpreter(fn).run(*args) 76 77 maybe_interpreted_fn = graph_with_interpreter 78 return maybe_interpreted_fn 79 80 81def reenter_make_fx(fn): 82 from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER 83 84 @functools.wraps(fn) 85 def wrapped(*args): 86 assert ( 87 _CURRENT_MAKE_FX_TRACER is not None 88 ), "Cannot reenter make_fx when we're not under a make_fx tracing session" 89 return _CURRENT_MAKE_FX_TRACER.trace_subgraph( 90 _maybe_run_with_interpreter(fn), *args 91 ) 92 93 return wrapped 94 95 96def _maybe_reenter_make_fx(fn): 97 from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER 98 99 if _CURRENT_MAKE_FX_TRACER is not None: 100 return reenter_make_fx(fn) 101 else: 102 return make_fx(fn) 103 104 105@contextmanager 106def _set_compilation_env(): 107 _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag 108 try: 109 # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo 110 # once we are confident fx tracing works with dynamo. 111 torch.fx._symbolic_trace._is_fx_tracing_flag = False 112 yield 113 finally: 114 torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing 115 116 117def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False): 118 """ 119 Dispatch-trace the branch with inputs and check if 120 producing graph has mutable op on the input. This is 121 bit restrictive as the branch must be traceable. 122 """ 123 try: 124 gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs) 125 except UnsupportedAliasMutationException: 126 # this can happen when nested cond_op is 127 # functionalized 128 return True 129 except Exception as e: 130 raise e 131 132 def _detect_input_mutation(gm): 133 input_nodes = set() 134 for node in gm.graph.nodes: 135 if node.op == "placeholder": 136 input_nodes.add(node) 137 if node.op == "call_function": 138 target = node.target 139 if ( 140 isinstance(target, torch._ops.OpOverload) 141 and target._schema.is_mutable 142 ): 143 for arg in node.args: 144 if arg in input_nodes: 145 return True 146 147 for _, module in gm.named_children(): 148 if isinstance(module, torch.fx.GraphModule): 149 if _detect_input_mutation(module): 150 return True 151 152 return False 153 154 return _detect_input_mutation(gm) 155 156 157def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False): 158 """ 159 Dispatch-trace the branch with inputs and check if 160 producing graph has output aliasing the branch input. This is 161 bit restrictive as the branch must be traceable. 162 """ 163 try: 164 gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs) 165 except UnsupportedAliasMutationException: 166 # this can happen when nested cond_op is 167 # functionalized 168 return True 169 except Exception as e: 170 raise e 171 172 def _detect_input_alias(gm): 173 input_storages = set() 174 for node in gm.graph.nodes: 175 # We need to check existence of "val" because we reuse the logic here 176 # for map operator, where num_mapped_args is a scalar 177 # and doesn't have a "val" meta. 178 if node.op == "placeholder" and "val" in node.meta: 179 input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage())) 180 if node.op == "output": 181 182 def check_alias(out): 183 if out is not None and "val" in out.meta: 184 out_storage = StorageWeakRef(out.meta["val"]._typed_storage()) 185 return out_storage in input_storages 186 return False 187 188 if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))): 189 return True 190 191 for _, module in gm.named_children(): 192 if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module): 193 return True 194 195 return False 196 197 return _detect_input_alias(gm) 198 199 200def unique_graph_id(proxy_mode, prefix): 201 """Returns a unique name and id for a graph to be added to a proxy_mode tracer""" 202 # There are probably better ways - I know that create_arg has some self incrementing name 203 # magic to it, but since we explicitly have to get the name for register_module, 204 # I was not sure how to do that. This kinda simulates it. 205 next_name = None 206 i = 0 207 while not next_name: 208 candidate = f"{prefix}_{i}" 209 if hasattr(proxy_mode.tracer.root, candidate): 210 i += 1 211 else: 212 next_name = candidate 213 return i, next_name 214 215 216def _from_fun(t): 217 from torch._functorch.aot_autograd import from_fun 218 from torch._subclasses.functional_tensor import FunctionalTensor 219 220 if isinstance(t, torch.Tensor): 221 if t.dtype != torch.bool: 222 return torch.empty_strided( 223 t.size(), 224 t.stride(), 225 dtype=t.dtype, 226 requires_grad=t.requires_grad, 227 ) 228 else: 229 # clone of a functional tensor produces a functional tensor 230 # but we want to avoid it so we clone a non-functional version 231 maybe_unfunc_t = t 232 if isinstance(t, FunctionalTensor): 233 torch._sync(t) 234 maybe_unfunc_t = from_fun(t) 235 elif torch._is_functional_tensor(t): 236 # need to handle both types of functionalization here: 237 # these are the tensors that came from the user, 238 # which could be either FunctionalTensorWrapper or FunctionalTensor 239 torch._sync(t) 240 maybe_unfunc_t = torch._from_functional_tensor(t) 241 return maybe_unfunc_t.clone() 242 return t 243 244 245def clone_outputs_aliasing_inputs(args): 246 input_storage = { 247 StorageWeakRef(arg._typed_storage()) 248 for arg in args 249 if isinstance(arg, torch.Tensor) 250 } 251 252 def maybe_clone(t): 253 if ( 254 isinstance(t, torch.Tensor) 255 and StorageWeakRef(t._typed_storage()) in input_storage 256 ): 257 return t.clone() 258 return t 259 260 return maybe_clone 261 262 263def prepare_fw_with_masks(fn): 264 def fw_with_masks(*args): 265 fw_out = fn(*args) 266 return fw_out, [ 267 True if isinstance(ret, torch.Tensor) and ret.requires_grad else False 268 for ret in fw_out 269 ] 270 271 return fw_with_masks 272 273 274# TODO: The parameter use_output_and_grad_bw is required because some operations 275# that utilize this function, such as the while_loop, may require (grad, fwd_outputs) 276def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs): 277 from torch._functorch.aot_autograd import AOTConfig, create_joint 278 279 # Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys 280 # between Autograd and Python key. Currently, we only suspend functionalization but more can be 281 # added when required. Will encounter two problems if we don't suspend functionalization: 282 # 283 # 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper, 284 # but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching. 285 # However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to 286 # fetch the proxy for the inputs and fail to capture any operations on them. 287 # 288 # 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further 289 # wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer 290 # only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore, 291 # when creating the output node, it fails to associate the wrapped tensor with its proxy. 292 # Instead, it will create _tensor_constant as output. 293 294 dummy_aot_config = AOTConfig( 295 fw_compiler=None, # type: ignore[arg-type] 296 bw_compiler=None, # type: ignore[arg-type] 297 partition_fn=None, # type: ignore[arg-type] 298 decompositions={}, 299 num_params_buffers=0, 300 aot_id=0, 301 keep_inference_input_mutations=False, 302 ) 303 304 example_grad = [_from_fun(out) for out in fw_outputs] 305 num_grads = len(example_grad) 306 fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs) 307 308 def joint_fn(*joint_operands_grads): 309 if use_output_and_grad_bw: 310 grads = joint_operands_grads[0] 311 inputs = joint_operands_grads[1][-1:] 312 else: 313 grads = joint_operands_grads[:num_grads] 314 inputs = joint_operands_grads[num_grads:] 315 316 joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config) 317 _, grads = joint( 318 list(inputs), 319 [grad for grad in grads if grad is not None and grad.requires_grad], 320 ) 321 322 # In order to keep map functional for backward graph, 323 # we clone outputs that are aliasing inputs 324 maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads) 325 326 return pytree.tree_map(maybe_clone, grads) 327 328 if use_output_and_grad_bw: 329 example_xs_out = list(fw_inputs) + list(fw_outputs) 330 joint_graph = _maybe_reenter_make_fx(joint_fn)( 331 (list(example_grad), list(example_xs_out)) 332 ) 333 else: 334 example_xs_out = list(fw_inputs) 335 joint_graph = _maybe_reenter_make_fx(joint_fn)( 336 *(list(example_grad) + list(example_xs_out)) 337 ) 338 339 return fw_graph, joint_graph 340 341 342def _unstack_pytree(xs): 343 flat_xs, inspec = pytree.tree_flatten(xs) 344 if not all(isinstance(xs, torch.Tensor) for xs in flat_xs): 345 raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}") 346 347 if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs): 348 raise RuntimeError( 349 f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}" 350 ) 351 352 a = zip(*flat_xs) 353 354 pytrees = [] 355 for tuple in a: 356 pytrees.append(pytree.tree_unflatten(tuple, inspec)) 357 return pytrees 358 359 360def _stack_pytree(pytrees): 361 flat_out = [] 362 out_spec = None 363 for pt in pytrees: 364 flat_pt, out_spec = pytree.tree_flatten(pt) 365 flat_out.append(flat_pt) 366 assert out_spec is not None 367 b = zip(*flat_out) 368 stacked_out = [] 369 for leaves in b: 370 if all(isinstance(leaf, torch.Tensor) for leaf in leaves): 371 stacked_out.append(torch.stack(leaves)) 372 elif all(leaf is None for leaf in leaves): 373 # Backward graph can return None output when forward inputs doesn't require grad. 374 # When we eagerly execute backward graph, we need to call _stack_pytree on its output, 375 # therefore we need to deal with None output. 376 stacked_out.append(None) # type: ignore[arg-type] 377 else: 378 raise RuntimeError(f"Cannot stack {leaves}.") 379 return pytree.tree_unflatten(stacked_out, out_spec) 380