I wrap my module with DataParallel. I use
device_ids=[2, 3]
when forward, the gru part will raise AssertionError
class Multi(nn.Module):
def __init__(self):
super(Multi, self).__init__()
self.encoder_gru = nn.GRU(input_size=4, hidden_size=2,
num_layers=1, bidirectional=True)
def forward(self, input,hidden):
input = input.transpose(0, 1)
hidden = hidden.transpose(0, 1)
gru_output, h_n = self.encoder_gru(input, hidden)
return gru_output.transpose(0,1) # (seq_len, batch, hidden_size * num_directions)
data = Variable(torch.ones(6,2,4))
cuda_data = data.cuda(2)
model0 = Multi()
a = Variable(torch.ones(6,2,2))*2
model = nn.DataParallel(model0, device_ids=[2, 3])
model.cuda(2)
output = model(cuda_data, a.cuda(2))
loss = output.sum()
loss.backward()
print output
print data.grad
for name, param in model0.named_parameters():
print name, param.grad
error info
Traceback (most recent call last):
File "/home/ryk/programming/kbp_torch/test.py", line 223, in <module>
output = model(cuda_data, a.cuda(2))
File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 206, in __call__
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/data_parallel.py", line 61, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/data_parallel.py", line 71, in parallel_apply
return parallel_apply(replicas, inputs, kwargs)
File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/parallel_apply.py", line 46, in parallel_apply
raise output
AssertionError
Process finished with exit code 1