Load DDP model trained with 8 gpus on only 2 gpus?

I have many Distributed Data Parallel models (NOT Data Parallel!) trained with 8 gpus on a cluster. I have no problem correctly restoring them with same number of gpus (8). But wait time to get 8 is too long. So I want to restore them with only two.

I was wondering if it is even possible? if so what is the correct way to do it?

The script below (test.py) works fine with 8 gpus but produces erroneous results with 2 gpus (in the latter case, the results are the same as a model just initialized with random weights). I use “python -m torch.distributed.launch --nproc_per_node=num_gpus test.py” to run it from terminal.

import argparse
from torchvision.models import resnet18
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

def cleanup():
    dist.destroy_process_group()

def main():
    torch.distributed.init_process_group(
backend='nccl', init_method='env://')
    torch.cuda.set_device(args.local_rank)
    model = resnet18()
    model = model.to([args.local_rank][0])
    model = DDP(model, device_ids=[args.local_rank], 
output_device=[args.local_rank][0])

    # load the model
    checkpoint = torch.load(load_path)
    state_dict = checkpoint['model_state_dict']
    model.load_state_dict(state_dict)
    dist.barrier()

    cleanup()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="blah")
    parser.add_argument("--local_rank", type=int)
    args, _ = parser.parse_known_args()
    main()

This should be possible, there is a map_location argument in torch.load. Checkout this.

The map_location can be a device, a function, a map etc. [API]

Thank you for your answer. The documentation does not include a working example for DDP. I have already tried many ways using map function None of which have worked so far. If you could show me a simple working example with mnist dataset to map 8 gpus to 1 or 2 or 4 gpus, or cpu with DistributedData parallel I would greatly appreciate it.

There are a few things to clarify.

  1. As you are using the resnet18 from torchvision, the model only lives on a single GPU.
  2. The launcher script you use starts num_gpus processes, and each process has its own DDP instance, dataloader, and the model replica.
  3. With 1 and 2, your training scripts only need put the model to one GPU (you can use the rank as the device id), load the data into one GPU, and the DDP instance will handle the comm for you, and make sure that all model replicas are synchronized properly.
  4. With the above 3, the question then would be “how do I load a model to a specific GPU device?”. And the answer is use map_local=torch.device(rank).

The following code works for me with the launching cmd

python -m torch.distributed.launch --nproc_per_node=2 test.py
import argparse
from torchvision.models import resnet18
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch

def cleanup():
    dist.destroy_process_group()

def main(args):
    torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=args.local_rank, world_size=2)
    torch.cuda.set_device(args.local_rank)
    model = resnet18()

    path = "save_model.pt"
    if args.local_rank == 0:
        # save CPU model
        torch.save(model, path)

    dist.barrier()
    # local model to GPU
    loaded_model = torch.load(path, map_location=torch.device(args.local_rank))

    model = DDP(loaded_model, device_ids=[args.local_rank])
    print(f"Rank {args.local_rank} traning on device {list(model.parameters())[0].device}")

    # create a dedicated data loader for each process

    cleanup()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="blah")
    parser.add_argument("--local_rank", type=int)
    args, _ = parser.parse_known_args()
    main(args)

@ mrshenli thanks for your reply. I tried your method after a few minor correction but it still gives me the same erroneous result. I use this resnet script to call the model.

I trained it on a large dataset and decided to save it periodically during training. Due to testing slowing down the training I decided to test it later using the saved models. When I train DDP with 8 gpus and test DDP with 8 gpus later, there is no issue. However, when I train DDP with 8 gpus and test DDP with 2 gpus later the problem occurs.

Also I only want to save and load the state_dict and not the entire model since it takes a lot of space.

I will create a working example for mnist shortly.

I tried your method after a few minor correction but it still gives me the same erroneous result. I use this resnet script to call the model.

You mean you saw error by running the script as is? What error did you see and what fix did you applied?

Also I only want to save and load the state_dict and not the entire model since it takes a lot of space.

It should be doable by just modifying two lines (save and load).

When I train DDP with 8 gpus and test DDP with 8 gpus later, there is no issue. However, when I train DDP with 8 gpus and test DDP with 2 gpus later the problem occurs.

The resnet link you posted points to torchvision resnet, so the model only lives on a single device. How did you go from training on 8 gpus to testing on 2 gpus? Did you do the following?

  1. After training, use only rank 0 to save ddp.module to file.
  2. For testing, as you no longer need comm across models, you don’t need DDP. You can spawn two processes, each load the saved module from file to its dedicated device by setting map_reduce. And use sth like all_gather to collect loss/accuracy data to rank 0?

@ mrshenli thanks again. I will try to answer all your inquiries with more detail in a bit today.

Unfortunately, I could not use your script as is because my already saved DDP (without “.module”) was already saved using a state_dict method.
So as for the minor changes, I did the following. :

def main(args):
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.test_batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True)

    model = get_model()
#############################################################
   # My changes
    torch.cuda.set_device(args.local_rank)
    model = model.to([args.local_rank][0])
    model = DDP(model, device_ids=[args.local_rank], 
output_device=[args.local_rank][0])
    checkpoint = torch.load(args.load_path)  # , map_location=map_location)
    state_dict = checkpoint['model_state_dict']
    model.load_state_dict(state_dict)
##############################################################
    dist.barrier()
    test_function(model, test_loader, args.local_rank,args.load_path.with_suffix('.csv'))

I trained resnet18 from scratch. I just copied and used the resnet script locally.

As for your last two comments I did use just rank 0 to save the ddp, but I saved the state_dict() for ddp itself (without .module). That is why when I used your script I also had to remove the .module similar to this:
[solved] KeyError: ‘unexpected key “module.encoder.embedding.weight” in state_dict’
Is it correctly to do so?

Yes, that is correct. The saved and loaded model type need to match.

This line might cause a problem if the model was saved from a device that is not available on the machine that loads the model. But it should be OK in your case, as the model was saved from rank 0 (i.e., “cuda:0”), whose device is available in both envs. However, without map_location, it means the two DDP processes in testing are operating on the same GPU? That could also cause problems.

mrshenli Sorry for the late reply. Say I want to train the DDP model on 4 gpus and restore it as DDP on 2. I created an mnist example to illustrate my case while following your example. This whole script is borrowed from mnist, modified and split into three scripts:

  1. mnist_common.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import argparse
from torchvision import datasets, transforms
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader


def cleanup():
    dist.destroy_process_group()


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device, non_blocking=True), \
                       target.to(device, non_blocking=True)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()


def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device, non_blocking=True), \
                           target.to(device, non_blocking=True)
            output = model(data)
            test_loss += F.nll_loss(
                output,
                target,
                reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    if args.local_rank == 0:
        print('Test set: Average loss: {:.4f},'
              ' Accuracy: {}/{} ({:.2f}%)'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))


# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size',
                    type=int,
                    default=64,
                    metavar='N',
                    help='input batch size for training')
parser.add_argument('--test-batch-size',
                    type=int,
                    default=1000,
                    metavar='N',
                    help='input batch size for testing')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
                    help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                    help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                    help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--local_rank', type=int)
args = parser.parse_args()

train_dataset = datasets.MNIST(
    '../data',
    train=True,
    download=False,
    transform=transforms.Compose([
       transforms.ToTensor(),
       transforms.Normalize((0.1307,), (0.3081,))
    ]))
train_sampler = DistributedSampler(
    train_dataset,
    num_replicas=torch.cuda.device_count(),
    rank=args.local_rank)
train_loader = DataLoader(train_dataset,
                          batch_size=args.batch_size,
                          shuffle=(train_sampler is None),
                          num_workers=0,
                          pin_memory=True,
                          sampler=train_sampler)
test_loader = DataLoader(
    datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
    batch_size=args.test_batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True,)
  1. mnist_train.py
from __future__ import print_function
import torch
import torch.optim as optim
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import StepLR
from torch.nn.parallel import DistributedDataParallel as DDP
from mnist_common import args, Net, train_loader, train_sampler,\
    test_loader, train, test, cleanup


def main(args):
    dist.init_process_group(backend='nccl',
                            init_method='tcp://localhost:23456',
                            rank=args.local_rank,
                            world_size=torch.cuda.device_count())
    torch.manual_seed(args.seed)
    torch.cuda.set_device(args.local_rank)
    cudnn.benchmark = True

    model = Net()
    model = model.to([args.local_rank][0])  # distribute the model
    # Should we set the output_device value in DPP?
    model = DDP(model, device_ids=[args.local_rank])
    # , output_device=[args.local_rank][0])

    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train_sampler.set_epoch(epoch)
        train(args, model, args.local_rank,
              train_loader, optimizer, epoch)
        test(args, model, args.local_rank, test_loader)
        scheduler.step(epoch)

    # I intend to save the model
    # AFTER some training not, not before
    if args.local_rank == 0:
        torch.save(model, "mnist_cnn.pt")
    dist.barrier()
    cleanup()


if __name__ == '__main__':
    main(args)

Also I intend to test the model, only after training (sometimes up to a few days) has finished, by restoring the saved weights (or model).
2) mnist_test.py

from __future__ import print_function
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from mnist_common import args, Net, test_loader, test, cleanup


def main(args):
    dist.init_process_group(backend='nccl',
                            init_method='tcp://localhost:23456',
                            rank=args.local_rank,
                            world_size=2)
    torch.manual_seed(args.seed)
    torch.cuda.set_device(args.local_rank)

    model = torch.load("mnist_cnn.pt",
                       map_location=torch.device(args.local_rank))
    model = DDP(model, device_ids=[args.local_rank])
    print(f"Rank {args.local_rank} "
          f"test on device {list(model.parameters())[0].device}")

    test(args, model, args.local_rank, test_loader)
    cleanup()


if __name__ == '__main__':
    main(args)

The mnist_train.py runs sucessfully using
python -m torch.distributed.launch nproc_per_node=4 (or 2) mnist_train.py.
but when i run the test script using
python -m torch.distributed.launch nproc_per_node=2 mnist_test.py.
I get the following:

Rank 0 test on device cuda:0
Rank 1 test on device cuda:1
Test set: Average loss: 0.0274, Accuracy: 9913/10000 (99.13%)

RuntimeError: Expected tensor for argument #1 input to have 
the same device as tensor for argument #2 weight; 
but device 0 does not equal 1 
(while checking arguments for cudnn_convolution)

Rank 0 test on device cuda:0
Rank 1 test on device cuda:1
Test set: Average loss: 0.0274, Accuracy: 9913/10000 (99.13%)

RuntimeError: Expected tensor for argument #1 input to have
the same device as tensor for argument #2 weight;
but device 0 does not equal 1
(while checking arguments for cudnn_convolution)

This means the first parameter of both models are placed onto the correct device. Can you do the same check for all parameters? i.e., making sure that all parameters are placed to the correct device.

output = model(data)

Before the line above in test(...), can you print the device ids of the data as well? Looks like the model and data device does not match on rank 1.

I see. But it should not be the case since both are moved to args.local_rank. Anyways, I did what you suggested and also changed the test-batch-size to 1024. Here’s the outcome:

Rank 0 test on device cuda:0
Rank 1 test on device cuda:1
after data=data.to(device,), before output=model(data) in test function,  batch_idx: 0 device: 1
after data=data.to(device,), before output=model(data) in test function,  batch_idx: 0 device: 0
after data=data.to(device,), before output=model(data) in test function,  batch_idx: 1 device: 0
after data=data.to(device,), before output=model(data) in test function,  batch_idx: 2 device: 0
after data=data.to(device,), before output=model(data) in test function,  batch_idx: 3 device: 0
after data=data.to(device,), before output=model(data) in test function,  batch_idx: 4 device: 0
after data=data.to(device,), before output=model(data) in test function,  batch_idx: 5 device: 0
after data=data.to(device,), before output=model(data) in test function,  batch_idx: 6 device: 0
after data=data.to(device,), before output=model(data) in test function,  batch_idx: 7 device: 0
after data=data.to(device,), before output=model(data) in test function,  batch_idx: 8 device: 0
after data=data.to(device,), before output=model(data) in test function,  batch_idx: 9 device: 0
Test set: Average loss: 0.0275, Accuracy: 9913/10000 (99.13%)

RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 0 does not equal 1 (while checking arguments for cudnn_convolution)

Rank 1 test on device cuda:1
after data=data.to(device,), before output=model(data) in test function, batch_idx: 0 device: 1

This is weird. This means all model parameters are on cuda:1 and the input batch is also on cuda:1, but somehow one of the conv layers still throws device mismatch? I am not sure what happened here, but as the error suggests the mismatch occurs in cudnn_convolution, I would check if the input (x) of and the parameters the two conv layer (self.conv1 and self.conv2) match in the forward() function during testing.

BTW, two more comments on the script:

  1. As you are only doing forward during testing, it is not necessary to use DDP there, as all comm in DDP occurs during backward.
  2. I noticed you saving a DDP module and then load that DDP module and wrap it with another DDP module. Is this intentional? Shouldn’t mnist_train.py save model.module instead? (or use model.module to initialize DDP instances in testing)