Gradients of some part (RNN in my case) of the model become zero after the model is wrapped with nn.DataParallel

When I use a single GPU, all parameters in the model get non-zero gradients;
but after wrapping my model with nn.DataParallel, the gradients of some part (RNN in my case) of the model become zero.

To Reproduce

Steps to reproduce the behavior:

import sys

import torch
from torch import nn
import torch.nn.functional as F
from torchexp.stat import RunningAvgDict
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence

import numpy as np
import math

gpus = list(map(int, sys.argv[1].split(',')))

def to_device(m, x):
    """Send tensor into the device of the module.

        m (torch.nn.Module): Torch module.
        x (Tensor): Torch tensor.

        Tensor: Torch tensor located in the same place as torch module.

    assert isinstance(m, torch.nn.Module)
    device = next(m.parameters()).device

def reset_backward_rnn_state(states):
    """Sets backward BRNN states to zeroes - useful in processing of sliding windows over the inputs"""
    if isinstance(states, (list, tuple)):
        for state in states:
            state[1::2] = 0.
        states[1::2] = 0.
    return states

class RNNP(nn.Module):
    """RNN with projection layer module

    :param int idim: dimension of inputs, default: odim of CNN
    :param int nlayers: Number of layers
    :param int enc_dim: hidden dim of BLSTMP
    :param int proj_dim: proj. dim of BLSTMP
    :param int odim: ouput of encoder dimension (default: proj_dim)

    def __init__(self, idim, nlayers, enc_dim, proj_dim, odim):
        super(RNNP, self).__init__()
        self.enc_dim = enc_dim
        self.proj_dim = proj_dim
        self.odim = odim
        self.nlayers = nlayers

        self.rnn0 = nn.LSTM(idim, enc_dim, num_layers=1, bidirectional=True, batch_first=True)
        self.bt0 = nn.Linear(2 * enc_dim, proj_dim)

        for i in range(1,self.nlayers):
            rnn = nn.LSTM(proj_dim, enc_dim, num_layers=1, bidirectional=True, batch_first=True)
            setattr(self, f"rnn{i}", rnn)

            if i == self.nlayers - 1:
                setattr(self, f"bt{i}", nn.Linear(2 * enc_dim, odim))
                setattr(self, f"bt{i}", nn.Linear(2 * enc_dim, proj_dim))

    def forward(self, xs_pad, enc_lens, prev_state=None):
        """RNNP forward

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, cnn_o_dim)
        :param list of int enc_lens
        :param torch.Tensor prev_state: batch of previous RNN states
        :return: batch of hidden state sequences (B, Tmax, odim)
        :rtype: torch.Tensor
        hid_states = []

        for i in range(self.nlayers):
            total_length = xs_pad.size(1)
            xs_pack = pack_padded_sequence(xs_pad, enc_lens, batch_first = True)
            rnn = getattr(self, f"rnn{i}")
            if prev_state is not None:
                prev_state = reset_backward_rnn_state(prev_state)
            ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[i])

            ys_pad, enc_lens= pad_packed_sequence(ys, batch_first=True, total_length=total_length)
            # ys_pad: (B, T, enc_dim)

            projected = getattr(self, f"bt{i}")(ys_pad.contiguous().view(-1, ys_pad.size(2))) #(B*T, proj_dim)
            xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1))

        return xs_pad, to_device(self,enc_lens), hid_states

class BlstmEncoder(nn.Module):
    :param int idim: dim of input
    :param int enc_dim: hidden dim for BLSTM
    :param int proj_dim: projection dim for BLSTMP
    :param int odim: encoder output dimension (usually set as proj_dim)

    def __init__(self, idim, enc_dim, proj_dim, odim):
        super(BlstmEncoder, self).__init__()
        self.idim = idim
        self.enc_dim = enc_dim
        self.proj_dim = proj_dim
        self.odim = odim

        self.cnn_model = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
        cnn_o_dim = np.ceil(np.array(idim, dtype=np.float32) / 2)
        cnn_o_dim = int(cnn_o_dim) * 128

        self.nlayers = 1
        self.blstm = RNNP(cnn_o_dim, self.nlayers, enc_dim, proj_dim, odim)

    def forward(self, xs_pad, ilens, prev_state=None):
        xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), 1, xs_pad.size(2)).transpose(1,2)
        xs_pad = self.cnn_model(xs_pad)

        if torch.is_tensor(ilens):
            ilens = ilens.cpu().numpy()
            ilens = np.array(ilens, dtype=np.float32)

        enc_lens = np.array(np.ceil(ilens/2), dtype=np.int64)
        enc_lens = np.array(np.ceil(np.array(enc_lens, dtype=np.float32)/2), dtype=np.int64).tolist()

        xs_pad = xs_pad.transpose(1,2)
        xs_pad = xs_pad.contiguous().view(xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3))

        out, enc_lens, cur_state = self.blstm(xs_pad, enc_lens, prev_state)

        return out, enc_lens, cur_state

class MultiHeadBLSTM(nn.Module):

    def __init__(self):
        super(MultiHeadBLSTM, self).__init__()

        self.idim = 20
        self.odim = 10

        enc_o_dim = 128

        self.encoder = BlstmEncoder(self.idim, enc_o_dim, enc_o_dim, enc_o_dim)
        self.head = nn.Linear(enc_o_dim, self.odim)

    def device(self):
        return next(self.parameters()).device

    def forward(self, xs_pad, ilens):
        assert xs_pad.size(0) == ilens.size(0), "Batch size mismatch"
        batch_size = xs_pad.size(0)

        # Put data to device
        xs_pad = to_device(self,xs_pad)
        ilens = to_device(self,ilens)

        enc, enc_lens, _ = self.encoder(xs_pad, ilens)
        out = self.head(enc)

        return out, enc_lens

def run_batch(asr_model, ctc_loss, x, ilens, y_true, olens):

    pred, enc_lens = asr_model(x, ilens)
    if len(gpus) == 1:
        olens = to_device(asr_model, olens)
        olens = to_device(asr_model.module, olens)
    pred = F.log_softmax(pred, dim=-1)

    loss = ctc_loss(pred.transpose(0,1).contiguous(),
    print('ctc loss: ', loss)

    info = { 'loss': loss.item() }

    return info

def run_task(asr_model, ctc_loss, x, ilens, ys, olens):
    batch_size = len(ys)
    info = run_batch(asr_model, ctc_loss, x, ilens, ys, olens)

    # torch.set_printoptions(profile="full")
    if len(gpus) == 1:
        count = 0
        for n, p in asr_model.named_parameters():
            print('grad of param {}, {}: '.format(count, n), p.grad)#[0, 0, 0])
            count += 1
        print('num of params with grad: ', count)
        count = 0
        for n, p in asr_model.module.named_parameters():
            print('grad of param {}, {}: '.format(count, n), p.grad)#[0, 0, 0])
            count += 1
        print('num of params with grad: ', count)

if __name__ == '__main__':
    asr_model = MultiHeadBLSTM().cuda()
    if len(gpus) > 1:
        asr_model = nn.DataParallel(asr_model, device_ids=gpus)

    ctc_loss = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

    x = torch.rand(4, 20, 20)#.cuda() 
    ilens = torch.tensor([20, 20, 20, 20])#.cuda()
    ys = [torch.tensor([1, 1, 1]),
          torch.tensor([1, 1, 1, 1]),
          torch.tensor([1, 1, 1, 1, 1]),
          torch.tensor([1, 1, 1, 1, 1, 1])]
    olens = torch.tensor([3, 4, 5, 6])#.cuda()

    y_true =

    run_task(asr_model, ctc_loss, x, ilens, y_true, olens)

If I run python3 0 (gpu devices = [0]), everything is fine;
however, if I run python3 0,1 (gpu devices = [0, 1]), rnn receives no gradient.

Expected behavior

All parameters of both single- and multi-gpu models should have non-zero gradients.


PyTorch version: 1.4.1
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: Arch Linux
GCC version: (Arch Linux 9.3.0-1) 9.3.0
CMake version: version 3.17.0

Python version: 3.8
Is CUDA available: Yes
CUDA runtime version: 10.2.89
GPU models and configuration: 
GPU 1: GeForce GTX 1060 6GB

Nvidia driver version: 440.64
cuDNN version: /usr/lib/

Versions of relevant libraries:
[pip3] numpy==1.18.2
[pip3] torch==1.4.1
[pip3] torchexp==0.1.0
[pip3] torchvision==0.5.0a0+c81ac87
[conda] Could not collect

There is a bug in version 1.4

After using nightly version, the issue is solved.