I am changing vgg.py for my usage, I created a 2d mask of one of the filter which indicates top 10 elements but, i am doubtful if this affects gradients and training process. Any help will be appreciated
def forward(self, x):
x = self.features(x)
val,ind = torch.topk(torch.flatten(x,1),10)
mask_tensor = temp.permute(1,2,3,0) >= val[:,-1]
x = x*mask_tensor.permute(3,0,1,2)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x