suppose I have a layer like layer(x) = MyLinearLayer(x) + nn.Linear1(x) + ... + nn.Lineark(x) with no biases.
Suppose all nn.Linear{i}.weights have the same shape, and MyLinearLayer represents a proxy of a weight of the very same shape. Because of the linear structure, all the weights in these modules should receive the very same gradient. This gradient I compute in a custom MyLinearLayer.backward. What would be the most efficient way to re-use this gradient for all the modules nn.Linear{i}? Of course, I could write a custom layer.backward to do that, but are there other, simpler ways?
I don’t understand the question completely. The incoming dgrad will be the same as you already explained and will be passed to the backward function of each module. The wgrad computation is specific to each module and you won’t be able to reuse anything since the computation depends on the forward activation. The outgoing dgrad computation uses the parameters so also unsure what you want to reuse.
Could you clarify your use case a bit more?
Let me clarify. Let’s assume that Linear: (x, W) → (Wx) := y, so y = Wx, then
the JVP is dy = dWx + Wdx. Let gy be the in-flowing grad, then for the VJP
Tr(gy^T dy) = Tr(gy^T dW x) + Tr(gy^T W dx) = Tr(x gy^T dW) + [sensitivity for x which is irrelevant here], so the VJP for W, gW, is equal to gy x^T, and is independent of the forward compute. Since the input, x, is shared among the linear layers, and so is gy, the grad for each weight should be the same.
And the question is, what would be the easiest way to tell PyTorch that I have that structure? One solution could be just writing a custom Module that represents a linear combination of other Linear models. That, however, might be problematic in my use case because I have a very limited capability to modify the model’s graph due to quite rigid structure of subsequent non-trivial external optimizers with all sorts of backward hooks… Basically, I can only replace nn.Linear with nn.Linear. And I am not aware whether direct manipulations with the grad attribute are safe. Probably, custom nn.Linear{i} might work as well, with some API for grad swap and custom backward. But maybe there are other ways which are not subclass-based?