Issue with Custom Optimizer RuntimeError: leaf variable has been moved into the graph interior

Hi I am trying to create a custom optimizer that simulates the non-linear and asymmetric behavior of Resistive RAM. That is, simply speaking, when the weight is updated, change in weight is a function of the current value of weight.

Below is my code for custom optimizer. It bases SGD source code so the only part that I changed is def step.

'''Custom Optimizer'''

from torch.optim.optimizer import Optimizer, required
import copy

class RPU_SGD(Optimizer):
    def __init__(self, params, lr = required, momentum =0, dampening =0, weight_decay=0, nesterov=False):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay<0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov)
        super(RPU_SGD, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RPU_SGD, self).__setstate__(state)
        
    def step(self, closure=None):        
        loss =None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
        
            for p in group['params']:
                delta_w_0 = 0.012
                slope_p = 9.63
                slope_n = 9.63
    
                if p.grad is None:
                    continue
                d_p=p.grad
                
                if len(p.size())==2:
                    for i in range(p.size()[0]):
                        for j in range(p.size()[1]):
                            delta_w_pos = delta_w_0 * (1 - slope_p * p[i][j])
                            delta_w_neg = delta_w_0 * (1 + slope_n * p[i][j])

                            if d_p[i][j]>0:
                                gradient = delta_w_pos
                                p[i][j].add_(gradient, alpha=-group['lr'])

                            elif d_p[i][j]<0:
                                gradient = delta_w_neg
                                p[i][j].add_(gradient, alpha=-group['lr'])

                            elif d_p[i][j]==0:
                                graident =0
                                p[i][j].add_(gradient, alpha=-group['lr'])

   
        return loss

And below is my training code.

'''Training the Network'''
net.train()

for epoch in range(epochs):
#     net.train()
    
    running_loss = 0
    for i,data in enumerate(train_loader, 0):
        inputs,labels = data
        
        optimizer.zero_grad()
        pred = net(inputs)
        loss = criterion(pred, labels)
        loss.backward()
        optimizer.step()
        
        running_loss+=loss.item()
        if (i+1)%100==0:
            print(f"epoch: {epoch}/{epochs} | step: {i+1}/{len(train_loader)}, loss: {running_loss/100:.4f}")
            running_loss = 0
        
        new_params = list(net.parameters())

print("End of Training")

When I run them I get error of:

RuntimeError                              Traceback (most recent call last)
<ipython-input-64-b27ef8dbc022> in <module>
     12         pred = net(inputs)
     13         loss = criterion(pred, labels)
---> 14         loss.backward()
     15 #         import pdb; pdb.set_trace()
     16 #         with torch.no_grad():

~/anaconda3/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    193                 products. Defaults to ``False``.
    194         """
--> 195         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    196 
    197     def register_hook(self, hook):

~/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     95         retain_graph = create_graph
     96 
---> 97     Variable._execution_engine.run_backward(
     98         tensors, grad_tensors, retain_graph, create_graph,
     99         allow_unreachable=True)  # allow_unreachable flag

RuntimeError: leaf variable has been moved into the graph interior

I have two issues:

  1. I am not sure why leaf variable has been moved into the graph interior(tbh idk what it means too).
    2.I used for loop to update each weight parameters, yet this, i reckon, would be very time-consuming. Would there be any better way to do this?

Thanks in advance:)

Hi,

We should update this error message :stuck_out_tongue: We have an issue open for that somewhere.
It just means that a leaf Tensor (so a Tensor with no history that requires grad) has been modified inplace in a differentiable manner.
In your case, this is because your optimizer runs in a differentiable manner while it shouldn’t.
I think you’re missing the @torch.no_grad decorator for your step function that is in the pytorch optimizers :wink:

Thank you for your reply:) That one line fixes all up