Hi there! I am studying convolution networks with reinforcemente learning to tackle some simple games. I already developed a simple network comprising only full connected hidden layers. The network for that task was:
class NeuralNetworkLinear(nn.Module):
def __init__(self,input_size,output_size):
super(NeuralNetworkLinear, self).__init__()
n=200
self.fc1 = nn.Linear(input_size, n)
self.fc2 = nn.Linear(n, n)
self.fc3 = nn.Linear(n, output_size)
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
q_values = self.fc3(x)
return q_values
To make a forward pass I would do something like:
def select_action(self, state):
state = torch.Tensor(state).unsqueeze(0)
board=Variable(state)
probs = F.softmax(self.model(board)*self.temperature)
action = probs.multinomial()
where state is a numpy 1D array comprising the X and Y position of some Bot and its target.
Now I want to change my approach. I want to input the 2D grid to a convolution network. This grid is a numpy 2D array with values between 0 and 255, or 0 and 1, representing a “gray-scale” representation of the game world with 10x10 size. I copied this template convolution network from the pytorch tutorials:
class CNN(nn.Module):
def __init__(self,input_size,output_size):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=2, stride=1)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=2, stride=1)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 32, kernel_size=2, stride=1)
self.bn3 = nn.BatchNorm2d(32)
self.head = nn.Linear(448, 4)
def forward(self, state):
x = F.relu(self.bn1(self.conv1(state)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
return self.head(x.view(x.size(0), -1))
But now, when i use this new network, i get the following error:
“Expected 4D tensor as input, got 2D tensor instead.”. This error occur in “x = F.relu(self.bn1(self.conv1(state)))”
I don’t really know how to apprach this. I just want to convert my numpy 2D array to the right input to the convolution network. Can someone help me with this?
Thank you in advance