Block Gradient Flow For Specific Dimensions

Assuming I have a VAE with a latent dimension size of 16 where I want to block the gradient flow through certain subset of those 16 dimensions: How can I do that? During forward pass I have access to the input as well as a mask tensor that says which dimension’s gradient flow should be blocked. When manually zeroing out the gradients for these dimensions after calling backward(), the gradient will still have influenced all layers before it already, so this solution is insufficient. As far as I know setting “requires_grad=False” should solve this problem in case I want to zero out the whole tensor’s gradient; but this is not the case here because I want to zero out only certain dimensions.

I was thinking of creating 16 linear layers of size (1024, 1) instead of creating one single linear layer of size (1024, 16) to solve this problem. I would then dynamically set the “requires_grad” flag depending on the mask tensor of the input. Is this the right way to do it or are there better solutions to this?