Is there a pytorch implementation of tensor flow segment softmax?

I have a 1d Tensor and a sorted list of the same length with number indicating which elements are of the same group. I’d like to compute softmax for each group, and return a tensor with the same length of input tensor.

For example:

logits = Tensor([1,2,1,3,1])
segment_id = np.array([0,0,0,1,1])

Tensor([0.2119, 0.5761, 0.2119, 0.8808, 0.1192])

(0.2119, 0.5761, 0.2119 is the softmax output of [1,2,1], (0.8808, 0.1192) is the softmax output of [3,1])

I haven’t found this kind of method method in pytorch until now. I would like to know if there is such method in pytorch. If not, what’s the best practice to implement it in pytorch?

Many thanks

If you have only moderate tensor sizes, broadcasting probably is an easy way to achieve this:

import math
logits = torch.tensor([1.,2,1,3,1])
segment_id = torch.tensor([0,0,0,1,1])
segment_all_ids = torch.arange(2)
neginf = torch.full((), -math.inf, device=logits.device, dtype=logits.dtype)
logits_per_seg = torch.where(segment_id==segment_all_ids[:, None], logits, neginf)
p_per_seg = logits_per_seg.softmax(1)
p = p_per_seg.sum(0)

Best regards


Hi Thomas,

Thank you for your reply. Unfortunately I have to deal with tensors with large size, roughly around 250000. I have tried broadcasting, it’s slow and the computation of gradient is even slower.

Do you have any idea how to deal with tensor with large size?

Best regards

Hi Peng,

implementing custom kernels aside, the main alternative probably is looping over the segments. If you have sorted segments and not too many (say 10-100), it should not be unreasonably inefficient.

Best regards


Hi Thomas,

Thank you for your reply. I have tried looping over the segments, it does the job but not really efficient.

In case some one also interests in this question: at the end I found a relatively more efficient way: construct a matrix to first calculate the denominator of softmax and then do the division element-wise.

def segment_softmax(logits, segment_ids):
    logits_len = len(segment_ids)
    num_segments = max(segment_ids) + 1
    logits_exp = torch.exp(-logits).unsqueeze(1) # e^{-logit} N x 1

    # calculate summation of exponential of logits value for each group
    sparse_index = torch.LongTensor(np.stack([segment_ids, np.arange(logits_len)]))
    sparse_value = torch.ones(logits_len, dtype=torch.float)
    trans_matrix_sparse = torch.sparse.FloatTensor(sparse_index, sparse_value,
                                                   torch.Size([num_segments, logits_len]))
    softmax_den =, logits_exp)

    # repeat softmax denominator to have the same length as logits
    sparse_index2 = torch.LongTensor(np.stack([np.arange(logits_len), segment_ids]))
    sparse_value2 = torch.ones(logits_len, dtype=torch.float)
    trans_matrix_sparse2 = torch.sparse.FloatTensor(sparse_index2, sparse_value2,
                                                   torch.Size([logits_len, num_segments])).to(device)
    softmax_den_repeat =, softmax_den)

    return torch.squeeze(logits_exp/softmax_den_repeat)



Oh, that’s a cool solution! Note that it’s numerically less stable than traditional softmax, though, but if that becomes a problem, you can take the max and subtract that, too.