Taking different loss based on conditions in a batch

Hi! I would like to do loss backpropagation through a batch, but this loss value (and hence computation graph) differs per sample, depending on some conditions. I have a way of calculating all my losses, but I’m not sure how to design the backward portion. Let me share a snippet of the code here:

for _ in range(10000):
    # Mask away first column of all instances to start from there
    mask = torch.zeros((bsz, 20, 20))
    mask_idx = 0
    mask[:, :, mask_idx] = 1.0
    mask = mask.bool()
    # Apply sinkhorn then mask
    # torch.nn.parameter.Parameter is in param_list which I want to optimize via gradient descent
    log_probs = batch_sinkhorn(param_list[0], 0.5, 3, data[1])  
    
    # Find the tour
    # Construct choice mask and combine it with the adjacency matrix mask
    decode_mask = torch.zeros(bsz, nb_nodes)
    decode_mask[:, 0] = 1.0
    decode_mask = decode_mask.bool()
    idx = torch.zeros(bsz).long()
    tour = [torch.zeros(bsz).long()]
    prob_list = []
    
    valid_tour_flag = torch.ones(bsz).bool()
    for _ in range(nb_nodes-1):
        log_prob_choice = log_probs.gather(dim=1, index=idx.unsqueeze(1).unsqueeze(2).repeat(1, 1, nb_nodes)).view(bsz, nb_nodes)

        route_mask = data[1].gather(dim=1, index=idx.unsqueeze(1).unsqueeze(2).repeat(1, 1, nb_nodes)).view(bsz, nb_nodes)
        route_mask = (1.0 - route_mask).bool()

        combined_mask = decode_mask + route_mask # |combined_mask| = (bsz, nb_nodes)
        log_prob_choice = log_prob_choice.masked_fill(combined_mask, -1e29)
        idx = log_prob_choice.argmax(dim=-1)
    
        # Check if a tour is valid by seeing if we chose a masked action
        valid_tour_flag = ~combined_mask.gather(dim=1, index=idx.unsqueeze(1)).squeeze(1) * valid_tour_flag
        valid_tour_flag = valid_tour_flag.bool()
        decode_mask = decode_mask.float().scatter(dim=1, index=idx.unsqueeze(1), src=torch.ones((bsz, 1))).bool()
        tour.append(idx)
        prob_list.append(log_prob_choice.gather(dim=1, index=idx.unsqueeze(1)).squeeze(1))
    tour = torch.stack(tour, dim=1)
    sumLogProb = torch.stack(prob_list, dim=1).sum(-1)
    if torch.any(valid_tour_flag):
            print("There are some valid tours! They are for instances: {}".format(torch.where(valid_tour_flag == True)))
    # Compute tour lengths
    length = compute_tour_length(data[0], tour)
    
    log_probs = log_probs.masked_fill(mask, -1e20)

    # Calculate the dag loss
    d_loss = dag_loss(log_probs)
    optimal_tour_len = data[3]
    
    # This multiplies the log probability of the tour with the tour length and the valid tour flag

    loss = torch.zeros(bsz, requires_grad=True)
    # TODO: find out how to take the loss depending on condition
    loss[torch.where(length > 0)] += length[torch.where(length>0)]
    loss[torch.where(length <= 0)] += dag_loss[torch.where(length<0)]

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In general, the loops works as such:

  • I have a bunch of parameters in param_list[0] that I wish to optimize, let’s think of these are just rows of data points with some initial value, with a shape (bsz, 20)
  • I apply an algorithm called batch_sinkhorn to get some log probabilities
  • Based on the log probabilities, I construct a sequence of indexes that I can visit, each time masking out the index already visited
  • The mask is based on my visits and if the city is even allowed to be visited (captured in route_mask)
  • After that, I find out the length of all my visits
  • I calculate my first loss value by using dag_loss which returns a d_loss for each sample
  • If a sample successfully visited all indexes, I wish to take the loss via sumLogProbs * length
  • If a sample is unsuccessful, I wish to take the loss via d_loss

Hopefully this short write-up is clear. Basically, I wish to optimize for one loss if some criterion is met, on samples that were legal. If not, I optimize for another loss for those. Ultimately, all these losses will be pooled together via a mean.

Hi, I recently replied to a few similar topics. You can check them out, you should get some insight about how to solve your problem :slight_smile:

If not, I can write something more personalised

I did it! I re-used a flag that I had. Thanks for this nifty idea!