Though I have gone through the code of gradcheck (https://github.com/pytorch/pytorch/blob/master/torch/autograd/gradcheck.py), I am still not so sure how to use it.
Now I have a model define like this :
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns
input_size = INPUT_SIZE,
hidden_size = HIDDEN_SIZE, # rnn hidden unit
num_layers = NUM_LAYER, # number of rnn layer
batch_first = True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
dropout = 0.5
)
self.out = nn.Linear(HIDDEN_SIZE, NUM_CLASS)
def forward(self, x, l):
x = pack_padded_sequence(x, list(l.data), batch_first = True)
r_out, (h_n, h_c) = self.rnn(x, None)
r_out, _ = pad_packed_sequence(r_out, batch_first = True)
idx = (l-1).view(-1,1).expand(r_out.size(0), r_out.size(2)).unsqueeze(1).long()
r_out = r_out.gather(1, idx).squeeze().unsqueeze(1)
# choose r_out at the last time step
out = self.out(r_out[:, -1, :])
return out
How to use gradcheck to debug my code ?
Should I use it like this ?
for step, (x, y, l) in enumerate(dset_loaders['Train']):
b_x = Variable(x.cuda(GPU_ID).view(-1, TIME_STEP, INPUT_SIZE))
b_y = torch.squeeze(Variable(y.cuda(GPU_ID)))
b_l = torch.squeeze(Variable(l.cuda(GPU_ID)))
b_x, b_l = sort_batch(b_x, b_l)
output = model(b_x, b_l)
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), 1)
print "gradCheck :", gradcheck(model, b_x) # Where gradcheck is called
optimizer.step()
However when I use it in this way, I get the error shown below:
Traceback (most recent call last):
File "train.py", line 163, in <module>
train_model(rnn, dataloaders, optimizer, loss_func)
File "train.py", line 94, in train_model
print gradcheck(model, b_x)
File "/torch/autograd/gradcheck.py", line 154, in gradcheck
output = func(*inputs)
File "/torch/nn/modules/module.py", line 224, in __call__
result = self.forward(*input, **kwargs)
TypeError: forward() takes exactly 3 arguments (257 given)
My batch size is 256. I am not sure where is wrong.
Thank you!