xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/examples/convnext_example.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3The following example demonstrates how to train a ConvNeXt model
4with intermediate activations sharded across mutliple GPUs via DTensor
5
6To run the example, use the following command:
7torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py
8"""
9import os
10import time
11
12import torch
13import torch.distributed as dist
14import torch.nn as nn
15from torch.distributed.tensor import (
16    DeviceMesh,
17    distribute_module,
18    distribute_tensor,
19    init_device_mesh,
20    Replicate,
21    Shard,
22)
23
24
25WORLD_SIZE = 4
26ITER_TIME = 20
27
28
29class LayerNorm(nn.Module):
30    def __init__(self, normalized_shape, eps=1e-6, data_format=torch.contiguous_format):
31        super().__init__()
32        self.weight = nn.Parameter(torch.ones(normalized_shape))
33        self.bias = nn.Parameter(torch.zeros(normalized_shape))
34        self.eps = eps
35        self.data_format = data_format
36        if self.data_format not in [torch.contiguous_format]:
37            raise NotImplementedError
38        self.normalized_shape = (normalized_shape,)
39
40    def forward(self, x):
41        u = x.mean(1, keepdim=True)
42        s = (x - u).pow(2).mean(1, keepdim=True)
43        x = (x - u) / torch.sqrt(s + self.eps)
44        x = self.weight[:, None, None] * x + self.bias[:, None, None]
45        return x
46
47
48class Block(nn.Module):
49    def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6):
50        super().__init__()
51        self.dwconv = nn.Conv2d(
52            dim, dim, kernel_size=7, padding=3, groups=dim
53        )  # depthwise conv
54        self.norm = LayerNorm(dim, eps=1e-6, data_format=torch.contiguous_format)
55        self.pwconv1 = nn.Conv2d(
56            dim, 4 * dim, kernel_size=1, stride=1
57        )  # nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
58        self.act = nn.GELU()
59        self.pwconv2 = nn.Conv2d(
60            4 * dim, dim, kernel_size=1, stride=1
61        )  # nn.Linear(4 * dim, dim)
62        self.gamma = (
63            nn.Parameter(
64                layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
65            )
66            if layer_scale_init_value > 0
67            else None
68        )
69        self.drop_path = nn.Identity()
70
71    def forward(self, x):
72        input_x = x
73        x = self.dwconv(x)
74
75        x = self.norm(x)
76        x = self.pwconv1(x)
77        x = self.act(x)
78        x = self.pwconv2(x)
79
80        if self.gamma is not None:
81            x = self.gamma * self.drop_path(x)
82        x = input_x + x
83        return x
84
85
86class DownSampling(nn.Module):
87    def __init__(self, dim_in=3, dim_out=2, down_scale=4, norm_first=False):
88        super().__init__()
89        self.norm_first = norm_first
90        if norm_first:
91            self.norm = LayerNorm(dim_in, eps=1e-6, data_format=torch.contiguous_format)
92            self.conv = nn.Conv2d(
93                dim_in, dim_out, kernel_size=down_scale, stride=down_scale
94            )
95        else:
96            self.conv = nn.Conv2d(
97                dim_in, dim_out, kernel_size=down_scale, stride=down_scale
98            )
99            self.norm = LayerNorm(
100                dim_out, eps=1e-6, data_format=torch.contiguous_format
101            )
102
103    def forward(self, x):
104        if self.norm_first:
105            return self.conv(self.norm(x))
106        else:
107            return self.norm(self.conv(x))
108
109
110@torch.no_grad()
111def init_weights(m):
112    if type(m) == nn.Conv2d or type(m) == nn.Linear:
113        nn.init.ones_(m.weight)
114        if m.bias is not None:
115            nn.init.zeros_(m.bias)
116
117
118class ConvNeXt(nn.Module):
119    def __init__(
120        self,
121        in_chans=3,
122        num_classes=10,
123        depths=[1, 1],  # noqa: B006
124        dims=[2, 4],  # noqa: B006
125        drop_path_rate=0.0,
126        layer_scale_init_value=1e-6,
127        head_init_scale=1.0,
128    ):
129        super().__init__()
130
131        self.downsample_layers = nn.ModuleList()
132        stem = DownSampling(in_chans, dims[0], 4, norm_first=False)
133        self.downsample_layers.append(stem)
134        for i in range(len(dims) - 1):
135            downsample_layer = DownSampling(dims[i], dims[i + 1], 2, norm_first=True)
136            self.downsample_layers.append(downsample_layer)
137
138        self.stages = nn.ModuleList()
139        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
140        cur = 0
141        for i in range(len(dims)):
142            stage = nn.Sequential(
143                *[
144                    Block(
145                        dim=dims[i],
146                        drop_path=dp_rates[cur + j],
147                        layer_scale_init_value=layer_scale_init_value,
148                    )
149                    for j in range(depths[i])
150                ]
151            )
152            self.stages.append(stage)
153            cur += depths[i]
154
155        self.head = nn.Linear(dims[-1], num_classes)
156        self.apply(init_weights)
157
158    def forward(self, x):
159        for i in range(len(self.stages)):
160            x = self.downsample_layers[i](x)
161            x = self.stages[i](x)
162        x = x.mean([-2, -1])
163        x = self.head(x)
164        return x
165
166
167def _conv_fn(
168    name: str,
169    module: nn.Module,
170    device_mesh: DeviceMesh,
171) -> None:
172    for name, param in module.named_parameters():
173        dist_spec = [Replicate()]
174        dist_param = torch.nn.Parameter(
175            distribute_tensor(param, device_mesh, dist_spec)
176        )
177        dist_param.register_hook(lambda grad: grad.redistribute(placements=dist_spec))
178        name = "_".join(name.split("."))
179        module.register_parameter(name, dist_param)
180
181
182def train_convnext_example():
183    device_type = "cuda"
184    world_size = int(os.environ["WORLD_SIZE"])
185    mesh = init_device_mesh(device_type, (world_size,))
186    rank = mesh.get_rank()
187
188    in_shape = [7, 3, 512, 1024]
189    output_shape = [7, 1000]
190
191    torch.manual_seed(12)
192    model = ConvNeXt(
193        depths=[3, 3, 27, 3],
194        dims=[256, 512, 1024, 2048],
195        drop_path_rate=0.0,
196        num_classes=1000,
197    ).to(device_type)
198    model = distribute_module(model, mesh, _conv_fn, input_fn=None, output_fn=None)
199
200    criterion = torch.nn.CrossEntropyLoss()
201    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, amsgrad=False)
202
203    x = torch.randn(*in_shape).to(device_type).requires_grad_()
204    y_target = (
205        torch.empty(output_shape[0], dtype=torch.long)
206        .random_(output_shape[1])
207        .to(device_type)
208    )
209    x = distribute_tensor(x, mesh, [Shard(3)])
210    y_target = distribute_tensor(y_target, mesh, [Replicate()])
211
212    # warm up
213    y = model(x)
214    loss = criterion(y, y_target)
215    optimizer.zero_grad()
216    loss.backward()
217    optimizer.step()
218    torch.cuda.synchronize()
219
220    forward_time = 0.0
221    backward_time = 0.0
222    start = time.time()
223    for i in range(ITER_TIME):
224        t1 = time.time()
225        y = model(x)
226        torch.cuda.synchronize()
227        t2 = time.time()
228
229        loss = criterion(y, y_target)
230        optimizer.zero_grad()
231
232        t3 = time.time()
233        loss.backward()
234        torch.cuda.synchronize()
235        t4 = time.time()
236
237        optimizer.step()
238
239        forward_time += t2 - t1
240        backward_time += t4 - t3
241    torch.cuda.synchronize()
242    end = time.time()
243    max_reserved = torch.cuda.max_memory_reserved()
244    max_allocated = torch.cuda.max_memory_allocated()
245    print(
246        f"rank {rank}, {ITER_TIME} iterations, average latency {(end - start)/ITER_TIME*1000:10.2f} ms"
247    )
248    print(
249        f"rank {rank}, forward {forward_time/ITER_TIME*1000:10.2f} ms, backward {backward_time/ITER_TIME*1000:10.2f} ms"
250    )
251    print(
252        f"rank {rank}, max reserved {max_reserved/1024/1024/1024:8.2f} GiB, max allocated {max_allocated/1024/1024/1024:8.2f} GiB"
253    )
254    dist.destroy_process_group()
255
256
257train_convnext_example()
258