Ways to calculate softmax function over very large matrix

So I have matrix of size n x m, say, mat = torch.rand(n, m) and i want to calculate the softmax over the second dimension:

exp_mat = torch.exp(mat)
soft_max_mat = exp_mat/(exp_mat.sum(1).unsqueeze(1).repeat(1, exp_mat.size(1))

but this is too slow, even on gpus. I believe there should be some workaround like vectorization in pytorch. What is it exactly?

why don’t you use just

import torch.nn.functional as F
mat = torch.rand(n, m)
sftmx = F.softmax(mat, dim=1)

this will return with an error out of memory because the matrix is very large.