OOM when resuming with DistributedDataParalel

I’m hitting OOM when trying to resume training a seq2seq model with batchnorm and dropout on distributed data parallel.

Although I can start and train the model for hours with multiple GPUs I can only resume with half the number of GPUs I used at the start of training.

The exact error is below. For the times I ran it crashes on computing Bernoulli for dropout…

THCudaCheck FAIL file=/opt/pytorch/pytorch/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory
/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/_utils.py:86: UserWarning: 'async' is deprecated; use 'non_blocking'
  warnings.warn("'async' is deprecated; use 'non_blocking'")
Traceback (most recent call last):
  File "train.py", line 272, in <module>
    args.warm_start, args.n_gpus, args.rank, args.group_name, hparams)
  File "train.py", line 197, in train
    y_pred = model(x)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/workspace/distributed.py", line 101, in forward
    return self.module(*inputs, **kwargs)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/workspace/model.py", line 512, in forward
    mel_outputs_postnet = self.postnet(mel_outputs)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/workspace/model.py", line 144, in forward
    x = self.dropout(F.tanh(self.convolutions[i](x)))
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/dropout.py", line 46, in forward
    return F.dropout(input, self.p, self.training, self.inplace)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/functional.py", line 536, in dropout
    return _functions.dropout.Dropout.apply(input, p, training, inplace)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/_functions/dropout.py", line 41, in forward
    ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
RuntimeError: cuda runtime error (2) : out of memory at /opt/pytorch/pytorch/aten/src/THC/generic/THCStorage.cu:58

Any thoughts on this, @SimonW ?