Vmap and its compatibility issues

Hi All,

I have been using the Python Optimal Transport library. I want to define a loss function that iterates over every sample in my batch and calculates the sinkhorn distance for that sample and its ground-truth value. What I was using before was a for-loop:

for i in range(len(P_batch)):
      if i == 0:
         loss = ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
      loss += ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)

but this is way too slow for my application. I was reading through functorch, and apparently I should have been able to use the vmap functionality. But after wrapping my function in vmap, I get this weird error that everyone else is talking about:

RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257.

Does anyone have a workaround?

Cheers.

Maybe you could pull out the if condition and execute it before executing the loop to remove the data-dependency.

Hi @hmdolatabadi,

Following on from what @ptrblck said, vmap will have a problem with the if statement, and it seems you’re only using the if statement to initialize the loss from the first sample, then in-place add all other samples in your batch.

What you should do is vmap over the ot.sinkhorn2 function, then return a batch of outputs over which you sum (rather than summing within vmap itself).

You can try this and see if it works,

from torch.func import vmap

P = P_batch.flatten(start_dim=1)   #flatten tensors (but not including batch dim)
Q = Q_ batch.flatten(start_dim=1)

losses = vmap(ot.sinkhorn2, in_dims=(0, 0, None, None))(P, Q, C, epsilon) #assumes C and epsilon are constants (that are repeated for all samples). 
loss = torch.sum(losses)

FYI, if your ot.sinkhorn2 function is an nn.Module object, you’ll need to use torch.func.functional_call with your vmap call.

Thanks @ptrblck and @AlphaBetaGamma96!
The issue is more related to what @ptrblck described.
The sinkhorn function in the ot package uses a data-dependent if-statement, and that’s why I’m getting an error:

File /anaconda3/envs/my_env/lib/python3.8/site-packages/ot/bregman.py:505, in sinkhorn_knopp(a, b, M, reg, numItermax, stopThr, verbose, log, warn, warmstart, **kwargs)
    502 v = b / KtransposeU
    503 u = 1. / nx.dot(Kp, v)
--> 505 if (nx.any(KtransposeU == 0)
    506         or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
    507         or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
    508     # we have reached the machine precision
    509     # come back to previous solution and quit loop
    510     warnings.warn('Warning: numerical errors at iteration %d' % ii)
    511     u = uprev

RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .

I guess I might need an alternative route for computing the sinkhorn function that doesn’t involve if-statements. If you have any ideas, it is more than appreciated!

Thanks.

If the Python Optimal Transport library has a repo, I’d open an issue and ask, because it seems like you need a version that doesn’t have an if statement, unfortunately.