Custom activation functions?

New here so… How do I implement and use an activation function that’s based on another function in Pytorch, like for an example, swish?

If your new function is differentiable then just write it as a python function. If it has parameters, you can use nn.Module and you will need to implement the init and the forward for your function.

If it is not differentiable, you will have to define the backward operation you want using the informations here.

1 Like

@albanD , It keeps giving me a __main__.Activation is not a Module subclass error, when trying to implement it as a Python func… No parameters except for x and it is differentiable.

Could you give a code sample that reproduce this please?

@albanD

Code sample:

from math import e
from torch import nn
import torch
def Swish(x):
    return 1/(1+(math.e**-x)) * x

 #Copied from the GAN Tutorial with changes to the activation function.

ngf =  64
nz = 100
nc = 3

main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            Swish,
            #nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            Swish, #Define x, maybe? 
            #nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            Swish,
            #nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            Swish,
            #nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

Does x in the Swish function count as a parameter? New to PyTorch so… I also tried to use nn.Module and implementing init and forward, but I got same error.

Hello,

I think you should rewrite function Swish as a sub-class of nn.Module, and call Swish() instead of Swish in the Sequential.

@MariosOreo, that worked, thank you.

@ZeroMaxinumXZ , maybe a little late but what was your final code? Trying to do an activation function my self …