1import torch.nn as nn 2import torch.nn.functional as F 3 4 5class MNIST(nn.Module): 6 def __init__(self) -> None: 7 super().__init__() 8 self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 9 self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 10 self.conv2_drop = nn.Dropout2d() 11 self.fc1 = nn.Linear(320, 50) 12 self.fc2 = nn.Linear(50, 10) 13 14 def forward(self, x): 15 x = F.relu(F.max_pool2d(self.conv1(x), 2)) 16 x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 17 x = x.view(-1, 320) 18 x = F.relu(self.fc1(x)) 19 x = F.dropout(x, training=self.training) 20 x = self.fc2(x) 21 return F.log_softmax(x, dim=1) 22