Save checkpoints trained on multi GPUs for load on single GPU

I am training a GAN model right now on multi GPUs using DataParallel, and try to follow the official guidance here for saving torch.nn.DataParallel Models, as I plan to do evaluation on single GPU later, which means I need to load checkpoints trained on multi GPU to single GPU.

The official guidance indicates that, “to save a DataParallel model generically, save the model.module.state_dict() . This way, you have the flexibility to load the model any way you want to any device you want”:

#Save:
torch.save(model.module.state_dict(), PATH)
#Load:
# Load to whatever device you want

And this are my scripts for saving the generator and discriminator respectively:

torch.save(G.module.state_dict(), 
              '%s/%s_module.pth' % (root, join_strings('_', ['G', name_suffix])))
torch.save(D.module.state_dict(), 
              '%s/%s_module.pth' % (root, join_strings('_', ['D', name_suffix])))

However, when it comes to saving the checkpoint, I got error:

Traceback (most recent call last):
  File "train.py", line 227, in <module>
    main()
  File "train.py", line 224, in main
    run(config)
  File "train.py", line 206, in run
    state_dict, config, experiment_name)
  File "/home/BIGGAN/train_fns.py", line 101, in save_and_sample
    experiment_name, None, G_ema if config['ema'] else None)
  File "/home/BIGGAN/utils.py", line 721, in save_weights
    torch.save(G.module.state_dict(), 
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 594, in __getattr__
    type(self).__name__, name))
AttributeError: 'Generator' object has no attribute 'module'

But the checkpoints can be saved if I use:

torch.save(G.state_dict(), 
              '%s/%s.pth' % (root, join_strings('_', ['G', name_suffix])))
torch.save(D.state_dict(), 
              '%s/%s.pth' % (root, join_strings('_', ['D', name_suffix])))

I am using pytorch with version ‘1.5.0a0+8f84ded’.

I am not sure if the error has something to do with my pytorch version, or if I have missed something in my scripts.

Just in case, if there is another way around that can allow me to load checkpoints trained on multi GPU to a single GPU, would also be great.

Any guidance and assistance would be greatly appreciated!

I think the tutorial you linked has a bug when it comes to the loading. You would want to load the state dict back to model.module, i.e.

# Load to whatever device you want
might well be amended as

model.module.load_state_dict(torch.load(PATH))

This way, the state dict matches the model without the DataParallel wrapper, and you can also load it to a unwrapped model on a single GPU (use map_location in torch.load if needed).

Best regards

Thomas

@tom Hi, Thomas, thanks a lot for your response.

Actually I am having trouble with saving not the loading, namely:

torch.save(model.module.state_dict(), PATH)

Do you happen to know if this function only apply to the latest pytorch version? Cause I am using version ‘1.5.0a0+8f84ded’ and got the following error when it comes to saving checkpoint:

Traceback (most recent call last):
  File "train.py", line 227, in <module>
    main()
  File "train.py", line 224, in main
    run(config)
  File "train.py", line 206, in run
    state_dict, config, experiment_name)
  File "/home/BIGGAN/train_fns.py", line 101, in save_and_sample
    experiment_name, None, G_ema if config['ema'] else None)
  File "/home/BIGGAN/utils.py", line 721, in save_weights
    torch.save(G.module.state_dict(), 
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 594, in __getattr__
    type(self).__name__, name))
AttributeError: 'Generator' object has no attribute 'module'

Or do I need to do some extra configuration when setting up the models if I wanna use torch.save(model.module.state_dict(), PATH) to save torch.nn.DataParallel Models? Thank you!

Ah, yeah, the generator apparently isn’t wrapped in a DataParallel instance. (The error says you’re trying to access an attribute module (which a DataParallel would have) from an object of type Generator.)

@tom Thank you for your kind explanation. Yet as I am confident that I have applied DataParallel to the generator, may I check if this is indeed a version issue? Namely torch.save(model.module.state_dict(), PATH) and model.module.load_state_dict(torch.load(PATH)) are new functions only apply to the latest pytorch version?

As reference just in case, the following is my code for training setup and saving checkpoints:

use_gpu = torch.cuda.is_available()
device = torch.device("cuda" if use_gpu else "cpu")
D = model.DiscriminatorACGAN(x_dim=x_dim, c_dim=c_dim, norm=norm, weight_norm=weight_norm).to(device)
G = model.GeneratorACGAN(z_dim=z_dim, c_dim=c_dim).to(device)
ngpu = 2 # I am using a 2 GPU machine
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    G = nn.DataParallel(G, list(range(ngpu))).to(device)
if (device.type == 'cuda') and (ngpu > 1):
    D = nn.DataParallel(D, list(range(ngpu))).to(device)

# gan loss function
d_loss_fn, g_loss_fn = model.get_losses_fn(loss_mode)

# optimizer
d_optimizer = torch.optim.Adam(D.parameters(), lr=d_learning_rate, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=g_learning_rate, betas=(0.5, 0.999))

# run
z_sample = torch.randn(c_dim * 10, z_dim).to(device)
c_sample = torch.tensor(np.concatenate([np.eye(c_dim)] * 10), dtype=z_sample.dtype).to(device)

for ep in range(start_ep, epoch):
    for i, (x, c_dense) in enumerate(train_loader):

        step = ep * len(train_loader) + i + 1
        D.train()
        G.train()

        x = x.to(device)
        c_dense = c_dense.to(device)
        z = torch.randn(batch_size, z_dim).to(device)
        c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
        x_f = G(z, c)

        # train D
        ...
        # train G
        ...

        if ep % 10 == 0:
            torch.save(G.module.state_dict(), os.path.join(ckpt_dir, 'netG_{}.pth'.format(ep)))
            torch.save(D.module.state_dict(), os.path.join(ckpt_dir, 'netD_{}.pth'.format(ep)))

What’s orig_G/orig_D in torch.save? I don’t think I see these anywhere except in the save.

@tom Sorry, it’s a typo, I have corrected it. It should be

torch.save(G.module.state_dict(), os.path.join(ckpt_dir, 'netG_{}.pth'.format(ep)))
torch.save(D.module.state_dict(), os.path.join(ckpt_dir, 'netD_{}.pth'.format(ep)))

I was just modifying the code when copying it, and forgot to change it back , my fault, the error has nothing to do with the typo. :sweat_smile:

maybe you can do print(type(G)) at various points in your code to see where is becomes or not becomes a DataParallel.

Hi @Janine,

This is not related to PyTorch version but the DataParallel (also DistributedDataParallel) class wrapper of PyTorch nn class models.
DataParallel encloses the original model as it member variable, self.module.

In case you need both single-GPU and multi-GPU model training, you can change saving/loading behavior with if statements.
For example,

if isinstance(G, nn.DataParallel):
    torch.save(G.module.state_dict(), model_save_name)
else:
    torch.save(G.state_dict(), model_save_name)

If the current model class is DataParallel, you can save G.module.state_dict() otherwise save G.state_dict()
Also, at loading pretrained parameters, you could perform

if isinstance(G, nn.DataParallel):
    G.module.load_state_dict(state_dict)
else:
    G.load_state_dict(state_dict)

I would suggest stating a parent class Model that inherits nn.Module that overrides default state_dict function with the above method so that G and D could inherit it and simplify your training code part.
You may want to take a look at my code.
state_dict, load_state_dict, save functions are related.

6 Likes

Thank you @seungjun, this is indeed a very neat way to avoid conflicts. I can fully understand the module mechanism and resolve the issue right now.

Thank you @tom for your kind help, I have found the pitfalls in my original code and the problem has been solved.

Is there a way to do the opposite: a model trained on a single GPU being loaded into a multi-GPU inference script? This is how I’ve been doing it, but it doesn’t work as expected:

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    print("Using", torch.cuda.device_count(), "GPUs")

    model.to(device)
    checkpoint = torch.load(args.pretrained_path)

    if torch.cuda.device_count() > 1:
        model.module.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint['model_state_dict'])

I’m using 2 GPUs, I’ve set the batch size to 50 (so 25 samples per GPU). The input right before being passed into the model has dimensions torch.Size([50, 3, 2048, 2048]). However, the output returned by the model only has 25 items. Not sure why this is the case. Appreciate any insights into this!

Not sure about the cause of the problem, but I would change the order to:

  1. construct a single-GPU model
  2. load the weights to the single-GPU model
  3. parallelize the model to multi-GPU format (DP or DDP)

In your case, it could be:

model.to(device, ...)
model.load_state_dict(...)
model = nn.DataParallel(model, ...)
1 Like