Hi, all.
I encountered a very confusing problem.
It’s ok when I run model in single GPU. But it’s not work when I use multi-GPUs (single machine multi-GPUs) by model = torch.nn.DataParallel(model, device_ids=device_ids).
The puzzling thing is that the code is executable at the beginning. But after executing 5 batches (batch-szie = 100), an error occurs.
2019-12-01 16:17:15,151:Traceback (most recent call last):
2019-12-01 16:17:15,151: File "exp\GSTEG.py", line 25, in <module>
2019-12-01 16:17:15,151: main()
2019-12-01 16:17:15,151: File ".\main.py", line 51, in main
2019-12-01 16:17:15,151: s_top1,s_top5,o_top1,o_top5,v_top1,v_top5, sov_top1 = trainer.train(train_loader, base_model, logits_model, criterion, base_optimizer, logits_optimizer, epoch, opt)
2019-12-01 16:17:15,151: File ".\train.py", line 172, in train
2019-12-01 16:17:15,151: # s_output, o_output, v_output, loss = criterion(*((s, o, v, so, ov, vs, ss, oo, vv, so_t, ov_t, vs_t, os_t, vo_t, sv_t) + (s_target_var, o_target_var, v_target_var, meta)))
2019-12-01 16:17:15,151: File "C:\Users\gorvinchen\Miniconda3\envs\rcenet\lib\site-packages\torch\nn\modules\module.py", line 489, in __call__
2019-12-01 16:17:15,151: result = self.forward(*input, **kwargs)
2019-12-01 16:17:15,151: File "C:\Users\gorvinchen\Miniconda3\envs\rcenet\lib\site-packages\torch\nn\parallel\data_parallel.py", line 143, in forward
2019-12-01 16:17:15,151: outputs = self.parallel_apply(replicas, inputs, kwargs)
2019-12-01 16:17:15,151: File "C:\Users\gorvinchen\Miniconda3\envs\rcenet\lib\site-packages\torch\nn\parallel\data_parallel.py", line 153, in parallel_apply
2019-12-01 16:17:15,151: return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
2019-12-01 16:17:15,151: File "C:\Users\gorvinchen\Miniconda3\envs\rcenet\lib\site-packages\torch\nn\parallel\parallel_apply.py", line 83, in parallel_apply
2019-12-01 16:17:15,151: raise output
2019-12-01 16:17:15,151: File "C:\Users\gorvinchen\Miniconda3\envs\rcenet\lib\site-packages\torch\nn\parallel\parallel_apply.py", line 59, in _worker
2019-12-01 16:17:15,151: output = module(*input, **kwargs)
2019-12-01 16:17:15,151: File "C:\Users\gorvinchen\Miniconda3\envs\rcenet\lib\site-packages\torch\nn\modules\module.py", line 489, in __call__
2019-12-01 16:17:15,151: result = self.forward(*input, **kwargs)
2019-12-01 16:17:15,151: File ".\models\layers\AsyncTFCriterion.py", line 244, in forward
2019-12-01 16:17:15,151: s_msg, o_msg, v_msg = self.get_msg(idtime, 'past')
2019-12-01 16:17:15,151: File ".\models\layers\AsyncTFCriterion.py", line 147, in get_msg
2019-12-01 16:17:15,151: return self.mget(idtime, self.ns, self.no, self.nv, s_storage, o_storage, v_storage, cond, kernel)
2019-12-01 16:17:15,151: File ".\models\layers\AsyncTFCriterion.py", line 127, in mget
2019-12-01 16:17:15,151: s_out = [meta(ids, time, s_size, s_storage) for ids, time in idtime]
2019-12-01 16:17:15,151: File ".\models\layers\AsyncTFCriterion.py", line 127, in <listcomp>
2019-12-01 16:17:15,151: s_out = [meta(ids, time, s_size, s_storage) for ids, time in idtime]
2019-12-01 16:17:15,151: File ".\models\layers\AsyncTFCriterion.py", line 124, in meta
2019-12-01 16:17:15,151: if cond(t, t0)), 1. / self.decay)
2019-12-01 16:17:15,151: File ".\models\layers\AsyncTFCriterion.py", line 43, in avg
2019-12-01 16:17:15,167: item, w = next(iterator)
2019-12-01 16:17:15,167: File ".\models\layers\AsyncTFCriterion.py", line 124, in <genexpr>
2019-12-01 16:17:15,167: if cond(t, t0)), 1. / self.decay)
2019-12-01 16:17:15,167: File ".\models\layers\AsyncTFCriterion.py", line 145, in <lambda>
2019-12-01 16:17:15,167: cond = lambda t, t0: t < t0 if time == 'past' else t > t0
2019-12-01 16:17:15,167:RuntimeError: arguments are located on different GPUs at c:\a\w\1\s\windows\pytorch\aten\src\thc\generic/THCTensorMathCompareT.cu:7
My code is here
def avg(iterator, weight=1.):
# compounding weight
item, w = next(iterator)
total = item.clone() * w
n = 1.
for i, (item, w) in enumerate(iterator):
w1 = 1. * weight**(i + 1)
total += item * w1 * w
n += w1
return total / n
class MessagePassing(object):
# Class for keeping track of messages across frames
def __init__(self, maxsize, w_temporal, w_spatio, decay, sigma, ns, no, nv):
super(MessagePassing, self).__init__()
self.maxsize = maxsize
self.w_temporal = w_temporal
self.w_spatio = w_spatio
self.decay = decay
self.sigma = sigma
self.s_storage = {}
self.s_storage_gt = {}
self.o_storage = {}
self.o_storage_gt = {}
self.v_storage = {}
self.v_storage_gt = {}
self.training = self.training if hasattr(self, 'training') else True
self.ns = ns
self.no = no
self.nv = nv
def mget(self, idtime, s_size, o_size, v_size, s_storage, o_storage, v_storage, cond=lambda t, t0: True, kernel=lambda t, t0: 1):
# get message using condition on the timestamps
def meta(ids, t0, size, storage):
try:
return avg(((y, kernel(t, t0)) for t, y in storage[ids]
if cond(t, t0)), 1. / self.decay)
except (StopIteration, KeyError):
return torch.zeros(size)
s_out = [meta(ids, time, s_size, s_storage) for ids, time in idtime]
o_out = [meta(ids, time, o_size, o_storage) for ids, time in idtime]
v_out = [meta(ids, time, v_size, v_storage) for ids, time in idtime]
return Variable(torch.stack(s_out, 0).cuda()), Variable(torch.stack(o_out, 0).cuda()), Variable(torch.stack(v_out, 0).cuda())
def get_msg(self, idtime, time='past', s_storage=None, o_storage=None, v_storage=None):
s_storage = self.s_storage if s_storage is None else s_storage
o_storage = self.o_storage if o_storage is None else o_storage
v_storage = self.v_storage if v_storage is None else v_storage
cond = lambda t, t0: t < t0 if time == 'past' else t > t0
kernel = lambda t, t0: math.exp(-float(t - t0)**2 / (2 * self.sigma**2))
return self.mget(idtime, self.ns, self.no, self.nv, s_storage, o_storage, v_storage, cond, kernel)
def get_gt_msg(self, idtime, time='past'):
return self.get_msg(idtime, time, self.s_storage_gt, self.o_storage_gt, self.v_storage_gt)
def mset(self, s_msg, o_msg, v_msg, idtime, s_storage, o_storage, v_storage):
# keep a queue of size maxsize for each id
# messages are stored in normal space
# queue for each id is stored in the order in which the messages were stored
for s_m, o_m, v_m, (ids, time) in sorted(zip(s_msg, o_msg, v_msg, idtime), key=lambda x: random()):
if ids not in s_storage:
s_storage[ids] = []
if ids not in o_storage:
o_storage[ids] = []
if ids not in v_storage:
v_storage[ids] = []
s_data = s_m if type(s_m) is not torch.Tensor else s_m.data.cpu()
o_data = o_m if type(o_m) is not torch.Tensor else o_m.data.cpu()
v_data = v_m if type(v_m) is not torch.Tensor else v_m.data.cpu()
s_storage[ids].append((time, s_data))
o_storage[ids].append((time, o_data))
v_storage[ids].append((time, v_data))
if len(s_storage[ids]) > self.maxsize:
del s_storage[ids][0]
if len(o_storage[ids]) > self.maxsize:
del o_storage[ids][0]
if len(v_storage[ids]) > self.maxsize:
del v_storage[ids][0]
def set_msg(self, qs, qo, qv, idtime):
self.mset(qs, qo, qv, idtime, self.s_storage, self.o_storage, self.v_storage)
def set_gt_msg(self, s_target, o_target, v_target, idtime):
s_x = s_target.data.cpu()
o_x = o_target.data.cpu()
v_x = v_target.data.cpu()
self.mset(s_x, o_x, v_x, idtime, self.s_storage_gt, self.o_storage_gt, self.v_storage_gt)
class AsyncTFCriterion(nn.Module, MessagePassing):
def __init__(self, args):
memory_size = 20
w_temporal = 0.1
w_spatio = 0.1
memory_decay = 1.0
sigma = 300
MessagePassing.__init__(self, memory_size, w_temporal, w_spatio, memory_decay, sigma, args.s_class, args.o_class, args.v_class)
nn.Module.__init__(self)
self.msg_n = 5
self.cross_loss = nn.CrossEntropyLoss() # for s
self.bce_loss = nn.BCEWithLogitsLoss() # for c, o, v
self.BalanceLabels = BalanceLabels()
self.winsmooth = 1
def forward(self, s, o, v, so, ov, vs, ss, oo, vv, so_t, ov_t, vs_t, os_t, vo_t, sv_t, s_target, o_target, v_target, id_time, n=1, synchronous=False):
if o_target.dim() == 1:
print('converting Nx1 target to NxC')
o_target = Variable(gtmat(o.shape, o_target.data.long()))
if v_target.dim() == 1:
print('converting Nx1 target to NxC')
v_target = Variable(gtmat(v.shape, v_target.data.long()))
o_target = o_target.float()
v_target = v_target.float()
idtime = list(zip(id_time['id'], id_time['time']))
s_msg, o_msg, v_msg = self.get_msg(idtime, 'past')
s_fmsg, o_fmsg, v_fmsg = self.get_msg(idtime, 'future')
s_loss = self.cross_loss(s, s_target)
_qs = torch.nn.Softmax(dim = 1)(s)
o_loss = self.bce_loss(o, o_target)
_qo = torch.nn.Sigmoid()(o)
v_loss = self.bce_loss(v, v_target)
_qv = torch.nn.Sigmoid()(v)
qs_before_softmax = s.clone()