Hi.
in torch 1.7.0 , I can’t debug this error.
class DQN(nn.Module):
def __init__(self, num_states, num_actions):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(3,8,kernel_size=3,stride=2)
self.bn1 = nn.BatchNorm2d(8)
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2)
self.bn2 = nn.BatchNorm2d(16)
self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=2)
self.bn3 = nn.BatchNorm2d(32)
def conv2d_size_out(size, kernel_size = 3, stride = 1):
return (size - (kernel_size - 1) - 1) // stride + 1
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(2)))
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(num_states)))
linear_input_size = convw * convh * 32
self.head = nn.Linear(linear_input_size, num_actions)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
return self.head(x.view(x.size(0), -1))
input : [[num1,num2,num3,num4,num5],[num6,num7,num8,num9,num10]]