Yet another post on custom loss functions

I did read many posts on the website discussing the alternatives in PyTorch to implement a custom loss function. I would like to confirm that the current state of the project works just fine with the implementation below:

class TVLoss(Module):
    """
    Total variation loss.
    """
    def __init__(self):
        super(TVLoss, self).__init__()

    def forward(self, yhat, y):
        bsize, chan, height, width = y.size()
        errors = []
        for h in range(height-1):
            dy = torch.abs(y[:,:,h+1,:] - y[:,:,h,:])
            dyhat = torch.abs(yhat[:,:,h+1,:] - yhat[:,:,h,:])
            error = torch.norm(dy - dyhat, 1)
            errors.append(error)

        return sum(errors) / height

Is there any inefficiency or missing component? The training is apparently working but I didn’t have the chance to train for long enough.

1 Like

This looks fine. If your height is large, you may get better performance if you avoid the for-loop:

def forward(self, yhat, y):
    bsize, chan, height, width = y.size()
    errors = []
    dy = torch.abs(y[:,:,1:,:] - y[:,:,:-1,:])
    dyhat = torch.abs(yhat[:,:,1:,:] - yhat[:,:,:-1,:])
    error = torch.norm(dy - dyhat, 1)
    return error / height
2 Likes

You don’t actually need to subclass Module, you could just put it in a simple function. Like this…

def tvloss(yhat, y):
    bsize, chan, height, width = y.size()
    ...etc...
1 Like