# Differentiable slicing operation with float index

I have a 1D tensor with audio data, now I would like to train my neural network to create an index in the form of a float with which I can slice a window from the beginning of this index + 512 from the audio data to calculate a loss. Unfortunately, I have not yet found a way to do this without having to convert the float index into an int, which prevents a gradient calculation.

Hi sndjn!

You probably wonâ€™t be able to make this work. The basic problem is that
such an index is discrete, and, as such, is not (usefully) differentiable.

Pytorch knows that the operation of converting a float into an int is not
(usefully) differentiable and does not permit `requires_grad = True` to
be attached to an int (or a `LongTensor`, which it th specific type of tensor
you would use as an index).

Hereâ€™s the problem: Letâ€™s say that your float index is `7.41`. Perhaps you
round it to `7` to get an int index. As your float varies from `7.0` to `7.5`,
the rounded value, `7`, remains constant, so the derivative is zero. Then
at `7.5` the rounded value pops up to `8`. (Right at this point, the rounding
function is technically not differentiable.) Then from `7.5` to `8.0`, the
rounded value, `8`, is constant again, and the derivative is again zero.

While this function is differentiable â€śalmost everywhereâ€ť, with derivative
zero, a derivative of zero isnâ€™t useful for training with a gradient-descent
optimizer, because it doesnâ€™t tell the optimizer in what direction is should
move the model parameters. This is what I mean by â€śnot usefully
differentiableâ€ť (and why pytorch doesnâ€™t support `requires_grad` for
integer tensors).

However (and this might not make sense or work for your use case),
you could possibly use a float â€śindexâ€ť to interpolate into your audio data.

The idea would be that for an index of `7.0`, you would use your audio
values given by `data[7], data[8], data[9], ...`. But for an â€śindexâ€ť
of `7.41`, you would use `di (7.41), di (8.41), di (9.41), ...`,
where, for example, `di (7.41)` is the value obtained by interpolating
between `data[7]` and `data[8]`.

Provided that your interpolation scheme (for example, linear interpolation)
is (usefully) differentiable (and the subsequent processing that turns your
interpolated data into a loss is also differentiable), you will be able to

(This isnâ€™t the only possible scheme: You could, for example, just use
interpolation for the first and last data points in the window. Or you
could interpolate between (use a weighted average of) the loss value
you get from starting your window at `7` and the one you get from starting
your window at `8`. And so on â€¦)

Good luck!

K. Frank

1 Like

Thank you for your answer, I have found a solution that will probably work for my application. I have programmed a simplified example with only one float weight for optimization:

For those who only want the solution:

``````@torch.jit.script
def torch_interp(indices, xp, fp):
pos = torch.searchsorted(xp, indices, right=True)
pos = pos.clamp(min=1, max=xp.size(0)-1)
left = xp[pos - 1]
right = xp[pos]
weight = (indices - left) / (right - left)

return (1 - weight) * fp[pos - 1] + weight * fp[pos]
float_weight = torch.tensor([(torch.rand(())-0.5)*20.0], requires_grad=True) # initialized with random float
target_float = 0.1 # to slice from the test tensor for loss calculation and to compare with the optimized float_index
window_size = 512   # window_size for comparison
start_i = int(target_float*target_length)

loss_fn = SoftDTWLossPyTorch(gamma=0.1,normalize=True)

for epoch in range(1000):

float_index = torch.sigmoid(float_weight)

indices = torch.arange(0, window_size) + (float_index*target_length)
xp = torch.arange(test.size(0), dtype=torch.float32)
x = torch_interp(indices, xp, test)

# Loss-Calculation
loss = loss_fn(x.unsqueeze(0).unsqueeze(-1),test[start_i:start_i+window_size].unsqueeze(0).unsqueeze(-1))
# Much faster loss function that also works, but not as well: (x-test[start_i:start_i+window_size]).abs().sum()

# regularization term               # prevent local optima with randomness: probably unnecessary or even harmful for real networks with different inputs
loss = loss + loss * (float_weight-0.5).abs() + loss * torch.rand(())*6

loss.backward()
optimizer.step()

if epoch % 100 == 0:
print(f'Epoch {epoch}, Loss: {loss.item()}, Float Index: {float_index.item()}')
if (target_float-float_index).abs()<0.001:
print("Done")
break
``````