Hey,

Is there any way to index a tensor with float tensors and still keeping these tensors differentiable?

For example, I have a tensor of shape (3, 400, 400) which stands for an RGB image.

I want to “draw” a square in the image within some boundaries.

So there’s the code:

img = torch.full((3, img_size, img_size), dtype=torch.float32, fill_value=1)

And I have 2 float tensors representing the center of the square and a fixed width.

x_center = torch.rand(1, requires_grad=True)

y_center = torch.rand(1, requires_grad=True)

And I have a representation of the square left upper corner and right lower corner:

x1 = (x_center - (width/2))*img_size

x2 = (x_center + (width/2))*img_size

y1 = (y_center - (width/2))*img_size

y2 = (y_center + (width/2))*img_size

So finally I want to do this:

img[: x1:x2, y1:y2] = some_value

which of course cannot be done since these indices are floats and not longs. for example:

x1 = ((x_center - (width/2))*img_size).long()

converting these parameters to long does not keep the computation graph of x_center and y_center.

Is there any way to do this so when I call the backward function x_center and y_center grads will be updated?

Thanks!