Hi there! I am studying convolution networks with reinforcemente learning to tackle some simple games. I already developed a simple network comprising only full connected hidden layers. The network for that task was:
def __init__(self,input_size,output_size): super(NeuralNetworkLinear, self).__init__() n=200 self.fc1 = nn.Linear(input_size, n) self.fc2 = nn.Linear(n, n) self.fc3 = nn.Linear(n, output_size) def forward(self, state): x = F.relu(self.fc1(state)) x = F.relu(self.fc2(x)) q_values = self.fc3(x) return q_values
To make a forward pass I would do something like:
def select_action(self, state): state = torch.Tensor(state).unsqueeze(0) board=Variable(state) probs = F.softmax(self.model(board)*self.temperature) action = probs.multinomial()
where state is a numpy 1D array comprising the X and Y position of some Bot and its target.
Now I want to change my approach. I want to input the 2D grid to a convolution network. This grid is a numpy 2D array with values between 0 and 255, or 0 and 1, representing a “gray-scale” representation of the game world with 10x10 size. I copied this template convolution network from the pytorch tutorials:
def __init__(self,input_size,output_size): super(CNN, self).__init__() self.conv1 = nn.Conv2d(1, 16, kernel_size=2, stride=1) self.bn1 = nn.BatchNorm2d(16) self.conv2 = nn.Conv2d(16, 32, kernel_size=2, stride=1) self.bn2 = nn.BatchNorm2d(32) self.conv3 = nn.Conv2d(32, 32, kernel_size=2, stride=1) self.bn3 = nn.BatchNorm2d(32) self.head = nn.Linear(448, 4) def forward(self, state): x = F.relu(self.bn1(self.conv1(state))) x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) return self.head(x.view(x.size(0), -1))
But now, when i use this new network, i get the following error:
“Expected 4D tensor as input, got 2D tensor instead.”. This error occur in “x = F.relu(self.bn1(self.conv1(state)))”
I don’t really know how to apprach this. I just want to convert my numpy 2D array to the right input to the convolution network. Can someone help me with this?
Thank you in advance