xref: /aosp_15_r20/external/pytorch/functorch/examples/maml_omniglot/maml-omniglot-transforms.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2#
3# Copyright (c) Facebook, Inc. and its affiliates.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""
18This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
19for few-shot Omniglot classification.
20For more details see the original MAML paper:
21https://arxiv.org/abs/1703.03400
22
23This code has been modified from Jackie Loong's PyTorch MAML implementation:
24https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
25
26Our MAML++ fork and experiments are available at:
27https://github.com/bamos/HowToTrainYourMAMLPytorch
28"""
29
30import argparse
31import functools
32import time
33
34import matplotlib as mpl
35import matplotlib.pyplot as plt
36import numpy as np
37import pandas as pd
38from support.omniglot_loaders import OmniglotNShot
39
40import torch
41import torch.nn.functional as F
42import torch.optim as optim
43from torch import nn
44from torch.func import functional_call, grad, vmap
45
46
47mpl.use("Agg")
48plt.style.use("bmh")
49
50
51def main():
52    argparser = argparse.ArgumentParser()
53    argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5)
54    argparser.add_argument(
55        "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5
56    )
57    argparser.add_argument(
58        "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15
59    )
60    argparser.add_argument("--device", type=str, help="device", default="cuda")
61    argparser.add_argument(
62        "--task-num",
63        "--task_num",
64        type=int,
65        help="meta batch size, namely task num",
66        default=32,
67    )
68    argparser.add_argument("--seed", type=int, help="random seed", default=1)
69    args = argparser.parse_args()
70
71    torch.manual_seed(args.seed)
72    if torch.cuda.is_available():
73        torch.cuda.manual_seed_all(args.seed)
74    np.random.seed(args.seed)
75
76    # Set up the Omniglot loader.
77    device = args.device
78    db = OmniglotNShot(
79        "/tmp/omniglot-data",
80        batchsz=args.task_num,
81        n_way=args.n_way,
82        k_shot=args.k_spt,
83        k_query=args.k_qry,
84        imgsz=28,
85        device=device,
86    )
87
88    # Create a vanilla PyTorch neural network.
89    inplace_relu = True
90    net = nn.Sequential(
91        nn.Conv2d(1, 64, 3),
92        nn.BatchNorm2d(64, affine=True, track_running_stats=False),
93        nn.ReLU(inplace=inplace_relu),
94        nn.MaxPool2d(2, 2),
95        nn.Conv2d(64, 64, 3),
96        nn.BatchNorm2d(64, affine=True, track_running_stats=False),
97        nn.ReLU(inplace=inplace_relu),
98        nn.MaxPool2d(2, 2),
99        nn.Conv2d(64, 64, 3),
100        nn.BatchNorm2d(64, affine=True, track_running_stats=False),
101        nn.ReLU(inplace=inplace_relu),
102        nn.MaxPool2d(2, 2),
103        nn.Flatten(),
104        nn.Linear(64, args.n_way),
105    ).to(device)
106
107    net.train()
108
109    # We will use Adam to (meta-)optimize the initial parameters
110    # to be adapted.
111    meta_opt = optim.Adam(net.parameters(), lr=1e-3)
112
113    log = []
114    for epoch in range(100):
115        train(db, net, device, meta_opt, epoch, log)
116        test(db, net, device, epoch, log)
117        plot(log)
118
119
120# Trains a model for n_inner_iter using the support and returns a loss
121# using the query.
122def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
123    params = dict(net.named_parameters())
124    buffers = dict(net.named_buffers())
125    querysz = x_qry.size(0)
126
127    def compute_loss(new_params, buffers, x, y):
128        logits = functional_call(net, (new_params, buffers), x)
129        loss = F.cross_entropy(logits, y)
130        return loss
131
132    new_params = params
133    for _ in range(n_inner_iter):
134        grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
135        new_params = {k: new_params[k] - g * 1e-1 for k, g, in grads.items()}
136
137    # The final set of adapted parameters will induce some
138    # final loss and accuracy on the query dataset.
139    # These will be used to update the model's meta-parameters.
140    qry_logits = functional_call(net, (new_params, buffers), x_qry)
141    qry_loss = F.cross_entropy(qry_logits, y_qry)
142    qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz
143
144    return qry_loss, qry_acc
145
146
147def train(db, net, device, meta_opt, epoch, log):
148    params = dict(net.named_parameters())
149    buffers = dict(net.named_buffers())
150    n_train_iter = db.x_train.shape[0] // db.batchsz
151
152    for batch_idx in range(n_train_iter):
153        start_time = time.time()
154        # Sample a batch of support and query images and labels.
155        x_spt, y_spt, x_qry, y_qry = db.next()
156
157        task_num, setsz, c_, h, w = x_spt.size()
158
159        n_inner_iter = 5
160        meta_opt.zero_grad()
161
162        # In parallel, trains one model per task. There is a support (x, y)
163        # for each task and a query (x, y) for each task.
164        compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter)
165        qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry)
166
167        # Compute the maml loss by summing together the returned losses.
168        qry_losses.sum().backward()
169
170        meta_opt.step()
171        qry_losses = qry_losses.detach().sum() / task_num
172        qry_accs = 100.0 * qry_accs.sum() / task_num
173        i = epoch + float(batch_idx) / n_train_iter
174        iter_time = time.time() - start_time
175        if batch_idx % 4 == 0:
176            print(
177                f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}"
178            )
179
180        log.append(
181            {
182                "epoch": i,
183                "loss": qry_losses,
184                "acc": qry_accs,
185                "mode": "train",
186                "time": time.time(),
187            }
188        )
189
190
191def test(db, net, device, epoch, log):
192    # Crucially in our testing procedure here, we do *not* fine-tune
193    # the model during testing for simplicity.
194    # Most research papers using MAML for this task do an extra
195    # stage of fine-tuning here that should be added if you are
196    # adapting this code for research.
197    params = dict(net.named_parameters())
198    buffers = dict(net.named_buffers())
199    n_test_iter = db.x_test.shape[0] // db.batchsz
200
201    qry_losses = []
202    qry_accs = []
203
204    for batch_idx in range(n_test_iter):
205        x_spt, y_spt, x_qry, y_qry = db.next("test")
206        task_num, setsz, c_, h, w = x_spt.size()
207
208        # TODO: Maybe pull this out into a separate module so it
209        # doesn't have to be duplicated between `train` and `test`?
210        n_inner_iter = 5
211
212        for i in range(task_num):
213            new_params = params
214            for _ in range(n_inner_iter):
215                spt_logits = functional_call(net, (new_params, buffers), x_spt[i])
216                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
217                grads = torch.autograd.grad(spt_loss, new_params.values())
218                new_params = {
219                    k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads)
220                }
221
222            # The query loss and acc induced by these parameters.
223            qry_logits = functional_call(net, (new_params, buffers), x_qry[i]).detach()
224            qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none")
225            qry_losses.append(qry_loss.detach())
226            qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
227
228    qry_losses = torch.cat(qry_losses).mean().item()
229    qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
230    print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}")
231    log.append(
232        {
233            "epoch": epoch + 1,
234            "loss": qry_losses,
235            "acc": qry_accs,
236            "mode": "test",
237            "time": time.time(),
238        }
239    )
240
241
242def plot(log):
243    # Generally you should pull your plotting code out of your training
244    # script but we are doing it here for brevity.
245    df = pd.DataFrame(log)
246
247    fig, ax = plt.subplots(figsize=(6, 4))
248    train_df = df[df["mode"] == "train"]
249    test_df = df[df["mode"] == "test"]
250    ax.plot(train_df["epoch"], train_df["acc"], label="Train")
251    ax.plot(test_df["epoch"], test_df["acc"], label="Test")
252    ax.set_xlabel("Epoch")
253    ax.set_ylabel("Accuracy")
254    ax.set_ylim(70, 100)
255    fig.legend(ncol=2, loc="lower right")
256    fig.tight_layout()
257    fname = "maml-accs.png"
258    print(f"--- Plotting accuracy to {fname}")
259    fig.savefig(fname)
260    plt.close(fig)
261
262
263if __name__ == "__main__":
264    main()
265