How to implement a custom convolutional layer and call it from your own network?

Hello!
I would like to implement a slightly different version of conv2d and use it inside my neural network.
I would like to take into account an additional binary data during the convolution. For the sake of clarity, let’s consider the first layer of my network. From the input grayscale image, I compute a binary mask where object is white and background is black. Then, for the convolution, I will consider a fixed size window filter moving equally along the image and the mask. If the center of the considered window belongs to the object (ie is white), then only the pixels in the grayscale image which are white in the mask for the considered window should contribute to the filtering. The same reasoning is applied for pixel belonging to the background.
Here is my code for my custom layer :

class MyConv2d(nn.Module):
    def __init__(self, n_channels, out_channels, kernel_size, dilation=1, padding=0, stride=1):
        super(MyConv2d, self).__init__()

        self.kernel_size = (kernel_size, kernel_size)
        self.kernal_size_number = kernel_size * kernel_size
        self.out_channels = out_channels
        self.dilation = (dilation, dilation)
        self.padding = (padding, padding)
        self.stride = (stride, stride)
        self.n_channels = n_channels
        self.weights = nn.Parameter(torch.Tensor(self.out_channels, self.n_channels, self.kernal_size_number)).data.uniform_(0, 1)

    def forward(self, x, mask):
        width = self.calculateNewWidth(x)
        height = self.calculateNewHeight(x)
        result = torch.zeros(
            [x.shape[0] * self.out_channels, width, height], dtype=torch.float32, device=device
        )
        windows_x = self.calculateWindows(x)
        windows_mask = self.calculateWindows(mask)
        windows_mask[windows_mask < 1] = -1
        windows_mask_centers = windows_mask[:, :, windows_mask.size()[2]//2].view(windows_mask.size()[0], windows_mask.size()[1], 1)
        windows_mask = windows_mask * windows_mask_centers
        windows_mask[windows_mask < 1] = 0 
        windows_x_seg = windows_x * windows_mask

        for channel in range(x.shape[1]):
            for i_convNumber in range(self.out_channels):
                xx = torch.matmul(windows_x_seg[channel], self.weights[i_convNumber][channel])
                xx = xx.view(-1, width, height)
                result[i_convNumber * xx.shape[0] : (i_convNumber + 1) * xx.shape[0]] += xx

        result = result.view(x.shape[0], self.out_channels, width, height)
        return result

    def calculateWindows(self, x):
        windows = F.unfold(
            x, kernel_size=self.kernel_size, padding=self.padding, dilation=self.dilation, stride=self.stride
        )

        windows = windows.transpose(1, 2).contiguous().view(-1, x.shape[1], self.kernal_size_number)
        windows = windows.transpose(0, 1)

        return windows

    def calculateNewWidth(self, x):
        return (
            (x.shape[2] + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1)
            // self.stride[0]
        ) + 1

    def calculateNewHeight(self, x):
        return (
            (x.shape[3] + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1)
            // self.stride[1]
        ) + 1

Then, I would like to call MyConv2d from my network;
Here is a snipset of my network :

class MyNetwork(nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.conv = MyConv2d(1, 64, 5, stride=2, padding=0)
        # etc
       
    def forward(self, x, mask):
        x = F.relu(self.conv(x, mask))
        # etc
        return x
       

First of all, I have a question regarding the execution speed. MyConv2d is much slower than conv2d (because of the double for loop I guess). Is there a way to speed it up?
Secondly, I have an issue at the very first iteration when I train my network on gpu. Indeed, once the input got through my first custom layer, I get back Nan values in the output. Do you have any idea why this happens? Is there something wrong with my implementation of MyConv2d?
Last, I recently have a weird error that came out of the blue when I train my network:

copy_if failed to synchronize: cudaErrorIllegalAddress: an illegal memory access was encountered
This error occurs in MyConv2d when it runs into:

windows_mask[windows_mask < 1] = -1
Can you please help me fix this?

Many thanks in advance!

You could try to remove the loops and unfold the data, which could use more memory but might also be faster. Alternatively, you could also write a custom C++/CUDA extension, which could also yield a speedup.

You could add debug print statements and check which part of your custom layer returns the NaN values to narrow it down further.

If you are using an older PyTorch version, please update to the latest stable, since indexing errors should raise RuntimeErrors, not fail with illegal memory accesses.

Dear ptrblck,
Thank you for your reply.

If you are using an older PyTorch version, please update to the latest stable, since indexing errors should raise RuntimeErrors, not fail with illegal memory accesses.

I’ve tried to update pytorch to the last version but I encountered difficulties. I tried several things:

  • First of all I updated my cudatoolkit from 10.0 to 10.2
  • Then I ran the command : conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
    Nothing happened (I kept my old version of pytorch - 1.7.1)
  • Then I tried : pip install --upgrade torch torchvision torchaudio
    Here pytorch was updated to 1.8.1 but I could not launch my code on jupyter notebook because the kernel crashed at the very beginning when importing packages.
  • Then I decided to remove pytorch entirely and install it back. I ran the following lines :
conda uninstall pytorch
pip uninstall torch
pip uninstall torch
conda uninstall torchvision
pip uninstall torchvision
pip uninstall torchvision
conda uninstall torchaudio
pip uninstall torchaudio
pip uninstall torchaudio
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

But the version I got back was the previous one (1.7.1)

  • Then I wanted to force pytorch to update to 1.8.1. So I ran conda install pytorch=1.8.1 torchvision torchaudio cudatoolkit=10.2 -c pytorch
    But I got several package conflicts (openssl, mkl_fft, python_abi, msgpack-python, m2w64-gcc-libs, tornado, vs2008_runtime, libssh2, prompt_toolkit, vs2015_runtime, clyent, pywin32, sqlite, blas, statsmodels, ibllvm9, jupyter_client, babel, zipp, libpng, hypothesis, ipython, xz, configparser, enum34, traitlets, sphinx, numpy, nbformat, jpeg, contextlib2, pycrypto, python-simplegeneric, matplotlib-inline, liblapacke, jupyterlab_widgets, snowballstemmer, backcall, urllib3, ipython_genutils, jupyterlab_server, notebook, rtree, pysocks, qtconsole, bkcharts, conda, pyopenss, toolz, matplotlib-base, html5lib, mkl-service, zlib, spyder-kernels, imagesize, qtawesome, pandocfilters, ptyprocess, python-language-server, jedi, anaconda-project, … ).
    I am using python 3.8.3 and anaconda 4.10.1

Could you help me with this?

Your local CUDA toolkit won’t be used, if you install the conda binaries or pip wheels, and you would only need to install the NVIDIA driver.
Since you are apparently running into env conflicts, you could try to create a new conda env and install the latest PyTorch version there.

Thank you for your answer,
Ok, I will try to set up a new conda env and install PyTorch. I will let you know.

Dear Patrick,
I have created a new conda environement and I’ve been able to install the latest version of Pytorch.
That said, I still had an error related to illegal memory access. I found in another post that it could be related to having variables both on gpu and cpu. So, I checked up on the variable (i.e. weights) created in MyConv2d and it appeared that it was created by default on cpu. To correct I added the following lines at the top of the forward function :

        if x.is_cuda:
            self.weights = self.weights.cuda()

I also had problem of exploding gradient at the very beginning of training, so I normalized the output of MyConv2d to check if it improves things. For now, this issue seems solved, but I still have a problem with training but this time with memory. The training goes well for a big number of epochs (~100) , but at some point, I get this error :

RuntimeError: CUDA out of memory. Tried to allocate 656.00 MiB (GPU 0; 5.00 GiB total capacity; 345.42 MiB already allocated; 590.35 MiB free; 1.10 GiB reserved in total by PyTorch)

And I don’t understand why it occurs at this point; memory should be free at the end of each epoch, right? Then I have another question : is there a tool (like the Matlab profile viewer) to check the memory consumption for pytorch code?

I post here the last version of my custom convolutional layer, if it helps.

class MyConv2d(nn.Module):
    def __init__(self, n_channels, out_channels, kernel_size, dilation=1, padding=0, stride=1):
        super(MyConv2d, self).__init__()

        self.kernel_size = (kernel_size, kernel_size)
        self.kernel_size_number = kernel_size * kernel_size
        self.out_channels = out_channels
        self.dilation = (dilation, dilation)
        self.padding = (padding, padding)
        self.stride = (stride, stride)
        self.n_channels = n_channels
        self.weights = nn.Parameter(torch.Tensor(self.out_channels, self.n_channels, self.kernel_size_number)).data.uniform_(0, 1)

    def forward(self, x, mask):
        if x.is_cuda:
            self.weights = self.weights.cuda()
        width = self.calculateNewWidth(x)
        height = self.calculateNewHeight(x)
        result = torch.zeros(
            [x.shape[0] * self.out_channels, width, height], dtype=torch.float32, device=device
        )
        result_mask = torch.zeros(
            [x.shape[0] * self.out_channels, width, height], dtype=torch.float32, device=device
        )
        
        
        windows_x = self.calculateWindows(x)
        windows_mask = self.calculateWindows(mask)
        windows_mask[windows_mask < 1] = -1
        windows_mask_centers = windows_mask[:, :, windows_mask.size()[2]//2].view(windows_mask.size()[0], windows_mask.size()[1], 1)
        windows_mask = windows_mask * windows_mask_centers
        windows_mask[windows_mask < 1] = 0 
        windows_x_seg = windows_x * windows_mask
        
        # compute the result of x with mask-aware convolution
        for i_convNumber in range(self.out_channels):
            for channel in range(x.shape[1]):
                xx = torch.matmul(windows_x_seg[channel], self.weights[i_convNumber][channel].view(-1, 1))
                xx = xx.view(-1, width, height)/torch.sum(windows_mask[channel], 1).view(-1, width, height)
                result[i_convNumber * xx.shape[0] : (i_convNumber + 1) * xx.shape[0]] += xx
            result[i_convNumber * xx.shape[0] : (i_convNumber + 1) * xx.shape[0]] /= x.shape[1]

        result = result.view(x.shape[0], self.out_channels, width, height)
 
        # compute the result of mask with mask-aware convolution
        windows_mask_seg = self.calculateWindows(mask) * windows_mask
        for i_convNumber in range(self.out_channels):
            for channel in range(x.shape[1]):
                xx = torch.matmul(windows_mask_seg[channel], self.weights[i_convNumber][channel].view(-1, 1))
                xx = xx.view(-1, width, height)
                result_mask[i_convNumber * xx.shape[0] : (i_convNumber + 1) * xx.shape[0]] += xx

        result_mask = result_mask.view(mask.shape[0], self.out_channels, width, height)
        result_mask = torch.clamp(result_mask, min=0, max=1)
        
        return result, result_mask

    def calculateWindows(self, x):
        windows = F.unfold(
            x, kernel_size=self.kernel_size, padding=self.padding, dilation=self.dilation, stride=self.stride
        )

        windows = windows.transpose(1, 2).contiguous().view(-1, x.shape[1], self.kernel_size_number)
        windows = windows.transpose(0, 1)

        return windows

    def calculateNewWidth(self, x):
        return (
            (x.shape[2] + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1)
            // self.stride[0]
        ) + 1

    def calculateNewHeight(self, x):
        return (
            (x.shape[3] + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1)
            // self.stride[1]
        ) + 1

Many thanks in advance

Typically a gradual OOM after many epochs can be the result of something in training loop unwittingly holding on to previous data that no longer needs to stored (e.g., the loss). Can you share the training loop of the model?

Dear eqy,
Thank you for your reply.
Here is my training loop :

# transfer the model to the gpu
model.to(device)

#define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)

# define loss function
distortion = nn.MSELoss().cuda()
        
beta = 0.00001
n_epochs = 400

for epoch in range(0, n_epochs):
    running_loss = 0.0

    for i_batch, data in enumerate(dataloader):
        batch_images = data[0].to(device).float()
        batch_masks = data[1].to(device).float()
        [decoded_images, x_quantized] = model(batch_images, batch_mask, 1, True)
        optimizer.zero_grad()
        loss_dist = distortion(decoded_images, batch_images)
        loss_bit = entropy_dist(x_quantized, model.phi, model.var)
        loss = beta * loss_dist + loss_bit
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    running_loss = running_loss/len(dataloader)

I don’t know if I should prevent batch_mask from beeing stored for backpropagation (batch_mask.detach()). batch_mask does not appear in the loss but is an entry to my custom convolutional layers, and is used to modify the main input batch_images.