Custom nn.Conv2d

Hello, This is my first week as a PyTorch user.

How can we define a custom Conv2d function which works similar to nn.Conv2d but the multiplication and addition used inside nn.Conv2d are replaced with mymult(num1,num2) and myadd(num1,num2). It’s for some hdl simulation purpose. Even if it’s a slow implementation it doesn’t matter. We would like to test our theory on inference by loading already trained weights.

I think the easiest way would be to use unfold to create patches and apply your custom functions on each patch.
This post might be a good starter.

Can i do it like this ?

If that fits your use case better, go for it.
unfold is basically the im2col operation, which would create the patches for each operation with a kernel.

1 Like

Dear Sami,

Mentioned link “Custom a new convolution layer in cnn” is useful but it is just about forward pass of convolution layer. Backward pass have its own problems. For example, you should manage the dimension of produced gradient in your evaluations.

I suggest you to use Pytorch functions. Since Convolution is equivalent with Unfold + Matrix Multiplication + Fold maybe what ptrblck recommended you would be more useful and simpler.

I mean that whatever you can use ready Pytorch functions and I think most of the time there is a way instead of customizing layers.

If you want to customize a layer, I suggest you to check the open source implementation of multi-layer perceptron and convolution layer and their theories. Also, these two tutorials could help you:

Of course, what ptrblck sayes is more professional than mine.

Good luck

1 Like

Oh, I wouldn’t say that. :wink:
I’m just posting some suggestions and different approaches always help. :slight_smile:


Thanks alot for your replies. I am a little bit confused , apologies in advance as i am just a beginner. Do i have to override some function in nn.Conv2d() and use that unfolding part in it ? Can you please point me to some example.

You don’t have to override the conv functionality, but could directly implement your custom convolution using unfold.
The link in my first example gives you an example of a vanilla convolution or are you looking for something else?

Let me try to elaborate what is my goal Consider an example of Lenet - 5

class Net(nn.Module):
    def __init__(self):
        self.fc1 = nn.Linear(400,120)
        self.fc2 = nn.Linear(120,84)
        self.classifier = nn.Linear(84,10)
        self.features = nn.Sequential(*list(self.children()))


    def forward(self,x):
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        #x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def num_flat_features(self,x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
        return num_features

model = torch.load('lenet5.pth')

model_new = Net()    # before loading trained weights

model_new.load_state_dict(torch.load('lenet5.pth'))  # After loading trained w

with torch.no_grad():
	correct = 0
	total = 0
	for images,labels in test_loader:
		#images = images.reshape(-1,28*28)
		out = model_new(images)
		_,predicted = torch.max(,1)
		total += labels.size(0)
		correct += (predicted==labels).sum().item()
		print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

I am suppose to check accuracy of the network after substituting the nn.conv2d multiplication/addition using my two functions mymult(num1,num2) and myadd(num1,num2). The same has to be done with some other already trained neural networks. This is a short example, we can even choose them from inside pytorch too.

You could use my linked code snippet to implement your custom convolution and substitute nn.Conv2d with your class.
The easiest way would be to use my code in your forward method and set the weight and bias in the __init__ of your MyConv2d class.

These is how it looks like from what i understood from all of your posts

class Myconv2D(torch.autograd.Function):

    # Note that both forward and backward are @staticmethods
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        # input dim [batch_size, channels, height, width]

        batch_size = len(input)
        channels = input[0][0]
        h, w = input[0][0][0], input [0][0][0]
        image =  torch.tensor(batch_size, channels, h, w) # input image
        #print (image)
        #print (image.size())
        #print (image[0])

        kh, kw = 3, 3# kernel size
        dh, dw = 2, 2 # stride

        filt = nn.parameter(channels, kh, kw) # filter (create this as nn.Parameter if you want to train it)
        #print (filt)
        patches = image.unfold(2, kh, dh).unfold(3, kw, dw)
        patches = patches.contiguous().view(batch_size, channels, -1, kh, kw)
        patches = patches.permute(0, 2, 1, 3, 4)
        patches = patches.view(-1, channels, kh, kw)
        #print (patches.shape)
        #print ("Filter shape",filt.shape)
        #print (patches,"\n","\n","\n",filt)
        #print ("Builtin multiplication")
        dummy = patches

        patches = patches * filt # same is done below with 4 nested loops with custom operation
        #print ("Custom Multiplication starts here:")
        for b in range(batch_size* int(h/3)* int(h/3)):

            for c in range(channels):

                for height in range(kh):

                    for width in range(kw):

                        patches[b][c][height][width]   = \   
                        dummy[b][c][height][width] * filt[c][height][width]  replace this multiply(*) with my function (mymult(num1,num2))

        patches = patches.sum(1) # previous and this patch are same what is it doing
        #print (patches)
        output = patches # is it right ?
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input =
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias

class Myconv2D(nn.Module):
    def __init__(self):
        super(Myconv2D, self).__init__()
        self.fn = Myconv2D.apply
        # weight tensor = out_channels× in_channels/groups ×kH×kW
        self.weight = nn.Parameter(torch.randn(1, 1, 2, 2)) # when groups=1

    def forward(self, x):
        x = self.fn(x, self.weight)
        return x

The multiplication answers are the same as
patches = patches * filt and the custom 4-Nested loop structure in forward method of class Myconv2D(torch.autograd.Function)
After that there is addition “patches = patches.sum(1)” i am not sure what is it doing , I would like to replace the addition as well.

Can you please have a look at it.

I will be using this in the previous code.

class Net(nn.Module):
    def __init__(self):
        self.fc1 = nn.Linear(400,120)
        self.fc2 = nn.Linear(120,84)
        self.classifier = nn.Linear(84,10)
        self.features = nn.Sequential(*list(self.children()))

Is this the right direction. I am really Thankful for your time and effort. I wouldn’t even be able to start with this without your comments.

There seem to be a few minor mistakes in the code:

  • h and w should probably de defined as input.size(2) and input.size(3), respectively.
    Currently you are assigning the same input value to both. However, since you are passing input directly, they can also be removed completely.
  • As already mentioned, you are passing input, so I’m not sure what image should be used for.
  • You are also passing the weight to forward, which would make filt unnecessary. Also, you are currently re-initializing filt in each forward pass, which won’t store the already learned weights, so use weight instead for your operations.
  • Are you using the nested loop to compare the outputs of patches * filt or what is it used for?
  • patches.sum(1) calculates the sum over dim1. I’m currently not sure, which shape patches would have at that point.

How would you like to replace the addition?
Also, could you print the shape of patches before calculating sum(1) and name all dimensions, so that we can understand the use case better? :slight_smile:

Thanks a ton :slight_smile: It did the job

Dear Sami, could you please provide me with your final code? I am working on a similar problem. Thanks in advance

Kindly look at the following code. A. I am trying to vectorize it as it is dead slow.