def sample_model(x: list):
log_prob_sum = 0
padded_len = len(max(x, key=len))
padded_x = list(map(lambda s: [char for char in s] + [PAD] * (padded_len - len(s)), x))
src = indexTensor(padded_x, padded_len, CHARACTERS).to(DEVICE)
lng = lengthTensor(x).to(DEVICE)
hidden = encoder.forward(src, lng)
lstm_input = targetTensor([SOS] * MINI_BATCH_SZ, 1, CHARACTERS).to(DEVICE)
names = [''] * MINI_BATCH_SZ
# padded_len + 1 since as length of word increases the Levenshtein distance size goes down
for i in range(padded_len + 1):
lstm_probs, hidden = decoder.forward(lstm_input, hidden)
categorical = torch.distributions.Categorical(
probs=lstm_probs.squeeze().exp())
sample = categorical.sample()
log_prob_sum += categorical.log_prob(sample).sum()
for j in range(MINI_BATCH_SZ):
names[j] += CHARACTERS[sample[j].item()]
lstm_input = sample.unsqueeze(0)
return names, log_prob_sum
def iterate_train(dl: DataLoader, path: str = "Checkpoints/"):
all_losses = []
num_model_iterations = 0
scores_list = []
for epoch_index in range(1, ITER + 1):
for batch_index, x in enumerate(dl):
# Zero gradient in models
encoder_opt.zero_grad()
decoder_opt.zero_grad()
# Generate noised outputs
generated_names, log_prob_sum = sample_model(x)
# Split generated names
noised_list = [name.split(EOS)[0] for name in generated_names]
# Get summary stats of batch
sample_stats_sum_tensor = get_summary_stats_tensor(noised_list, x)
# Score batch
distance = torch.dist(sample_stats_sum_tensor, obs_stats_sum_tensor, p=2).detach()
score = distance * log_prob_sum
scores_list.append(score)
if batch_index % NUM_SAMPLE == 0:
# Multiply be -1 because doing gradient descent
reinforce_loss = -1 * torch.mean(torch.FloatTensor(scores_list))
reinforce_loss.backward()
encoder_opt.step()
decoder_opt.step()
# Zero out metrics
scores_list = []
I’m getting an error at the backward. I’m assuming it’s cause I’m multiplying a detached value by a non-detached one? But distance needs to be detached cause the way it’s calculated doesn’t allow a gradient to flow through it. But log_prob_sum does have a gradient flowing through it.