hi,
in the following code, i expect some sum to be exactly 1. but it is ALMOST 1. this happens on cpu as on gpu. it is more likely to be an issue of rounding. any idea why this happens? and how to fix it? thanks
code:
import sys
import numpy as np
import torch
from torch.distributions.uniform import Uniform
low_b = 1e-10
upp_b = 0.95 / 9.
nbr = 10
n = 100
# for reproducibility
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# end reproducibility
print("Pytorch version:{}".format(torch.__version__))
print("Python version:{}".format(sys.version))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device {}".format(device))
target = torch.zeros((n, nbr), device=device, dtype=torch.float32,
requires_grad=False)
dist_unif = Uniform(torch.tensor([low_b]), torch.tensor([upp_b]))
print("Start issue:")
for i in range(n):
smooth = dist_unif.sample((nbr, )).squeeze() # sample random numbers
smooth[2] = 0. # set position 2 to 0.
target[i, :] = smooth
target[i, 2] = 1. - smooth.sum() # replace position 2 by: 1 - sum(all)
# now, the sum(all) is expected to be 1. but it is not the case.
# the sum is ALMOST 1. Numerical instability?
print("SUM {}: {:.16f}".format(i, target[i].sum()))
assert target[i].sum() == 1., "sum {:.16f} is not 1.".format(
target[i].sum())
output CPU:
$ python issues.py
Pytorch version:1.4.0+cu100
Python version:3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
Device cpu
Start issue:
SUM 0: 0.9999999403953552
Traceback (most recent call last):
File "issues.py", line 44, in <module>
target[i].sum())
AssertionError: sum 0.9999999403953552 is not 1.
output GPU:
$ python issues.py
Pytorch version:1.4.0+cu100
Python version:3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
Device cuda:0
Start issue:
SUM 0: 1.0000001192092896
Traceback (most recent call last):
File "issues.py", line 44, in <module>
target[i].sum())
AssertionError: sum 1.0000001192092896 is not 1.