Masking gradients intra-minibatch per example in a Linear layer

I’d like some advice from this forum.

Here’s my problem: I want to selectively mask (set to zero) certain elements of the gradient of a dense linear matmul+bias forward operator, selected by a low-cardinality categorical per train example, a vector of which is passed simultaneously through the inputs.

This means somewhere I need access to per-example gradients, mask them accordingly and then sum over the minibatch in the backward pass. A hook to post multiply gradients is not enough, because they need to be selected per example.

I think I may need to write a custom backward() operator, as the backward is not exactly the gradient of the forward operator.

A an alternate source code implementation of nn.Linear, including explicit backward pass where the sum over the minibatch is explicitly implemented (so I can selectively mask right before that part) would be helpful. But I have not been able to find this—it’s a difficult thing to search for as there’s much too much posted with the keywords but without these particular features. I’ve seen libraries which post-hoc extract the per-example gradients but that isn’t enough, they need to be modified during backprop.

The overall goal is to guide the training process with knowledge of the categories, but at inference/model deployment time this categorical variable is not known and the forward inference is a standard matrix multiplication. This is for a real commercial product btw.