Home
last modified time | relevance | path

Searched refs:ModuleWrapPolicy (Results 1 – 25 of 30) sorted by relevance

12

/aosp_15_r20/external/pytorch/test/distributed/_composable/fully_shard/
H A Dtest_fully_shard_init.py14 from torch.distributed.fsdp.wrap import _Policy, CustomPolicy, ModuleWrapPolicy
60 ModuleWrapPolicy({UnitModule}),
61 ModuleWrapPolicy({nn.Sequential}),
181 policy=ModuleWrapPolicy({UnitModule}),
198 policy = ModuleWrapPolicy({UnitModule})
250 auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
256 policy=ModuleWrapPolicy({UnitModule}),
290 fully_shard(composable_module, policy=ModuleWrapPolicy({UnitModule}))
H A Dtest_fully_shard_model_checkpoint.py15 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
61 fully_shard(save_composable, policy=ModuleWrapPolicy({UnitModule}))
68 copy.deepcopy(local_model), policy=ModuleWrapPolicy({UnitModule})
140 policy=ModuleWrapPolicy({TransformerEncoderLayer, TransformerDecoderLayer}),
163 policy=ModuleWrapPolicy({TransformerDecoderLayer, TransformerEncoderLayer}),
H A Dtest_fully_shard_util.py11 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
44 policy=ModuleWrapPolicy({UnitModule}),
H A Dtest_fully_shard_runtime.py18 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
64 auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
70 policy=ModuleWrapPolicy({UnitModule}),
H A Dtest_fully_shard_compile.py12 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
67 "policy": ModuleWrapPolicy(
H A Dtest_fully_shard_optim_checkpoint.py11 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
90 fully_shard(composable_model, policy=ModuleWrapPolicy({UnitModule}))
/aosp_15_r20/external/pytorch/test/distributed/_composable/
H A Dtest_compose.py14 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
96 policy=ModuleWrapPolicy({nn.Linear}),
179 fully_shard(model.u1, policy=ModuleWrapPolicy({nn.Linear}))
180 fully_shard(model.u2, policy=ModuleWrapPolicy({nn.Linear}))
303 fully_shard(test_model, policy=ModuleWrapPolicy({UnitModule}))
314 fully_shard(test_model.u2, policy=ModuleWrapPolicy({UnitModule}))
/aosp_15_r20/external/pytorch/test/distributed/fsdp/
H A Dtest_fsdp_fine_tune.py15 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
111 policy = ModuleWrapPolicy({nn.Linear})
232 policy = ModuleWrapPolicy({nn.Linear})
280 policy = ModuleWrapPolicy({nn.Linear})
347 "auto_wrap_policy": ModuleWrapPolicy({LinearUnusedInput}),
H A Dtest_fsdp_hybrid_shard.py25 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
214 pol = ModuleWrapPolicy({nn.Linear})
376 auto_wrap_policy = ModuleWrapPolicy(
404 auto_wrap_policy = ModuleWrapPolicy(
H A Dtest_fsdp_backward_prefetch.py17 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
86 policy = ModuleWrapPolicy(
H A Dtest_fsdp_ignored_modules.py14 from torch.distributed.fsdp.wrap import ModuleWrapPolicy, transformer_auto_wrap_policy
160 ] = ModuleWrapPolicy({nn.Linear})
286 "policy": [transformer_policy, ModuleWrapPolicy((nn.Sequential,))],
H A Dtest_fsdp_misc.py27 ModuleWrapPolicy,
189 auto_wrap_policy=ModuleWrapPolicy([nn.Linear, nn.Conv2d]),
608 auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
683 auto_wrap_policy = ModuleWrapPolicy(module_classes)
727 auto_wrap_policy = ModuleWrapPolicy({nn.Sequential})
H A Dtest_fsdp_clip_grad_norm.py16 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
141 "auto_wrap_policy": ModuleWrapPolicy(
H A Dtest_fsdp_use_orig_params.py28 from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy
134 "auto_wrap_policy": ModuleWrapPolicy(
241 "auto_wrap_policy": ModuleWrapPolicy(
1062 "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
1112 "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
H A Dtest_fsdp_comm_hooks.py14 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
283 auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
H A Dtest_fsdp_sharded_grad_scaler.py19 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
227 "auto_wrap_policy": ModuleWrapPolicy(
H A Dtest_wrap.py29 ModuleWrapPolicy,
460 auto_wrap_policy = ModuleWrapPolicy(
468 auto_wrap_policy = ModuleWrapPolicy(
827 module_wrap_policy = ModuleWrapPolicy(module_classes)
H A Dtest_fsdp_comm.py15 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
361 auto_wrap_policy=ModuleWrapPolicy((MLP,)),
H A Dtest_fsdp_meta.py14 ModuleWrapPolicy,
412 auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
H A Dtest_fsdp_core.py20 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
477 "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
H A Dtest_checkpoint_wrapper.py18 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
276 auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
H A Dtest_fsdp_mixed_precision.py25 from torch.distributed.fsdp.wrap import ModuleWrapPolicy, size_based_auto_wrap_policy
834 auto_wrap_policy = ModuleWrapPolicy(
964 auto_wrap_policy = ModuleWrapPolicy(
1327 policy = ModuleWrapPolicy(
H A Dtest_fsdp_state_dict.py39 from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap
435 auto_wrap_policy = ModuleWrapPolicy(
461 auto_wrap_policy = ModuleWrapPolicy(
1200 auto_wrap_policy = ModuleWrapPolicy(
/aosp_15_r20/external/pytorch/test/distributed/checkpoint/
H A Dtest_state_dict.py38 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
184 copy.deepcopy(orig_model), policy=ModuleWrapPolicy(strategy)
191 auto_wrap_policy=ModuleWrapPolicy(strategy),
198 auto_wrap_policy=ModuleWrapPolicy(strategy),
325 fully_shard(dist_model, policy=ModuleWrapPolicy({UnitModule}))
330 auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
440 auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
/aosp_15_r20/external/pytorch/benchmarks/dynamo/
H A Ddist_util.py16 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
142 wrap_policy = ModuleWrapPolicy(blocks)

12