I am trying to implement k best selection, that consist of two parts: 1) aggregation / attention, 2) topk selection. The input data is a tensor of size (batch, size, channel, img_features)
. Aggregation just outputs the softmax probabilities along channel dim of input tensor, so it has size (batch_size, channel, 1)
. In topk I am selecting top probabilities along channel (batch_size, channel, 1)
, e.g. setting k = 3, will result (batch_size, 3, 1)
.
After this I want to leave only k indexes outputed from topk in input tensor, so the result will be (batch_size, 3, img_features)
with saving original order of input tensor along channel dim. But I am fail to compute gradients.
class Attention(nn.Module):
def __init__(self, img_features : int, aggreg_size=32, selection=True) -> None:
super(Attention, self).__init__()
self.img_features = img_features
self.aggreg_size = aggreg_size
self.selection = selection
self.aggregation_layer = self._make_aggragation_layer(self.img_features, self.aggreg_size)
self.softmax = nn.Softmax(dim=1)
def forward(self, x) -> torch.Tensor:
batch_size, channel, img_features = x.shape
out = torch.clone(x)
x = self.aggregation_layer(x)
# This returns the softmax probabilities with shape (btach_size,channel, 1)
x = self.softmax(x)
if self.selection:
return x
x = torch.mul(x, out)
return x
def _make_aggragation_layer(self, img_features : int, aggreg_size : int) -> torch.Tensor:
attention = nn.Sequential(
nn.Linear(img_features, aggreg_size),
nn.Tanh(),
nn.Linear(aggreg_size, 1)
)
return attention
class KBestSelection(nn.Module):
def __init__(self, k : int, img_features : int, aggreg_size=32) -> None:
super(KBestSelection, self).__init__()
self.k = k
self.img_features = img_features
self.aggreg_size = aggreg_size
self.aggreg_method = aggreg_method
self.aggregation = Attention(self.img_features, self.aggreg_size, selection=True)
def forward(self, x):
batch_size, channel, img_features = x.size()
out = torch.clone(x)
# optionaly if add argument
if self.aggreg_method == "mean":
x = torch.mean(x, axis=1, keepdim=True)
elif self.aggreg_method == "max":
x, _ = torch.max(x, axis=1, keepdim=True)
else:
# This will output tensor with shape (batch_size, channels, 1)
# of softmax probabilities
x = self.aggregation(x)
# This selects, topk of x
v, i = torch.topk(x, self.k, dim=1)
# This sorting is to save original order of out = torch.clone(x) tensor
# so we sort previous result indices
topk_sort, _ = torch.sort(i, dim=1)
# This is to understand what torch.gather does:
# new_x = torch.empty(out.shape[0], self.k, out.shape[2])
# for i in range(out.shape[0]):
# new_x[i]=out[i,topk_sort[i],:].squeeze(1)
x = torch.gather(out, 1, topk_sort.repeat(1,1,out.shape[-1]))
return x
initializing module and gradient compute will result:
x = random.rand(1,5,8)
k = 3
kbest = KBestSelection(3, 8, aggreg_size=4)
for name, param in kbest.named_parameters():
print("param {}, grad {}".format(name, param.grad))
>> param aggregation.aggregation_layer.0.weight, grad None
param aggregation.aggregation_layer.0.bias, grad None
param aggregation.aggregation_layer.2.weight, grad None
param aggregation.aggregation_layer.2.bias, grad None
Any ideas, what torch operation disattach gradient computation?