import torch
import torch.nn as nn
batch_size = 16
a = torch.rand(batch_size,12,4864,4864).cuda()
m = nn.Softmax(dim=3)
print(torch.sum(m(a),dim=3))

when batch_size is big, the output is not 1, when I change batch_size to 8, The output is all 1.
Is there any way to make the calculation of the softmax function more accurate?

How inaccurate is Softmax in your test? You are summing over a
large number of values, and it is to be expected that some floating-point
round-off error occurs.

Look at the largest deviations from the expected value of 1.0. Do your
results agree with 1.0 up to about 6 or 7 decimal places – that is, within
expected floating-point round-off error?

You can make the calculation of the softmax function more accurate by
performing the calculations in double-precision.