1# mypy: allow-untyped-defs 2""" 3The following example demonstrates how to train a ConvNeXt model 4with intermediate activations sharded across mutliple GPUs via DTensor 5 6To run the example, use the following command: 7torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py 8""" 9import os 10import time 11 12import torch 13import torch.distributed as dist 14import torch.nn as nn 15from torch.distributed.tensor import ( 16 DeviceMesh, 17 distribute_module, 18 distribute_tensor, 19 init_device_mesh, 20 Replicate, 21 Shard, 22) 23 24 25WORLD_SIZE = 4 26ITER_TIME = 20 27 28 29class LayerNorm(nn.Module): 30 def __init__(self, normalized_shape, eps=1e-6, data_format=torch.contiguous_format): 31 super().__init__() 32 self.weight = nn.Parameter(torch.ones(normalized_shape)) 33 self.bias = nn.Parameter(torch.zeros(normalized_shape)) 34 self.eps = eps 35 self.data_format = data_format 36 if self.data_format not in [torch.contiguous_format]: 37 raise NotImplementedError 38 self.normalized_shape = (normalized_shape,) 39 40 def forward(self, x): 41 u = x.mean(1, keepdim=True) 42 s = (x - u).pow(2).mean(1, keepdim=True) 43 x = (x - u) / torch.sqrt(s + self.eps) 44 x = self.weight[:, None, None] * x + self.bias[:, None, None] 45 return x 46 47 48class Block(nn.Module): 49 def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6): 50 super().__init__() 51 self.dwconv = nn.Conv2d( 52 dim, dim, kernel_size=7, padding=3, groups=dim 53 ) # depthwise conv 54 self.norm = LayerNorm(dim, eps=1e-6, data_format=torch.contiguous_format) 55 self.pwconv1 = nn.Conv2d( 56 dim, 4 * dim, kernel_size=1, stride=1 57 ) # nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 58 self.act = nn.GELU() 59 self.pwconv2 = nn.Conv2d( 60 4 * dim, dim, kernel_size=1, stride=1 61 ) # nn.Linear(4 * dim, dim) 62 self.gamma = ( 63 nn.Parameter( 64 layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True 65 ) 66 if layer_scale_init_value > 0 67 else None 68 ) 69 self.drop_path = nn.Identity() 70 71 def forward(self, x): 72 input_x = x 73 x = self.dwconv(x) 74 75 x = self.norm(x) 76 x = self.pwconv1(x) 77 x = self.act(x) 78 x = self.pwconv2(x) 79 80 if self.gamma is not None: 81 x = self.gamma * self.drop_path(x) 82 x = input_x + x 83 return x 84 85 86class DownSampling(nn.Module): 87 def __init__(self, dim_in=3, dim_out=2, down_scale=4, norm_first=False): 88 super().__init__() 89 self.norm_first = norm_first 90 if norm_first: 91 self.norm = LayerNorm(dim_in, eps=1e-6, data_format=torch.contiguous_format) 92 self.conv = nn.Conv2d( 93 dim_in, dim_out, kernel_size=down_scale, stride=down_scale 94 ) 95 else: 96 self.conv = nn.Conv2d( 97 dim_in, dim_out, kernel_size=down_scale, stride=down_scale 98 ) 99 self.norm = LayerNorm( 100 dim_out, eps=1e-6, data_format=torch.contiguous_format 101 ) 102 103 def forward(self, x): 104 if self.norm_first: 105 return self.conv(self.norm(x)) 106 else: 107 return self.norm(self.conv(x)) 108 109 110@torch.no_grad() 111def init_weights(m): 112 if type(m) == nn.Conv2d or type(m) == nn.Linear: 113 nn.init.ones_(m.weight) 114 if m.bias is not None: 115 nn.init.zeros_(m.bias) 116 117 118class ConvNeXt(nn.Module): 119 def __init__( 120 self, 121 in_chans=3, 122 num_classes=10, 123 depths=[1, 1], # noqa: B006 124 dims=[2, 4], # noqa: B006 125 drop_path_rate=0.0, 126 layer_scale_init_value=1e-6, 127 head_init_scale=1.0, 128 ): 129 super().__init__() 130 131 self.downsample_layers = nn.ModuleList() 132 stem = DownSampling(in_chans, dims[0], 4, norm_first=False) 133 self.downsample_layers.append(stem) 134 for i in range(len(dims) - 1): 135 downsample_layer = DownSampling(dims[i], dims[i + 1], 2, norm_first=True) 136 self.downsample_layers.append(downsample_layer) 137 138 self.stages = nn.ModuleList() 139 dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 140 cur = 0 141 for i in range(len(dims)): 142 stage = nn.Sequential( 143 *[ 144 Block( 145 dim=dims[i], 146 drop_path=dp_rates[cur + j], 147 layer_scale_init_value=layer_scale_init_value, 148 ) 149 for j in range(depths[i]) 150 ] 151 ) 152 self.stages.append(stage) 153 cur += depths[i] 154 155 self.head = nn.Linear(dims[-1], num_classes) 156 self.apply(init_weights) 157 158 def forward(self, x): 159 for i in range(len(self.stages)): 160 x = self.downsample_layers[i](x) 161 x = self.stages[i](x) 162 x = x.mean([-2, -1]) 163 x = self.head(x) 164 return x 165 166 167def _conv_fn( 168 name: str, 169 module: nn.Module, 170 device_mesh: DeviceMesh, 171) -> None: 172 for name, param in module.named_parameters(): 173 dist_spec = [Replicate()] 174 dist_param = torch.nn.Parameter( 175 distribute_tensor(param, device_mesh, dist_spec) 176 ) 177 dist_param.register_hook(lambda grad: grad.redistribute(placements=dist_spec)) 178 name = "_".join(name.split(".")) 179 module.register_parameter(name, dist_param) 180 181 182def train_convnext_example(): 183 device_type = "cuda" 184 world_size = int(os.environ["WORLD_SIZE"]) 185 mesh = init_device_mesh(device_type, (world_size,)) 186 rank = mesh.get_rank() 187 188 in_shape = [7, 3, 512, 1024] 189 output_shape = [7, 1000] 190 191 torch.manual_seed(12) 192 model = ConvNeXt( 193 depths=[3, 3, 27, 3], 194 dims=[256, 512, 1024, 2048], 195 drop_path_rate=0.0, 196 num_classes=1000, 197 ).to(device_type) 198 model = distribute_module(model, mesh, _conv_fn, input_fn=None, output_fn=None) 199 200 criterion = torch.nn.CrossEntropyLoss() 201 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, amsgrad=False) 202 203 x = torch.randn(*in_shape).to(device_type).requires_grad_() 204 y_target = ( 205 torch.empty(output_shape[0], dtype=torch.long) 206 .random_(output_shape[1]) 207 .to(device_type) 208 ) 209 x = distribute_tensor(x, mesh, [Shard(3)]) 210 y_target = distribute_tensor(y_target, mesh, [Replicate()]) 211 212 # warm up 213 y = model(x) 214 loss = criterion(y, y_target) 215 optimizer.zero_grad() 216 loss.backward() 217 optimizer.step() 218 torch.cuda.synchronize() 219 220 forward_time = 0.0 221 backward_time = 0.0 222 start = time.time() 223 for i in range(ITER_TIME): 224 t1 = time.time() 225 y = model(x) 226 torch.cuda.synchronize() 227 t2 = time.time() 228 229 loss = criterion(y, y_target) 230 optimizer.zero_grad() 231 232 t3 = time.time() 233 loss.backward() 234 torch.cuda.synchronize() 235 t4 = time.time() 236 237 optimizer.step() 238 239 forward_time += t2 - t1 240 backward_time += t4 - t3 241 torch.cuda.synchronize() 242 end = time.time() 243 max_reserved = torch.cuda.max_memory_reserved() 244 max_allocated = torch.cuda.max_memory_allocated() 245 print( 246 f"rank {rank}, {ITER_TIME} iterations, average latency {(end - start)/ITER_TIME*1000:10.2f} ms" 247 ) 248 print( 249 f"rank {rank}, forward {forward_time/ITER_TIME*1000:10.2f} ms, backward {backward_time/ITER_TIME*1000:10.2f} ms" 250 ) 251 print( 252 f"rank {rank}, max reserved {max_reserved/1024/1024/1024:8.2f} GiB, max allocated {max_allocated/1024/1024/1024:8.2f} GiB" 253 ) 254 dist.destroy_process_group() 255 256 257train_convnext_example() 258