1# mypy: allow-untyped-defs 2from dataclasses import dataclass 3from functools import cached_property 4from typing import Any, Dict, List, Optional, Sequence, Tuple, Union 5 6import torch 7from torch._ops import OpOverload 8from torch.distributed.device_mesh import DeviceMesh 9from torch.distributed.tensor._dtensor_spec import DTensorSpec 10from torch.distributed.tensor.placement_types import Placement 11 12 13try: 14 from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec 15except ImportError: 16 from torch.utils._pytree import ( # type: ignore[no-redef, assignment] 17 tree_leaves, 18 tree_map_only, 19 TreeSpec, 20 ) 21 22 23# Common type aliases 24ArgsType = Tuple[object, ...] 25KwargsType = Dict[str, object] 26 27PlacementList = List[Optional[Placement]] 28 29# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould 30# be the same set of possibilities. 31OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]] 32 33 34def _rebuild_tensor_from_dtensor_meta(arg) -> object: 35 """ 36 This is used to propagate tensor metadata, must be under fake mode 37 """ 38 assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta." 39 return torch.empty_strided( 40 arg.tensor_meta.shape, 41 arg.tensor_meta.stride, 42 dtype=arg.tensor_meta.dtype, 43 ) 44 45 46def _is_inplace_op(op: OpOverload): 47 # simple analysis of function schema to determine 48 # if this is an inplace variant, it might not 49 # be entirely correct, but it's good enough for now. 50 return op._schema.name[-1] == "_" 51 52 53def _is_out_variant_op(op: OpOverload): 54 # simple analysis of function schema to determine 55 # if this is an out variant, it might not 56 # be entirely correct, but it's good enough for now. 57 return "out" in op._schema.overload_name 58 59 60def _pretty_print_spec(spec: object) -> str: 61 if spec is None: 62 return "None" 63 elif isinstance(spec, DTensorSpec): 64 return "".join([str(p) for p in spec.placements]) 65 elif isinstance(spec, Sequence): 66 return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")" 67 else: 68 raise RuntimeError(f"Unknown spec type to print: spec={spec}") 69 70 71@dataclass 72class PlacementStrategy: 73 """ 74 A placement strategy describes acceptable sharding placements of the output 75 and the tensor arguments of an operation. 76 77 note: when the op return value is a single DTensor object, output_specs is 78 DTensorSpec; when the return value is a tuple of Optional[DTensor], 79 output_specs is a tuple of Optional[DTensorSpec]. 80 """ 81 82 output_specs: Union[DTensorSpec, Tuple[Optional[DTensorSpec], ...]] 83 input_specs: Optional[Sequence[DTensorSpec]] = None 84 85 # redistribute costs for this op placement strategy 86 # we need a nested list to record the cost for each 87 # operand of this operator, and for each operand of 88 # this operator it might have multiple placement strategies 89 redistribute_cost: Optional[List[List[float]]] = None 90 91 @cached_property 92 def output_spec(self) -> DTensorSpec: 93 """ 94 This function requires that the strategy have exactly one DTensorSpec as the 95 output spec. If the output_specs is a tuple, we throw an exception. 96 """ 97 if isinstance(self.output_specs, DTensorSpec): 98 return self.output_specs 99 else: 100 raise ValueError( 101 f"function output_spec expects a single DTensorSpec but got: {self.output_specs}" 102 ) 103 104 def input_spec(self, index: int = 0) -> DTensorSpec: 105 assert self.input_specs is not None, "input_specs of PlacementStrategy is None!" 106 assert len(self.input_specs) > index, ( 107 f"Invalid index {index} for input_specs of length " 108 f"{len(self.input_specs)}: {self.input_specs}" 109 ) 110 return self.input_specs[index] 111 112 def __str__(self) -> str: 113 if self.input_specs is not None: 114 input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> " 115 else: 116 input_specs_str = "" 117 output_spec_str = _pretty_print_spec(self.output_specs) 118 return f"{input_specs_str}{output_spec_str}" 119 120 121class StrategyType: 122 """ 123 Base class type for op strategy, We have two StrategyType: 124 OpStrategy and TupleStrategy 125 """ 126 127 128class OpStrategy(StrategyType): 129 """ 130 OpStrategy that consists of a list of placement strategies associated with the op 131 """ 132 133 def __init__(self, strategies: List[PlacementStrategy]) -> None: 134 super().__init__() 135 self.strategies: List[PlacementStrategy] = strategies 136 137 def __str__(self) -> str: 138 strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies]) 139 mesh_shape = self.mesh_shape 140 return f"[{strategy_list_str}] @ mesh: {mesh_shape}" 141 142 def max_num_shards(self) -> int: 143 """ 144 Returns the max number of shards across all placement strategies 145 """ 146 return max(strategy.output_spec.num_shards for strategy in self.strategies) 147 148 @property 149 def mesh_shape(self): 150 output_spec = self.strategies[0].output_specs 151 if isinstance(output_spec, DTensorSpec): 152 return output_spec.mesh.shape 153 else: 154 assert isinstance( 155 output_spec, tuple 156 ), "found no DTensorSpec in the OpStrategy!" 157 assert output_spec[0] is not None 158 return output_spec[0].mesh.shape 159 160 @property 161 def ndim(self): 162 return self.strategies[0].output_spec.ndim 163 164 @property 165 def shape(self): 166 return self.strategies[0].output_spec.shape 167 168 169class TupleStrategy(StrategyType): 170 """ 171 TupleStrategy represents the output strategy of this op is a tuple 172 of strategy, i.e. If the output of this op is a tuple of tensors or list of tensors 173 with possibly different placement strategies, we should return a TupleStrategy that 174 contains a tuple of OpStrategy, where each child represents the sharding strategy 175 of "each element" of the tuple/list of tensors the op returns. 176 177 NOTE: if the output of the op is a List[Tensor] and they share the same placement 178 strategy, then we should return a single OpStrategy instead of a TupleStrategy 179 """ 180 181 def __init__(self, childs: Sequence[StrategyType]) -> None: 182 super().__init__() 183 self.childs: Sequence[StrategyType] = childs 184 185 def __str__(self) -> str: 186 child_strategies_str = ", ".join( 187 [f"{str(strat)}" for idx, strat in enumerate(self.childs)] 188 ) 189 return f"TupleStrategy({child_strategies_str})" 190 191 192@dataclass 193class RuntimeSchemaInfo: 194 """ 195 RuntimeSchemaInfo stores the operator schema related information for runtime (eager) 196 execution. This is mainly used for two ways: 1. to generate hash for args to determine 197 whether to re-run sharding prop or not 2. to determine if we need pytree 198 """ 199 200 # This static_argnum records static arg "starting index" for ops that have non-tensor 201 # args/kwargs which would affect sharding propagation results. All args starting from 202 # this index would be hashed to our sharding cache. 203 # Note that only a few ops need this information, e.g. view, transpose, var.dim, etc. 204 static_argnum: int = 100 205 # This static_kwargkey records static kwarg names which would affect sharding prop 206 static_kwargkey: Optional[List[str]] = None 207 # each op can decide if it wants to use pytree flatten/unflatten during operator 208 # eager execution, by default we don't need to do flatten/unflatten, only if the 209 # op indicate it needs to, this is to accelerate eager performance. 210 needs_pytree: bool = False 211 212 213@dataclass 214class OpSchema: 215 """ 216 OpSchema is a data class that describes an operator input schemas, it includes 217 DTensorSpecs (instead of DTensor) and non-tensor args/kwargs (positional order 218 preserved). It is mainly used by the DTensor's dispatching logic to perform various 219 actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.) 220 221 NOTE: this should be used as a read only data class 222 TODO: make this a frozen dataclass 223 224 Args: 225 op: the operator overload we are intercepting 226 args_schema: contains args except that the DTensor args have been replaced 227 with its DTensorSpec or OpStrategy 228 kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced 229 with its DTensorSpec or OpStrategy 230 """ 231 232 op: OpOverload 233 args_schema: ArgsType 234 kwargs_schema: KwargsType 235 236 schema_info: Optional[RuntimeSchemaInfo] = None 237 238 @property 239 def args_spec(self) -> Tuple[DTensorSpec, ...]: 240 """ 241 args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list 242 with NO non-DTensor positional arguments (i.e. int/float/tuple, etc) 243 mainly used by sharding propagation to propagate the output spec 244 """ 245 args = ( 246 tree_leaves(self.args_schema) 247 if self.schema_info is not None and self.schema_info.needs_pytree 248 else self.args_schema 249 ) 250 return tuple(item for item in args if isinstance(item, DTensorSpec)) 251 252 @property 253 def args_strategy(self) -> Tuple[OpStrategy, ...]: 254 # filter out non-relevant values from args schema to get a clean OpStrategy list 255 # separate with args_spec for the ease of type annotation 256 # TODO: see if we should merge this with args_spec 257 args = ( 258 tree_leaves(self.args_schema) 259 if self.schema_info is not None and self.schema_info.needs_pytree 260 else self.args_schema 261 ) 262 return tuple(item for item in args if isinstance(item, OpStrategy)) 263 264 def __repr__(self) -> str: 265 args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema]) 266 return ( 267 f"OpSchema(op={self.op}," 268 f" args_schema=({args_schema})," 269 f" kwargs_schema={self.kwargs_schema})" 270 ) 271 272 def __str__(self) -> str: 273 args_schema: List[str] = [] 274 mesh_shape = None 275 for arg in self.args_schema: 276 if isinstance(arg, DTensorSpec): 277 args_schema.append(str(arg)) 278 mesh_shape = arg.mesh.shape 279 elif isinstance(arg, OpStrategy): 280 assert len(arg.strategies) == 1 281 args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs)) 282 mesh_shape = arg.mesh_shape 283 elif isinstance(arg, TupleStrategy): 284 first_op_strtgy = arg.childs[0] 285 assert isinstance(first_op_strtgy, OpStrategy) 286 mesh_shape = first_op_strtgy.mesh_shape 287 args_schema.append(str(arg)) 288 else: 289 args_schema.append(str(arg)) 290 return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})" 291 292 def __post_init__(self) -> None: 293 has_symints = False 294 for a in self.args_schema: 295 if isinstance(a, DTensorSpec) and a.tensor_meta is not None: 296 if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape): 297 has_symints = True 298 break 299 self.has_symints = has_symints 300 301 def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool: 302 arg = self.args_schema[arg_idx] 303 is_tensor = isinstance(arg, DTensorSpec) 304 if is_tensor: 305 return True 306 307 if not isinstance(arg, list): 308 return False 309 310 return all(isinstance(e, DTensorSpec) or e is None for e in arg) 311 312 def return_type_tuple_tensor_like(self) -> bool: 313 # all dispatch ops could only return Tuple[Tensor] or have None/ints/floats 314 # in the tuple, but the first element must be a Tensor, so this check is enough 315 return_types = self.op._schema.returns 316 return len(return_types) > 1 and isinstance( 317 return_types[0].type, torch.TensorType 318 ) 319 320 def return_type_tensor(self) -> bool: 321 return_types = self.op._schema.returns 322 # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like 323 # return types, so this check is enough for tensor like types 324 return isinstance(return_types[0].type, torch.TensorType) 325 326 def __hash__(self) -> int: 327 # Only hash args and kwargs that op indicates to hash 328 if not self.schema_info: 329 static_argnum = len(self.args_schema) 330 static_kwargkey = None 331 else: 332 static_argnum = self.schema_info.static_argnum 333 static_kwargkey = self.schema_info.static_kwargkey 334 335 args_to_hash = tuple( 336 tuple(e) if isinstance(e, list) else e 337 for i, e in enumerate(self.args_schema) 338 if self.arg_type_tensor_or_tensor_list_like(i) or i >= static_argnum 339 ) 340 if static_kwargkey is not None: 341 kwargs_to_hash = tuple( 342 self.kwargs_schema.get(k, None) for k in static_kwargkey 343 ) 344 return hash((self.op, args_to_hash, kwargs_to_hash)) 345 else: 346 return hash((self.op, args_to_hash)) 347 348 def __eq__(self, other: object) -> bool: 349 # early return checks 350 if not isinstance(other, OpSchema): 351 return False 352 353 if self.op != other.op: 354 return False 355 356 if len(self.args_schema) != len(other.args_schema): 357 return False 358 359 # compare each element and early return if any of them is different 360 if not self.schema_info: 361 static_argnum = len(self.args_schema) 362 static_kwargkey = None 363 else: 364 static_argnum = self.schema_info.static_argnum 365 static_kwargkey = self.schema_info.static_kwargkey 366 367 for i, (self_arg, other_arg) in enumerate( 368 zip(self.args_schema, other.args_schema) 369 ): 370 if isinstance(self_arg, DTensorSpec) and self_arg != other_arg: 371 return False 372 elif i >= static_argnum and self_arg != other_arg: 373 return False 374 375 # check kwarg equality when there's a static kwarg key 376 if static_kwargkey: 377 for key in static_kwargkey: 378 if self.kwargs_schema.get(key, None) != other.kwargs_schema.get( 379 key, None 380 ): 381 return False 382 383 return True 384 385 def gen_fake_args(self) -> ArgsType: 386 """ 387 gen_fake_args: generate fake args for the operator, this is mainly used 388 by sharding propagation rules to generate fake args for the operator 389 to run the local tensor operator and get the output spec. 390 """ 391 return tree_map_only( 392 DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.args_schema 393 ) 394 395 def gen_fake_kwargs(self) -> KwargsType: 396 """ 397 gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used 398 by sharding propagation rules to generate fake kwargs for the operator 399 to run the local tensor operator and get the output spec. 400 """ 401 return tree_map_only( 402 DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema 403 ) 404 405 def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None: 406 suggestion_args_spec = self.args_spec 407 new_arg_schema: List[object] = [] 408 idx_of_args_spec = 0 409 if ( 410 origin_schema.schema_info is not None 411 and origin_schema.schema_info.needs_pytree 412 ): 413 args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema) 414 else: 415 args_schema = origin_schema.args_schema 416 for arg in args_schema: 417 if isinstance(arg, DTensorSpec): 418 new_arg_schema.append(suggestion_args_spec[idx_of_args_spec]) 419 idx_of_args_spec += 1 420 else: 421 new_arg_schema.append(arg) 422 self.args_schema = tuple(new_arg_schema) 423 self.kwargs_schema = origin_schema.kwargs_schema 424 425 426@dataclass 427class OutputSharding: 428 """ 429 OutputSharding is a data class that is used by the sharding propagation, 430 it could set the output_spec upon successful propagation. If needs_redistribute 431 is set to True, a redistribute_schema would be returned together to indicate 432 the input arguments needs to be redistributed before the op execution. 433 434 NOTE: the redistribute_schema generated by sharding propagation should be 435 exactly the same as the operator OpSchema, except the DTensorSpecs 436 """ 437 438 output_spec: OutputSpecType 439 redistribute_schema: Optional[OpSchema] = None 440 needs_redistribute: bool = False 441 442 443@dataclass 444class OpInfo: 445 """ 446 All Runtime Op execution info are packed here 447 """ 448 449 mesh: DeviceMesh 450 schema: OpSchema 451 flat_args_schema: List[object] 452 local_args: Sequence[object] 453 local_kwargs: Dict[str, object] 454 args_tree_spec: Optional[TreeSpec] = None 455 456 # the output sharding info 457 output_sharding: Optional[OutputSharding] = None 458