Hi all,
I was trying to train a model using a Style loss with VGG19 in float16 precision however, when I calculated the Gram matrix it contains Nan and Inf. This does not happen in float32.
Here is the code of Calculating Gram matrix:
def getGramMatrix(self, x):
b, ch, h, w = x.size()
f = x.view(b, ch, w * h)
G = f.bmm(f.transpose(1, 2))
return G
It happens when doing the multiplication.
I can think only of two options. First one is to remove NaNs and Infs with
G = torch.nan_to_num(G, nan=0.0, posinf=0.0, neginf=0.0)
or cast the inputs to float32 before estimation the loss.
Is there any other solution to fix this issue?
Here is the full code of the loss function
class StyleLoss(_LossBase):
def __init__(self, weight = 1.0):
super(StyleLoss, self).__init__(weight)
self.add_module('vgg', VGG19())
self.criterion = torch.nn.L1Loss()
def getGramMatrix(self, x):
b, ch, h, w = x.size()
f = x.view(b, ch, w * h)
G = f.bmm(f.transpose(1, 2))
return G
def __call__(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
style_loss = 0.0
style_loss += self.criterion(self.getGramMatrix(x_vgg['relu2_2']), self.getGramMatrix(y_vgg['relu2_2']))
style_loss += self.criterion(self.getGramMatrix(x_vgg['relu3_4']), self.getGramMatrix(y_vgg['relu3_4']))
style_loss += self.criterion(self.getGramMatrix(x_vgg['relu4_4']), self.getGramMatrix(y_vgg['relu4_4']))
style_loss += self.criterion(self.getGramMatrix(x_vgg['relu5_2']), self.getGramMatrix(y_vgg['relu5_2']))
return style_loss * self.weight
Thank you.