xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/_ops/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3import functools
4import itertools
5import operator
6from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union
7
8import torch
9from torch.distributed.tensor._api import DTensor
10from torch.distributed.tensor._collective_utils import redistribute_cost
11from torch.distributed.tensor._dtensor_spec import DTensorSpec
12from torch.distributed.tensor._op_schema import (
13    OpSchema,
14    OpStrategy,
15    PlacementList,
16    PlacementStrategy,
17    RuntimeSchemaInfo,
18)
19from torch.distributed.tensor.device_mesh import DeviceMesh
20from torch.distributed.tensor.placement_types import (
21    Partial,
22    Placement,
23    Replicate,
24    Shard,
25)
26
27
28# convenient wrapper to register sharding propagation rules
29# pyre-fixme[3]: Return type must be annotated.
30# pyre-fixme[2]: Parameter must be annotated.
31def register_prop_rule(op, schema_info=None):
32    # pyre-fixme[53]: Captured variable `func` is not annotated.
33    # pyre-fixme[3]: Return type must be annotated.
34    # pyre-fixme[2]: Parameter must be annotated.
35    def wrapper(impl):
36        overloads = op if isinstance(op, list) else [op]
37        for overload in overloads:
38            DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule(
39                overload, impl, schema_info
40            )
41        return impl
42
43    return wrapper
44
45
46def register_op_strategy(op, schema_info=None):
47    # pyre-fixme[53]: Captured variable `func` is not annotated.
48    # pyre-fixme[3]: Return type must be annotated.
49    # pyre-fixme[2]: Parameter must be annotated.
50
51    # For every ATen op that accepts any args in this list,
52    # the arg itself can impact the strides (and potentially the sharding strategy)
53    # of the output tensor.
54    # thus, we will detect ATen schemas with any of these args and ensure
55    # that they get specialized here.
56    arg_names_that_require_specializing_cache_strategy = [
57        "memory_format",
58    ]
59
60    def wrapper(impl):
61        if isinstance(op, list):
62            overloads = op
63        else:
64            overloads = [op]
65
66        for overload in overloads:
67            curr_schema_info = None
68            if schema_info is None:
69                specialized_args = [
70                    a.name
71                    for a in overload._schema.arguments
72                    if a.name in arg_names_that_require_specializing_cache_strategy
73                ]
74                if any(specialized_args):
75                    curr_schema_info = RuntimeSchemaInfo(
76                        static_kwargkey=specialized_args
77                    )
78            else:
79                curr_schema_info = schema_info
80            DTensor._op_dispatcher.sharding_propagator.register_op_strategy(
81                overload, impl, curr_schema_info
82            )
83        return impl
84
85    return wrapper
86
87
88def as_list(
89    x: Union[List[object], object]
90    # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
91) -> Union[List[object], torch.fx.immutable_collections.immutable_list]:  # type: ignore[valid-type]
92    # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
93    # which is an object but treated as a list by the tracer. Therefore, keep
94    # `immutable_list` intact here as well.
95    if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list):
96        return x
97    else:
98        return [x]
99
100
101def normalize_dim(dim: int, ndim: int) -> int:
102    return dim if dim >= 0 else dim + ndim
103
104
105def normalize_dims(dims: Union[int, Sequence[int]], ndim: int) -> Sequence[int]:
106    """Normalize a dim or a sequence of dims, so that they are all positive."""
107    if isinstance(dims, int):
108        dims = (normalize_dim(dims, ndim),)
109    elif isinstance(dims, list):
110        dims = [normalize_dim(dim, ndim) for dim in dims]
111    elif isinstance(dims, tuple):
112        dims = tuple([normalize_dim(dim, ndim) for dim in dims])
113    return dims
114
115
116def prod(xs: Iterable[int]) -> int:
117    return functools.reduce(operator.mul, xs, 1)
118
119
120def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
121    """Check if the shape is shardable according to the spec."""
122    # number of shards in each tensor dimension
123    shards_map = [1] * len(shape)
124    for i, placement in enumerate(spec.placements):
125        if placement.is_shard():
126            shard_dim = cast(Shard, placement).dim
127            shards_map[shard_dim] *= spec.mesh.size(i)
128
129    for i, dim_size in enumerate(shape):
130        # TODO: maybe we should determine is_shardable based on
131        #       whether it's evenly sharded or not
132        if shards_map[i] > 1 and dim_size < shards_map[i]:
133            return False
134
135    return True
136
137
138def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
139    """Check if the shape is evenly shardable according to the spec."""
140    # number of shards in each tensor dimension
141    shards_map = [1] * len(shape)
142    for i, placement in enumerate(spec.placements):
143        if placement.is_shard():
144            shard_dim = cast(Shard, placement).dim
145            shards_map[shard_dim] *= spec.mesh.size(i)
146
147    for i, dim_size in enumerate(shape):
148        if shards_map[i] > 1 and (dim_size % shards_map[i] != 0):
149            return False
150
151    return True
152
153
154def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool:
155    """Return True if tensor dim is sharded."""
156    return any(p.is_shard(dim) for p in spec.placements)
157
158
159def is_tensor_partial(spec: DTensorSpec) -> bool:
160    """Return True if tensor is partial on the mesh."""
161    return any(p.is_partial() for p in spec.placements)
162
163
164def infer_broadcast_dims_map(
165    common_shape: torch.Size, input_shape: torch.Size
166) -> List[int]:
167    # infer the broadcast dims map, where it maps from the common shape dim to the input shape dim
168    # this is aligned with the broadcast semantics
169    common_ndim = len(common_shape)
170    input_ndim = len(input_shape)
171    broadcast_dims_map = [-1] * common_ndim
172    for idx in range(-1, -1 - input_ndim, -1):
173        if input_shape[idx] == common_shape[idx]:
174            broadcast_dims_map[common_ndim + idx] = input_ndim + idx
175    return broadcast_dims_map
176
177
178def map_placements_after_broadcast(
179    placements: Tuple[Placement, ...],
180    shape: torch.Size,
181    broadcast_dims_map: List[int],
182) -> Tuple[Placement, ...]:
183    """Map each placement based on the output shape after broadcast."""
184    new_placements: List[Placement] = []
185    for placement in placements:
186        if isinstance(placement, (Replicate, Partial)):
187            new_placements.append(placement)
188        else:
189            assert isinstance(placement, Shard)
190            shard_dim = normalize_dim(placement.dim, len(shape))
191            new_shard_dim = broadcast_dims_map[shard_dim]
192            if new_shard_dim != -1:
193                # there's a map from the common shape shard dim to
194                # the input shape shard dim before broadcasting,
195                # use that instead
196                new_placements.append(Shard(new_shard_dim))
197            else:
198                # there's no map between common shape shard dim and
199                # the input shape shard dim before broadcasting,
200                # in this case it means implicit broadcasting happen
201                # in this dim, so we can just mark it as replicate
202                # and implict broadcast will broadcast automatically
203                # to the sharded shape
204                new_placements.append(Replicate())
205
206    return tuple(new_placements)
207
208
209def generate_redistribute_costs(
210    src_strategy: OpStrategy, dst_spec: DTensorSpec
211) -> List[float]:
212    redistribute_costs: List[float] = []
213    for strat in src_strategy.strategies:
214        redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec))
215
216    return redistribute_costs
217
218
219def expand_to_full_mesh_op_strategy(
220    mesh: DeviceMesh,
221    op_schema: OpSchema,
222    single_mesh_dim_strategies: List[PlacementList],
223    *,
224    input_index: int = 1,
225    inplace_op: bool = False,
226) -> OpStrategy:
227    # Expand the single_mesh_dim_strategies to full mesh dim strategies.
228    all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim
229
230    strategy_combs = itertools.product(*all_mesh_dim_strategies)
231
232    all_strategies = []
233    for strategy_comb in strategy_combs:
234        spec_list: List[Optional[DTensorSpec]] = []
235        for specs in zip(*strategy_comb):
236            if specs[0] is not None:
237                spec_list.append(DTensorSpec(mesh, specs))
238            else:
239                spec_list.append(None)
240
241        input_specs: List[DTensorSpec] = [
242            s for s in spec_list[input_index:] if isinstance(s, DTensorSpec)
243        ]
244
245        input_args_strategy = op_schema.args_strategy
246        assert len(input_specs) == len(input_args_strategy)
247        self_spec = input_args_strategy[0].strategies[0].output_spec
248
249        if inplace_op and self_spec.placements != input_specs[0].placements:
250            # if it's inplace op, we would only allow the placement strategy to be added when the
251            # input_spec matches the first argument's runtime sharding, otherwise we skip
252            continue
253
254        # check inputs shardable
255        inputs_shardable = all(
256            is_tensor_shardable(inp.shape, s)
257            for inp, s in zip(input_args_strategy, input_specs)
258        )
259
260        # only add to the all_strategies list when all inputs are shardable
261        if inputs_shardable:
262            redistribute_cost = [
263                generate_redistribute_costs(input_strategy, input_spec)
264                for input_strategy, input_spec in zip(input_args_strategy, input_specs)
265            ]
266            if input_index > 1:
267                output_specs = tuple(spec_list[:input_index])
268            else:
269                if spec_list[0] is not None:
270                    output_specs = spec_list[0]  # type: ignore[assignment]
271                else:
272                    raise RuntimeError("output spec is None")
273            strategy = PlacementStrategy(
274                output_specs=output_specs,
275                input_specs=input_specs,
276                redistribute_cost=redistribute_cost,
277            )
278            all_strategies.append(strategy)
279
280    return OpStrategy(all_strategies)
281