I’m trying to optimize two models, a generator and a discriminator both of which use the embeddings generated by the generator model.
This is my code:
generator_source_output = generator_task_model(source_data,
token_type_ids = None,
attention_mask = source_masks,
labels = source_labels)
source_CE_loss = generator_source_output[0]
source_last_hidden_state = generator_source_output[2][-1][:, 0, :]
source_contrastive_loss = compute_contrastive_loss(temperature,
source_last_hidden_state.cpu().detach().numpy(), source_labels)
task_loss = (lambd * source_contrastive_loss) + (1-lambd)*(source_CE_loss)
#Invoking the generator (+ task) model on target data
generator_target_output = generator_task_model(target_data,
token_type_ids = None,
attention_mask = target_masks)
target_last_hidden_state = generator_target_output[1][-1][:, 0, :]
# Training the language discriminator
d_english = discriminator_model(source_last_hidden_state)
d_target = discriminator_model(target_last_hidden_state)
d_english_loss = criterion(d_english.to(device), torch.ones(d_english.size()).to(device))
d_target_loss = criterion(d_target.to(device), torch.zeros(d_target.size()).to(device))
discriminator_loss = d_english_loss + d_target_loss
discriminator_optimizer.zero_grad()
discriminator_loss.backward()
discriminator_optimizer.step()
total_d_loss = total_d_loss + discriminator_loss.item()
# Training the generator
g_english = discriminator_model(source_last_hidden_state)
g_target = discriminator_model(target_last_hidden_state)
g_english_loss = criterion(g_english.to(device), torch.zeros(g_english.size()).to(device))
g_target_loss = criterion(g_target.to(device), torch.ones(g_target.size()).to(device))
generator_loss = g_english_loss + g_target_loss + task_loss
generator_optimizer.zero_grad()
generator_loss.backward()
generator_optimizer.step()
I’m trying to optimize two models, a generator and a discriminator both of which use the embeddings generated by the generator model.
This is my code:
generator_source_output = generator_task_model(source_data,
token_type_ids = None,
attention_mask = source_masks,
labels = source_labels)
source_CE_loss = generator_source_output[0]
source_last_hidden_state = generator_source_output[2][-1][:, 0, :]
source_contrastive_loss = compute_contrastive_loss(temperature,
source_last_hidden_state.cpu().detach().numpy(), source_labels)
task_loss = (lambd * source_contrastive_loss) + (1-lambd)*(source_CE_loss)
#Invoking the generator (+ task) model on target data
generator_target_output = generator_task_model(target_data,
token_type_ids = None,
attention_mask = target_masks)
target_last_hidden_state = generator_target_output[1][-1][:, 0, :]
# Training the language discriminator
d_english = discriminator_model(source_last_hidden_state)
d_target = discriminator_model(target_last_hidden_state)
d_english_loss = criterion(d_english.to(device), torch.ones(d_english.size()).to(device))
d_target_loss = criterion(d_target.to(device), torch.zeros(d_target.size()).to(device))
discriminator_loss = d_english_loss + d_target_loss
discriminator_optimizer.zero_grad()
discriminator_loss.backward()
discriminator_optimizer.step()
total_d_loss = total_d_loss + discriminator_loss.item()
# Training the generator
g_english = discriminator_model(source_last_hidden_state)
g_target = discriminator_model(target_last_hidden_state)
g_english_loss = criterion(g_english.to(device), torch.zeros(g_english.size()).to(device))
g_target_loss = criterion(g_target.to(device), torch.ones(g_target.size()).to(device))
generator_loss = g_english_loss + g_target_loss + task_loss
generator_optimizer.zero_grad()
generator_loss.backward()
generator_optimizer.step()
On running this, I get the error: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
on this line generator_loss.backward()
.
I tried setting discriminator_loss.backward(retain_graph=True)
but it takes too long to train and I read that it not the most appropriate solution. So, can someone please help me get rid of this error?