Generate groupby index for a tensor

Hi guys, I am trying to implement a function that generates unique indexes for a tensor.

The rules are simple:

  1. assign a new unique value for each position that value equals to 2.
  2. keep the unique value from the previous positions if the current position is following a 2 started sequence.
  3. ignore value=1 positions which has no 2 a head.

Here is an example:
Input: [2, 0, 0, 2, 1, 0, 0, 2, 1, 1, 0, 0, 1, 1, 1]
Expect Output: [1, 0, 0, 2, 2, 0, 0, 3, 3, 3, 0, 0, 0, 0, 0]

I hope the whole pipeline can run within the autograd strategy.

I can successfully select and assign unique values for all values=2 right now, but it is hard for me to select values=1 and lead by 2.

>>> a = tc.Tensor([2., 0, 0, 2, 1, 0, 0, 0, 2, 1, 1, 0, 0, 1, 1, 1])
>>> B = tc.where(a == 2.0, tc.cumsum(a, -1).double(), 0.0)
>>> print(B) 
tensor([2., 0., 0., 4., 0., 0., 0., 0., 7., 0., 0., 0., 0., 0., 0., 0.])

I am not sure what do you like to achieve.
AFAIK, the moment you hardcode/assign a value within a tensor, there is a break in the computation graph and the gradient for those elements will be 0.

assume that tensor a is an output of a model and it brings gradients.
I am trying to calculate another tensor b based on a that groups the values in a.

Here is the example:
a = [2, 0, 0, 2, 1, 0, 0, 2, 1, 1, 0, 0, 1, 1, 1]
I need: b = [1, 0, 0, 2, 2, 0, 0, 3, 3, 3, 0, 0, 0, 0, 0]

I hope the algorithm can be done with Pytorch only so that b can carry the gradient from a.

If I understood your problem correctly, then you could do something like this.

def group_by_index(a):
    b = a.unsqueeze(-2)
    c = b.mT @ b
    x = torch.tril(torch.ones_like(c), -1) + torch.triu(c)
    tmp = (x == 0).cumsum(axis=-1)
    x[tmp>0] = 0
    x, _ = torch.max(x, dim=-2)
    x[x<2] = 0
    x[x>1] = 1
    
    vals = (a==2).cumsum(axis=-1)
    return x * vals


a = torch.tensor([2., 0, 0, 2, 1, 0, 0, 2, 1, 1, 0, 0, 1, 1, 1], requires_grad=True)
x = group_by_index(a)
print(x)
# Output
tensor([[1., 0., 0., 2., 2., 0., 0., 3., 3., 3., 0., 0., 0., 0., 0.]],
       grad_fn=<MulBackward0>)

This will also work for tensors with more dims.

# Input
a = torch.tensor([[2., 1, 0, 2, 2],
                  [1, 2, 1, 0, 1],
                  [2, 0, 2, 1, 1]], requires_grad=True)

# Output
tensor([[1., 1., 0., 2., 3.],
        [0., 1., 1., 0., 0.],
        [1., 0., 2., 2., 2.]], grad_fn=<MulBackward0>)

There is definitely room for improvement in this code, but it is just a workaround to get what you want and keep the grad.

1 Like

Thank you Matias!
I am very appreciate to your quick response!
I think this is a perfect solution!

1 Like

Hi Matias,

I tried your solution in my program, but I found that this solution still breaks the backward.
I provide the code I used to check gradient as following:

def group_by_index(a):
    b = a.unsqueeze(-2)
    c = b.mT @ b
    x = torch.tril(torch.ones_like(c), -1) + torch.triu(c)
    tmp = (x == 0).cumsum(axis=-1)
    x[tmp>0] = 0
    x, _ = torch.max(x, dim=-2)
    x[x<2] = 0
    x[x>1] = 1
    
    vals = (a==2).cumsum(axis=-1)
    return x * vals

tags = torch.tensor([2., 0, 0, 2, 1, 0, 0, 2, 1, 1, 0, 0, 1, 1, 1], requires_grad=True)
groups = group_by_index(a)
fake_loss = group.sum() # assume this is final loss of my whole model
fake_loss.backward()
assert tags.grad is not None  # tags.grad currently is None

I think the reason is that x[tmp > 0] is not differentiable because " index operation won’t be differentiable" from None gradients right after parameter initialization - #6 by ptrblck .

Thank you!

Hi,

I think you forgot to change the variable that you are giving to the function.

a is not defined, but I assume that it did not require_grad in the first place.
I tried with your variable tags and it seems to be working.

However, the gradient seemed to be 0, which might have to do with what you mentioned.

I changed it a little bit to avoid these operations

def group_by_index(a):
    b = a.unsqueeze(-2)
    c = b.mT @ b
    x = torch.tril(torch.ones_like(c), -1) + torch.triu(c)
    tmp = torch.nn.ReLU()((torch.nn.ReLU()((x-1)*-1).cumsum(axis=-1)*-1)+1)
    x = x * tmp
    x, _ = torch.max(x, dim=-2)
    x = x - 1
    x = x / (x+1e-10) # Small value added for nmerical stability
    
    vals = torch.nn.ReLU()(a-1).cumsum(axis=-1)
    return x * vals

tags = torch.tensor([2., 0, 0, 2, 1, 0, 0, 2, 1, 1, 0, 0, 1, 1, 1], requires_grad=True)
groups = group_by_index(tags)
print(groups)
fake_loss = groups.sum() # assume this is final loss of my whole model
fake_loss.backward()
print(tags.grad)
assert tags.grad is not None  # tags.grad currently is None

Hope this works :smile:

This might be a better approach.

def group_by_index(a):
    x = a.unsqueeze(-2).mT @ a.unsqueeze(-2)
    x = torch.tril(torch.ones_like(x), -1) + torch.triu(x)
    mask = ((x-1)*-1).clamp(0, 1).cumsum(axis=-1).clamp(max=1)*-1+1
    x, _ = (x*mask).max(dim=-2)
    x = (x-1).clamp(min=0, max=1)
    
    vals = (a-1).clamp(min=0, max=1).cumsum(axis=-1)
    return x * vals

Thanks to your suggestions! I will look at it in tomorrow.