When I use a single GPU, all parameters in the model get non-zero gradients;
but after wrapping my model with nn.DataParallel, the gradients of some part (RNN in my case) of the model become zero.
To Reproduce
Steps to reproduce the behavior:
test.py:
import sys
import torch
from torch import nn
import torch.nn.functional as F
from torchexp.stat import RunningAvgDict
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
import numpy as np
import math
gpus = list(map(int, sys.argv[1].split(',')))
def to_device(m, x):
"""Send tensor into the device of the module.
Args:
m (torch.nn.Module): Torch module.
x (Tensor): Torch tensor.
Returns:
Tensor: Torch tensor located in the same place as torch module.
"""
assert isinstance(m, torch.nn.Module)
device = next(m.parameters()).device
return x.to(device)
def reset_backward_rnn_state(states):
"""Sets backward BRNN states to zeroes - useful in processing of sliding windows over the inputs"""
if isinstance(states, (list, tuple)):
for state in states:
state[1::2] = 0.
else:
states[1::2] = 0.
return states
class RNNP(nn.Module):
"""RNN with projection layer module
:param int idim: dimension of inputs, default: odim of CNN
:param int nlayers: Number of layers
:param int enc_dim: hidden dim of BLSTMP
:param int proj_dim: proj. dim of BLSTMP
:param int odim: ouput of encoder dimension (default: proj_dim)
"""
def __init__(self, idim, nlayers, enc_dim, proj_dim, odim):
super(RNNP, self).__init__()
self.enc_dim = enc_dim
self.proj_dim = proj_dim
self.odim = odim
self.nlayers = nlayers
self.rnn0 = nn.LSTM(idim, enc_dim, num_layers=1, bidirectional=True, batch_first=True)
self.bt0 = nn.Linear(2 * enc_dim, proj_dim)
for i in range(1,self.nlayers):
rnn = nn.LSTM(proj_dim, enc_dim, num_layers=1, bidirectional=True, batch_first=True)
setattr(self, f"rnn{i}", rnn)
if i == self.nlayers - 1:
setattr(self, f"bt{i}", nn.Linear(2 * enc_dim, odim))
else:
setattr(self, f"bt{i}", nn.Linear(2 * enc_dim, proj_dim))
def forward(self, xs_pad, enc_lens, prev_state=None):
"""RNNP forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, cnn_o_dim)
:param list of int enc_lens
:param torch.Tensor prev_state: batch of previous RNN states
:return: batch of hidden state sequences (B, Tmax, odim)
:rtype: torch.Tensor
"""
hid_states = []
for i in range(self.nlayers):
total_length = xs_pad.size(1)
xs_pack = pack_padded_sequence(xs_pad, enc_lens, batch_first = True)
rnn = getattr(self, f"rnn{i}")
rnn.flatten_parameters()
if prev_state is not None:
prev_state = reset_backward_rnn_state(prev_state)
ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[i])
hid_states.append(states)
ys_pad, enc_lens= pad_packed_sequence(ys, batch_first=True, total_length=total_length)
# ys_pad: (B, T, enc_dim)
projected = getattr(self, f"bt{i}")(ys_pad.contiguous().view(-1, ys_pad.size(2))) #(B*T, proj_dim)
xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1))
return xs_pad, to_device(self,enc_lens), hid_states
class BlstmEncoder(nn.Module):
"""
:param int idim: dim of input
:param int enc_dim: hidden dim for BLSTM
:param int proj_dim: projection dim for BLSTMP
:param int odim: encoder output dimension (usually set as proj_dim)
"""
def __init__(self, idim, enc_dim, proj_dim, odim):
super(BlstmEncoder, self).__init__()
self.idim = idim
self.enc_dim = enc_dim
self.proj_dim = proj_dim
self.odim = odim
self.cnn_model = nn.Sequential(
nn.Conv2d(1, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, ceil_mode=True),
)
cnn_o_dim = np.ceil(np.array(idim, dtype=np.float32) / 2)
cnn_o_dim = int(cnn_o_dim) * 128
self.nlayers = 1
self.blstm = RNNP(cnn_o_dim, self.nlayers, enc_dim, proj_dim, odim)
def forward(self, xs_pad, ilens, prev_state=None):
xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), 1, xs_pad.size(2)).transpose(1,2)
xs_pad = self.cnn_model(xs_pad)
if torch.is_tensor(ilens):
ilens = ilens.cpu().numpy()
else:
ilens = np.array(ilens, dtype=np.float32)
enc_lens = np.array(np.ceil(ilens/2), dtype=np.int64)
enc_lens = np.array(np.ceil(np.array(enc_lens, dtype=np.float32)/2), dtype=np.int64).tolist()
xs_pad = xs_pad.transpose(1,2)
xs_pad = xs_pad.contiguous().view(xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3))
out, enc_lens, cur_state = self.blstm(xs_pad, enc_lens, prev_state)
return out, enc_lens, cur_state
class MultiHeadBLSTM(nn.Module):
def __init__(self):
super(MultiHeadBLSTM, self).__init__()
self.idim = 20
self.odim = 10
enc_o_dim = 128
self.encoder = BlstmEncoder(self.idim, enc_o_dim, enc_o_dim, enc_o_dim)
self.head = nn.Linear(enc_o_dim, self.odim)
@property
def device(self):
return next(self.parameters()).device
def forward(self, xs_pad, ilens):
assert xs_pad.size(0) == ilens.size(0), "Batch size mismatch"
batch_size = xs_pad.size(0)
# Put data to device
xs_pad = to_device(self,xs_pad)
ilens = to_device(self,ilens)
enc, enc_lens, _ = self.encoder(xs_pad, ilens)
out = self.head(enc)
return out, enc_lens
def run_batch(asr_model, ctc_loss, x, ilens, y_true, olens):
pred, enc_lens = asr_model(x, ilens)
if len(gpus) == 1:
olens = to_device(asr_model, olens)
else:
olens = to_device(asr_model.module, olens)
pred = F.log_softmax(pred, dim=-1)
loss = ctc_loss(pred.transpose(0,1).contiguous(),
y_true.cuda().to(dtype=torch.long),
enc_lens.cpu().to(dtype=torch.long),
olens.cpu().to(dtype=torch.long))
print('ctc loss: ', loss)
info = { 'loss': loss.item() }
loss.backward()
return info
def run_task(asr_model, ctc_loss, x, ilens, ys, olens):
batch_size = len(ys)
info = run_batch(asr_model, ctc_loss, x, ilens, ys, olens)
# torch.set_printoptions(profile="full")
if len(gpus) == 1:
count = 0
for n, p in asr_model.named_parameters():
print('grad of param {}, {}: '.format(count, n), p.grad)#[0, 0, 0])
count += 1
print('num of params with grad: ', count)
else:
count = 0
for n, p in asr_model.module.named_parameters():
print('grad of param {}, {}: '.format(count, n), p.grad)#[0, 0, 0])
count += 1
print('num of params with grad: ', count)
if __name__ == '__main__':
asr_model = MultiHeadBLSTM().cuda()
if len(gpus) > 1:
asr_model = nn.DataParallel(asr_model, device_ids=gpus)
ctc_loss = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
x = torch.rand(4, 20, 20)#.cuda()
ilens = torch.tensor([20, 20, 20, 20])#.cuda()
ys = [torch.tensor([1, 1, 1]),
torch.tensor([1, 1, 1, 1]),
torch.tensor([1, 1, 1, 1, 1]),
torch.tensor([1, 1, 1, 1, 1, 1])]
olens = torch.tensor([3, 4, 5, 6])#.cuda()
y_true = torch.cat(ys)#.cuda()
run_task(asr_model, ctc_loss, x, ilens, y_true, olens)
If I run python3 test.py 0
(gpu devices = [0]), everything is fine;
however, if I run python3 test.py 0,1
(gpu devices = [0, 1]), rnn receives no gradient.
Expected behavior
All parameters of both single- and multi-gpu models should have non-zero gradients.
Environment
PyTorch version: 1.4.1
Is debug build: No
CUDA used to build PyTorch: 10.2
OS: Arch Linux
GCC version: (Arch Linux 9.3.0-1) 9.3.0
CMake version: version 3.17.0
Python version: 3.8
Is CUDA available: Yes
CUDA runtime version: 10.2.89
GPU models and configuration:
GPU 0: GeForce GTX TITAN X
GPU 1: GeForce GTX 1060 6GB
Nvidia driver version: 440.64
cuDNN version: /usr/lib/libcudnn.so.7.6.5
Versions of relevant libraries:
[pip3] numpy==1.18.2
[pip3] torch==1.4.1
[pip3] torchexp==0.1.0
[pip3] torchvision==0.5.0a0+c81ac87
[conda] Could not collect