Gram matrix in mixed precision

Hi all,
I was trying to train a model using a Style loss with VGG19 in float16 precision however, when I calculated the Gram matrix it contains Nan and Inf. This does not happen in float32.
Here is the code of Calculating Gram matrix:

 def getGramMatrix(self, x):
        b, ch, h, w = x.size()
        f = x.view(b, ch, w * h)
        G = f.bmm(f.transpose(1, 2))
        return G

It happens when doing the multiplication.

I can think only of two options. First one is to remove NaNs and Infs with
G = torch.nan_to_num(G, nan=0.0, posinf=0.0, neginf=0.0)

or cast the inputs to float32 before estimation the loss.

Is there any other solution to fix this issue?

Here is the full code of the loss function

class StyleLoss(_LossBase):
    def __init__(self, weight = 1.0):
        super(StyleLoss, self).__init__(weight)
        self.add_module('vgg', VGG19())
        self.criterion = torch.nn.L1Loss()

    def getGramMatrix(self, x):
        b, ch, h, w = x.size()
        f = x.view(b, ch, w * h)
        G = f.bmm(f.transpose(1, 2))
        return G

    def __call__(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)

        style_loss = 0.0
        style_loss += self.criterion(self.getGramMatrix(x_vgg['relu2_2']), self.getGramMatrix(y_vgg['relu2_2']))
        style_loss += self.criterion(self.getGramMatrix(x_vgg['relu3_4']), self.getGramMatrix(y_vgg['relu3_4']))
        style_loss += self.criterion(self.getGramMatrix(x_vgg['relu4_4']), self.getGramMatrix(y_vgg['relu4_4']))
        style_loss += self.criterion(self.getGramMatrix(x_vgg['relu5_2']), self.getGramMatrix(y_vgg['relu5_2']))

        return style_loss * self.weight

Thank you.

Try this:

def remove_inf_nan(x):
    x[x!=x]=0
    x[~torch.isfinite(x)]=0
    return x

Another option is to clamp your tensors before multiplication. This would probably be a better solution than setting the values to zero.


x = torch.clamp(x, max=1.0e2, min=-1.0e2)

Or if you plan to use for gradient descent:


torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0e2)

https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html