This is my model:
class CNNModel(nn.Module):
def __init__(self):
super (CNNModel, self).__init__()
self.cnn1 = nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 7, stride = 4, padding = 3)
self.relu1 = nn.ReLU()
nn.init.xavier_uniform_(self.cnn1.weight)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.cnn2 = nn.Conv2d(in_channels = 16, out_channels = 64, kernel_size = 5, stride = 1, padding = 2)
self.relu2 = nn.ReLU()
nn.init.xavier_uniform_(self.cnn2.weight)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.fc1 = nn.Linear(3136, 10)
self.fcrelu = nn.ReLU()
self.fc2 = nn.Linear(3136, 10)
def forward(self, x):
out = self.cnn1(x)
out = self.relu1(out)
out = self.maxpool1(out)
out = self.cnn2(out)
out = self.relu2(out)
out = self.maxpool2(out)
out = out.view(out.size(0),-1)
out = self.fc1(out)
out = self.fcrelu(out)
out = self.fc2(out)
return out
When I run the following code i get the error as stated in the title:
Note: The shape of state is (1,4,110,84)
state = env.reset()
state, stacked_frames = stack_frames(stacked_frames, state, True)
state = torch.from_numpy(state)
state = state.unsqueeze(0)
print(state.shape)
Qs = model(Variable(state))
_, predicted = torch.max(output.data,1)
print(predicted)