1import argparse 2import math 3 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7from torch.func import functional_call, grad_and_value, stack_module_state, vmap 8 9 10# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a 11# tutorial on Model Ensembling with JAX by Will Whitney. 12# 13# The original code comes with the following citation: 14# @misc{Whitney2021Parallelizing, 15# author = {William F. Whitney}, 16# title = { {Parallelizing neural networks on one GPU with JAX} }, 17# year = {2021}, 18# url = {http://willwhitney.com/parallel-training-jax.html}, 19# } 20 21# GOAL: Demonstrate that it is possible to use eager-mode vmap 22# to parallelize training over models. 23 24parser = argparse.ArgumentParser(description="Functorch Ensembled Models") 25parser.add_argument( 26 "--device", 27 type=str, 28 default="cpu", 29 help="CPU or GPU ID for this process (default: 'cpu')", 30) 31args = parser.parse_args() 32 33DEVICE = args.device 34 35# Step 1: Make some spirals 36 37 38def make_spirals(n_samples, noise_std=0.0, rotations=1.0): 39 ts = torch.linspace(0, 1, n_samples, device=DEVICE) 40 rs = ts**0.5 41 thetas = rs * rotations * 2 * math.pi 42 signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1 43 labels = (signs > 0).to(torch.long).to(DEVICE) 44 45 xs = ( 46 rs * signs * torch.cos(thetas) 47 + torch.randn(n_samples, device=DEVICE) * noise_std 48 ) 49 ys = ( 50 rs * signs * torch.sin(thetas) 51 + torch.randn(n_samples, device=DEVICE) * noise_std 52 ) 53 points = torch.stack([xs, ys], dim=1) 54 return points, labels 55 56 57points, labels = make_spirals(100, noise_std=0.05) 58 59 60# Step 2: Define two-layer MLP and loss function 61class MLPClassifier(nn.Module): 62 def __init__(self, hidden_dim=32, n_classes=2): 63 super().__init__() 64 self.hidden_dim = hidden_dim 65 self.n_classes = n_classes 66 67 self.fc1 = nn.Linear(2, self.hidden_dim) 68 self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) 69 70 def forward(self, x): 71 x = self.fc1(x) 72 x = F.relu(x) 73 x = self.fc2(x) 74 x = F.log_softmax(x, -1) 75 return x 76 77 78loss_fn = nn.NLLLoss() 79model = MLPClassifier().to(DEVICE) 80 81 82def train_step_fn(weights, batch, targets, lr=0.2): 83 def compute_loss(weights, batch, targets): 84 output = functional_call(model, weights, batch) 85 loss = loss_fn(output, targets) 86 return loss 87 88 grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) 89 90 # NB: PyTorch is missing a "functional optimizer API" (possibly coming soon) 91 # so we are going to re-implement SGD here. 92 new_weights = {} 93 with torch.no_grad(): 94 for key in grad_weights: 95 new_weights[key] = weights[key] - grad_weights[key] * lr 96 97 return loss, new_weights 98 99 100# Step 4: Let's verify this actually trains. 101# We should see the loss decrease. 102def step4(): 103 global weights 104 for i in range(2000): 105 loss, weights = train_step_fn(dict(model.named_parameters()), points, labels) 106 if i % 100 == 0: 107 print(loss) 108 109 110step4() 111 112# Step 5: We're ready for multiple models. Let's define an init_fn 113# that, given a number of models, returns to us all of the weights. 114 115 116def init_fn(num_models): 117 models = [MLPClassifier().to(DEVICE) for _ in range(num_models)] 118 params, _ = stack_module_state(models) 119 return params 120 121 122# Step 6: Now, can we try multiple models at the same time? 123# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps 124# on decreasing 125 126 127def step6(): 128 parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None)) 129 batched_weights = init_fn(num_models=2) 130 for i in range(2000): 131 loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels) 132 if i % 200 == 0: 133 print(loss) 134 135 136step6() 137 138# Step 7: Now, the flaw with step 6 is that we were training on the same exact 139# data. This can lead to all of the models in the ensemble overfitting in the 140# same way. The solution that http://willwhitney.com/parallel-training-jax.html 141# applies is to randomly subset the data in a way that the models do not recieve 142# exactly the same data in each training step! 143# Because the goal of this doc is to show that we can use eager-mode vmap to 144# achieve similar things as JAX, the rest of this is left as an exercise to the reader. 145 146# In conclusion, to achieve what http://willwhitney.com/parallel-training-jax.html 147# does, we used the following additional items that PyTorch does not have: 148# 1. NN module functional API that turns a module into a (state, state_less_fn) pair 149# 2. Functional optimizers 150# 3. A "functional" grad API (that effectively wraps autograd.grad) 151# 4. Composability between the functional grad API and torch.vmap. 152