How to calculate log_softmax for list of tensors without breaking autograd

I’m trying to calculate the log_softmax function of a list of tensors, i.e., a list [t_1, t_2, …, t_n] where each t_i is of type torch.tensor and each t_i can be of a different, arbitrary shape. I do not want to apply the log_softmax function to each t_i separately, but to all of them as if they were part of the same unique tensor. The output of this function should be a list of tensors with the same shape as the input. Lastly, as I will apply this function to the end layer of a neural network, I want to be able to differentiate this function, i.e., the gradients must flow through it.

Pytorch provides the class torch.nn.LogSoftmax, but I cannot use it as it expects a single tensor as input, instead of a list of tensors. Additionally, I want to calculate the log_softmax function efficiently and in a stable way. To achieve that, I want to use the log-sum-exp trick. Lastly, I want to ignore the last value of the first element of the list (see code snippet below), i.e., not apply log_softmax to it.

This is my current implementation:

def log_softmax(pred_tensors):
    minus_inf = -1000 # Constant that represents minus infinity

    # Calculate the max value
    c = max([preds.amax() if preds is not None else minus_inf for preds in pred_tensors])

    # Calculate log(sum(e^(x_i-c)))
    log_sum_exp = 0
    for r in range(len(pred_tensors)):
        if pred_tensors[r] is not None:
            
            # Arity 0 -> ignore nullary predicate corresponding to termination condition
            curr_sum =  torch.sum(torch.exp(pred_tensors[r][:-1] - c))   if r == 0 else \
                        torch.sum(torch.exp(pred_tensors[r] - c))
            log_sum_exp += curr_sum
            
    log_sum_exp = torch.log(log_sum_exp)
        
    # Calculate log_softmax (apply log_softmax to the original tensor) (except to the termination condition)
    for r in range(len(pred_tensors)):
        if pred_tensors[r] is not None:
            # Arity 0 -> ignore nullary predicate corresponding to termination condition
            if r == 0:
                pred_tensors[r][:-1] -= log_sum_exp + c 
            else:    
                pred_tensors[r] -= log_sum_exp + c


    return pred_tensors

I have tested it and it works. However, I think my implementation may be breaking the autograd of Pytorch, in lines c = max([preds.amax() if preds is not None else minus_inf for preds in pred_tensors]) and log_sum_exp += curr_sum.

So, my questions are: Is my implementation really breaking autograd? If it is, can you provide an alternative implementation that works with autograd?

This looks ok to me.

I think that this code should be under with torch.no_grad(), as c is supposed to be a constant without any effect on backpropagation.

With the below code, you can calculate logsumexp on multiple tensors. But I am not sure if it will help your use case.

x = torch.randn(5,3)
y = torch.randn(5,6)
z = torch.randn(5,9)
composed = torch.cat([x, y, z], dim=-1)
logsumexp = torch.logsumexp(composed, dim=-1, keepdim=True)
1 Like

Thank you very much for your answer!

So, if I have understood you correctly, you say that the operation log_sum_exp += curr_sum would correctly calculate the gradients, i.e., not break autograd. I suppose that’s because, even though I am using the inplace operation += instead of something like torch.sum, torch.sum is called under the hood and, thus, the gradient can be calculated for the sum by using +=. Is that correct?

About the constant c, you are right. I thought I had to calculate the gradient for it as it uses the tensor values preds in pred_tensors but, since it is treated as a constant (i.e., I could have chosen any other value for c such as 0), I suppose I can ignore the gradients for that part.

Finally, about your logsumexp implementation with torch.cat(). It does not work for me as the tensors can have arbitrary shapes, and torch.cat() needs for all the dimensions to be the same except one. For example, imagine I want to calculate logsumexp for the tensors x = torch.randn(5,3) and y = torch.randn(3, 7). You can’t do composed = torch.cat([x,y]) in this case, that’s why I am stuck with my manual implementation.

Yes, you are right.

Can you elaborate a bit more on the context? Why are these tensors of different size?
In your code, it looks like you are considering all the elements of the (possibly) multi-dimensional tensors, as if they are vectors.
So, the below code might work?

x = torch.randn(5,3)
y = torch.randn(5,6)
z = torch.randn(5,9)
composed = torch.cat([x.flatten(), y.flatten(), z.flatten()], dim=-1)
logsumexp = torch.logsumexp(composed, dim=-1, keepdim=True)
1 Like

So, about the context. The reason why I have a list with tensors of different sizes is because they correspond to the outputs of a Neural Logic Machine. I want to sample one of those (grounded) predicates, so all the values must be positive and sum to 1. That’s why I am using logsoftmax.

Regarding the code, before your answer I did not know Pytorch provided the function logsumexp. So, after using such function and adapting my code with yours, this is the result:

 def _log_softmax(self, pred_tensors):
        # Remove the nullary predicate associated with the termination condition, so that it does not
        # affect the log_softmax computation
        term_cond_value = pred_tensors[0][-1]
        pred_tensors[0] = pred_tensors[0][:-1]
        
        # Calculate log_sum_exp of all the values in the tensors of the list
        # 1) flatten each tensor in the list
        # 2) concatenate them as a unique tensor
        # 3) calculate log_sum_exp
        log_sum_exp = torch.logsumexp(torch.cat([preds.flatten() if preds is not None else torch.empty(0, dtype=torch.float32) for preds in pred_tensors]), dim=-1)
    
        # Use log_sum_exp to calculate the log_softmax of the tensors in the list
        for r in range(len(pred_tensors)):
            if pred_tensors[r] is not None:
                pred_tensors[r] -= log_sum_exp
    
        # Append the nullary predicate corresponding to the termination condition
        pred_tensors[0] = torch.cat([pred_tensors[0], term_cond_value.reshape(1)]) # We need reshape() to transform from tensor of dimension 0 to dimension 1
        
        return pred_tensors

As you can observe, I use the function logsumexp to easily calculate the logsoftmax of my list of tensors. Since I do not want to change the value pred_tensors[0][-1], I simply remove it from the first tensor of the list, perform the calculations and append it after the calculations. Additionally, I ignore those elements of the list corresponding to None (instead of a valid tensor), in line [preds.flatten() if preds is not None else torch.empty(0, dtype=torch.float32) for preds in pred_tensors].

I have tested this code and is able to correctly calculate the logsoftmax function. Do you think my implementation works with autograd or is there any part of it which could potentially break the gradients?

It looks fine to me.

1 Like

Thank you very much for your help! I mark the issue as solved, then.