print(model(Variable(obs, volatile=True)))
I encounter when executing this line
obs is a size(1, 4, 84, 84) ByteTensor corresponding to (input_size, channels, height, width)
Here is my model definition:
class DQN(nn.Module):
def __init__(self, in_channels=4, num_actions=18):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
self.fc4 = nn.Linear(7 * 7 * 64, 512)
self.fc5 = nn.Linear(512, num_actions)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.fc4(x.view(x.size(0), -1)))
return self.fc5(x)
Below are errors:
/Users/mac/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
208
209 def __call__(self, *input, **kwargs):
--> 210 result = self.forward(*input, **kwargs)
211 for hook in self._forward_hooks.values():
212 hook_result = hook(self, input, result)
<ipython-input-25-3131e91fac28> in forward(self, x)
17
18 def forward(self, x):
---> 19 x = F.relu(self.conv1(x))
20 x = F.relu(self.conv2(x))
21 x = F.relu(self.conv3(x))
/Users/mac/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
208
209 def __call__(self, *input, **kwargs):
--> 210 result = self.forward(*input, **kwargs)
211 for hook in self._forward_hooks.values():
212 hook_result = hook(self, input, result)
/Users/mac/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/modules/conv.py in forward(self, input)
235 def forward(self, input):
236 return F.conv2d(input, self.weight, self.bias, self.stride,
--> 237 self.padding, self.dilation, self.groups)
238
239
/Users/mac/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/functional.py in conv2d(input, weight, bias, stride, padding, dilation, groups)
35 f = ConvNd(_pair(stride), _pair(padding), _pair(dilation), False,
36 _pair(0), groups)
---> 37 return f(input, weight, bias) if bias is not None else f(input, weight)
38
39
/Users/mac/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/_functions/conv.py in forward(self, input, weight, bias)
32 if k == 3:
33 input, weight = _view4d(input, weight)
---> 34 output = self._update_output(input, weight, bias)
35 if k == 3:
36 output, = _view3d(output)
/Users/mac/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/_functions/conv.py in _update_output(self, input, weight, bias)
89
90 self._bufs = [[] for g in range(self.groups)]
---> 91 return self._thnn('update_output', input, weight, bias)
92
93 def _grad_input(self, input, weight, grad_output):
/Users/mac/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/_functions/conv.py in _thnn(self, fn_name, input, weight, *args)
148 impl = _thnn_convs[self.thnn_class_name(input)]
149 if self.groups == 1:
--> 150 return impl[fn_name](self, self._bufs[0], input, weight, *args)
151 else:
152 res = []
/Users/mac/anaconda/envs/py35/lib/python3.5/site-packages/torch/nn/_functions/conv.py in call_update_output(self, bufs, input, weight, bias)
220 def make_update_output(fn):
221 def call_update_output(self, bufs, input, weight, bias):
--> 222 backend = type2backend[type(input)]
223 bufs.extend([input.new(), input.new()])
224 output = input.new(*self._output_size(input, weight))
/Users/mac/anaconda/envs/py35/lib/python3.5/site-packages/torch/_thnn/__init__.py in __getitem__(self, name)
13
14 def __getitem__(self, name):
---> 15 return self.backends[name].load()
16
17
KeyError: <class 'torch.ByteTensor'>