# Owner(s): ["oncall: distributed"] import sys import torch import torch.distributed.checkpoint as dcp import torch.nn as nn from torch.distributed._shard.sharded_tensor import ( Shard, ShardedTensor, ShardedTensorMetadata, ShardMetadata, ) from torch.distributed._shard.sharded_tensor.metadata import ( TensorProperties as TensorProperties_Shard, ) from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans from torch.distributed.checkpoint.api import CheckpointException from torch.distributed.checkpoint.default_planner import ( _create_default_local_metadata, create_default_global_save_plan, create_default_local_load_plan, create_default_local_save_plan, DefaultLoadPlanner, ) from torch.distributed.checkpoint.metadata import ( BytesStorageMetadata, ChunkStorageMetadata, MetadataIndex, TensorProperties, TensorStorageMetadata, ) from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType from torch.distributed.checkpoint.planner_helpers import ( create_read_items_for_chunk_list, ) from torch.testing._internal.common_utils import ( run_tests, TEST_WITH_DEV_DBG_ASAN, TestCase, ) from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir from torch.testing._internal.distributed.distributed_utils import ( with_dist, with_fake_comms, ) if TEST_WITH_DEV_DBG_ASAN: print( "Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr, ) sys.exit(0) def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8): shards_metadata = [] local_shards = [] for idx in range(0, world_size * shards_per_rank): shard_rank = idx // shards_per_rank shard_md = ShardMetadata( shard_offsets=[idx * shard_size], shard_sizes=[shard_size], placement=f"rank:{shard_rank}/cpu", ) shards_metadata.append(shard_md) if shard_rank == rank: shard = Shard.from_tensor_and_offsets( torch.rand(*shard_md.shard_sizes), shard_offsets=shard_md.shard_offsets, rank=rank, ) local_shards.append(shard) sharded_tensor_md = ShardedTensorMetadata( shards_metadata=shards_metadata, size=torch.Size([shard_size * len(shards_metadata)]), tensor_properties=TensorProperties_Shard.create_from_tensor(torch.zeros(1)), ) return ShardedTensor._init_from_local_shards_and_global_metadata( local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md ) class TestSavePlan(TestCase): @with_fake_comms(rank=1, world_size=4) def test_local_plan(self): tensor = torch.rand(10) val = [1, 2, 3] st = create_sharded_tensor(rank=1, world_size=4, shards_per_rank=1) state_dict = {"tensor": tensor, "value": val, "st": st} plan = create_default_local_save_plan(state_dict, False) self.assertEqual(3, len(plan.items)) wi = plan.items[0] self.assertEqual(wi.index, MetadataIndex("tensor", [0])) self.assertEqual(wi.type, WriteItemType.TENSOR) self.assertEqual(wi.tensor_data.size, tensor.size()) self.assertEqual( wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)), ) self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([0])) self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([10])) st_wi = plan.items[2] self.assertEqual(st_wi.index, MetadataIndex("st", [8])) self.assertEqual(st_wi.type, WriteItemType.SHARD) self.assertEqual(st_wi.tensor_data.size, st.size()) self.assertEqual( st_wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)), ) self.assertEqual(st_wi.tensor_data.chunk.offsets, torch.Size([8])) self.assertEqual(st_wi.tensor_data.chunk.sizes, torch.Size([8])) # Coordinator rank, should include replicated items as well plan = create_default_local_save_plan(state_dict, True) self.assertEqual(3, len(plan.items)) tensor_wi = next(wi for wi in plan.items if wi.type == WriteItemType.TENSOR) self.assertEqual(tensor_wi.index, MetadataIndex("tensor", [0])) self.assertEqual(tensor_wi.tensor_data.size, tensor.size()) self.assertEqual( tensor_wi.tensor_data.properties, TensorProperties.create_from_tensor(tensor), ) self.assertEqual(tensor_wi.tensor_data.chunk.offsets, torch.Size([0])) self.assertEqual(tensor_wi.tensor_data.chunk.sizes, torch.Size([10])) bytes_wi = next(wi for wi in plan.items if wi.type == WriteItemType.BYTE_IO) self.assertEqual(bytes_wi.index, MetadataIndex("value")) self.assertIsNone(bytes_wi.tensor_data) def test_global_plan(self): def create_data(rank): with with_dist(rank=rank, world_size=4): tensor = torch.rand(10) val = [1, 2, 3] st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1) state_dict = {"tensor": tensor, "value": val, "st": st} return create_default_local_save_plan(state_dict, rank == 0) all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)] all_plans = dedup_save_plans(all_plans) final_plans, metadata = create_default_global_save_plan(all_plans=all_plans) # The default global plan updates all indexes to include hints for new_plan, old_plan in zip(final_plans, all_plans): for new_item, old_item in zip(new_plan.items, old_plan.items): self.assertEqual(new_item.index, old_item.index) self.assertEqual(new_item.type, old_item.type) self.assertEqual(new_item.tensor_data, old_item.tensor_data) self.assertIn(new_item.index.fqn, metadata.state_dict_metadata) item_md = metadata.state_dict_metadata[new_item.index.fqn] if new_item.type == WriteItemType.BYTE_IO: self.assertTrue(isinstance(item_md, BytesStorageMetadata)) else: self.assertTrue(isinstance(item_md, TensorStorageMetadata)) self.assertEqual(item_md.size, old_item.tensor_data.size) self.assertEqual( item_md.properties, old_item.tensor_data.properties ) self.assertIsNotNone(new_item.index.index) # Make sure the hint is correct self.assertEqual( item_md.chunks[new_item.index.index], old_item.tensor_data.chunk ) def test_local_load_plan(self): def create_state_dict(rank): with with_dist(rank=rank, world_size=4): tensor = torch.rand(10) val = [1, 2, 3] st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1) return {"tensor": tensor, "value": val, "st": st} state_dict = create_state_dict(1) metadata = _create_default_local_metadata(state_dict) load_plan = create_default_local_load_plan(state_dict, metadata) # This will create 3 entries self.assertEqual(3, len(load_plan.items)) st_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "st") tensor_item = next( ri for ri in load_plan.items if ri.dest_index.fqn == "tensor" ) bytes_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "value") self.assertEqual(st_item.type, LoadItemType.TENSOR) # This is an exact copy self.assertEqual(st_item.dest_index, MetadataIndex("st", [8])) self.assertEqual(st_item.dest_offsets, torch.Size([0])) self.assertEqual(st_item.storage_index, MetadataIndex("st", [8])) self.assertEqual(st_item.storage_offsets, torch.Size([0])) self.assertEqual(st_item.lengths, torch.Size([8])) self.assertEqual(tensor_item.type, LoadItemType.TENSOR) self.assertEqual(tensor_item.dest_index, MetadataIndex("tensor", [0])) self.assertEqual(tensor_item.dest_offsets, torch.Size([0])) self.assertEqual(tensor_item.storage_index, MetadataIndex("tensor", [0])) self.assertEqual(tensor_item.storage_offsets, torch.Size([0])) self.assertEqual(tensor_item.lengths, torch.Size([10])) self.assertEqual(bytes_item.type, LoadItemType.BYTE_IO) self.assertEqual(bytes_item.dest_index, MetadataIndex("value")) def test_load_with_resharding(self): def create_state_dict(rank, world_size): with with_dist(rank=rank, world_size=world_size): return { "st": create_sharded_tensor( rank=rank, world_size=world_size, shards_per_rank=1, shard_size=128 // world_size, ) } # Rank 1 has a 16 bytes shard from [16, 32[ world8_state_dict = create_state_dict(rank=1, world_size=8) world8_metadata = _create_default_local_metadata(world8_state_dict) # Rank 1 has a 32 bytes shard from [32, 64[ world4_state_dict = create_state_dict(rank=1, world_size=4) world4_metadata = _create_default_local_metadata(world4_state_dict) # First scenario, going from world=8 to world=4, need to load 2 shards # Each 4-world shard has 32 elements, so it needs to load 2 shards load_plan = create_default_local_load_plan(world4_state_dict, world8_metadata) self.assertEqual(2, len(load_plan.items)) low_ri = next( ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]) ) high_ri = next( ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16]) ) self.assertEqual(low_ri.storage_index, MetadataIndex("st", [32])) self.assertEqual(low_ri.storage_offsets, torch.Size([0])) self.assertEqual(low_ri.dest_index, MetadataIndex("st", [32])) self.assertEqual(low_ri.dest_offsets, torch.Size([0])) self.assertEqual(low_ri.lengths, torch.Size([16])) self.assertEqual(high_ri.storage_index, MetadataIndex("st", [48])) self.assertEqual(high_ri.storage_offsets, torch.Size([0])) self.assertEqual(high_ri.dest_index, MetadataIndex("st", [32])) self.assertEqual(high_ri.dest_offsets, torch.Size([16])) self.assertEqual(high_ri.lengths, torch.Size([16])) # Second scenario, going from world=4 to world=8, need to load half of 1 shard # rank1 on 8-world needs to load the upper half of the rank0 4-world shard load_plan = create_default_local_load_plan(world8_state_dict, world4_metadata) self.assertEqual(1, len(load_plan.items)) ri = load_plan.items[0] self.assertEqual(ri.storage_index, MetadataIndex("st", [0])) self.assertEqual(ri.storage_offsets, torch.Size([16])) self.assertEqual(ri.dest_index, MetadataIndex("st", [16])) self.assertEqual(ri.dest_offsets, torch.Size([0])) self.assertEqual(ri.lengths, torch.Size([16])) def test_load_with_world_size_diff_by_one(self): def create_state_dict(rank, world_size): with with_dist(rank=rank, world_size=world_size): return { "st": create_sharded_tensor( rank=rank, world_size=world_size, shards_per_rank=1, shard_size=120 // world_size, ) } # rank 1 has a 30 bytes shard from [30, 60[ world4_state_dict = create_state_dict(rank=1, world_size=4) world4_metadata = _create_default_local_metadata(world4_state_dict) # rank 1 has a 40 bytes shard from [40, 80[ world3_state_dict = create_state_dict(rank=1, world_size=3) load_plan = create_default_local_load_plan(world3_state_dict, world4_metadata) self.assertEqual(2, len(load_plan.items)) # this is [30, 60] to load [40, 60] low_ri = next( ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]) ) # this is [60, 90] to load [60, 80] high_ri = next( ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20]) ) self.assertEqual(low_ri.storage_index, MetadataIndex("st", [30])) self.assertEqual(low_ri.storage_offsets, torch.Size([10])) self.assertEqual(low_ri.dest_index, MetadataIndex("st", [40])) self.assertEqual(low_ri.dest_offsets, torch.Size([0])) self.assertEqual(low_ri.lengths, torch.Size([20])) self.assertEqual(high_ri.storage_index, MetadataIndex("st", [60])) self.assertEqual(high_ri.storage_offsets, torch.Size([0])) self.assertEqual(high_ri.dest_index, MetadataIndex("st", [40])) self.assertEqual(high_ri.dest_offsets, torch.Size([20])) self.assertEqual(high_ri.lengths, torch.Size([20])) class TestPlannerHelpers(TestCase): def test_create_read_item_from_chunks(self): tensor_md = TensorStorageMetadata( properties=TensorProperties.create_from_tensor(torch.empty([16])), size=torch.Size([16]), chunks=[ ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([8])), ChunkStorageMetadata(offsets=torch.Size([8]), sizes=torch.Size([8])), ], ) chunk = ChunkStorageMetadata(offsets=torch.Size([4]), sizes=torch.Size([7])) read_items = create_read_items_for_chunk_list("foo", tensor_md, [chunk]) self.assertEqual(2, len(read_items)) self.assertEqual(MetadataIndex("foo", [4]), read_items[0].dest_index) self.assertEqual(torch.Size([0]), read_items[0].dest_offsets) self.assertEqual(MetadataIndex("foo", [0]), read_items[0].storage_index) self.assertEqual(torch.Size([4]), read_items[0].storage_offsets) self.assertEqual(torch.Size([4]), read_items[0].lengths) self.assertEqual(MetadataIndex("foo", [4]), read_items[1].dest_index) self.assertEqual(torch.Size([4]), read_items[1].dest_offsets) self.assertEqual(MetadataIndex("foo", [8]), read_items[1].storage_index) self.assertEqual(torch.Size([0]), read_items[1].storage_offsets) self.assertEqual(torch.Size([3]), read_items[1].lengths) class TestLoadPlanner(TestCase): @with_temp_dir def test_strict(self): original_module = nn.Linear(2, 2) dcp.save(state_dict={"module": original_module}, checkpoint_id=self.temp_dir) new_module = nn.Linear(2, 2) new_module.extra_param = nn.Parameter(torch.randn(2, 2)) dcp.load( state_dict={"module": new_module}, checkpoint_id=self.temp_dir, planner=DefaultLoadPlanner(allow_partial_load=True), ) with self.assertRaisesRegex(CheckpointException, "Missing key in checkpoint"): dcp.load( state_dict={"module": new_module}, checkpoint_id=self.temp_dir, planner=DefaultLoadPlanner(allow_partial_load=False), ) if __name__ == "__main__": run_tests()