Implementing LSTMCell. Problem with DataParallel on multi-GPU due to in-place operation

I’m trying to implement LSTM using LSTMCell. Here is my implementation

class lstm(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, n_layers, batch_size):
        super(lstm, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.n_layers = n_layers
        self.embed = nn.Linear(input_size, hidden_size)
        self.lstm = nn.ModuleList([nn.LSTMCell(hidden_size, hidden_size) for i in range(self.n_layers)])
        self.output = nn.Sequential(
                nn.Linear(hidden_size, output_size),
                #nn.BatchNorm1d(output_size),
                nn.Tanh())
        z = torch.zeros(batch_size, self.hidden_size)
        hc = torch.unsqueeze(torch.stack((z, z)), 0)
        hidden = torch.Tensor([])
        for i in range(self.n_layers):
            hidden = torch.cat((hidden, hc))
        self.register_buffer('hidden', hidden)

    def init_hidden(self):
        self.hidden.fill_(0)

    def forward(self, input):
        embedded = self.embed(input.view(-1, self.input_size))
        h_in = embedded
        for i in range(self.n_layers):
            hc = torch.unbind(self.hidden[i])
            lstm_out = self.lstm[i](h_in, hc)
            lstm_out = torch.stack(lstm_out)
            self.hidden[i].copy_(lstm_out)
            h_in = self.hidden[i][0]

        return self.output(h_in)

However, I figured that hidden has requires_grad set to True during training. Thus, the in-place copy_ creates an error as follows.

RuntimeError: Output 0 of SelectBackward0 is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

On the other hand, I believe that in-place updates are the only way to keep the next hidden state and next cell state in DataParallel with multi-GPU.

I can’t figure out a workaround for this issue. Any help would be appreciated!