tl;dr Using autocast is causing issues with a specific loss. Forcing to fp32 didn’t fix. Help needed.
Problem description:
Autocast and Gradscale is being applied on the training loop, as AMP documentation explains:
self.optimizer_g.zero_grad(set_to_none=True)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
self.output = self.net_g(self.lq)
l_g_total = 0
loss_dict = OrderedDict()
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
if l_g_percep is not None:
l_g_total += l_g_percep
loss_dict['l_percep'] = l_g_percep
scaler.scale(l_g_total).backward()
scaler.step(self.optimizer_g)
scaler.update()
The cri_perceptual
is the Perceptual Loss, which is causing issues:
def forward(self, x, gt):
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
for k in x_features.keys():
percep_loss += self.criterion(
x_features[k], gt_features[k]) * self.layer_weights[k]
return percep_loss
It uses torchvision VGG model to extract features:
from torchvision.models import vgg
from torchvision.models import VGG19_Weights
vgg_net = getattr(vgg, vgg_type)(weights=VGG19_Weights.DEFAULT)
self.vgg_net.eval()
def forward(self, x):
if self.range_norm:
x = (x + 1) / 2
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for key, layer in self.vgg_net._modules.items():
x = layer(x)
if key in self.layer_name_list:
output[key] = x.clone()
return output
Things I’ve tried so far, without success:
- Adding decorator
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
to both vgg feature extraction and perceptual forward. Didn’t work, even though I did not usecustom_bwd
as documentation says, as this would require some modifications on current.backward()
- Forcing VGG inference with
.float()
or.half()
Question:
- Any idea how to solve this and what might be causing the issue?