Home
last modified time | relevance | path

Searched defs:FSDPParamGroup (Results 1 – 4 of 4) sorted by relevance

/aosp_15_r20/external/pytorch/torch/distributed/_composable/fsdp/
H A D_fsdp_param_group.py96 class FSDPParamGroup: class
437 target_fsdp_param_group: "FSDPParamGroup", pass_type: str
622 def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor):
H A Dfully_shard.py391 def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]):
/aosp_15_r20/external/pytorch/torch/distributed/_tools/
H A Dfsdp2_mem_tracker.py170 self, fsdp_param_group: FSDPParamGroup
/aosp_15_r20/external/pytorch/test/distributed/_composable/fsdp/
H A Dtest_fully_shard_comm.py151 def all_gather(fsdp_param_group: FSDPParamGroup, group: dist.ProcessGroup):