Padding_mode not working correctly in Conv2d

It seems like all three options for padding_mode parameter: 'zeros' , 'reflect' , 'replicate' output same 0 paddings. Only 'circular' outputs the padding its name suggests. I have used the following code to test this.

import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.utils as utils
import numpy as np

def imshow(images):
    img = utils.make_grid(images)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

for images, labels, names in loader_eval:
    conv1 = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=10, padding_mode='replicate')
    conv1.state_dict()['weight'].copy_(torch.FloatTensor([[[[1.0]]]]))
    conv1.state_dict()['bias'].copy_(torch.FloatTensor([0.0]))
    img = conv1(images)
    imshow(img.detach())
    break

Am I doing something wrong or is there a bug in the implementation?
Thanks.

Hi,

I ran your code and everything works just fine. Could you please share the exact input and output that produces wrong padding?

I changed your visualization method and this may help to depict the cases better.

import torch.nn as nn
from PIL import Image
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
import torchvision.utils as utils
import numpy as np

def imshow(img):
    npimg = img.numpy()
    # npimg = np.transpose(npimg, (1, 2, 0))
    df_cm = pd.DataFrame(npimg[0, 0])
    plt.figure(figsize = (10,7))
    sn.heatmap(df_cm, annot=True)

# I have considered that you have a batch of 1 grayscale image.
img = torch.arange(0, 100, dtype=torch.float32).view(1, 1, 10, 10)
conv1 = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=(2,2), padding_mode='replicate')
conv1.state_dict()['weight'].copy_(torch.FloatTensor([[[[1.0]]]]))
conv1.state_dict()['bias'].copy_(torch.FloatTensor([0.0]))
img2 = conv1(img)
imshow(img2.detach())

bests

Hi,
Thanks for the reply. I have ran your code and here are the ouputs:
reflect


replicate

circular

Thanks.

Hi,
I think I found the issue. My Pytorch version is 1.4.0 and I was refering to 1.5.0 docs.
Thanks again for the help.

1 Like

Ow, that is possible, I lost the track of versions.
I think I have to mention that libraries like PyTorch are being consistently updated and there are thousands of issues and pull requests on Github. So, I think the best way to keep our codes stable and more reliable (even faster and more optimized) is to update to latest stable version.

Good luck

1 Like