i am trying to get the input and output information of a network. When debugging, i got this error, Runtime, shape ‘[-1, 400]’ is invalid for input of size 384. I tried different values, but can’t find the correct value. Is there a way to solve this issue? Thanks.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
input_shape = (3, 21,21)
dummy_input = torch.randn(6,*input_shape)
graph = torch.jit._get_trace_graph(model, args=dummy_input, _force_outplace=False, _return_inputs_states=False)
Error meaasge:
RuntimeError: shape '[-1, 400]' is invalid for input of size 384