@wgharbieh @smth I have added self.lstm1.flatten_parameters()
above out, hidden = self.lstm1(input[i], hidden)
, but I faced another error and got the following message:
Traceback (most recent call last):
File "steps/run_dpcl.py", line 600, in <module>
main(device)
File "steps/run_dpcl.py", line 417, in main
train(model, mix_mean, mix_var, clean_mean, clean_var, device, ema)
File "steps/run_dpcl.py", line 149, in train
output, hidden = model(mix_v, hidden, lengths)
File "/mnt/tools/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/mnt/tools/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 114, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/mnt/tools/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 124, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/mnt/tools/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 65, in parallel_apply
raise output
File "/mnt/tools/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 41, in _worker
output = module(*input, **kwargs)
File "/mnt/tools/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/mnt/workspace/pytorch/deep_clustering/model/blstm_dpcl.py", line 58, in forward
output, hidden = self.blstm(inputs, hidden)
File "/mnt/tools/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/mnt/tools/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 192, in forward
output, hidden = func(input, self.all_weights, hx, batch_sizes)
File "/mnt/tools/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 323, in forward
return func(input, *fargs, **fkwargs)
File "/mnt/tools/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 287, in forward
dropout_ts)
RuntimeError: torch/csrc/autograd/variable.cpp:115: get_grad_fn: Assertion `output_nr == 0` failed.