Index_select with Dataparallel: arguments are located on different GPUs

I get

RuntimeError: arguments are located on different GPUs at /pytorch/torch/lib/THC/generic/THCTensorIndex.cu:390

When I use index_selct function with torch.nn.DataParallel(net).

During forwarding pass of my model. I use tenspr.index_select() to flatten the output feature map of a Conv layer to feed it into the linear layer. Basically, I am implementing next CONV layer with a linear. Which I am going to use to increase the flexibility of my model. It works fine on one GPU and able to reproduce results of next CONV layer using index_select and the linear layer on top. Here is code part where I get this error.

import torch.nn as nn
import torch, math
from torch.autograd import Variable

class SpatialPool(nn.Module):
    def __init__(self, amd0=225, kd=3):
        super(SpatialPool, self).__init__()
        print('*** spatial_pooling.py : __init__() ***', amd0,kd,fd)
        self.use_gpu = True
        self.amd0 = amd0 #225
        self.kd = kd
        self.padding = nn.ReplicationPad2d(1).cuda()

        ww = hh = int(math.sqrt(amd0)) ## 15
        counts = torch.LongTensor(amd0,kd*kd) ## size [225,9]
        v = [[(hh+2)*i + j for j in range(ww+2)] for i in range(hh+2)]
        count = 0
        for h in range(1,hh+1):
            for w in range(1,ww+1):
                counts[count,:] = torch.LongTensor([v[h - 1][w - 1], v[h - 1][w], v[h - 1][w + 1],
                                                    v[h][w - 1], v[h][w], v[h][w + 1],
                                                    v[h + 1][w - 1], v[h + 1][w], v[h + 1][w + 1]])
                count += 1

        self.counts = counts.cuda()

    def forward(self, fm):
        fm = self.padding(fm) ## FM is Variable of size[batch_size,512,15,15]
        fm = fm.permute(0, 2, 3, 1).contiguous()
        fm = fm.view(fm.size(0), -1, fm.size(3))
        print('fm size and max ', fm.size(), torch.max(self.counts))
        pfm = fm.index_select(1,Variable(self.counts[:,0]))
        for h in range(1,self.kd*self.kd):
            pfm = torch.cat((pfm,fm.index_select(1, Variable(self.counts[:, h]))),2)
        # print('pfm size:::::::: ', pfm.size()) #[batch_size,225,512*9]
        return pfm

Here fm is a matrix of size.

Full error message looks like below

File "/home/gurkirt/Dropbox/sandbox/ssd-pytorch-linear/layers/modules/feat_pooling.py", line 48, in forward
    pfm = self.spatial_pool1(fm)  # pooled feature map
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 224, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/gurkirt/Dropbox/sandbox/ssd-pytorch-linear/layers/modules/spatial_pooling.py", line 36, in forward
    pfm = fm.index_select(1,Variable(self.counts[:,0]))
  File "/usr/local/lib/python3.5/dist-packages/torch/autograd/variable.py", line 681, in index_select
    return IndexSelect.apply(self, dim, index)
  File "/usr/local/lib/python3.5/dist-packages/torch/autograd/_functions/tensor.py", line 297, in forward
    return tensor.index_select(dim, index)
RuntimeError: arguments are located on different GPUs at /pytorch/torch/lib/THC/generic/THCTensorIndex.cu:390
1 Like

I don’t know what version of pytorch you’re using, but I tested this on a build from source and it doesn’t error out.

Thank for the reply. Did you testes by applying dataparallel on multiple GPUs?

I am using latest version of pytorch available by pip

Yes, I tested with dataparallel on multiple GPUs. You could try building from source to see if the problem goes away, or wait for the next release (should be out soon).

Hi, Richard thanks for the help. I seriously need help. I am stuck at this point for few days. I still get the same error. I installed pytorch latest source from GitHub by following the instruction from https://github.com/pytorch/pytorch#from-source.

Here the complete script to reproduce the error. Please check the part in main function where I create mynet and apply Dataparallel. I run the below snippet using two GPUs. If I put the whole thing below in python script name test.py then run it using CUDA_VISIBLE_DEVICES=0,1 python test.py

import torch.nn as nn
import torch
from torch.autograd import Variable

def main():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    net = mynet()
    net = net.cuda()
    net = torch.nn.DataParallel(net)
    x = torch.autograd.Variable(torch.FloatTensor(64, 512, 15, 15).cuda())  # batch_sizexCxWxH
    out = net(x)
    print(out.size())

class SpatialPool(nn.Module):
    def __init__(self, amd0=225, kd=3):
        super(SpatialPool, self).__init__()
        print('*** spatial_pooling.py : __init__() ***', amd0,kd)
        self.use_gpu = True
        self.amd0 = amd0 #225
        self.kd = kd
        self.padding = nn.ReplicationPad2d(1).cuda()

        ww = hh = int(math.sqrt(amd0)) ## 15
        counts = torch.LongTensor(amd0,kd*kd) ## size [225,9]
        v = [[(hh+2)*i + j for j in range(ww+2)] for i in range(hh+2)]
        count = 0
        for h in range(1,hh+1):
            for w in range(1,ww+1):
                counts[count,:] = torch.LongTensor([v[h - 1][w - 1], v[h - 1][w], v[h - 1][w + 1],
                                                    v[h][w - 1], v[h][w], v[h][w + 1],
                                                    v[h + 1][w - 1], v[h + 1][w], v[h + 1][w + 1]])
                count += 1

        self.counts = counts.cuda()

    def forward(self, fm):
        fm = self.padding(fm) ## FM is Variable of size[batch_size,512,15,15]
        fm = fm.permute(0, 2, 3, 1).contiguous()
        fm = fm.view(fm.size(0), -1, fm.size(3))
        print('fm size and max ', fm.size(), torch.max(self.counts))
        pfm = fm.index_select(1,Variable(self.counts[:,0]))
        for h in range(1,self.kd*self.kd):
            pfm = torch.cat((pfm,fm.index_select(1, Variable(self.counts[:, h]))),2)
        # print('pfm size:::::::: ', pfm.size()) #[batch_size,225,512*9]
        return pfm


class mynet(nn.Module):
    def __init__(self):
        super(mynet, self).__init__()
        self.cl = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.featPool = SpatialPool(amd0=225)

    def forward(self, x):
        x = self.cl(x)
        x = self.featPool(x)
        return x

if __name__ == '__main__':
    main()

When I put the whole code above in python script name test.py then run it using CUDA_VISIBLE_DEVICES=0,1 python test.py then I get this:

File "dummy_test.py", line 61, in <module>
main()
File "dummy_test.py", line 12, in main
out = net(x)
File "/home/gurkirt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
result = self.forward(*input, **kwargs)
File "/home/gurkirt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 68, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/gurkirt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 78, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/gurkirt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 67, in parallel_apply
raise output
File "/home/gurkirt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 42, in _worker
output = module(*input, **kwargs)
File "/home/gurkirt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
result = self.forward(*input, **kwargs)
File "dummy_test.py", line 57, in forward
x = self.featPool(x)
File "/home/gurkirt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
result = self.forward(*input, **kwargs)
File "dummy_test.py", line 42, in forward
pfm = fm.index_select(1,Variable(self.counts[:,0]))
RuntimeError: arguments are located on different GPUs at /home/gurkirt/pytorch/aten/src/THC/generic/THCTensorIndex.cu:452
1 Like

Okay, I repro-ed the problem. You should use register_buffer to add the counts tensor to the Module so that DataParallel will pick it up. The following will fix it:

import torch.nn as nn
import torch
from torch.autograd import Variable
import math

def main():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    net = mynet()
    net = net.cuda()
    net = torch.nn.DataParallel(net)
    x = torch.autograd.Variable(torch.FloatTensor(64, 512, 15, 15).cuda())  # batch_sizexCxWxH
    out = net(x)
    print(out.size())

class SpatialPool(nn.Module):
    def __init__(self, amd0=225, kd=3):
        super(SpatialPool, self).__init__()
        print('*** spatial_pooling.py : __init__() ***', amd0,kd)
        self.use_gpu = True
        self.amd0 = amd0 #225
        self.kd = kd
        self.padding = nn.ReplicationPad2d(1).cuda()

        ww = hh = int(math.sqrt(amd0)) ## 15
        counts = torch.LongTensor(amd0,kd*kd) ## size [225,9]
        v = [[(hh+2)*i + j for j in range(ww+2)] for i in range(hh+2)]
        count = 0
        for h in range(1,hh+1):
            for w in range(1,ww+1):
                counts[count,:] = torch.LongTensor([v[h - 1][w - 1], v[h - 1][w], v[h - 1][w + 1],
                                                    v[h][w - 1], v[h][w], v[h][w + 1],
                                                    v[h + 1][w - 1], v[h + 1][w], v[h + 1][w + 1]])
                count += 1

        # self.counts = counts.cuda()
        self.register_buffer("counts", counts)

    def forward(self, fm):
        fm = self.padding(fm) ## FM is Variable of size[batch_size,512,15,15]
        fm = fm.permute(0, 2, 3, 1).contiguous()
        fm = fm.view(fm.size(0), -1, fm.size(3))
        print('fm size and max ', fm.size(), torch.max(self.counts))
        pfm = fm.index_select(1,Variable(self.counts[:,0]))
        for h in range(1,self.kd*self.kd):
            pfm = torch.cat((pfm,fm.index_select(1, Variable(self.counts[:, h]))),2)
        # print('pfm size:::::::: ', pfm.size()) #[batch_size,225,512*9]
        return pfm


class mynet(nn.Module):
    def __init__(self):
        super(mynet, self).__init__()
        self.cl = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.featPool = SpatialPool(amd0=225)

    def forward(self, x):
        x = self.cl(x)
        x = self.featPool(x)
        return x

if __name__ == '__main__':
    main()
4 Likes

Yes, it works. Thanks a lot, Richard.