xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/experimental/_tp_transform.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import operator
4from typing import Any, cast, Dict, List, Optional, Sequence, Tuple
5
6import torch
7from torch._subclasses.fake_tensor import FakeTensor
8from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
9from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
10from torch.distributed.tensor._op_schema import (
11    OpSchema,
12    OutputSharding,
13    OutputSpecType,
14    PlacementStrategy,
15)
16from torch.distributed.tensor._redistribute import redistribute_local_tensor
17from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle
18from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
19from torch.export import ExportedProgram
20from torch.export.exported_program import ExportGraphSignature
21from torch.fx import GraphModule
22from torch.fx.experimental.proxy_tensor import make_fx
23from torch.fx.node import Node
24from torch.fx.passes.infra.pass_base import PassBase, PassResult
25from torch.fx.passes.shape_prop import _extract_tensor_metadata
26from torch.utils import _pytree as pytree
27
28
29__all__ = ["tensor_parallel_transformation"]
30
31aten = torch.ops.aten
32
33
34def tensor_parallel_transformation(
35    exported_program: ExportedProgram,
36    rank: int,
37    world_size: int,
38    device_type: str,
39    parallel_strategies: Dict[str, ParallelStyle],
40) -> ExportedProgram:
41    """
42    The entry point function to perform graph transformations on an exported program
43    to transform a single-device graph into a tensor parallel graph.
44
45    .. warning::
46        This API is experimental and subject to change.
47    """
48
49    gm = exported_program.graph_module
50    sig = copy.deepcopy(exported_program.graph_signature)
51    state_dict = copy.copy(exported_program.state_dict)
52
53    with gm._set_replace_hook(sig.get_replace_hook()):
54        res = _TensorParallelTransformPass(
55            rank,
56            world_size,
57            device_type,
58            state_dict,
59            exported_program.graph_signature,
60            parallel_strategies,
61        )(gm)
62        assert res is not None
63        gm = res.graph_module
64
65    return exported_program._update(gm, sig, state_dict=state_dict)
66
67
68class _TensorParallelTransformPass(PassBase):
69    """
70    This pass is responsible for transforming a single-device graph into a tensor parallel
71    graph. It will mark the placement strategy of each node in the graph,
72    partition the graph into distributed graph, then shard the parameters/buffers accordingly.
73    """
74
75    def __init__(
76        self,
77        rank: int,
78        world_size: int,
79        device_type: str,
80        state_dict: Dict[str, torch.Tensor],
81        graph_signature: ExportGraphSignature,
82        parallel_strategies: Dict[str, ParallelStyle],
83    ) -> None:
84        super().__init__()
85        self.rank = rank
86        self.mesh = DeviceMesh(device_type, torch.arange(world_size))
87        self.state_dict: Dict[str, torch.Tensor] = state_dict
88        self.graph_signature = graph_signature
89        self.parallel_strategies = parallel_strategies
90
91    def call(self, graph_module) -> PassResult:
92        gm = copy.deepcopy(graph_module)
93
94        parameter_placements = _generate_parameter_and_buffer_placements(
95            list(self.state_dict.keys()), self.parallel_strategies
96        )
97        placement_strategies = _mark_sharding(
98            gm, self.graph_signature, self.mesh, parameter_placements
99        )
100        _partitioner(gm)
101        _shard_state_dict(
102            self.state_dict, placement_strategies, self.graph_signature, self.mesh
103        )
104        return PassResult(gm, True)
105
106
107def _generate_parameter_and_buffer_placements(
108    params_and_buffers: List[str],
109    parallel_strategies: Dict[str, ParallelStyle],
110) -> Dict[str, Placement]:
111    """
112    Build parameter placements based on the give parallel style of linear layers.
113    """
114    parameter_placements: Dict[str, Placement] = {}
115    for linear_fqn, parallel_style in parallel_strategies.items():
116        weight_fqn = f"{linear_fqn}.weight"
117        bias_fqn = f"{linear_fqn}.bias"
118        assert weight_fqn in params_and_buffers
119        parameter_placements[weight_fqn] = (
120            Shard(0) if parallel_style == ColwiseParallel else Shard(1)
121        )
122        if bias_fqn in params_and_buffers:
123            parameter_placements[bias_fqn] = (
124                Shard(0) if parallel_style == ColwiseParallel else Replicate()
125            )
126    return parameter_placements
127
128
129def _mark_tensor_parallel_shardings(
130    gm: GraphModule,
131    graph_signature: ExportGraphSignature,
132    mesh: DeviceMesh,
133    parameter_placements: Dict[str, Placement],
134) -> Dict[Node, PlacementStrategy]:
135    """
136    Mark the placement strategies of the parameter and buffer placeholder nodes.
137    """
138    placement_strategies: Dict[Node, PlacementStrategy] = {}
139    num_params_and_buffers = len(graph_signature.inputs_to_parameters) + len(
140        graph_signature.inputs_to_buffers
141    )
142    placeholder_idx: int = 0
143    for node in gm.graph.nodes:
144        if node.op == "placeholder":
145            if placeholder_idx < num_params_and_buffers:
146                fqn: str = _get_input_node_fqn(node.name, graph_signature)
147                placement: Placement = (
148                    parameter_placements[fqn]
149                    if fqn in parameter_placements
150                    else Replicate()
151                )
152                placement_strategies[node] = _create_placement_strategy(
153                    node,
154                    mesh,
155                    placements=(placement,),
156                )
157                placeholder_idx += 1
158            else:
159                placement_strategies[node] = _create_placement_strategy(
160                    node,
161                    mesh,
162                    placements=(Replicate(),),
163                )
164    return placement_strategies
165
166
167def _get_input_node_fqn(input_name: str, graph_signature: ExportGraphSignature) -> str:
168    """
169    Return the FQN of an input node.
170    """
171    if input_name in graph_signature.inputs_to_parameters:
172        return graph_signature.inputs_to_parameters[input_name]
173    elif input_name in graph_signature.inputs_to_buffers:
174        return graph_signature.inputs_to_buffers[input_name]
175    else:
176        raise ValueError(
177            f"{input_name} not found in inputs_to_parameters or inputs_to_buffers"
178        )
179
180
181def _mark_sharding(
182    gm: GraphModule,
183    graph_signature: ExportGraphSignature,
184    mesh: DeviceMesh,
185    parameter_placements: Dict[str, Placement],
186) -> Dict[Node, PlacementStrategy]:
187    """
188    Mark the sharding strategy for each node in the graph module.
189    """
190    placement_strategies: Dict[
191        Node, PlacementStrategy
192    ] = _mark_tensor_parallel_shardings(gm, graph_signature, mesh, parameter_placements)
193
194    for node in gm.graph.nodes:
195        if node.op == "placeholder":
196            if node not in placement_strategies:
197                placement_strategies[node] = _create_placement_strategy(
198                    node, mesh, placements=(Replicate(),)
199                )
200            node.meta["sharding"] = placement_strategies[node]
201        elif node.op == "call_function":
202            if node.target == operator.getitem:
203                input_nodes = node.all_input_nodes
204                assert (
205                    len(input_nodes) == 1
206                ), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
207                arg_strategy = placement_strategies[input_nodes[0]]
208                placement_strategies[node] = _create_placement_strategy(
209                    node,
210                    mesh,
211                    placements=arg_strategy.output_spec.placements,
212                    input_specs=_get_input_node_specs(node, placement_strategies),
213                )
214                node.meta["sharding"] = placement_strategies[node]
215            else:
216                op_schema = _get_op_schema(node, placement_strategies)
217
218                # get DTensor specs for inputs and outputs
219                if (
220                    op_schema.op
221                    not in DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs
222                    and op_schema.op
223                    not in DTensor._op_dispatcher.sharding_propagator.op_to_rules
224                ):
225                    # Mark all as replicated
226                    output_sharding = _generate_default_output_sharding(
227                        node,
228                        mesh,
229                        op_schema,
230                    )
231                else:
232                    output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding(
233                        op_schema,
234                    )
235                placement_strategies[node] = PlacementStrategy(
236                    output_specs=_get_output_spec_from_output_sharding(output_sharding),
237                    input_specs=output_sharding.redistribute_schema.args_spec
238                    if output_sharding.redistribute_schema is not None
239                    else _get_input_node_specs(node, placement_strategies),
240                )
241                node.meta["sharding"] = placement_strategies[node]
242        elif node.op == "output":
243            node.meta["sharding"] = None
244        else:
245            raise RuntimeError(f"op code {node.op} not supported")
246    return placement_strategies
247
248
249def _get_output_spec_from_output_sharding(
250    output_sharding: OutputSharding,
251) -> DTensorSpec:
252    """
253    Util function to extract output spec from output sharding.
254    """
255    if isinstance(output_sharding.output_spec, DTensorSpec):
256        return output_sharding.output_spec
257    else:
258        # For ops that return multiple outputs, the outputs should have the same output spec
259        assert isinstance(output_sharding.output_spec, Sequence)
260        assert output_sharding.output_spec[0] is not None
261        output_sharding.output_spec[0].tensor_meta = None
262        return output_sharding.output_spec[0]
263
264
265def _create_placement_strategy(
266    node: Node,
267    mesh: DeviceMesh,
268    placements: Tuple[Placement, ...],
269    input_specs: Optional[Sequence[DTensorSpec]] = None,
270) -> PlacementStrategy:
271    """
272    Util function to construct a placement strategy for a given node.
273    """
274    placement = PlacementStrategy(
275        input_specs=input_specs,
276        output_specs=DTensorSpec(
277            mesh=mesh,
278            placements=placements,
279        ),
280    )
281    _populate_tensor_meta(node, placement.output_specs)
282    return placement
283
284
285def _populate_tensor_meta(node: Node, output_spec: OutputSpecType) -> None:
286    """
287    Util function to populate tensor meta of output_spec based on node metadata.
288    """
289    if isinstance(node.meta["val"], Sequence):
290        assert isinstance(output_spec, Sequence)
291        for spec, fake_tensor in zip(output_spec, node.meta["val"]):
292            assert spec is not None
293            spec.tensor_meta = TensorMeta(
294                shape=fake_tensor.shape,
295                stride=fake_tensor.stride(),
296                dtype=fake_tensor.dtype,
297            )
298    else:
299        assert isinstance(output_spec, DTensorSpec)
300        output_spec.tensor_meta = TensorMeta(
301            shape=node.meta["val"].shape,
302            stride=node.meta["val"].stride(),
303            dtype=node.meta["val"].dtype,
304        )
305
306
307def _generate_default_output_sharding(
308    node: Node,
309    mesh: DeviceMesh,
310    op_schema: OpSchema,
311) -> OutputSharding:
312    """
313    Util function to create a default output sharding that suggests Replicate placement for both args and outputs.
314    """
315
316    def update_arg_spec(arg_spec: DTensorSpec) -> DTensorSpec:
317        return DTensorSpec(
318            mesh=arg_spec.mesh,
319            placements=(Replicate(),),
320            tensor_meta=arg_spec.tensor_meta,
321        )
322
323    new_op_schema = OpSchema(
324        op=op_schema.op,
325        args_schema=pytree.tree_map_only(
326            DTensorSpec, update_arg_spec, op_schema.args_schema
327        ),
328        kwargs_schema=op_schema.kwargs_schema,
329    )
330
331    def create_output_spec(tensor: FakeTensor) -> DTensorSpec:
332        return DTensorSpec(
333            mesh=mesh,
334            placements=(Replicate(),),
335            tensor_meta=TensorMeta(
336                shape=tensor.shape,
337                stride=tensor.stride(),
338                dtype=tensor.dtype,
339            ),
340        )
341
342    return OutputSharding(
343        output_spec=pytree.tree_map_only(
344            FakeTensor, create_output_spec, node.meta["val"]
345        ),
346        redistribute_schema=new_op_schema,
347        needs_redistribute=True,
348    )
349
350
351def _partitioner(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
352    """
353    Graph partitioner that partitions the single device graph
354    to distributed graph
355    """
356    for node in gm.graph.nodes:
357        node_sharding = node.meta["sharding"]
358        if node.op == "placeholder":
359            out_spec = node_sharding.output_spec
360            local_val = _partition_val(node.meta["val"], out_spec)
361            # update node value
362            node.meta["val"] = local_val
363        elif node.op == "call_function":
364            out_spec = node_sharding.output_spec
365            # check if there's misaligned sharding, insert reshard if there is
366            expected_input_specs = node_sharding.input_specs
367            for idx, input_arg in enumerate(node.all_input_nodes):
368                input_arg_sharding = input_arg.meta["sharding"]
369                input_arg_spec = input_arg_sharding.output_spec
370                desired_spec = (
371                    out_spec
372                    if expected_input_specs is None
373                    else expected_input_specs[idx]
374                )
375                if input_arg_spec != desired_spec:
376                    _insert_reshard_gm(
377                        gm, node, input_arg, input_arg_spec, desired_spec
378                    )
379            # convert output val to its local component
380            output_val = node.meta["val"]
381            node.meta["val"] = _partition_val(output_val, out_spec)
382        elif node.op == "output":
383            for input_arg in node.all_input_nodes:
384                # input args of output should be Replicate, otherwise redistribution is needed.
385                input_args_to_check: Sequence[Node] = (
386                    input_arg if isinstance(input_arg, Sequence) else [input_arg]
387                )
388                for arg in input_args_to_check:
389                    arg_sharding = arg.meta["sharding"]
390                    arg_spec = arg_sharding.output_spec
391                    desired_spec = copy.copy(arg_spec)
392                    desired_spec.placements = (Replicate(),)
393                    if arg_spec != desired_spec:
394                        _insert_reshard_gm(gm, node, arg, arg_spec, desired_spec)
395        else:
396            raise RuntimeError(f"op code {node} not supported")
397
398    _clean_up_graph_metadata(gm)
399    gm.graph.lint()
400    gm.recompile()
401    return gm
402
403
404def _partition_val(val: Any, spec: DTensorSpec) -> Any:
405    """
406    util function to convert a full tensor val to its local component
407    """
408    if isinstance(val, torch.Tensor):
409        local_shard = val
410        if val.ndim == 0:
411            # If it's already a scalar tensor, it is already local, we don't
412            # need to do anything
413            return local_shard
414
415        for idx, placement in enumerate(spec.placements):
416            if placement.is_shard():
417                placement = cast(Shard, placement)
418                num_chunks = spec.mesh.size(mesh_dim=idx)
419                my_coord = spec.mesh.get_coordinate()
420                assert my_coord is not None, "current rank not in mesh!"
421                my_coord_on_mesh_dim = my_coord[idx]
422                local_shard = placement._split_tensor(
423                    local_shard, num_chunks, with_padding=False, contiguous=True
424                )[0][my_coord_on_mesh_dim]
425        return local_shard
426    elif isinstance(val, (list, tuple)):
427        return val.__class__(_partition_val(v, spec) for v in val)
428    else:
429        raise RuntimeError(f"val type {type(val)} not supported")
430
431
432def _insert_reshard_gm(
433    gm: torch.fx.GraphModule,
434    node: Node,
435    input_arg: Node,
436    input_arg_spec: DTensorSpec,
437    desired_spec: DTensorSpec,
438) -> None:
439    """
440    Transform the graph for tensor redistribution.
441    """
442    input_arg_spec.tensor_meta = input_arg.meta["tensor_meta"]
443    desired_spec.tensor_meta = input_arg.meta["tensor_meta"]
444    input_arg_tensor = input_arg.meta["val"]
445
446    # insert reshard operation
447    def reshard_fn(local_tensor: torch.Tensor) -> torch.Tensor:
448        return redistribute_local_tensor(
449            local_tensor,
450            input_arg_spec,
451            desired_spec,
452        )
453
454    reshard_gm = make_fx(reshard_fn)(input_arg_tensor)
455    reshard_gm_nodes = list(reshard_gm.graph.nodes)
456    input_node = reshard_gm_nodes[0]
457    with gm.graph.inserting_before(node):
458        # copy nn_module_stack metadata for output, all-reduce nodes
459        for reshard_node in reshard_gm.graph.nodes:
460            if reshard_node.op not in ["placeholder", "output"]:
461                reshard_node.meta["nn_module_stack"] = (
462                    copy.copy(input_arg.meta["nn_module_stack"])
463                    if not input_arg.op == "placeholder"
464                    else copy.copy(node.meta["nn_module_stack"])
465                )
466        output_node = gm.graph.graph_copy(
467            reshard_gm.graph,
468            val_map={
469                input_node: input_arg,
470            },
471        )
472    node.replace_input_with(input_arg, output_node)  # type: ignore[arg-type]
473
474
475def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None:
476    """
477    Clean up the graph by removing sharding and partitioning related metadata
478    """
479    for node in gm.graph.nodes:
480        if "sharding" in node.meta:
481            del node.meta["sharding"]
482        if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
483            local_tensor_meta = _extract_tensor_metadata(node.meta["val"])
484            node.meta["tensor_meta"] = local_tensor_meta
485
486
487def _get_input_node_specs(
488    node: Node, placement_strategies: Dict[Node, PlacementStrategy]
489) -> Tuple[DTensorSpec, ...]:
490    """
491    Get the input specs of a node.
492    """
493    input_specs_list: List[DTensorSpec] = []
494    for input_arg in node.all_input_nodes:
495        if input_arg in placement_strategies:
496            output_spec = placement_strategies[input_arg].output_specs
497            assert isinstance(output_spec, DTensorSpec)
498            input_specs_list.append(output_spec)
499        else:
500            raise ValueError(f"{input_arg} does not have output_spec populated.")
501    return tuple(input_specs_list)
502
503
504def _get_op_schema(
505    node: Node, placement_strategies: Dict[Node, PlacementStrategy]
506) -> OpSchema:
507    """
508    Util function to construct the operator schema of a node.
509    """
510    args_schema_list = pytree.tree_map_only(
511        Node, lambda arg: placement_strategies[arg].output_specs, node.args
512    )
513    op_schema = OpSchema(
514        op=cast(torch._ops.OpOverload, node.target),
515        args_schema=tuple(args_schema_list),
516        kwargs_schema=cast(Dict[str, object], node.kwargs),
517    )
518    return op_schema
519
520
521def _shard_state_dict(
522    state_dict: Dict[str, torch.Tensor],
523    placement_strategies: Dict[Node, PlacementStrategy],
524    graph_signature: ExportGraphSignature,
525    mesh: DeviceMesh,
526) -> None:
527    """
528    Inplace partition the weights based on the placement strategy
529    """
530    for node, placement_strategy in placement_strategies.items():
531        if node.op != "placeholder":
532            continue
533        if node.name in graph_signature.inputs_to_parameters:
534            fqn = graph_signature.inputs_to_parameters[node.name]
535        elif node.name in graph_signature.inputs_to_buffers:
536            fqn = graph_signature.inputs_to_buffers[node.name]
537        else:
538            continue
539        assert fqn in state_dict, f"{fqn} not found in state dict: {state_dict.keys()}"
540
541        original_param = state_dict[fqn]
542        dtensor_param = distribute_tensor(
543            original_param,
544            mesh,
545            placement_strategy.output_spec.placements,
546        )
547        local_param = dtensor_param.to_local()
548        state_dict[fqn] = (
549            torch.nn.Parameter(local_param)
550            if isinstance(original_param, torch.nn.Parameter)
551            else local_param
552        )
553