Error: Implementing Custom Activation Function(TERLU) using this paper https://arxiv.org/pdf/2006.02797

import torch
import torch.nn as nn
import torch.nn.functional as F

class TERLUFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha, beta, mu):
        ctx.save_for_backward(x, alpha, beta, mu)
        
        # Create an out-of-place output tensor
        output = torch.empty_like(x)  # Create empty tensor with same size as x
        
        # Compute regions and assign them to the output tensor
        region_1 = (x <= 0)
        region_2 = (x > 0) & (x < mu)
        region_3 = (x >= mu)
        
        output[region_1] = alpha * (torch.exp(x[region_1]) - 1)
        output[region_2] = x[region_2]
        output[region_3] = beta * (mu - (torch.exp(-(x[region_3] - mu)) - 1))

        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, alpha, beta, mu = ctx.saved_tensors
        grad_input = torch.zeros_like(x)

        region_1 = (x <= 0)
        grad_input[region_1] = grad_output[region_1] * (torch.exp(x[region_1]) - 1 + alpha)

        region_2 = (x > 0) & (x < mu)
        grad_input[region_2] = grad_output[region_2] * 1

        region_3 = (x >= mu)
        grad_input[region_3] = grad_output[region_3] * (-torch.exp(-(x[region_3] - mu)) + 1 + beta * mu + beta)

        grad_beta = torch.zeros_like(beta)
        grad_beta = grad_output[region_3] * (mu - (torch.exp(-(x[region_3] - mu)) - 1))

        return grad_input, None, grad_beta, None  # Return gradients for x, alpha, beta, mu

class TERLU(nn.Module):
    def __init__(self, alpha=1.0, mu=1.0):
        super(TERLU, self).__init__()
        self.alpha = torch.tensor(alpha)  # Make alpha a Parameter
        self.beta = nn.Parameter(torch.tensor(1.0))  # Trainable beta
        self.mu = torch.tensor(mu)  # Make mu a Parameter

    def forward(self, x):
        return TERLUFunction.apply(x, self.alpha, self.beta, self.mu)

    def extra_repr(self):
        return f'alpha={self.alpha}, beta={self.beta}, mu={self.mu}’
class Trainer:
    def __iniit__(…):
        ….
    def train_step(self, batch, step_count):
        images, masks, ground_truths = batch
        images, masks, ground_truths = images.to(self.device), masks.to(self.device), ground_truths.to(self.device)

        with autocast(self.device):
            coarse_out, refine_out = self.generator(images, masks)
            coarse_out_ = images * (1 - masks) + coarse_out * masks
            refine_out_ = images * (1 - masks) + refine_out * masks

            if torch.isnan(refine_out_).any() or torch.isnan(ground_truths).any():
                print("NaN detected in outputs or ground truths!")
                refine_out_ = torch.nan_to_num(refine_out_, nan=0.0)
                ground_truths = torch.nan_to_num(ground_truths, nan=0.0)

            # Discriminator Loss
            fake_images = refine_out_.detach()
            self.opt_discriminator.zero_grad()
            ground_preds = self.discriminator(ground_truths, masks)
            fake_preds = self.discriminator(fake_images, masks)

            # gradient_penalty = self.compute_gradient_penalty(self.discriminator, ground_truths, fake_images, masks)
            # disc_loss = -torch.mean(ground_preds) + torch.mean(fake_preds) + self.sigma * gradient_penalty # WGAN_LOSS

            hinge_loss = torch.mean(self.relu(1 - ground_preds)) + torch.mean(self.relu(1 + fake_preds)) + self.epsilon
            disc_loss = hinge_loss
            self.scaler_d.scale(disc_loss).backward(retain_graph=True)
            # torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 1.0)
            self.scaler_d.step(self.opt_discriminator)
            self.scaler_d.update()

            # Mask and Adverserial Loss
            fake_images = refine_out_.detach()
            coarse_loss = self.l1_loss(coarse_out_, images) + self.epsilon
            refine_loss = self.l1_loss(fake_images, images) + self.epsilon
            adv_loss = -torch.mean(self.discriminator(fake_images, masks)) + self.epsilon

            self.opt_coarse.zero_grad()
            self.scaler_c.scale(coarse_loss).backward(retain_graph=True)
            # torch.nn.utils.clip_grad_norm_(self.generator.coarse_generator.parameters(), 1.0)
            self.scaler_c.step(self.opt_coarse)
            self.scaler_c.update()

            # Preceptual Loss
            ground_features = self.vgg(ground_truths)
            fake_features = self.vgg(refine_out_)
            perceptual_loss = self.l1_loss(ground_features, fake_features)

            # Generator Loss
            rec_loss = 0.5 * self.lambda_l1 * coarse_loss + self.lambda_l1 * refine_loss + self.epsilon
            gen_loss = rec_loss + self.beta * adv_loss + self.lambda_perceptual * perceptual_loss

            self.opt_generator.zero_grad()
            self.scaler_g.scale(gen_loss).backward()
            # torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 1.0)
            self.scaler_g.step(self.opt_generator)
            self.scaler_g.update()
import torch.optim as optim

def train(start_epoch=1, num_epochs=2, train_loader=train_loader, device='cuda'):

    # Initialize Generator and Discriminator
    cam = CAM().to(device)
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)

    if start_epoch == 1:
        generator.apply(weights_init)
        discriminator.apply(weights_init)
        print("Weights initialized")

    # Optimizers with learning rates
    opt_coarse = optim.Adam(generator.coarse_generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    opt_generator = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    opt_discriminator = optim.Adam(discriminator.parameters(), lr=3e-4, betas=(0.5, 0.999))

    # Initialize Logger
    logger = Logger(train_log_file=TRAIN_CSV_PATH, val_log_file=VAL_CSV_PATH, checkpoint_dir=CHKPT_DIR)

    if start_epoch > 1:
        logger.load_checkpoint(start_epoch - 1, generator, discriminator,
                               opt_coarse, opt_generator, opt_discriminator)

    # Initialize Trainer and start training
    trainer = Trainer(generator, discriminator, opt_coarse,
                      opt_generator, opt_discriminator,
                      train_loader, logger, device)
    trainer.train(start_epoch, num_epochs)

    torch.cuda.empty_cache()
    gc.collect()

Glimpse of Error

A lot of the times this can be avoided by modifying the model code. A tool that might be helpful though is Automatic differentiation package - torch.autograd — PyTorch 2.6 documentation which will automatically insert a clone for you in any place that would otherwise produce this error.