Differentiable slicing operation with float index

I have a 1D tensor with audio data, now I would like to train my neural network to create an index in the form of a float with which I can slice a window from the beginning of this index + 512 from the audio data to calculate a loss. Unfortunately, I have not yet found a way to do this without having to convert the float index into an int, which prevents a gradient calculation.

Hi sndjn!

You probably won’t be able to make this work. The basic problem is that
such an index is discrete, and, as such, is not (usefully) differentiable.

Pytorch knows that the operation of converting a float into an int is not
(usefully) differentiable and does not permit requires_grad = True to
be attached to an int (or a LongTensor, which it th specific type of tensor
you would use as an index).

Here’s the problem: Let’s say that your float index is 7.41. Perhaps you
round it to 7 to get an int index. As your float varies from 7.0 to 7.5,
the rounded value, 7, remains constant, so the derivative is zero. Then
at 7.5 the rounded value pops up to 8. (Right at this point, the rounding
function is technically not differentiable.) Then from 7.5 to 8.0, the
rounded value, 8, is constant again, and the derivative is again zero.

While this function is differentiable “almost everywhere”, with derivative
zero, a derivative of zero isn’t useful for training with a gradient-descent
optimizer, because it doesn’t tell the optimizer in what direction is should
move the model parameters. This is what I mean by “not usefully
differentiable” (and why pytorch doesn’t support requires_grad for
integer tensors).

However (and this might not make sense or work for your use case),
you could possibly use a float “index” to interpolate into your audio data.

The idea would be that for an index of 7.0, you would use your audio
values given by data[7], data[8], data[9], .... But for an “index”
of 7.41, you would use di (7.41), di (8.41), di (9.41), ...,
where, for example, di (7.41) is the value obtained by interpolating
between data[7] and data[8].

Provided that your interpolation scheme (for example, linear interpolation)
is (usefully) differentiable (and the subsequent processing that turns your
interpolated data into a loss is also differentiable), you will be able to
backpropagate through your float “index” and train your model.

(This isn’t the only possible scheme: You could, for example, just use
interpolation for the first and last data points in the window. Or you
could interpolate between (use a weighted average of) the loss value
you get from starting your window at 7 and the one you get from starting
your window at 8. And so on …)

Good luck!

K. Frank

1 Like

Thank you for your answer, I have found a solution that will probably work for my application. I have programmed a simplified example with only one float weight for optimization:

For those who only want the solution:

def torch_interp(indices, xp, fp):
    pos = torch.searchsorted(xp, indices, right=True)
    pos = pos.clamp(min=1, max=xp.size(0)-1)
    left = xp[pos - 1]
    right = xp[pos]
    weight = (indices - left) / (right - left)

    return (1 - weight) * fp[pos - 1] + weight * fp[pos]
test = F.pad(signal, (0, 512))
float_weight = torch.tensor([(torch.rand(())-0.5)*20.0], requires_grad=True) # initialized with random float
target_float = 0.1 # to slice from the test tensor for loss calculation and to compare with the optimized float_index
window_size = 512   # window_size for comparison
optimizer = torch.optim.AdamW([float_weight], lr=0.1)
start_i = int(target_float*target_length)

loss_fn = SoftDTWLossPyTorch(gamma=0.1,normalize=True)

for epoch in range(1000):
    float_index = torch.sigmoid(float_weight)

    indices = torch.arange(0, window_size) + (float_index*target_length)
    xp = torch.arange(test.size(0), dtype=torch.float32)
    x = torch_interp(indices, xp, test)
    # Loss-Calculation
    loss = loss_fn(x.unsqueeze(0).unsqueeze(-1),test[start_i:start_i+window_size].unsqueeze(0).unsqueeze(-1)) 
    # Much faster loss function that also works, but not as well: (x-test[start_i:start_i+window_size]).abs().sum()

                # regularization term               # prevent local optima with randomness: probably unnecessary or even harmful for real networks with different inputs
    loss = loss + loss * (float_weight-0.5).abs() + loss * torch.rand(())*6
    if epoch % 100 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}, Float Index: {float_index.item()}')
    if (target_float-float_index).abs()<0.001: