Saving torch models

net = Net(...).cuda()
torch.save(net.state_dict(), './net.pth')
net.load_state_dict(torch.load('./net.pth'))
5 Likes