Illegal memory access in encoder-decoder network

I am implementing an encoder-decoder network in which the encoder takes a 5d input and compresses it to a 4d output, while the decoder takes a 4d input and up-samples to a 4d output. There are shortcut connections which pass 4d slices of encoder feature maps to the decoder. There are also 4d slices of MaxPool indices passed from encoder to decoder. I am getting “illegal memory access” errors. Here is a minimal example:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.utils.data

#---------------------------------------------------------------------------
# define encoder-decoder network and optimizer

class encoder_decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc_conv1 = nn.Conv3d(1,8, kernel_size=(1,3,3), stride=1, padding=(0,1,1))
        self.pool = nn.MaxPool3d(kernel_size=(1,2,2), stride=(1,2,2), return_indices=True)
        self.enc_conv2 = nn.Conv3d(8,8, kernel_size=(3,1,1), stride=1, padding=0)
        
        self.dec_conv1 = nn.Conv2d(8,8, kernel_size=3, stride=1, padding=1)
        self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
        self.dec_conv2 = nn.Conv2d(16,1, kernel_size=3, stride=1, padding=1)

    def forward(self, net_input):
        out = net_input
        out = self.enc_conv1(out)          # 1x3x64x64 -> 8x3x64x64
        shortcut = out[:,:,1]              # reference to center slice of out
        out, indices = self.pool(out)      # 8x3x64x64 -> 8x3x32x32
        indices = indices[:,:,1]           # reference to center slice of indices
        out = self.enc_conv2(out)          # 8x3x32x32 -> 8x1x32x32
        
        out = out.squeeze(2)               # 8x1x32x32 -> 8x32x32
        out = self.dec_conv1(out)          # 8x32x32 -> 8x32x32
        out = self.unpool(out, indices)    # 8x32x32 -> 8x64x64
        out = torch.cat((shortcut,out),1)  # 8x64x64 -> 16x64x64
        out = self.dec_conv2(out)          # 16x64x64 -> 1x64x64
        
        return out

net = encoder_decoder()
net.cuda()

criterion = nn.MSELoss()
criterion.cuda()

optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

#----------------------------------------------------------------------------
# define dataset and dataloader

class create_dataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data = [ (torch.rand(1,3,64,64),
                       torch.rand(1,64,64)) for i in range(100) ]

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

dataset = create_dataset()
print('Loaded ' + str(len(dataset)) + ' training examples')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, num_workers=8)

#----------------------------------------------------------------------------
# training loop

print('Start training loop')
for epoch in range(4):
	
    print('Epoch: ' + str(epoch))
    net.train()

    for training_idx, (input_batch,target_batch) in enumerate(dataloader):
        print('Training batch: ' + str(training_idx))
        input_batch = Variable(input_batch.cuda())
        target_batch = Variable(target_batch.cuda())

        optimizer.zero_grad()
        output_batch = net(input_batch)
        err = criterion(output_batch, target_batch)
        err.backward()
        optimizer.step()

Here is the output:

$ CUDA_LAUNCH_BLOCKING=1 python error.py 
Loaded 100 training examples
Start training loop
Epoch: 0
Training batch: 0
Training batch: 1
THCudaCheck FAIL file=/py/conda-bld/pytorch_1490895093647/work/torch/lib/THCUNN/generic/SpatialMaxUnpooling.cu line=43 error=77 : an illegal memory access was encountered
Traceback (most recent call last):
  File "error.py", line 78, in <module>
    output_batch = net(input_batch)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "error.py", line 31, in forward
    out = self.unpool(out, indices)    # 8x32x32 -> 8x64x64
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/torch/nn/modules/pooling.py", line 304, in forward
    self.padding, output_size)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/torch/nn/functional.py", line 277, in max_unpool2d
    return f(input, indices)
  File "/home/ubuntu/anaconda3/lib/python3.5/site-packages/torch/nn/_functions/thnn/pooling.py", line 177, in forward
    self.output_size[1], self.output_size[0])
RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at /py/conda-bld/pytorch_1490895093647/work/torch/lib/THCUNN/generic/SpatialMaxUnpooling.cu:43

Hi Alex,

A very quick way to debug this is to run this in CPU mode.
You will know exactly why this failed immediately in the CPU mode.

Now if that does not happen, I am thinking it might be a CUDA Unpooling bug, and I want to spend a few hours tracking this down. If you can help me by giving me a script that I can run (with dummy data and stuff) that reproduces this issue, I’m happy to run it and fix the issue within a week or less.

@smth I believe the minimal example provided in my original post is the script you are asking about. You just have to remove all the .cuda() if you want a CPU version. That said, I’ve been playing around with that code this morning, and I’m finding all sorts of weird behavior.

First, the most serious problem is that nn.MaxPool3d outputs junk indices. Here is an example:

pool3d = nn.MaxPool3d(kernel_size=2,stride=2,return_indices=True)
img3d = Variable(torch.rand(1,1,4,4,4))
out, indices = pool3d(img3d)
print(indices)

The elements of indices should be in the range [0,63], but here is what is printed:

Variable containing:
(0 ,0 ,0 ,.,.) = 
  4.6117e+18 -9.2234e+18
  3.0065e+10  4.2950e+09

(0 ,0 ,1 ,.,.) = 
  6.5536e+04  8.5213e+14
 -3.4588e+18  4.6117e+18
[torch.LongTensor of size 1x1x2x2x2]

Second, and this is the really “weird behavior” stuff, my minimal example stops producing an error when I change these lines of code:

input_batch = Variable(input_batch.cuda())
target_batch = Variable(target_batch.cuda())

to this:

input_batch = Variable(input_batch)
target_batch = Variable(target_batch)
input_batch = input_batch.cuda()
target_batch = target_batch.cuda()

I’m not sure what to make of that. When I run the code on the CPU, I get an error complaining about the pooling indices, which makes sense. However, when running on the GPU, I get no error at all, as long as I move the inputs and targets to the GPU in a very particular way (otherwise I get an illegal memory access error). That’s dangerous behavior, because I wouldn’t have even noticed that the pooling indices are all messed up if I had written my code in a slightly different way. (Another seemingly random thing is that removing the shortcut connection also eliminates the error, so this is all very weird.)

Finally, for future reference, regarding the correctness of my minimal example, I need to be more careful when slicing the pooling indices. The pooling indices are supposed to refer to coordinates in the input Tensor given in flattened spatial coordinates. So, what I really should have written is:

pre_pool_shape = out.size[2:]
out, indices = self.pool(out)
indices = indicies[:,:,pre_pool_shape[0]//2]   # only want the center slice
indices %= pre_pool_shape[1]*pre_pool_shape[2]  # flattened coords for 2d slice

Of course this won’t work until nn.MaxPool3d is fixed.

1 Like

Moving the call to cuda() our of the declaration of the Variable solved it for me, too. And I’m using the latest version of pytorch.