yhz_yhz
(yhz yhz)
October 19, 2020, 4:29pm
1
I used LSTMCell for decoders .And my decoder module looks like this :decoders = nn.ModuleList([Decoder(args, gpu) for i in range(args.max_len)])
I changeded it for parallel using
decoders = nn.parallel.DistributedDataParallel(decoders,
device_ids=[gpu])
And when I wrote this
output = decoders [i] (input)
the error raised
TypeError: ‘DistributedDataParallel’ object does not support indexing
How can I fix this?
mrshenli
(Shen Li)
October 19, 2020, 5:53pm
2
Hey @yhz_yhz , if you would like to access the original module that you passed to DistributedDataParallel
ctor, you can use decoders.module
. See the code below.
output_device = device_ids[0]
self.output_device = _get_device_index(output_device, True)
if process_group is None:
self.process_group = _get_default_group()
else:
self.process_group = process_group
self.dim = dim
self.module = module
self.device = list(self.module.parameters())[0].device
self.broadcast_buffers = broadcast_buffers
self.find_unused_parameters = find_unused_parameters
self.require_backward_grad_sync = True
self.require_forward_param_sync = True
self.ddp_join_enabled = False
self.gradient_as_bucket_view = gradient_as_bucket_view
if hasattr(module, '_ddp_params_and_buffers_to_ignore'):
self.parameters_to_ignore = module._ddp_params_and_buffers_to_ignore
else: