1#!/usr/bin/env python3 2# 3# Copyright (c) Facebook, Inc. and its affiliates. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17""" 18This example shows how to use higher to do Model Agnostic Meta Learning (MAML) 19for few-shot Omniglot classification. 20For more details see the original MAML paper: 21https://arxiv.org/abs/1703.03400 22 23This code has been modified from Jackie Loong's PyTorch MAML implementation: 24https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py 25 26Our MAML++ fork and experiments are available at: 27https://github.com/bamos/HowToTrainYourMAMLPytorch 28""" 29 30import argparse 31import functools 32import time 33 34import matplotlib as mpl 35import matplotlib.pyplot as plt 36import numpy as np 37import pandas as pd 38from support.omniglot_loaders import OmniglotNShot 39 40import torch 41import torch.nn.functional as F 42import torch.optim as optim 43from torch import nn 44from torch.func import functional_call, grad, vmap 45 46 47mpl.use("Agg") 48plt.style.use("bmh") 49 50 51def main(): 52 argparser = argparse.ArgumentParser() 53 argparser.add_argument("--n-way", "--n_way", type=int, help="n way", default=5) 54 argparser.add_argument( 55 "--k-spt", "--k_spt", type=int, help="k shot for support set", default=5 56 ) 57 argparser.add_argument( 58 "--k-qry", "--k_qry", type=int, help="k shot for query set", default=15 59 ) 60 argparser.add_argument("--device", type=str, help="device", default="cuda") 61 argparser.add_argument( 62 "--task-num", 63 "--task_num", 64 type=int, 65 help="meta batch size, namely task num", 66 default=32, 67 ) 68 argparser.add_argument("--seed", type=int, help="random seed", default=1) 69 args = argparser.parse_args() 70 71 torch.manual_seed(args.seed) 72 if torch.cuda.is_available(): 73 torch.cuda.manual_seed_all(args.seed) 74 np.random.seed(args.seed) 75 76 # Set up the Omniglot loader. 77 device = args.device 78 db = OmniglotNShot( 79 "/tmp/omniglot-data", 80 batchsz=args.task_num, 81 n_way=args.n_way, 82 k_shot=args.k_spt, 83 k_query=args.k_qry, 84 imgsz=28, 85 device=device, 86 ) 87 88 # Create a vanilla PyTorch neural network. 89 inplace_relu = True 90 net = nn.Sequential( 91 nn.Conv2d(1, 64, 3), 92 nn.BatchNorm2d(64, affine=True, track_running_stats=False), 93 nn.ReLU(inplace=inplace_relu), 94 nn.MaxPool2d(2, 2), 95 nn.Conv2d(64, 64, 3), 96 nn.BatchNorm2d(64, affine=True, track_running_stats=False), 97 nn.ReLU(inplace=inplace_relu), 98 nn.MaxPool2d(2, 2), 99 nn.Conv2d(64, 64, 3), 100 nn.BatchNorm2d(64, affine=True, track_running_stats=False), 101 nn.ReLU(inplace=inplace_relu), 102 nn.MaxPool2d(2, 2), 103 nn.Flatten(), 104 nn.Linear(64, args.n_way), 105 ).to(device) 106 107 net.train() 108 109 # We will use Adam to (meta-)optimize the initial parameters 110 # to be adapted. 111 meta_opt = optim.Adam(net.parameters(), lr=1e-3) 112 113 log = [] 114 for epoch in range(100): 115 train(db, net, device, meta_opt, epoch, log) 116 test(db, net, device, epoch, log) 117 plot(log) 118 119 120# Trains a model for n_inner_iter using the support and returns a loss 121# using the query. 122def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry): 123 params = dict(net.named_parameters()) 124 buffers = dict(net.named_buffers()) 125 querysz = x_qry.size(0) 126 127 def compute_loss(new_params, buffers, x, y): 128 logits = functional_call(net, (new_params, buffers), x) 129 loss = F.cross_entropy(logits, y) 130 return loss 131 132 new_params = params 133 for _ in range(n_inner_iter): 134 grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) 135 new_params = {k: new_params[k] - g * 1e-1 for k, g, in grads.items()} 136 137 # The final set of adapted parameters will induce some 138 # final loss and accuracy on the query dataset. 139 # These will be used to update the model's meta-parameters. 140 qry_logits = functional_call(net, (new_params, buffers), x_qry) 141 qry_loss = F.cross_entropy(qry_logits, y_qry) 142 qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz 143 144 return qry_loss, qry_acc 145 146 147def train(db, net, device, meta_opt, epoch, log): 148 params = dict(net.named_parameters()) 149 buffers = dict(net.named_buffers()) 150 n_train_iter = db.x_train.shape[0] // db.batchsz 151 152 for batch_idx in range(n_train_iter): 153 start_time = time.time() 154 # Sample a batch of support and query images and labels. 155 x_spt, y_spt, x_qry, y_qry = db.next() 156 157 task_num, setsz, c_, h, w = x_spt.size() 158 159 n_inner_iter = 5 160 meta_opt.zero_grad() 161 162 # In parallel, trains one model per task. There is a support (x, y) 163 # for each task and a query (x, y) for each task. 164 compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter) 165 qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry) 166 167 # Compute the maml loss by summing together the returned losses. 168 qry_losses.sum().backward() 169 170 meta_opt.step() 171 qry_losses = qry_losses.detach().sum() / task_num 172 qry_accs = 100.0 * qry_accs.sum() / task_num 173 i = epoch + float(batch_idx) / n_train_iter 174 iter_time = time.time() - start_time 175 if batch_idx % 4 == 0: 176 print( 177 f"[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}" 178 ) 179 180 log.append( 181 { 182 "epoch": i, 183 "loss": qry_losses, 184 "acc": qry_accs, 185 "mode": "train", 186 "time": time.time(), 187 } 188 ) 189 190 191def test(db, net, device, epoch, log): 192 # Crucially in our testing procedure here, we do *not* fine-tune 193 # the model during testing for simplicity. 194 # Most research papers using MAML for this task do an extra 195 # stage of fine-tuning here that should be added if you are 196 # adapting this code for research. 197 params = dict(net.named_parameters()) 198 buffers = dict(net.named_buffers()) 199 n_test_iter = db.x_test.shape[0] // db.batchsz 200 201 qry_losses = [] 202 qry_accs = [] 203 204 for batch_idx in range(n_test_iter): 205 x_spt, y_spt, x_qry, y_qry = db.next("test") 206 task_num, setsz, c_, h, w = x_spt.size() 207 208 # TODO: Maybe pull this out into a separate module so it 209 # doesn't have to be duplicated between `train` and `test`? 210 n_inner_iter = 5 211 212 for i in range(task_num): 213 new_params = params 214 for _ in range(n_inner_iter): 215 spt_logits = functional_call(net, (new_params, buffers), x_spt[i]) 216 spt_loss = F.cross_entropy(spt_logits, y_spt[i]) 217 grads = torch.autograd.grad(spt_loss, new_params.values()) 218 new_params = { 219 k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads) 220 } 221 222 # The query loss and acc induced by these parameters. 223 qry_logits = functional_call(net, (new_params, buffers), x_qry[i]).detach() 224 qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction="none") 225 qry_losses.append(qry_loss.detach()) 226 qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) 227 228 qry_losses = torch.cat(qry_losses).mean().item() 229 qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() 230 print(f"[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}") 231 log.append( 232 { 233 "epoch": epoch + 1, 234 "loss": qry_losses, 235 "acc": qry_accs, 236 "mode": "test", 237 "time": time.time(), 238 } 239 ) 240 241 242def plot(log): 243 # Generally you should pull your plotting code out of your training 244 # script but we are doing it here for brevity. 245 df = pd.DataFrame(log) 246 247 fig, ax = plt.subplots(figsize=(6, 4)) 248 train_df = df[df["mode"] == "train"] 249 test_df = df[df["mode"] == "test"] 250 ax.plot(train_df["epoch"], train_df["acc"], label="Train") 251 ax.plot(test_df["epoch"], test_df["acc"], label="Test") 252 ax.set_xlabel("Epoch") 253 ax.set_ylabel("Accuracy") 254 ax.set_ylim(70, 100) 255 fig.legend(ncol=2, loc="lower right") 256 fig.tight_layout() 257 fname = "maml-accs.png" 258 print(f"--- Plotting accuracy to {fname}") 259 fig.savefig(fname) 260 plt.close(fig) 261 262 263if __name__ == "__main__": 264 main() 265