1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3# This file is a model zoo for testing torch.distributed.pipelining. 4import torch 5from torch.autograd import Function 6from torch.distributed.pipelining import pipe_split, SplitPoint 7 8 9class ExampleCode(torch.nn.Module): 10 def __init__(self, d_hid): 11 super().__init__() 12 self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 13 self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 14 self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False)) 15 self.lin0 = torch.nn.Linear(d_hid, d_hid) 16 self.lin1 = torch.nn.Linear(d_hid, d_hid) 17 18 def forward(self, x): 19 x = torch.mm(x, self.mm_param0) 20 x = torch.relu(x) 21 # try passing a value that doesn't require_grad across skip boundaries 22 a_constant = self.cval.clone() 23 x = self.lin0(x) 24 pipe_split() 25 x = torch.relu(x) + a_constant 26 x = torch.mm(x, self.mm_param1) 27 x = self.lin1(x) 28 x = torch.relu(x) 29 return x 30 31 32class ModelWithKwargs(torch.nn.Module): 33 DEFAULT_DHID = 512 34 DEFAULT_BATCH_SIZE = 256 35 36 def __init__(self, d_hid: int = DEFAULT_DHID): 37 super().__init__() 38 self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 39 self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 40 self.lin0 = torch.nn.Linear(d_hid, d_hid) 41 self.lin1 = torch.nn.Linear(d_hid, d_hid) 42 43 def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)): 44 x = torch.mm(x, self.mm_param0) 45 x = x + y 46 x = self.lin0(x) 47 x = torch.relu(x) 48 pipe_split() 49 x = torch.mm(x, self.mm_param1) 50 x = self.lin1(x) 51 x = torch.relu(x) 52 return x 53 54 55class ModelWithParamAlias(torch.nn.Module): 56 default_dhid = 512 57 default_batch_size = 256 58 59 def __init__(self, d_hid: int = default_dhid): 60 super().__init__() 61 self.mm_param1 = self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 62 self.lin1 = self.lin0 = torch.nn.Linear(d_hid, d_hid) 63 64 def forward(self, x, y): 65 x = torch.mm(x, self.mm_param0) 66 x = x + y 67 x = self.lin0(x) 68 x = torch.relu(x) 69 pipe_split() 70 x = torch.mm(x, self.mm_param1) 71 x = self.lin1(x) 72 x = torch.relu(x) 73 return x 74 75 76# MLP Layer 77class MLPModule(torch.nn.Module): 78 def __init__(self, d_hid: int): 79 super().__init__() 80 self.net1 = torch.nn.Linear(d_hid, d_hid) 81 self.relu = torch.nn.ReLU() 82 self.net2 = torch.nn.Linear(d_hid, d_hid) 83 84 def forward(self, x): 85 x = self.net1(x) 86 x = self.relu(x) 87 x = self.net2(x) 88 return x 89 90 91# Multi-MLP model 92class MultiMLP(torch.nn.Module): 93 def __init__(self, d_hid: int, n_layers: int = 2): 94 super().__init__() 95 self.layers = torch.nn.ModuleList([MLPModule(d_hid) for _ in range(n_layers)]) 96 # For testing purpose only, this should be defined by user 97 self.split_spec = { 98 f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers) 99 } 100 101 def forward(self, x): 102 for layer in self.layers: 103 x = layer(x) 104 return x 105 106 107class CustomLinearDx(Function): 108 @staticmethod 109 def forward(ctx, input_val, weight, bias, module, layer_idx): 110 ctx.save_for_backward(input_val, weight, bias) 111 ctx.module = module 112 ctx.layer_idx = layer_idx 113 return input_val.mm(weight.t()) + bias 114 115 @staticmethod 116 def backward(ctx, grad_output): 117 input_val, weight, bias = ctx.saved_tensors 118 grad_input = grad_output.mm(weight) 119 ctx.module.cached_context[ctx.layer_idx].append(grad_output.clone()) 120 ctx.module.cached_context[str(ctx.layer_idx) + "_input"].append( 121 input_val.clone() 122 ) 123 return grad_input, None, None, None, None 124 125 126class CustomLinearDxDw(Function): 127 @staticmethod 128 def forward(ctx, input_val, weight, bias): 129 ctx.save_for_backward(input_val, weight, bias) 130 return input_val.mm(weight.t()) + bias 131 132 @staticmethod 133 def backward(ctx, grad_output): 134 input_val, weight, bias = ctx.saved_tensors 135 grad_input = grad_output.mm(weight) 136 grad_weight = grad_output.t().mm(input_val) 137 grad_bias = grad_output.sum(0) 138 return grad_input, grad_weight, grad_bias 139 140 141class MLPModuleWithDw(torch.nn.Module): 142 def __init__(self, d_hid: int): 143 super().__init__() 144 self.fc1_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 145 self.fc1_bias = torch.nn.Parameter(torch.randn(d_hid)) 146 self.fc2_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 147 self.fc2_bias = torch.nn.Parameter(torch.randn(d_hid)) 148 149 torch.nn.init.uniform_(self.fc1_weight, -0.01, 0.01) 150 torch.nn.init.uniform_(self.fc2_weight, -0.01, 0.01) 151 torch.nn.init.uniform_(self.fc1_bias, -0.01, 0.01) 152 torch.nn.init.uniform_(self.fc2_bias, -0.01, 0.01) 153 154 self.cached_context = {} 155 self.cached_context["fc1"] = [] 156 self.cached_context["fc2"] = [] 157 self.cached_context["fc1_input"] = [] 158 self.cached_context["fc2_input"] = [] 159 160 self.use_custom_logic = False 161 162 def forward(self, x): 163 if not self.use_custom_logic: 164 self.hidden = CustomLinearDxDw.apply(x, self.fc1_weight, self.fc1_bias) 165 self.hidden = torch.nn.functional.relu(self.hidden) 166 output = CustomLinearDxDw.apply(self.hidden, self.fc2_weight, self.fc2_bias) 167 return output 168 169 self.hidden = CustomLinearDx.apply( 170 x, self.fc1_weight, self.fc1_bias, self, "fc1" 171 ) 172 self.hidden = torch.nn.functional.relu(self.hidden) 173 output = CustomLinearDx.apply( 174 self.hidden, self.fc2_weight, self.fc2_bias, self, "fc2" 175 ) 176 return output 177 178 def compute_dW(self): 179 grad_output_fc1 = self.cached_context["fc1"].pop(0) 180 grad_output_fc2 = self.cached_context["fc2"].pop(0) 181 cached_input_fc1 = self.cached_context["fc1_input"].pop(0) 182 cached_input_fc2 = self.cached_context["fc2_input"].pop(0) 183 184 dW2 = grad_output_fc2.t().mm(cached_input_fc2) 185 db2 = grad_output_fc2.sum(0) 186 187 dW1 = grad_output_fc1.t().mm(cached_input_fc1) 188 db1 = grad_output_fc1.sum(0) 189 190 if self.fc1_weight.grad is not None: 191 self.fc1_weight.grad += dW1 192 self.fc1_bias.grad += db1 193 self.fc2_weight.grad += dW2 194 self.fc2_bias.grad += db2 195 else: 196 self.fc1_weight.grad = dW1 197 self.fc1_bias.grad = db1 198 self.fc2_weight.grad = dW2 199 self.fc2_bias.grad = db2 200 201 def toggle(self): 202 self.use_custom_logic = not self.use_custom_logic 203 204 205# Multi-MLP model With Dw 206class MultiMLPWithDw(torch.nn.Module): 207 def __init__(self, d_hid: int, n_layers: int = 2): 208 super().__init__() 209 self.layers = torch.nn.ModuleList( 210 [MLPModuleWithDw(d_hid) for _ in range(n_layers)] 211 ) 212 # For testing purpose only, this should be defined by user 213 self.split_spec = { 214 f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers) 215 } 216 self.use_custom_logic = False 217 218 def forward(self, x): 219 for layer in self.layers: 220 x = layer(x) 221 return x 222 223 def toggle(self): 224 self.use_custom_logic = not self.use_custom_logic 225 for layer in self.layers: 226 layer.toggle() 227 228 def compute_dW(self): 229 if not self.use_custom_logic: 230 raise RuntimeError("Need to call toggle() to enable custom backward and dW") 231 232 for i in reversed(range(len(self.layers))): 233 self.layers[i].compute_dW() 234