1# Copyright (c) Meta Platforms, Inc. and affiliates 2import contextlib 3import functools 4import logging 5import operator 6import warnings 7from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING 8 9import torch 10import torch.distributed as dist 11import torch.distributed.tensor._api as dtensor 12import torch.distributed.tensor._random as random 13from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta 14from torch.distributed.tensor._op_schema import ( 15 _is_inplace_op, 16 _is_out_variant_op, 17 OpInfo, 18 OpSchema, 19 OutputSpecType, 20) 21from torch.distributed.tensor._random import is_rng_supported_mesh 22from torch.distributed.tensor._redistribute import redistribute_local_tensor 23from torch.distributed.tensor._sharding_prop import ShardingPropagator 24from torch.distributed.tensor._tp_conv import ( 25 convolution_backward_handler, 26 convolution_handler, 27) 28from torch.distributed.tensor._utils import try_find_mesh_from_args 29from torch.distributed.tensor.placement_types import Partial, Placement, Replicate 30 31 32if TYPE_CHECKING: 33 from torch.distributed.device_mesh import DeviceMesh 34 35try: 36 from torch.utils import _cxx_pytree as pytree 37except ImportError: 38 from torch.utils import _pytree as pytree # type: ignore[no-redef] 39 40aten = torch.ops.aten 41logger = logging.getLogger(__name__) 42 43 44def decompose_handler( 45 op_call: torch._ops.OpOverload, 46 args: Tuple[object, ...], 47 kwargs: Dict[str, object], 48) -> object: 49 """ 50 Decomposes a op to core ATen op, this handler is mostly here 51 for inference mode usage where the ops are not core aten ops. 52 """ 53 r = op_call.decompose(*args, **kwargs) 54 if r is not NotImplemented: 55 return r 56 else: 57 raise RuntimeError("Decomposition failed") 58 59 60def is_same_size_handler( 61 op_call: torch._ops.OpOverload, 62 args: Tuple[object, ...], 63 kwargs: Dict[str, object], 64) -> bool: 65 lhs = cast(torch.Tensor, args[0]) 66 rhs = cast(torch.Tensor, args[1]) 67 return lhs.shape == rhs.shape 68 69 70def found_inf_reduce_handler( 71 op_call: torch._ops.OpOverload, 72 args: Tuple[object, ...], 73 kwargs: Dict[str, object], 74) -> None: 75 op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) 76 local_tensor_args = pytree.tree_unflatten( 77 cast(List[object], op_info.local_args), op_info.args_tree_spec 78 ) 79 local_tensor_args = cast(Tuple[object, ...], local_tensor_args) 80 local_results = op_call(*local_tensor_args, **op_info.local_kwargs) 81 82 grad_dtensor = cast(list[dtensor.DTensor], args[0])[0] 83 grad_placements = grad_dtensor.placements 84 mesh = grad_dtensor.device_mesh 85 86 found_inf_placements: list[Placement] = [] 87 for placement in grad_placements: 88 if isinstance(placement, Replicate): 89 found_inf_placements.append(placement) 90 else: 91 found_inf_placements.append(Partial("max")) 92 93 target_tensor = cast(torch.Tensor, args[1]) 94 spec = DTensorSpec( 95 mesh=mesh, 96 placements=tuple(found_inf_placements), 97 tensor_meta=TensorMeta( 98 shape=target_tensor.size(), 99 stride=target_tensor.stride(), 100 dtype=target_tensor.dtype, 101 ), 102 ) 103 found_inf_dtensor = dtensor.DTensor( 104 local_tensor=target_tensor, spec=spec, requires_grad=False 105 ) 106 found_inf = found_inf_dtensor.full_tensor() 107 target_tensor.copy_(found_inf) 108 109 110class OpDispatcher: 111 """ 112 Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding 113 propagation, redistribute local args, local compute, and post-processing (re-wrapping). It 114 also handles any op specific logic if necessary. 115 116 NOTE: Given the runtime overhead of Tensor subclass (__torch_dispatch__), the OpDispatcher 117 is designed to minimize the CPU overhead by using the tricks of proper unflattening, faster 118 pytree if needed, and leveraging various caching mechanisms implemented in the sharding 119 propagation and redistribute modules. The CPU overhead is critical to eager mode performance, 120 one need to carefully measure the CPU overhead when making significant changes to the 121 OpDispatcher and ShardingPropagator. 122 """ 123 124 def __init__(self) -> None: 125 self.sharding_propagator = ShardingPropagator() 126 self._random_ops = { 127 aten.native_dropout.default, 128 aten.normal_.default, 129 aten.rand_like.default, 130 aten.randn_like.default, 131 aten.randint_like.default, 132 aten.randint_like.low_dtype, 133 aten.randint_like.low_dtype_out, 134 aten.uniform_.default, 135 aten.bernoulli.default, 136 aten.bernoulli_.float, 137 } 138 self._custom_op_handlers = { 139 aten.linear.default: decompose_handler, 140 aten.is_same_size.default: is_same_size_handler, 141 aten.convolution.default: convolution_handler, 142 aten.convolution_backward.default: convolution_backward_handler, 143 aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler, 144 } 145 146 # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) 147 # as implicitly replicated or we throw error to user. 148 # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave 149 # it as False by default. 150 self._allow_implicit_replication = False 151 152 def dispatch( 153 self, 154 op_call: torch._ops.OpOverload, 155 args: Tuple[object, ...], 156 kwargs: Dict[str, object], 157 ) -> object: 158 """ 159 Main dispatching logic 160 """ 161 # operators that does not need to go through sharding propagation 162 if op_call in self._custom_op_handlers: 163 return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] 164 165 # extract local tensor and sharding infos to a OpInfo 166 op_info = self.unwrap_to_op_info(op_call, args, kwargs) 167 logger.debug("Dispatching op_call: %s", op_info.schema) 168 169 self.sharding_propagator.propagate(op_info) 170 output_sharding = op_info.output_sharding 171 logger.debug("output_sharding for %s: %s", op_call, output_sharding) 172 assert output_sharding is not None, "output sharding should not be None" 173 174 mesh = op_info.mesh 175 if mesh.get_coordinate() is not None: 176 # computation that happens in the current rank of the mesh, normal case 177 if output_sharding.needs_redistribute: 178 # If sharding propagation decision needs redistribute, perform redistribute 179 # on args first, which could potentially modify args (i.e. allgather certain arg) 180 assert output_sharding.redistribute_schema is not None 181 self.redistribute_local_args( 182 op_info, output_sharding.redistribute_schema 183 ) 184 185 local_tensor_args = ( 186 pytree.tree_unflatten( 187 cast(List[object], op_info.local_args), op_info.args_tree_spec 188 ) 189 if op_info.args_tree_spec 190 else op_info.local_args 191 ) 192 193 # run local op computation with potentially modified args/kwargs 194 local_tensor_args = cast(Tuple[object, ...], local_tensor_args) 195 if op_call in self._random_ops: 196 if not random._rng_tracker and is_rng_supported_mesh(mesh): 197 # Default to `OffsetBasedRNGTracker` if the parallelism API 198 # did not already construct one 199 random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type) 200 201 first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast( 202 torch.Tensor, local_tensor_args[0] 203 ) 204 rng_context = ( 205 random._rng_tracker._distribute_region(first_arg._spec) 206 if random._rng_tracker and not first_local_arg.is_meta 207 else contextlib.nullcontext() 208 ) 209 # For DTensor random operator, run it within a RNGTracker context to 210 # ensure the random number generator is properly distributed. 211 with rng_context: 212 local_results = op_call(*local_tensor_args, **op_info.local_kwargs) 213 else: 214 # normal case, run local sharded op computation 215 local_results = op_call(*local_tensor_args, **op_info.local_kwargs) 216 217 else: 218 # For a non-participating device (happens on rank that does not belong to 219 # the device mesh), we do: 220 # 1. if the return type is scalar, set the local result to None. 221 # 2. if the return type is Tensor or List[Tensor], return empty 222 # tensor(s) with correct dtype. 223 spec = output_sharding.output_spec 224 ret_list = op_info.schema.op._schema.returns 225 226 if spec is None: 227 # For a scalar return type, the non-participating device has None 228 # as its local result 229 local_results = None 230 else: 231 232 def default_tensor(spec: DTensorSpec) -> torch.Tensor: 233 if spec.tensor_meta is not None: 234 shape = spec.tensor_meta.shape 235 dtype = spec.tensor_meta.dtype 236 if len(shape) == 0: 237 # scalar tensor 238 return torch.zeros((), dtype=dtype) 239 else: 240 # non-scalar tensor 241 return torch.tensor([], dtype=dtype) 242 else: 243 raise RuntimeError(f"{spec} has no tensor metadata.") 244 245 if isinstance(spec, DTensorSpec): 246 # return a Tensor value 247 local_results = default_tensor(spec) 248 elif isinstance(spec, Sequence): 249 # return a List[Tensor] value 250 local_results = [ 251 default_tensor(s) if s is not None else None for s in spec 252 ] 253 assert isinstance(local_results, List) 254 if None in local_results: 255 ret_type = str(ret_list[0].type) 256 raise NotImplementedError( 257 f"return type {ret_type} in DTensor op is not supported" 258 ) 259 260 if output_sharding.output_spec is None: 261 if op_call == aten.equal.default: 262 # For equal operator, The local results from all devices should be all-gathered 263 # and a reduce op (AND) will be performed on the list of results to ensure SPMD 264 # execution. We can extend this for more ops if necessary. 265 obj_list = [None for _ in range(dist.get_world_size())] 266 dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined] 267 obj_list = list(filter(lambda x: x is not None, obj_list)) 268 # perform reduce on the collection with AND op 269 local_results = functools.reduce(operator.and_, obj_list, True) 270 271 if _is_inplace_op(op_call): 272 # inplace op should return self instead of re-wrapping 273 if output_sharding.output_spec is not None: 274 return args[0] 275 else: 276 return None 277 elif _is_out_variant_op(op_call): 278 # out variant could possibly have multiple out args (i.e. lu_unpack.out) 279 output_specs = ( 280 (output_sharding.output_spec,) 281 if not isinstance(output_sharding.output_spec, tuple) 282 else output_sharding.output_spec 283 ) 284 out_dts = [] 285 spec_idx = 0 286 for argument in op_call._schema.arguments: 287 if argument.is_out: 288 out_dt = cast(dtensor.DTensor, kwargs[argument.name]) 289 out_dt._spec = cast(DTensorSpec, output_specs[spec_idx]) 290 out_dts.append(out_dt) 291 spec_idx += 1 292 293 assert len(out_dts) >= 1, "out variant should have at least one out arg" 294 return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] 295 else: 296 return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] 297 298 @staticmethod 299 def redistribute_local_args( 300 op_info: OpInfo, 301 suggested_input_schema: OpSchema, 302 ) -> None: 303 # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it 304 if op_info.args_tree_spec is not None: 305 flatten_args_schema_to_reshard = tuple( 306 pytree.tree_leaves(suggested_input_schema.args_schema) 307 ) 308 else: 309 flatten_args_schema_to_reshard = suggested_input_schema.args_schema 310 311 new_local_args: List[object] = [] 312 for i, arg_spec in enumerate(op_info.flat_args_schema): 313 reshard_arg_spec = flatten_args_schema_to_reshard[i] 314 if isinstance(arg_spec, DTensorSpec): 315 local_tensor = cast(torch.Tensor, op_info.local_args[i]) 316 if arg_spec != reshard_arg_spec: 317 resharded_local_tensor = redistribute_local_tensor( 318 local_tensor, arg_spec, reshard_arg_spec 319 ) 320 new_local_args.append(resharded_local_tensor) 321 else: 322 new_local_args.append(local_tensor) 323 else: 324 new_local_args.append(reshard_arg_spec) 325 326 op_info.local_args = tuple(new_local_args) 327 328 def unwrap_to_op_info( 329 self, 330 op_call: torch._ops.OpOverload, 331 args: Tuple[object, ...], 332 kwargs: Dict[str, object], 333 ) -> OpInfo: 334 # get runtime schema info to determine whether to use pytree to flatten inputs 335 runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( 336 op_call, None 337 ) 338 339 if runtime_schema_info is not None and runtime_schema_info.needs_pytree: 340 # flatten args/kwargs when op says necessary 341 tree_args, args_spec = pytree.tree_flatten(args) 342 args_list: Sequence[object] = tree_args 343 else: 344 args_list, args_spec = args, None 345 346 args_schema: List[object] = [] 347 kwargs_schema: Dict[str, object] = {} 348 local_args: List[object] = [] 349 local_kwargs: Dict[str, object] = {} 350 mesh: Optional[DeviceMesh] = None 351 352 for arg in args_list: 353 if isinstance(arg, dtensor.DTensor): 354 local_args.append(arg._local_tensor) 355 if mesh is not None and mesh != arg.device_mesh: 356 # TODO: try replicate dtensor spec in missing dimension would work 357 # for most cases for foreach case except when the first DTensor in 358 # the list is one that also need to be replicated. We need to revisit 359 # how we want to handle this corner case. For now, this case would hit 360 # the cross mesh error even if implicit replication is turned on. 361 spec = self._try_replicate_dtensor_spec_in_missing_dim( 362 op_call, arg, mesh 363 ) 364 args_schema.append(spec) 365 else: 366 mesh = arg.device_mesh 367 args_schema.append(arg._spec) 368 elif isinstance(arg, torch.Tensor): 369 mesh = mesh or try_find_mesh_from_args(op_call, args_list) 370 args_schema.append( 371 self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh) 372 ) 373 local_args.append(arg) 374 else: 375 args_schema.append(arg) 376 local_args.append(arg) 377 378 for k, v in kwargs.items(): 379 if isinstance(v, dtensor.DTensor): 380 local_kwargs[k] = v._local_tensor 381 if mesh is not None and mesh != v.device_mesh: 382 spec = self._try_replicate_dtensor_spec_in_missing_dim( 383 op_call, v, mesh 384 ) 385 kwargs_schema[k] = spec 386 else: 387 mesh = v.device_mesh 388 kwargs_schema[k] = v._spec 389 elif isinstance(v, torch.Tensor): 390 mesh = mesh or try_find_mesh_from_args(op_call, args_list) 391 kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor( 392 op_call, v, mesh 393 ) 394 local_kwargs[k] = v 395 else: 396 kwargs_schema[k] = v 397 local_kwargs[k] = v 398 399 assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!" 400 op_info = OpInfo( 401 mesh, 402 OpSchema( 403 op_call, 404 pytree.tree_unflatten(args_schema, args_spec) 405 if args_spec 406 else tuple(args_schema), 407 kwargs_schema, 408 schema_info=runtime_schema_info, 409 ), 410 args_schema, 411 tuple(local_args), 412 local_kwargs, 413 args_spec, 414 ) 415 return op_info 416 417 @staticmethod 418 def wrap(res: object, spec: OutputSpecType) -> object: 419 if isinstance(res, torch.Tensor): 420 if spec is not None: 421 assert isinstance( 422 spec, DTensorSpec 423 ), f"output spec does not match with output! Expected DTensorSpec, got {spec}." 424 return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) 425 else: 426 # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor 427 assert res.ndim == 0, "output tensor should be scalar!" 428 return res 429 elif isinstance(res, (list, tuple)): 430 assert spec is not None and isinstance( 431 spec, (list, tuple) 432 ), f"output spec does not match with output! Expected list/tuple, got {spec}." 433 res_list = [] 434 for e, s in zip(res, spec): 435 res_list.append(OpDispatcher.wrap(e, s)) 436 437 return tuple(res_list) if isinstance(res, tuple) else res_list 438 else: 439 # if the res contains only non tensor values (i.e. int/float/none), we simply return it 440 # without rewrapping to DTensor. 441 return res 442 443 def _try_replicate_spec_for_scalar_tensor( 444 self, 445 op_call: torch._ops.OpOverload, 446 tensor_arg: torch.Tensor, 447 mesh: "DeviceMesh", 448 ) -> DTensorSpec: 449 # util function to produce a replicate spec for a scalar tensor arg/kwarg 450 if tensor_arg.numel() == 1 and tensor_arg.ndim == 1: 451 warnings.warn( 452 "Found a non-scalar tensor with numel=1 and ndim!=0, " 453 "we are implicitly creating a replicated DTensor for it. " 454 "However, please consider changing it to a scalar tensor " 455 "or explicitly create a DTensor under distributed enviroment." 456 ) 457 458 if tensor_arg.numel() == 1 or self._allow_implicit_replication: 459 # scalar tensor can be safely treated as replicated 460 replication_spec = DTensorSpec( 461 mesh, 462 (Replicate(),) * mesh.ndim, 463 tensor_meta=TensorMeta( 464 shape=tensor_arg.shape, 465 stride=tensor_arg.stride(), 466 dtype=tensor_arg.dtype, 467 ), 468 ) 469 else: 470 raise RuntimeError( 471 f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" 472 " torch.Tensor to DTensor before calling distributed operators!" 473 ) 474 return replication_spec 475 476 def _try_replicate_dtensor_spec_in_missing_dim( 477 self, 478 op_call: torch._ops.OpOverload, 479 dtensor_arg: "dtensor.DTensor", 480 mesh: "DeviceMesh", 481 ) -> DTensorSpec: 482 # util function to produce a new spec for a DTensor arg/kwarg 483 # that puts Replicate() placement in the missing dimension for foreach ops 484 from torch.distributed.device_mesh import _mesh_resources 485 486 cur_mesh = dtensor_arg.device_mesh 487 root_mesh = _mesh_resources.get_root_mesh(cur_mesh) 488 if ( 489 self._allow_implicit_replication 490 and "foreach" in op_call.__name__ 491 and root_mesh == mesh 492 ): 493 placements = [Replicate() for _ in range(root_mesh.ndim)] 494 cur_mesh_root_idx = _mesh_resources.get_root_mesh_dim(cur_mesh) 495 placements[cur_mesh_root_idx] = dtensor_arg.placements[0] # type: ignore[call-overload] 496 replicate_spec = DTensorSpec( 497 root_mesh, 498 tuple(placements), 499 tensor_meta=TensorMeta( 500 shape=dtensor_arg.shape, 501 stride=dtensor_arg.stride(), 502 dtype=dtensor_arg.dtype, 503 ), 504 ) 505 else: 506 raise NotImplementedError( 507 f"{op_call}: DTensor does not support cross-mesh operation yet! " 508 f"Got meshes: {mesh} {cur_mesh}" 509 ) 510 return replicate_spec 511