xref: /aosp_15_r20/external/pytorch/functorch/examples/ensembling/parallel_train.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import math
3
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7from torch.func import functional_call, grad_and_value, stack_module_state, vmap
8
9
10# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a
11# tutorial on Model Ensembling with JAX by Will Whitney.
12#
13# The original code comes with the following citation:
14# @misc{Whitney2021Parallelizing,
15#     author = {William F. Whitney},
16#     title = { {Parallelizing neural networks on one GPU with JAX} },
17#     year = {2021},
18#     url = {http://willwhitney.com/parallel-training-jax.html},
19# }
20
21# GOAL: Demonstrate that it is possible to use eager-mode vmap
22# to parallelize training over models.
23
24parser = argparse.ArgumentParser(description="Functorch Ensembled Models")
25parser.add_argument(
26    "--device",
27    type=str,
28    default="cpu",
29    help="CPU or GPU ID for this process (default: 'cpu')",
30)
31args = parser.parse_args()
32
33DEVICE = args.device
34
35# Step 1: Make some spirals
36
37
38def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
39    ts = torch.linspace(0, 1, n_samples, device=DEVICE)
40    rs = ts**0.5
41    thetas = rs * rotations * 2 * math.pi
42    signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1
43    labels = (signs > 0).to(torch.long).to(DEVICE)
44
45    xs = (
46        rs * signs * torch.cos(thetas)
47        + torch.randn(n_samples, device=DEVICE) * noise_std
48    )
49    ys = (
50        rs * signs * torch.sin(thetas)
51        + torch.randn(n_samples, device=DEVICE) * noise_std
52    )
53    points = torch.stack([xs, ys], dim=1)
54    return points, labels
55
56
57points, labels = make_spirals(100, noise_std=0.05)
58
59
60# Step 2: Define two-layer MLP and loss function
61class MLPClassifier(nn.Module):
62    def __init__(self, hidden_dim=32, n_classes=2):
63        super().__init__()
64        self.hidden_dim = hidden_dim
65        self.n_classes = n_classes
66
67        self.fc1 = nn.Linear(2, self.hidden_dim)
68        self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
69
70    def forward(self, x):
71        x = self.fc1(x)
72        x = F.relu(x)
73        x = self.fc2(x)
74        x = F.log_softmax(x, -1)
75        return x
76
77
78loss_fn = nn.NLLLoss()
79model = MLPClassifier().to(DEVICE)
80
81
82def train_step_fn(weights, batch, targets, lr=0.2):
83    def compute_loss(weights, batch, targets):
84        output = functional_call(model, weights, batch)
85        loss = loss_fn(output, targets)
86        return loss
87
88    grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
89
90    # NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
91    # so we are going to re-implement SGD here.
92    new_weights = {}
93    with torch.no_grad():
94        for key in grad_weights:
95            new_weights[key] = weights[key] - grad_weights[key] * lr
96
97    return loss, new_weights
98
99
100# Step 4: Let's verify this actually trains.
101# We should see the loss decrease.
102def step4():
103    global weights
104    for i in range(2000):
105        loss, weights = train_step_fn(dict(model.named_parameters()), points, labels)
106        if i % 100 == 0:
107            print(loss)
108
109
110step4()
111
112# Step 5: We're ready for multiple models. Let's define an init_fn
113# that, given a number of models, returns to us all of the weights.
114
115
116def init_fn(num_models):
117    models = [MLPClassifier().to(DEVICE) for _ in range(num_models)]
118    params, _ = stack_module_state(models)
119    return params
120
121
122# Step 6: Now, can we try multiple models at the same time?
123# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
124# on decreasing
125
126
127def step6():
128    parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
129    batched_weights = init_fn(num_models=2)
130    for i in range(2000):
131        loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels)
132        if i % 200 == 0:
133            print(loss)
134
135
136step6()
137
138# Step 7: Now, the flaw with step 6 is that we were training on the same exact
139# data. This can lead to all of the models in the ensemble overfitting in the
140# same way. The solution that http://willwhitney.com/parallel-training-jax.html
141# applies is to randomly subset the data in a way that the models do not recieve
142# exactly the same data in each training step!
143# Because the goal of this doc is to show that we can use eager-mode vmap to
144# achieve similar things as JAX, the rest of this is left as an exercise to the reader.
145
146# In conclusion, to achieve what http://willwhitney.com/parallel-training-jax.html
147# does, we used the following additional items that PyTorch does not have:
148# 1. NN module functional API that turns a module into a (state, state_less_fn) pair
149# 2. Functional optimizers
150# 3. A "functional" grad API (that effectively wraps autograd.grad)
151# 4. Composability between the functional grad API and torch.vmap.
152