Backprop from Parameter.grad

Edit: clarification:
The question is how to
start the autograd backprop from accumulated gradient instead of from loss or activation”

That is, we want dout/da , but we only have dout/dp (we do not have out ), so we can’t compute it straightforward.

Mathematically:
dout/da = dout/dp * dp/da
all we need is dp/da and we are done.

note1: a is arbitrary, can be an activation or another parameter.
note2: we can create the computational graph only from a->p .


Original post:

I know of two main ways to do backprop in pytorch:

  1. from loss: loss.backward()
  2. using gradients and activation with same size (or broadcast-able?) torch.autograd.backward(x, g)

Now, lets say we have some p = torch.nn.Parameter() for which some gradient (p.grad) was already accumulated, and we want to continue the back-propagation from it.
Lets say we can construct the entire computational graph until p, and now we have it as some activation: a=f(input)

Note that mathematically its very simple chain rule, just not sure how to do it in pytorch. (at least from python)

Any suggestions?
Here is some pseudo code for the problem:

input=torch.tensor(...,requires_grad=True)
a=f(input)
p.grad = some_value_we_want # dout/dp
p.data = another_value_we_want 

How do we calculate dout/da in pytorch?

assuming
out = p ? f(input) where ? is some unknown function.
(shape of a is not the same shape of p.grad)…

You can get dout/da by callling a.retain_grad()

input = torch.ones(1)
f = torch.nn.Linear(1, 1, bias=False)
f.weight.data.copy_(torch.ones(1))
a = 2*f(input)
out = a*a/2
a.retain_grad()
out.backward()
assert a.grad[0] == 2.
assert f.weight.grad[0] == 4.

That did not answer the question.
I asked something else…

Could you clarify your question?
It seems you asked how to calculate dout/da, which @Yaroslav_Bulatov tried to answer.

Sure, I’ll clarify: the question is how to
start the autograd backprop from accumulated gradient instead of from loss or activation”

That is, we want dout/da , but we only have dout/dp (we do not have out), so we can’t compute it straightforward.

Mathematically:
dout/da = dout/dp * dp/da
all we need is dp/da and we are done.

note1: a is arbitrary, can be an activation or another parameter.
note2: we can create the computational graph only from a->p.

Maybe something like this?

p.retain_grad()
out.backward()
dout_dp = p.grad
p.backward(dout_dp)
dout_da = a.grad

This is closer, but did not work.
Reason is that dout_dp should be detached from the computation graph of out when connected to a and vice versa. (Mathematically its possible, we have all the information).
Let me illustrate further with a code example:

import torch
m = 2
n = 3
batch = 4

a = torch.ones(batch, m)
a.retain_grad()
layer = torch.nn.Linear(m, n, bias=False)
layer.weight.data.copy_(torch.ones(n, m))

dout_dp = torch.ones(n, m)*5  # A **given** dout_dp, detached from `out` and `a`.
layer.weight.grad = dout_dp
v2 = layer(a)   # reconstruct a computation graph a->p

# Problem: `p.grad` is detached from the reconstructed computation graph of `a`
# so the backward below not going to work
torch.autograd.backward(list(layer.parameters()), 
                        [p.grad for p in layer.parameters()])
print(a.grad)  # None

I think that doing p.backward(p.grad) is an advancement and probably part of the future solution, but we are lacking a connection to the gradient, which must to be added artificially somehow:
How do we tell pytorch that a has connection to p.grad (and figure that connection automatically from the new computation graph a->p)?

I see you did there p.retain_grad(), I guess this saves the “connection” from a->p->out which is what we want to avoid,
as we want to start “clean” from the gradient and restored computation graph “a->p”.

From what I understand, it seems like autograd does not “use” parameter gradients (like p.grad) directly in its computation, but only AccumulateGradients to them.

Note we can’t really restore from accumulated gradients with batch>1 because we lose all the batch-related info.
so the question should be for batch=1.
Currently I did not find any way to do this (automatically) with Pytorch.