I have been using the
Python Optimal Transport library. I want to define a loss function that iterates over every sample in my batch and calculates the
sinkhorn distance for that sample and its ground-truth value. What I was using before was a for-loop:
for i in range(len(P_batch)):
if i == 0:
loss = ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
loss += ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
but this is way too slow for my application. I was reading through
functorch, and apparently I should have been able to use the
vmap functionality. But after wrapping my function in
vmap, I get this weird error that everyone else is talking about:
RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257.
Does anyone have a workaround?
Maybe you could pull out the if condition and execute it before executing the loop to remove the data-dependency.
Following on from what @ptrblck said,
vmap will have a problem with the if statement, and it seems you’re only using the if statement to initialize the loss from the first sample, then in-place add all other samples in your batch.
What you should do is vmap over the
ot.sinkhorn2 function, then return a batch of outputs over which you sum (rather than summing within vmap itself).
You can try this and see if it works,
from torch.func import vmap
P = P_batch.flatten(start_dim=1) #flatten tensors (but not including batch dim)
Q = Q_ batch.flatten(start_dim=1)
losses = vmap(ot.sinkhorn2, in_dims=(0, 0, None, None))(P, Q, C, epsilon) #assumes C and epsilon are constants (that are repeated for all samples).
loss = torch.sum(losses)
FYI, if your
ot.sinkhorn2 function is an
nn.Module object, you’ll need to use
torch.func.functional_call with your vmap call.
Thanks @ptrblck and @AlphaBetaGamma96!
The issue is more related to what @ptrblck described.
sinkhorn function in the
ot package uses a data-dependent if-statement, and that’s why I’m getting an error:
File /anaconda3/envs/my_env/lib/python3.8/site-packages/ot/bregman.py:505, in sinkhorn_knopp(a, b, M, reg, numItermax, stopThr, verbose, log, warn, warmstart, **kwargs)
502 v = b / KtransposeU
503 u = 1. / nx.dot(Kp, v)
--> 505 if (nx.any(KtransposeU == 0)
506 or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
507 or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
508 # we have reached the machine precision
509 # come back to previous solution and quit loop
510 warnings.warn('Warning: numerical errors at iteration %d' % ii)
511 u = uprev
RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
I guess I might need an alternative route for computing the
sinkhorn function that doesn’t involve if-statements. If you have any ideas, it is more than appreciated!
Python Optimal Transport library has a repo, I’d open an issue and ask, because it seems like you need a version that doesn’t have an if statement, unfortunately.