xref: /aosp_15_r20/external/pytorch/functorch/examples/maml_omniglot/support/omniglot_loaders.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation:
16#     https://github.com/dragen1860/MAML-Pytorch
17#     https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py
18#     https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py
19
20import errno
21import os
22import os.path
23
24import numpy as np
25import torchvision.transforms as transforms
26from PIL import Image
27
28import torch
29import torch.utils.data as data
30
31
32class Omniglot(data.Dataset):
33    urls = [
34        "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip",
35        "https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip",
36    ]
37    raw_folder = "raw"
38    processed_folder = "processed"
39    training_file = "training.pt"
40    test_file = "test.pt"
41
42    """
43    The items are (filename,category). The index of all the categories can be found in self.idx_classes
44    Args:
45    - root: the directory where the dataset will be stored
46    - transform: how to transform the input
47    - target_transform: how to transform the target
48    - download: need to download the dataset
49    """
50
51    def __init__(self, root, transform=None, target_transform=None, download=False):
52        self.root = root
53        self.transform = transform
54        self.target_transform = target_transform
55
56        if not self._check_exists():
57            if download:
58                self.download()
59            else:
60                raise RuntimeError(
61                    "Dataset not found." + " You can use download=True to download it"
62                )
63
64        self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
65        self.idx_classes = index_classes(self.all_items)
66
67    def __getitem__(self, index):
68        filename = self.all_items[index][0]
69        img = str.join("/", [self.all_items[index][2], filename])
70
71        target = self.idx_classes[self.all_items[index][1]]
72        if self.transform is not None:
73            img = self.transform(img)
74        if self.target_transform is not None:
75            target = self.target_transform(target)
76
77        return img, target
78
79    def __len__(self):
80        return len(self.all_items)
81
82    def _check_exists(self):
83        return os.path.exists(
84            os.path.join(self.root, self.processed_folder, "images_evaluation")
85        ) and os.path.exists(
86            os.path.join(self.root, self.processed_folder, "images_background")
87        )
88
89    def download(self):
90        import urllib
91        import zipfile
92
93        if self._check_exists():
94            return
95
96        # download files
97        try:
98            os.makedirs(os.path.join(self.root, self.raw_folder))
99            os.makedirs(os.path.join(self.root, self.processed_folder))
100        except OSError as e:
101            if e.errno == errno.EEXIST:
102                pass
103            else:
104                raise
105
106        for url in self.urls:
107            print("== Downloading " + url)
108            data = urllib.request.urlopen(url)
109            filename = url.rpartition("/")[2]
110            file_path = os.path.join(self.root, self.raw_folder, filename)
111            with open(file_path, "wb") as f:
112                f.write(data.read())
113            file_processed = os.path.join(self.root, self.processed_folder)
114            print("== Unzip from " + file_path + " to " + file_processed)
115            zip_ref = zipfile.ZipFile(file_path, "r")
116            zip_ref.extractall(file_processed)
117            zip_ref.close()
118        print("Download finished.")
119
120
121def find_classes(root_dir):
122    retour = []
123    for root, dirs, files in os.walk(root_dir):
124        for f in files:
125            if f.endswith("png"):
126                r = root.split("/")
127                lr = len(r)
128                retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
129    print(f"== Found {len(retour)} items ")
130    return retour
131
132
133def index_classes(items):
134    idx = {}
135    for i in items:
136        if i[1] not in idx:
137            idx[i[1]] = len(idx)
138    print(f"== Found {len(idx)} classes")
139    return idx
140
141
142class OmniglotNShot:
143    def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None):
144        """
145        Different from mnistNShot, the
146        :param root:
147        :param batchsz: task num
148        :param n_way:
149        :param k_shot:
150        :param k_query:
151        :param imgsz:
152        """
153
154        self.resize = imgsz
155        self.device = device
156        if not os.path.isfile(os.path.join(root, "omniglot.npy")):
157            # if root/data.npy does not exist, just download it
158            self.x = Omniglot(
159                root,
160                download=True,
161                transform=transforms.Compose(
162                    [
163                        lambda x: Image.open(x).convert("L"),
164                        lambda x: x.resize((imgsz, imgsz)),
165                        lambda x: np.reshape(x, (imgsz, imgsz, 1)),
166                        lambda x: np.transpose(x, [2, 0, 1]),
167                        lambda x: x / 255.0,
168                    ]
169                ),
170            )
171
172            temp = (
173                {}
174            )  # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
175            for img, label in self.x:
176                if label in temp.keys():
177                    temp[label].append(img)
178                else:
179                    temp[label] = [img]
180
181            self.x = []
182            for (
183                label,
184                imgs,
185            ) in temp.items():  # labels info deserted , each label contains 20imgs
186                self.x.append(np.array(imgs))
187
188            # as different class may have different number of imgs
189            self.x = np.array(self.x).astype(
190                np.float64
191            )  # [[20 imgs],..., 1623 classes in total]
192            # each character contains 20 imgs
193            print("data shape:", self.x.shape)  # [1623, 20, 84, 84, 1]
194            temp = []  # Free memory
195            # save all dataset into npy file.
196            np.save(os.path.join(root, "omniglot.npy"), self.x)
197            print("write into omniglot.npy.")
198        else:
199            # if data.npy exists, just load it.
200            self.x = np.load(os.path.join(root, "omniglot.npy"))
201            print("load from omniglot.npy.")
202
203        # [1623, 20, 84, 84, 1]
204        # TODO: can not shuffle here, we must keep training and test set distinct!
205        self.x_train, self.x_test = self.x[:1200], self.x[1200:]
206
207        # self.normalization()
208
209        self.batchsz = batchsz
210        self.n_cls = self.x.shape[0]  # 1623
211        self.n_way = n_way  # n way
212        self.k_shot = k_shot  # k shot
213        self.k_query = k_query  # k query
214        assert (k_shot + k_query) <= 20
215
216        # save pointer of current read batch in total cache
217        self.indexes = {"train": 0, "test": 0}
218        self.datasets = {
219            "train": self.x_train,
220            "test": self.x_test,
221        }  # original data cached
222        print("DB: train", self.x_train.shape, "test", self.x_test.shape)
223
224        self.datasets_cache = {
225            "train": self.load_data_cache(
226                self.datasets["train"]
227            ),  # current epoch data cached
228            "test": self.load_data_cache(self.datasets["test"]),
229        }
230
231    def normalization(self):
232        """
233        Normalizes our data, to have a mean of 0 and sdt of 1
234        """
235        self.mean = np.mean(self.x_train)
236        self.std = np.std(self.x_train)
237        self.max = np.max(self.x_train)
238        self.min = np.min(self.x_train)
239        # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
240        self.x_train = (self.x_train - self.mean) / self.std
241        self.x_test = (self.x_test - self.mean) / self.std
242
243        self.mean = np.mean(self.x_train)
244        self.std = np.std(self.x_train)
245        self.max = np.max(self.x_train)
246        self.min = np.min(self.x_train)
247
248    # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
249
250    def load_data_cache(self, data_pack):
251        """
252        Collects several batches data for N-shot learning
253        :param data_pack: [cls_num, 20, 84, 84, 1]
254        :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
255        """
256        #  take 5 way 1 shot as example: 5 * 1
257        setsz = self.k_shot * self.n_way
258        querysz = self.k_query * self.n_way
259        data_cache = []
260
261        # print('preload next 50 caches of batchsz of batch.')
262        for sample in range(10):  # num of episodes
263            x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
264            for i in range(self.batchsz):  # one batch means one set
265                x_spt, y_spt, x_qry, y_qry = [], [], [], []
266                selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False)
267
268                for j, cur_class in enumerate(selected_cls):
269                    selected_img = np.random.choice(
270                        20, self.k_shot + self.k_query, False
271                    )
272
273                    # meta-training and meta-test
274                    x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]])
275                    x_qry.append(data_pack[cur_class][selected_img[self.k_shot :]])
276                    y_spt.append([j for _ in range(self.k_shot)])
277                    y_qry.append([j for _ in range(self.k_query)])
278
279                # shuffle inside a batch
280                perm = np.random.permutation(self.n_way * self.k_shot)
281                x_spt = np.array(x_spt).reshape(
282                    self.n_way * self.k_shot, 1, self.resize, self.resize
283                )[perm]
284                y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
285                perm = np.random.permutation(self.n_way * self.k_query)
286                x_qry = np.array(x_qry).reshape(
287                    self.n_way * self.k_query, 1, self.resize, self.resize
288                )[perm]
289                y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
290
291                # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
292                x_spts.append(x_spt)
293                y_spts.append(y_spt)
294                x_qrys.append(x_qry)
295                y_qrys.append(y_qry)
296
297            # [b, setsz, 1, 84, 84]
298            x_spts = (
299                np.array(x_spts)
300                .astype(np.float32)
301                .reshape(self.batchsz, setsz, 1, self.resize, self.resize)
302            )
303            y_spts = np.array(y_spts).astype(int).reshape(self.batchsz, setsz)
304            # [b, qrysz, 1, 84, 84]
305            x_qrys = (
306                np.array(x_qrys)
307                .astype(np.float32)
308                .reshape(self.batchsz, querysz, 1, self.resize, self.resize)
309            )
310            y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz)
311
312            x_spts, y_spts, x_qrys, y_qrys = (
313                torch.from_numpy(z).to(self.device)
314                for z in [x_spts, y_spts, x_qrys, y_qrys]
315            )
316
317            data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
318
319        return data_cache
320
321    def next(self, mode="train"):
322        """
323        Gets next batch from the dataset with name.
324        :param mode: The name of the splitting (one of "train", "val", "test")
325        :return:
326        """
327        # update cache if indexes is larger cached num
328        if self.indexes[mode] >= len(self.datasets_cache[mode]):
329            self.indexes[mode] = 0
330            self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])
331
332        next_batch = self.datasets_cache[mode][self.indexes[mode]]
333        self.indexes[mode] += 1
334
335        return next_batch
336