How to customize activation function of torchvision model(resnet18)?

Hi all,
I hope that you are having a great day.
I am implementing a paper on uncertainty estimation and using torch-vision pre-trained model ResNet-18. However I want to use my own customize activation function in the second last layer of resnet-18 instead of relu. How do I do that?
I searched online but found no solution
Thank you.

By “second last layer” do you mean the second block in layer4?
If so, you could use the following code:

model = models.resnet18(pretrained=False)
model.layer4[1].relu = nn.LeakyReLU(inplace=True)  # your custom function here

Let me know, if you meant another layer.

As a small addition to the answer, one can inspect the model by printing it:

print(model)
# which shows something like:
# ResNet(
#  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#  (relu): ReLU(inplace)
#  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
#  (layer1): Sequential(
#    (0): BasicBlock(
#      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
#      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#      (relu): ReLU(inplace)
#      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
#      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#    )
#    (1): BasicBlock(
#      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
#      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#      (relu): ReLU(inplace)
#      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
#      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#    )
#  )
# ... the rest of layers

And now you can easily navigate to whatever layer you would like, using attributes and arrays represented in the graph. It could be model.layer1[0].relu or model.layer2[1].conv1.weight, or any other part of the model.

Hi,
Thank you for your reply:
by second last layer I mean, layer before the first fully connected so it would be conv2 in BasicBlock1 of layer 4.

By the way I am using re-scaled cauchy distribution (kernel) (PDF) function 1/(1 + (x*x)), which I have implemented like this:

import torch

def pdf_cauchy_distribution(tensor):
    '''
    this fuction takes the output from neural netwrok's layer and implements a 
    kernal function which acts as an activation function.
    
    input:
    tensor: output of neural network's layer computation (w*x + b)
    
    output:
    also a tensor which after going to pdf cauchy distribution fucntion
    which is f(x) = 1/(1+x^2)
    '''
    
    return (1 / (1 + torch.mul(tensor, tensor)))

but when I use it like you described, it gives an error:

TypeError: cannot assign 'model.activation.pdf_cauchy_distribution' as child module 'relu' (torch.nn.Module or None expected)

Is there any functionality for linear activation functions such as the one I described offered by Pytorch?

I found this https://pytorch.org/docs/stable/distributions.html#cauchy, but I am not sure how does that serve my purpose?

Could you try to define this activation as an nn.Module and try to assign it again?

Excuse my ignorance, do you mean like this?

class cauchy_activation(nn.Module):
    def __init__(self, x):
        super(cauchy_activation, self).__init__()
        self.inp = x
        
    def activation(self):
        return pdf_cauchy_distribution(self.inp)
def pdf_cauchy_distribution(tensor):
    '''
    this fuction takes the output from neural netwrok's layer and implements a 
    kernal function which acts as an activation function.
    
    input:
    tensor: output of neural network's layer computation (w*x + b)
    
    output:
    also a tensor which after going to pdf cauchy distribution fucntion
    which is f(x) = 1/(1+x^2)
    '''
    
    return (1 / (1 + torch.mul(tensor, tensor)))

if it is like this then:
when I instantiate this class, what should be the input, I mean what to feed to the class’ __init__()

I did it like this:

import torch
import torch.nn as nn

def pdf_cauchy_distribution(tensor):
    '''
    this fuction takes the output from neural netwrok's layer and implements a 
    kernal function which acts as an activation function.
    
    input:
    tensor: output of neural network's layer computation (w*x + b)
    
    output:
    also a tensor which after going to pdf cauchy distribution fucntion
    which is f(x) = 1/(1+x^2)
    '''
    
    return (1 / (1 + torch.mul(tensor, tensor)))

class cauchy_activation(nn.Module):
    def __init__(self):
        super(cauchy_activation, self).__init__()
        
    def activation(self, inp):
        return pdf_cauchy_distribution(self.inp)

and in model file:

class Resnet18(BaseModel):
    def __init__(self, classes=10):
        super(Resnet18, self).__init__()
        par = argparse.ArgumentParser(description='Model_resnet18')
        par.add_argument('-c', '--config', default = 'config.json', type=str, help = 'config file path (default: None)')
        args = par.parse_args()
        config = json.load(open(args.config))
        
        self.resnet = models.resnet18(pretrained = False)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, classes, bias=config['arch']['last_layer_bias'])
        self.resnet.layer4[1].relu = cauchy_activation.activation
        '''
        ct = 0
        for child in self.resnet.children():
            #print("child ", ct, ": \n", child)
            for param in child.parameters():
                if(ct != 9):
                    param.requires_grad = False 
                #elif(ct != 7):
                    #param.requires_grad = False
            ct += 1
        #exit
        '''
        #self.resnet.lay
    def forward(self, x_input):
        output = self.resnet(x_input)
        return output 

still it is giving same error:

TypeError: cannot assign 'model.activation.cauchy_activation.activation' as child module 'relu' (torch.nn.Module or None expected)

Could you try the following definition:

class cauchy_activation(nn.Module):
    def __init__(self):
        super(cauchy_activation, self).__init__()
        
    def forward(self, x):
        return pdf_cauchy_distribution(x)

model = models.resnet18(pretrained=False)
model.layer4[1].relu = cauchy_activation()
output = model(torch.randn(1, 3, 224, 224))
2 Likes

Wow, thank you sir it works now.
Can just explain in one or two sentences, why it works?

and is this activation function for the whole layer4? (seems to me it is). I wanted to just apply it to conv2 in BasicBlock1 of layer 4 .

I just initialized a stateless module without assigning the input in its __init__.
Basically the module now just calls your activation function without storing any parameters.

Yes, it will be used twice. You would need to manipulate the BasicBlock implementation here.

alright thank you very much for your help.
God bless you.