1import torch 2import torch.nn as nn 3 4 5# configurable 6bsz = 64 7imgsz = 64 8nz = 100 9ngf = 64 10ndf = 64 11nc = 3 12 13 14# custom weights initialization called on netG and netD 15def weights_init(m): 16 classname = m.__class__.__name__ 17 if classname.find("Conv") != -1: 18 m.weight.data.normal_(0.0, 0.02) 19 elif classname.find("BatchNorm") != -1: 20 m.weight.data.normal_(1.0, 0.02) 21 m.bias.data.fill_(0) 22 23 24class _netG(nn.Module): 25 def __init__(self, ngpu): 26 super().__init__() 27 self.ngpu = ngpu 28 self.main = nn.Sequential( 29 # input is Z, going into a convolution 30 nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), 31 nn.BatchNorm2d(ngf * 8), 32 nn.ReLU(True), 33 # state size. (ngf*8) x 4 x 4 34 nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 35 nn.BatchNorm2d(ngf * 4), 36 nn.ReLU(True), 37 # state size. (ngf*4) x 8 x 8 38 nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 39 nn.BatchNorm2d(ngf * 2), 40 nn.ReLU(True), 41 # state size. (ngf*2) x 16 x 16 42 nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 43 nn.BatchNorm2d(ngf), 44 nn.ReLU(True), 45 # state size. (ngf) x 32 x 32 46 nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), 47 nn.Tanh(), 48 # state size. (nc) x 64 x 64 49 ) 50 51 def forward(self, input): 52 if self.ngpu > 1 and isinstance(input.data, torch.cuda.FloatTensor): 53 output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 54 else: 55 output = self.main(input) 56 return output 57 58 59class _netD(nn.Module): 60 def __init__(self, ngpu): 61 super().__init__() 62 self.ngpu = ngpu 63 self.main = nn.Sequential( 64 # input is (nc) x 64 x 64 65 nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 66 nn.LeakyReLU(0.2, inplace=True), 67 # state size. (ndf) x 32 x 32 68 nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 69 nn.BatchNorm2d(ndf * 2), 70 nn.LeakyReLU(0.2, inplace=True), 71 # state size. (ndf*2) x 16 x 16 72 nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 73 nn.BatchNorm2d(ndf * 4), 74 nn.LeakyReLU(0.2, inplace=True), 75 # state size. (ndf*4) x 8 x 8 76 nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 77 nn.BatchNorm2d(ndf * 8), 78 nn.LeakyReLU(0.2, inplace=True), 79 # state size. (ndf*8) x 4 x 4 80 nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), 81 nn.Sigmoid(), 82 ) 83 84 def forward(self, input): 85 if self.ngpu > 1 and isinstance(input.data, torch.cuda.FloatTensor): 86 output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 87 else: 88 output = self.main(input) 89 90 return output.view(-1, 1) 91