Hi, this is my first time writing a Neural Network using PyTorch and I encountered the following error
'Linear' object has no attribute 'log_softmax'
Here’s my code:
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, # 224 x 224 -> 222 x 222
out_channels = 64,
kernel_size = 3)
self.maxpool1 = nn.MaxPool2d(kernel_size=2) # 222 x 222 -> 111 x 111
self.batchnorm1 = nn.BatchNorm2d(64)
self.dropout1 = nn.Dropout(p = 0.5)
self.conv2 = nn.Conv2d(in_channels = 64, # 111 x 111 -> 106 x 106
out_channels = 128,
kernel_size = 5)
self.maxpool2 = nn.MaxPool2d(kernel_size = 2) # 106 x 106 -> 53 x 53
self.batchnorm2 = nn.BatchNorm2d(128)
self.dropout2 = nn.Dropout(p = 0.3)
self.conv3 = nn.Conv2d(in_channels = 128, # 53 x 53 -> 46 x 46
out_channels = 256,
kernel_size = 7)
self.maxpool3 = nn.MaxPool2d(kernel_size=2) # 46 x 46 -> 23 x 23
self.batchnorm3 = nn.BatchNorm2d(256)
self.dropout3 = nn.Dropout(p = 0.2)
self.fc1 = nn.Linear(256*23*23, 1000)
self.dropout4 = nn.Dropout(p = 0.1)
self.fc2 = nn.Linear(1000, 2)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.batchnorm1(x)
x = self.dropout1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.batchnorm2(x)
x = self.dropout2(x)
x = F.relu(x)
x = self.conv3(x)
x = self.maxpool3(x)
x = self.batchnorm3(x)
x = self.dropout3(x)
x = F.relu(x)
print(x.shape)
x = x.view(x.size(0), -1)
print(x.shape)
x = self.fc1(x)
x = self.dropout4(x)
x = F.relu(x)
x = self.fc2
return x
Here’s the code for my training loop
In_epochs = 10
lr = 0.001
cec_loss = nn.CrossEntropyLoss()
params = model.parameters()
optimizer = optim.Adam(params = params, lr = 0.001)
n_epochs = 3
n_iterations = 0
losses = []
for e in range(n_epochs):
counter = 0
for i, (images, labels) in enumerate(train_loader):
images = Variable(images)
labels = Variable(labels)
output = model(images)
model.zero_grad()
loss = cec_loss(output, labels)
loss.backward()
optimizer.step()
n_iterations +=1
counter += 1
if counter % 100 == 0:
print(loss.item())