Saving and resuming in DDP training

I trained the model for 5 epochs on 3 GPUs using DDP. I saved the model on the first GPU at the end of training to the hard disk. Now, if I try to load the state_dict to the model, I get this error.

    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for AudioCNN:
        Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "conv3.weight", "conv3.bias", "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias".
        Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias", "module.conv2.weight", "module.conv2.bias", "module.conv3.weight", "module.conv3.bias", "module.fc1.weight", "module.fc1.bias", "module.fc2.weight", "module.fc2.bias".

This is essentially my training script

import os
from datetime import datetime
import argparse
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import torch.distributed as dist
from model import AudioCNN
from data import CustomAudioDataset

resume = os.path.isfile("models/model_latest.tar")

if resume:
    checkpoint = torch.load("models/model_latest.tar")

if not resume:
    with open("models/loss.csv", "w") as f:
        f.write("epoch,batch,loss\n")


if resume:
    print(f"Found previous training files, resuming from {checkpoint['epoch'] + 1} epoch.")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-n",
        "--nodes",
        default=1,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 4)",
    )
    parser.add_argument(
        "-g", "--gpus", default=1, type=int, help="number of gpus per node"
    )
    parser.add_argument(
        "-nr", "--nr", default=0, type=int, help="ranking within the nodes"
    )
    parser.add_argument(
        "--epochs",
        default=2,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    args = parser.parse_args()
    args.world_size = args.gpus * args.nodes
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "8888"
    mp.spawn(train, nprocs=args.gpus, args=(args,))


def train(gpu, args):
    rank = args.nr * args.gpus + gpu
    dist.init_process_group(
        backend="nccl", init_method="env://", world_size=args.world_size, rank=rank
    )
    torch.manual_seed(0)
    model = AudioCNN()
    optimizer = torch.optim.SGD(model.parameters(), 1e-3)
    criterion = nn.CrossEntropyLoss().cuda(gpu)
    epochs_completed = 0
    current_loss = 0
    if resume:
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epochs_completed = checkpoint["epoch"]
        current_loss = checkpoint["loss"]
    torch.cuda.set_device(gpu)
    model.cuda(gpu)
    batch_size = 8
    # define loss function (criterion) and optimizer
    # Wrap the model
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
    # Data loading code
    train_dataset = CustomAudioDataset()
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        sampler=train_sampler,
    )

    start = datetime.now()
    total_step = len(train_loader)
    for epoch in range(epochs_completed, args.epochs):
        fname = (
            f"models/model_0{epoch}.tar" if epoch < 10 else f"models/model_{epoch}.tar"
        )
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if gpu == 0:
                current_loss = loss.item()
                with open("models/loss.csv", "a") as f:
                    f.write(f"{epoch},{i},{current_loss}\n")
            if (i + 1) % 100 == 0 and gpu == 0:
                print(
                    "Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format(
                        epoch + 1, args.epochs, i + 1, total_step, loss.item()
                    )
                )
        if gpu == 0:
            state_dict = {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": current_loss,
            }
            torch.save(state_dict, fname)
            torch.save(state_dict, "models/model_latest.tar")
            torch.save(model, "models/model_latest_model_only")

    if gpu == 0:
        print("Training complete in: " + str(datetime.now() - start))


if __name__ == "__main__":
    main()

what am I doing wrong, how can I fix this?

Apparently, when saving models that’s been training on multiple devices, we have to use

torch.save(model.module.state_dict(), PATH)

and not torch.save(model.state_dict(), PATH).

Refer the end of this page for more.