I’m using a model with its loss wrapped within the forward function, like this
Class MyModel():
def __init__(self):
# init
def forward(self, input, target):
# some computing
return predict, loss
When I tried to train this model with DataParallel, it throwed some error like this:
File "/home/zhangyu/codes/relation_video/main.py", line 335, in <module>
main()
File "/home/zhangyu/codes/relation_video/main.py", line 128, in main
train(train_loader, model, criterion, optimizer, epoch, log_training)
File "/home/zhangyu/codes/relation_video/main.py", line 175, in train
model(input_var, box, target_var)
File "/home/zhangyu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/zhangyu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 115, in forward
return self.gather(outputs, self.output_device)
File "/home/zhangyu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 127, in gather
return gather(outputs, output_device, dim=self.dim)
File "/home/zhangyu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
return gather_map(outputs)
File "/home/zhangyu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
return Gather.apply(target_device, dim, *outputs)
File "/home/zhangyu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 54, in forward
ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
File "/home/zhangyu/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 54, in <lambda>
ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
RuntimeError: dimension specified as 0 but tensor has no dimensions
Note when I removed the dataparallel and used only one GPU, the model worked fine. When I removed the loss from the return line of the forward function(only return predict), it also worked fine.
So I think there is somgthing wrong with dataparallel and returning loss in forword fuction. Any sugesstions on this?