Built-in weight_norm on RNN

Hello,
I am trying to use weight_norm on a GRU, however when I was trying to do this, there appeared an error.

  File "<stdin>", line 1, in <module>
  File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 147, in cuda
    return self._apply(lambda t: t.cuda(device_id))
  File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/rnn.py", line 116, in _apply
    self.flatten_parameters()
  File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/rnn.py", line 107, in flatten_parameters
    rnn._copyParams(all_weights, params)
  File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/backends/cudnn/rnn.py", line 186, in _copyParams
    assert param_from.type() == param_to.type()

Example is like following:

import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
a = nn.GRU(100, 20)
b = weight_norm(a, name='weight_hh_l0')
b = weight_norm(b, name='weight_ih_l0')
b.cuda()

When running through b.cuda() the error happened.

I built the code from source.

1 Like

I haven’t looked at this in detail, but if you change the 3rd line to:
a = nn.GRU(100, 20).cuda()
it works.

3 Likes

Yes, that does work. Thank you!

Hi,
What about the reverse weights weight_ih_l0_reverse and weight_hh_l0_reverse? to they also need to be reparameterized?