Hello, I currently have an Encoder-Decoder architecture using ResNet-34 as the CNN Encoder and an LSTM with Soft Attention as the Decoder. When I train my model on Google Colab everything works well during training, and I save the models’ state dicts accordingly. I then load the weights in the same runtime that I trained, evaluate the performance, and it seems to work well. However, when I restart the runtime and try to reload the weights I saved and evaluate the performance it’s as if the weights haven’t been saved or trained at all. At first I thought it was an issue with training on the GPU in colab and then loading with map_location on the CPU but it showed the same issue even when training and loading solely on the CPU. Any insight on the issue would be appreciated thank you!
Training Code:
ENCODER_PATH = CDIR + '/ImageCaptioning/encoder.pth'
DECODER_PATH = CDIR + '/ImageCaptioning/decoder.pth'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ALPHA_COEF = 1
encoder = Encoder().to(DEVICE)
decoder = Decoder(256, 256, num_tokens, 300, DEVICE).to(DEVICE)
if os.path.exists(ENCODER_PATH) and os.path.exists(DECODER_PATH):
encoder.load_state_dict(torch.load(ENCODER_PATH, map_location=DEVICE))
decoder.load_state_dict(torch.load(DECODER_PATH, map_location=DEVICE))
encoder_optim = torch.optim.Adam(encoder.parameters(), lr=1e-4)
decoder_optim = torch.optim.Adam(decoder.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss().to(DEVICE)
def train(epochs):
total_loss = []
# steps before evaluating/plotting
iters = 5
encoder.train()
decoder.train()
for e in range(epochs):
for i, data in enumerate(trainloader):
images, captions = data
images = images.to(DEVICE)
matrix, lengths = get_matrix_and_lengths(captions)
encoded_captions = torch.tensor(matrix, dtype=torch.int64).to(DEVICE)
caption_lengths = torch.tensor(lengths, dtype=torch.int64).to(DEVICE)
for c in range(len(encoded_captions)):
caption = encoded_captions[c]
caption_length = caption_lengths[c]
features = encoder(images)
logits, alphas, sorted_caption, decode_lengths = decoder(features, caption, caption_length)
next_tokens = sorted_caption[:, 1:]
next_tokens = pack_padded_sequence(next_tokens, decode_lengths, batch_first=True)[0]
logits = pack_padded_sequence(logits, decode_lengths, batch_first=True)[0]
loss = criterion(logits, next_tokens)
loss += ALPHA_COEF * ((1 - alphas.sum(dim=1)).pow(2)).mean()
encoder_optim.zero_grad()
decoder_optim.zero_grad()
loss.backward()
encoder_optim.step()
decoder_optim.step()
total_loss.append(loss.detach().cpu().numpy())
if (i+1) % iters == 0:
torch.save(encoder.state_dict(), ENCODER_PATH)
torch.save(decoder.state_dict(), DECODER_PATH)
print(f'[{e+1}, {i+1}] Loss: {np.mean(total_loss[-iters]):.3f}')
plot(total_loss)
Networks:
# ResNet-34 CNN Encoder
class Encoder(nn.Module):
def __init__(self, output_dim=14):
super().__init__()
resnet = resnet34(weights=ResNet34_Weights.DEFAULT)
layers = list(resnet.children())[:-2]
self.resnet = nn.Sequential(*layers)
# adaptive pool layer so encoder can take images of different sizes
self.resize = nn.AdaptiveAvgPool2d((output_dim, output_dim))
self.fine_tune()
def forward(self, x):
x = self.resnet(x)
x = self.resize(x)
x = x.permute(0, 2, 3, 1)
return x
# disable learning up to first three res blocks
def fine_tune(self):
for l in list(self.resnet.children())[:5]:
for p in l.parameters():
p.requires_grad = False
# Soft-Attention Network
class Attention(nn.Module):
def __init__(self, encoder_dim, decoder_dim, attention_dim):
super().__init__()
# [b_size, image_size, encoder_dim]
self.encoder_att = nn.Linear(encoder_dim, attention_dim)
# [b_size, decoder_dim]
self.decoder_att = nn.Linear(decoder_dim, attention_dim)
self.att = nn.Linear(attention_dim, 1)
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
def forward(self, features, hidden):
att_features = self.encoder_att(features)
att_hidden = self.decoder_att(hidden)
att_cat = self.relu(att_features + att_hidden.unsqueeze(1))
alpha_logits = self.att(att_cat).squeeze(2)
# [b_size, image_size]
alpha = self.softmax(alpha_logits)
features_weighted = (features * alpha.unsqueeze(2)).sum(dim=1)
return features_weighted, alpha
class Decoder(nn.Module):
def __init__(self, decoder_dim, attention_dim, num_tokens, embed_size, device, encoder_dim=512):
super().__init__()
self.encoder_dim = encoder_dim
self.decoder_dim = decoder_dim
self.attention_dim = attention_dim
self.num_tokens = num_tokens
self.device = device
self.attention = Attention(encoder_dim, decoder_dim, attention_dim).to(self.device)
self.init_h0 = nn.Linear(encoder_dim, decoder_dim)
self.init_c0 = nn.Linear(encoder_dim, decoder_dim)
self.embedding = nn.Embedding(num_tokens, embed_size)
self.lstm = nn.LSTMCell(embed_size + encoder_dim, decoder_dim)
self.dropout = nn.Dropout(p=0.4)
self.f_beta = nn.Linear(decoder_dim, encoder_dim)
self.sigmoid = nn.Sigmoid()
self.fc = nn.Linear(decoder_dim, num_tokens)
def initialize(self, features):
# [b_size, image_size, encoder_dim]
features = features.mean(dim=1)
h0 = self.init_h0(features)
c0 = self.init_c0(features)
return h0, c0
def forward(self, features, captions, caption_lengths):
batch_size = features.shape[0]
# [b_size, image_size, encoder_dim]
features = features.reshape(batch_size, -1, self.encoder_dim)
# sort captions and features in descending order by caption length
caption_lengths, sort_indices = caption_lengths.sort(descending=True)
captions = captions[sort_indices]
features = features[sort_indices]
h, c = self.initialize(features)
# [b_size, max_length, embed_size]
embedding = self.embedding(captions)
decode_lengths = (caption_lengths - 1).tolist()
max_length = max(decode_lengths)
logits = torch.zeros(batch_size, max_length, self.num_tokens).to(self.device)
alphas = torch.zeros(batch_size, max_length, features.shape[1]).to(self.device)
for t in range(max_length):
batch_t = sum([l > t for l in decode_lengths])
# [b_size, encoder_dim]
features_weighted, alpha = self.attention(features[:batch_t], h[:batch_t])
gate = self.sigmoid(self.f_beta(h[:batch_t]))
features_weighted = features_weighted * gate
# cat: [b_size, embed_size], [b_size, encoder_dim]
input = torch.cat((embedding[:batch_t, t, :], features_weighted), dim=1)
h, c = self.lstm(input, (h[:batch_t], c[:batch_t]))
logit = self.fc(self.dropout(h))
logits[:batch_t, t, :] = logit
alphas[:batch_t, t, :] = alpha
return logits, alphas, captions, decode_lengths