Sparse linear module?

In my network, the last layer is nn.Linear classifier with in_channels=32, out_channels=64*64=4096 (very large)

My input batch is very large, and for each input x, only some of the class label are admissible. I have external algorithm to filter off impossible class labels.

As an example, say I have x1, x2, x3 …xN as input. For xi, there are only 3 possible labels are yi1, yi2, yi3. For xj, there are only 2 possible labels yj1, yj2, etc …

I can get the logits of all samples of all classes by calling logit = classifier([x1,x2 …xN]) = [ [z11,z12…z14096], [z21,z22…z24096] … [zN1,zN2…zN4096]]. Since i don’t need most of the values, is there a faster way to do this, e.g. using torch sparse tensor?

My current solution uses a loop which is very slow:
logit_of_admissble_class = [ classifer_k for zip (xi,yik) ]