xref: /aosp_15_r20/external/pytorch/torch/distributed/_sharding_spec/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Keep old package for BC purposes, this file should be removed once
2# everything moves to the `torch.distributed._shard` package.
3import sys
4import warnings
5
6import torch
7from torch.distributed._shard.sharding_spec import *  # noqa: F403
8
9
10with warnings.catch_warnings():
11    warnings.simplefilter("always")
12    warnings.warn(
13        "`torch.distributed._sharding_spec` will be deprecated, "
14        "use `torch.distributed._shard.sharding_spec` instead",
15        DeprecationWarning,
16        stacklevel=2,
17    )
18
19import torch.distributed._shard.sharding_spec as _sharding_spec
20
21
22sys.modules["torch.distributed._sharding_spec"] = _sharding_spec
23