Propogate loss though inner loop Meta-Learning

Hello,

I am currently trying to implement some version of the Titan: Learning to Memorize at Test Time paper. Essentially, there is some inner loop which computes something like M_t = M_{t-1}-grad[M_{t-1}Kx - Vx]. Here, M is a linear layer, K and V are both matrices. M_{t-1} is the old state, updated through backprop to M_t. K and V in the inner loop act as hyperparameters in the loss, and thus, must be updated by the outer loop.

The outer loop launches the inner loop, then runs y = M_tQx where Q is a matrix. Then, compute a standard loss on y, backprop it, and step the optim. My understanding is that we should be able to propogate the loss through the inner loop, so that K and V are updated to be “better loss hyperparameters”.

    def _update_memory(self, grads, eta, alpha):
        """
        Optimizer for the neural memory.
        
        Updates based on the gradient of the surprise metric.
        
        M_t = (1-alpha)*M_{t-1} + s_t
        s_t = eta*s_{t-1} + lr*grad
        
        Args:
        - grads: the gradient of the surprise metric
        - eta: data dependent momentum
        - alpha: decay factor
        """
        for (pname, param), grad in zip(self.model.named_parameters(), grads):
            past_surprise = self.st_past_state.get_value(pname)
            surpise = eta*past_surprise + self.lr*grad # surprise_t = eta*surprise_{t-1} + lr*grad
            
            self.st_past_state.register_buffer(pname, surpise)
            # now weights with decay. (1-alpha)*param + surprise
            param.data = (1-alpha)*param.data + surpise
                
    def condition(self, x) -> torch.Tensor:
        """
        Condition the model on the input x
        
        Returns:
        - surprise: the surprise metric
        """
        self.model.requires_grad_(True)
        # prepare the grad. inner loop only updates the model
        k = self.key(x)
        v = self.value(x)
        
        s_t = self.surprise_metric(self.model(k), v)
        # Compute gradients w.r.t. model params
        grads = torch.autograd.grad(s_t, self.model.parameters(), create_graph=True, retain_graph=True)
        self._update_memory(grads, eta=0.9, alpha=0.1)
        self.model.requires_grad_(False)
        return s_t 

    
    def forward(self, x):
        q = self.query(x)
        return self.model(q)

These are the important parts of the algorithm. Then, my testing code is as follows:

    x = torch.randn(1, 10) # tokens 1 x 10
    model = NeuralMemory(dim_in=10, dim_out=10)
    
    s = model.condition(x)
    y = model(x)
    
    l = nn.L1Loss()
    loss = l(y, x)
    loss.backward()
    
    print(model.key.weight.grad)
    print(model.value.weight.grad)
    print(model.query.weight.grad)    

I get that the key and value have no grad (None and None in the print). Query of course does as this is just a standard pass for it.

I am struggling to figure out how to have the outer loss propogate through to K and V. K and V should not be updated based on the inner loss.

Thank you for any help!

Hi aheschl!

I can’t tell what you are trying to do from the code fragments you have posted.

Could you post a fully-self-contained, runnable, greatly-simplified example script
that demonstrates the core issue you are facing, together with the output you get
when you run it?

Here, model, key, and value all appear to be properties of whatever self is.

But here, key and value appear to be properties of model (whatever that may be).

Are model, key, and value the same throughout your post, or do they refer to
different things in different places?

With a super-simple, fully-self-contained script, we wouldn’t have to guess about what
is going on.

Best.

K. Frank

Hi Frank,

Thank you so much for taking the time to reply. You’re right, I should’ve given a better example script :sweat_smile:

I have actually managed to solve my problem. Incase anyone runs into this thread and wants an answer, here is a self contained example script like you indicated:

class NeuralMemory(nn.Module):
    def __init__(self, 
                dim_in: int, 
                dim_out: int, 
                lr: float=1e-3
            ):
        super(NeuralMemory, self).__init__()
        self.model = nn.Linear(dim_in, dim_out, bias=True)        
        self.key = nn.Linear(dim_in, dim_in, bias=False)
        self.value = nn.Linear(dim_in, dim_in, bias=False)
        self.query = nn.Linear(dim_in, dim_in, bias=False)
        
        self.register_buffer("p_surprise", None)  
        self.lr = lr  
        self.surprise_metric = nn.L1Loss(reduction='mean')
        self.model.requires_grad_(False)
    
    def _update_memory(self, grads, eta, alpha):
        """
        Optimizer for the neural memory.
        
        Updates based on the gradient of the surprise metric.
        
        M_t = (1-alpha)*M_{t-1} + s_t
        s_t = eta*s_{t-1} + lr*grad
        
        Args:
        - grads: the gradient of the surprise metric
        - eta: data dependent momentum
        - alpha: decay factor
        """
        for (pname, param), grad in zip(self.model.named_parameters(), grads):
            past_surprise = self.p_surprise
            # check state for this params and initialize if not present
            if past_surprise is None:
                past_surprise = torch.zeros_like(param.data)
                self.register_buffer(pname, past_surprise)
                
            if grad is None:
                warning(f"Gradient for {pname} is None. Skipping update.")
                continue
            
            surpise = eta*past_surprise + self.lr*grad # surprise_t = eta*surprise_{t-1} + lr*grad
            
            self.st_past_state.register_buffer(pname, surpise)
            # now weights with decay. (1-alpha)*param + surprise
            param.data = (1-alpha)*param.data + surpise
                
    def condition(self, x) -> torch.Tensor:
        """
        Condition the model on the input x
        
        Returns:
        - surprise: the surprise metric
        """
        self.model.requires_grad_(True)
        # prepare the grad. inner loop only updates the model
        k = self.key(x)
        v = self.value(x)
        
        s_t = self.surprise_metric(self.model(k), v)
        # Compute gradients w.r.t. model params
        grads = torch.autograd.grad(s_t, self.model.parameters(), create_graph=True, retain_graph=True)
        self._update_memory(grads, eta=0.9, alpha=0.1)
        self.model.requires_grad_(False)
        return s_t 

    
    def forward(self, x):
        q = self.query(x)
        return self.model(q)

If you test the above with

if __name__ == "__main__":
    x = torch.randn(1, 10) # tokens 1 x 10
    model = NeuralMemory(dim_in=10, dim_out=10)
    
    s = model.condition(x)
    y = model(x)
    
    l = nn.L1Loss()
    loss = l(y, x)
    loss.backward()
    
    print(model.key.weight.grad)
    print(model.value.weight.grad)
    print(model.query.weight.grad)    

You will find that the only weight with grad is query. This is an issue, as key and value should be loss hyper parameters for the inner loop “condition”.

To allow the gradient to flow, and compute the second derivative, I have found that self.model should be a tensor. When running param.data = (1-alpha)*param.data + surpise, no gradient is stored (or something else, but I am not sure). If you change self.model to torch.tensor(..., requires_grad=True) and update the matrix directly, the gradient will flow through to key and value in the outer loop. You can also use nn.Parameter.

Maybe my understanding of WHY the fix works isn’t 100% correct, but it works.

Incase anyone wants to see the full code for such a method, you can checkout this repo: GitHub - aheschl1/Titan: Unofficial Implementation of Titans: Learning to Memorize at Test Time

Best,
Andrew