Anyway to update nn.Parameter while keeping gradient / Change nn.Parameter to torch.Tensor

Hi everyone,

I’m currently working on a project that requires the ability to update the weights of a model while maintaining the gradient.

For example,

Model A outputs a vector that is used as an update for Model B’s weights (in a diff. fashion, for more context see: https://mediatum.ub.tum.de/doc/814768/file.pdf) and then Model B makes a prediction, the error for the prediction is then backpropagated to Model A.

This is possible when the weights of Model B are torch.Tensor objects as they can be updated while maintaining the gradient - but the gradient breaks when using nn.Parameter even when using .clone_(). The problem is that all of the pre-implemented nn.Module objects use nn.Parameter for the weights. To fix this in the case of a linear layer, we can simply recreate it with the weight and bias terms stored as torch.Tensor

class Linear(nn.Module):
	def __init__(self, in_features, out_features, bias=False):
		super(Linear, self).__init__()
		self.weight = (-1 - 1) * torch.rand([in_features, out_features]) + 1
		if bias:
			self.bias = (-1-1) * torch.rand([1, out_features]) + 1
		else:
			self.bias = None

		self.in_features = in_features
		self.out_features = out_features

		# Used in Brute Force
		self.no_fast_weights = in_features * out_features if not bias else (in_features * out_features) + out_features

		# Used in FROM/TO Architecture
		self.no_from = in_features
		self.no_to = out_features

	def update_weights(self, update, idx, update_func):
		weight_idx = idx + self.no_fast_weights
		if self.bias is not None:
			weight_idx -= self.out_features
			bias_update = update[weight_idx: weight_idx + self.out_features, :].reshape(self.bias.shape)
			self.bias = update_func(self.bias, bias_update)
		weight_update = update[idx:weight_idx, :].reshape(self.weight.shape)
		self.weight = update_func(self.weight, weight_update)

	def forward(self, x):
		ret = torch.matmul(x, self.weight)
		if self.bias is not None:
			return ret + self.bias
		return ret

Now, however, I’m ready to move onto a more advanced model, in particular, updating the weights of an RNN. This is proving to be challenging without re-implementing the forward pass of the RNN. I would like to avoid this as I’m afraid that any re-implementation will be slower than the original, and of course, more prone to bugs, besides the fact that the forward pass doesn’t actually change, rather, the objects behind it do.

What I want to do is build some structure around the base pytorch implementations, e.g.

class RNNUpdatable(nn.RNN):
         def __init__(self, *args, **kwargs):
             super() etc etc
             self.hidden_weight = torch.Tensor([hidden sizes, etc...])
             etc etc
     
         def weight_update(update):
               self.hidden_weight = update 
               etc etc 

Which would leave the forward pass up to the original pytorch implementation.

I can’t simply do this as when the parent class is initialized it automatically assigns the weights of the RNN to a nn.Parameter and this seems to be impossible to change. I also can’t create a new variable as then it wouldn’t be used in the forward pass and I would have to reimplement everything which is the problem I’m trying to avoid.

Is there anyway around this issue in Pytorch?

You can use nn.LSTM and turn off the requires_grad flag for these.

rnn = nn.LSTM(100, 100)
for param in rnn:
    param.requires_grad = False

and it is left like tensors.

Unfortunately, this breaks the gradient computation, which I need to be retained.

you can either inplace update parameter.data, this is invisible to autograd, or have something like an optimizer between your models, and pass gradients to model A by using autograd.grad (with grad_outputs parameter).

Here is what I believe to be the best way to do it:

import torch
import torch.nn as nn
from torch import _VF as F


class RNN(nn.Module):
	def __init__(self, input_size, hidden_size):
		super(RNN, self).__init__()
		self.hidden_size = hidden_size
		self.hh_w = torch.rand([hidden_size, hidden_size])
		self.hi_w = torch.rand([hidden_size, input_size])
		self.no_fast_weights = hidden_size * hidden_size + hidden_size * input_size

	def update_weights(self, update, idx, update_func):
		end_of_weight_idx = idx + self.no_fast_weights
		end_of_hh_idx = idx + self.hidden_size * self.hidden_size
		hh_update = update[idx:end_of_hh_idx, :].reshape(self.hh_w.shape)
		hi_update = update[end_of_hh_idx:end_of_weight_idx, :].reshape(self.hi_w.shape)
		self.hh_w = update_func(self.hh_w, hh_update)
		self.hi_w = update_func(self.hi_w, hi_update)


	def forward(self, x):
		# F.rnn_tanh(torch.rand([3,3]), (self.input_w, self.hidden_w), False, 1, 0.0, True, False, False)
		hx = torch.zeros(1, x.shape[0], self.hidden_size)
		out, hid = F.rnn_tanh(x, hx, [self.hi_w, self.hh_w], False, 1, 0.0, True, False, False)
		return out, hid

which is based off a hypernetwork implementation (which is more straightforward since they are predicting the weights rather than weight updates, and using a built-in F module, while this is using some of the internals of pytorch. GitHub - g1910/HyperNetworks: PyTorch implementation of HyperNetworks (Ha et al., ICLR 2017) for R)