Differentiable Indexing

I want to do something like this, but I need it be be differentiable w.r.t the index-tensors. is there any possibility to achieve this?

import torch

# initialize tensor
tensor = torch.zeros((1, 400, 400)).double()
tensor.requires_grad_(True)

# create index ranges
x_range = torch.arange(150, 250).double()
x_range.requires_grad_(True)
y_range = torch.arange(150, 250).double()
y_range.requires_grad_(True)

# get indices of flattened tensor
x_range = x_range.long().repeat(100, 1)
y_range = y_range.long().repeat(100, 1)
y_range = y_range.t()
tensor_size = tensor.size()
indices = y_range.sub(1).mul(tensor_size[2]).add(x_range).view((1, -1))

# create patch
patch = torch.ones((1, 100, 100)).double()


# flatten tensor
tensor_flattened = tensor.contiguous().view((1, -1))

# set patch to cells of tensor_flattend at indices and reshape tensor
tensor_flattened.scatter_(1, indices, patch.view(1, -1))
tensor = tensor_flattened.view(tensor_size)

# sum up for scalar output for calling backward()
tensor_sum = tensor.sum()

# calling backward()
tensor_sum.backward()

# alternative to avoid summing tensor:
tensor.backward(torch.ones_like(tensor))


1 Like

I think you are looking for this? Indexing a variable with a variable

Not directly. I want to set values in a tensor via indexing and I need this operation to be differentiable w.r.t to the indices.

Basically what I want is a version of tensor[idices] = m which is differentiable w.r.t indices

You need a “soft” function to get a meaningful gradient. Use something like torch.nn.functional.grid_sample which interpolates between values in your tensor.

1 Like

Thanks. Could you maybe provide a minimum working example how to use it in this case?
Even if it is only pseudo code it would help me a lot.

import torch
import torch.nn.functional as F

src = torch.arange(25, dtype=torch.float).reshape(1, 1, 5, 5).requires_grad_()  # 1 x 1 x 5 x 5 with 0 ... 25
indices = torch.tensor([[-1, -1], [0, 0]], dtype=torch.float).reshape(1, 1, -1, 2)  # 1 x 1 x 2 x 2
output = F.grid_sample(src, indices)
print(output)  # tensor([[[[  0.,  12.]]]])

(-1, -1) is the top-left corner. (0, 0) is the center. The src has to be 4-d or 5-d (N x C x IH x IW). Same with indices. If you don’t have batch size (N) or channels ( C) set these to dimensions to size 1.

6 Likes

Thanks a lot. I will try this ASAP

Hi @colesbury , will the indexing be differentiable when I need to apply different loss operations depending upon the magnitude of the values in the tensor.

e.g predicted_tensor[mask_1] = loss_fn_1 ()
& predicted_tensor[mask_2] = loss_fn_2 ()