Different forward and backward weights

I have a use case, where I need to use a different set of weights to compute the backward pass. An instance of where this is used is in this work https://www.nature.com/articles/ncomms13276 and numerous follow up works. Or this https://www.siarez.com/projects/random-backpropogation

Is there anyway to do this with the current API? If not what would be the best angle of attack?

Thanks

You could try to load the backward state_dict before executing the backward operation, but it’s quite a hacky way:

model = nn.Sequential(
    nn.Linear(1, 1, bias=False),
    nn.Linear(1, 1, bias=False)
)

with torch.no_grad():
    model[0].weight.fill_(1.)
    model[1].weight.fill_(1.)

sd_forward = copy.deepcopy(model.state_dict())
sd_backward = copy.deepcopy(sd_forward)
sd_backward['0.weight'].fill_(10.)
sd_backward['1.weight'].fill_(10.)


# one train step
output = model(torch.ones(1, 1))
model.load_state_dict(sd_backward)
output.mean().backward()
print(model[0].weight.grad)
> tensor([[10.]])
print(model[1].weight.grad)
> tensor([[1.]])

Also note, that the last gradient is wrong, since the output was calculated using the old weights.
Would this approach work for you or did I misunderstand your question?

@ptrblck Thanks for your reply. I actually figured out the right way to do this. I basically wrote my own Autograd Function class, similar to here: https://pytorch.org/docs/stable/notes/extending.html
Then inside its backward method, I used a different set of weights to compute grad_input. So the backward now looks like this:

    def backward(ctx, grad_output):
        input, weight, b_weights, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(b_weights)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[3]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, torch.zeros_like(b_weights), grad_bias

b_weights are the backwards weights that are passed to the forward function and saved in ctx

1 Like

That looks like a good approach! Thanks for sharing it! :slight_smile: