Calculating gradients for previous layers with only intermediate gradient input?

I’m trying to solve a problem, and it’s maybe a bit strange so I’ve had a lot of difficulty reaching a solution. I have a model which consists of two components: the first is constituted of convolutional layers which produces a 3D block of output, and the second uses this output as weights for linear layers. The loss is then calculated based on the final output of the whole thing. I realize how strange this is, but it would be long-winded to explain why I believe I need the model to be set up like this.

On the forward pass, things work alright and I can obtain a sensible output. On the backward pass I can calculate the gradients of the linear layers, but I need to re-package these (essentially just re-shaping/concatenation) so that they can be propagated back to the convolutional layers. So what I believe I need is a way to calculate the gradients of the convolutional layers with only these intermediate, re-packaged gradients.

I suppose my first question is whether this is even possible? I’ve been going back and forth on it, but I don’t see why it would not be possible mathematically. And if so, my second question is whether there’s a concise way to do it in PyTorch already, or if I would need to make everything myself?

Thank you.