Training of custom quantum feed-forward model giving RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed)

Hi, I’ve written a custom ML model (a quantum circuit), however when it comes to training I am getting the following error:

RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

It arises during the second epoch suggesting some part of the graph is being carried over when it shouldn’t be(?). Previous discussions seem to all focus on RNNs and the hidden state needing to be detached from its history. I assume I need to do something similar here but detaching state or output does not solve the error.

If anyone can fix or explain the error that would be greatly appreciated!

Below is the minimal code that produces the error:

def train(model, optimizer, batch_size, n_qubits):
    criterion = nn.MSELoss()

    dataset = DataFactory(batch_size, n_qubits)
    #losses = []
    for epoch in range(5):
        print(epoch)
        x, y = dataset.next_batch() 
        
        state, output = model(x)
        loss = criterion(output.real, y.real)
        
        #Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        #losses.append(loss.item())
        
    return None #losses

n_qubits = 3
batch_size = 1
model = BasicModel(n_qubits, batch_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.05)
L = train(model, optimizer, batch_size, n_qubits)
class BasicModel(nn.Module):
    def __init__(self, n_qubits: int, batch_size: int):
        super().__init__()   
        self.n_qubits = n_qubits
        self.batch_size = batch_size
        self.fRx = Rx_layer([*range(n_qubits)], batch_size=batch_size)
        Z = torch.Tensor([[1,0],[0,-1]]).cdouble().reshape((1,1,1,2,2))
        self.Observable = Z.repeat((batch_size, n_qubits, 1, 1, 1))        

    def forward(self, x):
        state = torch.zeros((self.batch_size, self.n_qubits, 1, 2, 1), dtype=torch.cdouble)
        state[:, :, :, 0, 0] = 1 
        Rx1 = Rx_layer([*range(self.n_qubits)], weights=x, batch_size=self.batch_size)
        state = Rx1(state)
        state = self.fRx(state)
        O = torch.matmul(state.transpose(3,4).conj(), torch.matmul(self.Observable, state))
        return state, O.mean(dim=2)[:,:]
class Rx_layer(nn.Module):
    "A layer applying the Rx gate"
    def __init__(self, qubits: list, batch_size: int, weights = None):
        """
        qubits: a list with the index of every qubit this layer is to be applied to
        weights: a tensor of rotation angles, if given from input data
        """
        
        super().__init__()
        
        self.qubits = qubits
        self.num_qubits = len(qubits)
        self.batch_size = batch_size
        if weights is None:
            self.weights = nn.Parameter(torch.Tensor(self.batch_size, self.num_qubits).cdouble())
            nn.init.uniform_(self.weights, 0, 2*np.pi)        
       else:
            if weights.shape[1] == 1 and self.num_qubits > 1:
                self.weights = weights.repeat(1, self.num_qubits)
            elif weights.shape[0] == batch_size and weights.shape[1] == self.num_qubits:
                self.weights = weights
            else:
                raise RuntimeError("Dimensions of weight tensor are incompatable. Check the input has the right batch size and qubit count")
            
        self.U = self.Rx().cdouble()

    def Rx(self):
        a = (self.weights/2).cos().reshape(self.batch_size, self.num_qubits,1,1,1)
        b = (self.weights/2).sin().reshape(self.batch_size, self.num_qubits,1,1,1)
        identity = torch.eye(2).reshape(1,1,1,2,2)
        off_identity = torch.Tensor([[0,1],[1,0]]).reshape(1,1,1,2,2)
        return a*identity - 1j*b*off_identity
        
           
    def forward(self, state):
        """
        Take state to be a tensor with dimension batch x qubits x d&c x 2 x 1
        """
        state[:,self.qubits] = torch.matmul(self.U, state[:,self.qubits]) 

        return state
class DataFactory(Dataset):
    def __init__(self, batch_size: int, input_dim: int):
        self.batch_size = batch_size 
        self.input_dim = input_dim
    
    def next_batch(self):
        X = torch.rand((self.batch_size, self.input_dim)).cdouble()*2 -1
        Y = torch.rand((self.batch_size, self.input_dim)).cdouble()*2 -1
        return X, Y

This is my first post so let me know if there is anything I can clarify (eg I have omitted a mathematical description of the model) or improve!

meybe it’s because of this part.

you are doing some inplace operation.
try re-writing it somehow with torch.cat or other function .

I’ve tried changing it to
state_out = state.clone().detach()
state_out[:,self.qubits] = torch.matmul(self.U, state[:,self.qubits])
however this still gives the same error. I’ve also tried with just .clone() or .detach()
Unfortunately the matrix multiplication is key here, and torch.bmm won’t work since state is a 5D tensor. Are there other functions i’ve missed?

hi , i really don’t know what is wrong with it.
but by changing

to

self.U = nn.Parameter(self.Rx().cdouble())

it’s no longer raising an error.
but i don’t feel it is right to do this. maybe it helps you find out what’s wrong with it.

It is indeed not right as self.U should be parameterized by self.weights, however I have managed to get it working, thanks for the help!
I moved self.U = self.Rx().cdouble() from init to forward in the Rx_layer class. This makes sense as it now generates U with updated parameters when required.