Memory leakage with Custom Loss functions during validation step

from torch.autograd import Variable, Function
# Source: 
def _to_one_hot(y, n_dims=None):
    Take integer y (tensor or variable) with n dims and 
    convert it to 1-hot representation with n+1 dims
    y_tensor = if isinstance(y, Variable) else y
    y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
    n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
    y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
    y_one_hot = y_one_hot.view(y.size()[0], -1)
    return Variable(y_one_hot) if isinstance(y, Variable) else y_one_hot

class LSEP(Function): 
    Autograd function of LSEP loss. Appropirate for multi-label
    - Reference: Li+2017
    def forward(ctx, input, target, max_num_trials = None):
        batch_size = target.size()[0]
        label_size = target.size()[1]

        ## rank weight 
        rank_weights = [1.0/1]
        for i in range(1, label_size):
            rank_weights.append(rank_weights[i-1] + (1.0/i+1))

        if max_num_trials is None: 
            max_num_trials = target.size()[1] - 1

        positive_indices =
        negative_indices = target.eq(0).float()
        ## summing over all negatives and positives
        loss = 0.
        for i in range(input.size()[0]): # loop over examples
            pos = np.array([j for j,pos in enumerate(positive_indices[i]) if pos != 0])
            neg = np.array([j for j,neg in enumerate(negative_indices[i]) if neg != 0])
            for j,pj in enumerate(pos):
                for k,nj in enumerate(neg):
                    loss += np.exp(input[i,nj]-input[i,pj])
        loss = torch.from_numpy(np.array([np.log(1 + loss)])).float()
        ctx.save_for_backward(input, target)
        ctx.loss = loss
        ctx.positive_indices = positive_indices
        ctx.negative_indices = negative_indices
        return loss

    # This function has only a single output, so it gets only one gradient 
    def backward(ctx, grad_output):
        input, target = ctx.saved_tensors
        loss = Variable(ctx.loss, requires_grad = False)
        positive_indices = ctx.positive_indices
        negative_indices = ctx.negative_indices

        fac  = -1 / loss
        grad_input = torch.zeros(input.size())
        ## make one-hot vectors
        one_hot_pos, one_hot_neg = [],[]
        for i in range(grad_input.size()[0]): # loop over examples
            pos_ind = np.array([j for j,pos in enumerate(positive_indices[i]) if pos != 0])
            neg_ind = np.array([j for j,neg in enumerate(negative_indices[i]) if neg != 0])
        ## grad
        for i in range(grad_input.size()[0]):
            for dum_j,phot in enumerate(one_hot_pos[i]):
                for dum_k,nhot in enumerate(one_hot_neg[i]):
                    grad_input[i] += (phot-nhot)*torch.exp(-input[i].data*(phot-nhot))
        grad_input = Variable(grad_input) * (grad_output * fac)

        return grad_input, None, None
#--- main class
class LSEPLoss(nn.Module): 
    def __init__(self): 
        super(LSEPLoss, self).__init__()
    def forward(self, input, target): 
        return LSEP.apply(input.cpu(), target.cpu())

def val_metrics(model, valid_dl):
    total = 0
    total2 = 0
    sum_loss = 0
    correct2 = 0
    correct = 0 
    predc = 0
    for x, y in valid_dl:
        batch = y.shape[0]
        with torch.cuda.device(1L):
            x = Variable(x.cuda().float())
            y = Variable(y.cuda())  #.unsqueeze(1))
        out = model(x)
        # Number of correct answers
        with torch.cuda.device(1L):
            pred = (out > 0.0).cuda().long()
#         correct += pred.eq(
        correct += pred.long().eq(y.long()).sum().item()  # Based on whole dataset
#         correct2 += pred[y>0].sum().data   # Considering only y lables
        correct2 += (y.long()*pred).sum().item()
        # Loss calculation
        y = y.float()
        loss = criterion(out, y)
        sum_loss += batch*([0])
#         sum_loss += batch*(

        # Total number of data points
        total += batch  # Total items
        total2 += y.sum().item() # Total y labels available 
        # Number of positive predictions per data point
        pred_per_data = pred.float().sum(dim=1).mean().item()
        # F score calculation
        f1 = f1_score(y.cpu().data, pred.cpu().data, average="samples")
        predc= predc+pred.sum().item()
    print("val loss, overall accuracy, y_label_accuracy, pos_pred_per_data, f1_score", round((sum_loss/total),4), 
                                                                                    round((correct/(total*28.0)), 4),
                                                                                     round(correct2/total2, 4), 
                                                                                    round(pred_per_data, 4), round(f1,4))
def train_triangular_policy(model, train_dl, valid_dl, lr_low=1e-5, lr_high=0.01, epochs = 4):
    idx = 0
    iterations = epochs*len(train_dl)
    lrs = get_triangular_lr(lr_low, lr_high, iterations)
    for i in range(epochs):
        total = 0
        sum_loss = 0
        for i, (x, y) in enumerate(train_dl):
            optim = get_optimizer(model, lr = lrs[idx], wd =0)
            batch = y.shape[0]
            with torch.cuda.device(1L):
                x = Variable(x.cuda().float())
                y = Variable(y.cuda().float())    #.unsqueeze(1)
            out = model(x)
            loss = criterion(out, y)
            idx += 1
            total += batch
            sum_loss += batch*([0])
        print("train loss", (sum_loss/total))
        val_metrics(model, valid_dl)
    return sum_loss/total

def training_loop(model, train_dl, valid_dl, steps=3, lr_low=1e-6, lr_high=0.01, epochs = 4):
    for i in range(steps):
        start = 
        loss = train_triangular_policy(model, train_dl, valid_dl, lr_low, lr_high, epochs)
        end =
        t = 'Time elapsed {}'.format(end - start)
        print("----End of step", t)

These are the loss function, training and validation function i am using. I can see memory leakge during each step of validation loop. Not sure what is causing this. The code used to run finr with BCE loss. Facing issue with the custom loss function

@ptrblck Will you be able to help?

Could you post the shapes of all inputs so that I could debug it using random values?

@ptrblck Thanks for the reply. Give me some time. I will share it

The input to the model is:
X is of shape 64,3,224,224
Y is a multi label target with 64,28 dimension. Do you need any other sizes?

Thanks for the shapes. I’ll check it a bit later as I’m currently on my mobile.

@ptrblck Sure thanks.

when you create the variables, can you try passing requires_grad=False?
I suspect that the computation graphs are created even during validation routine.

@InnovArul Thanks for the suggestion. I can try doing that. I was not facing such issues without the custom loss function. But i can recheck and try your suggestion

I just checked your loss function and it seems to work fine without creating a memory leak.

x = torch.randn(64, 28, requires_grad=True, device='cuda')
target = torch.empty(64, 28, device='cuda').random_(2)

criterion = LSEPLoss()

for epoch in range(10):
    loss = criterion(x, target)

Could you run it in your setup and check it?

I couldn’t test the whole script, as some objects and functions are missing, e.g. model, get_optimizer.
Also, Variables are deprecated since PyTorch 0.4.0, so you can just use tensors instead. :wink:

@ptrblck Yeah sure I will do that. I had checked that and found that loss was fine. If I remove val_metrics(model, valid_dl) call in training loop, there is no leakage. However, if I include it, I face issues. I will try to run what you have suggested. I can share the model and optimizer in some time if it still doesn’t work.
I was previously using a different version of pytorch on production system hence the variable. Will drop it :smile:

again, my suspicion is that the computation graphs are created during validation routine.

Can you try to avoid it?

Pytorch <= 0.4:

input = Variable(input, requires_grad=False)

Pytorch >= 1.0:

with torch.no_grad():

@ptrblck: do you have the similar hunch about this?


I’m not sure, as I would assume all local variables are cleared after returning from this method.
However, this will still save some memory so it’s definitely recommended.

@InnovArul @ptrblck Thanks for the help. with torch.no_grad(): solved the problem