I am trying to replicate https://github.com/Pranshu258/Deep_Image_Captioning which is A2C model applied to image captioning.
The training loop code is below. But with each epoch my GPU memory keeps filling up and after several iterations, training breaks as GPU goes out of memory.
I tried ‘del’ of the captions_in_v and features_in_v tensors at the end of the episode loop, but still, GPU memory is not filled. I am not able to understand why GPU memory does not get free after each episode loop. Any inputs would be helpful.
def train_a2cNetwork(train_data=None, epoch_count=10, episodes=100, model_save_rate=0.10):
rewardNet = RewardNetwork(data["word_to_idx"]).to(device)
policyNet = PolicyNetwork(data["word_to_idx"]).to(device)
valueNet = ValueNetwork(data["word_to_idx"]).to(device)
rewardNet.load_state_dict(torch.load('models/rewardNetwork.pt'))
policyNet.load_state_dict(torch.load('models/policyNetwork.pt'))
valueNet.load_state_dict(torch.load('models/valueNetwork.pt'))
a2cNetwork = AdvantageActorCriticNetwork(valueNet, policyNet).to(device)
a2cNetwork.train(True)
optimizer = optim.Adam(a2cNetwork.parameters(), lr=0.0001)
for epoch in range(epoch_count):
episodicAvgLoss = 0
captions, features, urls = sample_coco_minibatch(train_data, batch_size=episodes, split='train')
for episode in range(episodes):
log_probs = []
values = []
rewards = []
captions_in = captions[episode:episode + 1, :]
features_in = features[episode:episode + 1]
captions_in_v = torch.tensor(captions_in, device=device).long()
features_in_v = torch.tensor(features_in, device=device).float()
value, probs = a2cNetwork(features_in_v, captions_in_v)
probs = F.softmax(probs, dim=2)
dist = probs.cpu().detach().numpy()[0, 0]
action = np.random.choice(probs.shape[-1], p=dist)
gen_cap = torch.from_numpy(np.array([action])).unsqueeze(0).to(device)
captions_in_v = torch.cat((captions_in_v, gen_cap), axis=1)
log_prob = torch.log(probs[0, 0, action])
reward = GetRewards(features_in_v, captions_in_v, rewardNet)
reward = reward.cpu().detach().numpy()[0, 0]
rewards.append(reward)
values.append(value)
log_probs.append(log_prob)
values = torch.FloatTensor(values).to(device)
rewards = torch.FloatTensor(rewards).to(device)
log_probs = torch.stack(log_probs).to(device)
advantage = values - rewards
actorLoss = (-log_probs * advantage).mean()
criticLoss = 0.5 * advantage.pow(2).mean()
loss = actorLoss + criticLoss
episodicAvgLoss += loss.item() / episodes
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"[training] epoch:{epoch} episodicAvgLoss: {episodicAvgLoss}")
if epoch % epoch_count_for_save == 0:
model_name = "epoch-%d_%s" % (epoch, A2CNETWORK_WEIGHTS_FILE)
print_green(f"[training] Saving intermediate model : {model_name}")
torch.save(a2cNetwork.state_dict(), os.path.join(LOG_DIR, model_name))