Pytorch custom loss function using gradient, Combined Function and Derivative Approximation (CFDA), memory leaks

I am trying to use a neural network to perform function and derivative approximation using Pytorch. Derivatives are used among other things in Molecular Dynamics (MD) simulations. A common suggestion/approach is to penalize not just the network output relative to the labels, but also the derivative of the network with respect to the inputs, i.e. using a custom loss function, with some prefactors p and an exponentially decaying learning rate r:

My code seems to be working ok, but I am not sure if it is an efficient implementation, and it also seems to be leaking memory, rendering my computer unusable after a given number of iterations. Would also be very interested if someone has some insight into using neural networks for regression, as all the action seems to be in classification.

Relevant code:

class Trainer:
def __init__(self, net): = net = 1E-3
    self.lr_cur =
    self.gamma = 0.75
    self.pe_start = 1
    self.pe_limit = 400
    self.pf_start = 1000
    self.pf_limit = 1 = self.pe_start = self.pf_start

    self.criterion = nn.MSELoss(reduction="mean")
    self.optimizer = optim.Adam(net.parameters(),
    self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.gamma)

def loss_fn(self, outputs, labels, derivatives, label_derivatives):
    label_loss = self.criterion(outputs, labels)
    derivative_loss = self.criterion(derivatives, label_derivatives)

    return*label_loss +*derivative_loss

def train(self, train_loader, epoch, epochs):
    for i, data in enumerate(train_loader):
        features, labels, label_derivatives = data

        features.requires_grad = True
        outputs =
        outputs.backward(features, create_graph=True)
        derivatives = features.grad

        self.loss = self.loss_fn(outputs, labels, derivatives, label_derivatives)
    if (epoch - 1) % epochs == 200:
        self.lr_cur = self.scheduler.get_lr()[0] = self.pe_limit*(1 - self.lr_cur/ + self.pe_start*(self.lr_cur/ = self.pf_limit*(1 - self.lr_cur/ + self.pf_start*(self.lr_cur/

    return self.loss.item()