I have a 1D tensor with audio data, now I would like to train my neural network to create an index in the form of a float with which I can slice a window from the beginning of this index + 512 from the audio data to calculate a loss. Unfortunately, I have not yet found a way to do this without having to convert the float index into an int, which prevents a gradient calculation.

Hi sndjn!

You probably wonâ€™t be able to make this work. The basic problem is that

such an index is discrete, and, as such, is not (usefully) differentiable.

Pytorch knows that the operation of converting a float into an int is not

(usefully) differentiable and does not permit `requires_grad = True`

to

be attached to an int (or a `LongTensor`

, which it th specific type of tensor

you would use as an index).

Hereâ€™s the problem: Letâ€™s say that your float index is `7.41`

. Perhaps you

round it to `7`

to get an int index. As your float varies from `7.0`

to `7.5`

,

the rounded value, `7`

, remains constant, so the derivative is zero. Then

at `7.5`

the rounded value pops up to `8`

. (Right at this point, the rounding

function is technically not differentiable.) Then from `7.5`

to `8.0`

, the

rounded value, `8`

, is constant again, and the derivative is again zero.

While this function is differentiable â€śalmost everywhereâ€ť, with derivative

zero, a derivative of zero isnâ€™t useful for training with a gradient-descent

optimizer, because it doesnâ€™t tell the optimizer in what direction is should

move the model parameters. This is what I mean by â€śnot usefully

differentiableâ€ť (and why pytorch doesnâ€™t support `requires_grad`

for

integer tensors).

However (and this might not make sense or work for your use case),

you could possibly use a float â€śindexâ€ť to interpolate into your audio data.

The idea would be that for an index of `7.0`

, you would use your audio

values given by `data[7], data[8], data[9], ...`

. But for an â€śindexâ€ť

of `7.41`

, you would use `di (7.41), di (8.41), di (9.41), ...`

,

where, for example, `di (7.41)`

is the value obtained by interpolating

between `data[7]`

and `data[8]`

.

Provided that your interpolation scheme (for example, linear interpolation)

is (usefully) differentiable (and the subsequent processing that turns your

interpolated data into a loss is also differentiable), you will be able to

backpropagate through your float â€śindexâ€ť and train your model.

(This isnâ€™t the only possible scheme: You could, for example, just use

interpolation for the first and last data points in the window. Or you

could interpolate between (use a weighted average of) the loss value

you get from starting your window at `7`

and the one you get from starting

your window at `8`

. And so on â€¦)

Good luck!

K. Frank

Thank you for your answer, I have found a solution that will probably work for my application. I have programmed a simplified example with only one float weight for optimization:

For those who only want the solution:

```
@torch.jit.script
def torch_interp(indices, xp, fp):
pos = torch.searchsorted(xp, indices, right=True)
pos = pos.clamp(min=1, max=xp.size(0)-1)
left = xp[pos - 1]
right = xp[pos]
weight = (indices - left) / (right - left)
return (1 - weight) * fp[pos - 1] + weight * fp[pos]
test = F.pad(signal, (0, 512))
float_weight = torch.tensor([(torch.rand(())-0.5)*20.0], requires_grad=True) # initialized with random float
target_float = 0.1 # to slice from the test tensor for loss calculation and to compare with the optimized float_index
window_size = 512 # window_size for comparison
optimizer = torch.optim.AdamW([float_weight], lr=0.1)
start_i = int(target_float*target_length)
loss_fn = SoftDTWLossPyTorch(gamma=0.1,normalize=True)
for epoch in range(1000):
optimizer.zero_grad()
float_index = torch.sigmoid(float_weight)
indices = torch.arange(0, window_size) + (float_index*target_length)
xp = torch.arange(test.size(0), dtype=torch.float32)
x = torch_interp(indices, xp, test)
# Loss-Calculation
loss = loss_fn(x.unsqueeze(0).unsqueeze(-1),test[start_i:start_i+window_size].unsqueeze(0).unsqueeze(-1))
# Much faster loss function that also works, but not as well: (x-test[start_i:start_i+window_size]).abs().sum()
# regularization term # prevent local optima with randomness: probably unnecessary or even harmful for real networks with different inputs
loss = loss + loss * (float_weight-0.5).abs() + loss * torch.rand(())*6
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f'Epoch {epoch}, Loss: {loss.item()}, Float Index: {float_index.item()}')
if (target_float-float_index).abs()<0.001:
print("Done")
break
```