xref: /aosp_15_r20/external/pytorch/test/onnx/model_defs/dcgan.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2import torch.nn as nn
3
4
5# configurable
6bsz = 64
7imgsz = 64
8nz = 100
9ngf = 64
10ndf = 64
11nc = 3
12
13
14# custom weights initialization called on netG and netD
15def weights_init(m):
16    classname = m.__class__.__name__
17    if classname.find("Conv") != -1:
18        m.weight.data.normal_(0.0, 0.02)
19    elif classname.find("BatchNorm") != -1:
20        m.weight.data.normal_(1.0, 0.02)
21        m.bias.data.fill_(0)
22
23
24class _netG(nn.Module):
25    def __init__(self, ngpu):
26        super().__init__()
27        self.ngpu = ngpu
28        self.main = nn.Sequential(
29            # input is Z, going into a convolution
30            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
31            nn.BatchNorm2d(ngf * 8),
32            nn.ReLU(True),
33            # state size. (ngf*8) x 4 x 4
34            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
35            nn.BatchNorm2d(ngf * 4),
36            nn.ReLU(True),
37            # state size. (ngf*4) x 8 x 8
38            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
39            nn.BatchNorm2d(ngf * 2),
40            nn.ReLU(True),
41            # state size. (ngf*2) x 16 x 16
42            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
43            nn.BatchNorm2d(ngf),
44            nn.ReLU(True),
45            # state size. (ngf) x 32 x 32
46            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
47            nn.Tanh(),
48            # state size. (nc) x 64 x 64
49        )
50
51    def forward(self, input):
52        if self.ngpu > 1 and isinstance(input.data, torch.cuda.FloatTensor):
53            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
54        else:
55            output = self.main(input)
56        return output
57
58
59class _netD(nn.Module):
60    def __init__(self, ngpu):
61        super().__init__()
62        self.ngpu = ngpu
63        self.main = nn.Sequential(
64            # input is (nc) x 64 x 64
65            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
66            nn.LeakyReLU(0.2, inplace=True),
67            # state size. (ndf) x 32 x 32
68            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
69            nn.BatchNorm2d(ndf * 2),
70            nn.LeakyReLU(0.2, inplace=True),
71            # state size. (ndf*2) x 16 x 16
72            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
73            nn.BatchNorm2d(ndf * 4),
74            nn.LeakyReLU(0.2, inplace=True),
75            # state size. (ndf*4) x 8 x 8
76            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
77            nn.BatchNorm2d(ndf * 8),
78            nn.LeakyReLU(0.2, inplace=True),
79            # state size. (ndf*8) x 4 x 4
80            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
81            nn.Sigmoid(),
82        )
83
84    def forward(self, input):
85        if self.ngpu > 1 and isinstance(input.data, torch.cuda.FloatTensor):
86            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
87        else:
88            output = self.main(input)
89
90        return output.view(-1, 1)
91