1# mypy: allow-untyped-defs 2import torch 3import torch.utils._pytree as pytree 4from torch._C import DispatchKey 5from torch._dispatch.python import suspend_functionalization 6from torch._functorch.aot_autograd import AOTConfig, create_joint 7from torch._higher_order_ops.utils import ( 8 _has_potential_branch_input_alias, 9 _has_potential_branch_input_mutation, 10 _maybe_run_with_interpreter, 11 reenter_make_fx, 12 UnsupportedAliasMutationException, 13) 14from torch._ops import HigherOrderOperator 15from torch._subclasses.fake_tensor import FakeTensorMode 16from torch._subclasses.functional_tensor import disable_functional_mode 17from torch.fx.experimental.proxy_tensor import ( 18 disable_proxy_modes_tracing, 19 make_fx, 20 ProxyTorchDispatchMode, 21 track_tensor_tree, 22) 23 24from .utils import ( 25 _from_fun, 26 _stack_pytree, 27 _unstack_pytree, 28 clone_outputs_aliasing_inputs, 29 prepare_fw_with_masks, 30) 31 32 33# TODO: We add this to prevent dymamo from tracing into map_wrapper, 34# remove the wrapper call when it's ready. 35class MapWrapper(HigherOrderOperator): 36 def __init__(self): 37 super().__init__("map") 38 39 def __call__(self, xs, *args): 40 return map_wrapper(xs, *args) 41 42 43class MapImpl(HigherOrderOperator): 44 def __init__(self): 45 super().__init__("map_impl") 46 47 def __call__(self, *args, **kwargs): 48 return super().__call__(*args, **kwargs) 49 50 51map = MapWrapper() 52 53map_impl = MapImpl() 54 55dummy_aot_config = AOTConfig( 56 fw_compiler=None, # type: ignore[arg-type] 57 bw_compiler=None, # type: ignore[arg-type] 58 partition_fn=None, # type: ignore[arg-type] 59 decompositions={}, 60 num_params_buffers=0, 61 aot_id=0, 62 keep_inference_input_mutations=False, 63) 64 65 66def create_fw_bw_graph(f, num_mapped_args, *args): 67 mapped_xs = args[:num_mapped_args] 68 pos_args = args[num_mapped_args:] 69 70 # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py 71 72 with suspend_functionalization(), disable_functional_mode(): 73 with disable_proxy_modes_tracing(): 74 unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs) 75 example_xs = _unstack_pytree(unwrapped_mapped_xs)[0] 76 77 example_pos_args = [ 78 _from_fun(arg) if isinstance(arg, torch.Tensor) else arg 79 for arg in pos_args 80 ] 81 example_flat_out = pytree.tree_map( 82 _from_fun, f(*example_xs, *example_pos_args) 83 ) 84 if any( 85 not isinstance(out, torch.Tensor) 86 for out in example_flat_out 87 if out is not None 88 ): 89 raise RuntimeError( 90 "Expect outputs of map only contains tensors or None. " 91 f"Got types {[type(out) for out in example_flat_out]}." 92 ) 93 example_grad = [_from_fun(out) for out in example_flat_out] 94 95 fw_graph = make_fx(f)(*example_xs, *example_pos_args) 96 97 def joint_f(*example_args): 98 joint_mapped_args = example_args[:joint_num_mapped] 99 args = example_args[joint_num_mapped:] 100 101 mapped_input = joint_mapped_args[:num_mapped_args] 102 mapped_grads = joint_mapped_args[num_mapped_args:] 103 104 joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config) 105 _, grads = joint( 106 list(mapped_input) + list(args), 107 [ 108 grad 109 for grad in mapped_grads 110 if grad is not None and grad.requires_grad 111 ], 112 ) 113 114 # In order to keep map functional for backward graph, 115 # we clone outputs that are aliasing inputs 116 maybe_clone = clone_outputs_aliasing_inputs(example_args) 117 118 return pytree.tree_map(maybe_clone, grads) 119 120 joint_num_mapped = len(example_grad) + len(example_xs) 121 joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args) 122 return fw_graph, joint_graph 123 124 125def map_wrapper(f, xs, *args): 126 flat_xs, xs_spec = pytree.tree_flatten(xs) 127 if not all(isinstance(t, torch.Tensor) for t in flat_xs): 128 raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.") 129 130 num_mapped_args = len(flat_xs) 131 shapes = [xs.shape for xs in flat_xs] 132 leading_dim_size = shapes[0][0] 133 if leading_dim_size == 0: 134 raise RuntimeError("Leading dimensions of mapped xs cannot be 0.") 135 136 if any(cur_shape[0] != leading_dim_size for cur_shape in shapes): 137 raise RuntimeError( 138 f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}." 139 ) 140 141 out_spec = None 142 143 def flat_fn(*flat_args): 144 xs = pytree.tree_unflatten(list(flat_args[:num_mapped_args]), xs_spec) 145 unflattened_out = f(xs, *flat_args[num_mapped_args:]) 146 flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out) 147 148 nonlocal out_spec 149 out_spec = tmp_out_spec 150 return flat_out 151 152 return pytree.tree_unflatten( 153 map_impl(flat_fn, flat_xs, args), out_spec # type: ignore[arg-type] 154 ) 155 156 157class MapAutogradOp(torch.autograd.Function): 158 @staticmethod 159 def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args): 160 ctx.save_for_backward(*flat_args) 161 ctx._joint_graph = joint_graph 162 ctx._num_mapped_args = num_mapped_args 163 with torch._C._AutoDispatchBelowAutograd(): 164 return ( 165 *map_impl( 166 fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:] 167 ), 168 ) 169 170 @staticmethod 171 def backward(ctx, *flat_grads): 172 fw_args = ctx.saved_tensors 173 fw_mapped_args = fw_args[: ctx._num_mapped_args] 174 pos_args = fw_args[ctx._num_mapped_args :] 175 176 grads = map_impl( 177 ctx._joint_graph, 178 fw_mapped_args + flat_grads, 179 pos_args, 180 ) 181 return None, None, None, *grads 182 183 184def trace_map(proxy_mode, func_overload, f, xs, pos_args): 185 leading_dim_size = xs[0].shape[0] 186 187 example_input = _unstack_pytree(xs)[0] 188 body_graph = f 189 190 body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args) 191 192 next_name = proxy_mode.tracer.get_fresh_qualname("body_graph_") 193 194 proxy_mode.tracer.root.register_module(next_name, body_graph) 195 196 with disable_proxy_modes_tracing(): 197 example_outs = body_graph(*example_input, *pos_args) 198 199 def expand_tensor(t): 200 if isinstance(t, torch.Tensor): 201 return t.expand(leading_dim_size, *t.shape) 202 return t 203 204 expanded_outs = pytree.tree_map(expand_tensor, example_outs) 205 206 node_args = (body_graph, list(xs), list(pos_args)) 207 proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) 208 out_proxy = proxy_mode.tracer.create_proxy( 209 "call_function", func_overload, proxy_args, {}, name="map_impl" 210 ) 211 return track_tensor_tree( 212 expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer 213 ) 214 215 216@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd) 217def map_dense(f, xs, pos_args): 218 pytrees = [] 219 for inp in _unstack_pytree(xs): 220 pytrees.append(f(*inp, *pos_args)) 221 return _stack_pytree(pytrees) 222 223 224@map_impl.py_impl(DispatchKey.Autograd) 225def map_autograd(f, xs, pos_args): 226 num_mapped_args = len(xs) 227 fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args) 228 flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args) 229 return flat_out 230 231 232@map_impl.py_impl(ProxyTorchDispatchMode) 233def map_proxy_torch_dispatch_mode(mode, f, xs, args): 234 return trace_map(mode, map_impl, f, xs, args) 235 236 237@map_impl.py_impl(FakeTensorMode) 238def map_fake_tensor_mode(mode, f, xs, args): 239 with mode: 240 return map_dense(f, xs, args) 241 242 243@map_impl.py_functionalize_impl 244def map_functionalize(ctx, f, xs, pos_args): 245 unwrapped_xs = ctx.unwrap_tensors(xs) 246 unwrapped_args = ctx.unwrap_tensors(pos_args) 247 wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f)) 248 249 with ctx.redispatch_to_next(): 250 with disable_proxy_modes_tracing(): 251 example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) 252 pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch 253 if _has_potential_branch_input_mutation( 254 f, example_inputs, pre_dispatch=pre_dispatch 255 ): 256 raise UnsupportedAliasMutationException("torch.map is mutating the input!") 257 258 if _has_potential_branch_input_alias( 259 f, example_inputs, pre_dispatch=pre_dispatch 260 ): 261 raise UnsupportedAliasMutationException("torch.map is aliasing the input!") 262 263 map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args) 264 return ctx.wrap_tensors(map_return) 265