xref: /aosp_15_r20/external/pytorch/test/distributed/pipelining/model_registry.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3# This file is a model zoo for testing torch.distributed.pipelining.
4import torch
5from torch.autograd import Function
6from torch.distributed.pipelining import pipe_split, SplitPoint
7
8
9class ExampleCode(torch.nn.Module):
10    def __init__(self, d_hid):
11        super().__init__()
12        self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
13        self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
14        self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False))
15        self.lin0 = torch.nn.Linear(d_hid, d_hid)
16        self.lin1 = torch.nn.Linear(d_hid, d_hid)
17
18    def forward(self, x):
19        x = torch.mm(x, self.mm_param0)
20        x = torch.relu(x)
21        # try passing a value that doesn't require_grad across skip boundaries
22        a_constant = self.cval.clone()
23        x = self.lin0(x)
24        pipe_split()
25        x = torch.relu(x) + a_constant
26        x = torch.mm(x, self.mm_param1)
27        x = self.lin1(x)
28        x = torch.relu(x)
29        return x
30
31
32class ModelWithKwargs(torch.nn.Module):
33    DEFAULT_DHID = 512
34    DEFAULT_BATCH_SIZE = 256
35
36    def __init__(self, d_hid: int = DEFAULT_DHID):
37        super().__init__()
38        self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
39        self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
40        self.lin0 = torch.nn.Linear(d_hid, d_hid)
41        self.lin1 = torch.nn.Linear(d_hid, d_hid)
42
43    def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)):
44        x = torch.mm(x, self.mm_param0)
45        x = x + y
46        x = self.lin0(x)
47        x = torch.relu(x)
48        pipe_split()
49        x = torch.mm(x, self.mm_param1)
50        x = self.lin1(x)
51        x = torch.relu(x)
52        return x
53
54
55class ModelWithParamAlias(torch.nn.Module):
56    default_dhid = 512
57    default_batch_size = 256
58
59    def __init__(self, d_hid: int = default_dhid):
60        super().__init__()
61        self.mm_param1 = self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
62        self.lin1 = self.lin0 = torch.nn.Linear(d_hid, d_hid)
63
64    def forward(self, x, y):
65        x = torch.mm(x, self.mm_param0)
66        x = x + y
67        x = self.lin0(x)
68        x = torch.relu(x)
69        pipe_split()
70        x = torch.mm(x, self.mm_param1)
71        x = self.lin1(x)
72        x = torch.relu(x)
73        return x
74
75
76# MLP Layer
77class MLPModule(torch.nn.Module):
78    def __init__(self, d_hid: int):
79        super().__init__()
80        self.net1 = torch.nn.Linear(d_hid, d_hid)
81        self.relu = torch.nn.ReLU()
82        self.net2 = torch.nn.Linear(d_hid, d_hid)
83
84    def forward(self, x):
85        x = self.net1(x)
86        x = self.relu(x)
87        x = self.net2(x)
88        return x
89
90
91# Multi-MLP model
92class MultiMLP(torch.nn.Module):
93    def __init__(self, d_hid: int, n_layers: int = 2):
94        super().__init__()
95        self.layers = torch.nn.ModuleList([MLPModule(d_hid) for _ in range(n_layers)])
96        # For testing purpose only, this should be defined by user
97        self.split_spec = {
98            f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
99        }
100
101    def forward(self, x):
102        for layer in self.layers:
103            x = layer(x)
104        return x
105
106
107class CustomLinearDx(Function):
108    @staticmethod
109    def forward(ctx, input_val, weight, bias, module, layer_idx):
110        ctx.save_for_backward(input_val, weight, bias)
111        ctx.module = module
112        ctx.layer_idx = layer_idx
113        return input_val.mm(weight.t()) + bias
114
115    @staticmethod
116    def backward(ctx, grad_output):
117        input_val, weight, bias = ctx.saved_tensors
118        grad_input = grad_output.mm(weight)
119        ctx.module.cached_context[ctx.layer_idx].append(grad_output.clone())
120        ctx.module.cached_context[str(ctx.layer_idx) + "_input"].append(
121            input_val.clone()
122        )
123        return grad_input, None, None, None, None
124
125
126class CustomLinearDxDw(Function):
127    @staticmethod
128    def forward(ctx, input_val, weight, bias):
129        ctx.save_for_backward(input_val, weight, bias)
130        return input_val.mm(weight.t()) + bias
131
132    @staticmethod
133    def backward(ctx, grad_output):
134        input_val, weight, bias = ctx.saved_tensors
135        grad_input = grad_output.mm(weight)
136        grad_weight = grad_output.t().mm(input_val)
137        grad_bias = grad_output.sum(0)
138        return grad_input, grad_weight, grad_bias
139
140
141class MLPModuleWithDw(torch.nn.Module):
142    def __init__(self, d_hid: int):
143        super().__init__()
144        self.fc1_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid))
145        self.fc1_bias = torch.nn.Parameter(torch.randn(d_hid))
146        self.fc2_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid))
147        self.fc2_bias = torch.nn.Parameter(torch.randn(d_hid))
148
149        torch.nn.init.uniform_(self.fc1_weight, -0.01, 0.01)
150        torch.nn.init.uniform_(self.fc2_weight, -0.01, 0.01)
151        torch.nn.init.uniform_(self.fc1_bias, -0.01, 0.01)
152        torch.nn.init.uniform_(self.fc2_bias, -0.01, 0.01)
153
154        self.cached_context = {}
155        self.cached_context["fc1"] = []
156        self.cached_context["fc2"] = []
157        self.cached_context["fc1_input"] = []
158        self.cached_context["fc2_input"] = []
159
160        self.use_custom_logic = False
161
162    def forward(self, x):
163        if not self.use_custom_logic:
164            self.hidden = CustomLinearDxDw.apply(x, self.fc1_weight, self.fc1_bias)
165            self.hidden = torch.nn.functional.relu(self.hidden)
166            output = CustomLinearDxDw.apply(self.hidden, self.fc2_weight, self.fc2_bias)
167            return output
168
169        self.hidden = CustomLinearDx.apply(
170            x, self.fc1_weight, self.fc1_bias, self, "fc1"
171        )
172        self.hidden = torch.nn.functional.relu(self.hidden)
173        output = CustomLinearDx.apply(
174            self.hidden, self.fc2_weight, self.fc2_bias, self, "fc2"
175        )
176        return output
177
178    def compute_dW(self):
179        grad_output_fc1 = self.cached_context["fc1"].pop(0)
180        grad_output_fc2 = self.cached_context["fc2"].pop(0)
181        cached_input_fc1 = self.cached_context["fc1_input"].pop(0)
182        cached_input_fc2 = self.cached_context["fc2_input"].pop(0)
183
184        dW2 = grad_output_fc2.t().mm(cached_input_fc2)
185        db2 = grad_output_fc2.sum(0)
186
187        dW1 = grad_output_fc1.t().mm(cached_input_fc1)
188        db1 = grad_output_fc1.sum(0)
189
190        if self.fc1_weight.grad is not None:
191            self.fc1_weight.grad += dW1
192            self.fc1_bias.grad += db1
193            self.fc2_weight.grad += dW2
194            self.fc2_bias.grad += db2
195        else:
196            self.fc1_weight.grad = dW1
197            self.fc1_bias.grad = db1
198            self.fc2_weight.grad = dW2
199            self.fc2_bias.grad = db2
200
201    def toggle(self):
202        self.use_custom_logic = not self.use_custom_logic
203
204
205# Multi-MLP model With Dw
206class MultiMLPWithDw(torch.nn.Module):
207    def __init__(self, d_hid: int, n_layers: int = 2):
208        super().__init__()
209        self.layers = torch.nn.ModuleList(
210            [MLPModuleWithDw(d_hid) for _ in range(n_layers)]
211        )
212        # For testing purpose only, this should be defined by user
213        self.split_spec = {
214            f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
215        }
216        self.use_custom_logic = False
217
218    def forward(self, x):
219        for layer in self.layers:
220            x = layer(x)
221        return x
222
223    def toggle(self):
224        self.use_custom_logic = not self.use_custom_logic
225        for layer in self.layers:
226            layer.toggle()
227
228    def compute_dW(self):
229        if not self.use_custom_logic:
230            raise RuntimeError("Need to call toggle() to enable custom backward and dW")
231
232        for i in reversed(range(len(self.layers))):
233            self.layers[i].compute_dW()
234