Home
last modified time | relevance | path

Searched defs:_FSDPState (Results 1 – 8 of 8) sorted by relevance

/aosp_15_r20/external/pytorch/torch/distributed/fsdp/
H A D_runtime_utils.py96 def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool:
111 state: _FSDPState,
144 def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module):
170 root_state: _FSDPState,
236 state: _FSDPState,
274 state: _FSDPState,
307 state: _FSDPState,
345 state: _FSDPState,
408 state: _FSDPState,
435 state: _FSDPState,
[all …]
H A D_state_dict_utils.py67 def _should_unshard_params(fsdp_state: _FSDPState) -> bool:
85 module: nn.Module, fsdp_state: _FSDPState
111 fsdp_state: _FSDPState,
138 def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None:
146 fsdp_state: _FSDPState,
159 fsdp_state: _FSDPState,
182 fsdp_state: _FSDPState,
287 fsdp_state: _FSDPState,
314 fsdp_state: _FSDPState,
359 fsdp_state: _FSDPState,
[all …]
H A D_init_utils.py105 state: _FSDPState,
157 state: _FSDPState,
287 state: _FSDPState,
363 state: _FSDPState,
412 state: _FSDPState,
430 state: _FSDPState,
496 state: _FSDPState,
513 state: _FSDPState,
525 def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
541 def _init_state_dict_state(state: _FSDPState) -> _FSDPState:
[all …]
H A D_unshard_param_utils.py73 def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None:
87 def _register_flat_param(state: _FSDPState, module: nn.Module) -> None:
103 def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator:
126 state: _FSDPState,
161 state: _FSDPState,
238 state: _FSDPState,
308 def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
324 def _register_orig_params(state: _FSDPState, module: nn.Module) -> None:
H A D_common_utils.py118 class _FSDPState(_State): class
200 def _is_composable(state: _FSDPState):
206 def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamHandle"]:
226 def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool:
368 state: _FSDPState, handle: "FlatParamHandle", logger: logging.Logger
380 state: _FSDPState, handle: "FlatParamHandle"
446 state: _FSDPState,
H A D_optim_utils.py340 fsdp_state: _FSDPState,
359 fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup]
1311 fsdp_state: _FSDPState,
2076 fsdp_state: _FSDPState,
H A D_exec_order_utils.py63 state: _FSDPState,
/aosp_15_r20/external/pytorch/test/distributed/_composable/fully_shard/
H A Dtest_fully_shard_runtime.py189 state: _FSDPState,
199 state: _FSDPState,