DataParallel freezes

I encountered the same issue above, and I resolve this issue by chaning the following way.

model = ResNet().to(device)
model = nn.DataParallel(model)

to

model = ResNet()
model = nn.DataParallel(model).to(device)

In my case, it works.

BTW, the model should be in the output_device as the official document mentioned.

Thanks,

1 Like