CUDA alignment error when using DataParallel

Hi @smth I am trying to use the DataParallel routine, but getting stuck on an error. Right now I am just trying to DataParallelify just one conv module within my net. So previously I had:

self.conv1 = nn.Conv2d(N.inputChannels, N.outputChannels, N.kernelSquareSize, stride = (1,1), padding = (1,1));

and now, as per the DataParallel documentation, I have:

self.conv1 = nn.Conv2d(N.inputChannels, N.outputChannels, N.kernelSquareSize, stride = (1,1), padding = (1,1)); self.conv1 = torch.nn.DataParallel(self.conv1, device_ids = [1, 2])

With this new code however, I am unable to get it to run, and I get the following error:

File “/billly/sillyNet.py”, line 67, in forward_prop
x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2,2))
File “/data/venv/pytorch/local/lib/python2.7/site-packages/torch/nn/modules/module.py”, line 210, in call
result = self.forward(*input, **kwargs)
File “/data/venv/pytorch/local/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py”, line 45, in forward
return self.gather(outputs, self.output_device)
File “/data/venv/pytorch/local/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py”, line 57, in gather
return gather(outputs, output_device)
File “/data/venv/pytorch/local/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py”, line 25, in gather
return gather_map(outputs)
File “/data/venv/pytorch/local/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py”, line 23, in gather_map
return Gather(target_device)(*outputs)
File "/data/venv/pytorch/local/lib/python2.7/site-packages/torch/nn/parallel/functions.py", line 32, in forward
return comm.gather(inputs, self.dim, self.target_device)
File “/data/venv/pytorch/local/lib/python2.7/site-packages/torch/cuda/comm.py”, line 141, in gather
result.narrow(dim, chunk_start, tensor.size(dim)).copy
(tensor, True)
RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at /home/soumith/local/builder/wheel/pytorch-src/torch/lib/THC/THCTensorCopy.cu:85

And it is occurring during the (attempted) forward prop in my code…

thanks…

Can you please give us the parameters of the conv so we can try to reproduce the issue?

Hi @apaszke,

Sure, here is my complete snippet:

 class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
         
    self.bn1 = nn.BatchNorm2d(8)
    self.bn2 = nn.BatchNorm2d(16)

    self.conv1 = nn.Conv2d(1, 8, 3 ,stride = (1,1), padding = (1,1))
    self.conv1 = torch.nn.DataParallel(self.conv1, device_ids = [1, 2])           
    self.conv2 = nn.Conv2d(8, 16 ,3, stride = (1,1), padding = (1,1))
    //etc ...

  def forward_prop(self, x):
 
    x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2,2))
    x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), (2,2))    

This is my only change. I will also mention that if I simply remove the torch.nn.DataParallel line in the above, my code runs and trains fine.

Thanks,

I need some additional information. Why do you define forward_prop and not forward? What’s the input size? On which GPUs are the modules? On which GPU is the input?

Hi @apaszke,

Why do you define forward_prop and not forward?

For this question, I am not sure what the relevance is as far as the name goes?.. Perhaps I am missing something - but to be honest I had defined it before in my class and it worked - perhaps I am missing something deeper here? It’s just a name of the forward propagation function that I give…

What’s the input size?

This one is a 100x100 image, single channel, minibatch size of 16.

On which GPUs are the modules? On which GPU is the input?

I believe they are all on my GPU 1, and I only say this because when I nvidia-smi, this GPU is the only one that seems to ever be used. Put another way - at no where in my code have I explicitly specified that I want to use specific GPUs, except for the device_ids in the seemingly problematic statement.

@Kalamaya the name forward is important because the __call__ operator uses forward from the module. You need to define a forward function in every Module that you create, if not it will raise an error in __call__.

Ah right. The problem is that we’re numbering GPUs starting from 0. So all the modules and data is on GPU0, but you’re telling the DataParallel to run on GPU1 and GPU2 (i.e. 2nd and 3rd GPU). Can you change that and see if it helps? If that’s it, then we need to improve the error message.

Also, unless this is a helper used in forward that’s defined somewhere else, you need to call it like that. @fmassa wrote a nice explanation why.

@fmassa @apaszke I guess I am getting very confused here… To my knowledge, my forward_prop member function IS my definition of the forward function in my Net module.… have I misunderstood something here? For completeness, I have pasted my Net class here:

class Net(nn.Module):
 
  def __init__(self):
    super(Net, self).__init__()

    # Define the network
    self.bn1 = nn.BatchNorm2d(8)
    self.bn2 = nn.BatchNorm2d(16)        
    self.conv1 = nn.Conv2d(1, 8, 3, stride = (1,1), padding = (1,1))
    self.conv2 = nn.Conv2d(8, 16, 3, stride = (1,1), padding = (1,1))    
    self.conv3 = nn.Conv2d(16, 32, 2, stride = (1,1), padding = (0,0))
    self.conv4 = nn.Conv2d(32, 64, 3, stride = (1,1), padding = (1,1))
    self.conv5 = nn.Conv2d(64, 64, 3, stride = (3,3), padding = (0,0))
    self.fc1   = nn.Linear(256, 32) 
    self.fc2   = nn.Linear(32, 16)
    self.fc3   = nn.Linear(16, 2)

  def forward_prop(self, x):
    
    # Conv1 with batch norm
    x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2,2))

    # Conv2 with batch norm.
    x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), (2,2))    

    # Conv3
    x = F.max_pool2d(F.relu(self.conv3(x)), (2,2))    

    # Conv4
    x = F.max_pool2d(F.relu(self.conv4(x)), (2,2))

    # Conv5
    x = F.relu(self.conv5(x))
    
    # Flatten the feature map thus far for use in the fully connected layer.
    x = x.view(-1, self.num_flat_features(x))
        
    # Fully connected 1    
    x = F.relu(self.fc1(x))

    # Fully connected 2    
    x = F.relu(self.fc2(x))

    # Final layer
    x = self.fc3(x)

    return x

  def num_flat_features(self, x):    
    # all dimensions except the batch dimension
    size = x.size()[1:] 
    num_features = 1
    for s in size:
        num_features *= s
    return num_features

So real quick, in the rest of my file, I basically have:

net = Net().cuda()
net.forward_prop(trainingBatch)

Now what I would simply like to do is given this setup, I want to just use the DataParallel setup. I looked at the DCGAN example, however I cannot use it at first glance, because I do not know how to do a reshape-operation in the nn.Sequentual() function, and as far as the image_net example, it is not clear to me what their “model” is, vis-a-vis what I have here.

So instead I wrote a simple net (as above), and would like to use the DataParallel capability here. What would I do differently on this setup as above?

thanks again!

Ok, I got this to work, because apparently the indexing of the GPUs used by nvidia-smi, is not the same as the indexing used by the program. This was what was causing on of the issues. (There was another issue where if the batch size was too small, it complained, but that seemed to go away once I made the batch size larger. Thanks!!

You’re using the Module incorrectly. You should never call a function directly to apply it. You should only define the forward function and simply call a module like a function - module(input). This will run the __call__ method, that we’ve implementented and that will call into your forward function with the inputs. It’s necessary for some additional bookkeeping like hooks, etc.

So the problem was 1 vs 0-based GPU indexing right? We need to fix that, it should never give you an invalid memory access.

@apaszke thanks, yes I now understand - I just need to fill in / define the forward function with whatever needs to be done, and then from the object instantiation I made, do something like:

myNet = Net()
myNet = torch.nn.DataParallel(myNet, device_ids=[0,1])
myNet.cuda()
output = myNet(input)

Is this correct?

However there is a subtlety that still confuses me: In the DCGAN example, we have:

The nn.parallel.data_parallel here is applied to a member of the _netG class, in this case, self.main. This contrasts to what I have, where I apply the torch.nn.DataParallel to the entire network object. I guess this works because both the net object in my example and the nn.Sequential are both modules? At any rate, some elucidation of this would help me greatly. Thanks again.

They’re both equivalent. The DCGAN example has been written when we didn’t have the DataParallel module, and it only existed in that functional form, that you can use inside the forward.