1# mypy: allow-untyped-defs 2import copy 3import operator 4from typing import Any, cast, Dict, List, Optional, Sequence, Tuple 5 6import torch 7from torch._subclasses.fake_tensor import FakeTensor 8from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor 9from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta 10from torch.distributed.tensor._op_schema import ( 11 OpSchema, 12 OutputSharding, 13 OutputSpecType, 14 PlacementStrategy, 15) 16from torch.distributed.tensor._redistribute import redistribute_local_tensor 17from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle 18from torch.distributed.tensor.placement_types import Placement, Replicate, Shard 19from torch.export import ExportedProgram 20from torch.export.exported_program import ExportGraphSignature 21from torch.fx import GraphModule 22from torch.fx.experimental.proxy_tensor import make_fx 23from torch.fx.node import Node 24from torch.fx.passes.infra.pass_base import PassBase, PassResult 25from torch.fx.passes.shape_prop import _extract_tensor_metadata 26from torch.utils import _pytree as pytree 27 28 29__all__ = ["tensor_parallel_transformation"] 30 31aten = torch.ops.aten 32 33 34def tensor_parallel_transformation( 35 exported_program: ExportedProgram, 36 rank: int, 37 world_size: int, 38 device_type: str, 39 parallel_strategies: Dict[str, ParallelStyle], 40) -> ExportedProgram: 41 """ 42 The entry point function to perform graph transformations on an exported program 43 to transform a single-device graph into a tensor parallel graph. 44 45 .. warning:: 46 This API is experimental and subject to change. 47 """ 48 49 gm = exported_program.graph_module 50 sig = copy.deepcopy(exported_program.graph_signature) 51 state_dict = copy.copy(exported_program.state_dict) 52 53 with gm._set_replace_hook(sig.get_replace_hook()): 54 res = _TensorParallelTransformPass( 55 rank, 56 world_size, 57 device_type, 58 state_dict, 59 exported_program.graph_signature, 60 parallel_strategies, 61 )(gm) 62 assert res is not None 63 gm = res.graph_module 64 65 return exported_program._update(gm, sig, state_dict=state_dict) 66 67 68class _TensorParallelTransformPass(PassBase): 69 """ 70 This pass is responsible for transforming a single-device graph into a tensor parallel 71 graph. It will mark the placement strategy of each node in the graph, 72 partition the graph into distributed graph, then shard the parameters/buffers accordingly. 73 """ 74 75 def __init__( 76 self, 77 rank: int, 78 world_size: int, 79 device_type: str, 80 state_dict: Dict[str, torch.Tensor], 81 graph_signature: ExportGraphSignature, 82 parallel_strategies: Dict[str, ParallelStyle], 83 ) -> None: 84 super().__init__() 85 self.rank = rank 86 self.mesh = DeviceMesh(device_type, torch.arange(world_size)) 87 self.state_dict: Dict[str, torch.Tensor] = state_dict 88 self.graph_signature = graph_signature 89 self.parallel_strategies = parallel_strategies 90 91 def call(self, graph_module) -> PassResult: 92 gm = copy.deepcopy(graph_module) 93 94 parameter_placements = _generate_parameter_and_buffer_placements( 95 list(self.state_dict.keys()), self.parallel_strategies 96 ) 97 placement_strategies = _mark_sharding( 98 gm, self.graph_signature, self.mesh, parameter_placements 99 ) 100 _partitioner(gm) 101 _shard_state_dict( 102 self.state_dict, placement_strategies, self.graph_signature, self.mesh 103 ) 104 return PassResult(gm, True) 105 106 107def _generate_parameter_and_buffer_placements( 108 params_and_buffers: List[str], 109 parallel_strategies: Dict[str, ParallelStyle], 110) -> Dict[str, Placement]: 111 """ 112 Build parameter placements based on the give parallel style of linear layers. 113 """ 114 parameter_placements: Dict[str, Placement] = {} 115 for linear_fqn, parallel_style in parallel_strategies.items(): 116 weight_fqn = f"{linear_fqn}.weight" 117 bias_fqn = f"{linear_fqn}.bias" 118 assert weight_fqn in params_and_buffers 119 parameter_placements[weight_fqn] = ( 120 Shard(0) if parallel_style == ColwiseParallel else Shard(1) 121 ) 122 if bias_fqn in params_and_buffers: 123 parameter_placements[bias_fqn] = ( 124 Shard(0) if parallel_style == ColwiseParallel else Replicate() 125 ) 126 return parameter_placements 127 128 129def _mark_tensor_parallel_shardings( 130 gm: GraphModule, 131 graph_signature: ExportGraphSignature, 132 mesh: DeviceMesh, 133 parameter_placements: Dict[str, Placement], 134) -> Dict[Node, PlacementStrategy]: 135 """ 136 Mark the placement strategies of the parameter and buffer placeholder nodes. 137 """ 138 placement_strategies: Dict[Node, PlacementStrategy] = {} 139 num_params_and_buffers = len(graph_signature.inputs_to_parameters) + len( 140 graph_signature.inputs_to_buffers 141 ) 142 placeholder_idx: int = 0 143 for node in gm.graph.nodes: 144 if node.op == "placeholder": 145 if placeholder_idx < num_params_and_buffers: 146 fqn: str = _get_input_node_fqn(node.name, graph_signature) 147 placement: Placement = ( 148 parameter_placements[fqn] 149 if fqn in parameter_placements 150 else Replicate() 151 ) 152 placement_strategies[node] = _create_placement_strategy( 153 node, 154 mesh, 155 placements=(placement,), 156 ) 157 placeholder_idx += 1 158 else: 159 placement_strategies[node] = _create_placement_strategy( 160 node, 161 mesh, 162 placements=(Replicate(),), 163 ) 164 return placement_strategies 165 166 167def _get_input_node_fqn(input_name: str, graph_signature: ExportGraphSignature) -> str: 168 """ 169 Return the FQN of an input node. 170 """ 171 if input_name in graph_signature.inputs_to_parameters: 172 return graph_signature.inputs_to_parameters[input_name] 173 elif input_name in graph_signature.inputs_to_buffers: 174 return graph_signature.inputs_to_buffers[input_name] 175 else: 176 raise ValueError( 177 f"{input_name} not found in inputs_to_parameters or inputs_to_buffers" 178 ) 179 180 181def _mark_sharding( 182 gm: GraphModule, 183 graph_signature: ExportGraphSignature, 184 mesh: DeviceMesh, 185 parameter_placements: Dict[str, Placement], 186) -> Dict[Node, PlacementStrategy]: 187 """ 188 Mark the sharding strategy for each node in the graph module. 189 """ 190 placement_strategies: Dict[ 191 Node, PlacementStrategy 192 ] = _mark_tensor_parallel_shardings(gm, graph_signature, mesh, parameter_placements) 193 194 for node in gm.graph.nodes: 195 if node.op == "placeholder": 196 if node not in placement_strategies: 197 placement_strategies[node] = _create_placement_strategy( 198 node, mesh, placements=(Replicate(),) 199 ) 200 node.meta["sharding"] = placement_strategies[node] 201 elif node.op == "call_function": 202 if node.target == operator.getitem: 203 input_nodes = node.all_input_nodes 204 assert ( 205 len(input_nodes) == 1 206 ), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}" 207 arg_strategy = placement_strategies[input_nodes[0]] 208 placement_strategies[node] = _create_placement_strategy( 209 node, 210 mesh, 211 placements=arg_strategy.output_spec.placements, 212 input_specs=_get_input_node_specs(node, placement_strategies), 213 ) 214 node.meta["sharding"] = placement_strategies[node] 215 else: 216 op_schema = _get_op_schema(node, placement_strategies) 217 218 # get DTensor specs for inputs and outputs 219 if ( 220 op_schema.op 221 not in DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs 222 and op_schema.op 223 not in DTensor._op_dispatcher.sharding_propagator.op_to_rules 224 ): 225 # Mark all as replicated 226 output_sharding = _generate_default_output_sharding( 227 node, 228 mesh, 229 op_schema, 230 ) 231 else: 232 output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding( 233 op_schema, 234 ) 235 placement_strategies[node] = PlacementStrategy( 236 output_specs=_get_output_spec_from_output_sharding(output_sharding), 237 input_specs=output_sharding.redistribute_schema.args_spec 238 if output_sharding.redistribute_schema is not None 239 else _get_input_node_specs(node, placement_strategies), 240 ) 241 node.meta["sharding"] = placement_strategies[node] 242 elif node.op == "output": 243 node.meta["sharding"] = None 244 else: 245 raise RuntimeError(f"op code {node.op} not supported") 246 return placement_strategies 247 248 249def _get_output_spec_from_output_sharding( 250 output_sharding: OutputSharding, 251) -> DTensorSpec: 252 """ 253 Util function to extract output spec from output sharding. 254 """ 255 if isinstance(output_sharding.output_spec, DTensorSpec): 256 return output_sharding.output_spec 257 else: 258 # For ops that return multiple outputs, the outputs should have the same output spec 259 assert isinstance(output_sharding.output_spec, Sequence) 260 assert output_sharding.output_spec[0] is not None 261 output_sharding.output_spec[0].tensor_meta = None 262 return output_sharding.output_spec[0] 263 264 265def _create_placement_strategy( 266 node: Node, 267 mesh: DeviceMesh, 268 placements: Tuple[Placement, ...], 269 input_specs: Optional[Sequence[DTensorSpec]] = None, 270) -> PlacementStrategy: 271 """ 272 Util function to construct a placement strategy for a given node. 273 """ 274 placement = PlacementStrategy( 275 input_specs=input_specs, 276 output_specs=DTensorSpec( 277 mesh=mesh, 278 placements=placements, 279 ), 280 ) 281 _populate_tensor_meta(node, placement.output_specs) 282 return placement 283 284 285def _populate_tensor_meta(node: Node, output_spec: OutputSpecType) -> None: 286 """ 287 Util function to populate tensor meta of output_spec based on node metadata. 288 """ 289 if isinstance(node.meta["val"], Sequence): 290 assert isinstance(output_spec, Sequence) 291 for spec, fake_tensor in zip(output_spec, node.meta["val"]): 292 assert spec is not None 293 spec.tensor_meta = TensorMeta( 294 shape=fake_tensor.shape, 295 stride=fake_tensor.stride(), 296 dtype=fake_tensor.dtype, 297 ) 298 else: 299 assert isinstance(output_spec, DTensorSpec) 300 output_spec.tensor_meta = TensorMeta( 301 shape=node.meta["val"].shape, 302 stride=node.meta["val"].stride(), 303 dtype=node.meta["val"].dtype, 304 ) 305 306 307def _generate_default_output_sharding( 308 node: Node, 309 mesh: DeviceMesh, 310 op_schema: OpSchema, 311) -> OutputSharding: 312 """ 313 Util function to create a default output sharding that suggests Replicate placement for both args and outputs. 314 """ 315 316 def update_arg_spec(arg_spec: DTensorSpec) -> DTensorSpec: 317 return DTensorSpec( 318 mesh=arg_spec.mesh, 319 placements=(Replicate(),), 320 tensor_meta=arg_spec.tensor_meta, 321 ) 322 323 new_op_schema = OpSchema( 324 op=op_schema.op, 325 args_schema=pytree.tree_map_only( 326 DTensorSpec, update_arg_spec, op_schema.args_schema 327 ), 328 kwargs_schema=op_schema.kwargs_schema, 329 ) 330 331 def create_output_spec(tensor: FakeTensor) -> DTensorSpec: 332 return DTensorSpec( 333 mesh=mesh, 334 placements=(Replicate(),), 335 tensor_meta=TensorMeta( 336 shape=tensor.shape, 337 stride=tensor.stride(), 338 dtype=tensor.dtype, 339 ), 340 ) 341 342 return OutputSharding( 343 output_spec=pytree.tree_map_only( 344 FakeTensor, create_output_spec, node.meta["val"] 345 ), 346 redistribute_schema=new_op_schema, 347 needs_redistribute=True, 348 ) 349 350 351def _partitioner(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 352 """ 353 Graph partitioner that partitions the single device graph 354 to distributed graph 355 """ 356 for node in gm.graph.nodes: 357 node_sharding = node.meta["sharding"] 358 if node.op == "placeholder": 359 out_spec = node_sharding.output_spec 360 local_val = _partition_val(node.meta["val"], out_spec) 361 # update node value 362 node.meta["val"] = local_val 363 elif node.op == "call_function": 364 out_spec = node_sharding.output_spec 365 # check if there's misaligned sharding, insert reshard if there is 366 expected_input_specs = node_sharding.input_specs 367 for idx, input_arg in enumerate(node.all_input_nodes): 368 input_arg_sharding = input_arg.meta["sharding"] 369 input_arg_spec = input_arg_sharding.output_spec 370 desired_spec = ( 371 out_spec 372 if expected_input_specs is None 373 else expected_input_specs[idx] 374 ) 375 if input_arg_spec != desired_spec: 376 _insert_reshard_gm( 377 gm, node, input_arg, input_arg_spec, desired_spec 378 ) 379 # convert output val to its local component 380 output_val = node.meta["val"] 381 node.meta["val"] = _partition_val(output_val, out_spec) 382 elif node.op == "output": 383 for input_arg in node.all_input_nodes: 384 # input args of output should be Replicate, otherwise redistribution is needed. 385 input_args_to_check: Sequence[Node] = ( 386 input_arg if isinstance(input_arg, Sequence) else [input_arg] 387 ) 388 for arg in input_args_to_check: 389 arg_sharding = arg.meta["sharding"] 390 arg_spec = arg_sharding.output_spec 391 desired_spec = copy.copy(arg_spec) 392 desired_spec.placements = (Replicate(),) 393 if arg_spec != desired_spec: 394 _insert_reshard_gm(gm, node, arg, arg_spec, desired_spec) 395 else: 396 raise RuntimeError(f"op code {node} not supported") 397 398 _clean_up_graph_metadata(gm) 399 gm.graph.lint() 400 gm.recompile() 401 return gm 402 403 404def _partition_val(val: Any, spec: DTensorSpec) -> Any: 405 """ 406 util function to convert a full tensor val to its local component 407 """ 408 if isinstance(val, torch.Tensor): 409 local_shard = val 410 if val.ndim == 0: 411 # If it's already a scalar tensor, it is already local, we don't 412 # need to do anything 413 return local_shard 414 415 for idx, placement in enumerate(spec.placements): 416 if placement.is_shard(): 417 placement = cast(Shard, placement) 418 num_chunks = spec.mesh.size(mesh_dim=idx) 419 my_coord = spec.mesh.get_coordinate() 420 assert my_coord is not None, "current rank not in mesh!" 421 my_coord_on_mesh_dim = my_coord[idx] 422 local_shard = placement._split_tensor( 423 local_shard, num_chunks, with_padding=False, contiguous=True 424 )[0][my_coord_on_mesh_dim] 425 return local_shard 426 elif isinstance(val, (list, tuple)): 427 return val.__class__(_partition_val(v, spec) for v in val) 428 else: 429 raise RuntimeError(f"val type {type(val)} not supported") 430 431 432def _insert_reshard_gm( 433 gm: torch.fx.GraphModule, 434 node: Node, 435 input_arg: Node, 436 input_arg_spec: DTensorSpec, 437 desired_spec: DTensorSpec, 438) -> None: 439 """ 440 Transform the graph for tensor redistribution. 441 """ 442 input_arg_spec.tensor_meta = input_arg.meta["tensor_meta"] 443 desired_spec.tensor_meta = input_arg.meta["tensor_meta"] 444 input_arg_tensor = input_arg.meta["val"] 445 446 # insert reshard operation 447 def reshard_fn(local_tensor: torch.Tensor) -> torch.Tensor: 448 return redistribute_local_tensor( 449 local_tensor, 450 input_arg_spec, 451 desired_spec, 452 ) 453 454 reshard_gm = make_fx(reshard_fn)(input_arg_tensor) 455 reshard_gm_nodes = list(reshard_gm.graph.nodes) 456 input_node = reshard_gm_nodes[0] 457 with gm.graph.inserting_before(node): 458 # copy nn_module_stack metadata for output, all-reduce nodes 459 for reshard_node in reshard_gm.graph.nodes: 460 if reshard_node.op not in ["placeholder", "output"]: 461 reshard_node.meta["nn_module_stack"] = ( 462 copy.copy(input_arg.meta["nn_module_stack"]) 463 if not input_arg.op == "placeholder" 464 else copy.copy(node.meta["nn_module_stack"]) 465 ) 466 output_node = gm.graph.graph_copy( 467 reshard_gm.graph, 468 val_map={ 469 input_node: input_arg, 470 }, 471 ) 472 node.replace_input_with(input_arg, output_node) # type: ignore[arg-type] 473 474 475def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None: 476 """ 477 Clean up the graph by removing sharding and partitioning related metadata 478 """ 479 for node in gm.graph.nodes: 480 if "sharding" in node.meta: 481 del node.meta["sharding"] 482 if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): 483 local_tensor_meta = _extract_tensor_metadata(node.meta["val"]) 484 node.meta["tensor_meta"] = local_tensor_meta 485 486 487def _get_input_node_specs( 488 node: Node, placement_strategies: Dict[Node, PlacementStrategy] 489) -> Tuple[DTensorSpec, ...]: 490 """ 491 Get the input specs of a node. 492 """ 493 input_specs_list: List[DTensorSpec] = [] 494 for input_arg in node.all_input_nodes: 495 if input_arg in placement_strategies: 496 output_spec = placement_strategies[input_arg].output_specs 497 assert isinstance(output_spec, DTensorSpec) 498 input_specs_list.append(output_spec) 499 else: 500 raise ValueError(f"{input_arg} does not have output_spec populated.") 501 return tuple(input_specs_list) 502 503 504def _get_op_schema( 505 node: Node, placement_strategies: Dict[Node, PlacementStrategy] 506) -> OpSchema: 507 """ 508 Util function to construct the operator schema of a node. 509 """ 510 args_schema_list = pytree.tree_map_only( 511 Node, lambda arg: placement_strategies[arg].output_specs, node.args 512 ) 513 op_schema = OpSchema( 514 op=cast(torch._ops.OpOverload, node.target), 515 args_schema=tuple(args_schema_list), 516 kwargs_schema=cast(Dict[str, object], node.kwargs), 517 ) 518 return op_schema 519 520 521def _shard_state_dict( 522 state_dict: Dict[str, torch.Tensor], 523 placement_strategies: Dict[Node, PlacementStrategy], 524 graph_signature: ExportGraphSignature, 525 mesh: DeviceMesh, 526) -> None: 527 """ 528 Inplace partition the weights based on the placement strategy 529 """ 530 for node, placement_strategy in placement_strategies.items(): 531 if node.op != "placeholder": 532 continue 533 if node.name in graph_signature.inputs_to_parameters: 534 fqn = graph_signature.inputs_to_parameters[node.name] 535 elif node.name in graph_signature.inputs_to_buffers: 536 fqn = graph_signature.inputs_to_buffers[node.name] 537 else: 538 continue 539 assert fqn in state_dict, f"{fqn} not found in state dict: {state_dict.keys()}" 540 541 original_param = state_dict[fqn] 542 dtensor_param = distribute_tensor( 543 original_param, 544 mesh, 545 placement_strategy.output_spec.placements, 546 ) 547 local_param = dtensor_param.to_local() 548 state_dict[fqn] = ( 549 torch.nn.Parameter(local_param) 550 if isinstance(original_param, torch.nn.Parameter) 551 else local_param 552 ) 553