Training an autoencoder with a weighted loss where the loss weights are trainable and have norm 1

I’m looking to create a weighted loss function where the weights always have norm 1 and are trainable. I see two main avenues for accomplishing this (described below). I have invested significant energy in only the first so far, because I am not comfortable enough with PyTorch to make much progress on the second (yet). I’m posting to ask for help with getting either of the avenues working (and initially, even confirmation that what I want to do is possible).

For the record, it seems at first glance that what I’m doing is similar to what’s done here and certainly the linked code in this repo has informed my attempts. Still, this code does not account for the normalization of the weights parameter.

Avenue 1: hard-renormalizing the weights

I have tried to compute regular gradients and then simply normalize the weights to have norm 1. In this setting, I run into the following error:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

To do this hard-renormalization, I have tried something like this:

class WeightedLoss(nn.Module):
    ...
    def forward(self, output, input, weights):
        return (((output - input)**2 * weights)
                .view(input.size(0), -1)
                .sum(dim=1, keepdims=True)
                .mean()

weights = Torch.rand(.9, 1.1)
weights.requires_grad = True

for epoch in range(n_epochs):
    for batch in dataloader:
        weights = F.normalize(weights, p=2, dim=0)
        output = model(input)
        loss = criterion(output, input, weights)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

This yields the RunTimeError mentioned above. I have also tried something like this:

class WeightedLoss(nn.Module):
    def __init__(self, n_weights):
        super().__init__()
        self.n_weights = n_weights
        self.weights = nn.Parameter(torch.Tensor(n_weights))
        nn.init.uniform_(self.weights, .9, 1.1)

    def forward(self, output, input):
        self.weights = F.normalize(self.weights, p=2, dim=0)
        return (((output - input)**2 * self.weights)
                .view(input.size(0), -1)
                .sum(dim=1, keepdims=True)
                .mean()

for epoch in range(n_epochs):
    for batch in dataloader:
        output = model(input)
        loss = criterion(output, input)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Unfortunately, this second bit won’t work because F.normalize does not return a parameter. I believe a variation on this code (so that self.weights remains a parameter) either gives the same RunTimeError, or messes up somewhere else… I am not sure how to restructure either of these codes in a way that allows for retain_graph = True to make sense without the code becoming exceptionally inefficient during training.

Avenue 2: projecting the gradient updates

Another perfectly acceptable option that I don’t know how to implement would be the following: instantiate the weights (either as a Tensor with requires_grad = True, or as an nn.Parameter in WeightedLoss); write a custom backward method for WeightedLoss that computes the usual gradient for self.weights and then projects it/normalizes it such that weights - alpha * weights.grad still has 2-norm equal to 1. In this case, I’m entirely at a loss for how to write the function backwards, and have had trouble finding online examples pertaining to my particular use case.

Thank you to anyone with useful insight on this problem.

How about change this line into

self.weights = nn.Parameter(F.normalize(self.weights, p=2, dim=0))

According to the partial snippet, the retain_graph error may be caused by there is some computation in global and cleared in the first .backward that when you backward twice it cannot get this node to compute the gradient.

Hi @MariosOreo,
Thanks for the reply. Unfortunately, redefining the Parameter in this way each time destroys the computation graph associated to it, so it doesn’t get updated by .step after .backward is computed. Whatever the solution is, it must preserve the computational graph for weights (by updating its value without instantiating a brand new Parameter with the same name).

Progress so far

My current solution is to include a second loss function:

class CoefficientLoss(nn.Module):
    def __init__(self, alpha):
        assert isinstance(alpha, torch.Tensor), 'Whoops, alpha is not a Tensor.'
        self.alpha = alpha
    def forward(self, weights):
        return self.alpha * (1 - weights.norm())**2

Then one simply calls

loss = criterion(output, input) + criterion_aux(criterion.weights)

where weights are no longer normalized using F.normalize (because this new loss function is serving that purpose), and where

criterion = WeightedLoss(n_weights)
criterion_aux = CoefficientLoss(torch.Tensor([1000.]))

This gives me a close approximation to what I need (but not precisely what I want!).

Feasible direction

No pun intended.

It turns out that this code (linked in OP) actually provides a lovely example of defining a backwards method manually, but it took me a bit of staring to parse what was going on. It’s still a little fuzzy. In short, it should be possible to define a WeightedLoss using a WeightedLossFunc that has a backward method that returns the appropriate gradients. For my particular use case this seems rather involved — that is, I’m still unsure how to proceed with this avenue, and any suggestions to this end would be greatly appreciated.

1 Like

You could have a reference on this.
The demo in your link is really friendly. And I tried to simulate it to meet your requirement, I am not sure if it works or not. I think using custom function is a better way (you mentioned) because it makes us clearer about some internal mechanism.

class WeightedLoss(nn.Module):
    def __init__(self, weight):
        super(WeightedLoss, self).__init__()
        self.weight = weight
        self.lossfunc = WeightedLossFunc.apply

    def forward(self, input, target):
        return self.lossfunc(input, target, self.weight)

class WeightedLossFunc(Function):

    @staticmethod
    def forward(ctx, input, target, weight):
        ctx.save_for_backward(input, target, weight)
        return ((target-input)**2 * weight).view(target.size(0), -1).sum(dim=1, keepdims=True).mean()

    @staticmethod
    def backward(ctx, grad_output):
        input, target, weight = ctx.saved_tensors
        batch_size = target.size(0)
        input_grad, target_grad, weight_grad = None, None, None
        differ = target - input
        input_grad = -2 * differ * weight / batch_size
        weight_grad = differ ** 2 / batch_size

        return input_grad, target_grad, weight_grad

Here is another thread about build our own custom loss functions, I did not have time to read it but hope this helps you.
If you find a better solution, let me know.
Thank you!

Thank you for pointing to that first link. Between that and the other link I had found, I have been able to assemble something that works. That being said, the way that I define WeightedLossFunc is a slightly hacky work-around, because I don’t know how to define its backward method when I need to account for some additional transform with its own backward to deal with. For instance, I have something like:

class WeightedLossFunc(Function):
    @staticmethod
    def forward(ctx, weights, *differences):
        batch_size = differences[0].size(0)
        # batch meaned squared errors
        squared_errors = (torch.cat([
            d.view(batch_size, -1).pow(2)
            .sum(dim=1, keepdim=True)
            for d in differences], dim=1)
                          .mean(dim=0, keepdim=True))
        # weights must be positive, norm 1.
        weights[weights < 0] = 1e-6
        weights = F.normalize(weights, p=2, dim=0)
        ctx.save_for_backward(squared_errors, weights, *differences)
        wse = squared_errors.mm(weights.view(-1, 1)).squeeze()
        return wse
    @staticmethod
    def backward(ctx, grad_output):
        squared_errors, weights = ctx.saved_tensors[:2]
        differences = ctx.saved_tensors[2:]
        batch_size = differences[0].size(0)
        grad_differences = grad_weights = None
        if ctx.needs_input_grad[0]:
            # compute grad_weights
            grad_weights = squared_errors.squeeze()
            # project gradient onto Tangent Space of Sphere at point weights
            grad_weights = WaveletLossFunc._orthoproject(grad_weights, weights)
        if ctx.needs_input_grad[1]:
            # compute all grad_differences
            grad_differences = [2 * d * weights[i] / batch_size
                                for i, d in enumerate(differences)]
        return (grad_weights, *grad_differences)

(Actually, I return -grad_weights in my own code, because I don’t want the weights to be concentrated on the part of the data that achieves the lowest loss. I don’t know if this naïve choice really achieves what I want here, or if I should actually implement something GAN-like that trains adversarially. Initial tests seem to suggest it’s working fine.) Then the loss itself is implemented in the following fashion:

class WeightedLoss(nn.Module):
    def __init__(self, some_parameters):
        super().__init__()
        self.some_transform = SomeTransform(some_parameters)
        weights = F.normalize(torch.rand(self.some_transform.output_length), 
                              dim=0)
        self.weights = nn.Parameter(weights)
        self.weighted_loss = WeightedLossFunc.apply
    def forward(self, x, y):
        transformed_differences = self.some_transform(x - y)
        return self.weighted_loss(self.weights, *transformed_differences)

I guess what I’d like to know is whether there’s an easy way to write WeightedLossFunc.backward if I were to move the transformed_differences = self.some_transform(x - y) call into WeightedLossFunc.forward. I think that both codes should function identically to one another, but one of them makes more sense from an organizational/structural perspective; my current way is a little kludgy. I think that knowing this would also serve to write down one of the final missing pieces in PyTorch’s autograd tutorial.

1 Like