1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3from model_registry import ModelWithKwargs 4 5import torch 6from torch.distributed.pipelining import pipeline 7from torch.distributed.pipelining.microbatch import ( 8 merge_chunks, 9 split_args_kwargs_into_chunks, 10 TensorChunkSpec, 11) 12from torch.testing._internal.common_utils import run_tests, TestCase 13 14 15d_hid = 512 16torch.manual_seed(0) 17 18 19class MicrobatchTests(TestCase): 20 def test_split_and_merge(self): 21 x0 = torch.randn(128, d_hid) 22 x1 = torch.randn(256, d_hid) 23 x2 = torch.randn(512, d_hid) 24 25 args = (x0, x1, x2) 26 kwargs = {"x0": x0, "x1": x1, "x2": x2} 27 28 # Default chunking: dim 0 29 arg_chunks, kwarg_chunks = split_args_kwargs_into_chunks(args, kwargs, 2) 30 assert len(arg_chunks) == 2 31 assert len(kwarg_chunks) == 2 32 assert arg_chunks[0][0].shape == torch.Size([64, d_hid]) 33 assert arg_chunks[1][0].shape == torch.Size([64, d_hid]) 34 assert arg_chunks[0][1].shape == torch.Size([128, d_hid]) 35 assert arg_chunks[0][2].shape == torch.Size([256, d_hid]) 36 assert kwarg_chunks[0]["x0"].shape == torch.Size([64, d_hid]) 37 assert kwarg_chunks[0]["x1"].shape == torch.Size([128, d_hid]) 38 assert kwarg_chunks[1]["x2"].shape == torch.Size([256, d_hid]) 39 40 # Merge chunks back together 41 merged_args = merge_chunks( 42 arg_chunks, 43 (TensorChunkSpec(0), TensorChunkSpec(0), TensorChunkSpec(0)), 44 ) 45 torch.testing.assert_close(merged_args, args) 46 47 merged_kwargs = merge_chunks( 48 kwarg_chunks, 49 { 50 "x0": TensorChunkSpec(0), 51 "x1": TensorChunkSpec(0), 52 "x2": TensorChunkSpec(0), 53 }, 54 ) 55 torch.testing.assert_close(merged_kwargs, kwargs) 56 print("Microbatch test passed") 57 58 def test_chunk_spec(self): 59 mod = ModelWithKwargs() 60 batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE 61 62 x = torch.randn(batch_size, d_hid) 63 y = torch.randn(batch_size, d_hid) 64 65 num_chunks = 4 66 67 args_chunk_spec = TensorChunkSpec.from_tuple((0,)) 68 kwargs_chunk_spec = TensorChunkSpec.from_dict({"y": 0}) 69 70 args_split, kwargs_split = split_args_kwargs_into_chunks( 71 (x,), 72 {"y": y}, 73 num_chunks, 74 args_chunk_spec, 75 kwargs_chunk_spec, 76 ) 77 78 pipe = pipeline( 79 mod, 80 mb_args=args_split[0], 81 mb_kwargs=kwargs_split[0], 82 ) 83 84 ref = mod(x, y) 85 out = pipe(x, y)[0] 86 torch.testing.assert_close(out, ref) 87 print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") 88 89 90if __name__ == "__main__": 91 run_tests() 92