With the help of the code snippet @ptrblck you referred to I re-wrote the Slerp (as much as I could think). I just don’t know how Slerp works!
And this implementation by me is also wrong. Just trying and failing.
def slerp(start, end, val):
a = start / torch.norm(start)
b = end / torch.norm(end)
omega = torch.acos(torch.clamp(torch.mm(a, b.t()), -1, 1))
so = torch.sin(omega)
if so == 0:
return (1.0 - val) * start + val * end # L'Hopital's rule / LERP
return torch.sin((1.0 - val) * omega) / so * start + torch.sin(val * omega) / so * end
I pass the following inputs to possible slerp
I tried rewriting above:
BATCH_SIZE = 64
Z_DIM = 128
z_start = torch.randn(BATCH_SIZE, Z_DIM)
z_end = torch.randn(BATCH_SIZE, Z_DIM)
z_point = slerp(z_start, z_end, 0.5)
I get following error with above code:
Traceback (most recent call last):
File "main.py", line 320, in <module>
interpolate(epoch=70000, mode='slerp', n_latents=20)
File "main.py", line 306, in interpolate
z_point = slerp(z_start, z_end, i.item())
File "main.py", line 267, in slerp
if so == 0:
RuntimeError: bool value of Tensor with more than one value is ambiguous
If I shut down the following code snippet (tryna bypass bool
comparison ambiguity):
so = if so == 0:
return (1.0 - val) * start + val * end # L'Hopital's rule / LERP
I still get the following error:
Traceback (most recent call last):
File "main.py", line 320, in <module>
interpolate(epoch=70000, mode='slerp', n_latents=20)
File "main.py", line 306, in interpolate
z_point = slerp(z_start, z_end, i.item())
File "main.py", line 269, in slerp
return torch.sin((1.0 - val) * omega) / so * start + torch.sin(val * omega) / so * end
RuntimeError: The size of tensor a (64) must match the size of tensor b (128) at non-singleton dimension 1
Is there anyone who can help me with this?