I tried run RAM model on my own data set. However, changing the python code is not easy for me because it is the first time that I use python.
train.t7 has b’data’ and b’labels’. The dimension of b’data’ is 50000X28X28 and the dimension of b’label’ is 50000X1. I have been doing my best to run the model on my data set but the attempts are unsuccessful. I need your help in resolving this issue.
I run the following code to load the model:
o=torchfile.load('train.t7'
)
x=torch.tensor(o[b'data'])
y=torch.tensor(o[b'labels'])
dataset = torch.utils.data.TensorDataset(x, y)
num_train = len(dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
However, the following error appears:
[*] Number of model parameters: 209,677
[*] Model Checkpoint Dir: ./ckpt
[*] Param Path: ./ckpt/ram_1_8x8_2_params.json
[*] Train on 45000 samples, validate on 5000 samples
Epoch: 1/200 - LR: 0.000300
0%| | 0/45000 [00:00<?, ?it/s]
Traceback (most recent call last):
File "main.py", line 48, in <module>
main(config)
File "main.py", line 39, in main
trainer.train()
File "/home/wing92518/Downloads/1_recurrent-visual-attention-master/trainer.py", line 168, in train
train_loss, train_acc = self.train_one_epoch(epoch)
File "/home/wing92518/Downloads/1_recurrent-visual-attention-master/trainer.py", line 247, in train_one_epoch
x, l_t, h_t, last=True
File "/home/wing92518/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/home/wing92518/Downloads/1_recurrent-visual-attention-master/model.py", line 101, in forward
g_t = self.sensor(x, l_t_prev)
File "/home/wing92518/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/home/wing92518/Downloads/1_recurrent-visual-attention-master/modules.py", line 207, in forward
phi = self.retina.foveate(x, l_t_prev)
File "/home/wing92518/Downloads/1_recurrent-visual-attention-master/modules.py", line 56, in foveate
phi.append(self.extract_patch(x, l, size))
File "/home/wing92518/Downloads/1_recurrent-visual-attention-master/modules.py", line 86, in extract_patch
B, C, H, W = x.shape
ValueError: not enough values to unpack (expected 4, got 3)
`