class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.cnn1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=7, stride=4, padding=3)
self.relu1 = nn.ReLU()Preformatted text
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.cnn2 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=5, stride=1, padding=2)
self.relu2 = nn.ReLU()
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.fc1 = nn.Linear(3136, 10)
self.fcrelu = nn.ReLU()
def forward(self, x):
out = self.cnn1(x)
out = self.relu1(out)
out = self.maxpool1(out)
out = self.cnn2(out)
out = self.relu2(out)
out = self.maxpool2(out)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.fcrelu(out)
return out