Model Weights not Updating Correctly After Loading

Hello, I currently have an Encoder-Decoder architecture using ResNet-34 as the CNN Encoder and an LSTM with Soft Attention as the Decoder. When I train my model on Google Colab everything works well during training, and I save the models’ state dicts accordingly. I then load the weights in the same runtime that I trained, evaluate the performance, and it seems to work well. However, when I restart the runtime and try to reload the weights I saved and evaluate the performance it’s as if the weights haven’t been saved or trained at all. At first I thought it was an issue with training on the GPU in colab and then loading with map_location on the CPU but it showed the same issue even when training and loading solely on the CPU. Any insight on the issue would be appreciated thank you!

Training Code:

ENCODER_PATH = CDIR + '/ImageCaptioning/encoder.pth'
DECODER_PATH = CDIR + '/ImageCaptioning/decoder.pth'

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ALPHA_COEF = 1

encoder = Encoder().to(DEVICE)
decoder = Decoder(256, 256, num_tokens, 300, DEVICE).to(DEVICE)

if os.path.exists(ENCODER_PATH) and os.path.exists(DECODER_PATH):
    encoder.load_state_dict(torch.load(ENCODER_PATH, map_location=DEVICE))
    decoder.load_state_dict(torch.load(DECODER_PATH, map_location=DEVICE))

encoder_optim = torch.optim.Adam(encoder.parameters(), lr=1e-4)
decoder_optim = torch.optim.Adam(decoder.parameters(), lr=1e-3)

criterion = nn.CrossEntropyLoss().to(DEVICE)

def train(epochs):
    total_loss = []
    # steps before evaluating/plotting
    iters = 5
    encoder.train()
    decoder.train()
    for e in range(epochs):
        for i, data in enumerate(trainloader):
            images, captions = data
            images = images.to(DEVICE)
            matrix, lengths = get_matrix_and_lengths(captions)
            encoded_captions = torch.tensor(matrix, dtype=torch.int64).to(DEVICE)
            caption_lengths = torch.tensor(lengths, dtype=torch.int64).to(DEVICE)
            
            for c in range(len(encoded_captions)):
                caption = encoded_captions[c]
                caption_length = caption_lengths[c]

                features = encoder(images)
                logits, alphas, sorted_caption, decode_lengths = decoder(features, caption, caption_length)

                next_tokens = sorted_caption[:, 1:]
                next_tokens = pack_padded_sequence(next_tokens, decode_lengths, batch_first=True)[0]
                logits = pack_padded_sequence(logits, decode_lengths, batch_first=True)[0]

                loss = criterion(logits, next_tokens)
                loss += ALPHA_COEF * ((1 - alphas.sum(dim=1)).pow(2)).mean()

                encoder_optim.zero_grad()
                decoder_optim.zero_grad()

                loss.backward()

                encoder_optim.step()
                decoder_optim.step()

                total_loss.append(loss.detach().cpu().numpy())

            if (i+1) % iters == 0:
                torch.save(encoder.state_dict(), ENCODER_PATH)
                torch.save(decoder.state_dict(), DECODER_PATH)

                print(f'[{e+1}, {i+1}] Loss: {np.mean(total_loss[-iters]):.3f}')
                plot(total_loss)

Networks:

# ResNet-34 CNN Encoder
class Encoder(nn.Module):
    def __init__(self, output_dim=14):
        super().__init__()
        resnet = resnet34(weights=ResNet34_Weights.DEFAULT)
        layers = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*layers)
        # adaptive pool layer so encoder can take images of different sizes
        self.resize = nn.AdaptiveAvgPool2d((output_dim, output_dim))

        self.fine_tune()

    def forward(self, x):
        x = self.resnet(x)
        x = self.resize(x)
        x = x.permute(0, 2, 3, 1)
        return x

    # disable learning up to first three res blocks
    def fine_tune(self):
        for l in list(self.resnet.children())[:5]:
            for p in l.parameters():
                p.requires_grad = False

# Soft-Attention Network
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        # [b_size, image_size, encoder_dim]
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        # [b_size, decoder_dim]
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, features, hidden):
        att_features = self.encoder_att(features)
        att_hidden = self.decoder_att(hidden)
        att_cat = self.relu(att_features + att_hidden.unsqueeze(1))
        alpha_logits = self.att(att_cat).squeeze(2)
        # [b_size, image_size]
        alpha = self.softmax(alpha_logits)
        features_weighted = (features * alpha.unsqueeze(2)).sum(dim=1)

        return features_weighted, alpha


class Decoder(nn.Module):
    def __init__(self, decoder_dim, attention_dim, num_tokens, embed_size, device, encoder_dim=512):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.attention_dim = attention_dim
        self.num_tokens = num_tokens
        self.device = device

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim).to(self.device)

        self.init_h0 = nn.Linear(encoder_dim, decoder_dim)
        self.init_c0 = nn.Linear(encoder_dim, decoder_dim)

        self.embedding = nn.Embedding(num_tokens, embed_size)
        self.lstm = nn.LSTMCell(embed_size + encoder_dim, decoder_dim)
        self.dropout = nn.Dropout(p=0.4)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, num_tokens)
        
    def initialize(self, features):
        # [b_size, image_size, encoder_dim]
        features = features.mean(dim=1)
        h0 = self.init_h0(features)
        c0 = self.init_c0(features)
        
        return h0, c0

    def forward(self, features, captions, caption_lengths):
        batch_size = features.shape[0]

        # [b_size, image_size, encoder_dim]
        features = features.reshape(batch_size, -1, self.encoder_dim)

        # sort captions and features in descending order by caption length
        caption_lengths, sort_indices = caption_lengths.sort(descending=True)
        captions = captions[sort_indices]
        features = features[sort_indices]

        h, c = self.initialize(features)
        # [b_size, max_length, embed_size]
        embedding = self.embedding(captions)

        decode_lengths = (caption_lengths - 1).tolist()
        max_length = max(decode_lengths)
        logits = torch.zeros(batch_size, max_length, self.num_tokens).to(self.device)
        alphas = torch.zeros(batch_size, max_length, features.shape[1]).to(self.device)

        for t in range(max_length):
            batch_t = sum([l > t for l in decode_lengths])
            # [b_size, encoder_dim]
            features_weighted, alpha = self.attention(features[:batch_t], h[:batch_t])
            gate = self.sigmoid(self.f_beta(h[:batch_t]))
            features_weighted = features_weighted * gate

            # cat: [b_size, embed_size], [b_size, encoder_dim]
            input = torch.cat((embedding[:batch_t, t, :], features_weighted), dim=1)
            h, c = self.lstm(input, (h[:batch_t], c[:batch_t]))

            logit = self.fc(self.dropout(h))
            logits[:batch_t, t, :] = logit
            alphas[:batch_t, t, :] = alpha

        return logits, alphas, captions, decode_lengths

I would probably start by verifying that the pretrained models are indeed loaded.
You are using a good approach to check if the state_dicts are available via:

if os.path.exists(ENCODER_PATH) and os.path.exists(DECODER_PATH):
    encoder.load_state_dict(torch.load(ENCODER_PATH, map_location=DEVICE))
    decoder.load_state_dict(torch.load(DECODER_PATH, map_location=DEVICE))

but are never checking for errors (e.g. in case the file location cannot be accessed), so add a debug print statement into the loading logic or an else path which would raise an error.

Also, once this is verified check the model outputs using static inputs (e.g. torch.ones) and make sure they are approx. equal for both use cases after calling model.eval().
If that’s the case, then I would continue checking the data loading and processing as the difference might be coming from this part of the code.

Thank you so much for the response! When you mentioned that the difference could be coming from the processing it made me realize that I totally forgot to sort the vocabulary after processing the dataset and that it was giving me a different ordering each time.