2d mask gradient issue

I am changing vgg.py for my usage, I created a 2d mask of one of the filter which indicates top 10 elements but, i am doubtful if this affects gradients and training process. Any help will be appreciated

    def forward(self, x):
        x = self.features(x)
        val,ind = torch.topk(torch.flatten(x,1),10)
        mask_tensor = temp.permute(1,2,3,0) >= val[:,-1]
        x = x*mask_tensor.permute(3,0,1,2)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

Since the >= operation is not differentiable, no gradient will flow back towards the mask_tensor.
And the gradient will be zeroed out for all the values where the mask is 0.