xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/_op_schema.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from dataclasses import dataclass
3from functools import cached_property
4from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
5
6import torch
7from torch._ops import OpOverload
8from torch.distributed.device_mesh import DeviceMesh
9from torch.distributed.tensor._dtensor_spec import DTensorSpec
10from torch.distributed.tensor.placement_types import Placement
11
12
13try:
14    from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec
15except ImportError:
16    from torch.utils._pytree import (  # type: ignore[no-redef, assignment]
17        tree_leaves,
18        tree_map_only,
19        TreeSpec,
20    )
21
22
23# Common type aliases
24ArgsType = Tuple[object, ...]
25KwargsType = Dict[str, object]
26
27PlacementList = List[Optional[Placement]]
28
29# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
30# be the same set of possibilities.
31OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
32
33
34def _rebuild_tensor_from_dtensor_meta(arg) -> object:
35    """
36    This is used to propagate tensor metadata, must be under fake mode
37    """
38    assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta."
39    return torch.empty_strided(
40        arg.tensor_meta.shape,
41        arg.tensor_meta.stride,
42        dtype=arg.tensor_meta.dtype,
43    )
44
45
46def _is_inplace_op(op: OpOverload):
47    # simple analysis of function schema to determine
48    # if this is an inplace variant, it might not
49    # be entirely correct, but it's good enough for now.
50    return op._schema.name[-1] == "_"
51
52
53def _is_out_variant_op(op: OpOverload):
54    # simple analysis of function schema to determine
55    # if this is an out variant, it might not
56    # be entirely correct, but it's good enough for now.
57    return "out" in op._schema.overload_name
58
59
60def _pretty_print_spec(spec: object) -> str:
61    if spec is None:
62        return "None"
63    elif isinstance(spec, DTensorSpec):
64        return "".join([str(p) for p in spec.placements])
65    elif isinstance(spec, Sequence):
66        return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")"
67    else:
68        raise RuntimeError(f"Unknown spec type to print: spec={spec}")
69
70
71@dataclass
72class PlacementStrategy:
73    """
74    A placement strategy describes acceptable sharding placements of the output
75    and the tensor arguments of an operation.
76
77    note: when the op return value is a single DTensor object, output_specs is
78    DTensorSpec; when the return value is a tuple of Optional[DTensor],
79    output_specs is a tuple of Optional[DTensorSpec].
80    """
81
82    output_specs: Union[DTensorSpec, Tuple[Optional[DTensorSpec], ...]]
83    input_specs: Optional[Sequence[DTensorSpec]] = None
84
85    # redistribute costs for this op placement strategy
86    # we need a nested list to record the cost for each
87    # operand of this operator, and for each operand of
88    # this operator it might have multiple placement strategies
89    redistribute_cost: Optional[List[List[float]]] = None
90
91    @cached_property
92    def output_spec(self) -> DTensorSpec:
93        """
94        This function requires that the strategy have exactly one DTensorSpec as the
95        output spec. If the output_specs is a tuple, we throw an exception.
96        """
97        if isinstance(self.output_specs, DTensorSpec):
98            return self.output_specs
99        else:
100            raise ValueError(
101                f"function output_spec expects a single DTensorSpec but got: {self.output_specs}"
102            )
103
104    def input_spec(self, index: int = 0) -> DTensorSpec:
105        assert self.input_specs is not None, "input_specs of PlacementStrategy is None!"
106        assert len(self.input_specs) > index, (
107            f"Invalid index {index} for input_specs of length "
108            f"{len(self.input_specs)}: {self.input_specs}"
109        )
110        return self.input_specs[index]
111
112    def __str__(self) -> str:
113        if self.input_specs is not None:
114            input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> "
115        else:
116            input_specs_str = ""
117        output_spec_str = _pretty_print_spec(self.output_specs)
118        return f"{input_specs_str}{output_spec_str}"
119
120
121class StrategyType:
122    """
123    Base class type for op strategy, We have two StrategyType:
124        OpStrategy and TupleStrategy
125    """
126
127
128class OpStrategy(StrategyType):
129    """
130    OpStrategy that consists of a list of placement strategies associated with the op
131    """
132
133    def __init__(self, strategies: List[PlacementStrategy]) -> None:
134        super().__init__()
135        self.strategies: List[PlacementStrategy] = strategies
136
137    def __str__(self) -> str:
138        strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
139        mesh_shape = self.mesh_shape
140        return f"[{strategy_list_str}] @ mesh: {mesh_shape}"
141
142    def max_num_shards(self) -> int:
143        """
144        Returns the max number of shards across all placement strategies
145        """
146        return max(strategy.output_spec.num_shards for strategy in self.strategies)
147
148    @property
149    def mesh_shape(self):
150        output_spec = self.strategies[0].output_specs
151        if isinstance(output_spec, DTensorSpec):
152            return output_spec.mesh.shape
153        else:
154            assert isinstance(
155                output_spec, tuple
156            ), "found no DTensorSpec in the OpStrategy!"
157            assert output_spec[0] is not None
158            return output_spec[0].mesh.shape
159
160    @property
161    def ndim(self):
162        return self.strategies[0].output_spec.ndim
163
164    @property
165    def shape(self):
166        return self.strategies[0].output_spec.shape
167
168
169class TupleStrategy(StrategyType):
170    """
171    TupleStrategy represents the output strategy of this op is a tuple
172    of strategy, i.e. If the output of this op is a tuple of tensors or list of tensors
173    with possibly different placement strategies, we should return a TupleStrategy that
174    contains a tuple of OpStrategy, where each child represents the sharding strategy
175    of "each element" of the tuple/list of tensors the op returns.
176
177    NOTE: if the output of the op is a List[Tensor] and they share the same placement
178    strategy, then we should return a single OpStrategy instead of a TupleStrategy
179    """
180
181    def __init__(self, childs: Sequence[StrategyType]) -> None:
182        super().__init__()
183        self.childs: Sequence[StrategyType] = childs
184
185    def __str__(self) -> str:
186        child_strategies_str = ", ".join(
187            [f"{str(strat)}" for idx, strat in enumerate(self.childs)]
188        )
189        return f"TupleStrategy({child_strategies_str})"
190
191
192@dataclass
193class RuntimeSchemaInfo:
194    """
195    RuntimeSchemaInfo stores the operator schema related information for runtime (eager)
196    execution. This is mainly used for two ways: 1. to generate hash for args to determine
197    whether to re-run sharding prop or not 2. to determine if we need pytree
198    """
199
200    # This static_argnum records static arg "starting index" for ops that have non-tensor
201    # args/kwargs which would affect sharding propagation results. All args starting from
202    # this index would be hashed to our sharding cache.
203    # Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
204    static_argnum: int = 100
205    # This static_kwargkey records static kwarg names which would affect sharding prop
206    static_kwargkey: Optional[List[str]] = None
207    # each op can decide if it wants to use pytree flatten/unflatten during operator
208    # eager execution, by default we don't need to do flatten/unflatten, only if the
209    # op indicate it needs to, this is to accelerate eager performance.
210    needs_pytree: bool = False
211
212
213@dataclass
214class OpSchema:
215    """
216    OpSchema is a data class that describes an operator input schemas, it includes
217    DTensorSpecs (instead of DTensor) and non-tensor args/kwargs (positional order
218    preserved). It is mainly used by the DTensor's dispatching logic to perform various
219    actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.)
220
221    NOTE: this should be used as a read only data class
222    TODO: make this a frozen dataclass
223
224    Args:
225        op: the operator overload we are intercepting
226        args_schema: contains args except that the DTensor args have been replaced
227            with its DTensorSpec or OpStrategy
228        kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
229            with its DTensorSpec or OpStrategy
230    """
231
232    op: OpOverload
233    args_schema: ArgsType
234    kwargs_schema: KwargsType
235
236    schema_info: Optional[RuntimeSchemaInfo] = None
237
238    @property
239    def args_spec(self) -> Tuple[DTensorSpec, ...]:
240        """
241        args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list
242            with NO non-DTensor positional arguments (i.e. int/float/tuple, etc)
243            mainly used by sharding propagation to propagate the output spec
244        """
245        args = (
246            tree_leaves(self.args_schema)
247            if self.schema_info is not None and self.schema_info.needs_pytree
248            else self.args_schema
249        )
250        return tuple(item for item in args if isinstance(item, DTensorSpec))
251
252    @property
253    def args_strategy(self) -> Tuple[OpStrategy, ...]:
254        # filter out non-relevant values from args schema to get a clean OpStrategy list
255        # separate with args_spec for the ease of type annotation
256        # TODO: see if we should merge this with args_spec
257        args = (
258            tree_leaves(self.args_schema)
259            if self.schema_info is not None and self.schema_info.needs_pytree
260            else self.args_schema
261        )
262        return tuple(item for item in args if isinstance(item, OpStrategy))
263
264    def __repr__(self) -> str:
265        args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema])
266        return (
267            f"OpSchema(op={self.op},"
268            f" args_schema=({args_schema}),"
269            f" kwargs_schema={self.kwargs_schema})"
270        )
271
272    def __str__(self) -> str:
273        args_schema: List[str] = []
274        mesh_shape = None
275        for arg in self.args_schema:
276            if isinstance(arg, DTensorSpec):
277                args_schema.append(str(arg))
278                mesh_shape = arg.mesh.shape
279            elif isinstance(arg, OpStrategy):
280                assert len(arg.strategies) == 1
281                args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
282                mesh_shape = arg.mesh_shape
283            elif isinstance(arg, TupleStrategy):
284                first_op_strtgy = arg.childs[0]
285                assert isinstance(first_op_strtgy, OpStrategy)
286                mesh_shape = first_op_strtgy.mesh_shape
287                args_schema.append(str(arg))
288            else:
289                args_schema.append(str(arg))
290        return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})"
291
292    def __post_init__(self) -> None:
293        has_symints = False
294        for a in self.args_schema:
295            if isinstance(a, DTensorSpec) and a.tensor_meta is not None:
296                if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape):
297                    has_symints = True
298                    break
299        self.has_symints = has_symints
300
301    def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool:
302        arg = self.args_schema[arg_idx]
303        is_tensor = isinstance(arg, DTensorSpec)
304        if is_tensor:
305            return True
306
307        if not isinstance(arg, list):
308            return False
309
310        return all(isinstance(e, DTensorSpec) or e is None for e in arg)
311
312    def return_type_tuple_tensor_like(self) -> bool:
313        # all dispatch ops could only return Tuple[Tensor] or have None/ints/floats
314        # in the tuple, but the first element must be a Tensor, so this check is enough
315        return_types = self.op._schema.returns
316        return len(return_types) > 1 and isinstance(
317            return_types[0].type, torch.TensorType
318        )
319
320    def return_type_tensor(self) -> bool:
321        return_types = self.op._schema.returns
322        # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like
323        # return types, so this check is enough for tensor like types
324        return isinstance(return_types[0].type, torch.TensorType)
325
326    def __hash__(self) -> int:
327        # Only hash args and kwargs that op indicates to hash
328        if not self.schema_info:
329            static_argnum = len(self.args_schema)
330            static_kwargkey = None
331        else:
332            static_argnum = self.schema_info.static_argnum
333            static_kwargkey = self.schema_info.static_kwargkey
334
335        args_to_hash = tuple(
336            tuple(e) if isinstance(e, list) else e
337            for i, e in enumerate(self.args_schema)
338            if self.arg_type_tensor_or_tensor_list_like(i) or i >= static_argnum
339        )
340        if static_kwargkey is not None:
341            kwargs_to_hash = tuple(
342                self.kwargs_schema.get(k, None) for k in static_kwargkey
343            )
344            return hash((self.op, args_to_hash, kwargs_to_hash))
345        else:
346            return hash((self.op, args_to_hash))
347
348    def __eq__(self, other: object) -> bool:
349        # early return checks
350        if not isinstance(other, OpSchema):
351            return False
352
353        if self.op != other.op:
354            return False
355
356        if len(self.args_schema) != len(other.args_schema):
357            return False
358
359        # compare each element and early return if any of them is different
360        if not self.schema_info:
361            static_argnum = len(self.args_schema)
362            static_kwargkey = None
363        else:
364            static_argnum = self.schema_info.static_argnum
365            static_kwargkey = self.schema_info.static_kwargkey
366
367        for i, (self_arg, other_arg) in enumerate(
368            zip(self.args_schema, other.args_schema)
369        ):
370            if isinstance(self_arg, DTensorSpec) and self_arg != other_arg:
371                return False
372            elif i >= static_argnum and self_arg != other_arg:
373                return False
374
375        # check kwarg equality when there's a static kwarg key
376        if static_kwargkey:
377            for key in static_kwargkey:
378                if self.kwargs_schema.get(key, None) != other.kwargs_schema.get(
379                    key, None
380                ):
381                    return False
382
383        return True
384
385    def gen_fake_args(self) -> ArgsType:
386        """
387        gen_fake_args: generate fake args for the operator, this is mainly used
388            by sharding propagation rules to generate fake args for the operator
389            to run the local tensor operator and get the output spec.
390        """
391        return tree_map_only(
392            DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.args_schema
393        )
394
395    def gen_fake_kwargs(self) -> KwargsType:
396        """
397        gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used
398            by sharding propagation rules to generate fake kwargs for the operator
399            to run the local tensor operator and get the output spec.
400        """
401        return tree_map_only(
402            DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
403        )
404
405    def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
406        suggestion_args_spec = self.args_spec
407        new_arg_schema: List[object] = []
408        idx_of_args_spec = 0
409        if (
410            origin_schema.schema_info is not None
411            and origin_schema.schema_info.needs_pytree
412        ):
413            args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema)
414        else:
415            args_schema = origin_schema.args_schema
416        for arg in args_schema:
417            if isinstance(arg, DTensorSpec):
418                new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
419                idx_of_args_spec += 1
420            else:
421                new_arg_schema.append(arg)
422        self.args_schema = tuple(new_arg_schema)
423        self.kwargs_schema = origin_schema.kwargs_schema
424
425
426@dataclass
427class OutputSharding:
428    """
429    OutputSharding is a data class that is used by the sharding propagation,
430    it could set the output_spec upon successful propagation. If needs_redistribute
431    is set to True, a redistribute_schema would be returned together to indicate
432    the input arguments needs to be redistributed before the op execution.
433
434    NOTE: the redistribute_schema generated by sharding propagation should be
435    exactly the same as the operator OpSchema, except the DTensorSpecs
436    """
437
438    output_spec: OutputSpecType
439    redistribute_schema: Optional[OpSchema] = None
440    needs_redistribute: bool = False
441
442
443@dataclass
444class OpInfo:
445    """
446    All Runtime Op execution info are packed here
447    """
448
449    mesh: DeviceMesh
450    schema: OpSchema
451    flat_args_schema: List[object]
452    local_args: Sequence[object]
453    local_kwargs: Dict[str, object]
454    args_tree_spec: Optional[TreeSpec] = None
455
456    # the output sharding info
457    output_sharding: Optional[OutputSharding] = None
458