Hi everyone,
I’m currently working on a project that requires the ability to update the weights of a model while maintaining the gradient.
For example,
Model A outputs a vector that is used as an update for Model B’s weights (in a diff. fashion, for more context see: https://mediatum.ub.tum.de/doc/814768/file.pdf) and then Model B makes a prediction, the error for the prediction is then backpropagated to Model A.
This is possible when the weights of Model B are torch.Tensor objects as they can be updated while maintaining the gradient - but the gradient breaks when using nn.Parameter even when using .clone_(). The problem is that all of the pre-implemented nn.Module objects use nn.Parameter for the weights. To fix this in the case of a linear layer, we can simply recreate it with the weight and bias terms stored as torch.Tensor
class Linear(nn.Module):
def __init__(self, in_features, out_features, bias=False):
super(Linear, self).__init__()
self.weight = (-1 - 1) * torch.rand([in_features, out_features]) + 1
if bias:
self.bias = (-1-1) * torch.rand([1, out_features]) + 1
else:
self.bias = None
self.in_features = in_features
self.out_features = out_features
# Used in Brute Force
self.no_fast_weights = in_features * out_features if not bias else (in_features * out_features) + out_features
# Used in FROM/TO Architecture
self.no_from = in_features
self.no_to = out_features
def update_weights(self, update, idx, update_func):
weight_idx = idx + self.no_fast_weights
if self.bias is not None:
weight_idx -= self.out_features
bias_update = update[weight_idx: weight_idx + self.out_features, :].reshape(self.bias.shape)
self.bias = update_func(self.bias, bias_update)
weight_update = update[idx:weight_idx, :].reshape(self.weight.shape)
self.weight = update_func(self.weight, weight_update)
def forward(self, x):
ret = torch.matmul(x, self.weight)
if self.bias is not None:
return ret + self.bias
return ret
Now, however, I’m ready to move onto a more advanced model, in particular, updating the weights of an RNN. This is proving to be challenging without re-implementing the forward pass of the RNN. I would like to avoid this as I’m afraid that any re-implementation will be slower than the original, and of course, more prone to bugs, besides the fact that the forward pass doesn’t actually change, rather, the objects behind it do.
What I want to do is build some structure around the base pytorch implementations, e.g.
class RNNUpdatable(nn.RNN):
def __init__(self, *args, **kwargs):
super() etc etc
self.hidden_weight = torch.Tensor([hidden sizes, etc...])
etc etc
def weight_update(update):
self.hidden_weight = update
etc etc
Which would leave the forward pass up to the original pytorch implementation.
I can’t simply do this as when the parent class is initialized it automatically assigns the weights of the RNN to a nn.Parameter and this seems to be impossible to change. I also can’t create a new variable as then it wouldn’t be used in the forward pass and I would have to reimplement everything which is the problem I’m trying to avoid.
Is there anyway around this issue in Pytorch?