xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/_dispatch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2import contextlib
3import functools
4import logging
5import operator
6import warnings
7from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
8
9import torch
10import torch.distributed as dist
11import torch.distributed.tensor._api as dtensor
12import torch.distributed.tensor._random as random
13from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
14from torch.distributed.tensor._op_schema import (
15    _is_inplace_op,
16    _is_out_variant_op,
17    OpInfo,
18    OpSchema,
19    OutputSpecType,
20)
21from torch.distributed.tensor._random import is_rng_supported_mesh
22from torch.distributed.tensor._redistribute import redistribute_local_tensor
23from torch.distributed.tensor._sharding_prop import ShardingPropagator
24from torch.distributed.tensor._tp_conv import (
25    convolution_backward_handler,
26    convolution_handler,
27)
28from torch.distributed.tensor._utils import try_find_mesh_from_args
29from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
30
31
32if TYPE_CHECKING:
33    from torch.distributed.device_mesh import DeviceMesh
34
35try:
36    from torch.utils import _cxx_pytree as pytree
37except ImportError:
38    from torch.utils import _pytree as pytree  # type: ignore[no-redef]
39
40aten = torch.ops.aten
41logger = logging.getLogger(__name__)
42
43
44def decompose_handler(
45    op_call: torch._ops.OpOverload,
46    args: Tuple[object, ...],
47    kwargs: Dict[str, object],
48) -> object:
49    """
50    Decomposes a op to core ATen op, this handler is mostly here
51    for inference mode usage where the ops are not core aten ops.
52    """
53    r = op_call.decompose(*args, **kwargs)
54    if r is not NotImplemented:
55        return r
56    else:
57        raise RuntimeError("Decomposition failed")
58
59
60def is_same_size_handler(
61    op_call: torch._ops.OpOverload,
62    args: Tuple[object, ...],
63    kwargs: Dict[str, object],
64) -> bool:
65    lhs = cast(torch.Tensor, args[0])
66    rhs = cast(torch.Tensor, args[1])
67    return lhs.shape == rhs.shape
68
69
70def found_inf_reduce_handler(
71    op_call: torch._ops.OpOverload,
72    args: Tuple[object, ...],
73    kwargs: Dict[str, object],
74) -> None:
75    op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
76    local_tensor_args = pytree.tree_unflatten(
77        cast(List[object], op_info.local_args), op_info.args_tree_spec
78    )
79    local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
80    local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
81
82    grad_dtensor = cast(list[dtensor.DTensor], args[0])[0]
83    grad_placements = grad_dtensor.placements
84    mesh = grad_dtensor.device_mesh
85
86    found_inf_placements: list[Placement] = []
87    for placement in grad_placements:
88        if isinstance(placement, Replicate):
89            found_inf_placements.append(placement)
90        else:
91            found_inf_placements.append(Partial("max"))
92
93    target_tensor = cast(torch.Tensor, args[1])
94    spec = DTensorSpec(
95        mesh=mesh,
96        placements=tuple(found_inf_placements),
97        tensor_meta=TensorMeta(
98            shape=target_tensor.size(),
99            stride=target_tensor.stride(),
100            dtype=target_tensor.dtype,
101        ),
102    )
103    found_inf_dtensor = dtensor.DTensor(
104        local_tensor=target_tensor, spec=spec, requires_grad=False
105    )
106    found_inf = found_inf_dtensor.full_tensor()
107    target_tensor.copy_(found_inf)
108
109
110class OpDispatcher:
111    """
112    Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding
113    propagation, redistribute local args, local compute, and post-processing (re-wrapping). It
114    also handles any op specific logic if necessary.
115
116    NOTE: Given the runtime overhead of Tensor subclass (__torch_dispatch__), the OpDispatcher
117    is designed to minimize the CPU overhead by using the tricks of proper unflattening, faster
118    pytree if needed, and leveraging various caching mechanisms implemented in the sharding
119    propagation and redistribute modules. The CPU overhead is critical to eager mode performance,
120    one need to carefully measure the CPU overhead when making significant changes to the
121    OpDispatcher and ShardingPropagator.
122    """
123
124    def __init__(self) -> None:
125        self.sharding_propagator = ShardingPropagator()
126        self._random_ops = {
127            aten.native_dropout.default,
128            aten.normal_.default,
129            aten.rand_like.default,
130            aten.randn_like.default,
131            aten.randint_like.default,
132            aten.randint_like.low_dtype,
133            aten.randint_like.low_dtype_out,
134            aten.uniform_.default,
135            aten.bernoulli.default,
136            aten.bernoulli_.float,
137        }
138        self._custom_op_handlers = {
139            aten.linear.default: decompose_handler,
140            aten.is_same_size.default: is_same_size_handler,
141            aten.convolution.default: convolution_handler,
142            aten.convolution_backward.default: convolution_backward_handler,
143            aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
144        }
145
146        # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
147        # as implicitly replicated or we throw error to user.
148        # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
149        # it as False by default.
150        self._allow_implicit_replication = False
151
152    def dispatch(
153        self,
154        op_call: torch._ops.OpOverload,
155        args: Tuple[object, ...],
156        kwargs: Dict[str, object],
157    ) -> object:
158        """
159        Main dispatching logic
160        """
161        # operators that does not need to go through sharding propagation
162        if op_call in self._custom_op_handlers:
163            return self._custom_op_handlers[op_call](op_call, args, kwargs)  # type: ignore[operator]
164
165        # extract local tensor and sharding infos to a OpInfo
166        op_info = self.unwrap_to_op_info(op_call, args, kwargs)
167        logger.debug("Dispatching op_call: %s", op_info.schema)
168
169        self.sharding_propagator.propagate(op_info)
170        output_sharding = op_info.output_sharding
171        logger.debug("output_sharding for %s: %s", op_call, output_sharding)
172        assert output_sharding is not None, "output sharding should not be None"
173
174        mesh = op_info.mesh
175        if mesh.get_coordinate() is not None:
176            # computation that happens in the current rank of the mesh, normal case
177            if output_sharding.needs_redistribute:
178                # If sharding propagation decision needs redistribute, perform redistribute
179                # on args first, which could potentially modify args (i.e. allgather certain arg)
180                assert output_sharding.redistribute_schema is not None
181                self.redistribute_local_args(
182                    op_info, output_sharding.redistribute_schema
183                )
184
185            local_tensor_args = (
186                pytree.tree_unflatten(
187                    cast(List[object], op_info.local_args), op_info.args_tree_spec
188                )
189                if op_info.args_tree_spec
190                else op_info.local_args
191            )
192
193            # run local op computation with potentially modified args/kwargs
194            local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
195            if op_call in self._random_ops:
196                if not random._rng_tracker and is_rng_supported_mesh(mesh):
197                    # Default to `OffsetBasedRNGTracker` if the parallelism API
198                    # did not already construct one
199                    random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type)
200
201                first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast(
202                    torch.Tensor, local_tensor_args[0]
203                )
204                rng_context = (
205                    random._rng_tracker._distribute_region(first_arg._spec)
206                    if random._rng_tracker and not first_local_arg.is_meta
207                    else contextlib.nullcontext()
208                )
209                # For DTensor random operator, run it within a RNGTracker context to
210                # ensure the random number generator is properly distributed.
211                with rng_context:
212                    local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
213            else:
214                # normal case, run local sharded op computation
215                local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
216
217        else:
218            # For a non-participating device (happens on rank that does not belong to
219            # the device mesh), we do:
220            #   1. if the return type is scalar, set the local result to None.
221            #   2. if the return type is Tensor or List[Tensor], return empty
222            #   tensor(s) with correct dtype.
223            spec = output_sharding.output_spec
224            ret_list = op_info.schema.op._schema.returns
225
226            if spec is None:
227                # For a scalar return type, the non-participating device has None
228                # as its local result
229                local_results = None
230            else:
231
232                def default_tensor(spec: DTensorSpec) -> torch.Tensor:
233                    if spec.tensor_meta is not None:
234                        shape = spec.tensor_meta.shape
235                        dtype = spec.tensor_meta.dtype
236                        if len(shape) == 0:
237                            # scalar tensor
238                            return torch.zeros((), dtype=dtype)
239                        else:
240                            # non-scalar tensor
241                            return torch.tensor([], dtype=dtype)
242                    else:
243                        raise RuntimeError(f"{spec} has no tensor metadata.")
244
245                if isinstance(spec, DTensorSpec):
246                    # return a Tensor value
247                    local_results = default_tensor(spec)
248                elif isinstance(spec, Sequence):
249                    # return a List[Tensor] value
250                    local_results = [
251                        default_tensor(s) if s is not None else None for s in spec
252                    ]
253                    assert isinstance(local_results, List)
254                    if None in local_results:
255                        ret_type = str(ret_list[0].type)
256                        raise NotImplementedError(
257                            f"return type {ret_type} in DTensor op is not supported"
258                        )
259
260        if output_sharding.output_spec is None:
261            if op_call == aten.equal.default:
262                # For equal operator, The local results from all devices should be all-gathered
263                # and a reduce op (AND) will be performed on the list of results to ensure SPMD
264                # execution. We can extend this for more ops if necessary.
265                obj_list = [None for _ in range(dist.get_world_size())]
266                dist.all_gather_object(obj_list, local_results)  # type: ignore[possibly-undefined]
267                obj_list = list(filter(lambda x: x is not None, obj_list))
268                # perform reduce on the collection with AND op
269                local_results = functools.reduce(operator.and_, obj_list, True)
270
271        if _is_inplace_op(op_call):
272            # inplace op should return self instead of re-wrapping
273            if output_sharding.output_spec is not None:
274                return args[0]
275            else:
276                return None
277        elif _is_out_variant_op(op_call):
278            # out variant could possibly have multiple out args (i.e. lu_unpack.out)
279            output_specs = (
280                (output_sharding.output_spec,)
281                if not isinstance(output_sharding.output_spec, tuple)
282                else output_sharding.output_spec
283            )
284            out_dts = []
285            spec_idx = 0
286            for argument in op_call._schema.arguments:
287                if argument.is_out:
288                    out_dt = cast(dtensor.DTensor, kwargs[argument.name])
289                    out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
290                    out_dts.append(out_dt)
291                    spec_idx += 1
292
293            assert len(out_dts) >= 1, "out variant should have at least one out arg"
294            return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
295        else:
296            return self.wrap(local_results, output_sharding.output_spec)  # type: ignore[possibly-undefined]
297
298    @staticmethod
299    def redistribute_local_args(
300        op_info: OpInfo,
301        suggested_input_schema: OpSchema,
302    ) -> None:
303        # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
304        if op_info.args_tree_spec is not None:
305            flatten_args_schema_to_reshard = tuple(
306                pytree.tree_leaves(suggested_input_schema.args_schema)
307            )
308        else:
309            flatten_args_schema_to_reshard = suggested_input_schema.args_schema
310
311        new_local_args: List[object] = []
312        for i, arg_spec in enumerate(op_info.flat_args_schema):
313            reshard_arg_spec = flatten_args_schema_to_reshard[i]
314            if isinstance(arg_spec, DTensorSpec):
315                local_tensor = cast(torch.Tensor, op_info.local_args[i])
316                if arg_spec != reshard_arg_spec:
317                    resharded_local_tensor = redistribute_local_tensor(
318                        local_tensor, arg_spec, reshard_arg_spec
319                    )
320                    new_local_args.append(resharded_local_tensor)
321                else:
322                    new_local_args.append(local_tensor)
323            else:
324                new_local_args.append(reshard_arg_spec)
325
326        op_info.local_args = tuple(new_local_args)
327
328    def unwrap_to_op_info(
329        self,
330        op_call: torch._ops.OpOverload,
331        args: Tuple[object, ...],
332        kwargs: Dict[str, object],
333    ) -> OpInfo:
334        # get runtime schema info to determine whether to use pytree to flatten inputs
335        runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
336            op_call, None
337        )
338
339        if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
340            # flatten args/kwargs when op says necessary
341            tree_args, args_spec = pytree.tree_flatten(args)
342            args_list: Sequence[object] = tree_args
343        else:
344            args_list, args_spec = args, None
345
346        args_schema: List[object] = []
347        kwargs_schema: Dict[str, object] = {}
348        local_args: List[object] = []
349        local_kwargs: Dict[str, object] = {}
350        mesh: Optional[DeviceMesh] = None
351
352        for arg in args_list:
353            if isinstance(arg, dtensor.DTensor):
354                local_args.append(arg._local_tensor)
355                if mesh is not None and mesh != arg.device_mesh:
356                    # TODO: try replicate dtensor spec in missing dimension would work
357                    # for most cases for foreach case except when the first DTensor in
358                    # the list is one that also need to be replicated. We need to revisit
359                    # how we want to handle this corner case. For now, this case would hit
360                    # the cross mesh error even if implicit replication is turned on.
361                    spec = self._try_replicate_dtensor_spec_in_missing_dim(
362                        op_call, arg, mesh
363                    )
364                    args_schema.append(spec)
365                else:
366                    mesh = arg.device_mesh
367                    args_schema.append(arg._spec)
368            elif isinstance(arg, torch.Tensor):
369                mesh = mesh or try_find_mesh_from_args(op_call, args_list)
370                args_schema.append(
371                    self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
372                )
373                local_args.append(arg)
374            else:
375                args_schema.append(arg)
376                local_args.append(arg)
377
378        for k, v in kwargs.items():
379            if isinstance(v, dtensor.DTensor):
380                local_kwargs[k] = v._local_tensor
381                if mesh is not None and mesh != v.device_mesh:
382                    spec = self._try_replicate_dtensor_spec_in_missing_dim(
383                        op_call, v, mesh
384                    )
385                    kwargs_schema[k] = spec
386                else:
387                    mesh = v.device_mesh
388                    kwargs_schema[k] = v._spec
389            elif isinstance(v, torch.Tensor):
390                mesh = mesh or try_find_mesh_from_args(op_call, args_list)
391                kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
392                    op_call, v, mesh
393                )
394                local_kwargs[k] = v
395            else:
396                kwargs_schema[k] = v
397                local_kwargs[k] = v
398
399        assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!"
400        op_info = OpInfo(
401            mesh,
402            OpSchema(
403                op_call,
404                pytree.tree_unflatten(args_schema, args_spec)
405                if args_spec
406                else tuple(args_schema),
407                kwargs_schema,
408                schema_info=runtime_schema_info,
409            ),
410            args_schema,
411            tuple(local_args),
412            local_kwargs,
413            args_spec,
414        )
415        return op_info
416
417    @staticmethod
418    def wrap(res: object, spec: OutputSpecType) -> object:
419        if isinstance(res, torch.Tensor):
420            if spec is not None:
421                assert isinstance(
422                    spec, DTensorSpec
423                ), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
424                return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
425            else:
426                # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
427                assert res.ndim == 0, "output tensor should be scalar!"
428                return res
429        elif isinstance(res, (list, tuple)):
430            assert spec is not None and isinstance(
431                spec, (list, tuple)
432            ), f"output spec does not match with output! Expected list/tuple, got {spec}."
433            res_list = []
434            for e, s in zip(res, spec):
435                res_list.append(OpDispatcher.wrap(e, s))
436
437            return tuple(res_list) if isinstance(res, tuple) else res_list
438        else:
439            # if the res contains only non tensor values (i.e. int/float/none), we simply return it
440            # without rewrapping to DTensor.
441            return res
442
443    def _try_replicate_spec_for_scalar_tensor(
444        self,
445        op_call: torch._ops.OpOverload,
446        tensor_arg: torch.Tensor,
447        mesh: "DeviceMesh",
448    ) -> DTensorSpec:
449        # util function to produce a replicate spec for a scalar tensor arg/kwarg
450        if tensor_arg.numel() == 1 and tensor_arg.ndim == 1:
451            warnings.warn(
452                "Found a non-scalar tensor with numel=1 and ndim!=0, "
453                "we are implicitly creating a replicated DTensor for it. "
454                "However, please consider changing it to a scalar tensor "
455                "or explicitly create a DTensor under distributed enviroment."
456            )
457
458        if tensor_arg.numel() == 1 or self._allow_implicit_replication:
459            # scalar tensor can be safely treated as replicated
460            replication_spec = DTensorSpec(
461                mesh,
462                (Replicate(),) * mesh.ndim,
463                tensor_meta=TensorMeta(
464                    shape=tensor_arg.shape,
465                    stride=tensor_arg.stride(),
466                    dtype=tensor_arg.dtype,
467                ),
468            )
469        else:
470            raise RuntimeError(
471                f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
472                " torch.Tensor to DTensor before calling distributed operators!"
473            )
474        return replication_spec
475
476    def _try_replicate_dtensor_spec_in_missing_dim(
477        self,
478        op_call: torch._ops.OpOverload,
479        dtensor_arg: "dtensor.DTensor",
480        mesh: "DeviceMesh",
481    ) -> DTensorSpec:
482        # util function to produce a new spec for a DTensor arg/kwarg
483        # that puts Replicate() placement in the missing dimension for foreach ops
484        from torch.distributed.device_mesh import _mesh_resources
485
486        cur_mesh = dtensor_arg.device_mesh
487        root_mesh = _mesh_resources.get_root_mesh(cur_mesh)
488        if (
489            self._allow_implicit_replication
490            and "foreach" in op_call.__name__
491            and root_mesh == mesh
492        ):
493            placements = [Replicate() for _ in range(root_mesh.ndim)]
494            cur_mesh_root_idx = _mesh_resources.get_root_mesh_dim(cur_mesh)
495            placements[cur_mesh_root_idx] = dtensor_arg.placements[0]  # type: ignore[call-overload]
496            replicate_spec = DTensorSpec(
497                root_mesh,
498                tuple(placements),
499                tensor_meta=TensorMeta(
500                    shape=dtensor_arg.shape,
501                    stride=dtensor_arg.stride(),
502                    dtype=dtensor_arg.dtype,
503                ),
504            )
505        else:
506            raise NotImplementedError(
507                f"{op_call}: DTensor does not support cross-mesh operation yet! "
508                f"Got meshes: {mesh} {cur_mesh}"
509            )
510        return replicate_spec
511