# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] # This file is a model zoo for testing torch.distributed.pipelining. import torch from torch.autograd import Function from torch.distributed.pipelining import pipe_split, SplitPoint class ExampleCode(torch.nn.Module): def __init__(self, d_hid): super().__init__() self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False)) self.lin0 = torch.nn.Linear(d_hid, d_hid) self.lin1 = torch.nn.Linear(d_hid, d_hid) def forward(self, x): x = torch.mm(x, self.mm_param0) x = torch.relu(x) # try passing a value that doesn't require_grad across skip boundaries a_constant = self.cval.clone() x = self.lin0(x) pipe_split() x = torch.relu(x) + a_constant x = torch.mm(x, self.mm_param1) x = self.lin1(x) x = torch.relu(x) return x class ModelWithKwargs(torch.nn.Module): DEFAULT_DHID = 512 DEFAULT_BATCH_SIZE = 256 def __init__(self, d_hid: int = DEFAULT_DHID): super().__init__() self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.lin0 = torch.nn.Linear(d_hid, d_hid) self.lin1 = torch.nn.Linear(d_hid, d_hid) def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)): x = torch.mm(x, self.mm_param0) x = x + y x = self.lin0(x) x = torch.relu(x) pipe_split() x = torch.mm(x, self.mm_param1) x = self.lin1(x) x = torch.relu(x) return x class ModelWithParamAlias(torch.nn.Module): default_dhid = 512 default_batch_size = 256 def __init__(self, d_hid: int = default_dhid): super().__init__() self.mm_param1 = self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.lin1 = self.lin0 = torch.nn.Linear(d_hid, d_hid) def forward(self, x, y): x = torch.mm(x, self.mm_param0) x = x + y x = self.lin0(x) x = torch.relu(x) pipe_split() x = torch.mm(x, self.mm_param1) x = self.lin1(x) x = torch.relu(x) return x # MLP Layer class MLPModule(torch.nn.Module): def __init__(self, d_hid: int): super().__init__() self.net1 = torch.nn.Linear(d_hid, d_hid) self.relu = torch.nn.ReLU() self.net2 = torch.nn.Linear(d_hid, d_hid) def forward(self, x): x = self.net1(x) x = self.relu(x) x = self.net2(x) return x # Multi-MLP model class MultiMLP(torch.nn.Module): def __init__(self, d_hid: int, n_layers: int = 2): super().__init__() self.layers = torch.nn.ModuleList([MLPModule(d_hid) for _ in range(n_layers)]) # For testing purpose only, this should be defined by user self.split_spec = { f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers) } def forward(self, x): for layer in self.layers: x = layer(x) return x class CustomLinearDx(Function): @staticmethod def forward(ctx, input_val, weight, bias, module, layer_idx): ctx.save_for_backward(input_val, weight, bias) ctx.module = module ctx.layer_idx = layer_idx return input_val.mm(weight.t()) + bias @staticmethod def backward(ctx, grad_output): input_val, weight, bias = ctx.saved_tensors grad_input = grad_output.mm(weight) ctx.module.cached_context[ctx.layer_idx].append(grad_output.clone()) ctx.module.cached_context[str(ctx.layer_idx) + "_input"].append( input_val.clone() ) return grad_input, None, None, None, None class CustomLinearDxDw(Function): @staticmethod def forward(ctx, input_val, weight, bias): ctx.save_for_backward(input_val, weight, bias) return input_val.mm(weight.t()) + bias @staticmethod def backward(ctx, grad_output): input_val, weight, bias = ctx.saved_tensors grad_input = grad_output.mm(weight) grad_weight = grad_output.t().mm(input_val) grad_bias = grad_output.sum(0) return grad_input, grad_weight, grad_bias class MLPModuleWithDw(torch.nn.Module): def __init__(self, d_hid: int): super().__init__() self.fc1_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.fc1_bias = torch.nn.Parameter(torch.randn(d_hid)) self.fc2_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.fc2_bias = torch.nn.Parameter(torch.randn(d_hid)) torch.nn.init.uniform_(self.fc1_weight, -0.01, 0.01) torch.nn.init.uniform_(self.fc2_weight, -0.01, 0.01) torch.nn.init.uniform_(self.fc1_bias, -0.01, 0.01) torch.nn.init.uniform_(self.fc2_bias, -0.01, 0.01) self.cached_context = {} self.cached_context["fc1"] = [] self.cached_context["fc2"] = [] self.cached_context["fc1_input"] = [] self.cached_context["fc2_input"] = [] self.use_custom_logic = False def forward(self, x): if not self.use_custom_logic: self.hidden = CustomLinearDxDw.apply(x, self.fc1_weight, self.fc1_bias) self.hidden = torch.nn.functional.relu(self.hidden) output = CustomLinearDxDw.apply(self.hidden, self.fc2_weight, self.fc2_bias) return output self.hidden = CustomLinearDx.apply( x, self.fc1_weight, self.fc1_bias, self, "fc1" ) self.hidden = torch.nn.functional.relu(self.hidden) output = CustomLinearDx.apply( self.hidden, self.fc2_weight, self.fc2_bias, self, "fc2" ) return output def compute_dW(self): grad_output_fc1 = self.cached_context["fc1"].pop(0) grad_output_fc2 = self.cached_context["fc2"].pop(0) cached_input_fc1 = self.cached_context["fc1_input"].pop(0) cached_input_fc2 = self.cached_context["fc2_input"].pop(0) dW2 = grad_output_fc2.t().mm(cached_input_fc2) db2 = grad_output_fc2.sum(0) dW1 = grad_output_fc1.t().mm(cached_input_fc1) db1 = grad_output_fc1.sum(0) if self.fc1_weight.grad is not None: self.fc1_weight.grad += dW1 self.fc1_bias.grad += db1 self.fc2_weight.grad += dW2 self.fc2_bias.grad += db2 else: self.fc1_weight.grad = dW1 self.fc1_bias.grad = db1 self.fc2_weight.grad = dW2 self.fc2_bias.grad = db2 def toggle(self): self.use_custom_logic = not self.use_custom_logic # Multi-MLP model With Dw class MultiMLPWithDw(torch.nn.Module): def __init__(self, d_hid: int, n_layers: int = 2): super().__init__() self.layers = torch.nn.ModuleList( [MLPModuleWithDw(d_hid) for _ in range(n_layers)] ) # For testing purpose only, this should be defined by user self.split_spec = { f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers) } self.use_custom_logic = False def forward(self, x): for layer in self.layers: x = layer(x) return x def toggle(self): self.use_custom_logic = not self.use_custom_logic for layer in self.layers: layer.toggle() def compute_dW(self): if not self.use_custom_logic: raise RuntimeError("Need to call toggle() to enable custom backward and dW") for i in reversed(range(len(self.layers))): self.layers[i].compute_dW()