I implement a custom layer and loss layer, when calling loss.backward(), it shows an error that “tensor of a(4) doesn’t match with tensor b(22) at non-singleton dimension” or “arguments are on different gpus”, depends on whether I’m using DataParallel module.
Here is my layer
class hamming(nn.Module):
def __init__(self,args):
super(hamming,self).__init__()
self.binary_dim = args.binary_dim
self.channels = 512
self.num_centers = args.num_centers
self.linear_project = nn.Linear(self.channels,self.binary_dim,False)
self.cluster_prob = nn.Conv2d(self.channels,self.num_centers,1)
self.prob_softmax = nn.Softmax2d()
self.cluster_median = nn.Parameter(th.rand(self.binary_dim,self.num_centers),True) # B x K
self.binarization = binary()
self.thres = nn.Threshold(0,0,True)
nn.init.xavier_uniform(self.cluster_median,1)
def forward(self,feature_map):
# project
size = feature_map.size()
_bs,_d,_h,_w = size[0],size[1],size[2],size[3]
perm = feature_map.permute(0, 2, 3, 1) # N x H x W x D
proj = self.linear_project(perm) # N x H x W x B
h_proj = proj.permute(0, 3, 1, 2).contiguous().view(_bs, self.binary_dim, _h * _w) # N x B x H*W
# assign
prob = self.cluster_prob(feature_map)
prob_softmax = self.prob_softmax(prob) # N X K X H X W
# find the NN clusters, is not differentiable
nn_max,nn_idx = th.max(prob_softmax,1,False) # N x H x W
nn_max_val = nn_max.unsqueeze(1).expand(-1,self.binary_dim,-1,-1).contiguous().view(_bs,self.binary_dim,_h*_w) # N x B x H*W
# median values
cen_idx = nn_idx.view(_bs,_h*_w)
median = th.cat([th.index_select(self.cluster_median,1,cen_idx[i,:]).unsqueeze(0) for i in range(_bs)]) # N x B x H*W
weighted_median = median * nn_max_val
# binary signature
binary_sig = self.thres(self.binarization(h_proj - weighted_median)).view(_bs,self.binary_dim,_h,_w) # N x B x H x W
return binary_sig,nn_idx
Basically it turns a feature map of size N x C x H x W to a binarized feature map N x B x H x W, plus a cluster id which nearest neighbor search would give.
Here is my loss layer:
def forward(self,q_b,p_b,n_b,q_cen,p_cen,n_cen,tf,idf,p_signal,n_signal,_nPos,_nNegs):
"""
:param q_b: _N x B x H x W, binary signature for queries, Variable
:param p_b: sum(_nPos) x B x H x W, binary signature for positives, Variable
:param n_b: sum(_nNegs) x B x H x W, binary signature for negatives, Variable
:param q_cen: _N x H x W, cluster idx each binary signature belongs to
Variable contains th.cuda.LongTensor
:param p_cen: sum(_nPos) x H x W cluster idx each binary signature belongs to
Variable contains th.cuda.LongTensor
:param n_cen: sum(_nNegs) x H x W cluster idx each binary signature belongs to
Variable contains th.cuda.LongTensor
:param tf: (_N+sum(_nPos)+sum(_nNegs),num_centers) numpy array
:param idf: (num_centers,) numpy array
:param p_signal: (sum(_nPos),), numpy array
:param n_signal: (sum(_nNegs),), numpy array
:param _nPos: list with length N, each entry stands for the number
of actual positives for each query
:param _nNegs: list with length N, each entry stands for the number
of actual violated negatives for each query
:return: triplet match degree
"""
_bs = q_b.size(0) # _bs may not equal to args.batch_size
_pos = p_b.size(0)
_neg = n_b.size(0)
h,w = q_b.size(2),q_b.size(3)
q_cen = q_cen.data
p_cen = p_cen.data
n_cen = n_cen.data
tf = th.from_numpy(tf).cuda()
idf = th.from_numpy(idf).cuda()
qp_b = th.cat([q_b[i].unsqueeze(0).expand(_nPos[i],-1,-1,-1) for i in range(_bs)])
assert qp_b.size(0) == p_b.size(0)
qp_hamming = th.abs(qp_b-p_b).sum(1) # sum(_nPos) x H x W
qp_weight = th.exp(-1.0 * (th.pow(qp_hamming, 2.0) / (self.sigma ** 2))) # _pos x H x W, Variable
qp_cen = th.cat([q_cen[i].unsqueeze(0).expand(_nPos[i],-1,-1) for i in range(_bs)])
assert qp_cen.size(0) == p_cen.size(0)
eqp_bit = th.eq(qp_cen,p_cen) # _pos x H x W, ByteTensor
smallp_bit = th.le(qp_hamming.data,self.thres) # _pos x H x W, ByteTensor
pos_idf = th.zeros(_pos,h,w).cuda()
pos_tf = th.ones(_pos,h,w).cuda()
for i in range(_pos):
for j in range(h):
for k in range(w):
pos_idf[i,j,k] = idf[p_cen[i,j,k]]
pos_tf[i,j,k] = tf[i,p_cen[i,j,k]]
eqp_bit = Variable(eqp_bit.float(),requires_grad=False)
smallp_bit = Variable(smallp_bit.float(),requires_grad=False)
pos_idf = Variable(pos_idf,requires_grad=False)
pos_tf = Variable(pos_tf,requires_grad=False)
qp_mask = eqp_bit * smallp_bit
qp_survive_weight = qp_weight * qp_mask
qp_match = (qp_survive_weight * th.pow(pos_idf, 2.0) / th.sqrt(pos_tf)).sum(-1).sum(-1) # _pos, Variable
qn_b = th.cat([q_b[i].unsqueeze(0).expand(_nNegs[i], -1, -1, -1) for i in range(_bs)])
assert qn_b.size(0) == n_b.size(0)
qn_hamming = th.abs(qn_b - n_b).sum(1) # _neg x H x W
qn_weight = th.exp(-1.0 * (th.pow(qn_hamming, 2.0) / (self.sigma ** 2))) # _neg x H x W, Variable
qn_cen = th.cat([q_cen[i].unsqueeze(0).expand(_nNegs[i],-1,-1) for i in range(_bs)]) # _neg x H x W
assert qn_cen.size(0) == n_cen.size(0)
eqn_bit = th.eq(qn_cen,n_cen) # _neg x H x W, ByteTensor
smalln_bit = th.le(qn_hamming.data,self.thres) # _neg x H x W, ByteTensor
neg_idf = th.zeros(_neg, h, w).cuda()
neg_tf = th.ones(_neg, h, w).cuda()
for i in range(_neg):
for j in range(h):
for k in range(w):
neg_idf[i,j,k] = idf[n_cen[i,j,k]]
neg_tf[i,j,k] = tf[i,n_cen[i,j,k]]
eqn_bit = Variable(eqn_bit.float(),requires_grad=False)
smalln_bit = Variable(smalln_bit.float(),requires_grad=False)
neg_idf = Variable(neg_idf,requires_grad=False)
neg_tf = Variable(neg_tf,requires_grad=False)
qn_mask = eqn_bit * smalln_bit
qn_survive_weight = qn_mask * qn_weight
qn_match = (qn_survive_weight * th.pow(neg_idf, 2.0) / th.sqrt(neg_tf)).sum(-1).sum(-1) # _neg, Variable
pos_signal = Variable(th.from_numpy(p_signal).cuda(),requires_grad=False)
neg_signal = Variable(th.from_numpy(n_signal).cuda(),requires_grad=False)
n_base = [0 for i in range(_bs + 1)]
p_base = [0 for i in range(_bs + 1)]
for i in range(_bs):
n_base[i + 1] = n_base[i] + _nNegs[i]
p_base[i + 1] = p_base[i] + _nPos[i]
pos_ = pos_signal * qp_match
pos_loss = th.cat([pos_[p_base[i]:p_base[i+1]].unsqueeze(1).expand(-1,_nNegs[i]).contiguous().view(-1) for i in range(_bs)])
neg_ = neg_signal * qn_match
neg_loss = th.cat([neg_[n_base[i]:n_base[i+1]].unsqueeze(0).expand(_nPos[i],-1).contiguous().view(-1) for i in range(_bs)])
dist_hinge = th.clamp(self.h_margin + neg_loss - pos_loss, min=0.0)
loss = th.mean(dist_hinge)
return loss
My main script looks like this:
import torch as th
from torch.nn import DataParallel
net = mynet().cuda()
net = DataParallel(net,device_ids=list(range(num_gpus)))
Forward pass succeeds, while using DataParallel would give arguments are on different gpus, disenabling DataParallel and use batch_size > 1 gives the error “the size of tensor a(4) doesn’t match with the size of tensor b(22) at non-singleton dimension”, using batch_size = 1, everything works fine.
Any help would be appreciated !