net = Net(...).cuda() torch.save(net.state_dict(), './net.pth') net.load_state_dict(torch.load('./net.pth'))