hi this is just piece of code but it does the job.
It is just stuck at lines where losses are used. it always throw
d_loss_hr = adversarial_loss(hr_output, real_label)
TypeError: 'Tensor' object is not callable
or same with pixel loss when i delete all adversarial losses
What is causing this problem ?
def define_loss() -> [nn.MSELoss, nn.MSELoss, ContentLoss, nn.BCEWithLogitsLoss]:
psnr_criterion = nn.MSELoss().to("cuda")
pixel_criterion = nn.MSELoss().to("cuda")
content_criterion = ContentLoss().to("cuda")
adversarial_criterion = nn.BCEWithLogitsLoss().to("cuda")
return psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion
psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss()
def train_model(generator,
discriminator,
g_optimizer,
d_optimizer,
pixel_loss,
content_loss,
adversarial_loss,
loader,
batch_size,
best_loss,
best_psnr,
scaler,
loss_chart,
batch_count,
epoch):
for batch in range(batch_count):
lr, hr = loader.get_training_batch(batch_size)
hr = hr.to("cuda")
lr = lr.to("cuda")
real_label = torch.full([lr.size(0), 1], 1.0, dtype=lr.dtype, device="cuda")
fake_label = torch.full([lr.size(0), 1], 0.0, dtype=lr.dtype, device="cuda")
sr = generator(lr)
# Initialize the discriminator optimizer gradient
d_optimizer.zero_grad()
# Calculate the loss of the discriminator on the high-resolution image
with amp.autocast():
hr_output = discriminator(hr)
d_loss_hr = adversarial_loss(hr_output, real_label)
# Gradient zoom
scaler.scale(d_loss_hr).backward()
# Calculate the loss of the discriminator on the super-resolution image.
with amp.autocast():
sr_output = discriminator(sr.detach())
d_loss_sr = adversarial_loss(sr_output, fake_label)
# Gradient zoom
scaler.scale(d_loss_sr).backward()
# Update discriminator parameters
scaler.step(d_optimizer)
scaler.update()
g_optimizer.zero_grad()
with amp.autocast():
output = discriminator(sr)
pixel_loss = 1.0 * pixel_loss(sr, hr.detach())
content_loss = 1.0 * content_loss(sr, hr.detach())
adversarial_loss = 0.001 * adversarial_loss(output, real_label)
# Count discriminator total loss
g_loss = pixel_loss + content_loss + adversarial_loss
# Gradient zoom
scaler.scale(g_loss).backward()
# Update generator parameters
scaler.step(g_optimizer)
scaler.update()