Hi,
I’ve written a CRNN model, which is basically a stack of CNN layers followed by an LSTM and a FC layer. The model was written in pytorch 0.4.1. I tested the model using some toy data, like this:
from torch.nn.parallel import data_parallel as par
model = CRNN(rnn_hidden_size = 32, bidirectional = True, classes_num = 20).type(gpu_dtype)
loss_fn = nn.BCELoss().type(gpu_dtype)
x = torch.randn(128,1,31,64).cuda()
y = torch.zeros(128, 20).cuda()
x = par(model, inputs = x, device_ids = [0])
loss = loss_fn(x, y)
loss.backward()
If I use only 1 GPU, it works fine. But when I used multiple GPUs, it reports the error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
If I commented out the 2d batch normalization in Conv_GLU, then it works fine for both 1 or multiple GPUs.
Can anyone spot any problem where the inplace operation happens, and why it does not work with multiple GPUs with batch normalization in Conv_GLU module?
The entire model is as below. Thank you so much for any help.
class Conv_GLU(nn.Module):
def init(self, in_channels, out_channels, stride_freq = 1):
super(Conv_GLU, self).init()
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=2*out_channels,
kernel_size=(3, 3), stride=(1, stride_freq),
padding=(1, 1), bias=False)
self.bn = nn.BatchNorm2d(num_features = 2*out_channels)
self.sigmoid = nn.Sigmoid()
self.out_channels = out_channels
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x1 = x[:, 0 : self.out_channels, :, :]
x2 = self.sigmoid(x[:, self.out_channels : 2 * self.out_channels, :, :])
x = x1 * x2
return x
class VggishConvBlock(nn.Module):
def init(self, in_channels, out_channels, stride_freq = 1):
super(VggishConvBlock, self).__init__()
self.conv1 = Conv_GLU(in_channels, out_channels, stride_freq)
self.conv2 = Conv_GLU(out_channels, out_channels, stride_freq = 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class Vggish(nn.Module):
def init(self):
super(Vggish, self).__init__()
self.conv_block1 = VggishConvBlock(in_channels=1, out_channels=64, stride_freq = 1) # N, 64, T, 64
self.conv_block2 = VggishConvBlock(in_channels=64, out_channels=64, stride_freq = 2) # N, 64, T, 32
self.conv_block3 = VggishConvBlock(in_channels=64, out_channels=128, stride_freq = 2) # N, 128, T, 16
self.conv_block4 = VggishConvBlock(in_channels=128, out_channels=128, stride_freq = 2) # N, 128, T, 8
self.conv_block5 = VggishConvBlock(in_channels=128, out_channels=256, stride_freq = 2) # N, 256, T, 4
self.conv_block6 = VggishConvBlock(in_channels=256, out_channels=256, stride_freq = 2) # N, 256, T, 2
def forward(self, x):
x = self.conv_block1(x)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = self.conv_block4(x)
x = self.conv_block5(x)
x = self.conv_block6(x)
return x
class Attention(nn.Module):
def init(self, num_classes, num_in_features):
super(Attention, self).init()
self.bn = nn.BatchNorm1d(num_in_features)
self.FC1 = nn.Linear(num_in_features, num_classes)
self.FC2 = nn.Linear(num_in_features, num_classes)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim = -1)
def forward(self, x):
# x: N, T, D
_, T, D = x.size()
x = x.view(-1, D) # NT, D
x = self.bn(x)
x = x.view(-1, T, D) # N, T, D
out = self.FC1(x)
out = self.sigmoid(out)
att = self.FC2(x)
att = self.softmax(att)
att = torch.clamp(att, 1e-7, 1.0)
out = out * att
out = torch.sum(out, dim = 1, keepdim = False) # N, D
out = out / torch.sum(att, dim = 1, keepdim = False) # N, D
return out
class LSTM(nn.Module):
def init(self, input_size, hidden_size, bidirectional=False):
super(LSTM, self).init()
self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
bidirectional=bidirectional, bias=True, batch_first = True)
def forward(self, x):
# x is of size (N ,T ,D)
self.rnn.flatten_parameters()
x, _ = self.rnn(x)
x = x.contiguous()
return x
class CRNN(nn.Module):
def init(self, rnn_hidden_size = 256, bidirectional = True, classes_num = 32):
super(CRNN, self).__init__()
self.vggish = Vggish()
self.rnn = LSTM(input_size = 512, hidden_size = rnn_hidden_size, bidirectional = bidirectional)
self.att = Attention(num_classes = classes_num, num_in_features = rnn_hidden_size * 2 if bidirectional else rnn_hidden_size)
def forward(self, x):
x = self.vggish(x) # N, C, T, F
x = x.permute(0,2,1,3).contiguous() # (N, T, C, F)
T, C, F = x.size(1), x.size(2), x.size(3)
x = x.view(-1, T, C * F) # N, T, CF
x = self.rnn(x)
x = self.att(x)
return x