# Correct input features for fc1
self.fc1_input_features = self._get_fc1_input_features()
self.fc1 = nn.Linear(self.fc1_input_features, 128)
self.fc2 = nn.Linear(128, 6)
def _get_fc1_input_features(self):
with torch.no_grad():
x = torch.zeros(1, 1, 3, 1024) # Simulate input
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(x.size(0), -1)
return x.size(1)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(x.size(0), -1) # Correct shape
x = F.relu(self.fc1(x))
x = self.fc2(x) # Removed extra F.relu
return x
Your code is incomplete so could you post the missing code snippets including the shape of the input to create an executable code snippet reproducing the issue?