Help regarding Slerp function for generative model sampling

It looks like torch.lerp function is available but torch.slerp does not exist in PyTorch 1.0. :thinking:

I’m inspired to try its implementation from Sampling Generative Networks paper for interpolation in latent space. Also recently introduced GAN 2.0 by NVIDIA also does Slerp interpolation.

Has anyone implemented Slerp in PyTorch? Here is a NumPy code snippet on Wikipedia for reference but I’m new to Slerp function so don’t understand what changes do I need to make to make it work in PyTorch…

Please help! @smth suggested me to look here for help from PyTorch issue I created earlier.

I think it should be implemented right into PyTorch >=1.0. What do you think? Please vote below as per your preferences:

  • Implement
  • Don’t Implement

0 voters

Tom White used slerp for the DCGAN interpolation and posted the numpy code here.
The code can be translated pretty easy to PyTorch, since all methods are available.

1 Like

With the help of the code snippet @ptrblck you referred to I re-wrote the Slerp (as much as I could think). I just don’t know how Slerp works!
And this implementation by me is also wrong. Just trying and failing. :disappointed::disappointed:

def slerp(start, end, val):
    a = start / torch.norm(start)
    b = end / torch.norm(end)
    omega = torch.acos(torch.clamp(torch.mm(a, b.t()), -1, 1))
    so = torch.sin(omega)
    if so == 0:
        return (1.0 - val) * start + val * end  # L'Hopital's rule / LERP
    return torch.sin((1.0 - val) * omega) / so * start + torch.sin(val * omega) / so * end

I pass the following inputs to possible slerp I tried rewriting above:

BATCH_SIZE = 64
Z_DIM = 128
z_start = torch.randn(BATCH_SIZE, Z_DIM)
z_end = torch.randn(BATCH_SIZE, Z_DIM)
z_point = slerp(z_start, z_end, 0.5)

I get following error with above code:

Traceback (most recent call last):
  File "main.py", line 320, in <module>
    interpolate(epoch=70000, mode='slerp', n_latents=20)
  File "main.py", line 306, in interpolate
    z_point = slerp(z_start, z_end, i.item())
  File "main.py", line 267, in slerp
    if so == 0:
RuntimeError: bool value of Tensor with more than one value is ambiguous

If I shut down the following code snippet (tryna bypass bool comparison ambiguity):

so = if so == 0:
    return (1.0 - val) * start + val * end  # L'Hopital's rule / LERP

I still get the following error:

Traceback (most recent call last):
  File "main.py", line 320, in <module>
    interpolate(epoch=70000, mode='slerp', n_latents=20)
  File "main.py", line 306, in interpolate
    z_point = slerp(z_start, z_end, i.item())
  File "main.py", line 269, in slerp
    return torch.sin((1.0 - val) * omega) / so * start + torch.sin(val * omega) / so * end
RuntimeError: The size of tensor a (64) must match the size of tensor b (128) at non-singleton dimension 1

Is there anyone who can help me with this?

There are some minor error regarding the shapes.
Here is a working solution for batched inputs:

def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm*high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
2 Likes

WoW! It is working. Thank you so much for helping me out. :ok_hand::slightly_smiling_face:

1 Like