Hello,
I am currently trying to implement some version of the Titan: Learning to Memorize at Test Time paper. Essentially, there is some inner loop which computes something like M_t = M_{t-1}-grad[M_{t-1}Kx - Vx]
. Here, M is a linear layer, K and V are both matrices. M_{t-1}
is the old state, updated through backprop to M_t
. K and V in the inner loop act as hyperparameters in the loss, and thus, must be updated by the outer loop.
The outer loop launches the inner loop, then runs y = M_tQx
where Q is a matrix. Then, compute a standard loss on y, backprop it, and step the optim. My understanding is that we should be able to propogate the loss through the inner loop, so that K and V are updated to be “better loss hyperparameters”.
def _update_memory(self, grads, eta, alpha):
"""
Optimizer for the neural memory.
Updates based on the gradient of the surprise metric.
M_t = (1-alpha)*M_{t-1} + s_t
s_t = eta*s_{t-1} + lr*grad
Args:
- grads: the gradient of the surprise metric
- eta: data dependent momentum
- alpha: decay factor
"""
for (pname, param), grad in zip(self.model.named_parameters(), grads):
past_surprise = self.st_past_state.get_value(pname)
surpise = eta*past_surprise + self.lr*grad # surprise_t = eta*surprise_{t-1} + lr*grad
self.st_past_state.register_buffer(pname, surpise)
# now weights with decay. (1-alpha)*param + surprise
param.data = (1-alpha)*param.data + surpise
def condition(self, x) -> torch.Tensor:
"""
Condition the model on the input x
Returns:
- surprise: the surprise metric
"""
self.model.requires_grad_(True)
# prepare the grad. inner loop only updates the model
k = self.key(x)
v = self.value(x)
s_t = self.surprise_metric(self.model(k), v)
# Compute gradients w.r.t. model params
grads = torch.autograd.grad(s_t, self.model.parameters(), create_graph=True, retain_graph=True)
self._update_memory(grads, eta=0.9, alpha=0.1)
self.model.requires_grad_(False)
return s_t
def forward(self, x):
q = self.query(x)
return self.model(q)
These are the important parts of the algorithm. Then, my testing code is as follows:
x = torch.randn(1, 10) # tokens 1 x 10
model = NeuralMemory(dim_in=10, dim_out=10)
s = model.condition(x)
y = model(x)
l = nn.L1Loss()
loss = l(y, x)
loss.backward()
print(model.key.weight.grad)
print(model.value.weight.grad)
print(model.query.weight.grad)
I get that the key and value have no grad (None and None in the print). Query of course does as this is just a standard pass for it.
I am struggling to figure out how to have the outer loss propogate through to K and V. K and V should not be updated based on the inner loss.
Thank you for any help!