Hi Lucas!
I can reproduce your issue using your fishy_tensor
.
I believe that it is caused by a somewhat unlikely conspiracy of round-off
error, likely exacerbated by the fact that the rows of your probs
tensor
are rather long (length 1600).
In your case your probs
(fishy_tensor
) are proper probabilities, with rows
normalized to sum to one (and it passes simplex.check()
). Categorical
does not require its row-sums to be normalized to one, so it normalizes
them (in your case, redundantly, but Categorical doesn’t know this) before
it applies. simplex.check()
. The normalization that Categorical
applies
changes the round-off error a little bit and just happens to push one of the
row-sums a little bit over simplex.check()
's tolerance of 1.e-6
, causing
Categorical
to raise the error due to its internal call to simplex.check()
.
This script that uses your fishy_tensor
illustrates these points:
import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())
fishy = torch.load ('fishy_tensor.pt')
print ((fishy.sum (dim = 1) - 1.0).abs().max()) # rows sum to one (up to round-off)
print (torch.distributions.constraints.simplex.check (fishy).all()) # passes test
fishy_norm = fishy / fishy.sum (dim = 1).unsqueeze (-1) # "normalize" rows to sum to one
print ((fishy_norm.sum (dim = 1) - 1.0).abs().max()) # rows sum not quite as well to one
print (torch.distributions.constraints.simplex.check (fishy_norm)) # fails
print ((fishy_norm.sum (dim = 1) - 1.0).abs() > 1.e-6) # simplex uses 1.e-6 as its tolerance
print (fishy_norm.sum (dim = 1)[45] - 1.0) # row 45 is the culprit
print (fishy_norm[45].sum() - 1.0) # but details of cuda sum matter
And here is its output:
1.12.0
11.6
GeForce GTX 1050 Ti
tensor(7.1526e-07, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(True, device='cuda:0')
tensor(1.0729e-06, device='cuda:0', grad_fn=<MaxBackward1>)
tensor([ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, False, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True], device='cuda:0')
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, True, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False], device='cuda:0')
tensor(-1.0729e-06, device='cuda:0', grad_fn=<SubBackward0>)
tensor(-5.9605e-07, device='cuda:0', grad_fn=<SubBackward0>)
For what it’s worth, running your particular fishy_tensor
on my cpu does not
show this behavior.
I assume that your code calls simplex.check()
to protect against Categorical
raising an error. If you’re comfortable that your probs
vector is legit, you could
instantiate Categorical
with validate_args = False
. (If you want to stick
with the validation and call simplex.check()
for protection, I imagine that
(redundantly) normalizing your probs
tensor (to mimic what Categorical
does
internally) before making your call to simplex.check()
might be an adequate
work-around.)
I also note – not that I think that you should read anything into this – summing
row 45 of fishy_tensor
(after normalizing it) in isolation does not cross
simplex.check()
's 1.e-6
tolerance, while summing it as a row sum of the
whole tensor does. I assume that this causes a slight difference in innocent
round-of error due to the gpu using differing orders of operations in the two
cases.
I don’t really see this as a bug, but rather an “expected” consequence
of round-of error. One could argue, however, (and I suppose I do) that
Categorical
/ simplex.check()
should take into account the length
of the probability vectors and widen out the tolerance as they become
longer.
Best.
K. Frank