Masked nn.Linear

I cannot comment on the design, but so if the input (not the weight) is masked to 0 in the forward pass, the gradient contribution to the weight that is applied to that masked part should should be 0 (because the backward computation matrix multiplies the gradient of the output with the input (in the right order, with the right transpositions) to get the weight gradient?
If you have NaNs/infs that won’t work, though and you need to deal with those.

Best regards

Thomas