Multiple gpu causes error for batchnorm2d

Hi,

I’ve written a CRNN model, which is basically a stack of CNN layers followed by an LSTM and a FC layer. The model was written in pytorch 0.4.1. I tested the model using some toy data, like this:

from torch.nn.parallel import data_parallel as par

model = CRNN(rnn_hidden_size = 32, bidirectional = True, classes_num = 20).type(gpu_dtype)
loss_fn = nn.BCELoss().type(gpu_dtype)
x = torch.randn(128,1,31,64).cuda()
y = torch.zeros(128, 20).cuda()
x = par(model, inputs = x, device_ids = [0])
loss = loss_fn(x, y)
loss.backward()

If I use only 1 GPU, it works fine. But when I used multiple GPUs, it reports the error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

If I commented out the 2d batch normalization in Conv_GLU, then it works fine for both 1 or multiple GPUs.

Can anyone spot any problem where the inplace operation happens, and why it does not work with multiple GPUs with batch normalization in Conv_GLU module?

The entire model is as below. Thank you so much for any help.

class Conv_GLU(nn.Module):
def init(self, in_channels, out_channels, stride_freq = 1):
super(Conv_GLU, self).init()

   self.conv = nn.Conv2d(in_channels=in_channels,
                          out_channels=2*out_channels,
                          kernel_size=(3, 3), stride=(1, stride_freq),
                          padding=(1, 1), bias=False)

   self.bn = nn.BatchNorm2d(num_features = 2*out_channels)
   self.sigmoid = nn.Sigmoid()
   self.out_channels = out_channels

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x1 = x[:, 0 : self.out_channels, :, :]
x2 = self.sigmoid(x[:, self.out_channels : 2 * self.out_channels, :, :])
x = x1 * x2

   return x

class VggishConvBlock(nn.Module):
def init(self, in_channels, out_channels, stride_freq = 1):

    super(VggishConvBlock, self).__init__()

    self.conv1 = Conv_GLU(in_channels, out_channels, stride_freq)

    self.conv2 = Conv_GLU(out_channels, out_channels, stride_freq = 1)


def forward(self, x):

    x = self.conv1(x)
    x = self.conv2(x)

    return x

class Vggish(nn.Module):
def init(self):

    super(Vggish, self).__init__()

    self.conv_block1 = VggishConvBlock(in_channels=1, out_channels=64, stride_freq = 1) # N, 64, T, 64
    self.conv_block2 = VggishConvBlock(in_channels=64, out_channels=64, stride_freq = 2) # N, 64, T, 32
    self.conv_block3 = VggishConvBlock(in_channels=64, out_channels=128, stride_freq = 2) # N, 128, T, 16
    self.conv_block4 = VggishConvBlock(in_channels=128, out_channels=128, stride_freq = 2) # N, 128, T, 8
    self.conv_block5 = VggishConvBlock(in_channels=128, out_channels=256, stride_freq = 2) # N, 256, T, 4
    self.conv_block6 = VggishConvBlock(in_channels=256, out_channels=256, stride_freq = 2) # N, 256, T, 2

def forward(self, x):
    x = self.conv_block1(x)
    x = self.conv_block2(x)
    x = self.conv_block3(x)
    x = self.conv_block4(x)
    x = self.conv_block5(x)
    x = self.conv_block6(x)
    return x

class Attention(nn.Module):
def init(self, num_classes, num_in_features):
super(Attention, self).init()
self.bn = nn.BatchNorm1d(num_in_features)
self.FC1 = nn.Linear(num_in_features, num_classes)
self.FC2 = nn.Linear(num_in_features, num_classes)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim = -1)

def forward(self, x):
   # x: N, T, D
   _, T, D = x.size()
   x = x.view(-1, D) # NT, D
   x = self.bn(x)
   x = x.view(-1, T, D)  # N, T, D
   out = self.FC1(x)
   out = self.sigmoid(out)
   att = self.FC2(x)
   att = self.softmax(att)
   att = torch.clamp(att, 1e-7, 1.0)
   out = out * att
   out = torch.sum(out, dim = 1, keepdim = False) # N, D
   out = out / torch.sum(att, dim = 1, keepdim = False) # N, D

   return out

class LSTM(nn.Module):
def init(self, input_size, hidden_size, bidirectional=False):
super(LSTM, self).init()
self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
bidirectional=bidirectional, bias=True, batch_first = True)

def forward(self, x):
    # x is of size (N ,T ,D)
    self.rnn.flatten_parameters()
    x, _ = self.rnn(x)
    x = x.contiguous()

    return x

class CRNN(nn.Module):
def init(self, rnn_hidden_size = 256, bidirectional = True, classes_num = 32):

    super(CRNN, self).__init__()

    self.vggish = Vggish()
    self.rnn = LSTM(input_size = 512, hidden_size = rnn_hidden_size, bidirectional = bidirectional)
    self.att = Attention(num_classes = classes_num, num_in_features = rnn_hidden_size * 2 if bidirectional else rnn_hidden_size)


def forward(self, x):
   x = self.vggish(x)  # N, C, T, F
   x = x.permute(0,2,1,3).contiguous()  # (N, T, C, F)
   T, C, F = x.size(1), x.size(2), x.size(3)
   x = x.view(-1, T, C * F)   # N, T, CF

   x = self.rnn(x)
   x = self.att(x)
   return x