Hello,
I am trying to use weight_norm
on a GRU, however when I was trying to do this, there appeared an error.
File "<stdin>", line 1, in <module>
File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 147, in cuda
return self._apply(lambda t: t.cuda(device_id))
File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/rnn.py", line 116, in _apply
self.flatten_parameters()
File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/nn/modules/rnn.py", line 107, in flatten_parameters
rnn._copyParams(all_weights, params)
File "/home/didoyang/anaconda2/lib/python2.7/site-packages/torch/backends/cudnn/rnn.py", line 186, in _copyParams
assert param_from.type() == param_to.type()
Example is like following:
import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
a = nn.GRU(100, 20)
b = weight_norm(a, name='weight_hh_l0')
b = weight_norm(b, name='weight_ih_l0')
b.cuda()
When running through b.cuda()
the error happened.
I built the code from source.