I trained the model for 5 epochs on 3 GPUs using DDP. I saved the model on the first GPU at the end of training to the hard disk. Now, if I try to load the state_dict to the model, I get this error.
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for AudioCNN:
Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "conv3.weight", "conv3.bias", "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias".
Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias", "module.conv2.weight", "module.conv2.bias", "module.conv3.weight", "module.conv3.bias", "module.fc1.weight", "module.fc1.bias", "module.fc2.weight", "module.fc2.bias".
This is essentially my training script
import os
from datetime import datetime
import argparse
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import torch.distributed as dist
from model import AudioCNN
from data import CustomAudioDataset
resume = os.path.isfile("models/model_latest.tar")
if resume:
checkpoint = torch.load("models/model_latest.tar")
if not resume:
with open("models/loss.csv", "w") as f:
f.write("epoch,batch,loss\n")
if resume:
print(f"Found previous training files, resuming from {checkpoint['epoch'] + 1} epoch.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-n",
"--nodes",
default=1,
type=int,
metavar="N",
help="number of data loading workers (default: 4)",
)
parser.add_argument(
"-g", "--gpus", default=1, type=int, help="number of gpus per node"
)
parser.add_argument(
"-nr", "--nr", default=0, type=int, help="ranking within the nodes"
)
parser.add_argument(
"--epochs",
default=2,
type=int,
metavar="N",
help="number of total epochs to run",
)
args = parser.parse_args()
args.world_size = args.gpus * args.nodes
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8888"
mp.spawn(train, nprocs=args.gpus, args=(args,))
def train(gpu, args):
rank = args.nr * args.gpus + gpu
dist.init_process_group(
backend="nccl", init_method="env://", world_size=args.world_size, rank=rank
)
torch.manual_seed(0)
model = AudioCNN()
optimizer = torch.optim.SGD(model.parameters(), 1e-3)
criterion = nn.CrossEntropyLoss().cuda(gpu)
epochs_completed = 0
current_loss = 0
if resume:
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epochs_completed = checkpoint["epoch"]
current_loss = checkpoint["loss"]
torch.cuda.set_device(gpu)
model.cuda(gpu)
batch_size = 8
# define loss function (criterion) and optimizer
# Wrap the model
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
# Data loading code
train_dataset = CustomAudioDataset()
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=args.world_size, rank=rank
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
sampler=train_sampler,
)
start = datetime.now()
total_step = len(train_loader)
for epoch in range(epochs_completed, args.epochs):
fname = (
f"models/model_0{epoch}.tar" if epoch < 10 else f"models/model_{epoch}.tar"
)
for i, (images, labels) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if gpu == 0:
current_loss = loss.item()
with open("models/loss.csv", "a") as f:
f.write(f"{epoch},{i},{current_loss}\n")
if (i + 1) % 100 == 0 and gpu == 0:
print(
"Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format(
epoch + 1, args.epochs, i + 1, total_step, loss.item()
)
)
if gpu == 0:
state_dict = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": current_loss,
}
torch.save(state_dict, fname)
torch.save(state_dict, "models/model_latest.tar")
torch.save(model, "models/model_latest_model_only")
if gpu == 0:
print("Training complete in: " + str(datetime.now() - start))
if __name__ == "__main__":
main()
what am I doing wrong, how can I fix this?