I have a Variational Autoencoder model with the following forward function:
def forward(self, x):
y = self.encoder(x)
means = self.linear_means(y)
log_var = self.linear_log_var(y)
z = reparameterize(means, log_var)
recon_x = self.decoder(z)
return VaeOutput(output=recon_x, mean=means, log_var=log_var)
VaeOutput is an object:
@dataclass
class VaeOutput:
output: torch.Tensor
log_var: torch.Tensor
mean: torch.Tensor
when I am training my model on a single GPU everything runs like a charm but whenever mu model is embedded with multi-gpus with nn.Dataparallel() i got an error about VaeOutput which is not serializable.
Traceback (most recent call last):
File "train.py", line 135, in <module>
main()
File "train.py", line 131, in main
model = model_runner.train(model, train_loader, learning_config, val_loader, feature_config)
File "/buzz-based-anomaly/utils/model_runner.py", line 375, in train
return self._train(model, train_dataloader, train_config, self._train_step, self._val_step,
File "/buzz-based-anomaly/utils/model_runner.py", line 455, in _train
train_epoch_loss = train_step_func(model, train_dataloader, optimizer, experiment, epoch, log_interval)
File "/buzz-based-anomaly/utils/model_runner.py", line 671, in _train_step
model_output = model(batch)
File "/root/miniconda3/envs/pytorch-cuda-env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/buzz-based-anomaly/utils/sm_data_parallel.py", line 10, in forward
return self.model(*input_data)
File "/root/miniconda3/envs/pytorch-cuda-env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/root/miniconda3/envs/pytorch-cuda-env/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 169, in forward
return self.gather(outputs, self.output_device)
File "/root/miniconda3/envs/pytorch-cuda-env/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 181, in gather
return gather(outputs, output_device, dim=self.dim)
File "/root/miniconda3/envs/pytorch-cuda-env/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 78, in gather
res = gather_map(outputs)
File "/root/miniconda3/envs/pytorch-cuda-env/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 73, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
TypeError: 'VaeOutput' object is not iterable
How should I approach that kind of problem? Is it possible to use custom model output with multigpu support?