I’m trying to calculate the *log_softmax* function of a list of tensors, i.e., a list [t_1, t_2, …, t_n] where each *t_i* is of type *torch.tensor* and each *t_i* can be of a different, arbitrary shape. **I do not want to apply the log_softmax function to each t_i separately, but to all of them as if they were part of the same unique tensor.** The output of this function should be a list of tensors with the same shape as the input. Lastly, as I will apply this function to the end layer of a neural network,

**I want to be able to differentiate this function**, i.e., the gradients must flow through it.

Pytorch provides the class torch.nn.LogSoftmax, but I cannot use it as it expects a single tensor as input, instead of a list of tensors. Additionally, I want to calculate the log_softmax function efficiently and in a stable way. To achieve that, I want to use the *log-sum-exp* trick. Lastly, I want to ignore the last value of the first element of the list (see code snippet below), i.e., not apply log_softmax to it.

This is my current implementation:

```
def log_softmax(pred_tensors):
minus_inf = -1000 # Constant that represents minus infinity
# Calculate the max value
c = max([preds.amax() if preds is not None else minus_inf for preds in pred_tensors])
# Calculate log(sum(e^(x_i-c)))
log_sum_exp = 0
for r in range(len(pred_tensors)):
if pred_tensors[r] is not None:
# Arity 0 -> ignore nullary predicate corresponding to termination condition
curr_sum = torch.sum(torch.exp(pred_tensors[r][:-1] - c)) if r == 0 else \
torch.sum(torch.exp(pred_tensors[r] - c))
log_sum_exp += curr_sum
log_sum_exp = torch.log(log_sum_exp)
# Calculate log_softmax (apply log_softmax to the original tensor) (except to the termination condition)
for r in range(len(pred_tensors)):
if pred_tensors[r] is not None:
# Arity 0 -> ignore nullary predicate corresponding to termination condition
if r == 0:
pred_tensors[r][:-1] -= log_sum_exp + c
else:
pred_tensors[r] -= log_sum_exp + c
return pred_tensors
```

I have tested it and it works. However, I think my implementation **may be breaking the autograd** of Pytorch, in lines `c = max([preds.amax() if preds is not None else minus_inf for preds in pred_tensors])`

and `log_sum_exp += curr_sum`

.

So, my questions are: **Is my implementation really breaking autograd? If it is, can you provide an alternative implementation that works with autograd?**