I am trying to implement cell gradient clamp for nn.LSTM.
Specifically, I need to clamp cell.grad whenever it is computed for every time step.
This is implemented in Alex grave’s rnnlib implementation.
(https://sourceforge.net/projects/rnnl/files/ > LstmLayer.hpp line 315)
//constrain errors to be in [-1,1] for stability
bound_range(inErrs, -1.0, 1.0);
This is different from calling torch.nn.utils.clip_grad_norm() after loss.backward().
Inspired by Gradient clipping
I was trying to set register_hook() on ‘hidden’ variable to modify gradient on-the-fly. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py : line 181
Here is the example code :
def clamp_grad(v, min, max):
v.register_hook(lambda grad: grad.clamp(min, max))
output, hidden = func(input, self.all_weights, hx, batch_sizes)
hidden = clamp_grad(hidden, -1.0, 1.0)
The problem I’m having now is ‘hidden’ variable has Nonetype grad at backward phase.
(i.e. v.grad == None)
Is there any other way to implement cell gradient clamp?
I don’t think you want the assignment, just register the hook.
Thanks for your comment.
But I don’t get your point. Could you show me simple example code?
I really appreciate your help.
So somewhere in my implementation of Graves handwriting generation (I should just put it on github, I know…), I have
hidden.register_hook(lambda x: x.clamp(min=-GRAD_CLIPPING, max=GRAD_CLIPPING))
Note that it looks mostly like yours, but
- I check whether the parameter wants a gradient (though I have not recently checked whether that is necessary, your error seems to suggest it may),
- I don’t overwrite hidden with the return value of register_hook, the documentation says it returns a handle. You can and should just use the variable you used register_hook on. [Edit: I see you are returning
v not the return of register hook, so you are good on that, sorry for the confusion.]
It seems that we cannot register a hook on nn.RNN / nn.LSTM because CuDNN does the whole operation in a single CUDA kernel. However, we can register a hook on RNNCell, LSTMCell.
(Ref : https://github.com/pytorch/pytorch/issues/1407)
Ah, good catch. I used LSTMCell in the handwriting generation, so I didn’t notice.