Searched defs:_FSDPState (Results 1 – 8 of 8) sorted by relevance
/aosp_15_r20/external/pytorch/torch/distributed/fsdp/ |
H A D | _runtime_utils.py | 96 def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool: 111 state: _FSDPState, 144 def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module): 170 root_state: _FSDPState, 236 state: _FSDPState, 274 state: _FSDPState, 307 state: _FSDPState, 345 state: _FSDPState, 408 state: _FSDPState, 435 state: _FSDPState, [all …]
|
H A D | _state_dict_utils.py | 67 def _should_unshard_params(fsdp_state: _FSDPState) -> bool: 85 module: nn.Module, fsdp_state: _FSDPState 111 fsdp_state: _FSDPState, 138 def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None: 146 fsdp_state: _FSDPState, 159 fsdp_state: _FSDPState, 182 fsdp_state: _FSDPState, 287 fsdp_state: _FSDPState, 314 fsdp_state: _FSDPState, 359 fsdp_state: _FSDPState, [all …]
|
H A D | _init_utils.py | 105 state: _FSDPState, 157 state: _FSDPState, 287 state: _FSDPState, 363 state: _FSDPState, 412 state: _FSDPState, 430 state: _FSDPState, 496 state: _FSDPState, 513 state: _FSDPState, 525 def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState: 541 def _init_state_dict_state(state: _FSDPState) -> _FSDPState: [all …]
|
H A D | _unshard_param_utils.py | 73 def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None: 87 def _register_flat_param(state: _FSDPState, module: nn.Module) -> None: 103 def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator: 126 state: _FSDPState, 161 state: _FSDPState, 238 state: _FSDPState, 308 def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None: 324 def _register_orig_params(state: _FSDPState, module: nn.Module) -> None:
|
H A D | _common_utils.py | 118 class _FSDPState(_State): class 200 def _is_composable(state: _FSDPState): 206 def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamHandle"]: 226 def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool: 368 state: _FSDPState, handle: "FlatParamHandle", logger: logging.Logger 380 state: _FSDPState, handle: "FlatParamHandle" 446 state: _FSDPState,
|
H A D | _optim_utils.py | 340 fsdp_state: _FSDPState, 359 fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup] 1311 fsdp_state: _FSDPState, 2076 fsdp_state: _FSDPState,
|
H A D | _exec_order_utils.py | 63 state: _FSDPState,
|
/aosp_15_r20/external/pytorch/test/distributed/_composable/fully_shard/ |
H A D | test_fully_shard_runtime.py | 189 state: _FSDPState, 199 state: _FSDPState,
|