1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation: 16# https://github.com/dragen1860/MAML-Pytorch 17# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py 18# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py 19 20import errno 21import os 22import os.path 23 24import numpy as np 25import torchvision.transforms as transforms 26from PIL import Image 27 28import torch 29import torch.utils.data as data 30 31 32class Omniglot(data.Dataset): 33 urls = [ 34 "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip", 35 "https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip", 36 ] 37 raw_folder = "raw" 38 processed_folder = "processed" 39 training_file = "training.pt" 40 test_file = "test.pt" 41 42 """ 43 The items are (filename,category). The index of all the categories can be found in self.idx_classes 44 Args: 45 - root: the directory where the dataset will be stored 46 - transform: how to transform the input 47 - target_transform: how to transform the target 48 - download: need to download the dataset 49 """ 50 51 def __init__(self, root, transform=None, target_transform=None, download=False): 52 self.root = root 53 self.transform = transform 54 self.target_transform = target_transform 55 56 if not self._check_exists(): 57 if download: 58 self.download() 59 else: 60 raise RuntimeError( 61 "Dataset not found." + " You can use download=True to download it" 62 ) 63 64 self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) 65 self.idx_classes = index_classes(self.all_items) 66 67 def __getitem__(self, index): 68 filename = self.all_items[index][0] 69 img = str.join("/", [self.all_items[index][2], filename]) 70 71 target = self.idx_classes[self.all_items[index][1]] 72 if self.transform is not None: 73 img = self.transform(img) 74 if self.target_transform is not None: 75 target = self.target_transform(target) 76 77 return img, target 78 79 def __len__(self): 80 return len(self.all_items) 81 82 def _check_exists(self): 83 return os.path.exists( 84 os.path.join(self.root, self.processed_folder, "images_evaluation") 85 ) and os.path.exists( 86 os.path.join(self.root, self.processed_folder, "images_background") 87 ) 88 89 def download(self): 90 import urllib 91 import zipfile 92 93 if self._check_exists(): 94 return 95 96 # download files 97 try: 98 os.makedirs(os.path.join(self.root, self.raw_folder)) 99 os.makedirs(os.path.join(self.root, self.processed_folder)) 100 except OSError as e: 101 if e.errno == errno.EEXIST: 102 pass 103 else: 104 raise 105 106 for url in self.urls: 107 print("== Downloading " + url) 108 data = urllib.request.urlopen(url) 109 filename = url.rpartition("/")[2] 110 file_path = os.path.join(self.root, self.raw_folder, filename) 111 with open(file_path, "wb") as f: 112 f.write(data.read()) 113 file_processed = os.path.join(self.root, self.processed_folder) 114 print("== Unzip from " + file_path + " to " + file_processed) 115 zip_ref = zipfile.ZipFile(file_path, "r") 116 zip_ref.extractall(file_processed) 117 zip_ref.close() 118 print("Download finished.") 119 120 121def find_classes(root_dir): 122 retour = [] 123 for root, dirs, files in os.walk(root_dir): 124 for f in files: 125 if f.endswith("png"): 126 r = root.split("/") 127 lr = len(r) 128 retour.append((f, r[lr - 2] + "/" + r[lr - 1], root)) 129 print(f"== Found {len(retour)} items ") 130 return retour 131 132 133def index_classes(items): 134 idx = {} 135 for i in items: 136 if i[1] not in idx: 137 idx[i[1]] = len(idx) 138 print(f"== Found {len(idx)} classes") 139 return idx 140 141 142class OmniglotNShot: 143 def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None): 144 """ 145 Different from mnistNShot, the 146 :param root: 147 :param batchsz: task num 148 :param n_way: 149 :param k_shot: 150 :param k_query: 151 :param imgsz: 152 """ 153 154 self.resize = imgsz 155 self.device = device 156 if not os.path.isfile(os.path.join(root, "omniglot.npy")): 157 # if root/data.npy does not exist, just download it 158 self.x = Omniglot( 159 root, 160 download=True, 161 transform=transforms.Compose( 162 [ 163 lambda x: Image.open(x).convert("L"), 164 lambda x: x.resize((imgsz, imgsz)), 165 lambda x: np.reshape(x, (imgsz, imgsz, 1)), 166 lambda x: np.transpose(x, [2, 0, 1]), 167 lambda x: x / 255.0, 168 ] 169 ), 170 ) 171 172 temp = ( 173 {} 174 ) # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} 175 for img, label in self.x: 176 if label in temp.keys(): 177 temp[label].append(img) 178 else: 179 temp[label] = [img] 180 181 self.x = [] 182 for ( 183 label, 184 imgs, 185 ) in temp.items(): # labels info deserted , each label contains 20imgs 186 self.x.append(np.array(imgs)) 187 188 # as different class may have different number of imgs 189 self.x = np.array(self.x).astype( 190 np.float64 191 ) # [[20 imgs],..., 1623 classes in total] 192 # each character contains 20 imgs 193 print("data shape:", self.x.shape) # [1623, 20, 84, 84, 1] 194 temp = [] # Free memory 195 # save all dataset into npy file. 196 np.save(os.path.join(root, "omniglot.npy"), self.x) 197 print("write into omniglot.npy.") 198 else: 199 # if data.npy exists, just load it. 200 self.x = np.load(os.path.join(root, "omniglot.npy")) 201 print("load from omniglot.npy.") 202 203 # [1623, 20, 84, 84, 1] 204 # TODO: can not shuffle here, we must keep training and test set distinct! 205 self.x_train, self.x_test = self.x[:1200], self.x[1200:] 206 207 # self.normalization() 208 209 self.batchsz = batchsz 210 self.n_cls = self.x.shape[0] # 1623 211 self.n_way = n_way # n way 212 self.k_shot = k_shot # k shot 213 self.k_query = k_query # k query 214 assert (k_shot + k_query) <= 20 215 216 # save pointer of current read batch in total cache 217 self.indexes = {"train": 0, "test": 0} 218 self.datasets = { 219 "train": self.x_train, 220 "test": self.x_test, 221 } # original data cached 222 print("DB: train", self.x_train.shape, "test", self.x_test.shape) 223 224 self.datasets_cache = { 225 "train": self.load_data_cache( 226 self.datasets["train"] 227 ), # current epoch data cached 228 "test": self.load_data_cache(self.datasets["test"]), 229 } 230 231 def normalization(self): 232 """ 233 Normalizes our data, to have a mean of 0 and sdt of 1 234 """ 235 self.mean = np.mean(self.x_train) 236 self.std = np.std(self.x_train) 237 self.max = np.max(self.x_train) 238 self.min = np.min(self.x_train) 239 # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) 240 self.x_train = (self.x_train - self.mean) / self.std 241 self.x_test = (self.x_test - self.mean) / self.std 242 243 self.mean = np.mean(self.x_train) 244 self.std = np.std(self.x_train) 245 self.max = np.max(self.x_train) 246 self.min = np.min(self.x_train) 247 248 # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) 249 250 def load_data_cache(self, data_pack): 251 """ 252 Collects several batches data for N-shot learning 253 :param data_pack: [cls_num, 20, 84, 84, 1] 254 :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks 255 """ 256 # take 5 way 1 shot as example: 5 * 1 257 setsz = self.k_shot * self.n_way 258 querysz = self.k_query * self.n_way 259 data_cache = [] 260 261 # print('preload next 50 caches of batchsz of batch.') 262 for sample in range(10): # num of episodes 263 x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] 264 for i in range(self.batchsz): # one batch means one set 265 x_spt, y_spt, x_qry, y_qry = [], [], [], [] 266 selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False) 267 268 for j, cur_class in enumerate(selected_cls): 269 selected_img = np.random.choice( 270 20, self.k_shot + self.k_query, False 271 ) 272 273 # meta-training and meta-test 274 x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]]) 275 x_qry.append(data_pack[cur_class][selected_img[self.k_shot :]]) 276 y_spt.append([j for _ in range(self.k_shot)]) 277 y_qry.append([j for _ in range(self.k_query)]) 278 279 # shuffle inside a batch 280 perm = np.random.permutation(self.n_way * self.k_shot) 281 x_spt = np.array(x_spt).reshape( 282 self.n_way * self.k_shot, 1, self.resize, self.resize 283 )[perm] 284 y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] 285 perm = np.random.permutation(self.n_way * self.k_query) 286 x_qry = np.array(x_qry).reshape( 287 self.n_way * self.k_query, 1, self.resize, self.resize 288 )[perm] 289 y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] 290 291 # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84] 292 x_spts.append(x_spt) 293 y_spts.append(y_spt) 294 x_qrys.append(x_qry) 295 y_qrys.append(y_qry) 296 297 # [b, setsz, 1, 84, 84] 298 x_spts = ( 299 np.array(x_spts) 300 .astype(np.float32) 301 .reshape(self.batchsz, setsz, 1, self.resize, self.resize) 302 ) 303 y_spts = np.array(y_spts).astype(int).reshape(self.batchsz, setsz) 304 # [b, qrysz, 1, 84, 84] 305 x_qrys = ( 306 np.array(x_qrys) 307 .astype(np.float32) 308 .reshape(self.batchsz, querysz, 1, self.resize, self.resize) 309 ) 310 y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz) 311 312 x_spts, y_spts, x_qrys, y_qrys = ( 313 torch.from_numpy(z).to(self.device) 314 for z in [x_spts, y_spts, x_qrys, y_qrys] 315 ) 316 317 data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) 318 319 return data_cache 320 321 def next(self, mode="train"): 322 """ 323 Gets next batch from the dataset with name. 324 :param mode: The name of the splitting (one of "train", "val", "test") 325 :return: 326 """ 327 # update cache if indexes is larger cached num 328 if self.indexes[mode] >= len(self.datasets_cache[mode]): 329 self.indexes[mode] = 0 330 self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode]) 331 332 next_batch = self.datasets_cache[mode][self.indexes[mode]] 333 self.indexes[mode] += 1 334 335 return next_batch 336