How to access gradients wrt each element in latent space vector z?

Hello, I’m building a VAE and I am trying to implement a way to compute the gradients of my encoder network activations wrt each element in the latent vector z. My encoder network outputs a mean and standard deviation of size [1,latent_size] for each image in the batch, which I then reparameterize to a vector z for each image. I’ve implemented hook functions that save the gradients for each layer, which have of course the feature dimension of that particular layer. Now I want to know how I can get from this to the gradients for each element in the latent vector z. I’ve looked at many repositories online, but they always compute the gradients wrt the full vector z, so I’m a bit lost on how to get to where I want.

Could someone maybe help me out? I’m quite new to pytorch so if I’m not clear or if there’s information missing, please let me know!