Very bad performance after loading from pth file

I have this function to save the model

save_checkpoint(
    {
        "epoch": epoch + 1,
        "state_dict": self.model.module.state_dict(),  # Model layers / weights
        "best_score": self.args[
            "best_pred"
        ],  # Save best score (scalar)
        "optimizer": self.optimizer.state_dict(),
    },
    is_best,  # Boolean. True if the predicted MAE is better than the last saved MAE
    self.args["save_path"],  # Where to save checkpoint
)


def save_checkpoint(state, is_best, save_path, filename="checkpoint.pth"):
    torch.save(state, "./" + str(save_path) + "/" + filename)
    if is_best:
        shutil.copyfile(
            "./" + str(save_path) + "/" + filename,
            "./" + str(save_path) + "/" + "model_best.pth",
        )

And this is how I load it:

def load_checkpoint(self, weight_path):
    if self.gpu_id == 0:
        print("### Loading Checkpoint ###")

    location = f"cuda:{self.gpu_id}"
    checkpoint = torch.load(weight_path, map_location=location)
    self.model.load_state_dict(checkpoint["state_dict"], strict=False)

When using strict=False I get no errors but the performance of the model is like it was never trained. When I use strick=True I get a alot of missing keys, e.g. Missing key(s) in state_dict: …

When I started working on this project I was using DP. Lately I migrated to DDP and it was working great. My last training session using DDP was successful and I wanted to run a validation run with it, so naturally I tried loading the pre-trained weights and all hell broke loose.

If anyone wants to see the actual error, its this:

Missing key(s) in state_dict: "module.conv1.weight", "module.in1.weight", "module.in1.bias", "module.conv2.weight", "module.in2.weight", "module.in2.bias", "module.layer1.0.conv1.weight", "module.layer1.0.in1.weight", "module.layer1.0.in1.bias", "module.layer1.0.conv2.weight", "module.layer1.0.in2.weight", "module.layer1.0.in2.bias", "module.layer1.0.conv3.weight", "module.layer1.0.in3.weight", "module.layer1.0.in3.bias", "module.layer1.0.downsample.0.weight", "module.layer1.0.downsample.1.weight", "module.layer1.0.downsample.1.bias", "module.layer1.1.conv1.weight", "module.layer1.1.in1.weight", "module.layer1.1.in1.bias", "module.layer1.1.conv2.weight", "module.layer1.1.in2.weight", "module.layer1.1.in2.bias", "module.layer1.1.conv3.weight", "module.layer1.1.in3.weight", "module.layer1.1.in3.bias", "module.layer1.2.conv1.weight", "module.layer1.2.in1.weight", "module.layer1.2.in1.bias", "module.layer1.2.conv2.weight", "module.layer1.2.in2.weight", "module.layer1.2.in2.bias", "module.layer1.2.conv3.weight", "module.layer1.2.in3.weight", "module.layer1.2.in3.bias", "module.layer1.3.conv1.weight", "module.layer1.3.in1.weight", "module.layer1.3.in1.bias", "module.layer1.3.conv2.weight", "module.layer1.3.in2.weight", "module.layer1.3.in2.bias", "module.layer1.3.conv3.weight", "module.layer1.3.in3.weight", "module.layer1.3.in3.bias", "module.transition1.0.0.weight", "module.transition1.0.1.weight", "module.transition1.0.1.bias", "module.transition1.1.0.0.weight", "module.transition1.1.0.1.weight", "module.transition1.1.0.1.bias", "module.stage2.0.branches.0.0.conv1.weight", "module.stage2.0.branches.0.0.in1.weight", "module.stage2.0.branches.0.0.in1.bias", "module.stage2.0.branches.0.0.conv2.weight", "module.stage2.0.branches.0.0.in2.weight", "module.stage2.0.branches.0.0.in2.bias", "module.stage2.0.branches.0.1.conv1.weight", "module.stage2.0.branches.0.1.in1.weight", "module.stage2.0.branches.0.1.in1.bias", "module.stage2.0.branches.0.1.conv2.weight", "module.stage2.0.branches.0.1.in2.weight", "module.stage2.0.branches.0.1.in2.bias", "module.stage2.0.branches.0.2.conv1.weight", "module.stage2.0.branches.0.2.in1.weight", "module.stage2.0.branches.0.2.in1.bias", "module.stage2.0.branches.0.2.conv2.weight", "module.stage2.0.branches.0.2.in2.weight", "module.stage2.0.branches.0.2.in2.bias", "module.stage2.0.branches.0.3.conv1.weight", "module.stage2.0.branches.0.3.in1.weight", "module.stage2.0.branches.0.3.in1.bias", "module.stage2.0.branches.0.3.conv2.weight", "module.stage2.0.branches.0.3.in2.weight", "module.stage2.0.branches.0.3.in2.bias", "module.stage2.0.branches.1.0.conv1.weight", "module.stage2.0.branches.1.0.in1.weight", "module.stage2.0.branches.1.0.in1.bias", "module.stage2.0.branches.1.0.conv2.weight", "module.stage2.0.branches.1.0.in2.weight", "module.stage2.0.branches.1.0.in2.bias", "module.stage2.0.branches.1.1.conv1.weight", "module.stage2.0.branches.1.1.in1.weight", "module.stage2.0.branches.1.1.in1.bias", "module.stage2.0.branches.1.1.conv2.weight", "module.stage2.0.branches.1.1.in2.weight", "module.stage2.0.branches.1.1.in2.bias", "module.stage2.0.branches.1.2.conv1.weight", "module.stage2.0.branches.1.2.in1.weight", "module.stage2.0.branches.1.2.in1.bias", "module.stage2.0.branches.1.2.conv2.weight", "module.stage2.0.branches.1.2.in2.weight", "module.stage2.0.branches.1.2.in2.bias", "module.stage2.0.branches.1.3.conv1.weight", "module.stage2.0.branches.1.3.in1.weight", "module.stage2.0.branches.1.3.in1.bias", "module.stage2.0.branches.1.3.conv2.weight", "module.stage2.0.branches.1.3.in2.weight", "module.stage2.0.branches.1.3.in2.bias", "module.stage2.0.fuse_layers.0.1.0.weight", "module.stage2.0.fuse_layers.0.1.1.weight", "module.stage2.0.fuse_layers.0.1.1.bias", "module.stage2.0.fuse_layers.1.0.0.0.weight", "module.stage2.0.fuse_layers.1.0.0.1.weight", "module.stage2.0.fuse_layers.1.0.0.1.bias", "module.transition2.2.0.0.weight", "module.transition2.2.0.1.weight", "module.transition2.2.0.1.bias", "module.stage3.0.branches.0.0.conv1.weight", "module.stage3.0.branches.0.0.in1.weight", "module.stage3.0.branches.0.0.in1.bias", "module.stage3.0.branches.0.0.conv2.weight", "module.stage3.0.branches.0.0.in2.weight", "module.stage3.0.branches.0.0.in2.bias", "module.stage3.0.branches.0.1.conv1.weight", "module.stage3.0.branches.0.1.in1.weight", "module.stage3.0.branches.0.1.in1.bias", "module.stage3.0.branches.0.1.conv2.weight", "module.stage3.0.branches.0.1.in2.weight", "module.stage3.0.branches.0.1.in2.bias", "module.stage3.0.branches.0.2.conv1.weight", "module.stage3.0.branches.0.2.in1.weight", "module.stage3.0.branches.0.2.in1.bias", "module.stage3.0.branches.0.2.conv2.weight", "module.stage3.0.branches.0.2.in2.weight", "module.stage3.0.branches.0.2.in2.bias", "module.stage3.0.branches.0.3.conv1.weight", "module.stage3.0.branches.0.3.in1.weight", "module.stage3.0.branches.0.3.in1.bias", "module.stage3.0.branches.0.3.conv2.weight", "module.stage3.0.branches.0.3.in2.weight", "module.stage3.0.branches.0.3.in2.bias", "module.stage3.0.branches.1.0.conv1.weight", "module.stage3.0.branches.1.0.in1.weight", "module.stage3.0.branches.1.0.in1.bias", "module.stage3.0.branches.1.0.conv2.weight", "module.stage3.0.branches.1.0.in2.weight", "module.stage3.0.branches.1.0.in2.bias", "module.stage3.0.branches.1.1.conv1.weight", "module.stage3.0.branches.1.1.in1.weight", "module.stage3.0.branches.1.1.in1.bias", "module.stage3.0.branches.1.1.conv2.weight", "module.stage3.0.branches.1.1.in2.weight", "module.stage3.0.branches.1.1.in2.bias", "module.stage3.0.branches.1.2.conv1.weight", "module.stage3.0.branches.1.2.in1.weight", "module.stage3.0.branches.1.2.in1.bias", "module.stage3.0.branches.1.2.conv2.weight", "module.stage3.0.branches.1.2.in2.weight", "module.stage3.0.branches.1.2.in2.bias", "module.stage3.0.branches.1.3.conv1.weight", "module.stage3.0.branches.1.3.in1.weight", "module.stage3.0.branches.1.3.in1.bias", "module.stage3.0.branches.1.3.conv2.weight", "module.stage3.0.branches.1.3.in2.weight", "module.stage3.0.branches.1.3.in2.bias", "module.stage3.0.branches.2.0.conv1.weight", "module.stage3.0.branches.2.0.in1.weight", "module.stage3.0.branches.2.0.in1.bias", "module.stage3.0.branches.2.0.conv2.weight", "module.stage3.0.branches.2.0.in2.weight", "module.stage3.0.branches.2.0.in2.bias", "module.stage3.0.branches.2.1.conv1.weight", "module.stage3.0.branches.2.1.in1.weight", "module.stage3.0.branches.2.1.in1.bias", "module.stage3.0.branches.2.1.conv2.weight", "module.stage3.0.branches.2.1.in2.weight", "module.stage3.0.branches.2.1.in2.bias", "module.stage3.0.branches.2.2.conv1.weight", "module.stage3.0.branches.2.2.in1.weight", "module.stage3.0.branches.2.2.in1.bias", "module.stage3.0.branches.2.2.conv2.weight", "module.stage3.0.branches.2.2.in2.weight", "module.stage3.0.branches.2.2.in2.bias", "module.stage3.0.branches.2.3.conv1.weight", "module.stage3.0.branches.2.3.in1.weight", "module.stage3.0.branches.2.3.in1.bias", "module.stage3.0.branches.2.3.conv2.weight", "module.stage3.0.branches.2.3.in2.weight", "module.stage3.0.branches.2.3.in2.bias", "module.stage3.0.fuse_layers.0.1.0.weight", "module.stage3.0.fuse_layers.0.1.1.weight", "module.stage3.0.fuse_layers.0.1.1.bias", "module.stage3.0.fuse_layers.0.2.0.weight", "module.stage3.0.fuse_layers.0.2.1.weight", "module.stage3.0.fuse_layers.0.2.1.bias", "module.stage3.0.fuse_layers.1.0.0.0.weight", "module.stage3.0.fuse_layers.1.0.0.1.weight", "module.stage3.0.fuse_layers.1.0.0.1.bias", "module.stage3.0.fuse_layers.1.2.0.weight", "module.stage3.0.fuse_layers.1.2.1.weight", "module.stage3.0.fuse_layers.1.2.1.bias", "module.stage3.0.fuse_layers.2.0.0.0.weight", "module.stage3.0.fuse_layers.2.0.0.1.weight", "module.stage3.0.fuse_layers.2.0.0.1.bias", "module.stage3.0.fuse_layers.2.0.1.0.weight", "module.stage3.0.fuse_layers.2.0.1.1.weight", "module.stage3.0.fuse_layers.2.0.1.1.bias", "module.stage3.0.fuse_layers.2.1.0.0.weight", "module.stage3.0.fuse_layers.2.1.0.1.weight", "module.stage3.0.fuse_layers.2.1.0.1.bias", "module.stage3.1.branches.0.0.conv1.weight", "module.stage3.1.branches.0.0.in1.weight", "module.stage3.1.branches.0.0.in1.bias", "module.stage3.1.branches.0.0.conv2.weight", "module.stage3.1.branches.0.0.in2.weight", "module.stage3.1.branches.0.0.in2.bias", "module.stage3.1.branches.0.1.conv1.weight", "module.stage3.1.branches.0.1.in1.weight", "module.stage3.1.branches.0.1.in1.bias", "module.stage3.1.branches.0.1.conv2.weight", "module.stage3.1.branches.0.1.in2.weight", "module.stage3.1.branches.0.1.in2.bias", "module.stage3.1.branches.0.2.conv1.weight", "module.stage3.1.branches.0.2.in1.weight", "module.stage3.1.branches.0.2.in1.bias", "module.stage3.1.branches.0.2.conv2.weight", "module.stage3.1.branches.0.2.in2.weight", "module.stage3.1.branches.0.2.in2.bias", "module.stage3.1.branches.0.3.conv1.weight", "module.stage3.1.branches.0.3.in1.weight", "module.stage3.1.branches.0.3.in1.bias", "module.stage3.1.branches.0.3.conv2.weight", "module.stage3.1.branches.0.3.in2.weight", "module.stage3.1.branches.0.3.in2.bias", "module.stage3.1.branches.1.0.conv1.weight", "module.stage3.1.branches.1.0.in1.weight", "module.stage3.1.branches.1.0.in1.bias", "module.stage3.1.branches.1.0.conv2.weight", "module.stage3.1.branches.1.0.in2.weight", "module.stage3.1.branches.1.0.in2.bias", "module.stage3.1.branches.1.1.conv1.weight", "module.stage3.1.branches.1.1.in1.weight", "module.stage3.1.branches.1.1.in1.bias", "module.stage3.1.branches.1.1.conv2.weight", "module.stage3.1.branches.1.1.in2.weight", "module.stage3.1.branches.1.1.in2.bias", "module.stage3.1.branches.1.2.conv1.weight", "module.stage3.1.branches.1.2.in1.weight", "module.stage3.1.branches.1.2.in1.bias", "module.stage3.1.branches.1.2.conv2.weight", "module.stage3.1.branches.1.2.in2.weight", "module.stage3.1.branches.1.2.in2.bias", "module.stage3.1.branches.1.3.conv1.weight", "module.stage3.1.branches.1.3.in1.weight", "module.stage3.1.branches.1.3.in1.bias", "module.stage3.1.branches.1.3.conv2.weight", "module.stage3.1.branches.1.3.in2.weight", "module.stage3.1.branches.1.3.in2.bias", "module.stage3.1.branches.2.0.conv1.weight", "module.stage3.1.branches.2.0.in1.weight", "module.stage3.1.branches.2.0.in1.bias", "module.stage3.1.branches.2.0.conv2.weight", "module.stage3.1.branches.2.0.in2.weight", "module.stage3.1.branches.2.0.in2.bias", "module.stage3.1.branches.2.1.conv1.weight", "module.stage3.1.branches.2.1.in1.weight", "module.stage3.1.branches.2.1.in1.bias", "module.stage3.1.branches.2.1.conv2.weight", "module.stage3.1.branches.2.1.in2.weight", "module.stage3.1.branches.2.1.in2.bias", "module.stage3.1.branches.2.2.conv1.weight", "module.stage3.1.branches.2.2.in1.weight", "module.stage3.1.branches.2.2.in1.bias", "module.stage3.1.branches.2.2.conv2.weight", "module.stage3.1.branches.2.2.in2.weight", "module.stage3.1.branches.2.2.in2.bias", "module.stage3.1.branches.2.3.conv1.weight", "module.stage3.1.branches.2.3.in1.weight", "module.stage3.1.branches.2.3.in1.bias", "module.stage3.1.branches.2.3.conv2.weight", "module.stage3.1.branches.2.3.in2.weight", "module.stage3.1.branches.2.3.in2.bias", "module.stage3.1.fuse_layers.0.1.0.weight", "module.stage3.1.fuse_layers.0.1.1.weight", "module.stage3.1.fuse_layers.0.1.1.bias", "module.stage3.1.fuse_layers.0.2.0.weight", "module.stage3.1.fuse_layers.0.2.1.weight", "module.stage3.1.fuse_layers.0.2.1.bias", "module.stage3.1.fuse_layers.1.0.0.0.weight", "module.stage3.1.fuse_layers.1.0.0.1.weight", "module.stage3.1.fuse_layers.1.0.0.1.bias", "module.stage3.1.fuse_layers.1.2.0.weight", "module.stage3.1.fuse_layers.1.2.1.weight", "module.stage3.1.fuse_layers.1.2.1.bias", "module.stage3.1.fuse_layers.2.0.0.0.weight", "module.stage3.1.fuse_layers.2.0.0.1.weight", "module.stage3.1.fuse_layers.2.0.0.1.bias", "module.stage3.1.fuse_layers.2.0.1.0.weight", "module.stage3.1.fuse_layers.2.0.1.1.weight", "module.stage3.1.fuse_layers.2.0.1.1.bias", "module.stage3.1.fuse_layers.2.1.0.0.weight", "module.stage3.1.fuse_layers.2.1.0.1.weight", "module.stage3.1.fuse_layers.2.1.0.1.bias", "module.stage3.2.branches.0.0.conv1.weight", "module.stage3.2.branches.0.0.in1.weight", "module.stage3.2.branches.0.0.in1.bias", "module.stage3.2.branches.0.0.conv2.weight", "module.stage3.2.branches.0.0.in2.weight", "module.stage3.2.branches.0.0.in2.bias", "module.stage3.2.branches.0.1.conv1.weight", "module.stage3.2.branches.0.1.in1.weight", "module.stage3.2.branches.0.1.in1.bias", "module.stage3.2.branches.0.1.conv2.weight", "module.stage3.2.branches.0.1.in2.weight", "module.stage3.2.branches.0.1.in2.bias", "module.stage3.2.branches.0.2.conv1.weight", "module.stage3.2.branches.0.2.in1.weight", "module.stage3.2.branches.0.2.in1.bias", "module.stage3.2.branches.0.2.conv2.weight", "module.stage3.2.branches.0.2.in2.weight", "module.stage3.2.branches.0.2.in2.bias", "module.stage3.2.branches.0.3.conv1.weight", "module.stage3.2.branches.0.3.in1.weight", "module.stage3.2.branches.0.3.in1.bias", "module.stage3.2.branches.0.3.conv2.weight", "module.stage3.2.branches.0.3.in2.weight", "module.stage3.2.branches.0.3.in2.bias", "module.stage3.2.branches.1.0.conv1.weight", "module.stage3.2.branches.1.0.in1.weight", "module.stage3.2.branches.1.0.in1.bias", "module.stage3.2.branches.1.0.conv2.weight", "module.stage3.2.branches.1.0.in2.weight", "module.stage3.2.branches.1.0.in2.bias", "module.stage3.2.branches.1.1.conv1.weight", "module.stage3.2.branches.1.1.in1.weight", "module.stage3.2.branches.1.1.in1.bias", "module.stage3.2.branches.1.1.conv2.weight", "module.stage3.2.branches.1.1.in2.weight", "module.stage3.2.branches.1.1.in2.bias", "module.stage3.2.branches.1.2.conv1.weight", "module.stage3.2.branches.1.2.in1.weight", "module.stage3.2.branches.1.2.in1.bias", "module.stage3.2.branches.1.2.conv2.weight", "module.stage3.2.branches.1.2.in2.weight", "module.stage3.2.branches.1.2.in2.bias", "module.stage3.2.branches.1.3.conv1.weight", "module.stage3.2.branches.1.3.in1.weight", "module.stage3.2.branches.1.3.in1.bias", "module.stage3.2.branches.1.3.conv2.weight", "module.stage3.2.branches.1.3.in2.weight", "module.stage3.2.branches.1.3.in2.bias", "module.stage3.2.branches.2.0.conv1.weight", "module.stage3.2.branches.2.0.in1.weight", "module.stage3.2.branches.2.0.in1.bias", "module.stage3.2.branches.2.0.conv2.weight", "module.stage3.2.branches.2.0.in2.weight", "module.stage3.2.branches.2.0.in2.bias", "module.stage3.2.branches.2.1.conv1.weight", "module.stage3.2.branches.2.1.in1.weight", "module.stage3.2.branches.2.1.in1.bias", "module.stage3.2.branches.2.1.conv2.weight", "module.stage3.2.branches.2.1.in2.weight", "module.stage3.2.branches.2.1.in2.bias", "module.stage3.2.branches.2.2.conv1.weight", "module.stage3.2.branches.2.2.in1.weight", "module.stage3.2.branches.2.2.in1.bias", "module.stage3.2.branches.2.2.conv2.weight", "module.stage3.2.branches.2.2.in2.weight", "module.stage3.2.branches.2.2.in2.bias", "module.stage3.2.branches.2.3.conv1.weight", "module.stage3.2.branches.2.3.in1.weight", "module.stage3.2.branches.2.3.in1.bias", "module.stage3.2.branches.2.3.conv2.weight", "module.stage3.2.branches.2.3.in2.weight", "module.stage3.2.branches.2.3.in2.bias", "module.stage3.2.fuse_layers.0.1.0.weight", "module.stage3.2.fuse_layers.0.1.1.weight", "module.stage3.2.fuse_layers.0.1.1.bias", "module.stage3.2.fuse_layers.0.2.0.weight", "module.stage3.2.fuse_layers.0.2.1.weight", "module.stage3.2.fuse_layers.0.2.1.bias", "module.stage3.2.fuse_layers.1.0.0.0.weight", "module.stage3.2.fuse_layers.1.0.0.1.weight", "module.stage3.2.fuse_layers.1.0.0.1.bias", "module.stage3.2.fuse_layers.1.2.0.weight", "module.stage3.2.fuse_layers.1.2.1.weight", "module.stage3.2.fuse_layers.1.2.1.bias", "module.stage3.2.fuse_layers.2.0.0.0.weight", "module.stage3.2.fuse_layers.2.0.0.1.weight", "module.stage3.2.fuse_layers.2.0.0.1.bias", "module.stage3.2.fuse_layers.2.0.1.0.weight", "module.stage3.2.fuse_layers.2.0.1.1.weight", "module.stage3.2.fuse_layers.2.0.1.1.bias", "module.stage3.2.fuse_layers.2.1.0.0.weight", "module.stage3.2.fuse_layers.2.1.0.1.weight", "module.stage3.2.fuse_layers.2.1.0.1.bias", "module.stage3.3.branches.0.0.conv1.weight", "module.stage3.3.branches.0.0.in1.weight", "module.stage3.3.branches.0.0.in1.bias", "module.stage3.3.branches.0.0.conv2.weight", "module.stage3.3.branches.0.0.in2.weight", "module.stage3.3.branches.0.0.in2.bias", "module.stage3.3.branches.0.1.conv1.weight", "module.stage3.3.branches.0.1.in1.weight", "module.stage3.3.branches.0.1.in1.bias", "module.stage3.3.branches.0.1.conv2.weight", "module.stage3.3.branches.0.1.in2.weight", "module.stage3.3.branches.0.1.in2.bias", "module.stage3.3.branches.0.2.conv1.weight", "module.stage3.3.branches.0.2.in1.weight", "module.stage3.3.branches.0.2.in1.bias", "module.stage3.3.branches.0.2.conv2.weight", "module.stage3.3.branches.0.2.in2.weight", "module.stage3.3.branches.0.2.in2.bias", "module.stage3.3.branches.0.3.conv1.weight", "module.stage3.3.branches.0.3.in1.weight", "module.stage3.3.branches.0.3.in1.bias", "module.stage3.3.branches.0.3.conv2.weight", "module.stage3.3.branches.0.3.in2.weight", "module.stage3.3.branches.0.3.in2.bias", "module.stage3.3.branches.1.0.conv1.weight", "module.stage3.3.branches.1.0.in1.weight", "module.stage3.3.branches.1.0.in1.bias", "module.stage3.3.branches.1.0.conv2.weight", "module.stage3.3.branches.1.0.in2.weight", "module.stage3.3.branches.1.0.in2.bias", "module.stage3.3.branches.1.1.conv1.weight", "module.stage3.3.branches.1.1.in1.weight", "module.stage3.3.branches.1.1.in1.bias", "module.stage3.3.branches.1.1.conv2.weight", "module.stage3.3.branches.1.1.in2.weight", "module.stage3.3.branches.1.1.in2.bias", "module.stage3.3.branches.1.2.conv1.weight", "module.stage3.3.branches.1.2.in1.weight", "module.stage3.3.branches.1.2.in1.bias", "module.stage3.3.branches.1.2.conv2.weight", "module.stage3.3.branches.1.2.in2.weight", "module.stage3.3.branches.1.2.in2.bias", "module.stage3.3.branches.1.3.conv1.weight", "module.stage3.3.branches.1.3.in1.weight", "module.stage3.3.branches.1.3.in1.bias", "module.stage3.3.branches.1.3.conv2.weight", "module.stage3.3.branches.1.3.in2.weight", "module.stage3.3.branches.1.3.in2.bias", "module.stage3.3.branches.2.0.conv1.weight", "module.stage3.3.branches.2.0.in1.weight", "module.stage3.3.branches.2.0.in1.bias", "module.stage3.3.branches.2.0.conv2.weight", "module.stage3.3.branches.2.0.in2.weight", "module.stage3.3.branches.2.0.in2.bias", "module.stage3.3.branches.2.1.conv1.weight", "module.stage3.3.branches.2.1.in1.weight", "module.stage3.3.branches.2.1.in1.bias", "module.stage3.3.branches.2.1.conv2.weight", "module.stage3.3.branches.2.1.in2.weight", "module.stage3.3.branches.2.1.in2.bias", "module.stage3.3.branches.2.2.conv1.weight", "module.stage3.3.branches.2.2.in1.weight", "module.stage3.3.branches.2.2.in1.bias", "module.stage3.3.branches.2.2.conv2.weight", "module.stage3.3.branches.2.2.in2.weight", "module.stage3.3.branches.2.2.in2.bias", "module.stage3.3.branches.2.3.conv1.weight", "module.stage3.3.branches.2.3.in1.weight", "module.stage3.3.branches.2.3.in1.bias", "module.stage3.3.branches.2.3.conv2.weight", "module.stage3.3.branches.2.3.in2.weight", "module.stage3.3.branches.2.3.in2.bias", "module.stage3.3.fuse_layers.0.1.0.weight", "module.stage3.3.fuse_layers.0.1.1.weight", "module.stage3.3.fuse_layers.0.1.1.bias", "module.stage3.3.fuse_layers.0.2.0.weight", "module.stage3.3.fuse_layers.0.2.1.weight", "module.stage3.3.fuse_layers.0.2.1.bias", "module.stage3.3.fuse_layers.1.0.0.0.weight", "module.stage3.3.fuse_layers.1.0.0.1.weight", "module.stage3.3.fuse_layers.1.0.0.1.bias", "module.stage3.3.fuse_layers.1.2.0.weight", "module.stage3.3.fuse_layers.1.2.1.weight", "module.stage3.3.fuse_layers.1.2.1.bias", "module.stage3.3.fuse_layers.2.0.0.0.weight", "module.stage3.3.fuse_layers.2.0.0.1.weight", "module.stage3.3.fuse_layers.2.0.0.1.bias", "module.stage3.3.fuse_layers.2.0.1.0.weight", "module.stage3.3.fuse_layers.2.0.1.1.weight", "module.stage3.3.fuse_layers.2.0.1.1.bias", "module.stage3.3.fuse_layers.2.1.0.0.weight", "module.stage3.3.fuse_layers.2.1.0.1.weight", "module.stage3.3.fuse_layers.2.1.0.1.bias", "module.transition3.3.0.0.weight", "module.transition3.3.0.1.weight", "module.transition3.3.0.1.bias", 

and this

Unexpected key(s) in state_dict: "conv1.weight", "in1.weight", "in1.bias", "conv2.weight", "in2.weight", "in2.bias", "layer1.0.conv1.weight", "layer1.0.in1.weight", "layer1.0.in1.bias", 

This is not the full error, I cut some of it out due to character limit.

I hope I can solve this without re-training the model. It took almost a week to finish.


looking inside the state_dict key shows that the layers are there but they are missing the “module” part for some reason…

Doing this fixed it:

        for k, v in state_dict.items():
            if not k.startswith("module."):
                k = f"module.{k}"
            new_state_dict[k] = v
        checkpoint["state_dict"] = new_state_dict
        self.model.load_state_dict(checkpoint["state_dict"])

I basically added what it was looking for. I don’t know what happened initially but it works.

The additional .module keys were added by nn.DataParallel and manipulating the state_dict is a valid approach to fix it. Use strict=False only if you are sure you can ignore missing or unexpected keys.

I must have messed something up since the .module was missing from the state_dict keys! Generally when using DDP, is this a valid basic approach to saving / loading models?

Yes, storing the state_dict from the internal .module allows you to directly load it into a “standard” (i.e. non-DDP) model and is also described here. Otherwise, if you store the state_dict directly from the DDP wrapper you would either need to remove the .module keys before loading the state_dict into a standard model or you could load it into a DDP model.

1 Like