Backward call raises arguments are on different gpus

I implement a custom layer and loss layer, when calling loss.backward(), it shows an error that “tensor of a(4) doesn’t match with tensor b(22) at non-singleton dimension” or “arguments are on different gpus”, depends on whether I’m using DataParallel module.

Here is my layer

class hamming(nn.Module):

    def __init__(self,args):

        super(hamming,self).__init__()
        self.binary_dim = args.binary_dim
        self.channels = 512
        self.num_centers = args.num_centers
        self.linear_project = nn.Linear(self.channels,self.binary_dim,False)
        self.cluster_prob = nn.Conv2d(self.channels,self.num_centers,1)
        self.prob_softmax = nn.Softmax2d()
        self.cluster_median = nn.Parameter(th.rand(self.binary_dim,self.num_centers),True) # B x K
        self.binarization = binary()
        self.thres = nn.Threshold(0,0,True)

        nn.init.xavier_uniform(self.cluster_median,1)

    def forward(self,feature_map):

        # project

        size = feature_map.size()
        _bs,_d,_h,_w = size[0],size[1],size[2],size[3]
        perm = feature_map.permute(0, 2, 3, 1) # N x H x W x D
        proj = self.linear_project(perm) # N x H x W x B
        h_proj = proj.permute(0, 3, 1, 2).contiguous().view(_bs, self.binary_dim, _h * _w) # N x B x H*W

        # assign

        prob = self.cluster_prob(feature_map)
        prob_softmax = self.prob_softmax(prob) # N X K X H X W

        # find the NN clusters, is not differentiable

        nn_max,nn_idx = th.max(prob_softmax,1,False)  # N x H x W
        nn_max_val = nn_max.unsqueeze(1).expand(-1,self.binary_dim,-1,-1).contiguous().view(_bs,self.binary_dim,_h*_w) # N x B x H*W

        # median values

        cen_idx = nn_idx.view(_bs,_h*_w)
        median = th.cat([th.index_select(self.cluster_median,1,cen_idx[i,:]).unsqueeze(0) for i in range(_bs)]) # N x B x H*W
        weighted_median = median * nn_max_val

        # binary signature

        binary_sig = self.thres(self.binarization(h_proj - weighted_median)).view(_bs,self.binary_dim,_h,_w) # N x B x H x W

        return binary_sig,nn_idx

Basically it turns a feature map of size N x C x H x W to a binarized feature map N x B x H x W, plus a cluster id which nearest neighbor search would give.

Here is my loss layer:

    def forward(self,q_b,p_b,n_b,q_cen,p_cen,n_cen,tf,idf,p_signal,n_signal,_nPos,_nNegs):

        """
        :param q_b: _N x B x H x W, binary signature for queries, Variable
        :param p_b: sum(_nPos) x B x H x W, binary signature for positives, Variable
        :param n_b: sum(_nNegs) x B x H x W, binary signature for negatives, Variable
        :param q_cen: _N x H x W, cluster idx each binary signature belongs to
        Variable contains th.cuda.LongTensor
        :param p_cen: sum(_nPos) x H x W cluster idx each binary signature belongs to
        Variable contains th.cuda.LongTensor
        :param n_cen: sum(_nNegs) x H x W cluster idx each binary signature belongs to
        Variable contains th.cuda.LongTensor
        :param tf: (_N+sum(_nPos)+sum(_nNegs),num_centers) numpy array
        :param idf: (num_centers,) numpy array
        :param p_signal: (sum(_nPos),), numpy array
        :param n_signal: (sum(_nNegs),), numpy array
        :param _nPos: list with length N, each entry stands for the number
        of actual positives for each query
        :param _nNegs: list with length N, each entry stands for the number
        of actual violated negatives for each query
        :return: triplet match degree
        """

        _bs = q_b.size(0) # _bs may not equal to args.batch_size
        _pos = p_b.size(0)
        _neg = n_b.size(0)
        h,w = q_b.size(2),q_b.size(3)

        q_cen = q_cen.data
        p_cen = p_cen.data
        n_cen = n_cen.data
        tf = th.from_numpy(tf).cuda()
        idf = th.from_numpy(idf).cuda()

        qp_b = th.cat([q_b[i].unsqueeze(0).expand(_nPos[i],-1,-1,-1) for i in range(_bs)])
        assert qp_b.size(0) == p_b.size(0)
        qp_hamming = th.abs(qp_b-p_b).sum(1) # sum(_nPos) x H x W
        qp_weight = th.exp(-1.0 * (th.pow(qp_hamming, 2.0) / (self.sigma ** 2)))  # _pos x H x W, Variable

        qp_cen = th.cat([q_cen[i].unsqueeze(0).expand(_nPos[i],-1,-1) for i in range(_bs)])
        assert qp_cen.size(0) == p_cen.size(0)

        eqp_bit = th.eq(qp_cen,p_cen) # _pos x H x W, ByteTensor
        smallp_bit = th.le(qp_hamming.data,self.thres) # _pos x H x W, ByteTensor

        pos_idf = th.zeros(_pos,h,w).cuda()
        pos_tf = th.ones(_pos,h,w).cuda()
        for i in range(_pos):
            for j in range(h):
                for k in range(w):
                    pos_idf[i,j,k] = idf[p_cen[i,j,k]]
                    pos_tf[i,j,k] = tf[i,p_cen[i,j,k]]

        eqp_bit = Variable(eqp_bit.float(),requires_grad=False)
        smallp_bit = Variable(smallp_bit.float(),requires_grad=False)
        pos_idf = Variable(pos_idf,requires_grad=False)
        pos_tf = Variable(pos_tf,requires_grad=False)

        qp_mask = eqp_bit * smallp_bit
        qp_survive_weight = qp_weight * qp_mask
        qp_match = (qp_survive_weight * th.pow(pos_idf, 2.0) / th.sqrt(pos_tf)).sum(-1).sum(-1) # _pos, Variable

        qn_b = th.cat([q_b[i].unsqueeze(0).expand(_nNegs[i], -1, -1, -1) for i in range(_bs)])
        assert qn_b.size(0) == n_b.size(0)
        qn_hamming = th.abs(qn_b - n_b).sum(1)  # _neg x H x W
        qn_weight = th.exp(-1.0 * (th.pow(qn_hamming, 2.0) / (self.sigma ** 2)))  # _neg x H x W, Variable

        qn_cen = th.cat([q_cen[i].unsqueeze(0).expand(_nNegs[i],-1,-1) for i in range(_bs)]) # _neg x H x W
        assert qn_cen.size(0) == n_cen.size(0)

        eqn_bit = th.eq(qn_cen,n_cen) # _neg x H x W, ByteTensor
        smalln_bit = th.le(qn_hamming.data,self.thres) # _neg x H x W, ByteTensor

        neg_idf = th.zeros(_neg, h, w).cuda()
        neg_tf = th.ones(_neg, h, w).cuda()
        for i in range(_neg):
            for j in range(h):
                for k in range(w):
                    neg_idf[i,j,k] = idf[n_cen[i,j,k]]
                    neg_tf[i,j,k] = tf[i,n_cen[i,j,k]]

        eqn_bit = Variable(eqn_bit.float(),requires_grad=False)
        smalln_bit = Variable(smalln_bit.float(),requires_grad=False)
        neg_idf = Variable(neg_idf,requires_grad=False)
        neg_tf = Variable(neg_tf,requires_grad=False)

        qn_mask = eqn_bit * smalln_bit
        qn_survive_weight = qn_mask * qn_weight
        qn_match = (qn_survive_weight * th.pow(neg_idf, 2.0) / th.sqrt(neg_tf)).sum(-1).sum(-1) # _neg, Variable

        pos_signal = Variable(th.from_numpy(p_signal).cuda(),requires_grad=False)
        neg_signal = Variable(th.from_numpy(n_signal).cuda(),requires_grad=False)

        n_base = [0 for i in range(_bs + 1)]
        p_base = [0 for i in range(_bs + 1)]
        for i in range(_bs):
            n_base[i + 1] = n_base[i] + _nNegs[i]
            p_base[i + 1] = p_base[i] + _nPos[i]

        pos_ = pos_signal * qp_match
        pos_loss = th.cat([pos_[p_base[i]:p_base[i+1]].unsqueeze(1).expand(-1,_nNegs[i]).contiguous().view(-1) for i in range(_bs)])

        neg_ = neg_signal * qn_match
        neg_loss = th.cat([neg_[n_base[i]:n_base[i+1]].unsqueeze(0).expand(_nPos[i],-1).contiguous().view(-1) for i in range(_bs)])

        dist_hinge = th.clamp(self.h_margin + neg_loss - pos_loss, min=0.0)
        loss = th.mean(dist_hinge)

        return loss

My main script looks like this:

import torch as th
from torch.nn import DataParallel

net = mynet().cuda()
net = DataParallel(net,device_ids=list(range(num_gpus)))

Forward pass succeeds, while using DataParallel would give arguments are on different gpus, disenabling DataParallel and use batch_size > 1 gives the error “the size of tensor a(4) doesn’t match with the size of tensor b(22) at non-singleton dimension”, using batch_size = 1, everything works fine.

Any help would be appreciated !