# Help regarding Slerp function for generative model sampling

It looks like `torch.lerp` function is available but `torch.slerp` does not exist in PyTorch 1.0. I’m inspired to try its implementation from Sampling Generative Networks paper for interpolation in latent space. Also recently introduced GAN 2.0 by NVIDIA also does Slerp interpolation.

Has anyone implemented Slerp in PyTorch? Here is a NumPy code snippet on Wikipedia for reference but I’m new to Slerp function so don’t understand what changes do I need to make to make it work in PyTorch…

I think it should be implemented right into PyTorch >=1.0. What do you think? Please vote below as per your preferences:

• Implement
• Don’t Implement

0 voters

Tom White used slerp for the DCGAN interpolation and posted the numpy code here.
The code can be translated pretty easy to PyTorch, since all methods are available.

1 Like

With the help of the code snippet @ptrblck you referred to I re-wrote the Slerp (as much as I could think). I just don’t know how Slerp works!
And this implementation by me is also wrong. Just trying and failing.  ``````def slerp(start, end, val):
a = start / torch.norm(start)
b = end / torch.norm(end)
omega = torch.acos(torch.clamp(torch.mm(a, b.t()), -1, 1))
so = torch.sin(omega)
if so == 0:
return (1.0 - val) * start + val * end  # L'Hopital's rule / LERP
return torch.sin((1.0 - val) * omega) / so * start + torch.sin(val * omega) / so * end
``````

I pass the following inputs to possible `slerp` I tried rewriting above:

``````BATCH_SIZE = 64
Z_DIM = 128
z_start = torch.randn(BATCH_SIZE, Z_DIM)
z_end = torch.randn(BATCH_SIZE, Z_DIM)
z_point = slerp(z_start, z_end, 0.5)
``````

I get following error with above code:

``````Traceback (most recent call last):
File "main.py", line 320, in <module>
interpolate(epoch=70000, mode='slerp', n_latents=20)
File "main.py", line 306, in interpolate
z_point = slerp(z_start, z_end, i.item())
File "main.py", line 267, in slerp
if so == 0:
RuntimeError: bool value of Tensor with more than one value is ambiguous
``````

If I shut down the following code snippet (tryna bypass `bool` comparison ambiguity):

``````so = if so == 0:
return (1.0 - val) * start + val * end  # L'Hopital's rule / LERP
``````

I still get the following error:

``````Traceback (most recent call last):
File "main.py", line 320, in <module>
interpolate(epoch=70000, mode='slerp', n_latents=20)
File "main.py", line 306, in interpolate
z_point = slerp(z_start, z_end, i.item())
File "main.py", line 269, in slerp
return torch.sin((1.0 - val) * omega) / so * start + torch.sin(val * omega) / so * end
RuntimeError: The size of tensor a (64) must match the size of tensor b (128) at non-singleton dimension 1
``````

Is there anyone who can help me with this?

There are some minor error regarding the shapes.
Here is a working solution for batched inputs:

``````def slerp(val, low, high):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm*high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
``````
2 Likes

WoW! It is working. Thank you so much for helping me out.  1 Like