Depthwise convolution gives (seemingly) random results

I implemented a function that performs a traditional image filtering such as gaussian smoothing in PyTorch with depthwise convolutions and a two-dimensional kernel. To save some memory I used .expand() to get the kernel in the right shape instead of repeat(). In my understanding this should do the same thing, since the kernel is static and used for every channel.

Consider the following code:

import numpy as np
print("numpy version: " + np.__version__)
import torch
print("torch version: " + torch.__version__)
import torch.nn.functional as F

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def tensor_stats(x):
    return ", ".join(["{}={:.3g}".format(stat, stat_fun(x).item())
                      for stat, stat_fun
                      in zip(("min", "max"), (torch.min, torch.max))])

def image_filter_repeat(image, kernel):
    num_channels = image.size()[1]
    weight = kernel.unsqueeze(0).unsqueeze(0).repeat(num_channels, 1, 1, 1)
    return F.conv2d(image, weight, groups=num_channels,
                    bias=None, stride=1, padding=0, dilation=1)

def image_filter_expand(image, kernel):
    num_channels = image.size()[1]
    weight = kernel.unsqueeze(0).unsqueeze(0).expand(num_channels, -1, -1, -1)
    return F.conv2d(image, weight, groups=num_channels,
                    bias=None, stride=1, padding=0, dilation=1)


np.random.seed(0)
torch.manual_seed(0)
kernel = torch.randn((5, 5))
print("kernel: " + tensor_stats(kernel))
input_image = torch.rand((1, 3, 512, 512))
print("input_image: " + tensor_stats(input_image))

np.random.seed(0)
torch.manual_seed(0)
for _ in range(10):
    output_image = image_filter_expand(input_image.clone(),
                                       kernel.clone())
    print("output_image (expand): " + tensor_stats(output_image))

    output_image = image_filter_repeat(input_image.clone(),
                                       kernel.clone())
    print("output_image (repeat): " + tensor_stats(output_image))

I got the following output:

numpy version: 1.14.5
torch version: 1.0.0
kernel: min=-2.32, max=2
input_image: min=1.79e-07, max=1
output_image (expand): min=nan, max=nan
output_image (repeat): min=-6.61, max=5.06
output_image (expand): min=-6.39, max=1.72e+25
output_image (repeat): min=-6.61, max=5.06
output_image (expand): min=nan, max=nan
output_image (repeat): min=-6.61, max=5.06
output_image (expand): min=nan, max=nan
output_image (repeat): min=-6.61, max=5.06
output_image (expand): min=nan, max=nan
output_image (repeat): min=-6.61, max=5.06
output_image (expand): min=nan, max=nan
output_image (repeat): min=-6.61, max=5.06
output_image (expand): min=-1.93e+04, max=1.8e+28
output_image (repeat): min=-6.61, max=5.06
output_image (expand): min=-1.93e+04, max=4.81
output_image (repeat): min=-6.61, max=5.06
output_image (expand): min=-1.93e+04, max=4.81
output_image (repeat): min=-6.61, max=5.06
output_image (expand): min=-2.64e+34, max=4.81
output_image (repeat): min=-6.61, max=5.06

As you can see repeat() gives consistent results while the version with expand() varies (seemingly) random between nan, exorbitant high values , and realistically values, which still deviate from repeat().

I’m not even able to reproduce this behaviour, because every run gives different results. I (unsuccessfully) tried to prevent this by

  1. setting the torch seed,
  2. setting the numpy random seed, although I don’t use any numpy code,
  3. setting the CUDNN backend into the deterministic mode, although I’m operating only on the CPU,
  4. cloning the input_image and kernel in every iteration.

I got two questions:

  1. Can someone reproduce my results or is this some weird behaviour happening only on my machine?
  2. Can some explain to me why the version with expand() produces this (seemingly) random results and thus why I should not use it for this purpose?
1 Like

That’s a bug in MKLDNN convolutions needing continuous kernels but not saying so.
Thank you for bringing this up here!

I filed an issue.

Best regards

Thomas

1 Like

Damn, should have guessed memory as error source. I was wondering why sometimes I got the exact same results in consecutive runs and by something trivial as commenting a line the output changed. Thanks for the fast clarification.

To follow up your proposal within the issue on GitHub:

Does .contiguous() only align the data within the memory or does it copy it? In other word, has x.expand(*).contiguous() lower memory requirement as x.repeat(*) or are they basically the same?

.contiguous() will copy here. It does if the tensor isn’t contiguous, and this is the case here, as contiguous is roughly defined as “lowest stride = 1, higher strides = previous_size * previous_stride”.

Best regards

Thomas