you will have to reconstruct the optimizer if you change the shapes of the parameters. there’s no real workaround.
Alternatively, you can simply slice self.W in the forward
function and keep it’s original shape intact.
Like:
out = torch.mm(input, self.W[inverse_mask, :]) # just an example
Here you dont repackage or change self.W, you just enforce a mask at runtime.