Hi! I would like to do loss backpropagation through a batch, but this loss value (and hence computation graph) differs per sample, depending on some conditions. I have a way of calculating all my losses, but I’m not sure how to design the backward portion. Let me share a snippet of the code here:
for _ in range(10000):
# Mask away first column of all instances to start from there
mask = torch.zeros((bsz, 20, 20))
mask_idx = 0
mask[:, :, mask_idx] = 1.0
mask = mask.bool()
# Apply sinkhorn then mask
# torch.nn.parameter.Parameter is in param_list which I want to optimize via gradient descent
log_probs = batch_sinkhorn(param_list[0], 0.5, 3, data[1])
# Find the tour
# Construct choice mask and combine it with the adjacency matrix mask
decode_mask = torch.zeros(bsz, nb_nodes)
decode_mask[:, 0] = 1.0
decode_mask = decode_mask.bool()
idx = torch.zeros(bsz).long()
tour = [torch.zeros(bsz).long()]
prob_list = []
valid_tour_flag = torch.ones(bsz).bool()
for _ in range(nb_nodes-1):
log_prob_choice = log_probs.gather(dim=1, index=idx.unsqueeze(1).unsqueeze(2).repeat(1, 1, nb_nodes)).view(bsz, nb_nodes)
route_mask = data[1].gather(dim=1, index=idx.unsqueeze(1).unsqueeze(2).repeat(1, 1, nb_nodes)).view(bsz, nb_nodes)
route_mask = (1.0 - route_mask).bool()
combined_mask = decode_mask + route_mask # |combined_mask| = (bsz, nb_nodes)
log_prob_choice = log_prob_choice.masked_fill(combined_mask, -1e29)
idx = log_prob_choice.argmax(dim=-1)
# Check if a tour is valid by seeing if we chose a masked action
valid_tour_flag = ~combined_mask.gather(dim=1, index=idx.unsqueeze(1)).squeeze(1) * valid_tour_flag
valid_tour_flag = valid_tour_flag.bool()
decode_mask = decode_mask.float().scatter(dim=1, index=idx.unsqueeze(1), src=torch.ones((bsz, 1))).bool()
tour.append(idx)
prob_list.append(log_prob_choice.gather(dim=1, index=idx.unsqueeze(1)).squeeze(1))
tour = torch.stack(tour, dim=1)
sumLogProb = torch.stack(prob_list, dim=1).sum(-1)
if torch.any(valid_tour_flag):
print("There are some valid tours! They are for instances: {}".format(torch.where(valid_tour_flag == True)))
# Compute tour lengths
length = compute_tour_length(data[0], tour)
log_probs = log_probs.masked_fill(mask, -1e20)
# Calculate the dag loss
d_loss = dag_loss(log_probs)
optimal_tour_len = data[3]
# This multiplies the log probability of the tour with the tour length and the valid tour flag
loss = torch.zeros(bsz, requires_grad=True)
# TODO: find out how to take the loss depending on condition
loss[torch.where(length > 0)] += length[torch.where(length>0)]
loss[torch.where(length <= 0)] += dag_loss[torch.where(length<0)]
optimizer.zero_grad()
loss.backward()
optimizer.step()
In general, the loops works as such:
- I have a bunch of parameters in
param_list[0]
that I wish to optimize, let’s think of these are just rows of data points with some initial value, with a shape(bsz, 20)
- I apply an algorithm called
batch_sinkhorn
to get some log probabilities - Based on the log probabilities, I construct a sequence of indexes that I can visit, each time masking out the index already visited
- The mask is based on my visits and if the city is even allowed to be visited (captured in
route_mask
) - After that, I find out the length of all my visits
- I calculate my first loss value by using
dag_loss
which returns ad_loss
for each sample - If a sample successfully visited all indexes, I wish to take the loss via
sumLogProbs * length
- If a sample is unsuccessful, I wish to take the loss via
d_loss
Hopefully this short write-up is clear. Basically, I wish to optimize for one loss if some criterion is met, on samples that were legal. If not, I optimize for another loss for those. Ultimately, all these losses will be pooled together via a mean.