I create a model that returns a dictionary. However, when I try to train it with DataParallel, it seems like the output dictionary cannot be properly gathered.
line 62, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
TypeError: zip argument #1 must support iteration
I realize at 1.0.0 the output of model supports dictionary with multi gpu, thanks to this.
def gather_map(outputs):
out = outputs[0]
if isinstance(out, torch.Tensor):
return Gather.apply(target_device, dim, *outputs)
if out is None:
return None
if isinstance(out, dict):
if not all((len(out) == len(d) for d in outputs)):
raise ValueError('All dicts must have the same number of keys')
return type(out)(((k, gather_map([d[k] for d in outputs]))
for k in out))
return type(out)(map(gather_map, zip(*outputs)))
However, we only have support for Tensor and dictionary. Be aware we don’t support Numbers.
2 Likes