You are probably using DataParallel but returning a scalar in the network. You should return a batched output.