Help with autocast causing loss issue

tl;dr Using autocast is causing issues with a specific loss. Forcing to fp32 didn’t fix. Help needed.

Problem description:

Autocast and Gradscale is being applied on the training loop, as AMP documentation explains:

        self.optimizer_g.zero_grad(set_to_none=True)
        scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
            self.output = self.net_g(self.lq)
            l_g_total = 0
            loss_dict = OrderedDict()
            if self.cri_perceptual:
                l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
                if l_g_percep is not None:
                    l_g_total += l_g_percep
                    loss_dict['l_percep'] = l_g_percep

        scaler.scale(l_g_total).backward()
        scaler.step(self.optimizer_g)
        scaler.update()

The cri_perceptual is the Perceptual Loss, which is causing issues:

def forward(self, x, gt):
	x_features = self.vgg(x)
	gt_features = self.vgg(gt.detach())
	for k in x_features.keys():
		percep_loss += self.criterion(
                        x_features[k], gt_features[k]) * self.layer_weights[k]
	return percep_loss

It uses torchvision VGG model to extract features:

from torchvision.models import vgg
from torchvision.models import VGG19_Weights

vgg_net = getattr(vgg, vgg_type)(weights=VGG19_Weights.DEFAULT)
self.vgg_net.eval()

def forward(self, x):
	if self.range_norm:
            x = (x + 1) / 2
        if self.use_input_norm:
            x = (x - self.mean) / self.std

        output = {}
        for key, layer in self.vgg_net._modules.items():
            x = layer(x)
            if key in self.layer_name_list:
                output[key] = x.clone()

        return output

Things I’ve tried so far, without success:

  • Adding decorator @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) to both vgg feature extraction and perceptual forward. Didn’t work, even though I did not use custom_bwd as documentation says, as this would require some modifications on current .backward()
  • Forcing VGG inference with .float() or .half()

Question:

  • Any idea how to solve this and what might be causing the issue?

Update: I was able to identify the origin of the issue is in GradScaler backward().
Still not sure how to solve it.

Solution has been found by one of my projects users (@terrainer).
The initial GradScaler value was too high, leading to inf/NaNs and making it unable to recover even after backoff_factor being applied. Decreasing init_scale from default 2.**16 to 2.**11 fixes the issue. Thanks again to @terrainer for finding, testing and reporting the solution.

1 Like