I found the reason for the error: it’s because my forward method has multiple arguments (i.e. model.forward(self, x, h, c) rather than just model.forward(self, x) ). However I am using a conv-LSTM architecture so the other arguments h and c are required. What is the solution in this case?
Note that nn.DataParallel is in maintenance mode and you should thus use DistributedDataParallel which should not suffer from these issues as each process uses its own input.