Memory overflow during Jacobian regularization

Hello PyTorch community,
currently, I am trying to implement a jacobian based regularizer.
The Jacobian computation itself seems to work but in case of MNIST in each epoch about 4 GB of ram is added. I tried several methods such as gc.collect(), del variable, etc. to clear the memory after an epoch, none seems to work. After a few epochs either I am out of RAM or GPU memory. It seems like the computational graph is somehow stored, and I was not able to delete it, yet. Any kind of help is welcome, either to fix the problem with the current code (preferably) or by using a different formula for the Jacobian computation that does not have memory issues. Is it possible or helpful to delete the computational graph after an epoch or is there another way to fix the massive use of memory, the memory overflow, respectively?

The computation goes as follows:

  1. Compute jacobian: derivatives w.r.t the inputs are important.

  2. Reduce jacobain by standard PyTorch methods such as torch.sum() to a loss

  3. Backprop the loss.

For the jacobian calculation I used the following code adopted from torch_jacobian ยท GitHub
but slightly changed to be usable on GPUs, too.

def get_batch_jacobian(net, x, to, device):
    """
    computes the jacobian of a batch
    """
    # noutputs: total output dim (e.g. net(x).shape(b,1,4,4) noutputs=1*4*4
    # b: batch
    # i: in_dim
    # o: out_dim
    # ti: total input dim
    # to: total output dim

    x_batch = x.shape[0]
    x_shape = x.shape[1:]
    x = x.unsqueeze(1)  # b, 1 ,i
    x = x.repeat(1, to, *(1,)*len(x.shape[2:]))  # b * to,i  copy to o dim
    x.requires_grad_(True)
    tmp_shape = x.shape
    y = net(x.reshape(-1, *tmp_shape[2:]))  # x.shape = b*to,i y.shape = b*to,to
    y_shape = y.shape[1:]  # y.shape = b*to,to
    y = y.reshape(x_batch, to, to) # y.shape = b,to,to
    input_val = torch.eye(to).reshape(1, to, to).repeat(x_batch, 1, 1).to(device) 
# input_val.shape = b,to,to  value is (eye)
    y.backward(input_val,create_graph=True)# y.shape = b,to,to, retain_graph=True
    return x.grad.reshape(x_batch, *y_shape, *x_shape)  # x.shape = b,o,i
def train(model, criterion, optimizer, trainloader, device, lbda, regularizer):
    model.train()
    for i, (inputs, labels) in enumerate(trainloader, 0):
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = inputs.view(inputs.shape[0],-1)
        optimizer.zero_grad()
        loss =  regularizer(model, criterion, lbda, inputs, labels, inputs.shape[0], device)
        loss.backward()
        optimizer.step()
def jacReg(model, criterion, lbd, inputs, targets, batch_size, device):
 
    y_pred = model(inputs)
    jac=get_batch_jacobian(model, inputs, 10, device).reshape(batch_size,10,784)
    j_sum = torch.sum(jac,dim=2)
    loss = torch.sum(torch.div(jac, j_sum.unsqueeze(-1)).flatten())
	return criterion(y_pred,targets) + lbd*loss
class ShallowNN(nn.Module):
    def __init__(self):
        """
        This is a simple neural net.
        """

        super(ShallowNN, self).__init__()
        self.linear1 = nn.Linear(784, 300)
        self.linear2 = nn.Linear(300, 10)

    def forward(self, x):
        h1_relu = self.linear1(x).clamp(min=0)
        out = self.linear2(h1_relu)
        return out
criterion = criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
regularizer = jacReg
train_set = torchvision.datasets.MNIST( root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)

 

for epoch in range(args.epochs):
        train(
            model,
            criterion,
            optimizer,
            trainloader,
            device=device,
            lbda=args.lbda,
            regularizer=regularizer
        )