import torch
import torch.nn as nn
import torch.nn.functional as F
class TERLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha, beta, mu):
ctx.save_for_backward(x, alpha, beta, mu)
# Create an out-of-place output tensor
output = torch.empty_like(x) # Create empty tensor with same size as x
# Compute regions and assign them to the output tensor
region_1 = (x <= 0)
region_2 = (x > 0) & (x < mu)
region_3 = (x >= mu)
output[region_1] = alpha * (torch.exp(x[region_1]) - 1)
output[region_2] = x[region_2]
output[region_3] = beta * (mu - (torch.exp(-(x[region_3] - mu)) - 1))
return output
@staticmethod
def backward(ctx, grad_output):
x, alpha, beta, mu = ctx.saved_tensors
grad_input = torch.zeros_like(x)
region_1 = (x <= 0)
grad_input[region_1] = grad_output[region_1] * (torch.exp(x[region_1]) - 1 + alpha)
region_2 = (x > 0) & (x < mu)
grad_input[region_2] = grad_output[region_2] * 1
region_3 = (x >= mu)
grad_input[region_3] = grad_output[region_3] * (-torch.exp(-(x[region_3] - mu)) + 1 + beta * mu + beta)
grad_beta = torch.zeros_like(beta)
grad_beta = grad_output[region_3] * (mu - (torch.exp(-(x[region_3] - mu)) - 1))
return grad_input, None, grad_beta, None # Return gradients for x, alpha, beta, mu
class TERLU(nn.Module):
def __init__(self, alpha=1.0, mu=1.0):
super(TERLU, self).__init__()
self.alpha = torch.tensor(alpha) # Make alpha a Parameter
self.beta = nn.Parameter(torch.tensor(1.0)) # Trainable beta
self.mu = torch.tensor(mu) # Make mu a Parameter
def forward(self, x):
return TERLUFunction.apply(x, self.alpha, self.beta, self.mu)
def extra_repr(self):
return f'alpha={self.alpha}, beta={self.beta}, mu={self.mu}’
class Trainer:
def __iniit__(…):
….
def train_step(self, batch, step_count):
images, masks, ground_truths = batch
images, masks, ground_truths = images.to(self.device), masks.to(self.device), ground_truths.to(self.device)
with autocast(self.device):
coarse_out, refine_out = self.generator(images, masks)
coarse_out_ = images * (1 - masks) + coarse_out * masks
refine_out_ = images * (1 - masks) + refine_out * masks
if torch.isnan(refine_out_).any() or torch.isnan(ground_truths).any():
print("NaN detected in outputs or ground truths!")
refine_out_ = torch.nan_to_num(refine_out_, nan=0.0)
ground_truths = torch.nan_to_num(ground_truths, nan=0.0)
# Discriminator Loss
fake_images = refine_out_.detach()
self.opt_discriminator.zero_grad()
ground_preds = self.discriminator(ground_truths, masks)
fake_preds = self.discriminator(fake_images, masks)
# gradient_penalty = self.compute_gradient_penalty(self.discriminator, ground_truths, fake_images, masks)
# disc_loss = -torch.mean(ground_preds) + torch.mean(fake_preds) + self.sigma * gradient_penalty # WGAN_LOSS
hinge_loss = torch.mean(self.relu(1 - ground_preds)) + torch.mean(self.relu(1 + fake_preds)) + self.epsilon
disc_loss = hinge_loss
self.scaler_d.scale(disc_loss).backward(retain_graph=True)
# torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 1.0)
self.scaler_d.step(self.opt_discriminator)
self.scaler_d.update()
# Mask and Adverserial Loss
fake_images = refine_out_.detach()
coarse_loss = self.l1_loss(coarse_out_, images) + self.epsilon
refine_loss = self.l1_loss(fake_images, images) + self.epsilon
adv_loss = -torch.mean(self.discriminator(fake_images, masks)) + self.epsilon
self.opt_coarse.zero_grad()
self.scaler_c.scale(coarse_loss).backward(retain_graph=True)
# torch.nn.utils.clip_grad_norm_(self.generator.coarse_generator.parameters(), 1.0)
self.scaler_c.step(self.opt_coarse)
self.scaler_c.update()
# Preceptual Loss
ground_features = self.vgg(ground_truths)
fake_features = self.vgg(refine_out_)
perceptual_loss = self.l1_loss(ground_features, fake_features)
# Generator Loss
rec_loss = 0.5 * self.lambda_l1 * coarse_loss + self.lambda_l1 * refine_loss + self.epsilon
gen_loss = rec_loss + self.beta * adv_loss + self.lambda_perceptual * perceptual_loss
self.opt_generator.zero_grad()
self.scaler_g.scale(gen_loss).backward()
# torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 1.0)
self.scaler_g.step(self.opt_generator)
self.scaler_g.update()
import torch.optim as optim
def train(start_epoch=1, num_epochs=2, train_loader=train_loader, device='cuda'):
# Initialize Generator and Discriminator
cam = CAM().to(device)
generator = Generator().to(device)
discriminator = Discriminator().to(device)
if start_epoch == 1:
generator.apply(weights_init)
discriminator.apply(weights_init)
print("Weights initialized")
# Optimizers with learning rates
opt_coarse = optim.Adam(generator.coarse_generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
opt_generator = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
opt_discriminator = optim.Adam(discriminator.parameters(), lr=3e-4, betas=(0.5, 0.999))
# Initialize Logger
logger = Logger(train_log_file=TRAIN_CSV_PATH, val_log_file=VAL_CSV_PATH, checkpoint_dir=CHKPT_DIR)
if start_epoch > 1:
logger.load_checkpoint(start_epoch - 1, generator, discriminator,
opt_coarse, opt_generator, opt_discriminator)
# Initialize Trainer and start training
trainer = Trainer(generator, discriminator, opt_coarse,
opt_generator, opt_discriminator,
train_loader, logger, device)
trainer.train(start_epoch, num_epochs)
torch.cuda.empty_cache()
gc.collect()
Glimpse of Error