Use torch.topk for k selection, but parameters gradients return None

I am trying to implement k best selection, that consist of two parts: 1) aggregation / attention, 2) topk selection. The input data is a tensor of size (batch, size, channel, img_features). Aggregation just outputs the softmax probabilities along channel dim of input tensor, so it has size (batch_size, channel, 1). In topk I am selecting top probabilities along channel (batch_size, channel, 1), e.g. setting k = 3, will result (batch_size, 3, 1).
After this I want to leave only k indexes outputed from topk in input tensor, so the result will be (batch_size, 3, img_features) with saving original order of input tensor along channel dim. But I am fail to compute gradients.

class Attention(nn.Module):
    def __init__(self, img_features : int, aggreg_size=32, selection=True) -> None:
        super(Attention, self).__init__()
        
        self.img_features = img_features
        self.aggreg_size = aggreg_size
        self.selection = selection
        
        self.aggregation_layer = self._make_aggragation_layer(self.img_features, self.aggreg_size)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x) -> torch.Tensor:
        batch_size, channel, img_features = x.shape
        out = torch.clone(x)
        x = self.aggregation_layer(x)

        # This returns the softmax probabilities with shape (btach_size,channel, 1)
        x = self.softmax(x)
        if self.selection:
            return x
        x = torch.mul(x, out)
        return x
        
    def _make_aggragation_layer(self, img_features : int, aggreg_size : int) -> torch.Tensor:
        attention = nn.Sequential(
            nn.Linear(img_features, aggreg_size),
            nn.Tanh(),
            nn.Linear(aggreg_size, 1)
        )

        return attention
class KBestSelection(nn.Module):
    def __init__(self, k : int, img_features : int, aggreg_size=32) -> None:

        super(KBestSelection, self).__init__()
        self.k = k
        self.img_features = img_features
        self.aggreg_size = aggreg_size
        self.aggreg_method = aggreg_method
        self.aggregation = Attention(self.img_features, self.aggreg_size, selection=True)

            
    def forward(self, x):
        batch_size, channel, img_features = x.size()
        out = torch.clone(x)
        # optionaly if add argument
        if self.aggreg_method == "mean":
            x = torch.mean(x, axis=1, keepdim=True)
        elif self.aggreg_method == "max":
            x, _ = torch.max(x, axis=1, keepdim=True)
        else:
	    # This will output tensor with shape (batch_size, channels, 1)
	    # of softmax probabilities
            x = self.aggregation(x)
        
	# This selects, topk of x
        v, i = torch.topk(x, self.k, dim=1)
	
	# This sorting is to save original order of out = torch.clone(x) tensor
	# so we sort previous result indices
        topk_sort, _ = torch.sort(i, dim=1)

	# This is to understand what torch.gather does:
	# new_x = torch.empty(out.shape[0], self.k, out.shape[2])
        # for i in range(out.shape[0]):
        #     new_x[i]=out[i,topk_sort[i],:].squeeze(1)


        x = torch.gather(out, 1, topk_sort.repeat(1,1,out.shape[-1]))
  
        return x

initializing module and gradient compute will result:

x = random.rand(1,5,8)
k = 3
kbest = KBestSelection(3, 8, aggreg_size=4)

for name, param in kbest.named_parameters():
    print("param {}, grad {}".format(name, param.grad))


>> param aggregation.aggregation_layer.0.weight, grad None
    param aggregation.aggregation_layer.0.bias, grad None
    param aggregation.aggregation_layer.2.weight, grad None
    param aggregation.aggregation_layer.2.bias, grad None

Any ideas, what torch operation disattach gradient computation?

Can anyone elaborate?

It seems you are using the non-differentiable returned indices from the topk operation, which will break the computation graph, so the behavior is expected.

@ptrblck thanks for reply. I found out that using index from topk will detach from autograd. I found workaround - setup zeros in the indexes the of the input x tensor, that are not in output of torch.topk, like this:

val, ind = torch.topk(a, 3, dim=1)
masked_scores = torch.zeros_like(a)
masked_scores.scatter_(1, ind, val)
x = torch.mul(x,masked_scores)

However, it is not I would like to output. The output is multiplication of input tensor with val (softmax values). I would save the indices from topk of the input x tensor with original order. Could you please guide me if it is possible?

I might misunderstand the use case as using the topk indices to index a tensor should work.
Your code isn’t executable but also checks the .grad attributes without ever executing a forward and backward pass, so could you post a minimal and executable code snippet reproducing the issue?

@ptrblck Sure, here it is:

import torch
from torch import nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, embed_size : int, aggreg_size=32, selection=True) -> None:
        super(Attention, self).__init__()
        
        self.embed_size = embed_size
        self.aggreg_size = aggreg_size
        self.selection = selection
        
        self.aggregation_layer = self._make_aggragation_layer(self.embed_size, self.aggreg_size)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x) -> torch.Tensor:
        batch_size, segment_size, embed_sized = x.shape
        #out = torch.clone(x)
        a = self.aggregation_layer(x)
        #a = torch.transpose(a,1,2)
        a = self.softmax(a)
        if self.selection:
            return a
        #x = torch.mul(x, out)
        x = torch.mul(a, x)
        return x
        
    def _make_aggragation_layer(self, embed_size : int, aggreg_size : int) -> torch.Tensor:
        attention = nn.Sequential(
            nn.Linear(embed_size, aggreg_size),
            nn.Tanh(),
            nn.Linear(aggreg_size, 1)
        )

        return attention
    

class KBestSelection(nn.Module):
    def __init__(self, k : int, embed_size : int, aggreg_size=32, aggreg_method="attention") -> None:

        super(KBestSelection, self).__init__()
        self.k = k
        self.embed_size = embed_size
        self.aggreg_size = aggreg_size
        self.aggreg_method = aggreg_method
        if aggreg_method == "attention":
            self.aggregation = Attention(self.embed_size, self.aggreg_size, selection=True)
        if aggreg_method == "gated_attention":
            self.aggregation = GatedAttention(self.embed_size, self.aggreg_size, selection=True)
            
    def forward(self, x):
        batch_size, segment, embed = x.size()
        if self.aggreg_method == "mean":
            x = torch.mean(x, axis=1, keepdim=True)
        elif self.aggreg_method == "max":
            x, _ = torch.max(x, axis=1, keepdim=True)
        else:
            a = self.aggregation(x)
    
        v, i = torch.topk(a, self.k, dim=1)
        topk_sort, _ = torch.sort(i, dim=1)
        x = torch.gather(x, 1, topk_sort.repeat(1,1,x.shape[-1]))
        
        #val, ind = torch.topk(a, self.k, dim=1)
        #masked_scores = torch.zeros_like(a)
        #masked_scores.scatter_(1, ind, val)
        #x = torch.mul(x,masked_scores)
        return x

Here is also example of initialize a class and output:

x_global = torch.rand((1,7,8),requires_grad=True)
criterion = torch.nn.CrossEntropyLoss()
kbest = KBestSelection(3, 8, aggreg_size=4, aggreg_method="attention")

Check weights:

output_k = kbest(x_global)
target_k = output_k
loss = criterion(output_k, target_k)
loss.backward()
for name, param in kbest.named_parameters():
    print("param {}, grad {}".format(name, param.grad))

input x_global tensor:

tensor([[[0.0704, 0.2302, 0.2573, 0.1347, 0.8778, 0.3715, 0.0537, 0.6822],
         [0.5025, 0.4463, 0.3607, 0.1861, 0.3675, 0.5545, 0.3682, 0.1448],
         [0.7487, 0.7322, 0.5235, 0.2246, 0.1231, 0.9058, 0.0487, 0.3481],
         [0.1350, 0.0776, 0.0839, 0.0784, 0.9185, 0.0921, 0.9626, 0.3209],
         [0.3911, 0.5254, 0.5977, 0.6108, 0.7147, 0.1242, 0.5805, 0.8559],
         [0.2151, 0.4266, 0.7918, 0.3405, 0.8684, 0.3049, 0.7841, 0.1766],
         [0.6038, 0.9951, 0.8347, 0.8354, 0.6527, 0.1021, 0.7711, 0.8940]]],
       requires_grad=True)

in the output of forward of KBestSelection we have 3 selected vectors (topk from attention, indices (3,4,6)) in dim 2, as we declared 3 in KBestSelection, and they save original ordering, but it is not differentiable:

tensor([[[0.1350, 0.0776, 0.0839, 0.0784, 0.9185, 0.0921, 0.9626, 0.3209],
         [0.3911, 0.5254, 0.5977, 0.6108, 0.7147, 0.1242, 0.5805, 0.8559],
         [0.6038, 0.9951, 0.8347, 0.8354, 0.6527, 0.1021, 0.7711, 0.8940]]],
       grad_fn=<GatherBackward0>)

Hi Neuro!

What, in detail, do you mean by “it is not differentiable?”

The tensor you posted carries a grad_fn. It’s not displayed
automatically, but if you print out its grad_fn, you will see that it’s
True. This indicates to me that (at least part of) what you are doing
is differentiable.

One point to bear in mind: You will be able to backpropagate through
the values returned by topk(), but not through the indices. (In general,
you can’t backpropagate through integers.) So you will be able to
backpropagate through the tensor you posted and obtain gradients
with respect to x_global.

If this doesn’t address your question, could you post a much-simplified
example – just topk() with some simple input data, a few operations,
and a .backward() call – in a fully-self-contained script that prints out
whatever it is that shows that something isn’t differentiable, together
with the output that you get when you run that script?

Best.

K. Frank

Hi, @KFrank. Thanks for your reply!
What I meant declaring not differentiable, is
That gradients of weights in the model are actually equals None while trying to loss.backward(), as I mentioned here:

Hi Neuro!

There are a lot of moving parts in your code.

Can you reproduce this behavior just using topk() and a single Linear
or maybe a Sequential with Linear, Tanh, Linear. Then post such a
simplified version as a fully-self-contained, runnable script, together with
its output.

Best.

K. Frank

@KFrank ok, this is the minimal script, I can’t drop the Attention part, as it creates softmax tensor. Attention is working fine and can autograd.

import torch
from torch import nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self, embed_size : int, aggreg_size=32) -> None:
        super(Attention, self).__init__()
        
        self.embed_size = embed_size
        self.aggreg_size = aggreg_size
        
        self.aggregation_layer = self._make_aggragation_layer(self.embed_size, self.aggreg_size)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x) -> torch.Tensor:
        batch_size, segment_size, embed_sized = x.shape

        a = self.aggregation_layer(x)

        a = self.softmax(a)

        return a
    
        
    def _make_aggragation_layer(self, embed_size : int, aggreg_size : int) -> torch.Tensor:
        attention = nn.Sequential(
            nn.Linear(embed_size, aggreg_size),
            nn.Tanh(),
            nn.Linear(aggreg_size, 1)
        )

        return attention


class KBestSelection(nn.Module):
    def __init__(self, k : int, embed_size : int, aggreg_size=32, aggreg_method="attention") -> None:

        super(KBestSelection, self).__init__()
        self.k = k
        self.embed_size = embed_size
        self.aggreg_size = aggreg_size
        self.aggreg_method = aggreg_method
        self.aggregation = Attention(self.embed_size, self.aggreg_size)
            
    def forward(self, x):
        batch_size, segment, embed = x.size()
        a = self.aggregation(x)
    
        v, i = torch.topk(a, self.k, dim=1)
        topk_sort, _ = torch.sort(i, dim=1)
        x = torch.gather(x, 1, topk_sort.repeat(1,1,x.shape[-1]))
        
        return x
    
    
if __name__ == '__main__':
    batch_size = 1
    channel_size = 5
    embed_size = 8 
    k_best_select = 3
    device = torch.device("cpu")
    
    x_global = torch.rand((batch_size, channel_size, embed_size), requires_grad=True).to(device)
    kbest = KBestSelection(k_best_select, embed_size, aggreg_size=4).to(device)
    att = Attention(8, 4)
    
    output_k = kbest(x_global)
    target_k = output_k
    
    output_att = att(x_global)
    target_att = output_att
    
    criterion = torch.nn.CrossEntropyLoss()
    loss_k = criterion(output_k, target_k)
    loss_k.backward()
    
    criterion = torch.nn.CrossEntropyLoss()
    loss_att = criterion(output_att, target_att)
    loss_att.backward()
    
    for name, param in kbest.named_parameters():
        print("param {}, grad {}".format(name, param.grad))
        
    for name, param in att.named_parameters():
        print("param {}, grad {}".format(name, param.grad))

output:

param aggregation.aggregation_layer.0.weight, grad None
param aggregation.aggregation_layer.0.bias, grad None
param aggregation.aggregation_layer.2.weight, grad None
param aggregation.aggregation_layer.2.bias, grad None
param aggregation_layer.0.weight, grad tensor([[-7.8084e-05, -5.6759e-04,  3.3589e-04, -9.4865e-05,  4.7848e-04,
         -3.9725e-04, -1.8808e-04, -6.5792e-04],
        [-6.7099e-06,  3.7607e-05, -3.6818e-05,  1.7142e-06, -5.1913e-05,
          2.6254e-05,  4.7702e-07,  4.4708e-05],
        [ 1.9961e-05, -5.0867e-04,  3.8582e-04, -4.1080e-05,  5.4957e-04,
         -3.4451e-04, -8.3687e-05, -6.0323e-04],
        [-3.0132e-06,  2.1018e-05, -1.6579e-05,  1.0212e-06, -2.1265e-05,
          1.3015e-05,  2.6332e-06,  2.1147e-05]])
param aggregation_layer.0.bias, grad tensor([-1.4359e-04, -1.0613e-05, -9.8242e-06, -1.7339e-06])
param aggregation_layer.2.weight, grad tensor([[-0.0009,  0.0007,  0.0004,  0.0002]])
param aggregation_layer.2.bias, grad tensor([5.5879e-08])

Hi Neuro!

Forgive my bluntness, but no, it’s not.

You really don’t need to define any classes (or functions, for that matter).
Just instantiate a Linear (or Sequential, if need be), set up some data,
call your Linear, run some computations, do your topk(), and run the
rest of what you need to reproduce your issue, without wrapping any
of it in a class or function.

(As an aside, your “minimal script” contains an instance property,
KBestSelection.aggreg_method that isn’t used anywhere.)

So drop the Attention part and just call softmax() (or instantiate and
call Softmax) as part of your script.

Please, build the script from the bottom up by adding, step by step, the
minimal functionality needed to reproduce your issue (rather than top
down by whittling away at your existing code).

Best.

K. Frank

@KFrank, ok, forgive my persistence, but what is the point, of this minimal script? Since I supposed that the problem is that I use indices sort in using topk_sort, _ = torch.sort(i, dim=1).