Why is my implementation of SELU() not equivalent to nn.SELU()?

For exercise I tried to implement my own selu activation function. When I compare its output to the output of nn.SELU(), it is the same up to 1e-9. However, when I start to train with my selu implementation, I get nan after a few iterations, whereas I get a normally decreasing loss with nn.SELU(). It seems like the gradient is calculated differently, but I dont know why. The data I am using is from here (put it in a folder called input). Then the code for reproducing this is (the important module is at the end):

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import warnings
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import math
import visdom
import torch.optim as optim


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count



class PulsarDataset(Dataset):

    def __init__(self, x, y):
        self.x = x
        self.y = y


    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):

        return (self.x[idx].astype(np.float), self.y[idx].astype(np.long))


def adjust_learning_rate(optimizers, init_lr, epoch):
    lr = init_lr * (0.5 ** (epoch // 30))

    for optimizer in optimizers:
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr



class Model(object):
    def __init__(self):
        super(Model, self).__init__()
        self.init_model()

        self.cross_entropy_loss = torch.nn.NLLLoss()

    def init_model(self):
        self.net = NN().to(device)

    def load_data(self, x_train, y_train, x_validation, y_validation, x_test, y_test, batch_size = 16):

        trainset = PulsarDataset(x_train, y_train)
        self.train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=10)

        validationset = PulsarDataset(x_validation, y_validation)
        self.validation_loader = torch.utils.data.DataLoader(validationset, batch_size=batch_size, shuffle=True, num_workers=10)

        testset = PulsarDataset(x_train, y_train)
        self.test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=10)

    def init_optimizer(self, lr = 2e-3, weight_decay = 5e-3):
        self.lr = lr
        self.weight_decay = weight_decay
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay)



    def main_compute_step(self, x, y):

        # FORWARD
        y_pred = self.net(x)
        cross_entropy = self.cross_entropy_loss(y_pred, y)
        self.loss = cross_entropy
        y_pred = torch.exp(y_pred)
        _, predicted = torch.max(y_pred, dim=1)
        
        accuracy = (y.detach().cpu().numpy().astype(np.int) == predicted.detach().cpu().numpy()).astype(np.int).mean()

        return y_pred.detach().cpu().numpy(), cross_entropy.detach().cpu().numpy(), accuracy


    def _train(self, epoch):
        adjust_learning_rate([self.optimizer], self.lr, epoch)

        cross_entropy_meter = AverageMeter()
        accuracy_meter = AverageMeter()
        cross_entropy_list = list()
        accuracy_list = list()
        for i, (x, y) in enumerate(self.train_loader):

            x,y = x.type(torch.FloatTensor).to(device), y.type(torch.LongTensor).to(device)


            self.optimizer.zero_grad()

            y_pred, cross_entropy, accuracy = self.main_compute_step(x, y)
            if i % 1 == 0:
                cross_entropy_list.append(cross_entropy)
                accuracy_list.append(accuracy)
                print(str(i) + '/' + str(len(self.train_loader)), 'cross_entropy: ' + str(cross_entropy_list[-1]), 'accuracy: ' + str(accuracy))
                

            cross_entropy_meter.update(cross_entropy)
            accuracy_meter.update(accuracy)

            # Update parameters with optimizers
            self.loss.backward()
            #torch.nn.utils.clip_grad_norm(self.net.parameters(),1)
            self.optimizer.step()

        return cross_entropy_meter.avg, accuracy_meter.avg


    def test(self, use_test_set = True):
        cross_entropy_meter = AverageMeter()
        accuracy_meter = AverageMeter()
        accuracy_list = list()

        if use_test_set:
            loader = self.test_loader
        else:
            loader = self.train_loader

        for i, (x,y) in enumerate(loader):
            if i % 10 == 0:
                print(str(i) + '/' + str(len(loader)))

            x,y = x.type(torch.FloatTensor).to(device), y.type(torch.LongTensor).to(device)

            with torch.no_grad():
                y_pred, cross_entropy, accuracy = self.main_compute_step(x,y)

            accuracy_list.append(accuracy)
            cross_entropy_meter.update(cross_entropy)
            accuracy_meter.update(accuracy)


        return cross_entropy_meter.avg, accuracy_meter.avg



    def train(self, lr = 2e-3, weight_decay = 5e-3, n_epochs = 1000, test_interval = 5, use_last_state=False, model_path=None, y_weight=1):
        self.n_epochs = n_epochs
        self.test_interval = test_interval

        self.init_optimizer(lr, weight_decay)

        self.last_epoch = 0
        self.train_loss_list = list()
        self.validation_loss_list = list()
        self.test_loss_list = list()

        self.train_acc_list = list()
        self.validation_acc_list = list()
        self.test_acc_list = list()

        for epoch in range(self.last_epoch + 1, self.n_epochs + 1):
            cross_entropy, accuracy = self._train(epoch)

            print('Epoch: [%d/%d], cross_entropy: %.9f, acc: %.9f' % (epoch, self.n_epochs, cross_entropy, accuracy))

            self.train_loss_list.append(cross_entropy)
            self.train_acc_list.append(accuracy)
            

            if epoch % self.test_interval == 0:
                validation_loss, validation_accuracy = self.test()
                self.validation_loss_list.append(validation_loss)
                self.validation_acc_list.append(validation_accuracy)
                
                print('-----------------------VALIDATION-----------------------')
                print('Epoch: [%d/%d], cross_entropy: %.9f, acc: %.9f' % (epoch, self.n_epochs, validation_loss, validation_accuracy))
                print('--------------------------------------------------------')

                test_loss, test_accuracy = self.test()
                self.test_loss_list.append(test_loss)
                self.test_acc_list.append(test_accuracy)
                
                print('-----------------------TEST-----------------------')
                print('Epoch: [%d/%d], cross_entropy: %.9f, acc: %.9f' % (epoch, self.n_epochs, test_loss, test_accuracy))
                print('---------------------------------------------------')


        return self.train_loss_list, self.validation_loss_list, self.test_loss_list, self.train_acc_list, self.validation_acc_list, self.test_acc_list


class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        
        self.linear1 = nn.Linear(8, 128)
        weights_init(self.linear1)
        self.linear2 = nn.Linear(128, 128)
        weights_init(self.linear1)
        self.linear3 = nn.Linear(128, 2)
        weights_init(self.linear1)


        ##################### WORKS #####################
        self.act1 = nn.SELU()
        self.act2 = MySelu()

        ##################### NAN #######################
        #self.act2 = nn.SELU()
        #self.act1 = MySelu()

        self.act_f = nn.LogSoftmax(dim=1)


    def forward(self, x):
        x = self.linear1(x)
        x_1 = self.act1(x)
        x_2 = self.act2(x)
        print((x_1 - x_2).mean())

        x = x_1

        x = self.linear2(x)
        x = self.act1(x)
        x_1 = self.act1(x)
        x_2 = self.act2(x)
        print((x_1 - x_2).mean())

        x = x_1

        x = self.linear3(x)
        x = self.act_f(x)

        return x


def random_weight(shape, is_relu):
    """
    Kaiming normalization: sqrt(2 / fan_in)
    """
    if len(shape) == 2:  # FC weight
        fan_in = shape[0]
    else:
        fan_in = np.prod(shape[1:]) # conv weight [out_channel, in_channel, kH, kW]

    if is_relu:
        w = torch.randn(shape) * np.sqrt(2. / fan_in)
    else:
        w = torch.randn(shape) * np.sqrt(1. / fan_in)
    w.requires_grad = True
    return w

def zero_weight(shape):
    return torch.zeros(shape, requires_grad=True)
 

def weights_init(m, is_relu = False):
    if type(m) in [nn.Conv2d, nn.Linear]:
        m.weight.data = random_weight(m.weight.data.size(), is_relu)
        m.bias.data = zero_weight(m.bias.data.size())



class MySelu(nn.Module):

    def __init__(self):
        super(MySelu, self).__init__()

    def forward(self, x):
        
        term1 = torch.clamp(x, min=0)

        term2 = torch.clamp(1.6732632423543772848170429916717 * (torch.exp(x) -1), max=0)

        return 1.0507009873554804934193349852946 * ( term1 + term2)





warnings.filterwarnings("ignore")
DataFrame = pd.read_csv("./input/pulsar_stars.csv")  
labels = DataFrame.target_class.values

DataFrame.drop(["target_class"],axis=1,inplace=True)

features = DataFrame.values
scaler = MinMaxScaler(feature_range=(0,1))

features_scaled = scaler.fit_transform(features)

x_train, x_validation, y_train, y_validation = train_test_split(features_scaled,labels,test_size=0.2)
x_validation, x_test, y_validation, y_test = train_test_split(x_validation,y_validation,test_size=0.8)



batch_size = 128
n_epochs = 50
test_interval = 5
device = 'cuda:0'#torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


model = Model()
model.load_data(x_train, y_train, x_validation, y_validation, x_test, y_test, batch_size=batch_size)
train_loss, validation_loss, test_loss, train_acc, validation_acc, test_acc = model.train(n_epochs = n_epochs, lr = 0.1, weight_decay = 5e-4)