I’ve been struggling with this patch of code where I cannot seem to get the size of my matrices right. This one is based on a batch size of 4.
If you have trouble rendering the code use https://nbviewer.jupyter.org/
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 2)
# super(Net, self).__init__()
# self.conv1 = nn.Conv2d(3, 32, 5)
# self.pool = nn.MaxPool2d(2, 2)
# self.conv2 = nn.Conv2d(32, 64, 5)
# self.fc1 = nn.Linear(64*9*9, 1024)
# self.fc2 = nn.Linear(1024, 7)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
# x = self.pool(F.relu(self.conv1(x)))
# x = self.pool(F.relu(self.conv2(x)))
# x = x.view(x.size(0), -1)
# #x = x.view(-1, 64)
# #x = F.relu(self.fc1(x))
# #x = F.relu(self.fc2(x))
# #x = self.fc2(x)
return x
net = Net()
And when try the net() this way, I get an index out of range error:
class Net(nn.Module):
def __init__(self):
# super().__init__()
# self.conv1 = nn.Conv2d(3, 6, 5)
# self.pool = nn.MaxPool2d(2, 2)
# self.conv2 = nn.Conv2d(6, 16, 5)
# self.fc1 = nn.Linear(16 * 5 * 5, 120)
# self.fc2 = nn.Linear(120, 84)
# self.fc3 = nn.Linear(84, 2)
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 5)
self.fc1 = nn.Linear(64*9*9, 1024)
self.fc2 = nn.Linear(1024, 7)
def forward(self, x):
# x = self.pool(F.relu(self.conv1(x)))
# x = self.pool(F.relu(self.conv2(x)))
# x = torch.flatten(x, 1) # flatten all dimensions except batch
# x = F.relu(self.fc1(x))
# x = F.relu(self.fc2(x))
# x = self.fc3(x)
# return x
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
#x = x.view(-1, 64)
#x = F.relu(self.fc1(x))
#x = F.relu(self.fc2(x))
#x = self.fc2(x)
return x